diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 9031cf4298..122d257b6b 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -320,6 +320,7 @@ scan_pipeline_repeats: false scan_layers_per_stage: false set_remat_policy_on_pipeline_iterations: true set_remat_policy_on_layers_per_stage: false +pipeline_save_decoder_layer_input: true # set to false to reduce pipeline tmem at cost of recomputing decoder layer inputs in backward pass # Choose 'remat_policy' between 'minimal_with_context', 'minimal', 'save_dot_with_context_except_mlp', 'save_dot_except_mlpwi', 'save_dot_except_mlp', diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 1eca954e53..c4027da369 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1011,6 +1011,14 @@ class PipelineParallelism(BaseModel): scan_layers_per_stage: bool = Field(False, description="Use jax.lax.scan over layers within a stage.") set_remat_policy_on_pipeline_iterations: bool = Field(True, description="Set remat policy on the pipeline scan.") set_remat_policy_on_layers_per_stage: bool = Field(False, description="Set remat policy on the inner layer scan.") + pipeline_save_decoder_layer_input: bool = Field( + True, + description=( + "Whether to save 'decoder_layer_input' activations in the pipeline remat policy. " + "Setting to False reduces temporary memory (tmem) during pipeline execution at the cost " + "of recomputing decoder layer inputs in the backward pass." + ), + ) class RematAndOffload(BaseModel): @@ -2850,7 +2858,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de # For AOT compilation and correctness, always prioritize the 'stage' axis for sharding when pipelining. for rule in self.logical_axis_rules: if rule and rule[0] == "activation_embed_and_logits_batch": - rule[1] = ["stage", "data", "fsdp", "fsdp_transpose", "expert"] + rule[1] = [ax for ax in ["stage", "data", "fsdp", "fsdp_transpose", "expert"] if ax in self.mesh_axes] break if "stage" in self.mesh_axes: diff --git a/src/maxtext/kernels/gather_reduce_sc.py b/src/maxtext/kernels/gather_reduce_sc.py index 5b3b8e7597..c858b45bf5 100644 --- a/src/maxtext/kernels/gather_reduce_sc.py +++ b/src/maxtext/kernels/gather_reduce_sc.py @@ -55,6 +55,7 @@ def __getitem__(self, shape): _BF16 = VectorTypeHelper(ir.BF16Type.get) +# fmt: off @jax.jit( static_argnames=[ "reduce_group_size", @@ -69,6 +70,7 @@ def __getitem__(self, shape): "topk_wgt_zero_nan", ], ) +# fmt: on def sc_gather_reduce( op: jax.Array, idx: jax.Array, diff --git a/src/maxtext/layers/attention_op.py b/src/maxtext/layers/attention_op.py index b3c3f296f4..2c937385d2 100644 --- a/src/maxtext/layers/attention_op.py +++ b/src/maxtext/layers/attention_op.py @@ -1624,13 +1624,22 @@ def _sequence_descriptor(segment_ids): dummy_attn_mask = None mask_type = "causal" else: - # Default case: no packing, no context parallelism - dummy_attn_mask = jnp.zeros( - (1, 1, 1, self.max_target_length, self.max_target_length), - dtype=jnp.uint8, - ) - attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) - attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) + # Default case: no packing, no context parallelism. + # For synthetic data, segment IDs are always all-ones (one segment per sequence), so + # the segment mask is all-True and the combined mask reduces to pure causal masking. + # Use mask_type="causal" directly to avoid materializing f32/s32[seq,seq] tensors that + # XLA loop_broadcast_fusion hoists into the pipeline scan carry (+5 GiB temp memory). + if self.config.dataset_type == "synthetic": + attn_mask = None + dummy_attn_mask = None + mask_type = "causal" + else: + dummy_attn_mask = jnp.zeros( + (1, 1, 1, self.max_target_length, self.max_target_length), + dtype=jnp.uint8, + ) + attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1).astype(jnp.uint8) dpa_layer = DotProductAttention( head_dim=head_dim, @@ -1643,12 +1652,10 @@ def _sequence_descriptor(segment_ids): dtype=self.dtype, float32_logits=self.float32_logits, qkv_layout=qkv_layout, - scale_factor=1.0, transpose_batch_sequence=False, window_size=sliding_window_size, context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, context_parallel_axis=self.config.context_sharding, - context_parallel_strategy=self.config.context_parallel_strategy, max_segments_per_seq=max_segments_per_seq, ) diff --git a/src/maxtext/layers/attentions.py b/src/maxtext/layers/attentions.py index 679c891360..062fcda34f 100644 --- a/src/maxtext/layers/attentions.py +++ b/src/maxtext/layers/attentions.py @@ -553,6 +553,7 @@ def __init__( mesh=mesh, shard_mode=config.shard_mode, debug_sharding=config.debug_sharding, + skip_trivial_specs=True, ) def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None: diff --git a/src/maxtext/layers/normalizations.py b/src/maxtext/layers/normalizations.py index e98977c60c..0b9d9ef6fc 100644 --- a/src/maxtext/layers/normalizations.py +++ b/src/maxtext/layers/normalizations.py @@ -22,7 +22,7 @@ import jax from jax import lax import jax.numpy as jnp -from jax.sharding import NamedSharding +from jax.sharding import NamedSharding, reshard from maxtext.common.common_types import Array, DType, ShardMode from maxtext.layers import nnx_wrappers from maxtext.layers.initializers import Initializer, variable_to_logically_partitioned @@ -78,7 +78,10 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> if not self.with_scale: if out_sharding is not None: - y = jax.lax.with_sharding_constraint(y, out_sharding) + if self.shard_mode == ShardMode.EXPLICIT: + y = reshard(y, out_sharding) + else: + y = jax.lax.with_sharding_constraint(y, out_sharding) return y scale = self.scale.get_value() @@ -88,8 +91,14 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) -> scale = jax.device_put(scale, max_utils.device_space()) scale = jnp.asarray(scale, self.dtype) - effective_scale = scale + self.scale_offset - return jnp.einsum("...k,k->...k", y, effective_scale, out_sharding=out_sharding) + effective_scale = scale + self.scale_offset if self.scale_offset != 0.0 else scale + y = y * effective_scale + if out_sharding is not None: + if self.shard_mode == ShardMode.EXPLICIT: + y = reshard(y, out_sharding) + else: + y = jax.lax.with_sharding_constraint(y, out_sharding) + return y class GlobalRMSNorm(RMSNorm): diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 047ddb97a8..b6b4d20ec5 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -36,6 +36,13 @@ import jax.numpy as jnp from jax.sharding import NamedSharding + +import flax + +try: + flax.config.update("flax_always_shard_variable", False) +except LookupError: + pass from flax import linen as nn, nnx from flax.linen import partitioning as nn_partitioning from flax.nnx import variablelib @@ -394,10 +401,11 @@ def diff_wrapper(curr_params, custom_params, rest, config, data): (loss, (aux, new_rest)), (raw_grads, custom_grads) = grad_func(curr_params, custom_params, rest, config, data) nnx.update(state.model, nnx.State.merge(custom_grads, new_rest)) - raw_grads = jax.tree_util.tree_map( - lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, - raw_grads, - ) + if config.grad_dtype != jnp.float32: + raw_grads = jax.tree_util.tree_map( + lambda x: x.astype(config.grad_dtype) if x.dtype == jnp.float32 else x, + raw_grads, + ) if config.parameter_memory_host_offload: raw_grads = jax.device_put( raw_grads, diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 836c425f09..cca5311a89 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -231,7 +231,11 @@ def jit_and_compile( def save_compiled(compiled, save_name): """Serialize and save the compiled function.""" - serialized, _, _ = serialize(compiled) + result = serialize(compiled) + # jax.experimental.serialize_executable.serialize() changed its return type: + # older JAX: (bytes, in_tree, out_tree) + # newer JAX: bytes + serialized = result[0] if isinstance(result, tuple) else result with open(save_name, "wb") as f: f.write(serialized) diff --git a/src/maxtext/utils/reference_hlo_deepseek3.txt b/src/maxtext/utils/reference_hlo_deepseek3.txt new file mode 100644 index 0000000000..a5324d4234 --- /dev/null +++ b/src/maxtext/utils/reference_hlo_deepseek3.txt @@ -0,0 +1,2000 @@ +HloModule jit_train_step, is_scheduled=true, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias), {4}: (4, {}, may-alias), {5}: (5, {}, may-alias), {6}: (6, {}, may-alias), {7}: (7, {}, may-alias), {8}: (8, {}, may-alias), {9}: (9, {}, may-alias), {10}: (10, {}, may-alias), {11}: (11, {}, may-alias), {12}: (12, {}, may-alias), {13}: (13, {}, may-alias), {14}: (14, {}, may-alias), {15}: (15, {}, may-alias), {16}: (16, {}, may-alias), {17}: (17, {}, may-alias), {18}: (18, {}, may-alias), {19}: (19, {}, may-alias), {20}: (20, {}, may-alias), {21}: (21, {}, may-alias), {22}: (22, {}, may-alias), {23}: (23, {}, may-alias), {24}: (24, {}, may-alias), {25}: (25, {}, may-alias), {26}: (26, {}, may-alias), {27}: (27, {}, may-alias), {28}: (28, {}, may-alias), {29}: (29, {}, may-alias), {30}: (30, {}, may-alias), {31}: (31, {}, may-alias), {32}: (32, {}, may-alias), {33}: (33, {}, may-alias), {34}: (34, {}, may-alias), {35}: (35, {}, may-alias), {36}: (36, {}, may-alias), {37}: (37, {}, may-alias), {38}: (38, {}, may-alias), {39}: (39, {}, may-alias), {40}: (40, {}, may-alias), {41}: (41, {}, may-alias), {42}: (42, {}, may-alias), {43}: (43, {}, may-alias), {44}: (44, {}, may-alias), {45}: (45, {}, may-alias), {46}: (46, {}, may-alias), {47}: (47, {}, may-alias), {48}: (48, {}, may-alias), {49}: (49, {}, may-alias), {50}: (50, {}, may-alias), {51}: (51, {}, may-alias), {52}: (52, {}, may-alias), {53}: (53, {}, may-alias), {54}: (54, {}, may-alias), {55}: (55, {}, may-alias), {56}: (56, {}, may-alias), {57}: (57, {}, may-alias), {58}: (58, {}, may-alias), {59}: (59, {}, may-alias), {60}: (60, {}, may-alias), {61}: (61, {}, may-alias), {62}: (62, {}, may-alias), {63}: (63, {}, may-alias), {64}: (64, {}, may-alias), {65}: (65, {}, may-alias), {66}: (66, {}, may-alias), {67}: (67, {}, may-alias), {68}: (68, {}, may-alias), {69}: (69, {}, may-alias), {70}: (70, {}, may-alias), {71}: (71, {}, may-alias), {72}: (72, {}, may-alias), {73}: (73, {}, may-alias), {74}: (74, {}, may-alias), {75}: (75, {}, may-alias), {76}: (76, {}, may-alias), {77}: (77, {}, may-alias), {78}: (78, {}, may-alias), {79}: (79, {}, may-alias), {80}: (80, {}, may-alias), {81}: (81, {}, may-alias), {82}: (82, {}, may-alias), {83}: (83, {}, may-alias), {84}: (84, {}, may-alias), {85}: (85, {}, may-alias), {86}: (86, {}, may-alias), {87}: (87, {}, may-alias), {88}: (88, {}, may-alias), {89}: (89, {}, may-alias), {90}: (90, {}, may-alias), {91}: (91, {}, may-alias), {92}: (92, {}, may-alias), {93}: (93, {}, may-alias), {94}: (94, {}, may-alias), {95}: (95, {}, may-alias), {96}: (96, {}, may-alias), {97}: (97, {}, may-alias), {98}: (98, {}, may-alias) }, entry_computation_layout={(s32[]{:T(128)}, f32[512]{0:T(512)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, /*index=5*/f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[1536,3]{0,1:T(4,128)}, /*index=10*/f32[512,3,576]{0,2,1:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,1536]{2,0,1:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, /*index=15*/f32[256,1]{0,1:T(1,128)}, f32[512,1,256]{2,1,0:T(1,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, /*index=20*/f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, /*index=25*/f32[512,1]{0,1:T(1,128)}, f32[128,1,128,512]{3,2,1,0:T(8,128)}, f32[1536,1]{0,1:T(1,128)}, f32[512,1,576]{0,2,1:T(8,128)}, f32[512,1,128,256]{3,2,1,0:T(8,128)}, /*index=30*/f32[512,1,1536]{2,1,0:T(1,128)}, f32[1536,1,128,192]{2,3,1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, s32[]{:T(128)}, f32[512]{0:T(512)}, /*index=35*/f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, /*index=40*/f32[512,3]{0,1:T(4,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[1536,3]{0,1:T(4,128)}, f32[512,3,576]{0,2,1:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, /*index=45*/f32[512,3,1536]{2,0,1:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[256,1]{0,1:T(1,128)}, f32[512,1,256]{2,1,0:T(1,128)}, /*index=50*/f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, /*index=55*/f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[128,1,128,512]{3,2,1,0:T(8,128)}, /*index=60*/f32[1536,1]{0,1:T(1,128)}, f32[512,1,576]{0,2,1:T(8,128)}, f32[512,1,128,256]{3,2,1,0:T(8,128)}, f32[512,1,1536]{2,1,0:T(1,128)}, f32[1536,1,128,192]{2,3,1,0:T(8,128)}, /*index=65*/f32[129280,512]{1,0:T(8,128)}, f32[512]{0:T(512)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, /*index=70*/f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[1536,3]{0,1:T(4,128)}, /*index=75*/f32[512,3,576]{0,2,1:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,1536]{2,0,1:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, /*index=80*/f32[256,1]{0,1:T(1,128)}, f32[512,1,256]{2,1,0:T(1,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, /*index=85*/f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, /*index=90*/f32[512,1]{0,1:T(1,128)}, f32[128,1,128,512]{3,2,1,0:T(8,128)}, f32[1536,1]{0,1:T(1,128)}, f32[512,1,576]{0,2,1:T(8,128)}, f32[512,1,128,256]{3,2,1,0:T(8,128)}, /*index=95*/f32[512,1,1536]{2,1,0:T(1,128)}, f32[1536,1,128,192]{2,3,1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, s32[]{:T(128)}, s32[4,128]{1,0:T(4,128)}, /*index=100*/s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)})->(s32[]{:T(128)}, f32[512]{0:T(512)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, /*index=5*/f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[1536,3]{0,1:T(4,128)}, /*index=10*/f32[512,3,576]{0,2,1:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,1536]{2,0,1:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, /*index=15*/f32[256,1]{0,1:T(1,128)}, f32[512,1,256]{2,1,0:T(1,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, /*index=20*/f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, /*index=25*/f32[512,1]{0,1:T(1,128)}, f32[128,1,128,512]{3,2,1,0:T(8,128)}, f32[1536,1]{0,1:T(1,128)}, f32[512,1,576]{0,2,1:T(8,128)}, f32[512,1,128,256]{3,2,1,0:T(8,128)}, /*index=30*/f32[512,1,1536]{2,1,0:T(1,128)}, f32[1536,1,128,192]{2,3,1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, s32[]{:T(128)}, f32[512]{0:T(512)}, /*index=35*/f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, /*index=40*/f32[512,3]{0,1:T(4,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[1536,3]{0,1:T(4,128)}, f32[512,3,576]{0,2,1:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, /*index=45*/f32[512,3,1536]{2,0,1:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[256,1]{0,1:T(1,128)}, f32[512,1,256]{2,1,0:T(1,128)}, /*index=50*/f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, /*index=55*/f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[128,1,128,512]{3,2,1,0:T(8,128)}, /*index=60*/f32[1536,1]{0,1:T(1,128)}, f32[512,1,576]{0,2,1:T(8,128)}, f32[512,1,128,256]{3,2,1,0:T(8,128)}, f32[512,1,1536]{2,1,0:T(1,128)}, f32[1536,1,128,192]{2,3,1,0:T(8,128)}, /*index=65*/f32[129280,512]{1,0:T(8,128)}, f32[512]{0:T(512)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[512,3,18432]{2,0,1:T(8,128)}, f32[18432,3,512]{2,0,1:T(8,128)}, /*index=70*/f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[512,3]{0,1:T(4,128)}, f32[128,3,128,512]{3,2,1,0:T(8,128)}, f32[1536,3]{0,1:T(4,128)}, /*index=75*/f32[512,3,576]{0,2,1:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,1536]{2,0,1:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, /*index=80*/f32[256,1]{0,1:T(1,128)}, f32[512,1,256]{2,1,0:T(1,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, /*index=85*/f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1,512]{2,1,0:T(1,128)}, f32[512,1]{0,1:T(1,128)}, f32[512,1]{0,1:T(1,128)}, /*index=90*/f32[512,1]{0,1:T(1,128)}, f32[128,1,128,512]{3,2,1,0:T(8,128)}, f32[1536,1]{0,1:T(1,128)}, f32[512,1,576]{0,2,1:T(8,128)}, f32[512,1,128,256]{3,2,1,0:T(8,128)}, /*index=95*/f32[512,1,1536]{2,1,0:T(1,128)}, f32[1536,1,128,192]{2,3,1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, s32[]{:T(128)}, f32[]{:T(128)}, /*index=100*/f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, /*index=105*/f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, s32[]{:T(128)}, f32[]{:T(128)})}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false}, allow_spmd_sharding_propagation_to_output={false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,true,true,true,true,true,true,true,true,true,true,true}, num_partitions=4 + +FileNames + +FunctionNames + +FileLocations + +StackFrames + + +%region_46.56 (top_k.25: bf16[], top_k.26: bf16[], top_k.27: s32[], top_k.28: s32[]) -> pred[] { + %constant.1377 = s32[]{:T(128)} constant(0) + %constant.1378 = s32[]{:T(128)} constant(2147483647) + %top_k.25 = bf16[]{:T(256)} parameter(0), metadata={op_name="top_k"} + %top_k.26 = bf16[]{:T(256)} parameter(1), metadata={op_name="top_k"} + %top_k.27 = s32[]{:T(128)} parameter(2), metadata={op_name="top_k"} + %top_k.28 = s32[]{:T(128)} parameter(3), metadata={op_name="top_k"} + %convert.269 = f32[]{:T(128)S(6)} convert(%top_k.25), metadata={op_name="convert.18"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.39 = s32[]{:T(128)S(6)} bitcast-convert(%convert.269), metadata={op_name="bitcast-convert.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.144 = pred[]{:T(512)S(6)} compare(%bitcast-convert.39, %constant.1377), direction=LT, metadata={op_name="compare.38"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.40 = s32[]{:T(128)S(6)} xor(%constant.1378, %bitcast-convert.39), metadata={op_name="xor.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %select.127 = s32[]{:T(128)S(6)} select(%compare.144, %xor.40, %bitcast-convert.39), metadata={op_name="select.16"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} + %convert.270 = f32[]{:T(128)S(6)} convert(%top_k.26), metadata={op_name="convert.19"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.40 = s32[]{:T(128)S(6)} bitcast-convert(%convert.270), metadata={op_name="bitcast-convert.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.145 = pred[]{:T(512)S(6)} compare(%bitcast-convert.40, %constant.1377), direction=LT, metadata={op_name="compare.39"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.41 = s32[]{:T(128)S(6)} xor(%constant.1378, %bitcast-convert.40), metadata={op_name="xor.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %select.128 = s32[]{:T(128)S(6)} select(%compare.145, %xor.41, %bitcast-convert.40), metadata={op_name="select.17"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} + %compare.146 = pred[]{:T(512)S(6)} compare(%select.127, %select.128), direction=GT, metadata={op_name="compare.0"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.147 = pred[]{:T(512)S(6)} compare(%select.128, %select.127), direction=GT, metadata={op_name="compare.117"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.148 = pred[]{:T(512)S(6)} compare(%compare.146, %compare.147), direction=EQ, metadata={op_name="compare.118"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.149 = pred[]{:T(512)S(6)} compare(%top_k.27, %top_k.28), direction=LT, metadata={op_name="compare.119"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.129 = pred[]{:T(512)} select(%compare.148, %compare.149, %compare.146), metadata={op_name="select.113"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_47.57 (sort.64: s32[], sort.65: s32[], sort.66: s32[], sort.67: s32[]) -> pred[] { + %sort.64 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(argsort)/sort"} + %sort.65 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(argsort)/sort"} + %sort.66 = s32[]{:T(128)} parameter(2), metadata={op_name="jit(argsort)/sort"} + %sort.67 = s32[]{:T(128)} parameter(3), metadata={op_name="jit(argsort)/sort"} + %lt_to.32 = pred[]{:T(512)S(6)} compare(%sort.64, %sort.65), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %lt_to.33 = pred[]{:T(512)S(6)} compare(%sort.65, %sort.64), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.150 = pred[]{:T(512)S(6)} compare(%lt_to.32, %lt_to.33), direction=EQ, metadata={op_name="compare.120"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.151 = pred[]{:T(512)S(6)} compare(%sort.66, %sort.67), direction=LT, metadata={op_name="compare.121"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.130 = pred[]{:T(512)} select(%compare.150, %compare.151, %lt_to.32), metadata={op_name="select.114"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_48.58 (sort.68: s32[], sort.69: s32[], sort.70: s32[], sort.71: s32[]) -> pred[] { + %sort.68 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(argsort)/sort"} + %sort.69 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(argsort)/sort"} + %sort.70 = s32[]{:T(128)} parameter(2), metadata={op_name="jit(argsort)/sort"} + %sort.71 = s32[]{:T(128)} parameter(3), metadata={op_name="jit(argsort)/sort"} + %lt_to.34 = pred[]{:T(512)S(6)} compare(%sort.68, %sort.69), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %lt_to.35 = pred[]{:T(512)S(6)} compare(%sort.69, %sort.68), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.152 = pred[]{:T(512)S(6)} compare(%lt_to.34, %lt_to.35), direction=EQ, metadata={op_name="compare.122"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.153 = pred[]{:T(512)S(6)} compare(%sort.70, %sort.71), direction=LT, metadata={op_name="compare.123"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.131 = pred[]{:T(512)} select(%compare.152, %compare.153, %lt_to.34), metadata={op_name="select.115"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_67.80 (sort.78: s32[], sort.79: s32[], sort.80: s32[], sort.81: s32[]) -> pred[] { + %sort.78 = s32[]{:T(128)} parameter(0), metadata={op_name="sort_activations/jit(argsort)/sort"} + %sort.79 = s32[]{:T(128)} parameter(1), metadata={op_name="sort_activations/jit(argsort)/sort"} + %sort.80 = s32[]{:T(128)} parameter(2), metadata={op_name="sort_activations/jit(argsort)/sort"} + %sort.81 = s32[]{:T(128)} parameter(3), metadata={op_name="sort_activations/jit(argsort)/sort"} + %lt_to.37 = pred[]{:T(512)S(6)} compare(%sort.78, %sort.79), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %lt_to.38 = pred[]{:T(512)S(6)} compare(%sort.79, %sort.78), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.156 = pred[]{:T(512)S(6)} compare(%lt_to.37, %lt_to.38), direction=EQ, metadata={op_name="compare.124"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.157 = pred[]{:T(512)S(6)} compare(%sort.80, %sort.81), direction=LT, metadata={op_name="compare.125"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.134 = pred[]{:T(512)} select(%compare.156, %compare.157, %lt_to.37), metadata={op_name="select.116"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_119.141 (reduce_sum.151: bf16[], reduce_sum.152: bf16[]) -> bf16[] { + %reduce_sum.151 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/reduce_sum"} + %reduce_sum.152 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/reduce_sum"} + ROOT %reduce_sum.153 = bf16[]{:T(256)} add(%reduce_sum.151, %reduce_sum.152), metadata={op_name="checkpoint/moe_layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_107.126 (psum.6: bf16[], psum.9: bf16[]) -> bf16[] { + %psum.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} + %psum.9 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} + ROOT %add.1411 = bf16[]{:T(256)} add(%psum.6, %psum.9), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_108.127 (psum.10: bf16[], psum.11: bf16[]) -> bf16[] { + %psum.10 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} + %psum.11 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} + ROOT %add.1412 = bf16[]{:T(256)} add(%psum.10, %psum.11), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_109.128 (psum.14: bf16[], psum.15: bf16[]) -> bf16[] { + %psum.14 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} + %psum.15 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} + ROOT %add.1413 = bf16[]{:T(256)} add(%psum.14, %psum.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_62.73 (reduce-window.111: s32[], reduce-window.112: s32[]) -> s32[] { + %reduce-window.111 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.35"} + %reduce-window.112 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.35"} + ROOT %reduce_window_sum.108 = s32[]{:T(128)} add(%reduce-window.111, %reduce-window.112), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_64.75 (reduce-window.113: s32[], reduce-window.114: s32[]) -> s32[] { + %reduce-window.113 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.36"} + %reduce-window.114 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.36"} + ROOT %reduce_window_sum.109 = s32[]{:T(128)} add(%reduce-window.113, %reduce-window.114), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_65.76 (reduce-window.115: s32[], reduce-window.116: s32[]) -> s32[] { + %reduce-window.115 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.63"} + %reduce-window.116 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.63"} + ROOT %reduce_window_sum.110 = s32[]{:T(128)} add(%reduce-window.115, %reduce-window.116), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_68.81.clone (reduce-window.396: s32[], reduce-window.397: s32[]) -> s32[] { + %reduce-window.396 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.38"} + %reduce-window.397 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.38"} + ROOT %reduce_window_sum.317 = s32[]{:T(128)} add(%reduce-window.396, %reduce-window.397), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_69.82.clone (reduce-window.400: s32[], reduce-window.401: s32[]) -> s32[] { + %reduce-window.400 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.64"} + %reduce-window.401 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.64"} + ROOT %reduce_window_sum.319 = s32[]{:T(128)} add(%reduce-window.400, %reduce-window.401), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_71.84.clone (reduce-window.404: s32[], reduce-window.405: s32[]) -> s32[] { + %reduce-window.404 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.40"} + %reduce-window.405 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.40"} + ROOT %reduce_window_sum.321 = s32[]{:T(128)} add(%reduce-window.404, %reduce-window.405), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_72.85.clone (reduce-window.408: s32[], reduce-window.409: s32[]) -> s32[] { + %reduce-window.408 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.65"} + %reduce-window.409 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.65"} + ROOT %reduce_window_sum.323 = s32[]{:T(128)} add(%reduce-window.408, %reduce-window.409), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_74.87.clone (reduce-window.412: s32[], reduce-window.413: s32[]) -> s32[] { + %reduce-window.412 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.42"} + %reduce-window.413 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.42"} + ROOT %reduce_window_sum.325 = s32[]{:T(128)} add(%reduce-window.412, %reduce-window.413), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_75.88.clone (reduce-window.416: s32[], reduce-window.417: s32[]) -> s32[] { + %reduce-window.416 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.66"} + %reduce-window.417 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.66"} + ROOT %reduce_window_sum.327 = s32[]{:T(128)} add(%reduce-window.416, %reduce-window.417), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_80.96.clone (reduce-window.420: s32[], reduce-window.421: s32[]) -> s32[] { + %reduce-window.420 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.44"} + %reduce-window.421 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.44"} + ROOT %reduce_window_sum.329 = s32[]{:T(128)} add(%reduce-window.420, %reduce-window.421), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_82.98.clone (reduce-window.424: s32[], reduce-window.425: s32[]) -> s32[] { + %reduce-window.424 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.45"} + %reduce-window.425 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.45"} + ROOT %reduce_window_sum.331 = s32[]{:T(128)} add(%reduce-window.424, %reduce-window.425), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_83.99.clone (reduce-window.428: s32[], reduce-window.429: s32[]) -> s32[] { + %reduce-window.428 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.67"} + %reduce-window.429 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.67"} + ROOT %reduce_window_sum.333 = s32[]{:T(128)} add(%reduce-window.428, %reduce-window.429), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_94.112 (reduce-window.174: s32[], reduce-window.175: s32[]) -> s32[] { + %reduce-window.174 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.49"} + %reduce-window.175 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.49"} + ROOT %reduce_window_sum.138 = s32[]{:T(128)} add(%reduce-window.174, %reduce-window.175), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_95.113 (reduce-window.179: s32[], reduce-window.180: s32[]) -> s32[] { + %reduce-window.179 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.69"} + %reduce-window.180 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.69"} + ROOT %reduce_window_sum.139 = s32[]{:T(128)} add(%reduce-window.179, %reduce-window.180), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_97.115 (reduce-window.184: s32[], reduce-window.185: s32[]) -> s32[] { + %reduce-window.184 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.51"} + %reduce-window.185 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.51"} + ROOT %reduce_window_sum.140 = s32[]{:T(128)} add(%reduce-window.184, %reduce-window.185), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_98.116 (reduce-window.189: s32[], reduce-window.190: s32[]) -> s32[] { + %reduce-window.189 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.70"} + %reduce-window.190 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.70"} + ROOT %reduce_window_sum.141 = s32[]{:T(128)} add(%reduce-window.189, %reduce-window.190), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_103.121 (reduce-window.194: s32[], reduce-window.195: s32[]) -> s32[] { + %reduce-window.194 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.53"} + %reduce-window.195 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.53"} + ROOT %reduce_window_sum.142 = s32[]{:T(128)} add(%reduce-window.194, %reduce-window.195), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_105.123 (reduce-window.199: s32[], reduce-window.200: s32[]) -> s32[] { + %reduce-window.199 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.54"} + %reduce-window.200 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.54"} + ROOT %reduce_window_sum.143 = s32[]{:T(128)} add(%reduce-window.199, %reduce-window.200), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_106.124 (reduce-window.204: s32[], reduce-window.205: s32[]) -> s32[] { + %reduce-window.204 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.71"} + %reduce-window.205 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.71"} + ROOT %reduce_window_sum.144 = s32[]{:T(128)} add(%reduce-window.204, %reduce-window.205), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.5 (param_0.17: bf16[129280,512], param_1.108: s32[1024]) -> bf16[512,512] { + %param_0.17 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) + %param_1.108 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.13 = s32[1024]{0:T(1024)} custom-call(%param_1.108), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %slice.892 = s32[512]{0:T(512)} slice(%custom-call.13), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %reshape.3433 = s32[4,128]{1,0:T(4,128)} reshape(%slice.892), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.604 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3433), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %gather.183 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} gather(%param_0.17, %transpose.604), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %transpose.603 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%gather.183), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %reshape.3432 = bf16[512,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.603), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} +} + +%fused_computation.6 (param_0.20: f32[163840,32], param_1.110: s32[1024]) -> f32[512,32] { + %param_0.20 = f32[163840,32]{1,0:T(8,128)} parameter(0) + %param_1.110 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.15 = s32[1024]{0:T(1024)} custom-call(%param_1.110), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %slice.894 = s32[512]{0:T(512)} slice(%custom-call.15), slice={[0:512]}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %reshape.3441 = s32[4,128]{1,0:T(4,128)} reshape(%slice.894), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %transpose.610 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3441), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %gather.185 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.20, %transpose.610), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %transpose.609 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.185), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + ROOT %reshape.3440 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.609), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} +} + +%fused_computation.7 (param_0.23: f32[163840,32], param_1.112: s32[1024]) -> f32[512,32] { + %param_0.23 = f32[163840,32]{1,0:T(8,128)} parameter(0) + %param_1.112 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.17 = s32[1024]{0:T(1024)} custom-call(%param_1.112), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %slice.896 = s32[512]{0:T(512)} slice(%custom-call.17), slice={[0:512]}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %reshape.3449 = s32[4,128]{1,0:T(4,128)} reshape(%slice.896), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %transpose.616 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3449), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %gather.187 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.23, %transpose.616), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %transpose.615 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.187), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + ROOT %reshape.3448 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.615), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} +} + +%fused_computation.8 (param_0.26: f32[163840,32], param_1.120: s32[1024]) -> f32[512,32] { + %param_0.26 = f32[163840,32]{1,0:T(8,128)} parameter(0) + %param_1.120 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.25 = s32[1024]{0:T(1024)} custom-call(%param_1.120), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %slice.904 = s32[512]{0:T(512)} slice(%custom-call.25), slice={[0:512]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %reshape.3457 = s32[4,128]{1,0:T(4,128)} reshape(%slice.904), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %transpose.622 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3457), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %gather.189 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.26, %transpose.622), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %transpose.621 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.189), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + ROOT %reshape.3456 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.621), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} +} + +%fused_computation.9 (param_0.29: f32[163840,32], param_1.122: s32[1024]) -> f32[512,32] { + %param_0.29 = f32[163840,32]{1,0:T(8,128)} parameter(0) + %param_1.122 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.27 = s32[1024]{0:T(1024)} custom-call(%param_1.122), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %slice.906 = s32[512]{0:T(512)} slice(%custom-call.27), slice={[0:512]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %reshape.3465 = s32[4,128]{1,0:T(4,128)} reshape(%slice.906), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %transpose.628 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3465), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %gather.191 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.29, %transpose.628), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %transpose.627 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.191), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + ROOT %reshape.3464 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.627), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} +} + +%fused_computation.10 (param_0.32: bf16[4096,512], param_1.126: s32[4096]) -> bf16[4096,512] { + %param_0.32 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) + %param_1.126 = s32[4096]{0:T(1024)S(1)} parameter(1) + %custom-call.31 = s32[4096]{0:T(1024)} custom-call(%param_1.126), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %slice.910 = s32[4096]{0:T(1024)} slice(%custom-call.31), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3473 = s32[4096]{0:T(1024)} reshape(%slice.910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.634 = s32[4096]{0:T(1024)} transpose(%reshape.3473), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.193 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.32, %transpose.634), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.633 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.193), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3472 = bf16[4096,512]{1,0:T(8,128)(2,1)} reshape(%transpose.633), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} +} + +%fused_computation.11 (param_0.35: bf16[4096,512], param_1.128: s32[4096]) -> bf16[4096,512] { + %param_0.35 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) + %param_1.128 = s32[4096]{0:T(1024)S(1)} parameter(1) + %custom-call.33 = s32[4096]{0:T(1024)} custom-call(%param_1.128), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %slice.912 = s32[4096]{0:T(1024)} slice(%custom-call.33), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3481 = s32[4096]{0:T(1024)} reshape(%slice.912), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.640 = s32[4096]{0:T(1024)} transpose(%reshape.3481), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.195 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.35, %transpose.640), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.639 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.195), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3480 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.639), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} +} + +%fused_computation.12 (param_0.38: bf16[4096,512], param_1.130: s32[4096]) -> bf16[4096,512] { + %param_0.38 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) + %param_1.130 = s32[4096]{0:T(1024)S(1)} parameter(1) + %custom-call.35 = s32[4096]{0:T(1024)} custom-call(%param_1.130), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %slice.914 = s32[4096]{0:T(1024)} slice(%custom-call.35), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3489 = s32[4096]{0:T(1024)} reshape(%slice.914), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.646 = s32[4096]{0:T(1024)} transpose(%reshape.3489), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.197 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.38, %transpose.646), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.645 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.197), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3488 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.645), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} +} + +%fused_computation.13 (param_0.41: bf16[4096,512], param_1.132: s32[4096]) -> bf16[4096,512] { + %param_0.41 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) + %param_1.132 = s32[4096]{0:T(1024)S(1)} parameter(1) + %custom-call.37 = s32[4096]{0:T(1024)} custom-call(%param_1.132), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %slice.916 = s32[4096]{0:T(1024)} slice(%custom-call.37), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3497 = s32[4096]{0:T(1024)} reshape(%slice.916), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.652 = s32[4096]{0:T(1024)} transpose(%reshape.3497), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.199 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.41, %transpose.652), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.651 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.199), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3496 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.651), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} +} + +%fused_computation.15 (param_0.47: s32[256], param_1.124: s32[1024]) -> s32[263] { + %param_0.47 = s32[256]{0:T(256)S(1)} parameter(0) + %param_1.124 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.29 = s32[1024]{0:T(1024)} custom-call(%param_1.124), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %slice.908 = s32[263]{0:T(512)} slice(%custom-call.29), slice={[0:263]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %reshape.3528 = s32[263]{0:T(512)} reshape(%slice.908), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %transpose.668 = s32[263]{0:T(512)} transpose(%reshape.3528), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %gather.204 = s32[263]{0:T(512)} gather(%param_0.47, %transpose.668), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %transpose.667 = s32[263]{0:T(512)} transpose(%gather.204), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + ROOT %reshape.3527 = s32[263]{0:T(512)S(1)} reshape(%transpose.667), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} +} + +%fused_computation.16 (param_0.50: s32[256], param_1.134: s32[1024]) -> s32[263] { + %param_0.50 = s32[256]{0:T(256)S(1)} parameter(0) + %param_1.134 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.39 = s32[1024]{0:T(1024)} custom-call(%param_1.134), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + %slice.918 = s32[263]{0:T(512)} slice(%custom-call.39), slice={[0:263]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + %reshape.3551 = s32[263]{0:T(512)} reshape(%slice.918), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %transpose.678 = s32[263]{0:T(512)} transpose(%reshape.3551), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %gather.207 = s32[263]{0:T(512)} gather(%param_0.50, %transpose.678), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + %transpose.677 = s32[263]{0:T(512)} transpose(%gather.207), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + ROOT %reshape.3550 = s32[263]{0:T(512)S(1)} reshape(%transpose.677), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} +} + +%region_173.198.clone (scatter-add.94: bf16[], scatter-add.96: bf16[]) -> bf16[] { + %scatter-add.94 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} + %scatter-add.96 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} + ROOT %add.1876 = bf16[]{:T(256)} add(%scatter-add.94, %scatter-add.96), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.21 (param_0.55: bf16[129280,512], param_1.65: s32[512], param_2.24: bf16[512,512]) -> bf16[129280,512] { + %param_0.55 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) + %param_1.65 = s32[512]{0:T(512)S(1)} parameter(1) + %reshape.3605 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.65), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.711 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3605), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %param_2.24 = bf16[512,512]{1,0:T(8,128)(2,1)S(1)} parameter(2) + %reshape.3606 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} reshape(%param_2.24), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} + %transpose.712 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%reshape.3606), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} + ROOT %scatter.73 = bf16[129280,512]{1,0:T(8,128)(2,1)} scatter(%param_0.55, %transpose.711, %transpose.712), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_173.198.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} +} + +%region_11.17 (top_k.0: bf16[], top_k.6: bf16[], top_k.7: s32[], top_k.8: s32[]) -> pred[] { + %constant.1338 = s32[]{:T(128)} constant(0) + %constant.1339 = s32[]{:T(128)} constant(2147483647) + %top_k.0 = bf16[]{:T(256)} parameter(0), metadata={op_name="top_k"} + %top_k.6 = bf16[]{:T(256)} parameter(1), metadata={op_name="top_k"} + %top_k.7 = s32[]{:T(128)} parameter(2), metadata={op_name="top_k"} + %top_k.8 = s32[]{:T(128)} parameter(3), metadata={op_name="top_k"} + %convert.263 = f32[]{:T(128)S(6)} convert(%top_k.0), metadata={op_name="convert.16"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.35 = s32[]{:T(128)S(6)} bitcast-convert(%convert.263), metadata={op_name="bitcast-convert.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.128 = pred[]{:T(512)S(6)} compare(%bitcast-convert.35, %constant.1338), direction=LT, metadata={op_name="compare.35"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.36 = s32[]{:T(128)S(6)} xor(%constant.1339, %bitcast-convert.35), metadata={op_name="xor.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %select.118 = s32[]{:T(128)S(6)} select(%compare.128, %xor.36, %bitcast-convert.35), metadata={op_name="select.14"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} + %convert.264 = f32[]{:T(128)S(6)} convert(%top_k.6), metadata={op_name="convert.17"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.36 = s32[]{:T(128)S(6)} bitcast-convert(%convert.264), metadata={op_name="bitcast-convert.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.129 = pred[]{:T(512)S(6)} compare(%bitcast-convert.36, %constant.1338), direction=LT, metadata={op_name="compare.36"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.37 = s32[]{:T(128)S(6)} xor(%constant.1339, %bitcast-convert.36), metadata={op_name="xor.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %select.119 = s32[]{:T(128)S(6)} select(%compare.129, %xor.37, %bitcast-convert.36), metadata={op_name="select.15"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} + %compare.130 = pred[]{:T(512)S(6)} compare(%select.118, %select.119), direction=GT, metadata={op_name="compare.1"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.131 = pred[]{:T(512)S(6)} compare(%select.119, %select.118), direction=GT, metadata={op_name="compare.108"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.132 = pred[]{:T(512)S(6)} compare(%compare.130, %compare.131), direction=EQ, metadata={op_name="compare.109"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.133 = pred[]{:T(512)S(6)} compare(%top_k.7, %top_k.8), direction=LT, metadata={op_name="compare.110"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.120 = pred[]{:T(512)} select(%compare.132, %compare.133, %compare.130), metadata={op_name="select.108"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_14.20.clone.1 (reduce-window.326: s32[], reduce-window.327: s32[]) -> s32[] { + %reduce-window.326 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.20"} + %reduce-window.327 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.20"} + ROOT %reduce_window_sum.282 = s32[]{:T(128)} add(%reduce-window.326, %reduce-window.327), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_15.21.clone.1 (reduce-window.330: s32[], reduce-window.331: s32[]) -> s32[] { + %reduce-window.330 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.56"} + %reduce-window.331 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.56"} + ROOT %reduce_window_sum.284 = s32[]{:T(128)} add(%reduce-window.330, %reduce-window.331), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_17.23.clone.1 (reduce-window.334: s32[], reduce-window.335: s32[]) -> s32[] { + %reduce-window.334 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.22"} + %reduce-window.335 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.22"} + ROOT %reduce_window_sum.286 = s32[]{:T(128)} add(%reduce-window.334, %reduce-window.335), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_18.24.clone.1 (reduce-window.338: s32[], reduce-window.339: s32[]) -> s32[] { + %reduce-window.338 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.57"} + %reduce-window.339 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.57"} + ROOT %reduce_window_sum.288 = s32[]{:T(128)} add(%reduce-window.338, %reduce-window.339), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_20.26.clone.1 (reduce-window.342: s32[], reduce-window.343: s32[]) -> s32[] { + %reduce-window.342 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.24"} + %reduce-window.343 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.24"} + ROOT %reduce_window_sum.290 = s32[]{:T(128)} add(%reduce-window.342, %reduce-window.343), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_21.27.clone.1 (reduce-window.346: s32[], reduce-window.347: s32[]) -> s32[] { + %reduce-window.346 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.58"} + %reduce-window.347 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.58"} + ROOT %reduce_window_sum.292 = s32[]{:T(128)} add(%reduce-window.346, %reduce-window.347), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.4.clone (param_0.68: s32[256], param_1.114: s32[1024]) -> s32[263] { + %param_0.68 = s32[256]{0:T(256)S(1)} parameter(0) + %param_1.114 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.19 = s32[1024]{0:T(1024)} custom-call(%param_1.114), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %slice.898 = s32[263]{0:T(512)} slice(%custom-call.19), slice={[0:263]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %reshape.3749 = s32[263]{0:T(512)} reshape(%slice.898), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %transpose.794 = s32[263]{0:T(512)} transpose(%reshape.3749), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %gather.209 = s32[263]{0:T(512)} gather(%param_0.68, %transpose.794), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %transpose.793 = s32[263]{0:T(512)} transpose(%gather.209), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + ROOT %reshape.3748 = s32[263]{0:T(512)S(1)} reshape(%transpose.793), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} +} + +%region_26.33.clone.1 (reduce-window.350: s32[], reduce-window.351: s32[]) -> s32[] { + %reduce-window.350 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.26"} + %reduce-window.351 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.26"} + ROOT %reduce_window_sum.294 = s32[]{:T(128)} add(%reduce-window.350, %reduce-window.351), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_28.35.clone.1 (reduce-window.354: s32[], reduce-window.355: s32[]) -> s32[] { + %reduce-window.354 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.27"} + %reduce-window.355 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.27"} + ROOT %reduce_window_sum.296 = s32[]{:T(128)} add(%reduce-window.354, %reduce-window.355), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_29.36.clone.1 (reduce-window.358: s32[], reduce-window.359: s32[]) -> s32[] { + %reduce-window.358 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.59"} + %reduce-window.359 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.59"} + ROOT %reduce_window_sum.298 = s32[]{:T(128)} add(%reduce-window.358, %reduce-window.359), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_12.18 (sort.44: s32[], sort.45: s32[], sort.46: s32[], sort.47: s32[], sort.48: s32[], sort.49: s32[]) -> pred[] { + %sort.46 = s32[]{:T(128)} parameter(2), metadata={op_name="jit(argsort)/sort"} + %sort.47 = s32[]{:T(128)} parameter(3), metadata={op_name="jit(argsort)/sort"} + %sort.44 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(argsort)/sort"} + %sort.45 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(argsort)/sort"} + %sort.48 = s32[]{:T(128)} parameter(4), metadata={op_name="jit(argsort)/sort"} + %sort.49 = s32[]{:T(128)} parameter(5), metadata={op_name="jit(argsort)/sort"} + %lt_to.27 = pred[]{:T(512)S(6)} compare(%sort.44, %sort.45), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %lt_to.28 = pred[]{:T(512)S(6)} compare(%sort.45, %sort.44), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.134 = pred[]{:T(512)S(6)} compare(%lt_to.27, %lt_to.28), direction=EQ, metadata={op_name="compare.111"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.135 = pred[]{:T(512)S(6)} compare(%sort.48, %sort.49), direction=LT, metadata={op_name="compare.112"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.121 = pred[]{:T(512)} select(%compare.134, %compare.135, %lt_to.27), metadata={op_name="select.109"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.2.clone (param_0.71: bf16[4096,512], param_1.116: s32[4096]) -> bf16[4096,512] { + %param_0.71 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) + %param_1.116 = s32[4096]{0:T(1024)S(1)} parameter(1) + %custom-call.21 = s32[4096]{0:T(1024)} custom-call(%param_1.116), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %slice.900 = s32[4096]{0:T(1024)} slice(%custom-call.21), slice={[0:4096]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3772 = s32[4096]{0:T(1024)} reshape(%slice.900), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.800 = s32[4096]{0:T(1024)} transpose(%reshape.3772), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.210 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.71, %transpose.800), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.799 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.210), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3771 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.799), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} +} + +%region_30.38 (sort.50: s32[], sort.51: s32[], sort.52: s32[], sort.53: s32[], sort.54: s32[], sort.55: s32[]) -> pred[] { + %sort.52 = s32[]{:T(128)} parameter(2), metadata={op_name="jit(argsort)/sort"} + %sort.53 = s32[]{:T(128)} parameter(3), metadata={op_name="jit(argsort)/sort"} + %sort.50 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(argsort)/sort"} + %sort.51 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(argsort)/sort"} + %sort.54 = s32[]{:T(128)} parameter(4), metadata={op_name="jit(argsort)/sort"} + %sort.55 = s32[]{:T(128)} parameter(5), metadata={op_name="jit(argsort)/sort"} + %lt_to.30 = pred[]{:T(512)S(6)} compare(%sort.50, %sort.51), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %lt_to.31 = pred[]{:T(512)S(6)} compare(%sort.51, %sort.50), direction=LT, metadata={op_name="lt_to"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.142 = pred[]{:T(512)S(6)} compare(%lt_to.30, %lt_to.31), direction=EQ, metadata={op_name="compare.113"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.143 = pred[]{:T(512)S(6)} compare(%sort.54, %sort.55), direction=LT, metadata={op_name="compare.114"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %select.126 = pred[]{:T(512)} select(%compare.142, %compare.143, %lt_to.30), metadata={op_name="select.110"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.3.clone (param_0.72: bf16[4096,512], param_1.118: s32[4096]) -> bf16[4096,512] { + %param_0.72 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} parameter(0) + %param_1.118 = s32[4096]{0:T(1024)S(1)} parameter(1) + %custom-call.23 = s32[4096]{0:T(1024)} custom-call(%param_1.118), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %slice.902 = s32[4096]{0:T(1024)} slice(%custom-call.23), slice={[0:4096]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3774 = s32[4096]{0:T(1024)} reshape(%slice.902), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.802 = s32[4096]{0:T(1024)} transpose(%reshape.3774), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.211 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.72, %transpose.802), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.801 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.211), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3773 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.801), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} +} + +%compare (name: s32[], name.1: s32[], name.2: bf16[], name.3: bf16[]) -> pred[] { + %name.2 = bf16[] parameter(2) + %name.3 = bf16[] parameter(3) + %name = s32[] parameter(0) + %name.1 = s32[] parameter(1) + ROOT %compare.377 = pred[] compare(%name, %name.1), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%compare.1 (name.4: s32[], name.5: s32[], name.6: f32[], name.7: f32[]) -> pred[] { + %name.6 = f32[] parameter(2) + %name.7 = f32[] parameter(3) + %name.4 = s32[] parameter(0) + %name.5 = s32[] parameter(1) + ROOT %compare.378 = pred[] compare(%name.4, %name.5), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%compare.2 (name.8: s32[], name.9: s32[], name.10: f32[], name.11: f32[]) -> pred[] { + %name.10 = f32[] parameter(2) + %name.11 = f32[] parameter(3) + %name.8 = s32[] parameter(0) + %name.9 = s32[] parameter(1) + ROOT %compare.379 = pred[] compare(%name.8, %name.9), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%compare.3 (name.12: s32[], name.13: s32[], name.14: f32[], name.15: f32[]) -> pred[] { + %name.14 = f32[] parameter(2) + %name.15 = f32[] parameter(3) + %name.12 = s32[] parameter(0) + %name.13 = s32[] parameter(1) + ROOT %compare.380 = pred[] compare(%name.12, %name.13), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%compare.4 (name.16: s32[], name.17: s32[], name.18: f32[], name.19: f32[]) -> pred[] { + %name.18 = f32[] parameter(2) + %name.19 = f32[] parameter(3) + %name.16 = s32[] parameter(0) + %name.17 = s32[] parameter(1) + ROOT %compare.381 = pred[] compare(%name.16, %name.17), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%called_computation.13 (param_0.4538: s32[256]) -> s32[256] { + %param_0.4538 = s32[256]{0:T(256)} parameter(0) + ROOT %copy.2071 = s32[256]{0:T(256)} copy(%param_0.4538), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"1134","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.13 (param_0.4539: s32[256]) -> s32[256] { + %param_0.4539 = s32[256]{0:T(256)} parameter(0) + ROOT %copy.2072.cloned.1 = s32[256]{0:T(256)} call(%param_0.4539), to_apply=%called_computation.13 +}, execution_thread="sparsecore" + +%region_49.59 (scatter-add.14: s32[], scatter-add.15: s32[]) -> s32[] { + %scatter-add.14 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} + %scatter-add.15 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} + ROOT %add.1353 = s32[]{:T(128)S(7)} add(%scatter-add.14, %scatter-add.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%fused_computation.22.clone.clone (param_0.4540: s32[256], param_1.5328: s32[4096], param_2.4484: s32[4096]) -> s32[256] { + %param_0.4540 = s32[256]{0:T(256)} parameter(0) + %param_1.5328 = s32[4096]{0:T(1024)} parameter(1) + %reshape.4039 = s32[4096]{0:T(1024)} reshape(%param_1.5328), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} + %transpose.857 = s32[4096]{0:T(1024)} transpose(%reshape.4039), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} + %param_2.4484 = s32[4096]{0:T(1024)} parameter(2) + %reshape.4040 = s32[4096]{0:T(1024)} reshape(%param_2.4484), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + %transpose.858 = s32[4096]{0:T(1024)} transpose(%reshape.4040), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.231 = s32[256]{0:T(256)} scatter(%param_0.4540, %transpose.857, %transpose.858), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_49.59, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.14 (param_0.4541: s32[256], param_1.5329: s32[4096], param_2.4485: s32[4096]) -> s32[256] { + %param_0.4541 = s32[256]{0:T(256)} parameter(0) + %param_1.5329 = s32[4096]{0:T(1024)} parameter(1) + %param_2.4485 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.39 = s32[256]{0:T(256)} fusion(%param_0.4541, %param_1.5329, %param_2.4485), kind=kCustom, calls=%fused_computation.22.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["256"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"4160","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.14 (param_0.4542: s32[256], param_1.5330: s32[4096], param_2.4486: s32[4096]) -> s32[256] { + %param_0.4542 = s32[256]{0:T(256)} parameter(0) + %param_1.5330 = s32[4096]{0:T(1024)} parameter(1) + %param_2.4486 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.40.cloned.1 = s32[256]{0:T(256)} call(%param_0.4542, %param_1.5330, %param_2.4486), to_apply=%called_computation.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation (param_0.84: s32[256], param_1.136: s32[4096], param_2.80: s32[4096], param_3.3090: token[]) -> s32[256] { + %param_3.3090 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} + %param_0.84 = s32[256]{0:T(256)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_1.136 = s32[4096]{0:T(1024)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_2.80 = s32[4096]{0:T(1024)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %copy.2072.cloned.1.call-start = ((s32[256]{0:T(256)}), s32[256]{0:T(256)}, u32[]{:S(8)}) async-start(%param_0.84), async_execution_thread="sparsecore", calls=%async_computation.13 + %copy.2072.cloned.1.call-done = s32[256]{0:T(256)} async-done(%copy.2072.cloned.1.call-start) + %scatter_offload_custom_fusion.40.cloned.1.call-start = ((s32[256]{0:T(256)}, s32[4096]{0:T(1024)}, s32[4096]{0:T(1024)}), s32[256]{0:T(256)}, u32[]{:S(8)}) async-start(%copy.2072.cloned.1.call-done, %param_1.136, %param_2.80), async_execution_thread="sparsecore", calls=%async_computation.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.40.cloned.1.call-done = s32[256]{0:T(256)} async-done(%scatter_offload_custom_fusion.40.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%async_computation (param_0.85: s32[256], param_1.137: s32[4096], param_2.81: s32[4096], param_3.3089: token[]) -> s32[256] { + %param_3.3089 = token[] parameter(3) + %param_0.85 = s32[256]{0:T(256)} parameter(0) + %param_1.137 = s32[4096]{0:T(1024)} parameter(1) + %param_2.81 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.2.cloned.1 = s32[256]{0:T(256)} call(%param_0.85, %param_1.137, %param_2.81, %param_3.3089), to_apply=%called_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.15 (param_0.4543: f32[9]) -> f32[9] { + %param_0.4543 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2073 = f32[9]{0:T(128)} copy(%param_0.4543), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.15 (param_0.4544: f32[9]) -> f32[9] { + %param_0.4544 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2074.cloned.1 = f32[9]{0:T(128)} call(%param_0.4544), to_apply=%called_computation.15 +}, execution_thread="sparsecore" + +%region_61.72 (scatter-add.24: f32[], scatter-add.25: f32[]) -> f32[] { + %scatter-add.24 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} + %scatter-add.25 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} + ROOT %add.1359 = f32[]{:T(128)S(7)} add(%scatter-add.24, %scatter-add.25), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%fused_computation.24.clone.clone (param_0.4545: f32[9], param_1.5331: s32[256], param_2.4487: f32[256]) -> f32[9] { + %param_0.4545 = f32[9]{0:T(128)} parameter(0) + %param_1.5331 = s32[256]{0:T(256)} parameter(1) + %reshape.4041 = s32[256]{0:T(256)} reshape(%param_1.5331), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.859 = s32[256]{0:T(256)} transpose(%reshape.4041), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %param_2.4487 = f32[256]{0:T(256)} parameter(2) + %reshape.4042 = f32[256]{0:T(256)} reshape(%param_2.4487), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.860 = f32[256]{0:T(256)} transpose(%reshape.4042), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.232 = f32[9]{0:T(128)} scatter(%param_0.4545, %transpose.859, %transpose.860), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_61.72, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.16 (param_0.4546: f32[9], param_1.5332: s32[256], param_2.4488: f32[256]) -> f32[9] { + %param_0.4546 = f32[9]{0:T(128)} parameter(0) + %param_1.5332 = s32[256]{0:T(256)} parameter(1) + %param_2.4488 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.41 = f32[9]{0:T(128)} fusion(%param_0.4546, %param_1.5332, %param_2.4488), kind=kCustom, calls=%fused_computation.24.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.16 (param_0.4547: f32[9], param_1.5333: s32[256], param_2.4489: f32[256]) -> f32[9] { + %param_0.4547 = f32[9]{0:T(128)} parameter(0) + %param_1.5333 = s32[256]{0:T(256)} parameter(1) + %param_2.4489 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.42.cloned.1 = f32[9]{0:T(128)} call(%param_0.4547, %param_1.5333, %param_2.4489), to_apply=%called_computation.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.1 (param_0.87: f32[9], param_1.139: s32[256], param_2.83: f32[256], param_3.3104: token[]) -> f32[9] { + %param_3.3104 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} + %param_0.87 = f32[9]{0:T(128)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_1.139 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_2.83 = f32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %copy.2074.cloned.1.call-start = ((f32[9]{0:T(128)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%param_0.87), async_execution_thread="sparsecore", calls=%async_computation.15 + %copy.2074.cloned.1.call-done = f32[9]{0:T(128)} async-done(%copy.2074.cloned.1.call-start) + %scatter_offload_custom_fusion.42.cloned.1.call-start = ((f32[9]{0:T(128)}, s32[256]{0:T(256)}, f32[256]{0:T(256)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%copy.2074.cloned.1.call-done, %param_1.139, %param_2.83), async_execution_thread="sparsecore", calls=%async_computation.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.42.cloned.1.call-done = f32[9]{0:T(128)} async-done(%scatter_offload_custom_fusion.42.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%async_computation.1 (param_0.88: f32[9], param_1.140: s32[256], param_2.84: f32[256], param_3.3103: token[]) -> f32[9] { + %param_3.3103 = token[] parameter(3) + %param_0.88 = f32[9]{0:T(128)} parameter(0) + %param_1.140 = s32[256]{0:T(256)} parameter(1) + %param_2.84 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.5.cloned.1 = f32[9]{0:T(128)} call(%param_0.88, %param_1.140, %param_2.84, %param_3.3103), to_apply=%called_computation.1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.17 (param_0.4548: s32[263]) -> s32[263] { + %param_0.4548 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2075 = s32[263]{0:T(512)} copy(%param_0.4548), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.17 (param_0.4549: s32[263]) -> s32[263] { + %param_0.4549 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2076.cloned.1 = s32[263]{0:T(512)} call(%param_0.4549), to_apply=%called_computation.17 +}, execution_thread="sparsecore" + +%region_63.74 (scatter-add.28: s32[], scatter-add.29: s32[]) -> s32[] { + %scatter-add.28 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} + %scatter-add.29 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} + ROOT %add.1360 = s32[]{:T(128)S(7)} add(%scatter-add.28, %scatter-add.29), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%fused_computation.25.clone.clone (param_0.4550: s32[263], param_1.5334: s32[8], param_2.4490: s32[8]) -> s32[263] { + %param_0.4550 = s32[263]{0:T(512)} parameter(0) + %param_1.5334 = s32[8]{0:T(128)} parameter(1) + %reshape.4043 = s32[8]{0:T(128)} reshape(%param_1.5334), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.861 = s32[8]{0:T(128)} transpose(%reshape.4043), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4490 = s32[8]{0:T(128)} parameter(2) + %reshape.4044 = s32[8]{0:T(128)} reshape(%param_2.4490), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + %transpose.862 = s32[8]{0:T(128)} transpose(%reshape.4044), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + ROOT %scatter-add.233 = s32[263]{0:T(512)} scatter(%param_0.4550, %transpose.861, %transpose.862), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_63.74, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.18 (param_0.4551: s32[263], param_1.5335: s32[8], param_2.4491: s32[8]) -> s32[263] { + %param_0.4551 = s32[263]{0:T(512)} parameter(0) + %param_1.5335 = s32[8]{0:T(128)} parameter(1) + %param_2.4491 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.43 = s32[263]{0:T(512)} fusion(%param_0.4551, %param_1.5335, %param_2.4491), kind=kCustom, calls=%fused_computation.25.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.18 (param_0.4552: s32[263], param_1.5336: s32[8], param_2.4492: s32[8]) -> s32[263] { + %param_0.4552 = s32[263]{0:T(512)} parameter(0) + %param_1.5336 = s32[8]{0:T(128)} parameter(1) + %param_2.4492 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.44.cloned.1 = s32[263]{0:T(512)} call(%param_0.4552, %param_1.5336, %param_2.4492), to_apply=%called_computation.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.2 (param_0.90: s32[263], param_1.142: s32[8], param_2.86: s32[8], param_3.3110: token[]) -> s32[263] { + %param_3.3110 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} + %param_0.90 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_1.142 = s32[8]{0:T(128)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_2.86 = s32[8]{0:T(128)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %copy.2076.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.90), async_execution_thread="sparsecore", calls=%async_computation.17 + %copy.2076.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2076.cloned.1.call-start) + %scatter_offload_custom_fusion.44.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[8]{0:T(128)}, s32[8]{0:T(128)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2076.cloned.1.call-done, %param_1.142, %param_2.86), async_execution_thread="sparsecore", calls=%async_computation.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.44.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.44.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%async_computation.2 (param_0.91: s32[263], param_1.143: s32[8], param_2.87: s32[8], param_3.3109: token[]) -> s32[263] { + %param_3.3109 = token[] parameter(3) + %param_0.91 = s32[263]{0:T(512)} parameter(0) + %param_1.143 = s32[8]{0:T(128)} parameter(1) + %param_2.87 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.8.cloned.1 = s32[263]{0:T(512)} call(%param_0.91, %param_1.143, %param_2.87, %param_3.3109), to_apply=%called_computation.2, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.19 (param_0.4553: s32[263]) -> s32[263] { + %param_0.4553 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2077 = s32[263]{0:T(512)} copy(%param_0.4553), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.19 (param_0.4554: s32[263]) -> s32[263] { + %param_0.4554 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2078.cloned.1 = s32[263]{0:T(512)} call(%param_0.4554), to_apply=%called_computation.19 +}, execution_thread="sparsecore" + +%region_73.86.clone (scatter-add.163: s32[], scatter-add.164: s32[]) -> s32[] { + %scatter-add.163 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} + %scatter-add.164 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} + ROOT %add.2485 = s32[]{:T(128)S(7)} add(%scatter-add.163, %scatter-add.164), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%fused_computation.26.clone.clone (param_0.4555: s32[263], param_1.5337: s32[256], param_2.4493: s32[256]) -> s32[263] { + %param_0.4555 = s32[263]{0:T(512)} parameter(0) + %param_1.5337 = s32[256]{0:T(256)} parameter(1) + %reshape.4045 = s32[256]{0:T(256)} reshape(%param_1.5337), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.863 = s32[256]{0:T(256)} transpose(%reshape.4045), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4493 = s32[256]{0:T(256)} parameter(2) + %reshape.4046 = s32[256]{0:T(256)} reshape(%param_2.4493), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.864 = s32[256]{0:T(256)} transpose(%reshape.4046), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.234 = s32[263]{0:T(512)} scatter(%param_0.4555, %transpose.863, %transpose.864), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_73.86.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.20 (param_0.4556: s32[263], param_1.5338: s32[256], param_2.4494: s32[256]) -> s32[263] { + %param_0.4556 = s32[263]{0:T(512)} parameter(0) + %param_1.5338 = s32[256]{0:T(256)} parameter(1) + %param_2.4494 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.45 = s32[263]{0:T(512)} fusion(%param_0.4556, %param_1.5338, %param_2.4494), kind=kCustom, calls=%fused_computation.26.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.20 (param_0.4557: s32[263], param_1.5339: s32[256], param_2.4495: s32[256]) -> s32[263] { + %param_0.4557 = s32[263]{0:T(512)} parameter(0) + %param_1.5339 = s32[256]{0:T(256)} parameter(1) + %param_2.4495 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.46.cloned.1 = s32[263]{0:T(512)} call(%param_0.4557, %param_1.5339, %param_2.4495), to_apply=%called_computation.20, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.3 (param_0.93: s32[263], param_1.145: s32[256], param_2.89: s32[256], param_3.3096: token[]) -> s32[263] { + %param_3.3096 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} + %param_0.93 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_1.145 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_2.89 = s32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %copy.2078.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.93), async_execution_thread="sparsecore", calls=%async_computation.19 + %copy.2078.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2078.cloned.1.call-start) + %scatter_offload_custom_fusion.46.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[256]{0:T(256)}, s32[256]{0:T(256)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2078.cloned.1.call-done, %param_1.145, %param_2.89), async_execution_thread="sparsecore", calls=%async_computation.20, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.46.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.46.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%async_computation.3 (param_0.94: s32[263], param_1.146: s32[256], param_2.90: s32[256], param_3.3095: token[]) -> s32[263] { + %param_3.3095 = token[] parameter(3) + %param_0.94 = s32[263]{0:T(512)} parameter(0) + %param_1.146 = s32[256]{0:T(256)} parameter(1) + %param_2.90 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.11.cloned.1 = s32[263]{0:T(512)} call(%param_0.94, %param_1.146, %param_2.90, %param_3.3095), to_apply=%called_computation.3, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.21 (param_0.4558: f32[9]) -> f32[9] { + %param_0.4558 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2079 = f32[9]{0:T(128)} copy(%param_0.4558), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.21 (param_0.4559: f32[9]) -> f32[9] { + %param_0.4559 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2080.cloned.1 = f32[9]{0:T(128)} call(%param_0.4559), to_apply=%called_computation.21 +}, execution_thread="sparsecore" + +%region_79.95.clone (scatter-add.167: f32[], scatter-add.168: f32[]) -> f32[] { + %scatter-add.167 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} + %scatter-add.168 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} + ROOT %add.2487 = f32[]{:T(128)S(7)} add(%scatter-add.167, %scatter-add.168), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%fused_computation.27.clone.clone (param_0.4560: f32[9], param_1.5340: s32[256], param_2.4496: f32[256]) -> f32[9] { + %param_0.4560 = f32[9]{0:T(128)} parameter(0) + %param_1.5340 = s32[256]{0:T(256)} parameter(1) + %reshape.4047 = s32[256]{0:T(256)} reshape(%param_1.5340), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.865 = s32[256]{0:T(256)} transpose(%reshape.4047), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %param_2.4496 = f32[256]{0:T(256)} parameter(2) + %reshape.4048 = f32[256]{0:T(256)} reshape(%param_2.4496), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.866 = f32[256]{0:T(256)} transpose(%reshape.4048), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.235 = f32[9]{0:T(128)} scatter(%param_0.4560, %transpose.865, %transpose.866), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_79.95.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.22 (param_0.4561: f32[9], param_1.5341: s32[256], param_2.4497: f32[256]) -> f32[9] { + %param_0.4561 = f32[9]{0:T(128)} parameter(0) + %param_1.5341 = s32[256]{0:T(256)} parameter(1) + %param_2.4497 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.47 = f32[9]{0:T(128)} fusion(%param_0.4561, %param_1.5341, %param_2.4497), kind=kCustom, calls=%fused_computation.27.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.22 (param_0.4562: f32[9], param_1.5342: s32[256], param_2.4498: f32[256]) -> f32[9] { + %param_0.4562 = f32[9]{0:T(128)} parameter(0) + %param_1.5342 = s32[256]{0:T(256)} parameter(1) + %param_2.4498 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.48.cloned.1 = f32[9]{0:T(128)} call(%param_0.4562, %param_1.5342, %param_2.4498), to_apply=%called_computation.22, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.4 (param_0.96: f32[9], param_1.148: s32[256], param_2.92: f32[256], param_3.3102: token[]) -> f32[9] { + %param_3.3102 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} + %param_0.96 = f32[9]{0:T(128)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_1.148 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_2.92 = f32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %copy.2080.cloned.1.call-start = ((f32[9]{0:T(128)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%param_0.96), async_execution_thread="sparsecore", calls=%async_computation.21 + %copy.2080.cloned.1.call-done = f32[9]{0:T(128)} async-done(%copy.2080.cloned.1.call-start) + %scatter_offload_custom_fusion.48.cloned.1.call-start = ((f32[9]{0:T(128)}, s32[256]{0:T(256)}, f32[256]{0:T(256)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%copy.2080.cloned.1.call-done, %param_1.148, %param_2.92), async_execution_thread="sparsecore", calls=%async_computation.22, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.48.cloned.1.call-done = f32[9]{0:T(128)} async-done(%scatter_offload_custom_fusion.48.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%async_computation.4 (param_0.97: f32[9], param_1.149: s32[256], param_2.93: f32[256], param_3.3101: token[]) -> f32[9] { + %param_3.3101 = token[] parameter(3) + %param_0.97 = f32[9]{0:T(128)} parameter(0) + %param_1.149 = s32[256]{0:T(256)} parameter(1) + %param_2.93 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.14.cloned.1 = f32[9]{0:T(128)} call(%param_0.97, %param_1.149, %param_2.93, %param_3.3101), to_apply=%called_computation.4, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.23 (param_0.4563: s32[263]) -> s32[263] { + %param_0.4563 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2081 = s32[263]{0:T(512)} copy(%param_0.4563), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.23 (param_0.4564: s32[263]) -> s32[263] { + %param_0.4564 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2082.cloned.1 = s32[263]{0:T(512)} call(%param_0.4564), to_apply=%called_computation.23 +}, execution_thread="sparsecore" + +%region_81.97.clone (scatter-add.171: s32[], scatter-add.172: s32[]) -> s32[] { + %scatter-add.171 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} + %scatter-add.172 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} + ROOT %add.2489 = s32[]{:T(128)S(7)} add(%scatter-add.171, %scatter-add.172), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%fused_computation.28.clone.clone (param_0.4565: s32[263], param_1.5343: s32[8], param_2.4499: s32[8]) -> s32[263] { + %param_0.4565 = s32[263]{0:T(512)} parameter(0) + %param_1.5343 = s32[8]{0:T(128)} parameter(1) + %reshape.4049 = s32[8]{0:T(128)} reshape(%param_1.5343), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.867 = s32[8]{0:T(128)} transpose(%reshape.4049), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4499 = s32[8]{0:T(128)} parameter(2) + %reshape.4050 = s32[8]{0:T(128)} reshape(%param_2.4499), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + %transpose.868 = s32[8]{0:T(128)} transpose(%reshape.4050), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + ROOT %scatter-add.236 = s32[263]{0:T(512)} scatter(%param_0.4565, %transpose.867, %transpose.868), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_81.97.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.24 (param_0.4566: s32[263], param_1.5344: s32[8], param_2.4500: s32[8]) -> s32[263] { + %param_0.4566 = s32[263]{0:T(512)} parameter(0) + %param_1.5344 = s32[8]{0:T(128)} parameter(1) + %param_2.4500 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.49 = s32[263]{0:T(512)} fusion(%param_0.4566, %param_1.5344, %param_2.4500), kind=kCustom, calls=%fused_computation.28.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.24 (param_0.4567: s32[263], param_1.5345: s32[8], param_2.4501: s32[8]) -> s32[263] { + %param_0.4567 = s32[263]{0:T(512)} parameter(0) + %param_1.5345 = s32[8]{0:T(128)} parameter(1) + %param_2.4501 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.50.cloned.1 = s32[263]{0:T(512)} call(%param_0.4567, %param_1.5345, %param_2.4501), to_apply=%called_computation.24, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.5 (param_0.99: s32[263], param_1.151: s32[8], param_2.95: s32[8], param_3.3112: token[]) -> s32[263] { + %param_3.3112 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} + %param_0.99 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_1.151 = s32[8]{0:T(128)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_2.95 = s32[8]{0:T(128)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %copy.2082.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.99), async_execution_thread="sparsecore", calls=%async_computation.23 + %copy.2082.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2082.cloned.1.call-start) + %scatter_offload_custom_fusion.50.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[8]{0:T(128)}, s32[8]{0:T(128)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2082.cloned.1.call-done, %param_1.151, %param_2.95), async_execution_thread="sparsecore", calls=%async_computation.24, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.50.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.50.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%async_computation.5 (param_0.100: s32[263], param_1.152: s32[8], param_2.96: s32[8], param_3.3111: token[]) -> s32[263] { + %param_3.3111 = token[] parameter(3) + %param_0.100 = s32[263]{0:T(512)} parameter(0) + %param_1.152 = s32[8]{0:T(128)} parameter(1) + %param_2.96 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.17.cloned.1 = s32[263]{0:T(512)} call(%param_0.100, %param_1.152, %param_2.96, %param_3.3111), to_apply=%called_computation.5, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.25 (param_0.4568: s32[263]) -> s32[263] { + %param_0.4568 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2083 = s32[263]{0:T(512)} copy(%param_0.4568), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.25 (param_0.4569: s32[263]) -> s32[263] { + %param_0.4569 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2084.cloned.1 = s32[263]{0:T(512)} call(%param_0.4569), to_apply=%called_computation.25 +}, execution_thread="sparsecore" + +%region_96.114 (scatter-add.48: s32[], scatter-add.49: s32[]) -> s32[] { + %scatter-add.48 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} + %scatter-add.49 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} + ROOT %add.1400 = s32[]{:T(128)S(7)} add(%scatter-add.48, %scatter-add.49), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%fused_computation.29.clone.clone (param_0.4570: s32[263], param_1.5346: s32[256], param_2.4502: s32[256]) -> s32[263] { + %param_0.4570 = s32[263]{0:T(512)} parameter(0) + %param_1.5346 = s32[256]{0:T(256)} parameter(1) + %reshape.4051 = s32[256]{0:T(256)} reshape(%param_1.5346), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %transpose.869 = s32[256]{0:T(256)} transpose(%reshape.4051), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %param_2.4502 = s32[256]{0:T(256)} parameter(2) + %reshape.4052 = s32[256]{0:T(256)} reshape(%param_2.4502), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.870 = s32[256]{0:T(256)} transpose(%reshape.4052), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.237 = s32[263]{0:T(512)} scatter(%param_0.4570, %transpose.869, %transpose.870), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_96.114, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.26 (param_0.4571: s32[263], param_1.5347: s32[256], param_2.4503: s32[256]) -> s32[263] { + %param_0.4571 = s32[263]{0:T(512)} parameter(0) + %param_1.5347 = s32[256]{0:T(256)} parameter(1) + %param_2.4503 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.51 = s32[263]{0:T(512)} fusion(%param_0.4571, %param_1.5347, %param_2.4503), kind=kCustom, calls=%fused_computation.29.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.26 (param_0.4572: s32[263], param_1.5348: s32[256], param_2.4504: s32[256]) -> s32[263] { + %param_0.4572 = s32[263]{0:T(512)} parameter(0) + %param_1.5348 = s32[256]{0:T(256)} parameter(1) + %param_2.4504 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.52.cloned.1 = s32[263]{0:T(512)} call(%param_0.4572, %param_1.5348, %param_2.4504), to_apply=%called_computation.26, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.6 (param_0.102: s32[263], param_1.154: s32[256], param_2.98: s32[256], param_3.3098: token[]) -> s32[263] { + %param_3.3098 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} + %param_0.102 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_1.154 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_2.98 = s32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %copy.2084.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.102), async_execution_thread="sparsecore", calls=%async_computation.25 + %copy.2084.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2084.cloned.1.call-start) + %scatter_offload_custom_fusion.52.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[256]{0:T(256)}, s32[256]{0:T(256)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2084.cloned.1.call-done, %param_1.154, %param_2.98), async_execution_thread="sparsecore", calls=%async_computation.26, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.52.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.52.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%async_computation.6 (param_0.103: s32[263], param_1.155: s32[256], param_2.99: s32[256], param_3.3097: token[]) -> s32[263] { + %param_3.3097 = token[] parameter(3) + %param_0.103 = s32[263]{0:T(512)} parameter(0) + %param_1.155 = s32[256]{0:T(256)} parameter(1) + %param_2.99 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.20.cloned.1 = s32[263]{0:T(512)} call(%param_0.103, %param_1.155, %param_2.99, %param_3.3097), to_apply=%called_computation.6, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%region_102.120 (scatter-add.52: f32[], scatter-add.53: f32[]) -> f32[] { + %scatter-add.52 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} + %scatter-add.53 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} + ROOT %add.1403 = f32[]{:T(128)S(7)} add(%scatter-add.52, %scatter-add.53), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%fused_computation.30.clone.clone (param_0.4575: f32[9], param_1.5349: s32[256], param_2.4505: f32[256]) -> f32[9] { + %param_0.4575 = f32[9]{0:T(128)} parameter(0) + %param_1.5349 = s32[256]{0:T(256)} parameter(1) + %reshape.4053 = s32[256]{0:T(256)} reshape(%param_1.5349), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.871 = s32[256]{0:T(256)} transpose(%reshape.4053), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} + %param_2.4505 = f32[256]{0:T(256)} parameter(2) + %reshape.4054 = f32[256]{0:T(256)} reshape(%param_2.4505), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.872 = f32[256]{0:T(256)} transpose(%reshape.4054), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.238 = f32[9]{0:T(128)} scatter(%param_0.4575, %transpose.871, %transpose.872), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_102.120, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.28 (param_0.4576: f32[9], param_1.5350: s32[256], param_2.4506: f32[256]) -> f32[9] { + %param_0.4576 = f32[9]{0:T(128)} parameter(0) + %param_1.5350 = s32[256]{0:T(256)} parameter(1) + %param_2.4506 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.53 = f32[9]{0:T(128)} fusion(%param_0.4576, %param_1.5350, %param_2.4506), kind=kCustom, calls=%fused_computation.30.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.28 (param_0.4577: f32[9], param_1.5351: s32[256], param_2.4507: f32[256]) -> f32[9] { + %param_0.4577 = f32[9]{0:T(128)} parameter(0) + %param_1.5351 = s32[256]{0:T(256)} parameter(1) + %param_2.4507 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.54.cloned.1 = f32[9]{0:T(128)} call(%param_0.4577, %param_1.5351, %param_2.4507), to_apply=%called_computation.28, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.7 (param_0.105: f32[9], param_1.157: s32[256], param_2.101: f32[256], param_3.3106: token[]) -> f32[9] { + %param_3.3106 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} + %param_0.105 = f32[9]{0:T(128)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_1.157 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_2.101 = f32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %scatter_offload_custom_fusion.54.cloned.1.call-start = ((f32[9]{0:T(128)}, s32[256]{0:T(256)}, f32[256]{0:T(256)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%param_0.105, %param_1.157, %param_2.101), async_execution_thread="sparsecore", calls=%async_computation.28, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.54.cloned.1.call-done = f32[9]{0:T(128)} async-done(%scatter_offload_custom_fusion.54.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%async_computation.7 (param_0.106: f32[9], param_1.158: s32[256], param_2.102: f32[256], param_3.3105: token[]) -> f32[9] { + %param_3.3105 = token[] parameter(3) + %param_0.106 = f32[9]{0:T(128)} parameter(0) + %param_1.158 = s32[256]{0:T(256)} parameter(1) + %param_2.102 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.23.cloned.1 = f32[9]{0:T(128)} call(%param_0.106, %param_1.158, %param_2.102, %param_3.3105), to_apply=%called_computation.7, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%region_104.122 (scatter-add.83: s32[], scatter-add.84: s32[]) -> s32[] { + %scatter-add.83 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} + %scatter-add.84 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} + ROOT %add.1404 = s32[]{:T(128)S(7)} add(%scatter-add.83, %scatter-add.84), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%fused_computation.31.clone.clone (param_0.4580: s32[263], param_1.5352: s32[8], param_2.4508: s32[8]) -> s32[263] { + %param_0.4580 = s32[263]{0:T(512)} parameter(0) + %param_1.5352 = s32[8]{0:T(128)} parameter(1) + %reshape.4055 = s32[8]{0:T(128)} reshape(%param_1.5352), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %transpose.873 = s32[8]{0:T(128)} transpose(%reshape.4055), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %param_2.4508 = s32[8]{0:T(128)} parameter(2) + %reshape.4056 = s32[8]{0:T(128)} reshape(%param_2.4508), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + %transpose.874 = s32[8]{0:T(128)} transpose(%reshape.4056), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + ROOT %scatter-add.239 = s32[263]{0:T(512)} scatter(%param_0.4580, %transpose.873, %transpose.874), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_104.122, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.30 (param_0.4581: s32[263], param_1.5353: s32[8], param_2.4509: s32[8]) -> s32[263] { + %param_0.4581 = s32[263]{0:T(512)} parameter(0) + %param_1.5353 = s32[8]{0:T(128)} parameter(1) + %param_2.4509 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.55 = s32[263]{0:T(512)} fusion(%param_0.4581, %param_1.5353, %param_2.4509), kind=kCustom, calls=%fused_computation.31.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.30 (param_0.4582: s32[263], param_1.5354: s32[8], param_2.4510: s32[8]) -> s32[263] { + %param_0.4582 = s32[263]{0:T(512)} parameter(0) + %param_1.5354 = s32[8]{0:T(128)} parameter(1) + %param_2.4510 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.56.cloned.1 = s32[263]{0:T(512)} call(%param_0.4582, %param_1.5354, %param_2.4510), to_apply=%called_computation.30, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.8 (param_0.108: s32[263], param_1.160: s32[8], param_2.104: s32[8], param_3.3114: token[]) -> s32[263] { + %param_3.3114 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} + %param_0.108 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_1.160 = s32[8]{0:T(128)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_2.104 = s32[8]{0:T(128)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %scatter_offload_custom_fusion.56.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[8]{0:T(128)}, s32[8]{0:T(128)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.108, %param_1.160, %param_2.104), async_execution_thread="sparsecore", calls=%async_computation.30, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.56.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.56.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%async_computation.8 (param_0.109: s32[263], param_1.161: s32[8], param_2.105: s32[8], param_3.3113: token[]) -> s32[263] { + %param_3.3113 = token[] parameter(3) + %param_0.109 = s32[263]{0:T(512)} parameter(0) + %param_1.161 = s32[8]{0:T(128)} parameter(1) + %param_2.105 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.26.cloned.1 = s32[263]{0:T(512)} call(%param_0.109, %param_1.161, %param_2.105, %param_3.3113), to_apply=%called_computation.8, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%region_13.19 (scatter-add.0: s32[], scatter-add.1: s32[]) -> s32[] { + %scatter-add.0 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} + %scatter-add.1 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} + ROOT %add.1312 = s32[]{:T(128)S(7)} add(%scatter-add.0, %scatter-add.1), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%fused_computation.17.clone.clone.clone (param_0.4585: s32[256], param_1.5355: s32[4096], param_2.4511: s32[4096]) -> s32[256] { + %param_0.4585 = s32[256]{0:T(256)} parameter(0) + %param_1.5355 = s32[4096]{0:T(1024)} parameter(1) + %reshape.4057 = s32[4096]{0:T(1024)} reshape(%param_1.5355), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} + %transpose.875 = s32[4096]{0:T(1024)} transpose(%reshape.4057), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} + %param_2.4511 = s32[4096]{0:T(1024)} parameter(2) + %reshape.4058 = s32[4096]{0:T(1024)} reshape(%param_2.4511), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + %transpose.876 = s32[4096]{0:T(1024)} transpose(%reshape.4058), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.240 = s32[256]{0:T(256)} scatter(%param_0.4585, %transpose.875, %transpose.876), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_13.19, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.32 (param_0.4586: s32[256], param_1.5356: s32[4096], param_2.4512: s32[4096]) -> s32[256] { + %param_0.4586 = s32[256]{0:T(256)} parameter(0) + %param_1.5356 = s32[4096]{0:T(1024)} parameter(1) + %param_2.4512 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.57 = s32[256]{0:T(256)} fusion(%param_0.4586, %param_1.5356, %param_2.4512), kind=kCustom, calls=%fused_computation.17.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["256"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"4160","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.32 (param_0.4587: s32[256], param_1.5357: s32[4096], param_2.4513: s32[4096]) -> s32[256] { + %param_0.4587 = s32[256]{0:T(256)} parameter(0) + %param_1.5357 = s32[4096]{0:T(1024)} parameter(1) + %param_2.4513 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.58.cloned.1 = s32[256]{0:T(256)} call(%param_0.4587, %param_1.5357, %param_2.4513), to_apply=%called_computation.32, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.9 (param_0.111: s32[256], param_1.163: s32[4096], param_2.107: s32[4096], param_3.3092: token[]) -> s32[256] { + %param_3.3092 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} + %param_0.111 = s32[256]{0:T(256)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_1.163 = s32[4096]{0:T(1024)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_2.107 = s32[4096]{0:T(1024)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %scatter_offload_custom_fusion.58.cloned.1.call-start = ((s32[256]{0:T(256)}, s32[4096]{0:T(1024)}, s32[4096]{0:T(1024)}), s32[256]{0:T(256)}, u32[]{:S(8)}) async-start(%param_0.111, %param_1.163, %param_2.107), async_execution_thread="sparsecore", calls=%async_computation.32, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.58.cloned.1.call-done = s32[256]{0:T(256)} async-done(%scatter_offload_custom_fusion.58.cloned.1.call-start), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%async_computation.9 (param_0.112: s32[256], param_1.164: s32[4096], param_2.108: s32[4096], param_3.3091: token[]) -> s32[256] { + %param_3.3091 = token[] parameter(3) + %param_0.112 = s32[256]{0:T(256)} parameter(0) + %param_1.164 = s32[4096]{0:T(1024)} parameter(1) + %param_2.108 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.29.cloned.1 = s32[256]{0:T(256)} call(%param_0.112, %param_1.164, %param_2.108, %param_3.3091), to_apply=%called_computation.9, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.33 (param_0.4588: s32[263]) -> s32[263] { + %param_0.4588 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2091 = s32[263]{0:T(512)} copy(%param_0.4588), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.33 (param_0.4589: s32[263]) -> s32[263] { + %param_0.4589 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2092.cloned.1 = s32[263]{0:T(512)} call(%param_0.4589), to_apply=%called_computation.33 +}, execution_thread="sparsecore" + +%region_19.25.clone.1 (scatter-add.141: s32[], scatter-add.142: s32[]) -> s32[] { + %scatter-add.141 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} + %scatter-add.142 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} + ROOT %add.2474 = s32[]{:T(128)S(7)} add(%scatter-add.141, %scatter-add.142), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%fused_computation.18.clone.clone.clone (param_0.4590: s32[263], param_1.5358: s32[256], param_2.4514: s32[256]) -> s32[263] { + %param_0.4590 = s32[263]{0:T(512)} parameter(0) + %param_1.5358 = s32[256]{0:T(256)} parameter(1) + %reshape.4059 = s32[256]{0:T(256)} reshape(%param_1.5358), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.877 = s32[256]{0:T(256)} transpose(%reshape.4059), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4514 = s32[256]{0:T(256)} parameter(2) + %reshape.4060 = s32[256]{0:T(256)} reshape(%param_2.4514) + %transpose.878 = s32[256]{0:T(256)} transpose(%reshape.4060), dimensions={0} + ROOT %scatter-add.241 = s32[263]{0:T(512)} scatter(%param_0.4590, %transpose.877, %transpose.878), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_19.25.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.34 (param_0.4591: s32[263], param_1.5359: s32[256], param_2.4515: s32[256]) -> s32[263] { + %param_0.4591 = s32[263]{0:T(512)} parameter(0) + %param_1.5359 = s32[256]{0:T(256)} parameter(1) + %param_2.4515 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.59 = s32[263]{0:T(512)} fusion(%param_0.4591, %param_1.5359, %param_2.4515), kind=kCustom, calls=%fused_computation.18.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.34 (param_0.4592: s32[263], param_1.5360: s32[256], param_2.4516: s32[256]) -> s32[263] { + %param_0.4592 = s32[263]{0:T(512)} parameter(0) + %param_1.5360 = s32[256]{0:T(256)} parameter(1) + %param_2.4516 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.60.cloned.1 = s32[263]{0:T(512)} call(%param_0.4592, %param_1.5360, %param_2.4516), to_apply=%called_computation.34, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.10 (param_0.114: s32[263], param_1.166: s32[256], param_2.110: s32[256], param_3.3094: token[]) -> s32[263] { + %param_3.3094 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} + %param_0.114 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_1.166 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_2.110 = s32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %copy.2092.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.114), async_execution_thread="sparsecore", calls=%async_computation.33 + %copy.2092.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2092.cloned.1.call-start) + %scatter_offload_custom_fusion.60.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[256]{0:T(256)}, s32[256]{0:T(256)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2092.cloned.1.call-done, %param_1.166, %param_2.110), async_execution_thread="sparsecore", calls=%async_computation.34, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.60.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.60.cloned.1.call-start), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%async_computation.10 (param_0.115: s32[263], param_1.167: s32[256], param_2.111: s32[256], param_3.3093: token[]) -> s32[263] { + %param_3.3093 = token[] parameter(3) + %param_0.115 = s32[263]{0:T(512)} parameter(0) + %param_1.167 = s32[256]{0:T(256)} parameter(1) + %param_2.111 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.32.cloned.1 = s32[263]{0:T(512)} call(%param_0.115, %param_1.167, %param_2.111, %param_3.3093), to_apply=%called_computation.10, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.35 (param_0.4593: f32[9]) -> f32[9] { + %param_0.4593 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2093 = f32[9]{0:T(128)} copy(%param_0.4593), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.35 (param_0.4594: f32[9]) -> f32[9] { + %param_0.4594 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2094.cloned.1 = f32[9]{0:T(128)} call(%param_0.4594), to_apply=%called_computation.35 +}, execution_thread="sparsecore" + +%region_25.32.clone.1 (scatter-add.145: f32[], scatter-add.146: f32[]) -> f32[] { + %scatter-add.145 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} + %scatter-add.146 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} + ROOT %add.2476 = f32[]{:T(128)S(7)} add(%scatter-add.145, %scatter-add.146), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%fused_computation.19.clone.clone.clone (param_0.4595: f32[9], param_1.5361: s32[256], param_2.4517: f32[256]) -> f32[9] { + %param_0.4595 = f32[9]{0:T(128)} parameter(0) + %param_1.5361 = s32[256]{0:T(256)} parameter(1) + %reshape.4061 = s32[256]{0:T(256)} reshape(%param_1.5361), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.879 = s32[256]{0:T(256)} transpose(%reshape.4061), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %param_2.4517 = f32[256]{0:T(256)} parameter(2) + %reshape.4062 = f32[256]{0:T(256)} reshape(%param_2.4517) + %transpose.880 = f32[256]{0:T(256)} transpose(%reshape.4062), dimensions={0} + ROOT %scatter-add.242 = f32[9]{0:T(128)} scatter(%param_0.4595, %transpose.879, %transpose.880), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_25.32.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.36 (param_0.4596: f32[9], param_1.5362: s32[256], param_2.4518: f32[256]) -> f32[9] { + %param_0.4596 = f32[9]{0:T(128)} parameter(0) + %param_1.5362 = s32[256]{0:T(256)} parameter(1) + %param_2.4518 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.61 = f32[9]{0:T(128)} fusion(%param_0.4596, %param_1.5362, %param_2.4518), kind=kCustom, calls=%fused_computation.19.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.36 (param_0.4597: f32[9], param_1.5363: s32[256], param_2.4519: f32[256]) -> f32[9] { + %param_0.4597 = f32[9]{0:T(128)} parameter(0) + %param_1.5363 = s32[256]{0:T(256)} parameter(1) + %param_2.4519 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.62.cloned.1 = f32[9]{0:T(128)} call(%param_0.4597, %param_1.5363, %param_2.4519), to_apply=%called_computation.36, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.11 (param_0.117: f32[9], param_1.169: s32[256], param_2.113: f32[256], param_3.3100: token[]) -> f32[9] { + %param_3.3100 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} + %param_0.117 = f32[9]{0:T(128)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_1.169 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_2.113 = f32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %copy.2094.cloned.1.call-start = ((f32[9]{0:T(128)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%param_0.117), async_execution_thread="sparsecore", calls=%async_computation.35 + %copy.2094.cloned.1.call-done = f32[9]{0:T(128)} async-done(%copy.2094.cloned.1.call-start) + %scatter_offload_custom_fusion.62.cloned.1.call-start = ((f32[9]{0:T(128)}, s32[256]{0:T(256)}, f32[256]{0:T(256)}), f32[9]{0:T(128)}, u32[]{:S(8)}) async-start(%copy.2094.cloned.1.call-done, %param_1.169, %param_2.113), async_execution_thread="sparsecore", calls=%async_computation.36, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.62.cloned.1.call-done = f32[9]{0:T(128)} async-done(%scatter_offload_custom_fusion.62.cloned.1.call-start), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%async_computation.11 (param_0.118: f32[9], param_1.170: s32[256], param_2.114: f32[256], param_3.3099: token[]) -> f32[9] { + %param_3.3099 = token[] parameter(3) + %param_0.118 = f32[9]{0:T(128)} parameter(0) + %param_1.170 = s32[256]{0:T(256)} parameter(1) + %param_2.114 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.35.cloned.1 = f32[9]{0:T(128)} call(%param_0.118, %param_1.170, %param_2.114, %param_3.3099), to_apply=%called_computation.11, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.37 (param_0.4598: s32[263]) -> s32[263] { + %param_0.4598 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2095 = s32[263]{0:T(512)} copy(%param_0.4598), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.37 (param_0.4599: s32[263]) -> s32[263] { + %param_0.4599 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2096.cloned.1 = s32[263]{0:T(512)} call(%param_0.4599), to_apply=%called_computation.37 +}, execution_thread="sparsecore" + +%region_27.34.clone.1 (scatter-add.149: s32[], scatter-add.150: s32[]) -> s32[] { + %scatter-add.149 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} + %scatter-add.150 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} + ROOT %add.2478 = s32[]{:T(128)S(7)} add(%scatter-add.149, %scatter-add.150), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%fused_computation.20.clone.clone.clone (param_0.4600: s32[263], param_1.5364: s32[8], param_2.4520: s32[8]) -> s32[263] { + %param_0.4600 = s32[263]{0:T(512)} parameter(0) + %param_1.5364 = s32[8]{0:T(128)} parameter(1) + %reshape.4063 = s32[8]{0:T(128)} reshape(%param_1.5364), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.881 = s32[8]{0:T(128)} transpose(%reshape.4063), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4520 = s32[8]{0:T(128)} parameter(2) + %reshape.4064 = s32[8]{0:T(128)} reshape(%param_2.4520) + %transpose.882 = s32[8]{0:T(128)} transpose(%reshape.4064), dimensions={0} + ROOT %scatter-add.243 = s32[263]{0:T(512)} scatter(%param_0.4600, %transpose.881, %transpose.882), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_27.34.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.38 (param_0.4601: s32[263], param_1.5365: s32[8], param_2.4521: s32[8]) -> s32[263] { + %param_0.4601 = s32[263]{0:T(512)} parameter(0) + %param_1.5365 = s32[8]{0:T(128)} parameter(1) + %param_2.4521 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.63 = s32[263]{0:T(512)} fusion(%param_0.4601, %param_1.5365, %param_2.4521), kind=kCustom, calls=%fused_computation.20.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +}, execution_thread="sparsecore" + +%async_computation.38 (param_0.4602: s32[263], param_1.5366: s32[8], param_2.4522: s32[8]) -> s32[263] { + %param_0.4602 = s32[263]{0:T(512)} parameter(0) + %param_1.5366 = s32[8]{0:T(128)} parameter(1) + %param_2.4522 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.64.cloned.1 = s32[263]{0:T(512)} call(%param_0.4602, %param_1.5366, %param_2.4522), to_apply=%called_computation.38, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%called_computation.12 (param_0.120: s32[263], param_1.172: s32[8], param_2.116: s32[8], param_3.3108: token[]) -> s32[263] { + %param_3.3108 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} + %param_0.120 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_1.172 = s32[8]{0:T(128)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %param_2.116 = s32[8]{0:T(128)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} + %copy.2096.cloned.1.call-start = ((s32[263]{0:T(512)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%param_0.120), async_execution_thread="sparsecore", calls=%async_computation.37 + %copy.2096.cloned.1.call-done = s32[263]{0:T(512)} async-done(%copy.2096.cloned.1.call-start) + %scatter_offload_custom_fusion.64.cloned.1.call-start = ((s32[263]{0:T(512)}, s32[8]{0:T(128)}, s32[8]{0:T(128)}), s32[263]{0:T(512)}, u32[]{:S(8)}) async-start(%copy.2096.cloned.1.call-done, %param_1.172, %param_2.116), async_execution_thread="sparsecore", calls=%async_computation.38, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.64.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.64.cloned.1.call-start), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%async_computation.12 (param_0.121: s32[263], param_1.173: s32[8], param_2.117: s32[8], param_3.3107: token[]) -> s32[263] { + %param_3.3107 = token[] parameter(3) + %param_0.121 = s32[263]{0:T(512)} parameter(0) + %param_1.173 = s32[8]{0:T(128)} parameter(1) + %param_2.117 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.38.cloned.1 = s32[263]{0:T(512)} call(%param_0.121, %param_1.173, %param_2.117, %param_3.3107), to_apply=%called_computation.12, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +}, execution_thread="sparsecore" + +%region_154.179 (reduce_sum.502: f32[], reduce_sum.336: f32[]) -> f32[] { + %reduce_sum.502 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.336 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.337 = f32[]{:T(128)} add(%reduce_sum.502, %reduce_sum.336), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.465 (param_0.4187: f32[3,1536,128,192]) -> f32[] { + %param_0.4187 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(0) + %bitcast.654 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_0.4187), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.564 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%bitcast.654, %bitcast.654), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5046 = f32[]{:T(128)} constant(0) + ROOT %reduce.612 = f32[]{:T(128)} reduce(%square.564, %constant.5046), dimensions={0,1,2,3}, to_apply=%region_154.179, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%fused_computation.466 (param_0.1438: f32[1536,3,128,192]) -> bf16[3,1536,128,192] { + %param_0.1438 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(0) + %copy.1548 = bf16[1536,3,128,192]{2,0,3,1:T(8,128)(2,1)} copy(%param_0.1438), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'dense_layers\'][\'self_attention\'][\'wq_b\'][\'kernel\']"} + ROOT %bitcast.655 = bf16[3,1536,128,192]{2,1,3,0:T(8,128)(2,1)} bitcast(%copy.1548), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} +} + +%region_221.246 (reduce_sum.964: f32[], reduce_sum.965: f32[]) -> f32[] { + %reduce_sum.964 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.965 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.534 = f32[]{:T(128)} add(%reduce_sum.964, %reduce_sum.965), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_187.212 (reduce_sum.726: f32[], reduce_sum.727: f32[]) -> f32[] { + %reduce_sum.726 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.727 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.452 = f32[]{:T(128)} add(%reduce_sum.726, %reduce_sum.727), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.467 (param_0.4157: f32[1536,3,128,192], param_1.5030: f32[], param_2.4288: f32[], param_3.2956: f32[], param_4.2203: f32[1536,3,128,192], param_5.2003: f32[], param_6.1444: f32[3,1536,128,192], param_7.1124: pred[], param_8.889: f32[1536,3,128,192]) -> (f32[], f32[1536,3,128,192], f32[1536,3,128,192], f32[1536,3,128,192], f32[]) { + %param_0.4157 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(0) + %param_3.2956 = f32[]{:T(128)S(6)} parameter(3) + %mul.5439.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_3.2956), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.1124 = pred[]{:T(512)S(6)} parameter(7) + %select_n.2121.clone.1 = pred[1536,3,128,192]{2,3,1,0:T(8,128)(4,1)} broadcast(%param_7.1124), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.1444 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(6) + %bitcast.1356.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_6.1444), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} + %param_5.2003 = f32[]{:T(128)} parameter(5) + %div.2562.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_5.2003), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2561.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%bitcast.1356.clone.1, %div.2562.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2120.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} select(%select_n.2121.clone.1, %bitcast.1356.clone.1, %div.2561.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4805.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4133.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4805.clone.1), dimensions={}, metadata={op_name="broadcast.334"} + %mul.5445.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.2120.clone.1, %broadcast.4133.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.889 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(8) + %constant.4809.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.5446.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4809.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5444.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_8.889, %mul.5446.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3446.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.5445.clone.1, %mul.5444.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4288 = f32[]{:T(128)S(6)} parameter(2) + %div.2558.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_2.4288), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.399.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.2120.clone.1, %select_n.2120.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4808.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.5443.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4808.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5441.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%integer_pow.399.clone.1, %mul.5443.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.2203 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(4) + %constant.4807.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.5442.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4807.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5440.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_4.2203, %mul.5442.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3445.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.5441.clone.1, %mul.5440.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.5030 = f32[]{:T(128)S(6)} parameter(1) + %div.2557.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_1.5030), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2556.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3445.clone.1, %div.2557.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.157.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} sqrt(%div.2556.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4806.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3444.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4806.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3443.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%sqrt.157.clone.1, %add.3444.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1086.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%div.2558.clone.1, %add.3443.clone.1), metadata={op_name="multiply.263"} + %div.2555.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3446.clone.1, %multiply.1086.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5438.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_0.4157, %broadcast.4133.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3442.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%div.2555.clone.1, %mul.5438.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5437.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%mul.5439.clone.1, %add.3442.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3441.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%param_0.4157, %mul.5437.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.565 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%add.3441.clone.1, %add.3441.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5016 = f32[]{:T(128)} constant(0) + %reduce.613 = f32[]{:T(128)} reduce(%square.565, %constant.5016), dimensions={0,1,2,3}, to_apply=%region_221.246, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.614.clone.1 = f32[]{:T(128)} reduce(%integer_pow.399.clone.1, %constant.5016), dimensions={0,1,2,3}, to_apply=%region_187.212, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.656 = (f32[]{:T(128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.613, %add.3441.clone.1, %add.3445.clone.1, %add.3446.clone.1, %reduce.614.clone.1) +} + +%region_160.185 (reduce_sum.544: f32[], reduce_sum.364: f32[]) -> f32[] { + %reduce_sum.544 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.364 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.365 = f32[]{:T(128)} add(%reduce_sum.544, %reduce_sum.364), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_158.183 (reduce_sum.530: f32[], reduce_sum.352: f32[]) -> f32[] { + %reduce_sum.530 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.352 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.357 = f32[]{:T(128)} add(%reduce_sum.530, %reduce_sum.352), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.493 (param_0.4183: bf16[256,512,512], param_1.5052: bf16[256,512,512]) -> (f32[], f32[]) { + %param_0.4183 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) + %broadcast_in_dim.1270 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4183), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.677 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1270), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.570 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.677, %bitcast.677), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5042 = f32[]{:T(128)} constant(0) + %reduce.615 = f32[]{:T(128)} reduce(%square.570, %constant.5042), dimensions={0,1,2,3}, to_apply=%region_160.185, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.5052 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(1) + %broadcast_in_dim.1278.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_1.5052), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.685.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1278.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.576.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.685.clone.1, %bitcast.685.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.617.clone.1 = f32[]{:T(128)} reduce(%square.576.clone.1, %constant.5042), dimensions={0,1,2,3}, to_apply=%region_158.183, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.764 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.615, %reduce.617.clone.1) +} + +%region_159.184 (reduce_sum.537: f32[], reduce_sum.358: f32[]) -> f32[] { + %reduce_sum.537 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.358 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.359 = f32[]{:T(128)} add(%reduce_sum.537, %reduce_sum.358), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.495 (param_0.4182: bf16[256,512,512]) -> f32[] { + %param_0.4182 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) + %broadcast_in_dim.1274 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4182), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.681 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1274), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.573 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.681, %bitcast.681), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5041 = f32[]{:T(128)} constant(0) + ROOT %reduce.616 = f32[]{:T(128)} reduce(%square.573, %constant.5041), dimensions={0,1,2,3}, to_apply=%region_159.184, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%region_227.252 (reduce_sum.1006: f32[], reduce_sum.1007: f32[]) -> f32[] { + %reduce_sum.1006 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.1007 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.548 = f32[]{:T(128)} add(%reduce_sum.1006, %reduce_sum.1007), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_193.218 (reduce_sum.768: f32[], reduce_sum.769: f32[]) -> f32[] { + %reduce_sum.768 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.769 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.466 = f32[]{:T(128)} add(%reduce_sum.768, %reduce_sum.769), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.513 (param_0.4151: f32[], param_1.5024: f32[256,1,512,512], param_2.4282: f32[], param_3.2950: f32[256,1,512,512], param_4.2197: f32[], param_5.1997: bf16[256,512,512], param_6.1438: pred[], param_7.1118: f32[], param_8.883: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { + %param_8.883 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) + %bitcast.1341.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.883), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} + %param_7.1118 = f32[]{:T(128)S(6)} parameter(7) + %mul.5388.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1118), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1438 = pred[]{:T(512)S(6)} parameter(6) + %select_n.2103.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1438), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_5.1997 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) + %broadcast_in_dim.1484.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.1997), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.1343.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1484.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %param_4.2197 = f32[]{:T(128)} parameter(4) + %div.2520.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2197), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2519.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1343.clone.1, %div.2520.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2102.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2103.clone.1, %bitcast.1343.clone.1, %div.2519.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4775.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4113.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4775.clone.1), dimensions={}, metadata={op_name="broadcast.2239"} + %mul.5390.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2102.clone.1, %broadcast.4113.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_3.2950 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) + %bitcast.1342.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2950), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} + %constant.4774.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.4112.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4774.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.5389.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1342.clone.1, %broadcast.4112.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3411.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5390.clone.1, %mul.5389.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4282 = f32[]{:T(128)S(6)} parameter(2) + %div.2518.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4282), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.393.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2102.clone.1, %select_n.2102.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4773.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.4115.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4773.clone.1), dimensions={}, metadata={op_name="broadcast.2242"} + %mul.5392.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.393.clone.1, %broadcast.4115.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_1.5024 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) + %bitcast.1344.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5024), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} + %constant.4772.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.4114.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4772.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.5391.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1344.clone.1, %broadcast.4114.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3412.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5392.clone.1, %mul.5391.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_0.4151 = f32[]{:T(128)S(6)} parameter(0) + %div.2517.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4151), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2516.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3412.clone.1, %div.2517.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.151.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2516.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4776.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.4111.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4776.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3410.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.151.clone.1, %broadcast.4111.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1080.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2518.clone.1, %add.3410.clone.1), metadata={op_name="multiply.269"} + %div.2515.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3411.clone.1, %multiply.1080.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5387.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1341.clone.1, %broadcast.4113.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3409.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2515.clone.1, %mul.5387.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5386.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.5388.clone.1, %add.3409.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3408.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1341.clone.1, %mul.5386.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.577 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3408.clone.1, %add.3408.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5010 = f32[]{:T(128)} constant(0) + %reduce.618 = f32[]{:T(128)} reduce(%square.577, %constant.5010), dimensions={0,1,2,3}, to_apply=%region_227.252, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %bitcast.831.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3412.clone.1) + %bitcast.804.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3411.clone.1) + %reduce.627.clone.1 = f32[]{:T(128)} reduce(%integer_pow.393.clone.1, %constant.5010), dimensions={0,1,2,3}, to_apply=%region_193.218, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.666 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.618, %add.3408.clone.1, %bitcast.831.clone.1, %bitcast.804.clone.1, %reduce.627.clone.1) +} + +%region_226.251 (reduce_sum.999: f32[], reduce_sum.1000: f32[]) -> f32[] { + %reduce_sum.999 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.1000 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.543 = f32[]{:T(128)} add(%reduce_sum.999, %reduce_sum.1000), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_192.217 (reduce_sum.761: f32[], reduce_sum.762: f32[]) -> f32[] { + %reduce_sum.761 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.762 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.465 = f32[]{:T(128)} add(%reduce_sum.761, %reduce_sum.762), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.514 (param_0.4152: f32[], param_1.5025: f32[256,1,512,512], param_2.4283: f32[], param_3.2951: f32[256,1,512,512], param_4.2198: f32[], param_5.1998: bf16[256,512,512], param_6.1439: pred[], param_7.1119: f32[], param_8.884: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { + %param_8.884 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) + %bitcast.1345.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.884), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} + %param_7.1119 = f32[]{:T(128)S(6)} parameter(7) + %mul.5395.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1119), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1439 = pred[]{:T(512)S(6)} parameter(6) + %select_n.2105.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1439), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_5.1998 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) + %broadcast_in_dim.1485.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.1998), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.1347.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1485.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %param_4.2198 = f32[]{:T(128)} parameter(4) + %div.2526.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2198), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2525.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1347.clone.1, %div.2526.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2104.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2105.clone.1, %bitcast.1347.clone.1, %div.2525.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4780.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4118.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4780.clone.1), dimensions={}, metadata={op_name="broadcast.2239"} + %mul.5397.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2104.clone.1, %broadcast.4118.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_3.2951 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) + %bitcast.1346.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2951), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} + %constant.4779.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.4117.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4779.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.5396.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1346.clone.1, %broadcast.4117.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3416.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5397.clone.1, %mul.5396.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4283 = f32[]{:T(128)S(6)} parameter(2) + %div.2524.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4283), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.394.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2104.clone.1, %select_n.2104.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4778.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.4120.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4778.clone.1), dimensions={}, metadata={op_name="broadcast.2242"} + %mul.5399.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.394.clone.1, %broadcast.4120.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_1.5025 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) + %bitcast.1348.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5025), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} + %constant.4777.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.4119.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4777.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.5398.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1348.clone.1, %broadcast.4119.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3417.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5399.clone.1, %mul.5398.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_0.4152 = f32[]{:T(128)S(6)} parameter(0) + %div.2523.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4152), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2522.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3417.clone.1, %div.2523.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.152.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2522.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4781.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.4116.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4781.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3415.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.152.clone.1, %broadcast.4116.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1081.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2524.clone.1, %add.3415.clone.1), metadata={op_name="multiply.268"} + %div.2521.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3416.clone.1, %multiply.1081.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5394.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1345.clone.1, %broadcast.4118.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3414.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2521.clone.1, %mul.5394.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5393.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.5395.clone.1, %add.3414.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3413.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1345.clone.1, %mul.5393.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.578 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3413.clone.1, %add.3413.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5011 = f32[]{:T(128)} constant(0) + %reduce.619 = f32[]{:T(128)} reduce(%square.578, %constant.5011), dimensions={0,1,2,3}, to_apply=%region_226.251, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %bitcast.822.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3417.clone.1) + %bitcast.795.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3416.clone.1) + %reduce.628.clone.1 = f32[]{:T(128)} reduce(%integer_pow.394.clone.1, %constant.5011), dimensions={0,1,2,3}, to_apply=%region_192.217, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.665 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.619, %add.3413.clone.1, %bitcast.822.clone.1, %bitcast.795.clone.1, %reduce.628.clone.1) +} + +%region_225.250 (reduce_sum.992: f32[], reduce_sum.993: f32[]) -> f32[] { + %reduce_sum.992 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.993 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.542 = f32[]{:T(128)} add(%reduce_sum.992, %reduce_sum.993), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_191.216 (reduce_sum.754: f32[], reduce_sum.755: f32[]) -> f32[] { + %reduce_sum.754 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.755 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.464 = f32[]{:T(128)} add(%reduce_sum.754, %reduce_sum.755), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.515 (param_0.4153: f32[], param_1.5026: f32[256,1,512,512], param_2.4284: f32[], param_3.2952: f32[256,1,512,512], param_4.2199: f32[], param_5.1999: bf16[256,512,512], param_6.1440: pred[], param_7.1120: f32[], param_8.885: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { + %param_8.885 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) + %bitcast.1349.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.885), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} + %param_7.1120 = f32[]{:T(128)S(6)} parameter(7) + %mul.5402.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1120), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1440 = pred[]{:T(512)S(6)} parameter(6) + %select_n.2107.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1440), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_5.1999 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) + %broadcast_in_dim.1486.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.1999), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.1351.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1486.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %param_4.2199 = f32[]{:T(128)} parameter(4) + %div.2532.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2199), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2531.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1351.clone.1, %div.2532.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2106.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2107.clone.1, %bitcast.1351.clone.1, %div.2531.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4785.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4123.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4785.clone.1), dimensions={}, metadata={op_name="broadcast.2239"} + %mul.5404.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2106.clone.1, %broadcast.4123.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_3.2952 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) + %bitcast.1350.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2952), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} + %constant.4784.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.4122.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4784.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.5403.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1350.clone.1, %broadcast.4122.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3421.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5404.clone.1, %mul.5403.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4284 = f32[]{:T(128)S(6)} parameter(2) + %div.2530.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4284), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.395.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2106.clone.1, %select_n.2106.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4783.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.4125.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4783.clone.1), dimensions={}, metadata={op_name="broadcast.2242"} + %mul.5406.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.395.clone.1, %broadcast.4125.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_1.5026 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) + %bitcast.1352.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5026), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} + %constant.4782.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.4124.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4782.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.5405.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1352.clone.1, %broadcast.4124.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3422.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5406.clone.1, %mul.5405.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_0.4153 = f32[]{:T(128)S(6)} parameter(0) + %div.2529.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4153), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2528.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3422.clone.1, %div.2529.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.153.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2528.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4786.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.4121.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4786.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3420.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.153.clone.1, %broadcast.4121.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1082.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2530.clone.1, %add.3420.clone.1), metadata={op_name="multiply.267"} + %div.2527.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3421.clone.1, %multiply.1082.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5401.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1349.clone.1, %broadcast.4123.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3419.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2527.clone.1, %mul.5401.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5400.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.5402.clone.1, %add.3419.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3418.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1349.clone.1, %mul.5400.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.579 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3418.clone.1, %add.3418.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5012 = f32[]{:T(128)} constant(0) + %reduce.620 = f32[]{:T(128)} reduce(%square.579, %constant.5012), dimensions={0,1,2,3}, to_apply=%region_225.250, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %bitcast.813.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3422.clone.1) + %bitcast.786.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3421.clone.1) + %reduce.629.clone.1 = f32[]{:T(128)} reduce(%integer_pow.395.clone.1, %constant.5012), dimensions={0,1,2,3}, to_apply=%region_191.216, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.664 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.620, %add.3418.clone.1, %bitcast.813.clone.1, %bitcast.786.clone.1, %reduce.629.clone.1) +} + +%region_155.180 (reduce_sum.509: f32[], reduce_sum.338: f32[]) -> f32[] { + %reduce_sum.509 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.338 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.343 = f32[]{:T(128)} add(%reduce_sum.509, %reduce_sum.338), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.527.clone.clone.clone (param_0.4096: bf16[4,128,129280], param_1.4958: s32[4,128], param_2.4215: f32[4,128], param_3.2918: f32[4,128], param_4.2170: bf16[4,128], param_5.1975: f32[4,128]) -> bf16[4,128,129280] { + %param_5.1975 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.5639 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_5.1975), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.2918 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.5638 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2918), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.4096 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.3099 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4096), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_4.2170 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.791 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_4.2170), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.790 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.3099, %sub.791), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.534 = f32[4,128,129280]{2,1,0:T(8,128)} exponential(%sub.790), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %mul.5637 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.5638, %exp.534), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.4215 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.2685 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4215), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.2684 = f32[4,128,129280]{2,1,0:T(8,128)} divide(%mul.5637, %div.2685), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.4958 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.363 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.4958), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.362 = s32[4,128,129280]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.361 = pred[4,128,129280]{2,1,0:T(8,128)(4,1)} compare(%eq.363, %eq.362), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %convert_element_type.3098 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%eq.361), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.789 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%div.2684, %convert_element_type.3098), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.5636 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.5639, %sub.789), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.3097 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} convert(%mul.5636), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +} + +%fused_computation.932.clone.clone (param_0.4097: f32[4,128], param_1.4959: bf16[4,128,512], param_2.4217: bf16[512]) -> bf16[4,128,512] { + %param_1.4959 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.3101 = f32[4,128,512]{2,1,0:T(8,128)} convert(%param_1.4959), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.4097 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.5642 = f32[4,128,512]{2,1,0:T(8,128)} broadcast(%param_0.4097), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.5641 = f32[4,128,512]{2,1,0:T(8,128)} multiply(%convert_element_type.3101, %mul.5642), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.3100 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} convert(%mul.5641), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.4217 = bf16[512]{0:T(512)(128)(2,1)S(1)} parameter(2) + %mul.5643 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.4217), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.5640 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.3100, %mul.5643), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%fused_computation.516 (param_0.4186: bf16[4,128,129280], param_1.5054: s32[4,128], param_2.4309: f32[4,128], param_3.2974: f32[4,128], param_4.2219: bf16[4,128], param_5.2017: f32[4,128], param_6.1458: f32[4,128], param_7.1138: bf16[4,128,512], param_8.902: bf16[512]) -> (f32[], bf16[512,129280,1]) { + %param_6.1458 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.1138 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)S(1)} parameter(7) + %param_8.902 = bf16[512]{0:T(512)(128)(2,1)S(1)} parameter(8) + %fusion.571.clone.1 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} fusion(%param_6.1458, %param_7.1138, %param_8.902), kind=kLoop, calls=%fused_computation.932.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.4186 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.5054 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.4309 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.2974 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %param_4.2219 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %param_5.2017 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %multiply_convert_fusion.1.clone.1 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} fusion(%param_0.4186, %param_1.5054, %param_2.4309, %param_3.2974, %param_4.2219, /*index=5*/%param_5.2017), kind=kLoop, calls=%fused_computation.527.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %convolution.141.clone.1 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} convolution(%fusion.571.clone.1, %multiply_convert_fusion.1.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %bitcast.758 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%convolution.141.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %convert_element_type.2606 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.758), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %square.581 = f32[512,129280]{1,0:T(8,128)} multiply(%convert_element_type.2606, %convert_element_type.2606), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5045 = f32[]{:T(128)} constant(0) + %reduce.621 = f32[]{:T(128)} reduce(%square.581, %constant.5045), dimensions={0,1}, to_apply=%region_155.180, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.754 = (f32[]{:T(128)}, bf16[512,129280,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.621, %convolution.141.clone.1) +} + +%region_174.199 (reduce_sum.635: f32[], reduce_sum.636: f32[]) -> f32[] { + %reduce_sum.635 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.636 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.424 = f32[]{:T(128)} add(%reduce_sum.635, %reduce_sum.636), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.517 (param_0.4170: bf16[129280,512]) -> f32[] { + %param_0.4170 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.2608 = f32[129280,512]{1,0:T(8,128)} convert(%param_0.4170), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %square.583 = f32[129280,512]{1,0:T(8,128)} multiply(%convert_element_type.2608, %convert_element_type.2608), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5029 = f32[]{:T(128)} constant(0) + ROOT %reduce.622 = f32[]{:T(128)} reduce(%square.583, %constant.5029), dimensions={0,1}, to_apply=%region_174.199, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%region_240.265 (reduce_sum.1097: f32[], reduce_sum.1098: f32[]) -> f32[] { + %reduce_sum.1097 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.1098 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.577 = f32[]{:T(128)} add(%reduce_sum.1097, %reduce_sum.1098), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_206.231 (reduce_sum.859: f32[], reduce_sum.860: f32[]) -> f32[] { + %reduce_sum.859 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.860 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.499 = f32[]{:T(128)} add(%reduce_sum.859, %reduce_sum.860), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.518 (param_0.4138: f32[129280,512], param_1.5011: f32[], param_2.4269: f32[], param_3.2937: f32[], param_4.2184: f32[129280,512], param_5.1984: f32[], param_6.1425: bf16[129280,512], param_7.1105: pred[], param_8.870: f32[129280,512]) -> (f32[], f32[129280,512], f32[129280,512], f32[129280,512], f32[]) { + %param_0.4138 = f32[129280,512]{1,0:T(8,128)} parameter(0) + %param_3.2937 = f32[]{:T(128)S(6)} parameter(3) + %mul.5276.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_3.2937), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.1105 = pred[]{:T(512)S(6)} parameter(7) + %select_n.2061.clone.1 = pred[129280,512]{1,0:T(8,128)(4,1)} broadcast(%param_7.1105), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.1425 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(6) + %convert_element_type.3042.clone.1 = f32[129280,512]{1,0:T(8,128)} convert(%param_6.1425), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %param_5.1984 = f32[]{:T(128)} parameter(5) + %div.2426.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_5.1984), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2425.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%convert_element_type.3042.clone.1, %div.2426.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2060.clone.1 = f32[129280,512]{1,0:T(8,128)} select(%select_n.2061.clone.1, %convert_element_type.3042.clone.1, %div.2425.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4695.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4063.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4695.clone.1), dimensions={}, metadata={op_name="broadcast.318"} + %mul.5282.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.2060.clone.1, %broadcast.4063.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.870 = f32[129280,512]{1,0:T(8,128)} parameter(8) + %constant.4699.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.5283.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4699.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5281.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_8.870, %mul.5283.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3341.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.5282.clone.1, %mul.5281.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4269 = f32[]{:T(128)S(6)} parameter(2) + %div.2422.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_2.4269), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.380.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.2060.clone.1, %select_n.2060.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4698.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.5280.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4698.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5278.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%integer_pow.380.clone.1, %mul.5280.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.2184 = f32[129280,512]{1,0:T(8,128)} parameter(4) + %constant.4697.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.5279.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4697.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5277.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_4.2184, %mul.5279.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3340.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.5278.clone.1, %mul.5277.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.5011 = f32[]{:T(128)S(6)} parameter(1) + %div.2421.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_1.5011), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2420.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3340.clone.1, %div.2421.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.138.clone.1 = f32[129280,512]{1,0:T(8,128)} sqrt(%div.2420.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4696.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3339.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4696.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3338.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%sqrt.138.clone.1, %add.3339.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1067.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%div.2422.clone.1, %add.3338.clone.1), metadata={op_name="multiply.282"} + %div.2419.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3341.clone.1, %multiply.1067.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5275.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_0.4138, %broadcast.4063.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3337.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%div.2419.clone.1, %mul.5275.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5274.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%mul.5276.clone.1, %add.3337.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3336.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%param_0.4138, %mul.5274.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.584 = f32[129280,512]{1,0:T(8,128)} multiply(%add.3336.clone.1, %add.3336.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.4997 = f32[]{:T(128)} constant(0) + %reduce.623 = f32[]{:T(128)} reduce(%square.584, %constant.4997), dimensions={0,1}, to_apply=%region_240.265, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.630.clone.1 = f32[]{:T(128)} reduce(%integer_pow.380.clone.1, %constant.4997), dimensions={0,1}, to_apply=%region_206.231, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.667 = (f32[]{:T(128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.623, %add.3336.clone.1, %add.3340.clone.1, %add.3341.clone.1, %reduce.630.clone.1) +} + +%region_222.247 (reduce_sum.971: f32[], reduce_sum.972: f32[]) -> f32[] { + %reduce_sum.971 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.972 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.535 = f32[]{:T(128)} add(%reduce_sum.971, %reduce_sum.972), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_188.213 (reduce_sum.733: f32[], reduce_sum.734: f32[]) -> f32[] { + %reduce_sum.733 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.734 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.457 = f32[]{:T(128)} add(%reduce_sum.733, %reduce_sum.734), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.519 (param_0.4156: f32[512,129280], param_1.5029: f32[], param_2.4287: f32[], param_3.2955: f32[], param_4.2202: f32[512,129280], param_5.2002: f32[], param_6.1443: bf16[512,129280,1], param_7.1123: pred[], param_8.888: f32[512,129280]) -> (f32[], f32[512,129280], f32[512,129280], f32[512,129280], f32[]) { + %param_0.4156 = f32[512,129280]{1,0:T(8,128)} parameter(0) + %param_3.2955 = f32[]{:T(128)S(6)} parameter(3) + %mul.5429.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_3.2955), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.1123 = pred[]{:T(512)S(6)} parameter(7) + %select_n.2117.clone.1 = pred[512,129280]{1,0:T(8,128)(4,1)} broadcast(%param_7.1123), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.1443 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} parameter(6) + %bitcast.1354.clone.1 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%param_6.1443), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %convert_element_type.3044.clone.1 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.1354.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %param_5.2002 = f32[]{:T(128)} parameter(5) + %div.2554.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_5.2002), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2553.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%convert_element_type.3044.clone.1, %div.2554.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2116.clone.1 = f32[512,129280]{1,0:T(8,128)} select(%select_n.2117.clone.1, %convert_element_type.3044.clone.1, %div.2553.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4799.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4131.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4799.clone.1), dimensions={}, metadata={op_name="broadcast.333"} + %mul.5435.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.2116.clone.1, %broadcast.4131.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.888 = f32[512,129280]{1,0:T(8,128)} parameter(8) + %constant.4803.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.5436.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4803.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5434.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_8.888, %mul.5436.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3440.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.5435.clone.1, %mul.5434.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4287 = f32[]{:T(128)S(6)} parameter(2) + %div.2550.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_2.4287), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.398.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.2116.clone.1, %select_n.2116.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4802.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.5433.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4802.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5431.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%integer_pow.398.clone.1, %mul.5433.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.2202 = f32[512,129280]{1,0:T(8,128)} parameter(4) + %constant.4801.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.5432.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4801.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5430.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_4.2202, %mul.5432.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3439.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.5431.clone.1, %mul.5430.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.5029 = f32[]{:T(128)S(6)} parameter(1) + %div.2549.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_1.5029), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2548.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3439.clone.1, %div.2549.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.156.clone.1 = f32[512,129280]{1,0:T(8,128)} sqrt(%div.2548.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4800.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3438.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4800.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3437.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%sqrt.156.clone.1, %add.3438.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1085.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%div.2550.clone.1, %add.3437.clone.1), metadata={op_name="multiply.264"} + %div.2547.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3440.clone.1, %multiply.1085.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5428.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_0.4156, %broadcast.4131.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3436.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%div.2547.clone.1, %mul.5428.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5427.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%mul.5429.clone.1, %add.3436.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3435.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%param_0.4156, %mul.5427.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.585 = f32[512,129280]{1,0:T(8,128)} multiply(%add.3435.clone.1, %add.3435.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5015 = f32[]{:T(128)} constant(0) + %reduce.624 = f32[]{:T(128)} reduce(%square.585, %constant.5015), dimensions={0,1}, to_apply=%region_222.247, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.631.clone.1 = f32[]{:T(128)} reduce(%integer_pow.398.clone.1, %constant.5015), dimensions={0,1}, to_apply=%region_188.213, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.668 = (f32[]{:T(128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.624, %add.3435.clone.1, %add.3439.clone.1, %add.3440.clone.1, %reduce.631.clone.1) +} + +%region_207.232 (reduce_sum.866: f32[], reduce_sum.867: f32[]) -> f32[] { + %reduce_sum.866 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.867 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.500 = f32[]{:T(128)} add(%reduce_sum.866, %reduce_sum.867), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.520 (param_0.4206: bf16[4,128,129280], param_1.5067: f32[4,128], param_2.4319: s32[4,128], param_3.2982: bf16[4,128]) -> f32[4,128] { + %param_2.4319 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %eq.299 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4319), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.294 = s32[4,128,129280]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.293 = pred[4,128,129280]{2,1,0:T(8,128)(4,1)} compare(%eq.299, %eq.294), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %param_0.4206 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.2613 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4206), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_3.2982 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) + %sub.652 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2982), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.643 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2613, %sub.652), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_1.5067 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %sub.650 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5067), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.639 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%sub.643, %sub.650), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %constant.5068 = f32[]{:T(128)} constant(0) + %broadcast.3638 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%constant.5068), dimensions={}, metadata={op_name="broadcast.496"} + %mul.4227 = f32[4,128,129280]{2,1,0:T(8,128)} select(%eq.293, %sub.639, %broadcast.3638), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + ROOT %reduce.625 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.4227, %constant.5068), dimensions={2}, to_apply=%region_207.232, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} +} + +%region_37.47 (reduce_sum.76: f32[], reduce_sum.82: f32[]) -> f32[] { + %reduce_sum.76 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.82 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.88 = f32[]{:T(128)} add(%reduce_sum.76, %reduce_sum.82), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.531 (param_0.4207: bf16[4,128,129280], param_1.5068: bf16[4,128]) -> f32[4,128] { + %param_0.4207 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.2619 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4207), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_1.5068 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) + %sub.653 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5068), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.649 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2619, %sub.653), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.448 = f32[4,128,129280]{2,1,0:T(8,128)} exponential(%sub.649), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %constant.5069 = f32[]{:T(128)} constant(0) + ROOT %reduce.626 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.448, %constant.5069), dimensions={2}, to_apply=%region_37.47, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} +} + +%region_152.177 (reduce_sum.488: f32[], reduce_sum.324: f32[]) -> f32[] { + %reduce_sum.488 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.324 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.329 = f32[]{:T(128)} add(%reduce_sum.488, %reduce_sum.324), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.539 (param_0.4189: f32[3,512,128,256]) -> f32[] { + %param_0.4189 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.734 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_0.4189), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.588 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%bitcast.734, %bitcast.734), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5048 = f32[]{:T(128)} constant(0) + ROOT %reduce.632 = f32[]{:T(128)} reduce(%square.588, %constant.5048), dimensions={0,1,2,3}, to_apply=%region_152.177, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%fused_computation.540 (param_0.1619: f32[512,3,128,256]) -> bf16[3,512,128,256] { + %param_0.1619 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(0) + %copy.1549 = bf16[512,3,128,256]{3,0,2,1:T(8,128)(2,1)} copy(%param_0.1619), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'dense_layers\'][\'self_attention\'][\'wkv_b\'][\'kernel\']"} + ROOT %bitcast.735 = bf16[3,512,128,256]{3,1,2,0:T(8,128)(2,1)} bitcast(%copy.1549), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} +} + +%region_219.244 (reduce_sum.950: f32[], reduce_sum.951: f32[]) -> f32[] { + %reduce_sum.950 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.951 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.528 = f32[]{:T(128)} add(%reduce_sum.950, %reduce_sum.951), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_185.210 (reduce_sum.712: f32[], reduce_sum.713: f32[]) -> f32[] { + %reduce_sum.712 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.713 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.450 = f32[]{:T(128)} add(%reduce_sum.712, %reduce_sum.713), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.541 (param_0.4159: f32[512,3,128,256], param_1.5032: f32[], param_2.4290: f32[], param_3.2958: f32[], param_4.2205: f32[512,3,128,256], param_5.2005: f32[], param_6.1446: f32[3,512,128,256], param_7.1126: pred[], param_8.891: f32[512,3,128,256]) -> (f32[], f32[512,3,128,256], f32[512,3,128,256], f32[512,3,128,256], f32[]) { + %param_0.4159 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(0) + %param_3.2958 = f32[]{:T(128)S(6)} parameter(3) + %mul.5459.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_3.2958), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.1126 = pred[]{:T(512)S(6)} parameter(7) + %select_n.2129.clone.1 = pred[512,3,128,256]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.1126), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.1446 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.1360.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_6.1446), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} + %param_5.2005 = f32[]{:T(128)} parameter(5) + %div.2578.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_5.2005), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2577.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%bitcast.1360.clone.1, %div.2578.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2128.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} select(%select_n.2129.clone.1, %bitcast.1360.clone.1, %div.2577.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4817.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4137.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4817.clone.1), dimensions={}, metadata={op_name="broadcast.336"} + %mul.5465.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.2128.clone.1, %broadcast.4137.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.891 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(8) + %constant.4821.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.5466.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4821.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5464.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_8.891, %mul.5466.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3458.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.5465.clone.1, %mul.5464.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4290 = f32[]{:T(128)S(6)} parameter(2) + %div.2574.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_2.4290), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.401.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.2128.clone.1, %select_n.2128.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.4820.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.5463.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4820.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5461.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%integer_pow.401.clone.1, %mul.5463.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.2205 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(4) + %constant.4819.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.5462.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4819.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5460.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_4.2205, %mul.5462.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3457.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.5461.clone.1, %mul.5460.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.5032 = f32[]{:T(128)S(6)} parameter(1) + %div.2573.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_1.5032), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2572.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3457.clone.1, %div.2573.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.159.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} sqrt(%div.2572.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.4818.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3456.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4818.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3455.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%sqrt.159.clone.1, %add.3456.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1088.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%div.2574.clone.1, %add.3455.clone.1), metadata={op_name="multiply.261"} + %div.2571.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3458.clone.1, %multiply.1088.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5458.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_0.4159, %broadcast.4137.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3454.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%div.2571.clone.1, %mul.5458.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5457.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%mul.5459.clone.1, %add.3454.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3453.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%param_0.4159, %mul.5457.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.589 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%add.3453.clone.1, %add.3453.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5018 = f32[]{:T(128)} constant(0) + %reduce.633 = f32[]{:T(128)} reduce(%square.589, %constant.5018), dimensions={0,1,2,3}, to_apply=%region_219.244, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.634.clone.1 = f32[]{:T(128)} reduce(%integer_pow.401.clone.1, %constant.5018), dimensions={0,1,2,3}, to_apply=%region_185.210, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.663 = (f32[]{:T(128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.633, %add.3453.clone.1, %add.3457.clone.1, %add.3458.clone.1, %reduce.634.clone.1) +} + +%region_172.197 (reduce_sum.628: f32[], reduce_sum.629: f32[]) -> f32[] { + %reduce_sum.628 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.629 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.423 = f32[]{:T(128)} add(%reduce_sum.628, %reduce_sum.629), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.780.clone.clone (param_0.4123: f32[4,128], param_1.5003: bf16[4,128,1536], param_2.4251: bf16[1536]) -> bf16[4,128,1536,1] { + %param_1.5003 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.3123 = f32[4,128,1536]{2,1,0:T(8,128)} convert(%param_1.5003), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} + %param_0.4123 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.5708 = f32[4,128,1536]{2,1,0:T(8,128)} broadcast(%param_0.4123), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %mul.5707 = f32[4,128,1536]{2,1,0:T(8,128)} multiply(%convert_element_type.3123, %mul.5708), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %convert_element_type.3122 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} convert(%mul.5707), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} + %param_2.4251 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.5709 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.4251), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %mul.5706 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.3122, %mul.5709), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + ROOT %bitcast.1448 = bf16[4,128,1536,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%mul.5706), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} +} + +%bitcast_fusion.12 (bitcast_input.12: bf16[4,128,128,192]) -> bf16[4,128,128,192] { + %bitcast_input.12 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.1470 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)} bitcast(%bitcast_input.12) +} + +%fused_computation.550 (param_0.4171: bf16[4,128,128,192], param_1.5043: f32[4,128], param_2.4301: bf16[4,128,1536], param_3.2969: bf16[1536]) -> (f32[], bf16[1536,128,192,1]) { + %param_1.5043 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.4301 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.2969 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.457.clone.1 = bf16[4,128,1536,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_1.5043, %param_2.4301, %param_3.2969), kind=kLoop, calls=%fused_computation.780.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %param_0.4171 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)S(1)} parameter(0) + %fusion.744 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.4171), kind=kLoop, calls=%bitcast_fusion.12 + %convolution.146.clone.1 = bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)} convolution(%fusion.457.clone.1, %fusion.744), window={size=192x4 pad=191_191x0_0 rhs_reversal=1x0}, dim_labels=1fb0_1io0->bf01, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=0} + %bitcast.843 = bf16[1536,128,192]{1,0,2:T(8,128)(2,1)} bitcast(%convolution.146.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=0} + %broadcast_in_dim.1300 = f32[1536,128,192]{1,0,2:T(8,128)} convert(%bitcast.843), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.745 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} bitcast(%broadcast_in_dim.1300), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.592 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%bitcast.745, %bitcast.745), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5030 = f32[]{:T(128)} constant(0) + %reduce.635 = f32[]{:T(128)} reduce(%square.592, %constant.5030), dimensions={0,1,2,3}, to_apply=%region_172.197, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.763 = (f32[]{:T(128)}, bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)}) tuple(%reduce.635, %convolution.146.clone.1) +} + +%region_239.264 (reduce_sum.1090: f32[], reduce_sum.1091: f32[]) -> f32[] { + %reduce_sum.1090 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.1091 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.576 = f32[]{:T(128)} add(%reduce_sum.1090, %reduce_sum.1091), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_205.230 (reduce_sum.852: f32[], reduce_sum.853: f32[]) -> f32[] { + %reduce_sum.852 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.853 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.494 = f32[]{:T(128)} add(%reduce_sum.852, %reduce_sum.853), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.555 (param_0.4139: f32[], param_1.5012: f32[], param_2.4270: f32[], param_3.2938: f32[1536,1,128,192], param_4.2185: f32[1536,1,128,192], param_5.1985: f32[], param_6.1426: bf16[1536,128,192,1], param_7.1106: pred[], param_8.871: f32[1536,1,128,192]) -> (f32[], f32[1536,1,128,192], f32[1536,1,128,192], f32[1536,1,128,192], f32[]) { diff --git a/src/maxtext/utils/reference_hlo_llama3_8b.txt b/src/maxtext/utils/reference_hlo_llama3_8b.txt new file mode 100644 index 0000000000..19a83b1a84 --- /dev/null +++ b/src/maxtext/utils/reference_hlo_llama3_8b.txt @@ -0,0 +1,2000 @@ +HloModule jit_train_step, is_scheduled=true, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias), {4}: (4, {}, may-alias), {5}: (5, {}, may-alias), {6}: (6, {}, may-alias), {7}: (7, {}, may-alias), {8}: (8, {}, may-alias), {9}: (9, {}, may-alias), {10}: (10, {}, may-alias), {11}: (11, {}, may-alias), {12}: (12, {}, may-alias), {13}: (13, {}, may-alias), {14}: (14, {}, may-alias), {15}: (15, {}, may-alias), {16}: (16, {}, may-alias), {17}: (17, {}, may-alias), {18}: (18, {}, may-alias), {19}: (19, {}, may-alias), {20}: (20, {}, may-alias), {21}: (21, {}, may-alias), {22}: (22, {}, may-alias), {23}: (23, {}, may-alias), {24}: (24, {}, may-alias), {25}: (25, {}, may-alias), {26}: (26, {}, may-alias), {27}: (27, {}, may-alias), {28}: (28, {}, may-alias), {29}: (29, {}, may-alias), {30}: (30, {}, may-alias), {31}: (31, {}, may-alias), {32}: (32, {}, may-alias), {33}: (33, {}, may-alias), {34}: (34, {}, may-alias), {35}: (35, {}, may-alias), {36}: (36, {}, may-alias), {37}: (37, {}, may-alias), {38}: (38, {}, may-alias) }, entry_computation_layout={(s32[]{:T(128)}, f32[4096]{0:T(1024)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, /*index=5*/f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, /*index=10*/f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, s32[]{:T(128)}, f32[4096]{0:T(1024)}, /*index=15*/f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, /*index=20*/f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, /*index=25*/f32[128256,4096]{1,0:T(8,128)}, f32[4096]{0:T(1024)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, /*index=30*/f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, /*index=35*/f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, s32[]{:T(128)}, s32[4,128]{1,0:T(4,128)}, /*index=40*/s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)})->(s32[]{:T(128)}, f32[4096]{0:T(1024)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, /*index=5*/f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, /*index=10*/f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, s32[]{:T(128)}, f32[4096]{0:T(1024)}, /*index=15*/f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, /*index=20*/f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, /*index=25*/f32[128256,4096]{1,0:T(8,128)}, f32[4096]{0:T(1024)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, /*index=30*/f32[4096,4]{0,1:T(4,128)}, f32[4096,4]{0,1:T(4,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, /*index=35*/f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, s32[]{:T(128)}, f32[]{:T(128)}, /*index=40*/f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, /*index=45*/f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, s32[]{:T(128)}, f32[]{:T(128)})}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false}, allow_spmd_sharding_propagation_to_output={false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,true,true,true,true,true,true,true,true,true,true,true}, num_partitions=4 + +FileNames + +FunctionNames + +FileLocations + +StackFrames + + +%fused_computation (param_0.2: bf16[128256,4096], param_1.7: s32[1024]) -> bf16[512,4096] { + %param_0.2 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(0) + %param_1.7 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.1 = s32[1024]{0:T(1024)} custom-call(%param_1.7), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %slice.6 = s32[512]{0:T(512)} slice(%custom-call.1), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %reshape.380 = s32[4,128]{1,0:T(4,128)} reshape(%slice.6), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.241 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.380), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %gather.4 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} gather(%param_0.2, %transpose.241), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,4096}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %transpose.240 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} transpose(%gather.4), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %reshape.379 = bf16[512,4096]{1,0:T(8,128)(2,1)} reshape(%transpose.240), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} +} + +%region_33.38.clone (scatter-add.6: bf16[], scatter-add.7: bf16[]) -> bf16[] { + %scatter-add.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} + %scatter-add.7 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} + ROOT %add.462 = bf16[]{:T(256)} add(%scatter-add.6, %scatter-add.7), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.1 (param_0.3: bf16[128256,4096], param_1.5: s32[512], param_2.4: bf16[512,4096]) -> bf16[128256,4096] { + %param_0.3 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(0) + %param_1.5 = s32[512]{0:T(512)S(1)} parameter(1) + %reshape.387 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.5), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.246 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.387), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %param_2.4 = bf16[512,4096]{1,0:T(8,128)(2,1)S(1)} parameter(2) + %reshape.388 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} reshape(%param_2.4), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} + %transpose.247 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} transpose(%reshape.388), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} + ROOT %scatter.2 = bf16[128256,4096]{1,0:T(8,128)(2,1)} scatter(%param_0.3, %transpose.246, %transpose.247), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_33.38.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} +} + +%region_32.37 (reduce_sum.244: f32[], reduce_sum.245: f32[]) -> f32[] { + %reduce_sum.244 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.245 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.249 = f32[]{:T(128)} add(%reduce_sum.244, %reduce_sum.245), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.280.clone.clone.clone (param_0.1106: bf16[4,128,128256], param_1.1269: s32[4,128], param_2.1080: f32[4,128], param_3.773: f32[4,128], param_4.481: bf16[4,128], param_5.412: f32[4,128]) -> bf16[4,128,128256] { + %param_5.412 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.1937 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_5.412), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.773 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.1936 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.773), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.1106 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1060 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1106), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_4.481 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.94 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_4.481), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.93 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.1060, %sub.94), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.62 = f32[4,128,128256]{2,1,0:T(8,128)} exponential(%sub.93), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %mul.1935 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1936, %exp.62), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.1080 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.823 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1080), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.822 = f32[4,128,128256]{2,1,0:T(8,128)} divide(%mul.1935, %div.823), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.1269 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.49 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1269), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.48 = s32[4,128,128256]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.47 = pred[4,128,128256]{2,1,0:T(8,128)(4,1)} compare(%eq.49, %eq.48), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %convert_element_type.1059 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%eq.47), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.92 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%div.822, %convert_element_type.1059), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.1934 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1937, %sub.92), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.1058 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convert(%mul.1934), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +} + +%fused_computation.316.clone.clone (param_0.1107: f32[4,128], param_1.1270: bf16[4,128,4096], param_2.1082: bf16[4096]) -> bf16[4,128,4096] { + %param_1.1270 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1062 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1270), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1107 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1940 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1107), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1939 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1062, %mul.1940), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1061 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1939), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.1082 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.1941 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1082), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.1938 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1061, %mul.1941), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%fused_computation.219 (param_0.1126: bf16[4,128,128256], param_1.1285: s32[4,128], param_2.1106: f32[4,128], param_3.789: f32[4,128], param_4.496: bf16[4,128], param_5.427: f32[4,128], param_6.300: f32[4,128], param_7.198: bf16[4,128,4096], param_8.116: bf16[4096]) -> (f32[], bf16[4096,128256,1]) { + %param_6.300 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.198 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(7) + %param_8.116 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(8) + %fusion.239.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_6.300, %param_7.198, %param_8.116), kind=kLoop, calls=%fused_computation.316.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.1126 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1285 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1106 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.789 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %param_4.496 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %param_5.427 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %multiply_convert_fusion.1.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1126, %param_1.1285, %param_2.1106, %param_3.789, %param_4.496, /*index=5*/%param_5.427), kind=kLoop, calls=%fused_computation.280.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %convolution.88.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} convolution(%fusion.239.clone.1, %multiply_convert_fusion.1.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %bitcast.306 = bf16[4096,128256]{1,0:T(8,128)(2,1)} bitcast(%convolution.88.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %convert_element_type.939 = f32[4096,128256]{1,0:T(8,128)} convert(%bitcast.306), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %square.157 = f32[4096,128256]{1,0:T(8,128)} multiply(%convert_element_type.939, %convert_element_type.939), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.998 = f32[]{:T(128)} constant(0) + %reduce.79 = f32[]{:T(128)} reduce(%square.157, %constant.998), dimensions={0,1}, to_apply=%region_32.37, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.154 = (f32[]{:T(128)}, bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.79, %convolution.88.clone.1) +} + +%region_34.39 (reduce_sum.250: f32[], reduce_sum.251: f32[]) -> f32[] { + %reduce_sum.250 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.251 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.252 = f32[]{:T(128)} add(%reduce_sum.250, %reduce_sum.251), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.220 (param_0.1125: bf16[128256,4096]) -> f32[] { + %param_0.1125 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.941 = f32[128256,4096]{1,0:T(8,128)} convert(%param_0.1125), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %square.159 = f32[128256,4096]{1,0:T(8,128)} multiply(%convert_element_type.941, %convert_element_type.941), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.997 = f32[]{:T(128)} constant(0) + ROOT %reduce.80 = f32[]{:T(128)} reduce(%square.159, %constant.997), dimensions={0,1}, to_apply=%region_34.39, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%region_60.65 (reduce_sum.385: f32[], reduce_sum.389: f32[]) -> f32[] { + %reduce_sum.385 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.389 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.390 = f32[]{:T(128)} add(%reduce_sum.385, %reduce_sum.389), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_46.51 (reduce_sum.313: f32[], reduce_sum.314: f32[]) -> f32[] { + %reduce_sum.313 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.314 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.315 = f32[]{:T(128)} add(%reduce_sum.313, %reduce_sum.314), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.221 (param_0.1113: f32[128256,4096], param_1.1273: f32[], param_2.1094: f32[], param_3.777: f32[], param_4.484: f32[128256,4096], param_5.415: f32[], param_6.288: bf16[128256,4096], param_7.186: pred[], param_8.104: f32[128256,4096]) -> (f32[], f32[128256,4096], f32[128256,4096], f32[128256,4096], f32[]) { + %param_0.1113 = f32[128256,4096]{1,0:T(8,128)} parameter(0) + %param_3.777 = f32[]{:T(128)S(6)} parameter(3) + %mul.1800.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_3.777), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.186 = pred[]{:T(512)S(6)} parameter(7) + %select_n.242.clone.1 = pred[128256,4096]{1,0:T(8,128)(4,1)} broadcast(%param_7.186), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.288 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(6) + %convert_element_type.1033.clone.1 = f32[128256,4096]{1,0:T(8,128)} convert(%param_6.288), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %param_5.415 = f32[]{:T(128)} parameter(5) + %div.725.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_5.415), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.724.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%convert_element_type.1033.clone.1, %div.725.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.241.clone.1 = f32[128256,4096]{1,0:T(8,128)} select(%select_n.242.clone.1, %convert_element_type.1033.clone.1, %div.724.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.899.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.515.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.899.clone.1), dimensions={}, metadata={op_name="broadcast.61"} + %mul.1806.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%select_n.241.clone.1, %broadcast.515.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.104 = f32[128256,4096]{1,0:T(8,128)} parameter(8) + %constant.903.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1807.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.903.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1805.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_8.104, %mul.1807.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.762.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%mul.1806.clone.1, %mul.1805.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1094 = f32[]{:T(128)S(6)} parameter(2) + %div.721.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_2.1094), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.60.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%select_n.241.clone.1, %select_n.241.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.902.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1804.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.902.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1802.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%integer_pow.60.clone.1, %mul.1804.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.484 = f32[128256,4096]{1,0:T(8,128)} parameter(4) + %constant.901.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1803.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.901.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1801.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_4.484, %mul.1803.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.761.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%mul.1802.clone.1, %mul.1801.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1273 = f32[]{:T(128)S(6)} parameter(1) + %div.720.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_1.1273), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.719.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%add.761.clone.1, %div.720.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.58.clone.1 = f32[128256,4096]{1,0:T(8,128)} sqrt(%div.719.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.900.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.760.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.900.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.759.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%sqrt.58.clone.1, %add.760.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.183.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%div.721.clone.1, %add.759.clone.1), metadata={op_name="multiply.33"} + %div.718.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%add.762.clone.1, %multiply.183.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1799.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_0.1113, %broadcast.515.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.758.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%div.718.clone.1, %mul.1799.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1798.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%mul.1800.clone.1, %add.758.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.757.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%param_0.1113, %mul.1798.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.160 = f32[128256,4096]{1,0:T(8,128)} multiply(%add.757.clone.1, %add.757.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.985 = f32[]{:T(128)} constant(0) + %reduce.81 = f32[]{:T(128)} reduce(%square.160, %constant.985), dimensions={0,1}, to_apply=%region_60.65, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.83.clone.1 = f32[]{:T(128)} reduce(%integer_pow.60.clone.1, %constant.985), dimensions={0,1}, to_apply=%region_46.51, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.135 = (f32[]{:T(128)}, f32[128256,4096]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.81, %add.757.clone.1, %add.761.clone.1, %add.762.clone.1, %reduce.83.clone.1) +} + +%region_59.64 (reduce_sum.382: f32[], reduce_sum.383: f32[]) -> f32[] { + %reduce_sum.382 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.383 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.384 = f32[]{:T(128)} add(%reduce_sum.382, %reduce_sum.383), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_45.50 (reduce_sum.307: f32[], reduce_sum.308: f32[]) -> f32[] { + %reduce_sum.307 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.308 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.312 = f32[]{:T(128)} add(%reduce_sum.307, %reduce_sum.308), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.222 (param_0.1114: f32[4096,128256], param_1.1274: f32[], param_2.1095: f32[], param_3.778: f32[], param_4.485: f32[4096,128256], param_5.416: f32[], param_6.289: bf16[4096,128256,1], param_7.187: pred[], param_8.105: f32[4096,128256]) -> (f32[], f32[4096,128256], f32[4096,128256], f32[4096,128256], f32[]) { + %param_0.1114 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %param_3.778 = f32[]{:T(128)S(6)} parameter(3) + %mul.1810.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_3.778), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.187 = pred[]{:T(512)S(6)} parameter(7) + %select_n.246.clone.1 = pred[4096,128256]{1,0:T(8,128)(4,1)} broadcast(%param_7.187), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.289 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} parameter(6) + %bitcast.409.clone.1 = bf16[4096,128256]{1,0:T(8,128)(2,1)} bitcast(%param_6.289), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %convert_element_type.1035.clone.1 = f32[4096,128256]{1,0:T(8,128)} convert(%bitcast.409.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %param_5.416 = f32[]{:T(128)} parameter(5) + %div.733.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_5.416), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.732.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%convert_element_type.1035.clone.1, %div.733.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.245.clone.1 = f32[4096,128256]{1,0:T(8,128)} select(%select_n.246.clone.1, %convert_element_type.1035.clone.1, %div.732.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.905.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.517.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.905.clone.1), dimensions={}, metadata={op_name="broadcast.62"} + %mul.1816.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%select_n.245.clone.1, %broadcast.517.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.105 = f32[4096,128256]{1,0:T(8,128)} parameter(8) + %constant.909.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1817.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.909.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1815.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_8.105, %mul.1817.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.768.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%mul.1816.clone.1, %mul.1815.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1095 = f32[]{:T(128)S(6)} parameter(2) + %div.729.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_2.1095), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.61.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%select_n.245.clone.1, %select_n.245.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.908.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1814.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.908.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1812.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%integer_pow.61.clone.1, %mul.1814.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.485 = f32[4096,128256]{1,0:T(8,128)} parameter(4) + %constant.907.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1813.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.907.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1811.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_4.485, %mul.1813.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.767.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%mul.1812.clone.1, %mul.1811.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1274 = f32[]{:T(128)S(6)} parameter(1) + %div.728.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_1.1274), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.727.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%add.767.clone.1, %div.728.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.59.clone.1 = f32[4096,128256]{1,0:T(8,128)} sqrt(%div.727.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.906.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.766.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.906.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.765.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%sqrt.59.clone.1, %add.766.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.184.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%div.729.clone.1, %add.765.clone.1), metadata={op_name="multiply.32"} + %div.726.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%add.768.clone.1, %multiply.184.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1809.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_0.1114, %broadcast.517.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.764.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%div.726.clone.1, %mul.1809.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1808.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%mul.1810.clone.1, %add.764.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.763.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%param_0.1114, %mul.1808.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.161 = f32[4096,128256]{1,0:T(8,128)} multiply(%add.763.clone.1, %add.763.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.986 = f32[]{:T(128)} constant(0) + %reduce.82 = f32[]{:T(128)} reduce(%square.161, %constant.986), dimensions={0,1}, to_apply=%region_59.64, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.84.clone.1 = f32[]{:T(128)} reduce(%integer_pow.61.clone.1, %constant.986), dimensions={0,1}, to_apply=%region_45.50, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.136 = (f32[]{:T(128)}, f32[4096,128256]{1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.82, %add.763.clone.1, %add.767.clone.1, %add.768.clone.1, %reduce.84.clone.1) +} + +%region_25.30 (reduce_sum.208: f32[], reduce_sum.209: f32[]) -> f32[] { + %reduce_sum.208 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.209 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.210 = f32[]{:T(128)} add(%reduce_sum.208, %reduce_sum.209), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.239 (param_0.1131: f32[4,14336,4096]) -> f32[] { + %param_0.1131 = f32[4,14336,4096]{2,0,1:T(4,128)} parameter(0) + %bitcast.314 = f32[14336,4,4096]{2,1,0:T(4,128)} bitcast(%param_0.1131), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.164 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%bitcast.314, %bitcast.314), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1003 = f32[]{:T(128)} constant(0) + ROOT %reduce.85 = f32[]{:T(128)} reduce(%square.164, %constant.1003), dimensions={0,1,2}, to_apply=%region_25.30, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%region_24.29 (reduce_sum.202: f32[], reduce_sum.203: f32[]) -> f32[] { + %reduce_sum.202 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.203 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.207 = f32[]{:T(128)} add(%reduce_sum.202, %reduce_sum.203), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_23.28 (reduce_sum.196: f32[], reduce_sum.200: f32[]) -> f32[] { + %reduce_sum.196 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.200 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.201 = f32[]{:T(128)} add(%reduce_sum.196, %reduce_sum.200), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.241 (param_0.1132: f32[4,4096,14336], param_1.1288: f32[4,4096,14336]) -> (f32[], f32[]) { + %param_0.1132 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(0) + %bitcast.318 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_0.1132), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.167 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%bitcast.318, %bitcast.318), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1004 = f32[]{:T(128)} constant(0) + %reduce.86 = f32[]{:T(128)} reduce(%square.167, %constant.1004), dimensions={0,1,2}, to_apply=%region_24.29, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1288 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(1) + %bitcast.322.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_1.1288), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.170.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%bitcast.322.clone.1, %bitcast.322.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.87.clone.1 = f32[]{:T(128)} reduce(%square.170.clone.1, %constant.1004), dimensions={0,1,2}, to_apply=%region_23.28, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.155 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.86, %reduce.87.clone.1) +} + +%fused_computation.244 (param_0.699: f32[14336,4,4096]) -> bf16[4,14336,4096] { + %param_0.699 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(0) + %copy.234 = bf16[14336,4,4096]{2,0,1:T(8,128)(2,1)} copy(%param_0.699), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wo\'][\'kernel\']"} + ROOT %bitcast.323 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} bitcast(%copy.234), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +} + +%fused_computation.245 (param_0.701: f32[4096,4,14336]) -> bf16[4,4096,14336] { + %param_0.701 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %copy.235 = bf16[4096,4,14336]{2,0,1:T(8,128)(2,1)} copy(%param_0.701), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_1\'][\'kernel\']"} + ROOT %bitcast.324 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} bitcast(%copy.235), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +} + +%fused_computation.246 (param_0.703: f32[4096,4,14336]) -> bf16[4,4096,14336] { + %param_0.703 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %copy.236 = bf16[4096,4,14336]{2,0,1:T(8,128)(2,1)} copy(%param_0.703), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_0\'][\'kernel\']"} + ROOT %bitcast.325 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} bitcast(%copy.236), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +} + +%region_52.57 (reduce_sum.343: f32[], reduce_sum.347: f32[]) -> f32[] { + %reduce_sum.343 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.347 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.348 = f32[]{:T(128)} add(%reduce_sum.343, %reduce_sum.347), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_38.43 (reduce_sum.271: f32[], reduce_sum.272: f32[]) -> f32[] { + %reduce_sum.271 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.272 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.273 = f32[]{:T(128)} add(%reduce_sum.271, %reduce_sum.272), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.247 (param_0.1121: f32[14336,4,4096], param_1.1281: f32[], param_2.1102: f32[], param_3.785: f32[], param_4.492: f32[14336,4,4096], param_5.423: f32[], param_6.296: f32[4,14336,4096], param_7.194: pred[], param_8.112: f32[14336,4,4096]) -> (f32[], f32[14336,4,4096], f32[14336,4,4096], f32[14336,4,4096], f32[]) { + %param_0.1121 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(0) + %param_3.785 = f32[]{:T(128)S(6)} parameter(3) + %mul.1868.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_3.785), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.194 = pred[]{:T(512)S(6)} parameter(7) + %select_n.274.clone.1 = pred[14336,4,4096]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.194), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.296 = f32[4,14336,4096]{2,0,1:T(4,128)} parameter(6) + %bitcast.423.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} bitcast(%param_6.296), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.423 = f32[]{:T(128)} parameter(5) + %div.789.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_5.423), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.788.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%bitcast.423.clone.1, %div.789.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.273.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} select(%select_n.274.clone.1, %bitcast.423.clone.1, %div.788.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.947.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.547.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.947.clone.1), dimensions={}, metadata={op_name="broadcast.69"} + %mul.1874.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%select_n.273.clone.1, %broadcast.547.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.112 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(8) + %constant.951.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1875.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.951.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1873.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_8.112, %mul.1875.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.806.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%mul.1874.clone.1, %mul.1873.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1102 = f32[]{:T(128)S(6)} parameter(2) + %div.785.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_2.1102), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.68.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%select_n.273.clone.1, %select_n.273.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.950.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1872.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.950.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1870.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%integer_pow.68.clone.1, %mul.1872.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.492 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(4) + %constant.949.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1871.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.949.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1869.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_4.492, %mul.1871.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.805.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%mul.1870.clone.1, %mul.1869.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1281 = f32[]{:T(128)S(6)} parameter(1) + %div.784.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_1.1281), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.783.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%add.805.clone.1, %div.784.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.66.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} sqrt(%div.783.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.948.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.804.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.948.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.803.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%sqrt.66.clone.1, %add.804.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.191.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%div.785.clone.1, %add.803.clone.1), metadata={op_name="multiply.25"} + %div.782.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%add.806.clone.1, %multiply.191.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1867.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_0.1121, %broadcast.547.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.802.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%div.782.clone.1, %mul.1867.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1866.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%mul.1868.clone.1, %add.802.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.801.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%param_0.1121, %mul.1866.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.171 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%add.801.clone.1, %add.801.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.993 = f32[]{:T(128)} constant(0) + %reduce.88 = f32[]{:T(128)} reduce(%square.171, %constant.993), dimensions={0,1,2}, to_apply=%region_52.57, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.91.clone.1 = f32[]{:T(128)} reduce(%integer_pow.68.clone.1, %constant.993), dimensions={0,1,2}, to_apply=%region_38.43, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.137 = (f32[]{:T(128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.88, %add.801.clone.1, %add.805.clone.1, %add.806.clone.1, %reduce.91.clone.1) +} + +%region_51.56 (reduce_sum.340: f32[], reduce_sum.341: f32[]) -> f32[] { + %reduce_sum.340 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.341 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.342 = f32[]{:T(128)} add(%reduce_sum.340, %reduce_sum.341), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_37.42 (reduce_sum.265: f32[], reduce_sum.266: f32[]) -> f32[] { + %reduce_sum.265 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.266 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.270 = f32[]{:T(128)} add(%reduce_sum.265, %reduce_sum.266), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.248 (param_0.1122: f32[4096,4,14336], param_1.1282: f32[], param_2.1103: f32[], param_3.786: f32[], param_4.493: f32[4096,4,14336], param_5.424: f32[], param_6.297: f32[4,4096,14336], param_7.195: pred[], param_8.113: f32[4096,4,14336]) -> (f32[], f32[4096,4,14336], f32[4096,4,14336], f32[4096,4,14336], f32[]) { + %param_0.1122 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %param_3.786 = f32[]{:T(128)S(6)} parameter(3) + %mul.1878.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_3.786), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.195 = pred[]{:T(512)S(6)} parameter(7) + %select_n.278.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.195), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.297 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(6) + %bitcast.425.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_6.297), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.424 = f32[]{:T(128)} parameter(5) + %div.797.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_5.424), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.796.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%bitcast.425.clone.1, %div.797.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.277.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%select_n.278.clone.1, %bitcast.425.clone.1, %div.796.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.953.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.553.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.953.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.1882.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.277.clone.1, %broadcast.553.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.113 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(8) + %constant.957.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.552.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.957.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.1881.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_8.113, %broadcast.552.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.811.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1882.clone.1, %mul.1881.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1103 = f32[]{:T(128)S(6)} parameter(2) + %div.793.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_2.1103), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.69.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.277.clone.1, %select_n.277.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.956.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.551.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.956.clone.1), dimensions={}, metadata={op_name="broadcast.60"} + %mul.1880.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%integer_pow.69.clone.1, %broadcast.551.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.493 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(4) + %constant.955.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.550.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.955.clone.1), dimensions={}, metadata={op_name="broadcast.59"} + %mul.1879.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_4.493, %broadcast.550.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.810.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1880.clone.1, %mul.1879.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1282 = f32[]{:T(128)S(6)} parameter(1) + %div.792.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_1.1282), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.791.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.810.clone.1, %div.792.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.67.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} sqrt(%div.791.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.954.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.548.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.954.clone.1), dimensions={}, metadata={op_name="broadcast.54"} + %add.809.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%sqrt.67.clone.1, %broadcast.548.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.192.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%div.793.clone.1, %add.809.clone.1), metadata={op_name="multiply.24"} + %div.790.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.811.clone.1, %multiply.192.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1877.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_0.1122, %broadcast.553.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.808.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%div.790.clone.1, %mul.1877.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1876.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%mul.1878.clone.1, %add.808.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.807.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%param_0.1122, %mul.1876.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.172 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%add.807.clone.1, %add.807.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.994 = f32[]{:T(128)} constant(0) + %reduce.89 = f32[]{:T(128)} reduce(%square.172, %constant.994), dimensions={0,1,2}, to_apply=%region_51.56, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.92.clone.1 = f32[]{:T(128)} reduce(%integer_pow.69.clone.1, %constant.994), dimensions={0,1,2}, to_apply=%region_37.42, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.138 = (f32[]{:T(128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.89, %add.807.clone.1, %add.810.clone.1, %add.811.clone.1, %reduce.92.clone.1) +} + +%region_50.55 (reduce_sum.334: f32[], reduce_sum.335: f32[]) -> f32[] { + %reduce_sum.334 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.335 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.336 = f32[]{:T(128)} add(%reduce_sum.334, %reduce_sum.335), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_36.41 (reduce_sum.259: f32[], reduce_sum.263: f32[]) -> f32[] { + %reduce_sum.259 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.263 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.264 = f32[]{:T(128)} add(%reduce_sum.259, %reduce_sum.263), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.249 (param_0.1123: f32[4096,4,14336], param_1.1283: f32[], param_2.1104: f32[], param_3.787: f32[], param_4.494: f32[4096,4,14336], param_5.425: f32[], param_6.298: f32[4,4096,14336], param_7.196: pred[], param_8.114: f32[4096,4,14336]) -> (f32[], f32[4096,4,14336], f32[4096,4,14336], f32[4096,4,14336], f32[]) { + %param_0.1123 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %param_3.787 = f32[]{:T(128)S(6)} parameter(3) + %mul.1885.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_3.787), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.196 = pred[]{:T(512)S(6)} parameter(7) + %select_n.282.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.196), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.298 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(6) + %bitcast.427.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_6.298), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.425 = f32[]{:T(128)} parameter(5) + %div.805.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_5.425), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.804.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%bitcast.427.clone.1, %div.805.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.281.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%select_n.282.clone.1, %bitcast.427.clone.1, %div.804.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.959.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.559.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.959.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.1889.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.281.clone.1, %broadcast.559.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.114 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(8) + %constant.963.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.558.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.963.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.1888.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_8.114, %broadcast.558.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.816.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1889.clone.1, %mul.1888.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1104 = f32[]{:T(128)S(6)} parameter(2) + %div.801.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_2.1104), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.70.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.281.clone.1, %select_n.281.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.962.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.557.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.962.clone.1), dimensions={}, metadata={op_name="broadcast.60"} + %mul.1887.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%integer_pow.70.clone.1, %broadcast.557.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.494 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(4) + %constant.961.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.556.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.961.clone.1), dimensions={}, metadata={op_name="broadcast.59"} + %mul.1886.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_4.494, %broadcast.556.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.815.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1887.clone.1, %mul.1886.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1283 = f32[]{:T(128)S(6)} parameter(1) + %div.800.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_1.1283), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.799.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.815.clone.1, %div.800.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.68.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} sqrt(%div.799.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.960.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.554.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.960.clone.1), dimensions={}, metadata={op_name="broadcast.54"} + %add.814.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%sqrt.68.clone.1, %broadcast.554.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.193.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%div.801.clone.1, %add.814.clone.1), metadata={op_name="multiply.23"} + %div.798.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.816.clone.1, %multiply.193.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1884.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_0.1123, %broadcast.559.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.813.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%div.798.clone.1, %mul.1884.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1883.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%mul.1885.clone.1, %add.813.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.812.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%param_0.1123, %mul.1883.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.173 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%add.812.clone.1, %add.812.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.995 = f32[]{:T(128)} constant(0) + %reduce.90 = f32[]{:T(128)} reduce(%square.173, %constant.995), dimensions={0,1,2}, to_apply=%region_50.55, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.93.clone.1 = f32[]{:T(128)} reduce(%integer_pow.70.clone.1, %constant.995), dimensions={0,1,2}, to_apply=%region_36.41, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.139 = (f32[]{:T(128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.90, %add.812.clone.1, %add.815.clone.1, %add.816.clone.1, %reduce.93.clone.1) +} + +%region_30.35 (reduce_sum.235: f32[], reduce_sum.236: f32[]) -> f32[] { + %reduce_sum.235 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.236 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.237 = f32[]{:T(128)} add(%reduce_sum.235, %reduce_sum.236), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.267 (param_0.1127: f32[4,4096,32,128]) -> f32[] { + %param_0.1127 = f32[4,4096,32,128]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.329 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1127), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.176 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%bitcast.329, %bitcast.329), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.999 = f32[]{:T(128)} constant(0) + ROOT %reduce.94 = f32[]{:T(128)} reduce(%square.176, %constant.999), dimensions={0,1,2,3}, to_apply=%region_30.35, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%region_29.34 (reduce_sum.229: f32[], reduce_sum.230: f32[]) -> f32[] { + %reduce_sum.229 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.230 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.231 = f32[]{:T(128)} add(%reduce_sum.229, %reduce_sum.230), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.269 (param_0.1128: f32[4,32,128,4096]) -> f32[] { + %param_0.1128 = f32[4,32,128,4096]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.333 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} bitcast(%param_0.1128), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.179 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%bitcast.333, %bitcast.333), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1000 = f32[]{:T(128)} constant(0) + ROOT %reduce.95 = f32[]{:T(128)} reduce(%square.179, %constant.1000), dimensions={0,1,2,3}, to_apply=%region_29.34, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%fused_computation.270 (param_0.753: f32[32,4,128,4096]) -> bf16[4,32,128,4096] { + %param_0.753 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(0) + %copy.237 = bf16[32,4,128,4096]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.753), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'out\'][\'kernel\']"} + ROOT %bitcast.334 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.237), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +} + +%region_57.62 (reduce_sum.370: f32[], reduce_sum.371: f32[]) -> f32[] { + %reduce_sum.370 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.371 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.375 = f32[]{:T(128)} add(%reduce_sum.370, %reduce_sum.371), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_43.48 (reduce_sum.298: f32[], reduce_sum.299: f32[]) -> f32[] { + %reduce_sum.298 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.299 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.300 = f32[]{:T(128)} add(%reduce_sum.298, %reduce_sum.299), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.271 (param_0.1116: f32[4096,4,32,128], param_1.1276: f32[], param_2.1097: f32[], param_3.780: f32[], param_4.487: f32[4096,4,32,128], param_5.418: f32[], param_6.291: f32[4,4096,32,128], param_7.189: pred[], param_8.107: f32[4096,4,32,128]) -> (f32[], f32[4096,4,32,128], f32[4096,4,32,128], f32[4096,4,32,128], f32[]) { + %param_0.1116 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.780 = f32[]{:T(128)S(6)} parameter(3) + %mul.1827.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_3.780), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.189 = pred[]{:T(512)S(6)} parameter(7) + %select_n.254.clone.1 = pred[4096,4,32,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.189), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.291 = f32[4,4096,32,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.413.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} bitcast(%param_6.291), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.418 = f32[]{:T(128)} parameter(5) + %div.749.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_5.418), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.748.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%bitcast.413.clone.1, %div.749.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.253.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} select(%select_n.254.clone.1, %bitcast.413.clone.1, %div.748.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.917.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.525.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.917.clone.1), dimensions={}, metadata={op_name="broadcast.63"} + %mul.1833.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%select_n.253.clone.1, %broadcast.525.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.107 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.921.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1834.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.921.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1832.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_8.107, %mul.1834.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.779.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%mul.1833.clone.1, %mul.1832.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1097 = f32[]{:T(128)S(6)} parameter(2) + %div.745.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1097), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.63.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%select_n.253.clone.1, %select_n.253.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.920.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1831.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.920.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1829.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.63.clone.1, %mul.1831.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.487 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.919.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1830.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.919.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1828.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_4.487, %mul.1830.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.778.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%mul.1829.clone.1, %mul.1828.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1276 = f32[]{:T(128)S(6)} parameter(1) + %div.744.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1276), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.743.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%add.778.clone.1, %div.744.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.61.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} sqrt(%div.743.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.918.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.777.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.918.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.776.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%sqrt.61.clone.1, %add.777.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.186.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%div.745.clone.1, %add.776.clone.1), metadata={op_name="multiply.30"} + %div.742.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%add.779.clone.1, %multiply.186.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1826.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_0.1116, %broadcast.525.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.775.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%div.742.clone.1, %mul.1826.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1825.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%mul.1827.clone.1, %add.775.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.774.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%param_0.1116, %mul.1825.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.180 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%add.774.clone.1, %add.774.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.988 = f32[]{:T(128)} constant(0) + %reduce.96 = f32[]{:T(128)} reduce(%square.180, %constant.988), dimensions={0,1,2,3}, to_apply=%region_57.62, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.100.clone.1 = f32[]{:T(128)} reduce(%integer_pow.63.clone.1, %constant.988), dimensions={0,1,2,3}, to_apply=%region_43.48, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.140 = (f32[]{:T(128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.96, %add.774.clone.1, %add.778.clone.1, %add.779.clone.1, %reduce.100.clone.1) +} + +%region_56.61 (reduce_sum.364: f32[], reduce_sum.368: f32[]) -> f32[] { + %reduce_sum.364 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.368 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.369 = f32[]{:T(128)} add(%reduce_sum.364, %reduce_sum.368), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_42.47 (reduce_sum.292: f32[], reduce_sum.293: f32[]) -> f32[] { + %reduce_sum.292 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.293 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.294 = f32[]{:T(128)} add(%reduce_sum.292, %reduce_sum.293), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.272 (param_0.1117: f32[32,4,128,4096], param_1.1277: f32[], param_2.1098: f32[], param_3.781: f32[], param_4.488: f32[32,4,128,4096], param_5.419: f32[], param_6.292: f32[4,32,128,4096], param_7.190: pred[], param_8.108: f32[32,4,128,4096]) -> (f32[], f32[32,4,128,4096], f32[32,4,128,4096], f32[32,4,128,4096], f32[]) { + %param_0.1117 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(0) + %param_3.781 = f32[]{:T(128)S(6)} parameter(3) + %mul.1837.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_3.781), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.190 = pred[]{:T(512)S(6)} parameter(7) + %select_n.258.clone.1 = pred[32,4,128,4096]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.190), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.292 = f32[4,32,128,4096]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.415.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} bitcast(%param_6.292), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.419 = f32[]{:T(128)} parameter(5) + %div.757.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_5.419), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.756.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%bitcast.415.clone.1, %div.757.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.257.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} select(%select_n.258.clone.1, %bitcast.415.clone.1, %div.756.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.923.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.527.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.923.clone.1), dimensions={}, metadata={op_name="broadcast.64"} + %mul.1843.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%select_n.257.clone.1, %broadcast.527.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.108 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(8) + %constant.927.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1844.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.927.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1842.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_8.108, %mul.1844.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.785.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%mul.1843.clone.1, %mul.1842.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1098 = f32[]{:T(128)S(6)} parameter(2) + %div.753.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_2.1098), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.64.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%select_n.257.clone.1, %select_n.257.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.926.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1841.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.926.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1839.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%integer_pow.64.clone.1, %mul.1841.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.488 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(4) + %constant.925.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1840.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.925.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1838.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_4.488, %mul.1840.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.784.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%mul.1839.clone.1, %mul.1838.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1277 = f32[]{:T(128)S(6)} parameter(1) + %div.752.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_1.1277), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.751.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%add.784.clone.1, %div.752.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.62.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} sqrt(%div.751.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.924.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.783.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.924.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.782.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%sqrt.62.clone.1, %add.783.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.187.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%div.753.clone.1, %add.782.clone.1), metadata={op_name="multiply.29"} + %div.750.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%add.785.clone.1, %multiply.187.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1836.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_0.1117, %broadcast.527.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.781.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%div.750.clone.1, %mul.1836.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1835.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%mul.1837.clone.1, %add.781.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.780.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%param_0.1117, %mul.1835.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.181 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%add.780.clone.1, %add.780.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.989 = f32[]{:T(128)} constant(0) + %reduce.97 = f32[]{:T(128)} reduce(%square.181, %constant.989), dimensions={0,1,2,3}, to_apply=%region_56.61, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.101.clone.1 = f32[]{:T(128)} reduce(%integer_pow.64.clone.1, %constant.989), dimensions={0,1,2,3}, to_apply=%region_42.47, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.141 = (f32[]{:T(128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.97, %add.780.clone.1, %add.784.clone.1, %add.785.clone.1, %reduce.101.clone.1) +} + +%region_47.52 (reduce_sum.319: f32[], reduce_sum.320: f32[]) -> f32[] { + %reduce_sum.319 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.320 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.321 = f32[]{:T(128)} add(%reduce_sum.319, %reduce_sum.320), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.279 (param_0.1136: bf16[4,128,128256], param_1.1292: f32[4,128], param_2.1109: s32[4,128], param_3.791: bf16[4,128]) -> f32[4,128] { + %param_2.1109 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %eq.30 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1109), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.25 = s32[4,128,128256]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.24 = pred[4,128,128256]{2,1,0:T(8,128)(4,1)} compare(%eq.30, %eq.25), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %param_0.1136 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.966 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1136), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_3.791 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) + %sub.73 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.791), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.64 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.966, %sub.73), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_1.1292 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %sub.71 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1292), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.60 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%sub.64, %sub.71), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %constant.1009 = f32[]{:T(128)} constant(0) + %broadcast.472 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%constant.1009), dimensions={}, metadata={op_name="broadcast.39"} + %mul.1674 = f32[4,128,128256]{2,1,0:T(8,128)} select(%eq.24, %sub.60, %broadcast.472), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + ROOT %reduce.98 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1674, %constant.1009), dimensions={2}, to_apply=%region_47.52, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} +} + +%region_7.10 (reduce_sum.123: f32[], reduce_sum.127: f32[]) -> f32[] { + %reduce_sum.123 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.127 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.128 = f32[]{:T(128)} add(%reduce_sum.123, %reduce_sum.127), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.284 (param_0.1137: bf16[4,128,128256], param_1.1293: bf16[4,128]) -> f32[4,128] { + %param_0.1137 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.972 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1137), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_1.1293 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) + %sub.74 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1293), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.70 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.972, %sub.74), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.54 = f32[4,128,128256]{2,1,0:T(8,128)} exponential(%sub.70), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %constant.1010 = f32[]{:T(128)} constant(0) + ROOT %reduce.99 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.54, %constant.1010), dimensions={2}, to_apply=%region_7.10, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} +} + +%region_31.36 (reduce_sum.238: f32[], reduce_sum.242: f32[]) -> f32[] { + %reduce_sum.238 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.242 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.243 = f32[]{:T(128)} add(%reduce_sum.238, %reduce_sum.242), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_28.33 (reduce_sum.223: f32[], reduce_sum.224: f32[]) -> f32[] { + %reduce_sum.223 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.224 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.228 = f32[]{:T(128)} add(%reduce_sum.223, %reduce_sum.224), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.290 (param_0.1129: f32[4,4096,8,128], param_1.1286: f32[4,4096,8,128]) -> (f32[], f32[]) { + %param_0.1129 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.350 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1129), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.184 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.350, %bitcast.350), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1001 = f32[]{:T(128)} constant(0) + %reduce.102 = f32[]{:T(128)} reduce(%square.184, %constant.1001), dimensions={0,1,2,3}, to_apply=%region_31.36, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1286 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(1) + %bitcast.354.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_1.1286), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.187.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.354.clone.1, %bitcast.354.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.103.clone.1 = f32[]{:T(128)} reduce(%square.187.clone.1, %constant.1001), dimensions={0,1,2,3}, to_apply=%region_28.33, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.156 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.102, %reduce.103.clone.1) +} + +%fused_computation.293 (param_0.812: f32[4096,4,8,128]) -> bf16[4,4096,8,128] { + %param_0.812 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %copy.238 = bf16[4096,4,8,128]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.812), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'value\'][\'kernel\']"} + ROOT %bitcast.355 = bf16[4,4096,8,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%copy.238), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +} + +%region_58.63 (reduce_sum.376: f32[], reduce_sum.377: f32[]) -> f32[] { + %reduce_sum.376 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.377 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.378 = f32[]{:T(128)} add(%reduce_sum.376, %reduce_sum.377), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_44.49 (reduce_sum.301: f32[], reduce_sum.305: f32[]) -> f32[] { + %reduce_sum.301 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.305 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.306 = f32[]{:T(128)} add(%reduce_sum.301, %reduce_sum.305), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.294 (param_0.1115: f32[4096,4,8,128], param_1.1275: f32[], param_2.1096: f32[], param_3.779: f32[], param_4.486: f32[4096,4,8,128], param_5.417: f32[], param_6.290: f32[4,4096,8,128], param_7.188: pred[], param_8.106: f32[4096,4,8,128]) -> (f32[], f32[4096,4,8,128], f32[4096,4,8,128], f32[4096,4,8,128], f32[]) { + %param_0.1115 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.779 = f32[]{:T(128)S(6)} parameter(3) + %mul.1820.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.779), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.188 = pred[]{:T(512)S(6)} parameter(7) + %select_n.250.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.188), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.290 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.411.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.290), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.417 = f32[]{:T(128)} parameter(5) + %div.741.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.417), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.740.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.411.clone.1, %div.741.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.249.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.250.clone.1, %bitcast.411.clone.1, %div.740.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.911.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.523.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.911.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.1824.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.249.clone.1, %broadcast.523.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.106 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.915.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.522.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.915.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %mul.1823.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.106, %broadcast.522.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.773.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1824.clone.1, %mul.1823.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1096 = f32[]{:T(128)S(6)} parameter(2) + %div.737.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1096), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.62.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.249.clone.1, %select_n.249.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.914.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.521.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.914.clone.1), dimensions={}, metadata={op_name="broadcast.56"} + %mul.1822.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.62.clone.1, %broadcast.521.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.486 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.913.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.520.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.913.clone.1), dimensions={}, metadata={op_name="broadcast.55"} + %mul.1821.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.486, %broadcast.520.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.772.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1822.clone.1, %mul.1821.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1275 = f32[]{:T(128)S(6)} parameter(1) + %div.736.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1275), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.735.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.772.clone.1, %div.736.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.60.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.735.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.912.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.518.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.912.clone.1), dimensions={}, metadata={op_name="broadcast.52"} + %add.771.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.60.clone.1, %broadcast.518.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.185.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.737.clone.1, %add.771.clone.1), metadata={op_name="multiply.31"} + %div.734.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.773.clone.1, %multiply.185.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1819.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1115, %broadcast.523.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.770.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%div.734.clone.1, %mul.1819.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1818.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1820.clone.1, %add.770.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.769.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1115, %mul.1818.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.188 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.769.clone.1, %add.769.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.987 = f32[]{:T(128)} constant(0) + %reduce.104 = f32[]{:T(128)} reduce(%square.188, %constant.987), dimensions={0,1,2,3}, to_apply=%region_58.63, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.106.clone.1 = f32[]{:T(128)} reduce(%integer_pow.62.clone.1, %constant.987), dimensions={0,1,2,3}, to_apply=%region_44.49, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.142 = (f32[]{:T(128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.104, %add.769.clone.1, %add.772.clone.1, %add.773.clone.1, %reduce.106.clone.1) +} + +%region_55.60 (reduce_sum.361: f32[], reduce_sum.362: f32[]) -> f32[] { + %reduce_sum.361 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.362 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.363 = f32[]{:T(128)} add(%reduce_sum.361, %reduce_sum.362), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_41.46 (reduce_sum.286: f32[], reduce_sum.287: f32[]) -> f32[] { + %reduce_sum.286 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.287 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.291 = f32[]{:T(128)} add(%reduce_sum.286, %reduce_sum.287), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.295 (param_0.1118: f32[4096,4,8,128], param_1.1278: f32[], param_2.1099: f32[], param_3.782: f32[], param_4.489: f32[4096,4,8,128], param_5.420: f32[], param_6.293: f32[4,4096,8,128], param_7.191: pred[], param_8.109: f32[4096,4,8,128]) -> (f32[], f32[4096,4,8,128], f32[4096,4,8,128], f32[4096,4,8,128], f32[]) { + %param_0.1118 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.782 = f32[]{:T(128)S(6)} parameter(3) + %mul.1847.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.782), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.191 = pred[]{:T(512)S(6)} parameter(7) + %select_n.262.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.191), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.293 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.417.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.293), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.420 = f32[]{:T(128)} parameter(5) + %div.765.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.420), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.764.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.417.clone.1, %div.765.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.261.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.262.clone.1, %bitcast.417.clone.1, %div.764.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.929.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.533.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.929.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.1851.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.261.clone.1, %broadcast.533.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.109 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.933.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.532.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.933.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %mul.1850.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.109, %broadcast.532.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.790.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1851.clone.1, %mul.1850.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1099 = f32[]{:T(128)S(6)} parameter(2) + %div.761.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1099), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.65.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.261.clone.1, %select_n.261.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.932.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.531.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.932.clone.1), dimensions={}, metadata={op_name="broadcast.56"} + %mul.1849.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.65.clone.1, %broadcast.531.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.489 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.931.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.530.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.931.clone.1), dimensions={}, metadata={op_name="broadcast.55"} + %mul.1848.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.489, %broadcast.530.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.789.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1849.clone.1, %mul.1848.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1278 = f32[]{:T(128)S(6)} parameter(1) + %div.760.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1278), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.759.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.789.clone.1, %div.760.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.63.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.759.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.930.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.528.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.930.clone.1), dimensions={}, metadata={op_name="broadcast.52"} + %add.788.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.63.clone.1, %broadcast.528.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.188.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.761.clone.1, %add.788.clone.1), metadata={op_name="multiply.28"} + %div.758.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.790.clone.1, %multiply.188.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1846.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1118, %broadcast.533.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.787.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%div.758.clone.1, %mul.1846.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1845.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1847.clone.1, %add.787.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.786.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1118, %mul.1845.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.189 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.786.clone.1, %add.786.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.990 = f32[]{:T(128)} constant(0) + %reduce.105 = f32[]{:T(128)} reduce(%square.189, %constant.990), dimensions={0,1,2,3}, to_apply=%region_55.60, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.107.clone.1 = f32[]{:T(128)} reduce(%integer_pow.65.clone.1, %constant.990), dimensions={0,1,2,3}, to_apply=%region_41.46, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.143 = (f32[]{:T(128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.105, %add.786.clone.1, %add.789.clone.1, %add.790.clone.1, %reduce.107.clone.1) +} + +%fused_computation.311 (param_0.877: bf16[4,128,4096], param_1.943: f32[4,128], param_2.720: f32[4,128], param_3.440: bf16[4,128,4096], param_4.265: bf16[4096]) -> bf16[4,128,4096] { + %param_3.440 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.265 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %mul.1754 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_4.265), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1728 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_3.440, %mul.1754), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.989 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%mul.1728), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.720 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %mul.1725 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_2.720), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1716 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.989, %mul.1725), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.877 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1000 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.877), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_1.943 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.1723 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.943), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1722 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1000, %mul.1723), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %add_any.138 = f32[4,128,4096]{2,1,0:T(8,128)} add(%mul.1716, %mul.1722), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add_any" stack_frame_id=0} + ROOT %convert_element_type.987 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%add_any.138), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} +} + +%region_4.7 (reduce_sum.114: f32[], reduce_sum.115: f32[]) -> f32[] { + %reduce_sum.114 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.115 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.116 = f32[]{:T(128)} add(%reduce_sum.114, %reduce_sum.115), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.312 (param_0.1138: bf16[4,128,4096]) -> f32[4,128] { + %param_0.1138 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.991 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1138), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %square.192 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.991, %convert_element_type.991), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/square" stack_frame_id=0} + %constant.1011 = f32[]{:T(128)} constant(0) + ROOT %reduce.108 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.192, %constant.1011), dimensions={2}, to_apply=%region_4.7, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} +} + +%region_10.13 (reduce_sum.141: f32[], reduce_sum.142: f32[]) -> f32[] { + %reduce_sum.141 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.142 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.143 = f32[]{:T(128)} add(%reduce_sum.141, %reduce_sum.142), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.314 (param_0.1133: bf16[4,128,4096], param_1.1289: bf16[4,128,4096], param_2.1107: bf16[4096]) -> f32[4,128] { + %param_0.1133 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.998 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1133), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_1.1289 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %param_2.1107 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.1753 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1107), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1727 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1289, %mul.1753), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.997 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%mul.1727), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %mul.1720 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.998, %convert_element_type.997), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.1005 = f32[]{:T(128)} constant(0) + ROOT %reduce.109 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1720, %constant.1005), dimensions={2}, to_apply=%region_10.13, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} +} + +%region_8.11 (reduce_sum.129: bf16[], reduce_sum.130: bf16[]) -> bf16[] { + %reduce_sum.129 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.130 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.134 = bf16[]{:T(256)} add(%reduce_sum.129, %reduce_sum.130), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.235.clone.clone (param_0.1102: f32[4096,128256]) -> bf16[4096,128256,1] { + %param_0.1102 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %convert_element_type.1049 = bf16[4096,128256]{1,0:T(8,128)(2,1)} convert(%param_0.1102), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + ROOT %bitcast.449 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} bitcast(%convert_element_type.1049), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} +} + +%fused_computation.280.clone.1.clone.clone (param_0.1103: bf16[4,128,128256], param_1.1265: s32[4,128], param_2.1075: f32[4,128], param_3.770: f32[4,128], param_4.478: bf16[4,128], param_5.409: f32[4,128]) -> bf16[4,128,128256] { + %param_5.409 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.1925 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_5.409), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.770 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.1924 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.770), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.1103 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1052 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1103), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_4.478 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.88 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_4.478), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.87 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.1052, %sub.88), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.60 = f32[4,128,128256]{2,1,0:T(8,128)} exponential(%sub.87), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %mul.1923 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1924, %exp.60), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.1075 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.819 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1075), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.818 = f32[4,128,128256]{2,1,0:T(8,128)} divide(%mul.1923, %div.819), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.1265 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.43 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1265), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.42 = s32[4,128,128256]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.41 = pred[4,128,128256]{2,1,0:T(8,128)(4,1)} compare(%eq.43, %eq.42), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %convert_element_type.1051 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%eq.41), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.86 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%div.818, %convert_element_type.1051), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.1922 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1925, %sub.86), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.1050 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convert(%mul.1922), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +} + +%fused_computation.315 (param_0.1101: f32[4,128], param_1.1264: bf16[4,128,4096], param_2.1076: f32[4096,128256], param_3.771: bf16[4,128,128256], param_4.479: s32[4,128], param_5.410: f32[4,128], param_6.285: f32[4,128], param_7.183: bf16[4,128], param_8.102: f32[4,128]) -> (bf16[4096], bf16[4,128,4096]) { + %param_1.1264 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1010 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1264), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1101 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1742 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1101), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1741 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1010, %mul.1742), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1009 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1741), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_3.771 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.479 = s32[4,128]{1,0:T(4,128)S(1)} parameter(4) + %param_5.410 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %param_6.285 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.183 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(7) + %param_8.102 = f32[4,128]{1,0:T(4,128)S(1)} parameter(8) + %multiply_convert_fusion.2.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} fusion(%param_3.771, %param_4.479, %param_5.410, %param_6.285, %param_7.183, /*index=5*/%param_8.102), kind=kLoop, calls=%fused_computation.280.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_2.1076 = f32[4096,128256]{1,0:T(8,128)} parameter(2) + %fusion.219.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} fusion(%param_2.1076), kind=kLoop, calls=%fused_computation.235.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %convolution.86.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} convolution(%multiply_convert_fusion.2.clone.1, %fusion.219.clone.1), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %mul.1724 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1009, %convolution.86.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.878 = bf16[]{:T(256)} constant(0) + %reduce.110 = bf16[4096]{0:T(1024)(128)(2,1)} reduce(%mul.1724, %constant.878), dimensions={0,1}, to_apply=%region_8.11, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} + ROOT %tuple.153 = (bf16[4096]{0:T(1024)(128)(2,1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.110, %convolution.86.clone.1) +} + +%fused_computation.323 (param_0.911: f32[64], param_1.978: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.978 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %div.621 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.978), dimensions={0,1}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %param_0.911 = f32[64]{0:T(128)S(1)} parameter(0) + %div.619 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.911), dimensions={3}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %div.618 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.621, %div.619), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %sin.38 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.618), metadata={op_name="jit(train_step)/layers/sin" stack_frame_id=0} + %convert_element_type.1018 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.38), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %cos.41.clone.1 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.618), metadata={op_name="jit(train_step)/layers/cos" stack_frame_id=0} + %convert_element_type.1017.clone.1 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.41.clone.1), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.150 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1018, %convert_element_type.1017.clone.1) +} + +%fused_computation.324 (param_0.908: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.908 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.858 = bf16[]{:T(256)} constant(-inf) + %pad.38 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.908, %constant.858), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + %pad.37 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.908, %constant.858), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + ROOT %maximum.34 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.38, %pad.37), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} +} + +%fused_computation.325 (param_0.910: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.910 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.857 = bf16[]{:T(256)} constant(-inf) + %pad.40 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.910, %constant.857), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + %pad.39 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.910, %constant.857), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + ROOT %maximum.35 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.40, %pad.39), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} +} + +%region_27.32 (reduce_sum.217: f32[], reduce_sum.221: f32[]) -> f32[] { + %reduce_sum.217 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.221 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.222 = f32[]{:T(128)} add(%reduce_sum.217, %reduce_sum.221), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_26.31 (reduce_sum.214: f32[], reduce_sum.215: f32[]) -> f32[] { + %reduce_sum.214 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.215 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.216 = f32[]{:T(128)} add(%reduce_sum.214, %reduce_sum.215), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.329 (param_0.1130: f32[4,4096], param_1.1287: f32[4,4096]) -> (f32[], f32[]) { + %param_0.1130 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(0) + %bitcast.371 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_0.1130), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.195 = f32[4096,4]{0,1:T(4,128)} multiply(%bitcast.371, %bitcast.371), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1002 = f32[]{:T(128)} constant(0) + %reduce.111 = f32[]{:T(128)} reduce(%square.195, %constant.1002), dimensions={0,1}, to_apply=%region_27.32, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1287 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(1) + %bitcast.375.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_1.1287), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.198.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%bitcast.375.clone.1, %bitcast.375.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.112.clone.1 = f32[]{:T(128)} reduce(%square.198.clone.1, %constant.1002), dimensions={0,1}, to_apply=%region_26.31, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.157 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.111, %reduce.112.clone.1) +} + +%region_54.59 (reduce_sum.355: f32[], reduce_sum.356: f32[]) -> f32[] { + %reduce_sum.355 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.356 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.357 = f32[]{:T(128)} add(%reduce_sum.355, %reduce_sum.356), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_40.45 (reduce_sum.280: f32[], reduce_sum.284: f32[]) -> f32[] { + %reduce_sum.280 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.284 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.285 = f32[]{:T(128)} add(%reduce_sum.280, %reduce_sum.284), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.332 (param_0.1119: f32[4096,4], param_1.1279: f32[], param_2.1100: f32[], param_3.783: f32[], param_4.490: f32[4096,4], param_5.421: f32[], param_6.294: f32[4,4096], param_7.192: pred[], param_8.110: f32[4096,4]) -> (f32[], f32[4096,4], f32[4096,4], f32[4096,4], f32[]) { + %param_0.1119 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.783 = f32[]{:T(128)S(6)} parameter(3) + %mul.1854.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_3.783), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.192 = pred[]{:T(512)S(6)} parameter(7) + %select_n.266.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.192), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.294 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(6) + %bitcast.419.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_6.294), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.421 = f32[]{:T(128)} parameter(5) + %div.773.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_5.421), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.772.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%bitcast.419.clone.1, %div.773.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.265.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%select_n.266.clone.1, %bitcast.419.clone.1, %div.772.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.935.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.539.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.935.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.1858.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.265.clone.1, %broadcast.539.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.110 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(8) + %constant.939.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.538.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.939.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.1857.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_8.110, %broadcast.538.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.795.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1858.clone.1, %mul.1857.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1100 = f32[]{:T(128)S(6)} parameter(2) + %div.769.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_2.1100), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.66.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.265.clone.1, %select_n.265.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.938.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.537.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.938.clone.1), dimensions={}, metadata={op_name="broadcast.58"} + %mul.1856.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%integer_pow.66.clone.1, %broadcast.537.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.490 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.937.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.536.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.937.clone.1), dimensions={}, metadata={op_name="broadcast.57"} + %mul.1855.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_4.490, %broadcast.536.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.794.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1856.clone.1, %mul.1855.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1279 = f32[]{:T(128)S(6)} parameter(1) + %div.768.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_1.1279), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.767.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.794.clone.1, %div.768.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.64.clone.1 = f32[4096,4]{0,1:T(4,128)} sqrt(%div.767.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.936.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.534.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.936.clone.1), dimensions={}, metadata={op_name="broadcast.53"} + %add.793.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%sqrt.64.clone.1, %broadcast.534.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.189.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%div.769.clone.1, %add.793.clone.1), metadata={op_name="multiply.27"} + %div.766.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.795.clone.1, %multiply.189.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1853.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_0.1119, %broadcast.539.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.792.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%div.766.clone.1, %mul.1853.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1852.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%mul.1854.clone.1, %add.792.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.791.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%param_0.1119, %mul.1852.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.199 = f32[4096,4]{0,1:T(4,128)} multiply(%add.791.clone.1, %add.791.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.991 = f32[]{:T(128)} constant(0) + %reduce.113 = f32[]{:T(128)} reduce(%square.199, %constant.991), dimensions={0,1}, to_apply=%region_54.59, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.115.clone.1 = f32[]{:T(128)} reduce(%integer_pow.66.clone.1, %constant.991), dimensions={0,1}, to_apply=%region_40.45, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.144 = (f32[]{:T(128)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.113, %add.791.clone.1, %add.794.clone.1, %add.795.clone.1, %reduce.115.clone.1) +} + +%region_53.58 (reduce_sum.349: f32[], reduce_sum.350: f32[]) -> f32[] { + %reduce_sum.349 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.350 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.354 = f32[]{:T(128)} add(%reduce_sum.349, %reduce_sum.350), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_39.44 (reduce_sum.277: f32[], reduce_sum.278: f32[]) -> f32[] { + %reduce_sum.277 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.278 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.279 = f32[]{:T(128)} add(%reduce_sum.277, %reduce_sum.278), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.333 (param_0.1120: f32[4096,4], param_1.1280: f32[], param_2.1101: f32[], param_3.784: f32[], param_4.491: f32[4096,4], param_5.422: f32[], param_6.295: f32[4,4096], param_7.193: pred[], param_8.111: f32[4096,4]) -> (f32[], f32[4096,4], f32[4096,4], f32[4096,4], f32[]) { + %param_0.1120 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.784 = f32[]{:T(128)S(6)} parameter(3) + %mul.1861.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_3.784), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.193 = pred[]{:T(512)S(6)} parameter(7) + %select_n.270.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.193), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.295 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(6) + %bitcast.421.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_6.295), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.422 = f32[]{:T(128)} parameter(5) + %div.781.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_5.422), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.780.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%bitcast.421.clone.1, %div.781.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.269.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%select_n.270.clone.1, %bitcast.421.clone.1, %div.780.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.941.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.545.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.941.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.1865.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.269.clone.1, %broadcast.545.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.111 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(8) + %constant.945.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.544.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.945.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.1864.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_8.111, %broadcast.544.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.800.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1865.clone.1, %mul.1864.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1101 = f32[]{:T(128)S(6)} parameter(2) + %div.777.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_2.1101), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.67.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.269.clone.1, %select_n.269.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.944.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.543.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.944.clone.1), dimensions={}, metadata={op_name="broadcast.58"} + %mul.1863.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%integer_pow.67.clone.1, %broadcast.543.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.491 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.943.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.542.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.943.clone.1), dimensions={}, metadata={op_name="broadcast.57"} + %mul.1862.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_4.491, %broadcast.542.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.799.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1863.clone.1, %mul.1862.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1280 = f32[]{:T(128)S(6)} parameter(1) + %div.776.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_1.1280), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.775.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.799.clone.1, %div.776.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.65.clone.1 = f32[4096,4]{0,1:T(4,128)} sqrt(%div.775.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.942.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.540.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.942.clone.1), dimensions={}, metadata={op_name="broadcast.53"} + %add.798.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%sqrt.65.clone.1, %broadcast.540.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.190.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%div.777.clone.1, %add.798.clone.1), metadata={op_name="multiply.26"} + %div.774.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.800.clone.1, %multiply.190.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1860.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_0.1120, %broadcast.545.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.797.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%div.774.clone.1, %mul.1860.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1859.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%mul.1861.clone.1, %add.797.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.796.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%param_0.1120, %mul.1859.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.200 = f32[4096,4]{0,1:T(4,128)} multiply(%add.796.clone.1, %add.796.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.992 = f32[]{:T(128)} constant(0) + %reduce.114 = f32[]{:T(128)} reduce(%square.200, %constant.992), dimensions={0,1}, to_apply=%region_53.58, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.116.clone.1 = f32[]{:T(128)} reduce(%integer_pow.67.clone.1, %constant.992), dimensions={0,1}, to_apply=%region_39.44, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.145 = (f32[]{:T(128)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.114, %add.796.clone.1, %add.799.clone.1, %add.800.clone.1, %reduce.116.clone.1) +} + +%region_9.12 (reduce_sum.135: f32[], reduce_sum.136: f32[]) -> f32[] { + %reduce_sum.135 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.136 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.137 = f32[]{:T(128)} add(%reduce_sum.135, %reduce_sum.136), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.344 (param_0.1134: bf16[4096]) -> f32[] { + %param_0.1134 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %convert_element_type.1022 = f32[4096]{0:T(1024)} convert(%param_0.1134), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %square.203 = f32[4096]{0:T(1024)} multiply(%convert_element_type.1022, %convert_element_type.1022), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1006 = f32[]{:T(128)} constant(0) + ROOT %reduce.117 = f32[]{:T(128)} reduce(%square.203, %constant.1006), dimensions={0}, to_apply=%region_9.12, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%region_49.54 (reduce_sum.328: f32[], reduce_sum.329: f32[]) -> f32[] { + %reduce_sum.328 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.329 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.333 = f32[]{:T(128)} add(%reduce_sum.328, %reduce_sum.329), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_35.40 (reduce_sum.256: f32[], reduce_sum.257: f32[]) -> f32[] { + %reduce_sum.256 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.257 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.258 = f32[]{:T(128)} add(%reduce_sum.256, %reduce_sum.257), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.345 (param_0.1124: f32[4096], param_1.1284: f32[], param_2.1105: f32[], param_3.788: f32[], param_4.495: f32[4096], param_5.426: f32[], param_6.299: bf16[4096], param_7.197: pred[], param_8.115: f32[4096]) -> (f32[], f32[4096], f32[4096], f32[4096], f32[]) { + %param_0.1124 = f32[4096]{0:T(1024)S(1)} parameter(0) + %param_3.788 = f32[]{:T(128)S(6)} parameter(3) + %mul.1892.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_3.788), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.197 = pred[]{:T(512)S(6)} parameter(7) + %select_n.286.clone.1 = pred[4096]{0:T(1024)(128)(4,1)} broadcast(%param_7.197), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.299 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(6) + %convert_element_type.1037.clone.1 = f32[4096]{0:T(1024)} convert(%param_6.299), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_5.426 = f32[]{:T(128)} parameter(5) + %div.813.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_5.426), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.812.clone.1 = f32[4096]{0:T(1024)} divide(%convert_element_type.1037.clone.1, %div.813.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.285.clone.1 = f32[4096]{0:T(1024)} select(%select_n.286.clone.1, %convert_element_type.1037.clone.1, %div.812.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.965.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.561.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.965.clone.1), dimensions={}, metadata={op_name="broadcast.72"} + %mul.1898.clone.1 = f32[4096]{0:T(1024)} multiply(%select_n.285.clone.1, %broadcast.561.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.115 = f32[4096]{0:T(1024)S(1)} parameter(8) + %constant.969.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1899.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.969.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1897.clone.1 = f32[4096]{0:T(1024)} multiply(%param_8.115, %mul.1899.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.822.clone.1 = f32[4096]{0:T(1024)S(1)} add(%mul.1898.clone.1, %mul.1897.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1105 = f32[]{:T(128)S(6)} parameter(2) + %div.809.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_2.1105), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.71.clone.1 = f32[4096]{0:T(1024)} multiply(%select_n.285.clone.1, %select_n.285.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.968.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1896.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.968.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1894.clone.1 = f32[4096]{0:T(1024)} multiply(%integer_pow.71.clone.1, %mul.1896.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.495 = f32[4096]{0:T(1024)S(1)} parameter(4) + %constant.967.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1895.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.967.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1893.clone.1 = f32[4096]{0:T(1024)} multiply(%param_4.495, %mul.1895.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.821.clone.1 = f32[4096]{0:T(1024)S(1)} add(%mul.1894.clone.1, %mul.1893.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1284 = f32[]{:T(128)S(6)} parameter(1) + %div.808.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_1.1284), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.807.clone.1 = f32[4096]{0:T(1024)} divide(%add.821.clone.1, %div.808.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.69.clone.1 = f32[4096]{0:T(1024)} sqrt(%div.807.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.966.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.820.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.966.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.819.clone.1 = f32[4096]{0:T(1024)} add(%sqrt.69.clone.1, %add.820.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.194.clone.1 = f32[4096]{0:T(1024)} multiply(%div.809.clone.1, %add.819.clone.1), metadata={op_name="multiply.22"} + %div.806.clone.1 = f32[4096]{0:T(1024)} divide(%add.822.clone.1, %multiply.194.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1891.clone.1 = f32[4096]{0:T(1024)} multiply(%param_0.1124, %broadcast.561.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.818.clone.1 = f32[4096]{0:T(1024)} add(%div.806.clone.1, %mul.1891.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1890.clone.1 = f32[4096]{0:T(1024)} multiply(%mul.1892.clone.1, %add.818.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.817.clone.1 = f32[4096]{0:T(1024)S(1)} add(%param_0.1124, %mul.1890.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.204 = f32[4096]{0:T(1024)} multiply(%add.817.clone.1, %add.817.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.996 = f32[]{:T(128)} constant(0) + %reduce.118 = f32[]{:T(128)} reduce(%square.204, %constant.996), dimensions={0}, to_apply=%region_49.54, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.119.clone.1 = f32[]{:T(128)} reduce(%integer_pow.71.clone.1, %constant.996), dimensions={0}, to_apply=%region_35.40, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.148 = (f32[]{:T(128)}, f32[4096]{0:T(1024)S(1)}, f32[4096]{0:T(1024)S(1)}, f32[4096]{0:T(1024)S(1)}, f32[]{:T(128)}) tuple(%reduce.118, %add.817.clone.1, %add.821.clone.1, %add.822.clone.1, %reduce.119.clone.1) +} + +%fused_computation.351 (param_0.971: s32[512]) -> s32[1024] { + %constant.793 = s32[] constant(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %broadcast.500 = s32[1024]{0:T(1024)} broadcast(%constant.793), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %param_0.971 = s32[512]{0:T(512)S(1)} parameter(0) + %constant.794 = s32[] constant(2147483647), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %pad.41 = s32[1024]{0:T(1024)} pad(%param_0.971, %constant.794), padding=0_512, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %constant.792 = s32[] constant(128255), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %broadcast.499 = s32[1024]{0:T(1024)} broadcast(%constant.792), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %clamp.1 = s32[1024]{0:T(1024)} clamp(%broadcast.500, %pad.41, %broadcast.499), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} +} + +%fused_computation.352 (param_0.970: s32[4,128]) -> s32[512] { + %param_0.970 = s32[4,128]{1,0:T(4,128)} parameter(0) + %constant.882 = s32[]{:T(128)} constant(0) + %broadcast.508 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.882), dimensions={}, metadata={op_name="broadcast.81"} + %lt.32 = pred[4,128]{1,0:T(4,128)(4,1)} compare(%param_0.970, %broadcast.508), direction=LT, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/lt" stack_frame_id=0} + %constant.879 = s32[]{:T(128)} constant(128256) + %add.746 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.879), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} + %add.734 = s32[4,128]{1,0:T(4,128)} add(%param_0.970, %add.746), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} + %select_n.178 = s32[4,128]{1,0:T(4,128)} select(%lt.32, %add.734, %param_0.970), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/select_n" stack_frame_id=0} + ROOT %bitcast.376 = s32[512]{0:T(512)S(1)} bitcast(%select_n.178), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} +} + +%region_61.66 (reduce_sum.391: f32[], reduce_sum.392: f32[]) -> f32[] { + %reduce_sum.391 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.392 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.396 = f32[]{:T(128)} add(%reduce_sum.391, %reduce_sum.392), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_48.53 (reduce_sum.322: f32[], reduce_sum.326: f32[]) -> f32[] { + %reduce_sum.322 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.326 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.327 = f32[]{:T(128)} add(%reduce_sum.322, %reduce_sum.326), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.353 (param_0.1135: bf16[4,128], param_1.1291: f32[4,128], param_2.1108: f32[4,128], param_3.790: s32[4,128]) -> (f32[], f32[], pred[4,128], f32[4,128]) { + %param_3.790 = s32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %constant.971.clone.1 = s32[]{:T(128)} constant(0) + %broadcast.562.clone.1 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.971.clone.1), dimensions={}, metadata={op_name="broadcast.81"} + %ne.6.clone.1 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} compare(%param_3.790, %broadcast.562.clone.1), direction=NE, metadata={op_name="jit(train_step)/jvp()/ne" stack_frame_id=0} + %param_1.1291 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %log.16 = f32[4,128]{1,0:T(4,128)} log(%param_1.1291), metadata={op_name="jit(train_step)/jvp()/log" stack_frame_id=0} + %param_0.1135 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(0) + %reduce_max.15 = f32[4,128]{1,0:T(4,128)} convert(%param_0.1135), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} + %add.748 = f32[4,128]{1,0:T(4,128)} add(%log.16, %reduce_max.15), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} + %square.207 = f32[4,128]{1,0:T(4,128)} multiply(%add.748, %add.748), metadata={op_name="jit(train_step)/jvp()/square" stack_frame_id=0} + %constant.1008 = f32[]{:T(128)} constant(0) + %broadcast.502 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1008), dimensions={}, metadata={op_name="broadcast.32"} + %mul.1791 = f32[4,128]{1,0:T(4,128)} multiply(%square.207, %broadcast.502), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %mul.1783 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %mul.1791, %broadcast.502), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %reduce.120 = f32[]{:T(128)} reduce(%mul.1783, %constant.1008), dimensions={0,1}, to_apply=%region_61.66, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %param_2.1108 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %neg.115.clone.1 = f32[4,128]{1,0:T(4,128)} negate(%param_2.1108), metadata={op_name="jit(train_step)/jvp()/neg" stack_frame_id=0} + %add.735.clone.1 = f32[4,128]{1,0:T(4,128)} add(%neg.115.clone.1, %mul.1791), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} + %mul.1784.clone.1 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %add.735.clone.1, %broadcast.502), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %reduce.121.clone.1 = f32[]{:T(128)} reduce(%mul.1784.clone.1, %constant.1008), dimensions={0,1}, to_apply=%region_48.53, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %mul.1789.clone.1 = f32[4,128]{1,0:T(4,128)} multiply(%add.748, %broadcast.502), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %constant.883.clone.1 = f32[]{:T(128)} constant(1) + %add.743.clone.1 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.883.clone.1), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} + %add.736.clone.1 = f32[4,128]{1,0:T(4,128)S(1)} add(%mul.1789.clone.1, %add.743.clone.1), metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} + ROOT %tuple.149 = (f32[]{:T(128)}, f32[]{:T(128)}, pred[4,128]{1,0:T(4,128)(4,1)S(1)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%reduce.120, %reduce.121.clone.1, %ne.6.clone.1, %add.736.clone.1) +} + +%fused_computation.356 (param_0.994: f32[4,128], param_1.1105: f32[4,128]) -> f32[4,128] { + %param_0.994 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %param_1.1105 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %constant.873 = f32[]{:T(128)} constant(0.000244140625) + %broadcast.510 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.873), dimensions={}, metadata={op_name="broadcast.245"} + %div.656 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1105, %broadcast.510), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.871 = f32[]{:T(128)} constant(1e-05) + %add.756 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.871), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %add.755 = f32[4,128]{1,0:T(4,128)} add(%div.656, %add.756), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %rsqrt.90 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.755), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} + %div.649 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.90, %add.755), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.856 = f32[]{:T(128)} constant(-0.5) + %mul.1795 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.856), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1788 = f32[4,128]{1,0:T(4,128)} multiply(%div.649, %mul.1795), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1787 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.994, %mul.1788), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.855 = f32[]{:T(128)} constant(0.00048828125) + %mul.1794 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.855), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.1786 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.1787, %mul.1794), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%region_5.8 (reduce_sum.120: s32[], reduce_sum.121: s32[]) -> s32[] { + %reduce_sum.120 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.121 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.122 = s32[]{:T(128)} add(%reduce_sum.120, %reduce_sum.121), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["0","2"]}]}} +} + +%fused_computation.359 (param_0.1011: pred[4,128]) -> s32[] { + %param_0.1011 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %convert_element_type.1029 = s32[4,128]{1,0:T(4,128)} convert(%param_0.1011), metadata={op_name="jit(train_step)/jvp()/convert_element_type" stack_frame_id=0} + %constant.881 = s32[]{:T(128)} constant(0) + ROOT %reduce.122 = s32[]{:T(128)} reduce(%convert_element_type.1029, %constant.881), dimensions={0,1}, to_apply=%region_5.8, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} +} + +%fused_computation.361 (param_0.997: f32[4,128]) -> f32[4,128] { + %param_0.997 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.874 = f32[]{:T(128)} constant(0.000244140625) + %broadcast.506 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.874), dimensions={}, metadata={op_name="broadcast.245"} + %div.654 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.997, %broadcast.506), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.872 = f32[]{:T(128)} constant(1e-05) + %add.745 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.872), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %add.742 = f32[4,128]{1,0:T(4,128)} add(%div.654, %add.745), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + ROOT %rsqrt.88 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.742), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} +} + +%fused_computation.362 (param_0.996: pred[4,128], param_1.1290: f32[]) -> f32[4,128] { + %param_0.996 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %param_1.1290 = f32[]{:T(128)S(6)} parameter(1) + %broadcast_in_dim.283 = f32[4,128]{1,0:T(4,128)} broadcast(%param_1.1290), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/broadcast_in_dim" stack_frame_id=0} + %constant.1007 = f32[]{:T(128)} constant(0) + %broadcast.504 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1007), dimensions={}, metadata={op_name="broadcast.32"} + ROOT %mul.1796 = f32[4,128]{1,0:T(4,128)S(1)} select(%param_0.996, %broadcast_in_dim.283, %broadcast.504), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} +} + +%fused_computation.364 () -> f32[64] { + %constant.877 = f32[]{:T(128)} constant(500000) + %broadcast.513 = f32[64]{0:T(128)} broadcast(%constant.877), dimensions={}, metadata={op_name="broadcast.236"} + %iota.46 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/layers/iota" stack_frame_id=0} + %constant.876 = s32[]{:T(128)} constant(2) + %broadcast.512 = s32[64]{0:T(128)} broadcast(%constant.876), dimensions={}, metadata={op_name="broadcast.237"} + %mul.1797 = s32[64]{0:T(128)} multiply(%iota.46, %broadcast.512), metadata={op_name="jit(train_step)/layers/mul" stack_frame_id=0} + %convert_element_type.1030 = f32[64]{0:T(128)} convert(%mul.1797), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %constant.875 = f32[]{:T(128)} constant(0.0078125) + %broadcast.511 = f32[64]{0:T(128)} broadcast(%constant.875), dimensions={}, metadata={op_name="broadcast.238"} + %div.657 = f32[64]{0:T(128)} multiply(%convert_element_type.1030, %broadcast.511), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + ROOT %pow.36 = f32[64]{0:T(128)S(1)} power(%broadcast.513, %div.657), metadata={op_name="jit(train_step)/layers/pow" stack_frame_id=0} +} + +%fused_computation.365 (param_0.1009: s32[4,128]) -> (f32[4,128,1,1], f32[4,128]) { + %param_0.1009 = s32[4,128]{1,0:T(4,128)} parameter(0) + %convert_element_type.1031 = f32[4,128]{1,0:T(4,128)S(1)} convert(%param_0.1009), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %bitcast.377 = f32[4,128,1,1]{1,0,3,2:T(4,128)} bitcast(%convert_element_type.1031), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.151 = (f32[4,128,1,1]{1,0,3,2:T(4,128)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%bitcast.377, %convert_element_type.1031) +} + +%fused_computation.369 (param_0.1110: f32[4096,4]) -> bf16[4,4096] { + %param_0.1110 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) + %bitcast.451 = f32[4,4096]{1,0:T(4,128)} bitcast(%param_0.1110), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.69 = bf16[4,4096]{1,0:T(4,128)(2,1)S(1)} convert(%bitcast.451) +} + +%fused_computation.370 (param_0.1111: f32[4096,4]) -> bf16[4,4096] { + %param_0.1111 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) + %bitcast.452 = f32[4,4096]{1,0:T(4,128)} bitcast(%param_0.1111), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.71 = bf16[4,4096]{1,0:T(4,128)(2,1)} convert(%bitcast.452) +} + +%region_6.9 (reduce_max.6: bf16[], reduce_max.8: bf16[]) -> bf16[] { + %reduce_max.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_max"} + %reduce_max.8 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_max"} + ROOT %reduce_max.9 = bf16[]{:T(256)} maximum(%reduce_max.6, %reduce_max.8), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.237.clone.clone (param_0.1097: f32[4096,128256]) -> bf16[4096,128256,1] { + %param_0.1097 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %convert_element_type.1042 = bf16[4096,128256]{1,0:T(8,128)(2,1)} convert(%param_0.1097), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + ROOT %bitcast.447 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} bitcast(%convert_element_type.1042), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} +} + +%fused_computation.317.clone.clone (param_0.1098: f32[4,128], param_1.1261: bf16[4,128,4096], param_2.1071: bf16[4096]) -> bf16[4,128,4096] { + %param_1.1261 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1044 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1261), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1098 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1916 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1098), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1915 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1044, %mul.1916), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1043 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1915), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.1071 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.1917 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1071), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.1914 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1043, %mul.1917), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%fused_computation.371 (param_0.1112: f32[4096,128256], param_1.1272: f32[4,128], param_2.1093: bf16[4,128,4096], param_3.776: bf16[4096]) -> (bf16[4,128], bf16[4,128,128256]) { + %param_1.1272 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1093 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.776 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.240.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1272, %param_2.1093, %param_3.776), kind=kLoop, calls=%fused_computation.317.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.1112 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %fusion.221.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1112), kind=kLoop, calls=%fused_computation.237.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %convolution.87.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convolution(%fusion.240.clone.1, %fusion.221.clone.1), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %constant.984 = bf16[]{:T(256)} constant(-inf) + %reduce.123 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} reduce(%convolution.87.clone.1, %constant.984), dimensions={2}, to_apply=%region_6.9, metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} + ROOT %tuple.152 = (bf16[4,128]{1,0:T(4,128)(2,1)S(1)}, bf16[4,128,128256]{2,1,0:T(8,128)(2,1)}) tuple(%reduce.123, %convolution.87.clone.1) +} + +%fused_computation.372 (param_0.1109: f32[4096,4,8,128]) -> bf16[4,4096,8,128] { + %param_0.1109 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %bitcast.450 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} bitcast(%param_0.1109), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.73 = bf16[4,4096,8,128]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.450) +} + +%convert_element_type.541.reduce_sub_computation (lhs: bf16[], rhs: bf16[]) -> bf16[] { + %lhs = bf16[] parameter(0) + %rhs = bf16[] parameter(1) + ROOT %add.609 = bf16[] add(%lhs, %rhs), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.122.clone.clone (param_0.1249: bf16[4,4096], param_1.1380: s32[]) -> bf16[4096] { + %param_0.1249 = bf16[4,4096]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1380 = s32[]{:T(128)S(6)} parameter(1) + %constant.1108 = s32[]{:T(128)} constant(0) + %dynamic_slice.316 = bf16[1,4096]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1249, %param_1.1380, %constant.1108), dynamic_slice_sizes={1,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %constant.1109 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %reduce.135 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.316, %constant.1109), dimensions={0}, to_apply=%convert_element_type.541.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%region_12.14 (reduce_sum.144: f32[], reduce_sum.148: f32[]) -> f32[] { + %reduce_sum.144 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.148 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.149 = f32[]{:T(128)} add(%reduce_sum.144, %reduce_sum.148), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.58.clone.clone (param_0.1250: bf16[4,4,128,4096], param_1.1381: s32[]) -> f32[4,128] { + %param_0.1250 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1381 = s32[]{:T(128)S(6)} parameter(1) + %constant.1110 = s32[]{:T(128)} constant(0) + %dynamic_slice.317 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1250, %param_1.1381, %constant.1110, %constant.1110, %constant.1110), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %bitcast.548 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.317), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} + %convert_element_type.1109 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.548), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.214 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1109, %convert_element_type.1109), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1111 = f32[]{:T(128)} constant(0) + ROOT %reduce.136 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.214, %constant.1111), dimensions={2}, to_apply=%region_12.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} +} + +%fused_computation.143.clone.1.clone (param_0.1251: f32[4,128]) -> f32[4,128] { + %param_0.1251 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1113 = f32[]{:T(128)} constant(0.000244140625) + %closed_call.81 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1113), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.842 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1251, %closed_call.81), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1112 = f32[]{:T(128)} constant(1e-05) + %closed_call.80 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1112), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.844 = f32[4,128]{1,0:T(4,128)} add(%div.842, %closed_call.80), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.97 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.844), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%fused_computation.24.clone.1.clone.clone (param_0.1265: bf16[4,4096,32,128], param_1.1391: s32[]) -> bf16[4096,32,128,1] { + %param_0.1265 = bf16[4,4096,32,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1391 = s32[]{:T(128)S(6)} parameter(1) + %constant.1126 = s32[]{:T(128)} constant(0) + %dynamic_slice.323 = bf16[1,4096,32,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1265, %param_1.1391, %constant.1126, %constant.1126, %constant.1126), dynamic_slice_sizes={1,4096,32,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.559 = bf16[4096,32,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.323), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.91.clone.clone (param_0.1266: bf16[4096], param_1.1392: f32[4,128], param_2.1170: bf16[4,4,128,4096], param_3.835: s32[]) -> bf16[4,128,4096,1] { + %param_2.1170 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.835 = s32[]{:T(128)S(6)} parameter(3) + %constant.1127 = s32[]{:T(128)} constant(0) + %dynamic_slice.324 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1170, %param_3.835, %constant.1127, %constant.1127, %constant.1127), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %bitcast.561 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.324), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} + %convert_element_type.1117 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.561), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1392 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2081 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1392), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2080 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1117, %mul.2081), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1116 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2080), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1266 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %mul.2079 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_0.1266), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2078 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1116, %mul.2079), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.560 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2078), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.36.clone.clone (param_0.1267: bf16[4,4096,32,128], param_1.1393: s32[], param_2.1171: f32[4,128], param_3.836: bf16[4,4,128,4096], param_4.524: bf16[4096]) -> bf16[4,128,32,128] { + %param_4.524 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %param_2.1171 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.836 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1393 = s32[]{:T(128)S(6)} parameter(1) + %fusion.343 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_4.524, %param_2.1171, %param_3.836, %param_1.1393), kind=kLoop, calls=%fused_computation.91.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1267 = bf16[4,4096,32,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.342 = bf16[4096,32,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1267, %param_1.1393), kind=kLoop, calls=%fused_computation.24.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %convolution.113 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.343, %fusion.342), window={size=1x32 pad=0_0x31_31 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} +} + +%fused_computation.70.clone.clone (param_0.1268: bf16[4,128,32,128]) -> (bf16[4,128,32,64], bf16[4,128,32,64]) { + %param_0.1268 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.160 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1268), slice={[0:4], [0:128], [0:32], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + %neg.129 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.160), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=0} + %split.161 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1268), slice={[0:4], [0:128], [0:32], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + ROOT %tuple.187 = (bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.129, %split.161) +} + +%fused_computation.145.clone.clone () -> f32[64] { + %constant.1116 = f32[]{:T(128)} constant(500000) + %closed_call.84 = f32[64]{0:T(128)} broadcast(%constant.1116), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %iota.51 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/iota" stack_frame_id=0} + %constant.1115 = s32[]{:T(128)} constant(2) + %closed_call.83 = s32[64]{0:T(128)} broadcast(%constant.1115), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %mul.2065 = s32[64]{0:T(128)} multiply(%iota.51, %closed_call.83), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1110 = f32[64]{0:T(128)} convert(%mul.2065), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %constant.1114 = f32[]{:T(128)} constant(0.0078125) + %closed_call.82 = f32[64]{0:T(128)} broadcast(%constant.1114), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.843 = f32[64]{0:T(128)} multiply(%convert_element_type.1110, %closed_call.82), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + ROOT %pow.38 = f32[64]{0:T(128)S(1)} power(%closed_call.84, %div.843), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/pow" stack_frame_id=0} +} + +%fused_computation.117.clone.clone (param_0.1252: f32[64], param_1.1382: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.1382 = f32[4,128]{1,0:T(4,128)} parameter(1) + %div.846 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1382), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %param_0.1252 = f32[64]{0:T(128)S(1)} parameter(0) + %div.845 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1252), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %div.844 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.846, %div.845), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %cos.43 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.844), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/cos" stack_frame_id=0} + %convert_element_type.1111 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.43), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %sin.35.clone.3 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.844), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/sin" stack_frame_id=0} + %convert_element_type.845.clone.3 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.35.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.185 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1111, %convert_element_type.845.clone.3) +} + +%fused_computation.120.clone.clone (param_0.1259: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1259 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1122 = bf16[]{:T(256)} constant(-inf) + %pad.61 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1259, %constant.1122), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %pad.60 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1259, %constant.1122), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %maximum.45 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.61, %pad.60), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + ROOT %bitcast.554 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.45), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.119.clone.clone (param_0.1253: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1253 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1117 = bf16[]{:T(256)} constant(-inf) + %pad.59 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1253, %constant.1117), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %pad.58 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1253, %constant.1117), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %maximum.44 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.59, %pad.58), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + ROOT %bitcast.549 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.44), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.73.clone.clone (param_0.1269: bf16[4,128,32,64], param_1.1394: bf16[4,128,32,64], param_2.1172: bf16[4,128,32,128], param_3.837: bf16[4,128,128], param_4.525: bf16[4,128,128]) -> bf16[4,32,128,128] { + %param_2.1172 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(2) + %param_4.525 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(4) + %mul.2085 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.525), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2083 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%param_2.1172, %mul.2085), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_1.1394 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1128 = bf16[]{:T(256)} constant(-inf) + %pad.65 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1394, %constant.1128), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_0.1269 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.64 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1269, %constant.1128), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %maximum.47 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.65, %pad.64), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_3.837 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2084 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.837), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2082 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.47, %mul.2084), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %add.846 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2083, %mul.2082), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %bitcast.562 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.846), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.90.clone.clone (param_0.1261: bf16[4096], param_1.1388: f32[4,128], param_2.1167: bf16[4,4,128,4096], param_3.832: s32[]) -> bf16[4,128,4096,1] { + %param_2.1167 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.832 = s32[]{:T(128)S(6)} parameter(3) + %constant.1124 = s32[]{:T(128)} constant(0) + %dynamic_slice.322 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1167, %param_3.832, %constant.1124, %constant.1124, %constant.1124), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %bitcast.557 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.322), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} + %convert_element_type.1115 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.557), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1388 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2073 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1388), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2072 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1115, %mul.2073), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1114 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2072), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1261 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %mul.2071 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_0.1261), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2070 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1114, %mul.2071), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.556 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2070), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.64.clone.1.clone.clone (param_0.1260: bf16[4,4096,8,128], param_1.1387: s32[]) -> bf16[4096,8,128,1] { + %param_0.1260 = bf16[4,4096,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1387 = s32[]{:T(128)S(6)} parameter(1) + %constant.1123 = s32[]{:T(128)} constant(0) + %dynamic_slice.321 = bf16[1,4096,8,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1260, %param_1.1387, %constant.1123, %constant.1123, %constant.1123), dynamic_slice_sizes={1,4096,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.555 = bf16[4096,8,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.321), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.89.clone.clone (param_0.1262: bf16[4,4096,8,128], param_1.1389: s32[], param_2.1168: f32[4,128], param_3.833: bf16[4,4,128,4096], param_4.522: bf16[4096]) -> bf16[4,128,8,128] { + %param_4.522 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %param_2.1168 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.833 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1389 = s32[]{:T(128)S(6)} parameter(1) + %fusion.340 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_4.522, %param_2.1168, %param_3.833, %param_1.1389), kind=kLoop, calls=%fused_computation.90.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1262 = bf16[4,4096,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.341 = bf16[4096,8,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1262, %param_1.1389), kind=kLoop, calls=%fused_computation.64.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %convolution.112 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.340, %fusion.341), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} +} + +%fused_computation.106.clone.clone (param_0.1263: bf16[4,128,8,128]) -> (bf16[4,128,8,64], bf16[4,128,8,64]) { + %param_0.1263 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.158 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1263), slice={[0:4], [0:128], [0:8], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + %neg.128 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.158), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=0} + %split.159 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1263), slice={[0:4], [0:128], [0:8], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + ROOT %tuple.186 = (bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.128, %split.159) +} + +%fused_computation.109.clone.clone (param_0.1264: bf16[4,128,8,64], param_1.1390: bf16[4,128,8,64], param_2.1169: bf16[4,128,8,128], param_3.834: bf16[4,128,128], param_4.523: bf16[4,128,128]) -> bf16[4,8,128,128] { + %param_2.1169 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(2) + %param_4.523 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(4) + %mul.2077 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.523), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2075 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%param_2.1169, %mul.2077), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_1.1390 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1125 = bf16[]{:T(256)} constant(-inf) + %pad.63 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1390, %constant.1125), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_0.1264 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.62 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1264, %constant.1125), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %maximum.46 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.63, %pad.62), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_3.834 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2076 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.834), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2074 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.46, %mul.2076), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %add.845 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2075, %mul.2074), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %bitcast.558 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.845), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.135.clone.clone (param_0.1255: bf16[4,4096,8,128], param_1.1384: s32[]) -> bf16[1,4096,8,128] { + %param_0.1255 = bf16[4,4096,8,128]{3,2,0,1:T(8,128)(2,1)} parameter(0) + %param_1.1384 = s32[]{:T(128)S(6)} parameter(1) + %constant.1120 = s32[]{:T(128)} constant(0) + ROOT %dynamic_slice.319 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} dynamic-slice(%param_0.1255, %param_1.1384, %constant.1120, %constant.1120, %constant.1120), dynamic_slice_sizes={1,4096,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +} + +%fused_computation.65.clone.1.clone.clone.clone.clone (param_0.1256: bf16[1,4096,8,128]) -> bf16[4096,8,128,1] { + %param_0.1256 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %copy.248 = bf16[1,4096,8,128]{3,1,2,0:T(8,128)(2,1)} copy(%param_0.1256), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0} + ROOT %bitcast.550 = bf16[4096,8,128,1]{2,0,1,3:T(8,128)(2,1)} bitcast(%copy.248), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.88.clone.clone.clone.clone (param_0.1257: bf16[4096], param_1.1385: f32[4,128], param_2.1165: bf16[4,4,128,4096], param_3.830: s32[]) -> bf16[4,128,4096,1] { + %param_2.1165 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.830 = s32[]{:T(128)S(6)} parameter(3) + %constant.1121 = s32[]{:T(128)} constant(0) + %dynamic_slice.320 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1165, %param_3.830, %constant.1121, %constant.1121, %constant.1121), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %bitcast.552 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.320), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} + %convert_element_type.1113 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.552), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1385 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2069 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1385), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2068 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1113, %mul.2069), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1112 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2068), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1257 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %mul.2067 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_0.1257), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2066 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1112, %mul.2067), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.551 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2066), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.114.clone.clone (param_0.1258: bf16[1,4096,8,128], param_1.1386: f32[4,128], param_2.1166: bf16[4,4,128,4096], param_3.831: s32[], param_4.521: bf16[4096]) -> bf16[4,8,128,128] { + %param_4.521 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %param_1.1386 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1166 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.831 = s32[]{:T(128)S(6)} parameter(3) + %fusion.339 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_4.521, %param_1.1386, %param_2.1166, %param_3.831), kind=kLoop, calls=%fused_computation.88.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1258 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %fusion.338 = bf16[4096,8,128,1]{2,0,1,3:T(8,128)(2,1)} fusion(%param_0.1258), kind=kLoop, calls=%fused_computation.65.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.111 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convolution(%fusion.339, %fusion.338), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + ROOT %bitcast.553 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%convolution.111), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.366.clone.clone (param_0.1293: f32[4,32,128,128]) -> (f32[4,32,128,1], f32[4,32,128]) { + %param_0.1293 = f32[4,32,128,128]{2,1,0,3:T(8,128)S(1)} parameter(0) + %slice.11 = f32[4,32,128,1]{2,1,0,3:T(8,128)S(1)} slice(%param_0.1293), slice={[0:4], [0:32], [0:128], [0:1]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/slice" stack_frame_id=0} + %bitcast.262.clone.3 = f32[4,32,128]{2,1,0:T(8,128)S(1)} bitcast(%slice.11), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/squeeze" stack_frame_id=0} + ROOT %tuple.192 = (f32[4,32,128,1]{2,1,0,3:T(8,128)S(1)}, f32[4,32,128]{2,1,0:T(8,128)S(1)}) tuple(%slice.11, %bitcast.262.clone.3) +} + +%region_13.16 (reduce_sum.150: f32[], reduce_sum.151: f32[]) -> f32[] { + %reduce_sum.150 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.151 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.155 = f32[]{:T(128)} add(%reduce_sum.150, %reduce_sum.151), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.25.clone.1.clone.clone.clone.clone.clone.clone (param_0.1270: bf16[4,32,128,4096], param_1.1395: s32[]) -> bf16[32,128,4096,1] { + %param_0.1270 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1395 = s32[]{:T(128)S(6)} parameter(1) + %constant.1129 = s32[]{:T(128)} constant(0) + %dynamic_slice.325 = bf16[1,32,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1270, %param_1.1395, %constant.1129, %constant.1129, %constant.1129), dynamic_slice_sizes={1,32,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.563 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dynamic_slice.325), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.81.clone.clone.clone.clone.clone.clone (param_0.1271: bf16[4,32,128,128]) -> bf16[4,128,32,128] { + %param_0.1271 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.564 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} bitcast(%param_0.1271), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.61.clone.clone (param_0.1272: bf16[4,32,128,4096], param_1.1396: s32[], param_2.1173: bf16[4,32,128,128], param_3.838: bf16[4,4,128,4096]) -> (f32[4,128], bf16[4,128,4096]) { + %param_3.838 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1396 = s32[]{:T(128)S(6)} parameter(1) + %constant.357.clone.1.clone.3 = s32[]{:T(128)} constant(0) + %dynamic_slice.208.clone.3 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_3.838, %param_1.1396, %constant.357.clone.1.clone.3, %constant.357.clone.1.clone.3, %constant.357.clone.1.clone.3), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %bitcast.207.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.208.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} + %param_2.1173 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %fusion.83.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} fusion(%param_2.1173), kind=kLoop, calls=%fused_computation.81.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} + %param_0.1272 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %fusion.82.clone.3 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.1272, %param_1.1396), kind=kLoop, calls=%fused_computation.25.clone.1.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.62.clone.3 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} convolution(%fusion.83.clone.3, %fusion.82.clone.3), window={size=1x32}, dim_labels=0b1f_1io0->0bf1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %bitcast.182.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%convolution.62.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %add.621.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%bitcast.207.clone.3, %bitcast.182.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %convert_element_type.1118 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%add.621.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.215 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1118, %convert_element_type.1118), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1130 = f32[]{:T(128)} constant(0) + %reduce.138 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.215, %constant.1130), dimensions={2}, to_apply=%region_13.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.188 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.138, %add.621.clone.3) +} + +%convert_element_type.556.reduce_sub_computation (lhs.1: bf16[], rhs.1: bf16[]) -> bf16[] { + %lhs.1 = bf16[] parameter(0) + %rhs.1 = bf16[] parameter(1) + ROOT %add.610 = bf16[] add(%lhs.1, %rhs.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.121.clone.clone (param_0.1254: bf16[4,4096], param_1.1383: s32[]) -> bf16[4096] { + %param_0.1254 = bf16[4,4096]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1383 = s32[]{:T(128)S(6)} parameter(1) + %constant.1118 = s32[]{:T(128)} constant(0) + %dynamic_slice.318 = bf16[1,4096]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1254, %param_1.1383, %constant.1118), dynamic_slice_sizes={1,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %constant.1119 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %reduce.137 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.318, %constant.1119), dimensions={0}, to_apply=%convert_element_type.556.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.12.clone.clone.clone (param_0.1273: bf16[4,14336,4096], param_1.1397: s32[]) -> bf16[14336,4096,1] { + %param_0.1273 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1397 = s32[]{:T(128)S(6)} parameter(1) + %constant.1131 = s32[]{:T(128)} constant(0) + %dynamic_slice.326 = bf16[1,14336,4096]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1273, %param_1.1397, %constant.1131, %constant.1131), dynamic_slice_sizes={1,14336,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.566 = bf16[14336,4096,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.326), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%bitcast_fusion.3.clone.clone (bitcast_input.12: bf16[4,128,4096]) -> bf16[4,128,4096] { + %bitcast_input.12 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.565 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%bitcast_input.12) +} + +%fused_computation.13.clone.clone (param_0.1274: bf16[4,128,4096], param_1.1398: bf16[4,14336,4096], param_2.1174: s32[]) -> bf16[14336,4,128] { + %param_1.1398 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} parameter(1) + %param_2.1174 = s32[]{:T(128)S(6)} parameter(2) + %fusion.344 = bf16[14336,4096,1]{1,0,2:T(8,128)(2,1)} fusion(%param_1.1398, %param_2.1174), kind=kLoop, calls=%fused_computation.12.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1274 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %fusion.345 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1274), kind=kLoop, calls=%bitcast_fusion.3.clone.clone + ROOT %convolution.114 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} convolution(%fusion.344, %fusion.345), window={size=4 pad=3_3 rhs_reversal=1}, dim_labels=bf0_0oi->b0f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} +} + +%fused_computation.144.clone.1.clone (param_0.1275: f32[4,128]) -> f32[4,128] { + %param_0.1275 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1133 = f32[]{:T(128)} constant(0.000244140625) + %closed_call.86 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1133), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.847 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1275, %closed_call.86), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1132 = f32[]{:T(128)} constant(1e-05) + %closed_call.85 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1132), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.847 = f32[4,128]{1,0:T(4,128)} add(%div.847, %closed_call.85), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.98 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.847), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%fused_computation.11.clone.1.clone.clone (param_0.1279: bf16[4,4096,14336], param_1.1402: s32[]) -> bf16[4096,14336,1] { + %param_0.1279 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1402 = s32[]{:T(128)S(6)} parameter(1) + %constant.1135 = s32[]{:T(128)} constant(0) + %dynamic_slice.328 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1279, %param_1.1402, %constant.1135, %constant.1135), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.568 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.328), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.96.clone.2.clone.clone (param_0.1280: f32[4,128], param_1.1403: bf16[4,128,4096], param_2.1177: bf16[4096]) -> bf16[4,128,4096] { + %param_1.1403 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1122 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1403), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1280 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2092 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1280), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2091 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1122, %mul.2092), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1121 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2091), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1177 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2093 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1177), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2090 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1121, %mul.2093), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.23.clone.clone (param_0.1281: bf16[4,4096,14336], param_1.1404: s32[], param_2.1178: f32[4,128], param_3.840: bf16[4,128,4096], param_4.527: bf16[4096]) -> bf16[4,128,14336] { + %param_2.1178 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.840 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.527 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.349 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1178, %param_3.840, %param_4.527), kind=kLoop, calls=%fused_computation.96.clone.2.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1281 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1404 = s32[]{:T(128)S(6)} parameter(1) + %fusion.348 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1281, %param_1.1404), kind=kLoop, calls=%fused_computation.11.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %convolution.116 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.349, %fusion.348), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} +} + +%fused_computation.14.clone.1.clone.clone (param_0.1282: bf16[4,4096,14336], param_1.1405: s32[]) -> bf16[4096,14336,1] { + %param_0.1282 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1405 = s32[]{:T(128)S(6)} parameter(1) + %constant.1136 = s32[]{:T(128)} constant(0) + %dynamic_slice.329 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1282, %param_1.1405, %constant.1136, %constant.1136), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.569 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.329), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.39.clone.1.clone.clone (param_0.1283: bf16[14336,4,128], param_1.1406: bf16[4,128,14336]) -> bf16[4,128,14336] { + %param_1.1406 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1137 = bf16[]{:T(256)} constant(1) + %jit_silu_.44 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} broadcast(%constant.1137), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)" stack_frame_id=0} + %neg.130 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} negate(%param_1.1406), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/neg" stack_frame_id=0} + %exp.69 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} exponential(%neg.130), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/exp" stack_frame_id=0} + %add.848 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%exp.69, %jit_silu_.44), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/add" stack_frame_id=0} + %div.848 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} divide(%jit_silu_.44, %add.848), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/div" stack_frame_id=0} + %mul.2095 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1406, %div.848), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/mul" stack_frame_id=0} + %param_0.1283 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(0) + %bitcast.570 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} bitcast(%param_0.1283), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} + ROOT %mul.2094 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.2095, %bitcast.570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} +} + +%fused_computation.21.clone.clone (param_0.1284: bf16[4,4096,14336], param_1.1407: s32[], param_2.1179: bf16[14336,4,128], param_3.841: bf16[4,128,14336]) -> bf16[4,128,4096] { + %param_2.1179 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(2) + %param_3.841 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %bitcast_multiply_fusion.15 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1179, %param_3.841), kind=kLoop, calls=%fused_computation.39.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %param_0.1284 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1407 = s32[]{:T(128)S(6)} parameter(1) + %fusion.350 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1284, %param_1.1407), kind=kLoop, calls=%fused_computation.14.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %convolution.117 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} convolution(%bitcast_multiply_fusion.15, %fusion.350), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} +} + +%fused_computation.14.clone.clone.clone (param_0.1276: bf16[4,4096,14336], param_1.1399: s32[]) -> bf16[4096,14336,1] { + %param_0.1276 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1399 = s32[]{:T(128)S(6)} parameter(1) + %constant.1134 = s32[]{:T(128)} constant(0) + %dynamic_slice.327 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1276, %param_1.1399, %constant.1134, %constant.1134), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.567 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.327), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.96.clone.1.clone.clone (param_0.1277: f32[4,128], param_1.1400: bf16[4,128,4096], param_2.1175: bf16[4096]) -> bf16[4,128,4096] { + %param_1.1400 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1120 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1400), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1277 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2088 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1277), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2087 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1120, %mul.2088), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1119 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2087), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1175 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2089 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1175), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2086 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1119, %mul.2089), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.20.clone.clone (param_0.1278: bf16[4,4096,14336], param_1.1401: s32[], param_2.1176: f32[4,128], param_3.839: bf16[4,128,4096], param_4.526: bf16[4096]) -> bf16[4,128,14336] { + %param_2.1176 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.839 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.526 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.347 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1176, %param_3.839, %param_4.526), kind=kLoop, calls=%fused_computation.96.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1278 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1401 = s32[]{:T(128)S(6)} parameter(1) + %fusion.346 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1278, %param_1.1401), kind=kLoop, calls=%fused_computation.14.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %convolution.115 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.347, %fusion.346), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} +} + +%region_14.17 (reduce_sum.166: f32[], reduce_sum.167: f32[]) -> f32[] { + %reduce_sum.166 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} + %reduce_sum.167 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} + ROOT %reduce_sum.168 = f32[]{:T(128)} add(%reduce_sum.166, %reduce_sum.167), metadata={op_name="checkpoint/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.11.clone.clone.clone.clone.clone.clone.clone (param_0.1285: bf16[4,4096,14336], param_1.1408: s32[]) -> bf16[4096,14336,1] { + %param_0.1285 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1408 = s32[]{:T(128)S(6)} parameter(1) + %constant.1138 = s32[]{:T(128)} constant(0) + %dynamic_slice.330 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1285, %param_1.1408, %constant.1138, %constant.1138), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.571 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.330), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.38.clone.1.clone.clone.clone.clone (param_0.1286: bf16[4,128,14336], param_1.1409: bf16[4,128,14336], param_2.1180: bf16[14336,4,128]) -> bf16[4,128,14336] { + %param_2.1180 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(2) + %bitcast.572 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} bitcast(%param_2.1180), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} + %param_1.1409 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %mul.2100 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%bitcast.572, %param_1.1409), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %constant.1139 = bf16[]{:T(256)} constant(1) + %jit_silu_.45 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} broadcast(%constant.1139), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)" stack_frame_id=0} + %param_0.1286 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %neg.131 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} negate(%param_0.1286), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/neg" stack_frame_id=0} + %exp.70 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} exponential(%neg.131), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/exp" stack_frame_id=0} + %add.849 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%exp.70, %jit_silu_.45), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/add" stack_frame_id=0} + %div.849 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} divide(%jit_silu_.45, %add.849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/div" stack_frame_id=0} + %mul.2099 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.2100, %div.849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} + %mul.2098 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%param_0.1286, %mul.2100), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} + %sub.98 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} subtract(%jit_silu_.45, %div.849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/sub" stack_frame_id=0} + %mul.2097 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%div.849, %sub.98), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/mul" stack_frame_id=0} + %mul.2096 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.2098, %mul.2097), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} + ROOT %add_any.145 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%mul.2099, %mul.2096), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/add_any" stack_frame_id=0} +} + +%fused_computation.63.clone.clone (param_0.1287: bf16[4,128,4096], param_1.1410: bf16[4096], param_2.1181: bf16[4,128,4096], param_3.842: bf16[4,4096,14336], param_4.528: s32[], param_5.435: bf16[4,128,14336], param_6.305: bf16[4,128,14336], param_7.200: bf16[14336,4,128]) -> (f32[4,128], bf16[4,128,4096]) { + %param_0.1287 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1124 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1287), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1181 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_5.435 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(5) + %param_6.305 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(6) + %param_7.200 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(7) + %fusion.134.clone.3 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} fusion(%param_5.435, %param_6.305, %param_7.200), kind=kLoop, calls=%fused_computation.38.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/add_any" stack_frame_id=0} + %param_3.842 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.528 = s32[]{:T(128)S(6)} parameter(4) + %fusion.79.clone.3 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_3.842, %param_4.528), kind=kLoop, calls=%fused_computation.11.clone.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.60.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convolution(%fusion.134.clone.3, %fusion.79.clone.3), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} + %add_any.132.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%param_2.1181, %convolution.60.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} + %param_1.1410 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(1) + %mul.2103 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_1.1410), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2102 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%add_any.132.clone.3, %mul.2103), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %convert_element_type.1123 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%mul.2102), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} + %mul.2101 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1124, %convert_element_type.1123), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %constant.1140 = f32[]{:T(128)} constant(0) + %reduce.139 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.2101, %constant.1140), dimensions={2}, to_apply=%region_14.17, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.189 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.139, %add_any.132.clone.3) +} + +%fused_computation.140.clone.clone (param_0.1288: f32[4,128], param_1.1411: f32[4,128]) -> f32[4,128] { + %param_0.1288 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %param_1.1411 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %constant.1144 = f32[]{:T(128)} constant(0.000244140625) + %closed_call.89 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1144), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.851 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1411, %closed_call.89), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1143 = f32[]{:T(128)} constant(1e-05) + %closed_call.88 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1143), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.850 = f32[4,128]{1,0:T(4,128)} add(%div.851, %closed_call.88), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %rsqrt.99 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.850), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} + %div.850 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.99, %add.850), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1142 = f32[]{:T(128)} constant(-0.5) + %closed_call.87 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1142), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %mul.2106 = f32[4,128]{1,0:T(4,128)} multiply(%div.850, %closed_call.87), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2105 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1288, %mul.2106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %constant.1141 = f32[]{:T(128)} constant(0.00048828125) + %mul.2107 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1141), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + ROOT %mul.2104 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.2105, %mul.2107), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} +} + +%region_20.24 (reduce_sum.175: bf16[], reduce_sum.179: bf16[]) -> bf16[] { + %reduce_sum.175 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} + %reduce_sum.179 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} + ROOT %reduce_sum.180 = bf16[]{:T(256)} add(%reduce_sum.175, %reduce_sum.179), metadata={op_name="checkpoint/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.93.clone.clone (param_0.1289: bf16[4,128,4096], param_1.1412: f32[4,128], param_2.1182: bf16[4,128,4096], param_3.843: bf16[4,128,4096], param_4.529: f32[4,128], param_5.436: bf16[4096]) -> (bf16[4096], bf16[4,128,4096]) { + %param_2.1182 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %convert_element_type.1126 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_2.1182), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1412 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2110 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1412), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2109 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1126, %mul.2110), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1125 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2109), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1289 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %mul.2108 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1125, %param_0.1289), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %constant.1145 = bf16[]{:T(256)} constant(0) + %reduce.140 = bf16[4096]{0:T(1024)(128)(2,1)} reduce(%mul.2108, %constant.1145), dimensions={0,1}, to_apply=%region_20.24, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum" stack_frame_id=0} + %param_3.843 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_5.436 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(5) + %mul.1399.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_5.436), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.1353.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_0.1289, %mul.1399.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %convert_element_type.769.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%mul.1353.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} + %mul.1333.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.769.clone.3, %mul.2110), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %param_4.529 = f32[4,128]{1,0:T(4,128)S(1)} parameter(4) + %mul.1344.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_4.529), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %mul.1332.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1126, %mul.1344.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %add_any.126.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} add(%mul.1333.clone.3, %mul.1332.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} + %convert_element_type.767.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%add_any.126.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} + %add_any.124.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%param_3.843, %convert_element_type.767.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} + ROOT %tuple.190 = (bf16[4096]{0:T(1024)(128)(2,1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.140, %add_any.124.clone.3) +} + +%region_15.18 (dot_general.157: f32[], dot_general.158: f32[]) -> f32[] { + %dot_general.157 = f32[]{:T(128)} parameter(0), metadata={op_name="vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general"} + %dot_general.158 = f32[]{:T(128)} parameter(1), metadata={op_name="vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general"} + ROOT %add.157 = f32[]{:T(128)} add(%dot_general.157, %dot_general.158), metadata={op_name="add.31"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.25.clone.clone.clone.clone.clone.clone.clone (param_0.1290: bf16[4,32,128,4096], param_1.1413: s32[]) -> bf16[32,128,4096,1] { + %param_0.1290 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1413 = s32[]{:T(128)S(6)} parameter(1) + %constant.1146 = s32[]{:T(128)} constant(0) + %dynamic_slice.331 = bf16[1,32,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1290, %param_1.1413, %constant.1146, %constant.1146, %constant.1146), dynamic_slice_sizes={1,32,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.573 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dynamic_slice.331), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.76.clone.clone.clone.clone.clone.clone (param_0.1291: bf16[4,128,4096]) -> bf16[4,128,4096,1] { + %param_0.1291 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.574 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%param_0.1291), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} +} + +%fused_computation.66.clone.clone (param_0.1292: bf16[4,32,128,128], param_1.1414: bf16[4,32,128,4096], param_2.1183: s32[], param_3.844: bf16[4,128,4096]) -> (f32[4,32,128], bf16[4,32,128,128]) { + %param_0.1292 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert.87 = f32[4,32,128,128]{3,2,1,0:T(8,128)} convert(%param_0.1292), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/convert" stack_frame_id=0} + %param_3.844 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %fusion.95.clone.3 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_3.844), kind=kLoop, calls=%fused_computation.76.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} + %param_1.1414 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) + %param_2.1183 = s32[]{:T(128)S(6)} parameter(2) + %fusion.94.clone.3 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_1.1414, %param_2.1183), kind=kLoop, calls=%fused_computation.25.clone.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.64.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} convolution(%fusion.95.clone.3, %fusion.94.clone.3), window={size=1x32 pad=0_0x31_31 rhs_reversal=0x1}, dim_labels=0bf1_1oi0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} + %constant.611.clone.3 = bf16[]{:T(256)} constant(0.25) + %div.442.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%constant.611.clone.3), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/div" stack_frame_id=0} + %div.441.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%convolution.64.clone.3, %div.442.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/div" stack_frame_id=0} + %bitcast.209.clone.3 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%div.441.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/div" stack_frame_id=0} + %convert.86 = f32[4,32,128,128]{3,2,1,0:T(8,128)} convert(%bitcast.209.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/convert.1" stack_frame_id=0} + %multiply.196 = f32[4,32,128,128]{3,2,1,0:T(8,128)} multiply(%convert.87, %convert.86), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/multiply" stack_frame_id=0} + %constant.1147 = f32[]{:T(128)} constant(0) + %dot_general.189 = f32[4,32,128]{2,1,0:T(8,128)S(1)} reduce(%multiply.196, %constant.1147), dimensions={3}, to_apply=%region_15.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general" stack_frame_id=0} + ROOT %tuple.191 = (f32[4,32,128]{2,1,0:T(8,128)S(1)}, bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)}) tuple(%dot_general.189, %bitcast.209.clone.3) +} diff --git a/src/maxtext/utils/reference_hlo_qwen3_1.7b.txt b/src/maxtext/utils/reference_hlo_qwen3_1.7b.txt new file mode 100644 index 0000000000..6bdc2b6141 --- /dev/null +++ b/src/maxtext/utils/reference_hlo_qwen3_1.7b.txt @@ -0,0 +1,2000 @@ +HloModule jit_train_step, is_scheduled=true, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias), {4}: (4, {}, may-alias), {5}: (5, {}, may-alias), {6}: (6, {}, may-alias), {7}: (7, {}, may-alias), {8}: (8, {}, may-alias), {9}: (9, {}, may-alias), {10}: (10, {}, may-alias), {11}: (11, {}, may-alias), {12}: (12, {}, may-alias), {13}: (13, {}, may-alias), {14}: (14, {}, may-alias), {15}: (15, {}, may-alias), {16}: (16, {}, may-alias), {17}: (17, {}, may-alias), {18}: (18, {}, may-alias), {19}: (19, {}, may-alias), {20}: (20, {}, may-alias), {21}: (21, {}, may-alias), {22}: (22, {}, may-alias), {23}: (23, {}, may-alias), {24}: (24, {}, may-alias), {25}: (25, {}, may-alias), {26}: (26, {}, may-alias), {27}: (27, {}, may-alias), {28}: (28, {}, may-alias), {29}: (29, {}, may-alias), {30}: (30, {}, may-alias), {31}: (31, {}, may-alias), {32}: (32, {}, may-alias), {33}: (33, {}, may-alias), {34}: (34, {}, may-alias), {35}: (35, {}, may-alias), {36}: (36, {}, may-alias), {37}: (37, {}, may-alias), {38}: (38, {}, may-alias), {39}: (39, {}, may-alias), {40}: (40, {}, may-alias), {41}: (41, {}, may-alias) }, entry_computation_layout={(s32[]{:T(128)}, f32[2048]{0:T(1024)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, /*index=5*/f32[2048,4]{0,1:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, /*index=10*/f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, s32[]{:T(128)}, /*index=15*/f32[2048]{0:T(1024)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, /*index=20*/f32[2048,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, /*index=25*/f32[128,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[2048]{0:T(1024)}, f32[2048,4,6144]{2,1,0:T(4,128)}, /*index=30*/f32[2048,4,6144]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, /*index=35*/f32[128,4]{0,1:T(4,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, /*index=40*/f32[151936,2048]{1,0:T(8,128)}, s32[]{:T(128)}, s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)}, /*index=45*/s32[4,128]{1,0:T(4,128)}, s32[4,128]{1,0:T(4,128)})->(s32[]{:T(128)}, f32[2048]{0:T(1024)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, /*index=5*/f32[2048,4]{0,1:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, /*index=10*/f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, s32[]{:T(128)}, /*index=15*/f32[2048]{0:T(1024)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, /*index=20*/f32[2048,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, /*index=25*/f32[128,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[2048]{0:T(1024)}, f32[2048,4,6144]{2,1,0:T(4,128)}, /*index=30*/f32[2048,4,6144]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, f32[2048,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, /*index=35*/f32[128,4]{0,1:T(4,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[128,4]{0,1:T(4,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, /*index=40*/f32[151936,2048]{1,0:T(8,128)}, s32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, /*index=45*/f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, f32[]{:T(128)}, /*index=50*/f32[]{:T(128)}, s32[]{:T(128)}, f32[]{:T(128)})}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false}, allow_spmd_sharding_propagation_to_output={false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,false,true,true,true,true,true,true,true,true,true,true,true}, num_partitions=4 + +FileNames + +FunctionNames + +FileLocations + +StackFrames + + +%fused_computation (param_0.2: bf16[151936,2048], param_1.7: s32[1024]) -> bf16[512,2048] { + %param_0.2 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + %param_1.7 = s32[1024]{0:T(1024)S(1)} parameter(1) + %custom-call.1 = s32[1024]{0:T(1024)} custom-call(%param_1.7), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %slice.6 = s32[512]{0:T(512)} slice(%custom-call.1), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %reshape.554 = s32[4,128]{1,0:T(4,128)} reshape(%slice.6), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.261 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.554), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %gather.4 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} gather(%param_0.2, %transpose.261), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,2048}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %transpose.260 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} transpose(%gather.4), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %reshape.553 = bf16[512,2048]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.260), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} +} + +%region_42.47.clone (scatter-add.6: bf16[], scatter-add.7: bf16[]) -> bf16[] { + %scatter-add.7 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} + %scatter-add.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} + ROOT %add.560 = bf16[]{:T(256)} add(%scatter-add.6, %scatter-add.7), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.1 (param_0.3: bf16[151936,2048], param_1.5: s32[512], param_2.4: bf16[512,2048]) -> bf16[151936,2048] { + %param_0.3 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + %param_1.5 = s32[512]{0:T(512)S(1)} parameter(1) + %reshape.561 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.5), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.266 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.561), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %param_2.4 = bf16[512,2048]{1,0:T(8,128)(2,1)} parameter(2) + %reshape.562 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} reshape(%param_2.4), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} + %transpose.267 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} transpose(%reshape.562), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} + ROOT %scatter.2 = bf16[151936,2048]{1,0:T(8,128)(2,1)} scatter(%param_0.3, %transpose.266, %transpose.267), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_42.47.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} +} + +%region_71.76 (reduce_sum.569: f32[], reduce_sum.570: f32[]) -> f32[] { + %reduce_sum.570 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.569 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.571 = f32[]{:T(128)} add(%reduce_sum.569, %reduce_sum.570), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_56.61 (reduce_sum.488: f32[], reduce_sum.492: f32[]) -> f32[] { + %reduce_sum.492 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.488 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.493 = f32[]{:T(128)} add(%reduce_sum.488, %reduce_sum.492), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.288 (param_0.1409: f32[151936,2048], param_1.1583: f32[], param_2.1325: f32[], param_3.908: f32[], param_4.547: f32[151936,2048], param_5.481: f32[], param_6.356: bf16[151936,2048], param_7.200: bf16[151936,2048,1], param_8.118: pred[], param_9.97: f32[151936,2048]) -> (f32[], f32[151936,2048], f32[151936,2048], f32[151936,2048], f32[]) { + %param_0.1409 = f32[151936,2048]{1,0:T(8,128)} parameter(0) + %param_3.908 = f32[]{:T(128)S(6)} parameter(3) + %mul.2449.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_3.908), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.118 = pred[]{:T(512)S(6)} parameter(8) + %select_n.268.clone.1 = pred[151936,2048]{1,0:T(8,128)(4,1)} broadcast(%param_8.118), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_7.200 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} parameter(7) + %bitcast.445.clone.1 = bf16[151936,2048]{1,0:T(8,128)(2,1)} bitcast(%param_7.200), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %convert_element_type.1433.clone.1 = f32[151936,2048]{1,0:T(8,128)} convert(%bitcast.445.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_6.356 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(6) + %convert_element_type.1432.clone.1 = f32[151936,2048]{1,0:T(8,128)} convert(%param_6.356), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %add_any.188.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%convert_element_type.1433.clone.1, %convert_element_type.1432.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add_any" stack_frame_id=0} + %param_5.481 = f32[]{:T(128)} parameter(5) + %div.860.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_5.481), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.859.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add_any.188.clone.1, %div.860.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.267.clone.1 = f32[151936,2048]{1,0:T(8,128)} select(%select_n.268.clone.1, %add_any.188.clone.1, %div.859.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1080.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.754.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1080.clone.1), dimensions={}, metadata={op_name="broadcast.74"} + %mul.2455.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%select_n.267.clone.1, %broadcast.754.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_9.97 = f32[151936,2048]{1,0:T(8,128)} parameter(9) + %constant.1084.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2456.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1084.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2454.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_9.97, %mul.2456.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.917.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%mul.2455.clone.1, %mul.2454.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1325 = f32[]{:T(128)S(6)} parameter(2) + %div.856.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_2.1325), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.65.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%select_n.267.clone.1, %select_n.267.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1083.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2453.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1083.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2451.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%integer_pow.65.clone.1, %mul.2453.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.547 = f32[151936,2048]{1,0:T(8,128)} parameter(4) + %constant.1082.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2452.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1082.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2450.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_4.547, %mul.2452.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.916.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%mul.2451.clone.1, %mul.2450.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1583 = f32[]{:T(128)S(6)} parameter(1) + %div.855.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_1.1583), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.854.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add.916.clone.1, %div.855.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.62.clone.1 = f32[151936,2048]{1,0:T(8,128)} sqrt(%div.854.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1081.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.915.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1081.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.914.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%sqrt.62.clone.1, %add.915.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.287.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%div.856.clone.1, %add.914.clone.1), metadata={op_name="multiply.46"} + %div.853.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add.917.clone.1, %multiply.287.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2448.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_0.1409, %broadcast.754.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.913.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%div.853.clone.1, %mul.2448.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2447.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%mul.2449.clone.1, %add.913.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.912.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%param_0.1409, %mul.2447.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.214 = f32[151936,2048]{1,0:T(8,128)} multiply(%add.912.clone.1, %add.912.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1188 = f32[]{:T(128)} constant(0) + %reduce.106 = f32[]{:T(128)} reduce(%square.214, %constant.1188), dimensions={0,1}, to_apply=%region_71.76, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.108.clone.1 = f32[]{:T(128)} reduce(%integer_pow.65.clone.1, %constant.1188), dimensions={0,1}, to_apply=%region_56.61, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.145 = (f32[]{:T(128)}, f32[151936,2048]{1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.106, %add.912.clone.1, %add.916.clone.1, %add.917.clone.1, %reduce.108.clone.1) +} + +%region_43.48 (reduce_sum.422: f32[], reduce_sum.423: f32[]) -> f32[] { + %reduce_sum.423 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.422 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.424 = f32[]{:T(128)} add(%reduce_sum.422, %reduce_sum.423), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.378.clone.clone (param_0.1396: f32[4,128], param_1.1576: bf16[4,128,2048], param_2.1301: bf16[2048]) -> bf16[4,128,2048] { + %param_1.1576 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1475 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1576), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1396 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2627 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1396), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2626 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1475, %mul.2627), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1474 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2626), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.1301 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2628 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1301), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.2625 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1474, %mul.2628), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%fused_computation.300.clone.clone.clone (param_0.1397: bf16[4,128,151936], param_1.1577: s32[4,128], param_2.1302: f32[4,128], param_3.901: f32[4,128], param_4.537: bf16[4,128], param_5.459: f32[4,128]) -> bf16[4,128,151936] { + %param_5.459 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.2632 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_5.459), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.901 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.2631 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.901), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.1397 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1478 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1397), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_4.537 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.92 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_4.537), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.91 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1478, %sub.92), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.62 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.91), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %mul.2630 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2631, %exp.62), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.1302 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.966 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1302), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.965 = f32[4,128,151936]{2,1,0:T(8,128)} divide(%mul.2630, %div.966), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.1577 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.49 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1577), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.48 = s32[4,128,151936]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.47 = pred[4,128,151936]{2,1,0:T(8,128)(4,1)} compare(%eq.49, %eq.48), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %convert_element_type.1477 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%eq.47), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.90 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%div.965, %convert_element_type.1477), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.2629 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2632, %sub.90), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.1476 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convert(%mul.2629), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +} + +%fused_computation.292 (param_0.1422: bf16[151936,2048], param_1.1596: f32[4,128], param_2.1338: bf16[4,128,2048], param_3.921: bf16[2048], param_4.560: bf16[4,128,151936], param_5.494: s32[4,128], param_6.369: f32[4,128], param_7.213: f32[4,128], param_8.131: bf16[4,128], param_9.98: f32[4,128]) -> (f32[], bf16[151936,2048,1]) { + %param_4.560 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(4) + %param_5.494 = s32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %param_6.369 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.213 = f32[4,128]{1,0:T(4,128)S(1)} parameter(7) + %param_8.131 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(8) + %param_9.98 = f32[4,128]{1,0:T(4,128)S(1)} parameter(9) + %multiply_convert_fusion.2.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} fusion(%param_4.560, %param_5.494, %param_6.369, %param_7.213, %param_8.131, /*index=5*/%param_9.98), kind=kLoop, calls=%fused_computation.300.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_1.1596 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1338 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.921 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.279.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1596, %param_2.1338, %param_3.921), kind=kLoop, calls=%fused_computation.378.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convolution.86.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} convolution(%multiply_convert_fusion.2.clone.1, %fusion.279.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %bitcast.314 = bf16[151936,2048]{1,0:T(8,128)(2,1)} bitcast(%convolution.86.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %convert_element_type.1347 = f32[151936,2048]{1,0:T(8,128)} convert(%bitcast.314), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_0.1422 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1346 = f32[151936,2048]{1,0:T(8,128)} convert(%param_0.1422), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %add_any.175 = f32[151936,2048]{1,0:T(8,128)} add(%convert_element_type.1347, %convert_element_type.1346), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add_any" stack_frame_id=0} + %square.215 = f32[151936,2048]{1,0:T(8,128)} multiply(%add_any.175, %add_any.175), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1201 = f32[]{:T(128)} constant(0) + %reduce.107 = f32[]{:T(128)} reduce(%square.215, %constant.1201), dimensions={0,1}, to_apply=%region_43.48, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.167 = (f32[]{:T(128)}, bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.107, %convolution.86.clone.1) +} + +%region_57.62 (reduce_sum.494: f32[], reduce_sum.495: f32[]) -> f32[] { + %reduce_sum.495 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.494 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.499 = f32[]{:T(128)} add(%reduce_sum.494, %reduce_sum.495), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.299 (param_0.1433: bf16[4,128,151936], param_1.1604: f32[4,128], param_2.1341: s32[4,128], param_3.923: bf16[4,128]) -> f32[4,128] { + %param_2.1341 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %eq.30 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1341), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.25 = s32[4,128,151936]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.24 = pred[4,128,151936]{2,1,0:T(8,128)(4,1)} compare(%eq.30, %eq.25), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %param_0.1433 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1364 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1433), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_3.923 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) + %sub.73 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.923), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.64 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1364, %sub.73), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_1.1604 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %sub.71 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1604), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.60 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%sub.64, %sub.71), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %constant.1213 = f32[]{:T(128)} constant(0) + %broadcast.681 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%constant.1213), dimensions={}, metadata={op_name="broadcast.99"} + %mul.2269 = f32[4,128,151936]{2,1,0:T(8,128)} select(%eq.24, %sub.60, %broadcast.681), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + ROOT %reduce.109 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.2269, %constant.1213), dimensions={2}, to_apply=%region_57.62, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} +} + +%region_9.12 (reduce_sum.237: f32[], reduce_sum.241: f32[]) -> f32[] { + %reduce_sum.241 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.237 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.262 = f32[]{:T(128)} add(%reduce_sum.237, %reduce_sum.241), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.304 (param_0.1434: bf16[4,128,151936], param_1.1605: bf16[4,128]) -> f32[4,128] { + %param_0.1434 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1370 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1434), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_1.1605 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) + %sub.74 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1605), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.70 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1370, %sub.74), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.54 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.70), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %constant.1214 = f32[]{:T(128)} constant(0) + ROOT %reduce.110 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.54, %constant.1214), dimensions={2}, to_apply=%region_9.12, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} +} + +%region_33.38 (reduce_sum.374: f32[], reduce_sum.375: f32[]) -> f32[] { + %reduce_sum.375 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.374 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.376 = f32[]{:T(128)} add(%reduce_sum.374, %reduce_sum.375), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.309 (param_0.1428: f32[4,6144,2048]) -> f32[] { + %param_0.1428 = f32[4,6144,2048]{2,0,1:T(4,128)} parameter(0) + %bitcast.328 = f32[6144,4,2048]{2,1,0:T(4,128)} bitcast(%param_0.1428), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.218 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%bitcast.328, %bitcast.328), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1207 = f32[]{:T(128)} constant(0) + ROOT %reduce.111 = f32[]{:T(128)} reduce(%square.218, %constant.1207), dimensions={0,1,2}, to_apply=%region_33.38, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%region_32.37 (reduce_sum.368: f32[], reduce_sum.369: f32[]) -> f32[] { + %reduce_sum.369 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.368 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.373 = f32[]{:T(128)} add(%reduce_sum.368, %reduce_sum.369), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_31.36 (reduce_sum.362: f32[], reduce_sum.366: f32[]) -> f32[] { + %reduce_sum.366 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.362 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.367 = f32[]{:T(128)} add(%reduce_sum.362, %reduce_sum.366), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.311 (param_0.1429: f32[4,2048,6144], param_1.1600: f32[4,2048,6144]) -> (f32[], f32[]) { + %param_0.1429 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(0) + %bitcast.332 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_0.1429), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.221 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%bitcast.332, %bitcast.332), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1208 = f32[]{:T(128)} constant(0) + %reduce.112 = f32[]{:T(128)} reduce(%square.221, %constant.1208), dimensions={0,1,2}, to_apply=%region_32.37, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1600 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(1) + %bitcast.336.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_1.1600), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.224.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%bitcast.336.clone.1, %bitcast.336.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.113.clone.1 = f32[]{:T(128)} reduce(%square.224.clone.1, %constant.1208), dimensions={0,1,2}, to_apply=%region_31.36, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.168 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.112, %reduce.113.clone.1) +} + +%fused_computation.314 (param_0.940: f32[6144,4,2048]) -> bf16[4,6144,2048] { + %param_0.940 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(0) + %copy.186 = bf16[6144,4,2048]{2,0,1:T(8,128)(2,1)} copy(%param_0.940), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wo\'][\'kernel\']"} + ROOT %bitcast.337 = bf16[4,6144,2048]{2,1,0:T(8,128)(2,1)} bitcast(%copy.186), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +} + +%fused_computation.315 (param_0.942: f32[2048,4,6144]) -> bf16[4,2048,6144] { + %param_0.942 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %copy.187 = bf16[2048,4,6144]{2,0,1:T(8,128)(2,1)} copy(%param_0.942), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_1\'][\'kernel\']"} + ROOT %bitcast.338 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} bitcast(%copy.187), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +} + +%fused_computation.316 (param_0.944: f32[2048,4,6144]) -> bf16[4,2048,6144] { + %param_0.944 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %copy.188 = bf16[2048,4,6144]{2,0,1:T(8,128)(2,1)} copy(%param_0.944), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_0\'][\'kernel\']"} + ROOT %bitcast.339 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} bitcast(%copy.188), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +} + +%region_62.67 (reduce_sum.521: f32[], reduce_sum.522: f32[]) -> f32[] { + %reduce_sum.522 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.521 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.523 = f32[]{:T(128)} add(%reduce_sum.521, %reduce_sum.522), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_47.52 (reduce_sum.443: f32[], reduce_sum.444: f32[]) -> f32[] { + %reduce_sum.444 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.443 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.445 = f32[]{:T(128)} add(%reduce_sum.443, %reduce_sum.444), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.317 (param_0.1418: f32[6144,4,2048], param_1.1592: f32[], param_2.1334: f32[], param_3.917: f32[], param_4.556: f32[6144,4,2048], param_5.490: f32[], param_6.365: f32[4,6144,2048], param_7.209: pred[], param_8.127: f32[6144,4,2048]) -> (f32[], f32[6144,4,2048], f32[6144,4,2048], f32[6144,4,2048], f32[]) { + %param_0.1418 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(0) + %param_3.917 = f32[]{:T(128)S(6)} parameter(3) + %mul.2521.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_3.917), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.209 = pred[]{:T(512)S(6)} parameter(7) + %select_n.304.clone.1 = pred[6144,4,2048]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.209), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.365 = f32[4,6144,2048]{2,0,1:T(4,128)} parameter(6) + %bitcast.463.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} bitcast(%param_6.365), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.490 = f32[]{:T(128)} parameter(5) + %div.932.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_5.490), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.931.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%bitcast.463.clone.1, %div.932.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.303.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} select(%select_n.304.clone.1, %bitcast.463.clone.1, %div.931.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1134.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.796.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1134.clone.1), dimensions={}, metadata={op_name="broadcast.83"} + %mul.2527.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%select_n.303.clone.1, %broadcast.796.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.127 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(8) + %constant.1138.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2528.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1138.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2526.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_8.127, %mul.2528.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.965.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%mul.2527.clone.1, %mul.2526.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1334 = f32[]{:T(128)S(6)} parameter(2) + %div.928.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_2.1334), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.74.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%select_n.303.clone.1, %select_n.303.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1137.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2525.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1137.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2523.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%integer_pow.74.clone.1, %mul.2525.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.556 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(4) + %constant.1136.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2524.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1136.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2522.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_4.556, %mul.2524.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.964.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%mul.2523.clone.1, %mul.2522.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1592 = f32[]{:T(128)S(6)} parameter(1) + %div.927.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_1.1592), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.926.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%add.964.clone.1, %div.927.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.71.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} sqrt(%div.926.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1135.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.963.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1135.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.962.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%sqrt.71.clone.1, %add.963.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.296.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%div.928.clone.1, %add.962.clone.1), metadata={op_name="multiply.37"} + %div.925.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%add.965.clone.1, %multiply.296.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2520.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_0.1418, %broadcast.796.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.961.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%div.925.clone.1, %mul.2520.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2519.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%mul.2521.clone.1, %add.961.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.960.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%param_0.1418, %mul.2519.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.225 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%add.960.clone.1, %add.960.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1197 = f32[]{:T(128)} constant(0) + %reduce.114 = f32[]{:T(128)} reduce(%square.225, %constant.1197), dimensions={0,1,2}, to_apply=%region_62.67, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.117.clone.1 = f32[]{:T(128)} reduce(%integer_pow.74.clone.1, %constant.1197), dimensions={0,1,2}, to_apply=%region_47.52, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.146 = (f32[]{:T(128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.114, %add.960.clone.1, %add.964.clone.1, %add.965.clone.1, %reduce.117.clone.1) +} + +%region_61.66 (reduce_sum.515: f32[], reduce_sum.516: f32[]) -> f32[] { + %reduce_sum.516 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.515 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.520 = f32[]{:T(128)} add(%reduce_sum.515, %reduce_sum.516), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_46.51 (reduce_sum.437: f32[], reduce_sum.438: f32[]) -> f32[] { + %reduce_sum.438 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.437 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.439 = f32[]{:T(128)} add(%reduce_sum.437, %reduce_sum.438), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.318 (param_0.1419: f32[2048,4,6144], param_1.1593: f32[], param_2.1335: f32[], param_3.918: f32[], param_4.557: f32[2048,4,6144], param_5.491: f32[], param_6.366: f32[4,2048,6144], param_7.210: pred[], param_8.128: f32[2048,4,6144]) -> (f32[], f32[2048,4,6144], f32[2048,4,6144], f32[2048,4,6144], f32[]) { + %param_0.1419 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %param_3.918 = f32[]{:T(128)S(6)} parameter(3) + %mul.2531.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_3.918), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.210 = pred[]{:T(512)S(6)} parameter(7) + %select_n.308.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.210), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.366 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(6) + %bitcast.465.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_6.366), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.491 = f32[]{:T(128)} parameter(5) + %div.940.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_5.491), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.939.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%bitcast.465.clone.1, %div.940.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.307.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%select_n.308.clone.1, %bitcast.465.clone.1, %div.939.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1140.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.802.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1140.clone.1), dimensions={}, metadata={op_name="broadcast.85"} + %mul.2535.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.307.clone.1, %broadcast.802.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.128 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(8) + %constant.1144.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.801.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1144.clone.1), dimensions={}, metadata={op_name="broadcast.84"} + %mul.2534.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_8.128, %broadcast.801.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.970.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2535.clone.1, %mul.2534.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1335 = f32[]{:T(128)S(6)} parameter(2) + %div.936.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_2.1335), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.75.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.307.clone.1, %select_n.307.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1143.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.800.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1143.clone.1), dimensions={}, metadata={op_name="broadcast.73"} + %mul.2533.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%integer_pow.75.clone.1, %broadcast.800.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.557 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(4) + %constant.1142.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.799.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1142.clone.1), dimensions={}, metadata={op_name="broadcast.72"} + %mul.2532.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_4.557, %broadcast.799.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.969.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2533.clone.1, %mul.2532.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1593 = f32[]{:T(128)S(6)} parameter(1) + %div.935.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_1.1593), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.934.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.969.clone.1, %div.935.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.72.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} sqrt(%div.934.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1141.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.797.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1141.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %add.968.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%sqrt.72.clone.1, %broadcast.797.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.297.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%div.936.clone.1, %add.968.clone.1), metadata={op_name="multiply.36"} + %div.933.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.970.clone.1, %multiply.297.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2530.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_0.1419, %broadcast.802.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.967.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%div.933.clone.1, %mul.2530.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2529.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%mul.2531.clone.1, %add.967.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.966.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%param_0.1419, %mul.2529.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.226 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%add.966.clone.1, %add.966.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1198 = f32[]{:T(128)} constant(0) + %reduce.115 = f32[]{:T(128)} reduce(%square.226, %constant.1198), dimensions={0,1,2}, to_apply=%region_61.66, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.118.clone.1 = f32[]{:T(128)} reduce(%integer_pow.75.clone.1, %constant.1198), dimensions={0,1,2}, to_apply=%region_46.51, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.147 = (f32[]{:T(128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.115, %add.966.clone.1, %add.969.clone.1, %add.970.clone.1, %reduce.118.clone.1) +} + +%region_60.65 (reduce_sum.509: f32[], reduce_sum.513: f32[]) -> f32[] { + %reduce_sum.513 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.509 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.514 = f32[]{:T(128)} add(%reduce_sum.509, %reduce_sum.513), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_45.50 (reduce_sum.431: f32[], reduce_sum.432: f32[]) -> f32[] { + %reduce_sum.432 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.431 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.436 = f32[]{:T(128)} add(%reduce_sum.431, %reduce_sum.432), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.319 (param_0.1420: f32[2048,4,6144], param_1.1594: f32[], param_2.1336: f32[], param_3.919: f32[], param_4.558: f32[2048,4,6144], param_5.492: f32[], param_6.367: f32[4,2048,6144], param_7.211: pred[], param_8.129: f32[2048,4,6144]) -> (f32[], f32[2048,4,6144], f32[2048,4,6144], f32[2048,4,6144], f32[]) { + %param_0.1420 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %param_3.919 = f32[]{:T(128)S(6)} parameter(3) + %mul.2538.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_3.919), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.211 = pred[]{:T(512)S(6)} parameter(7) + %select_n.312.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.211), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.367 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(6) + %bitcast.467.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_6.367), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.492 = f32[]{:T(128)} parameter(5) + %div.948.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_5.492), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.947.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%bitcast.467.clone.1, %div.948.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.311.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%select_n.312.clone.1, %bitcast.467.clone.1, %div.947.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1146.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.808.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1146.clone.1), dimensions={}, metadata={op_name="broadcast.85"} + %mul.2542.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.311.clone.1, %broadcast.808.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.129 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(8) + %constant.1150.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.807.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1150.clone.1), dimensions={}, metadata={op_name="broadcast.84"} + %mul.2541.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_8.129, %broadcast.807.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.975.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2542.clone.1, %mul.2541.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1336 = f32[]{:T(128)S(6)} parameter(2) + %div.944.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_2.1336), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.76.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.311.clone.1, %select_n.311.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1149.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.806.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1149.clone.1), dimensions={}, metadata={op_name="broadcast.73"} + %mul.2540.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%integer_pow.76.clone.1, %broadcast.806.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.558 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(4) + %constant.1148.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.805.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1148.clone.1), dimensions={}, metadata={op_name="broadcast.72"} + %mul.2539.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_4.558, %broadcast.805.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.974.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2540.clone.1, %mul.2539.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1594 = f32[]{:T(128)S(6)} parameter(1) + %div.943.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_1.1594), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.942.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.974.clone.1, %div.943.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.73.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} sqrt(%div.942.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1147.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.803.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1147.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %add.973.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%sqrt.73.clone.1, %broadcast.803.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.298.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%div.944.clone.1, %add.973.clone.1), metadata={op_name="multiply.35"} + %div.941.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.975.clone.1, %multiply.298.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2537.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_0.1420, %broadcast.808.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.972.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%div.941.clone.1, %mul.2537.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2536.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%mul.2538.clone.1, %add.972.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.971.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%param_0.1420, %mul.2536.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.227 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%add.971.clone.1, %add.971.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1199 = f32[]{:T(128)} constant(0) + %reduce.116 = f32[]{:T(128)} reduce(%square.227, %constant.1199), dimensions={0,1,2}, to_apply=%region_60.65, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.119.clone.1 = f32[]{:T(128)} reduce(%integer_pow.76.clone.1, %constant.1199), dimensions={0,1,2}, to_apply=%region_45.50, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.148 = (f32[]{:T(128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.116, %add.971.clone.1, %add.974.clone.1, %add.975.clone.1, %reduce.119.clone.1) +} + +%region_39.44 (reduce_sum.404: f32[], reduce_sum.408: f32[]) -> f32[] { + %reduce_sum.408 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.404 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.409 = f32[]{:T(128)} add(%reduce_sum.404, %reduce_sum.408), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.335 (param_0.1423: f32[4,2048,16,128]) -> f32[] { + %param_0.1423 = f32[4,2048,16,128]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.343 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1423), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.230 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%bitcast.343, %bitcast.343), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1202 = f32[]{:T(128)} constant(0) + ROOT %reduce.120 = f32[]{:T(128)} reduce(%square.230, %constant.1202), dimensions={0,1,2,3}, to_apply=%region_39.44, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%region_38.43 (reduce_sum.401: f32[], reduce_sum.402: f32[]) -> f32[] { + %reduce_sum.402 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.401 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.403 = f32[]{:T(128)} add(%reduce_sum.401, %reduce_sum.402), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.337 (param_0.1424: f32[4,16,128,2048]) -> f32[] { + %param_0.1424 = f32[4,16,128,2048]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.347 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} bitcast(%param_0.1424), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.233 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%bitcast.347, %bitcast.347), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1203 = f32[]{:T(128)} constant(0) + ROOT %reduce.121 = f32[]{:T(128)} reduce(%square.233, %constant.1203), dimensions={0,1,2,3}, to_apply=%region_38.43, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%fused_computation.338 (param_0.989: f32[16,4,128,2048]) -> bf16[4,16,128,2048] { + %param_0.989 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(0) + %copy.189 = bf16[16,4,128,2048]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.989), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'out\'][\'kernel\']"} + ROOT %bitcast.348 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.189), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +} + +%region_68.73 (reduce_sum.551: f32[], reduce_sum.555: f32[]) -> f32[] { + %reduce_sum.555 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.551 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.556 = f32[]{:T(128)} add(%reduce_sum.551, %reduce_sum.555), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_53.58 (reduce_sum.473: f32[], reduce_sum.474: f32[]) -> f32[] { + %reduce_sum.474 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.473 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.478 = f32[]{:T(128)} add(%reduce_sum.473, %reduce_sum.474), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.339 (param_0.1412: f32[2048,4,16,128], param_1.1586: f32[], param_2.1328: f32[], param_3.911: f32[], param_4.550: f32[2048,4,16,128], param_5.484: f32[], param_6.359: f32[4,2048,16,128], param_7.203: pred[], param_8.121: f32[2048,4,16,128]) -> (f32[], f32[2048,4,16,128], f32[2048,4,16,128], f32[2048,4,16,128], f32[]) { + %param_0.1412 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.911 = f32[]{:T(128)S(6)} parameter(3) + %mul.2473.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_3.911), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.203 = pred[]{:T(512)S(6)} parameter(7) + %select_n.280.clone.1 = pred[2048,4,16,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.203), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.359 = f32[4,2048,16,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.451.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} bitcast(%param_6.359), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.484 = f32[]{:T(128)} parameter(5) + %div.884.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_5.484), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.883.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%bitcast.451.clone.1, %div.884.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.279.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} select(%select_n.280.clone.1, %bitcast.451.clone.1, %div.883.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1098.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.768.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1098.clone.1), dimensions={}, metadata={op_name="broadcast.75"} + %mul.2479.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%select_n.279.clone.1, %broadcast.768.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.121 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.1102.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2480.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1102.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2478.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_8.121, %mul.2480.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.933.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%mul.2479.clone.1, %mul.2478.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1328 = f32[]{:T(128)S(6)} parameter(2) + %div.880.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1328), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.68.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%select_n.279.clone.1, %select_n.279.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1101.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2477.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1101.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2475.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.68.clone.1, %mul.2477.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.550 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.1100.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2476.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1100.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2474.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_4.550, %mul.2476.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.932.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%mul.2475.clone.1, %mul.2474.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1586 = f32[]{:T(128)S(6)} parameter(1) + %div.879.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1586), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.878.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%add.932.clone.1, %div.879.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.65.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} sqrt(%div.878.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1099.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.931.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1099.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.930.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%sqrt.65.clone.1, %add.931.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.290.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%div.880.clone.1, %add.930.clone.1), metadata={op_name="multiply.43"} + %div.877.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%add.933.clone.1, %multiply.290.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2472.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_0.1412, %broadcast.768.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.929.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%div.877.clone.1, %mul.2472.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2471.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%mul.2473.clone.1, %add.929.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.928.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%param_0.1412, %mul.2471.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.234 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%add.928.clone.1, %add.928.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1191 = f32[]{:T(128)} constant(0) + %reduce.122 = f32[]{:T(128)} reduce(%square.234, %constant.1191), dimensions={0,1,2,3}, to_apply=%region_68.73, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.124.clone.1 = f32[]{:T(128)} reduce(%integer_pow.68.clone.1, %constant.1191), dimensions={0,1,2,3}, to_apply=%region_53.58, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.149 = (f32[]{:T(128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.122, %add.928.clone.1, %add.932.clone.1, %add.933.clone.1, %reduce.124.clone.1) +} + +%region_67.72 (reduce_sum.548: f32[], reduce_sum.549: f32[]) -> f32[] { + %reduce_sum.549 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.548 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.550 = f32[]{:T(128)} add(%reduce_sum.548, %reduce_sum.549), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_52.57 (reduce_sum.467: f32[], reduce_sum.471: f32[]) -> f32[] { + %reduce_sum.471 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.467 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.472 = f32[]{:T(128)} add(%reduce_sum.467, %reduce_sum.471), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.340 (param_0.1413: f32[16,4,128,2048], param_1.1587: f32[], param_2.1329: f32[], param_3.912: f32[], param_4.551: f32[16,4,128,2048], param_5.485: f32[], param_6.360: f32[4,16,128,2048], param_7.204: pred[], param_8.122: f32[16,4,128,2048]) -> (f32[], f32[16,4,128,2048], f32[16,4,128,2048], f32[16,4,128,2048], f32[]) { + %param_0.1413 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(0) + %param_3.912 = f32[]{:T(128)S(6)} parameter(3) + %mul.2483.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_3.912), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.204 = pred[]{:T(512)S(6)} parameter(7) + %select_n.284.clone.1 = pred[16,4,128,2048]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.204), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.360 = f32[4,16,128,2048]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.453.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} bitcast(%param_6.360), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.485 = f32[]{:T(128)} parameter(5) + %div.892.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_5.485), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.891.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%bitcast.453.clone.1, %div.892.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.283.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} select(%select_n.284.clone.1, %bitcast.453.clone.1, %div.891.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1104.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.770.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1104.clone.1), dimensions={}, metadata={op_name="broadcast.76"} + %mul.2489.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%select_n.283.clone.1, %broadcast.770.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.122 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(8) + %constant.1108.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2490.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1108.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2488.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_8.122, %mul.2490.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.939.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%mul.2489.clone.1, %mul.2488.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1329 = f32[]{:T(128)S(6)} parameter(2) + %div.888.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_2.1329), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.69.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%select_n.283.clone.1, %select_n.283.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1107.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2487.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1107.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2485.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%integer_pow.69.clone.1, %mul.2487.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.551 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(4) + %constant.1106.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2486.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1106.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2484.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_4.551, %mul.2486.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.938.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%mul.2485.clone.1, %mul.2484.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1587 = f32[]{:T(128)S(6)} parameter(1) + %div.887.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_1.1587), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.886.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%add.938.clone.1, %div.887.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.66.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} sqrt(%div.886.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1105.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.937.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1105.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.936.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%sqrt.66.clone.1, %add.937.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.291.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%div.888.clone.1, %add.936.clone.1), metadata={op_name="multiply.42"} + %div.885.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%add.939.clone.1, %multiply.291.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2482.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_0.1413, %broadcast.770.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.935.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%div.885.clone.1, %mul.2482.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2481.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%mul.2483.clone.1, %add.935.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.934.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%param_0.1413, %mul.2481.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.235 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%add.934.clone.1, %add.934.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1192 = f32[]{:T(128)} constant(0) + %reduce.123 = f32[]{:T(128)} reduce(%square.235, %constant.1192), dimensions={0,1,2,3}, to_apply=%region_67.72, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.125.clone.1 = f32[]{:T(128)} reduce(%integer_pow.69.clone.1, %constant.1192), dimensions={0,1,2,3}, to_apply=%region_52.57, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.150 = (f32[]{:T(128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.123, %add.934.clone.1, %add.938.clone.1, %add.939.clone.1, %reduce.125.clone.1) +} + +%region_41.46 (reduce_sum.416: f32[], reduce_sum.417: f32[]) -> f32[] { + %reduce_sum.417 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.416 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.418 = f32[]{:T(128)} add(%reduce_sum.416, %reduce_sum.417), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_36.41 (reduce_sum.389: f32[], reduce_sum.390: f32[]) -> f32[] { + %reduce_sum.390 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.389 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.394 = f32[]{:T(128)} add(%reduce_sum.389, %reduce_sum.390), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.352 (param_0.1426: f32[4,2048,8,128], param_1.1598: f32[4,2048,8,128]) -> (f32[], f32[]) { + %param_0.1426 = f32[4,2048,8,128]{3,2,0,1:T(8,128)S(1)} parameter(0) + %bitcast.352 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1426), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.238 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.352, %bitcast.352), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1205 = f32[]{:T(128)} constant(0) + %reduce.126 = f32[]{:T(128)} reduce(%square.238, %constant.1205), dimensions={0,1,2,3}, to_apply=%region_41.46, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1598 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} parameter(1) + %bitcast.356.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_1.1598), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.241.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.356.clone.1, %bitcast.356.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.127.clone.1 = f32[]{:T(128)} reduce(%square.241.clone.1, %constant.1205), dimensions={0,1,2,3}, to_apply=%region_36.41, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.169 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.126, %reduce.127.clone.1) +} + +%fused_computation.355 (param_0.1021: f32[2048,4,8,128]) -> bf16[4,2048,8,128] { + %param_0.1021 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) + %copy.190 = bf16[2048,4,8,128]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.1021), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'value\'][\'kernel\']"} + ROOT %bitcast.357 = bf16[4,2048,8,128]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.190), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +} + +%region_70.75 (reduce_sum.563: f32[], reduce_sum.564: f32[]) -> f32[] { + %reduce_sum.564 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.563 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.565 = f32[]{:T(128)} add(%reduce_sum.563, %reduce_sum.564), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_55.60 (reduce_sum.485: f32[], reduce_sum.486: f32[]) -> f32[] { + %reduce_sum.486 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.485 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.487 = f32[]{:T(128)} add(%reduce_sum.485, %reduce_sum.486), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.356 (param_0.1410: f32[2048,4,8,128], param_1.1584: f32[], param_2.1326: f32[], param_3.909: f32[], param_4.548: f32[2048,4,8,128], param_5.482: f32[], param_6.357: f32[4,2048,8,128], param_7.201: pred[], param_8.119: f32[2048,4,8,128]) -> (f32[], f32[2048,4,8,128], f32[2048,4,8,128], f32[2048,4,8,128], f32[]) { + %param_0.1410 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.909 = f32[]{:T(128)S(6)} parameter(3) + %mul.2459.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.909), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.201 = pred[]{:T(512)S(6)} parameter(7) + %select_n.272.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.201), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.357 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.447.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.357), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.482 = f32[]{:T(128)} parameter(5) + %div.868.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.482), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.867.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.447.clone.1, %div.868.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.271.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.272.clone.1, %bitcast.447.clone.1, %div.867.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1086.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.760.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1086.clone.1), dimensions={}, metadata={op_name="broadcast.80"} + %mul.2463.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.271.clone.1, %broadcast.760.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.119 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.1090.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.759.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1090.clone.1), dimensions={}, metadata={op_name="broadcast.79"} + %mul.2462.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.119, %broadcast.759.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.922.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.2463.clone.1, %mul.2462.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1326 = f32[]{:T(128)S(6)} parameter(2) + %div.864.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1326), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.66.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.271.clone.1, %select_n.271.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1089.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.758.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1089.clone.1), dimensions={}, metadata={op_name="broadcast.69"} + %mul.2461.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.66.clone.1, %broadcast.758.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.548 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.1088.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.757.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1088.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.2460.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.548, %broadcast.757.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.921.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.2461.clone.1, %mul.2460.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1584 = f32[]{:T(128)S(6)} parameter(1) + %div.863.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1584), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.862.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.921.clone.1, %div.863.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.63.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.862.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1087.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.755.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1087.clone.1), dimensions={}, metadata={op_name="broadcast.63"} + %add.920.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.63.clone.1, %broadcast.755.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.288.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.864.clone.1, %add.920.clone.1), metadata={op_name="multiply.45"} + %div.861.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.922.clone.1, %multiply.288.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2458.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1410, %broadcast.760.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.919.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%div.861.clone.1, %mul.2458.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2457.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.2459.clone.1, %add.919.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.918.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1410, %mul.2457.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.242 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.918.clone.1, %add.918.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1189 = f32[]{:T(128)} constant(0) + %reduce.128 = f32[]{:T(128)} reduce(%square.242, %constant.1189), dimensions={0,1,2,3}, to_apply=%region_70.75, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.130.clone.1 = f32[]{:T(128)} reduce(%integer_pow.66.clone.1, %constant.1189), dimensions={0,1,2,3}, to_apply=%region_55.60, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.151 = (f32[]{:T(128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.128, %add.918.clone.1, %add.921.clone.1, %add.922.clone.1, %reduce.130.clone.1) +} + +%region_65.70 (reduce_sum.536: f32[], reduce_sum.537: f32[]) -> f32[] { + %reduce_sum.537 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.536 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.541 = f32[]{:T(128)} add(%reduce_sum.536, %reduce_sum.537), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_50.55 (reduce_sum.458: f32[], reduce_sum.459: f32[]) -> f32[] { + %reduce_sum.459 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.458 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.460 = f32[]{:T(128)} add(%reduce_sum.458, %reduce_sum.459), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.357 (param_0.1415: f32[2048,4,8,128], param_1.1589: f32[], param_2.1331: f32[], param_3.914: f32[], param_4.553: f32[2048,4,8,128], param_5.487: f32[], param_6.362: f32[4,2048,8,128], param_7.206: pred[], param_8.124: f32[2048,4,8,128]) -> (f32[], f32[2048,4,8,128], f32[2048,4,8,128], f32[2048,4,8,128], f32[]) { + %param_0.1415 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.914 = f32[]{:T(128)S(6)} parameter(3) + %mul.2500.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.914), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.206 = pred[]{:T(512)S(6)} parameter(7) + %select_n.292.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.206), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.362 = f32[4,2048,8,128]{3,2,0,1:T(8,128)S(1)} parameter(6) + %bitcast.457.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.362), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.487 = f32[]{:T(128)} parameter(5) + %div.908.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.487), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.907.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.457.clone.1, %div.908.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.291.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.292.clone.1, %bitcast.457.clone.1, %div.907.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1116.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.782.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1116.clone.1), dimensions={}, metadata={op_name="broadcast.80"} + %mul.2504.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.291.clone.1, %broadcast.782.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.124 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.1120.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.781.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1120.clone.1), dimensions={}, metadata={op_name="broadcast.79"} + %mul.2503.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.124, %broadcast.781.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.949.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.2504.clone.1, %mul.2503.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1331 = f32[]{:T(128)S(6)} parameter(2) + %div.904.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1331), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.71.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.291.clone.1, %select_n.291.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1119.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.780.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1119.clone.1), dimensions={}, metadata={op_name="broadcast.69"} + %mul.2502.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.71.clone.1, %broadcast.780.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.553 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.1118.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.779.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1118.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.2501.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.553, %broadcast.779.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.948.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.2502.clone.1, %mul.2501.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1589 = f32[]{:T(128)S(6)} parameter(1) + %div.903.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1589), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.902.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.948.clone.1, %div.903.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.68.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.902.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1117.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.777.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1117.clone.1), dimensions={}, metadata={op_name="broadcast.63"} + %add.947.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.68.clone.1, %broadcast.777.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.293.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.904.clone.1, %add.947.clone.1), metadata={op_name="multiply.40"} + %div.901.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.949.clone.1, %multiply.293.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2499.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1415, %broadcast.782.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.946.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%div.901.clone.1, %mul.2499.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2498.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.2500.clone.1, %add.946.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.945.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1415, %mul.2498.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.243 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.945.clone.1, %add.945.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1194 = f32[]{:T(128)} constant(0) + %reduce.129 = f32[]{:T(128)} reduce(%square.243, %constant.1194), dimensions={0,1,2,3}, to_apply=%region_65.70, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.131.clone.1 = f32[]{:T(128)} reduce(%integer_pow.71.clone.1, %constant.1194), dimensions={0,1,2,3}, to_apply=%region_50.55, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.152 = (f32[]{:T(128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.129, %add.945.clone.1, %add.948.clone.1, %add.949.clone.1, %reduce.131.clone.1) +} + +%fused_computation.373 (param_0.1095: bf16[4,128,2048], param_1.1142: f32[4,128], param_2.842: f32[4,128], param_3.484: bf16[4,128,2048], param_4.283: bf16[2048]) -> bf16[4,128,2048] { + %param_3.484 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.283 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %mul.2385 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_4.283), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2359 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%param_3.484, %mul.2385), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1387 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%mul.2359), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.842 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %mul.2356 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_2.842), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2347 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1387, %mul.2356), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.1095 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1398 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1095), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_1.1142 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2354 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_1.1142), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2353 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1398, %mul.2354), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %add_any.184 = f32[4,128,2048]{2,1,0:T(8,128)} add(%mul.2347, %mul.2353), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add_any" stack_frame_id=0} + ROOT %convert_element_type.1385 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%add_any.184), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} +} + +%region_6.9 (reduce_sum.228: f32[], reduce_sum.229: f32[]) -> f32[] { + %reduce_sum.229 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.228 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.230 = f32[]{:T(128)} add(%reduce_sum.228, %reduce_sum.229), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.374 (param_0.1435: bf16[4,128,2048]) -> f32[4,128] { + %param_0.1435 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1389 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1435), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %square.246 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1389, %convert_element_type.1389), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/square" stack_frame_id=0} + %constant.1215 = f32[]{:T(128)} constant(0) + ROOT %reduce.132 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.246, %constant.1215), dimensions={2}, to_apply=%region_6.9, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} +} + +%region_12.15 (reduce_sum.275: f32[], reduce_sum.276: f32[]) -> f32[] { + %reduce_sum.276 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.275 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.277 = f32[]{:T(128)} add(%reduce_sum.275, %reduce_sum.276), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.376 (param_0.1430: bf16[4,128,2048], param_1.1601: bf16[4,128,2048], param_2.1339: bf16[2048]) -> f32[4,128] { + %param_0.1430 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1396 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1430), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_1.1601 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %param_2.1339 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2384 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1339), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2358 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1601, %mul.2384), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1395 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%mul.2358), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %mul.2351 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1396, %convert_element_type.1395), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.1209 = f32[]{:T(128)} constant(0) + ROOT %reduce.133 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.2351, %constant.1209), dimensions={2}, to_apply=%region_12.15, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} +} + +%region_10.13 (reduce_sum.263: bf16[], reduce_sum.264: bf16[]) -> bf16[] { + %reduce_sum.264 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.263 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.268 = bf16[]{:T(256)} add(%reduce_sum.263, %reduce_sum.264), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.296.clone.clone (param_0.1392: bf16[151936,2048]) -> bf16[151936,2048,1] { + %param_0.1392 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + ROOT %bitcast.505 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%param_0.1392), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} +} + +%fused_computation.300.clone.1.clone.clone (param_0.1393: bf16[4,128,151936], param_1.1573: s32[4,128], param_2.1296: f32[4,128], param_3.896: f32[4,128], param_4.533: bf16[4,128], param_5.455: f32[4,128]) -> bf16[4,128,151936] { + %param_5.455 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.2616 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_5.455), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.896 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.2615 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.896), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.1393 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1468 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1393), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_4.533 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.86 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_4.533), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.85 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1468, %sub.86), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.60 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.85), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %mul.2614 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2615, %exp.60), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.1296 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.962 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1296), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.961 = f32[4,128,151936]{2,1,0:T(8,128)} divide(%mul.2614, %div.962), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.1573 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.43 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1573), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.42 = s32[4,128,151936]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %eq.41 = pred[4,128,151936]{2,1,0:T(8,128)(4,1)} compare(%eq.43, %eq.42), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %convert_element_type.1467 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%eq.41), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.84 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%div.961, %convert_element_type.1467), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.2613 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2616, %sub.84), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.1466 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convert(%mul.2613), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +} + +%fused_computation.377 (param_0.1391: f32[4,128], param_1.1572: bf16[4,128,2048], param_2.1297: bf16[151936,2048], param_3.897: bf16[4,128,151936], param_4.534: s32[4,128], param_5.456: f32[4,128], param_6.342: f32[4,128], param_7.198: bf16[4,128], param_8.116: f32[4,128]) -> (bf16[2048], bf16[4,128,2048]) { + %param_1.1572 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1408 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1572), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1391 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2373 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1391), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2372 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1408, %mul.2373), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1407 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2372), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_3.897 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.534 = s32[4,128]{1,0:T(4,128)S(1)} parameter(4) + %param_5.456 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %param_6.342 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.198 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(7) + %param_8.116 = f32[4,128]{1,0:T(4,128)S(1)} parameter(8) + %multiply_convert_fusion.3.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} fusion(%param_3.897, %param_4.534, %param_5.456, %param_6.342, %param_7.198, /*index=5*/%param_8.116), kind=kLoop, calls=%fused_computation.300.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_2.1297 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(2) + %fusion.261.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_2.1297), kind=kLoop, calls=%fused_computation.296.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %convolution.84.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} convolution(%multiply_convert_fusion.3.clone.1, %fusion.261.clone.1), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %mul.2355 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1407, %convolution.84.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.1051 = bf16[]{:T(256)} constant(0) + %reduce.134 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%mul.2355, %constant.1051), dimensions={0,1}, to_apply=%region_10.13, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} + ROOT %tuple.166 = (bf16[2048]{0:T(1024)(128)(2,1)S(1)}, bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.134, %convolution.84.clone.1) +} + +%fused_computation.385 (param_0.1129: f32[64], param_1.1177: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.1177 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %div.720 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1177), dimensions={0,1}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %param_0.1129 = f32[64]{0:T(128)S(1)} parameter(0) + %div.718 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1129), dimensions={3}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %div.717 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.720, %div.718), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %sin.38 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.717), metadata={op_name="jit(train_step)/layers/sin" stack_frame_id=0} + %convert_element_type.1416 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.38), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %cos.41.clone.1 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.717), metadata={op_name="jit(train_step)/layers/cos" stack_frame_id=0} + %convert_element_type.1415.clone.1 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.41.clone.1), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.159 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1416, %convert_element_type.1415.clone.1) +} + +%fused_computation.386 (param_0.1126: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.1126 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1030 = bf16[]{:T(256)} constant(-inf) + %pad.46 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1126, %constant.1030), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + %pad.45 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1126, %constant.1030), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + ROOT %maximum.42 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.46, %pad.45), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} +} + +%fused_computation.387 (param_0.1128: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.1128 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1029 = bf16[]{:T(256)} constant(-inf) + %pad.48 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1128, %constant.1029), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + %pad.47 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1128, %constant.1029), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + ROOT %maximum.43 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.48, %pad.47), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} +} + +%region_35.40 (reduce_sum.383: f32[], reduce_sum.387: f32[]) -> f32[] { + %reduce_sum.387 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.383 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.388 = f32[]{:T(128)} add(%reduce_sum.383, %reduce_sum.387), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_34.39 (reduce_sum.380: f32[], reduce_sum.381: f32[]) -> f32[] { + %reduce_sum.381 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.380 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.382 = f32[]{:T(128)} add(%reduce_sum.380, %reduce_sum.381), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.391 (param_0.1427: f32[4,2048], param_1.1599: f32[4,2048]) -> (f32[], f32[]) { + %param_0.1427 = f32[4,2048]{1,0:T(4,128)S(1)} parameter(0) + %bitcast.385 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_0.1427), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.249 = f32[2048,4]{0,1:T(4,128)} multiply(%bitcast.385, %bitcast.385), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1206 = f32[]{:T(128)} constant(0) + %reduce.135 = f32[]{:T(128)} reduce(%square.249, %constant.1206), dimensions={0,1}, to_apply=%region_35.40, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1599 = f32[4,2048]{1,0:T(4,128)} parameter(1) + %bitcast.389.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_1.1599), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.252.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%bitcast.389.clone.1, %bitcast.389.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.136.clone.1 = f32[]{:T(128)} reduce(%square.252.clone.1, %constant.1206), dimensions={0,1}, to_apply=%region_34.39, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.170 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.135, %reduce.136.clone.1) +} + +%region_64.69 (reduce_sum.530: f32[], reduce_sum.534: f32[]) -> f32[] { + %reduce_sum.534 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.530 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.535 = f32[]{:T(128)} add(%reduce_sum.530, %reduce_sum.534), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_49.54 (reduce_sum.452: f32[], reduce_sum.453: f32[]) -> f32[] { + %reduce_sum.453 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.452 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.457 = f32[]{:T(128)} add(%reduce_sum.452, %reduce_sum.453), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.394 (param_0.1416: f32[2048,4], param_1.1590: f32[], param_2.1332: f32[], param_3.915: f32[], param_4.554: f32[2048,4], param_5.488: f32[], param_6.363: f32[4,2048], param_7.207: pred[], param_8.125: f32[2048,4]) -> (f32[], f32[2048,4], f32[2048,4], f32[2048,4], f32[]) { + %param_0.1416 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.915 = f32[]{:T(128)S(6)} parameter(3) + %mul.2507.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_3.915), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.207 = pred[]{:T(512)S(6)} parameter(7) + %select_n.296.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.207), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.363 = f32[4,2048]{1,0:T(4,128)S(1)} parameter(6) + %bitcast.459.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_6.363), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.488 = f32[]{:T(128)} parameter(5) + %div.916.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_5.488), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.915.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%bitcast.459.clone.1, %div.916.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.295.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%select_n.296.clone.1, %bitcast.459.clone.1, %div.915.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1122.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.788.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1122.clone.1), dimensions={}, metadata={op_name="broadcast.82"} + %mul.2511.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.295.clone.1, %broadcast.788.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.125 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(8) + %constant.1126.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.787.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1126.clone.1), dimensions={}, metadata={op_name="broadcast.81"} + %mul.2510.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_8.125, %broadcast.787.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.954.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.2511.clone.1, %mul.2510.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1332 = f32[]{:T(128)S(6)} parameter(2) + %div.912.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_2.1332), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.72.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.295.clone.1, %select_n.295.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1125.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.786.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1125.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.2509.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%integer_pow.72.clone.1, %broadcast.786.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.554 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1124.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.785.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1124.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.2508.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_4.554, %broadcast.785.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.953.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.2509.clone.1, %mul.2508.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1590 = f32[]{:T(128)S(6)} parameter(1) + %div.911.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_1.1590), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.910.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.953.clone.1, %div.911.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.69.clone.1 = f32[2048,4]{0,1:T(4,128)} sqrt(%div.910.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1123.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.783.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1123.clone.1), dimensions={}, metadata={op_name="broadcast.64"} + %add.952.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%sqrt.69.clone.1, %broadcast.783.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.294.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%div.912.clone.1, %add.952.clone.1), metadata={op_name="multiply.39"} + %div.909.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.954.clone.1, %multiply.294.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2506.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_0.1416, %broadcast.788.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.951.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%div.909.clone.1, %mul.2506.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2505.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%mul.2507.clone.1, %add.951.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.950.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%param_0.1416, %mul.2505.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.253 = f32[2048,4]{0,1:T(4,128)} multiply(%add.950.clone.1, %add.950.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1195 = f32[]{:T(128)} constant(0) + %reduce.137 = f32[]{:T(128)} reduce(%square.253, %constant.1195), dimensions={0,1}, to_apply=%region_64.69, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.139.clone.1 = f32[]{:T(128)} reduce(%integer_pow.72.clone.1, %constant.1195), dimensions={0,1}, to_apply=%region_49.54, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.153 = (f32[]{:T(128)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.137, %add.950.clone.1, %add.953.clone.1, %add.954.clone.1, %reduce.139.clone.1) +} + +%region_63.68 (reduce_sum.527: f32[], reduce_sum.528: f32[]) -> f32[] { + %reduce_sum.528 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.527 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.529 = f32[]{:T(128)} add(%reduce_sum.527, %reduce_sum.528), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_48.53 (reduce_sum.446: f32[], reduce_sum.450: f32[]) -> f32[] { + %reduce_sum.450 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.446 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.451 = f32[]{:T(128)} add(%reduce_sum.446, %reduce_sum.450), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.395 (param_0.1417: f32[2048,4], param_1.1591: f32[], param_2.1333: f32[], param_3.916: f32[], param_4.555: f32[2048,4], param_5.489: f32[], param_6.364: f32[4,2048], param_7.208: pred[], param_8.126: f32[2048,4]) -> (f32[], f32[2048,4], f32[2048,4], f32[2048,4], f32[]) { + %param_0.1417 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.916 = f32[]{:T(128)S(6)} parameter(3) + %mul.2514.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_3.916), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.208 = pred[]{:T(512)S(6)} parameter(7) + %select_n.300.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.208), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.364 = f32[4,2048]{1,0:T(4,128)} parameter(6) + %bitcast.461.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_6.364), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.489 = f32[]{:T(128)} parameter(5) + %div.924.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_5.489), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.923.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%bitcast.461.clone.1, %div.924.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.299.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%select_n.300.clone.1, %bitcast.461.clone.1, %div.923.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1128.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.794.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1128.clone.1), dimensions={}, metadata={op_name="broadcast.82"} + %mul.2518.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.299.clone.1, %broadcast.794.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.126 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(8) + %constant.1132.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.793.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1132.clone.1), dimensions={}, metadata={op_name="broadcast.81"} + %mul.2517.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_8.126, %broadcast.793.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.959.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.2518.clone.1, %mul.2517.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1333 = f32[]{:T(128)S(6)} parameter(2) + %div.920.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_2.1333), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.73.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.299.clone.1, %select_n.299.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1131.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.792.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1131.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.2516.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%integer_pow.73.clone.1, %broadcast.792.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.555 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1130.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.791.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1130.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.2515.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_4.555, %broadcast.791.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.958.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.2516.clone.1, %mul.2515.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1591 = f32[]{:T(128)S(6)} parameter(1) + %div.919.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_1.1591), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.918.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.958.clone.1, %div.919.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.70.clone.1 = f32[2048,4]{0,1:T(4,128)} sqrt(%div.918.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1129.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.789.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1129.clone.1), dimensions={}, metadata={op_name="broadcast.64"} + %add.957.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%sqrt.70.clone.1, %broadcast.789.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.295.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%div.920.clone.1, %add.957.clone.1), metadata={op_name="multiply.38"} + %div.917.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.959.clone.1, %multiply.295.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2513.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_0.1417, %broadcast.794.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.956.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%div.917.clone.1, %mul.2513.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2512.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%mul.2514.clone.1, %add.956.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.955.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%param_0.1417, %mul.2512.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.254 = f32[2048,4]{0,1:T(4,128)} multiply(%add.955.clone.1, %add.955.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1196 = f32[]{:T(128)} constant(0) + %reduce.138 = f32[]{:T(128)} reduce(%square.254, %constant.1196), dimensions={0,1}, to_apply=%region_63.68, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.140.clone.1 = f32[]{:T(128)} reduce(%integer_pow.73.clone.1, %constant.1196), dimensions={0,1}, to_apply=%region_48.53, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.154 = (f32[]{:T(128)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.138, %add.955.clone.1, %add.958.clone.1, %add.959.clone.1, %reduce.140.clone.1) +} + +%region_11.14 (reduce_sum.269: f32[], reduce_sum.270: f32[]) -> f32[] { + %reduce_sum.270 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.269 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.271 = f32[]{:T(128)} add(%reduce_sum.269, %reduce_sum.270), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.406 (param_0.1431: bf16[2048]) -> f32[] { + %param_0.1431 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %convert_element_type.1420 = f32[2048]{0:T(1024)} convert(%param_0.1431), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %square.257 = f32[2048]{0:T(1024)} multiply(%convert_element_type.1420, %convert_element_type.1420), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1210 = f32[]{:T(128)} constant(0) + ROOT %reduce.141 = f32[]{:T(128)} reduce(%square.257, %constant.1210), dimensions={0}, to_apply=%region_11.14, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%region_59.64 (reduce_sum.506: f32[], reduce_sum.507: f32[]) -> f32[] { + %reduce_sum.507 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.506 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.508 = f32[]{:T(128)} add(%reduce_sum.506, %reduce_sum.507), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_44.49 (reduce_sum.425: f32[], reduce_sum.429: f32[]) -> f32[] { + %reduce_sum.429 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.425 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.430 = f32[]{:T(128)} add(%reduce_sum.425, %reduce_sum.429), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.407 (param_0.1421: f32[2048], param_1.1595: f32[], param_2.1337: f32[], param_3.920: f32[], param_4.559: f32[2048], param_5.493: f32[], param_6.368: bf16[2048], param_7.212: pred[], param_8.130: f32[2048]) -> (f32[], f32[2048], f32[2048], f32[2048], f32[]) { + %param_0.1421 = f32[2048]{0:T(1024)S(1)} parameter(0) + %param_3.920 = f32[]{:T(128)S(6)} parameter(3) + %mul.2545.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_3.920), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.212 = pred[]{:T(512)S(6)} parameter(7) + %select_n.316.clone.1 = pred[2048]{0:T(1024)(128)(4,1)} broadcast(%param_7.212), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.368 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(6) + %convert_element_type.1435.clone.1 = f32[2048]{0:T(1024)} convert(%param_6.368), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_5.493 = f32[]{:T(128)} parameter(5) + %div.956.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_5.493), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.955.clone.1 = f32[2048]{0:T(1024)} divide(%convert_element_type.1435.clone.1, %div.956.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.315.clone.1 = f32[2048]{0:T(1024)} select(%select_n.316.clone.1, %convert_element_type.1435.clone.1, %div.955.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1152.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.810.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1152.clone.1), dimensions={}, metadata={op_name="broadcast.86"} + %mul.2551.clone.1 = f32[2048]{0:T(1024)} multiply(%select_n.315.clone.1, %broadcast.810.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.130 = f32[2048]{0:T(1024)S(1)} parameter(8) + %constant.1156.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2552.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1156.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2550.clone.1 = f32[2048]{0:T(1024)} multiply(%param_8.130, %mul.2552.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.981.clone.1 = f32[2048]{0:T(1024)S(1)} add(%mul.2551.clone.1, %mul.2550.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1337 = f32[]{:T(128)S(6)} parameter(2) + %div.952.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_2.1337), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.77.clone.1 = f32[2048]{0:T(1024)} multiply(%select_n.315.clone.1, %select_n.315.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1155.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2549.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1155.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2547.clone.1 = f32[2048]{0:T(1024)} multiply(%integer_pow.77.clone.1, %mul.2549.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.559 = f32[2048]{0:T(1024)S(1)} parameter(4) + %constant.1154.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2548.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1154.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2546.clone.1 = f32[2048]{0:T(1024)} multiply(%param_4.559, %mul.2548.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.980.clone.1 = f32[2048]{0:T(1024)S(1)} add(%mul.2547.clone.1, %mul.2546.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1595 = f32[]{:T(128)S(6)} parameter(1) + %div.951.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_1.1595), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.950.clone.1 = f32[2048]{0:T(1024)} divide(%add.980.clone.1, %div.951.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.74.clone.1 = f32[2048]{0:T(1024)} sqrt(%div.950.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1153.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.979.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1153.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.978.clone.1 = f32[2048]{0:T(1024)} add(%sqrt.74.clone.1, %add.979.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.299.clone.1 = f32[2048]{0:T(1024)} multiply(%div.952.clone.1, %add.978.clone.1), metadata={op_name="multiply.34"} + %div.949.clone.1 = f32[2048]{0:T(1024)} divide(%add.981.clone.1, %multiply.299.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2544.clone.1 = f32[2048]{0:T(1024)} multiply(%param_0.1421, %broadcast.810.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.977.clone.1 = f32[2048]{0:T(1024)} add(%div.949.clone.1, %mul.2544.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2543.clone.1 = f32[2048]{0:T(1024)} multiply(%mul.2545.clone.1, %add.977.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.976.clone.1 = f32[2048]{0:T(1024)S(1)} add(%param_0.1421, %mul.2543.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.258 = f32[2048]{0:T(1024)} multiply(%add.976.clone.1, %add.976.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1200 = f32[]{:T(128)} constant(0) + %reduce.142 = f32[]{:T(128)} reduce(%square.258, %constant.1200), dimensions={0}, to_apply=%region_59.64, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.143.clone.1 = f32[]{:T(128)} reduce(%integer_pow.77.clone.1, %constant.1200), dimensions={0}, to_apply=%region_44.49, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.157 = (f32[]{:T(128)}, f32[2048]{0:T(1024)S(1)}, f32[2048]{0:T(1024)S(1)}, f32[2048]{0:T(1024)S(1)}, f32[]{:T(128)}) tuple(%reduce.142, %add.976.clone.1, %add.980.clone.1, %add.981.clone.1, %reduce.143.clone.1) +} + +%fused_computation.413 (param_0.1191: s32[512]) -> s32[1024] { + %constant.960 = s32[] constant(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %broadcast.727 = s32[1024]{0:T(1024)} broadcast(%constant.960), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %param_0.1191 = s32[512]{0:T(512)S(1)} parameter(0) + %constant.961 = s32[] constant(2147483647), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %pad.49 = s32[1024]{0:T(1024)} pad(%param_0.1191, %constant.961), padding=0_512, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %constant.959 = s32[] constant(151935), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %broadcast.726 = s32[1024]{0:T(1024)} broadcast(%constant.959), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %clamp.1 = s32[1024]{0:T(1024)} clamp(%broadcast.727, %pad.49, %broadcast.726), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} +} + +%fused_computation.416 (param_0.1190: s32[4,128]) -> s32[512] { + %param_0.1190 = s32[4,128]{1,0:T(4,128)} parameter(0) + %constant.1055 = s32[]{:T(128)} constant(0) + %broadcast.747 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1055), dimensions={}, metadata={op_name="broadcast.95"} + %lt.32 = pred[4,128]{1,0:T(4,128)(4,1)} compare(%param_0.1190, %broadcast.747), direction=LT, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/lt" stack_frame_id=0} + %constant.1052 = s32[]{:T(128)} constant(151936) + %add.901 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1052), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} + %add.879 = s32[4,128]{1,0:T(4,128)} add(%param_0.1190, %add.901), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} + %select_n.178 = s32[4,128]{1,0:T(4,128)} select(%lt.32, %add.879, %param_0.1190), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/select_n" stack_frame_id=0} + ROOT %bitcast.390 = s32[512]{0:T(512)S(1)} bitcast(%select_n.178), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} +} + +%region_40.45 (reduce_sum.410: f32[], reduce_sum.411: f32[]) -> f32[] { + %reduce_sum.411 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.410 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.415 = f32[]{:T(128)} add(%reduce_sum.410, %reduce_sum.411), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_37.42 (reduce_sum.395: f32[], reduce_sum.396: f32[]) -> f32[] { + %reduce_sum.396 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.395 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.397 = f32[]{:T(128)} add(%reduce_sum.395, %reduce_sum.396), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.418 (param_0.1425: f32[4,128], param_1.1597: f32[4,128]) -> (f32[], f32[]) { + %param_0.1425 = f32[4,128]{1,0:T(4,128)} parameter(0) + %bitcast.394 = f32[128,4]{0,1:T(4,128)} bitcast(%param_0.1425), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.261 = f32[128,4]{0,1:T(4,128)} multiply(%bitcast.394, %bitcast.394), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1204 = f32[]{:T(128)} constant(0) + %reduce.144 = f32[]{:T(128)} reduce(%square.261, %constant.1204), dimensions={0,1}, to_apply=%region_40.45, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1597 = f32[4,128]{1,0:T(4,128)} parameter(1) + %bitcast.398.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_1.1597), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.264.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%bitcast.398.clone.1, %bitcast.398.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.145.clone.1 = f32[]{:T(128)} reduce(%square.264.clone.1, %constant.1204), dimensions={0,1}, to_apply=%region_37.42, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.171 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.144, %reduce.145.clone.1) +} + +%region_72.77 (reduce_sum.572: f32[], reduce_sum.576: f32[]) -> f32[] { + %reduce_sum.576 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.572 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.577 = f32[]{:T(128)} add(%reduce_sum.572, %reduce_sum.576), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_58.63 (reduce_sum.500: f32[], reduce_sum.501: f32[]) -> f32[] { + %reduce_sum.501 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.500 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.502 = f32[]{:T(128)} add(%reduce_sum.500, %reduce_sum.501), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.421 (param_0.1432: bf16[4,128], param_1.1603: f32[4,128], param_2.1340: f32[4,128], param_3.922: s32[4,128]) -> (f32[], f32[], pred[4,128], f32[4,128]) { + %param_3.922 = s32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %constant.1158.clone.1 = s32[]{:T(128)} constant(0) + %broadcast.811.clone.1 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1158.clone.1), dimensions={}, metadata={op_name="broadcast.95"} + %ne.6.clone.1 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} compare(%param_3.922, %broadcast.811.clone.1), direction=NE, metadata={op_name="jit(train_step)/jvp()/ne" stack_frame_id=0} + %param_1.1603 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %log.16 = f32[4,128]{1,0:T(4,128)} log(%param_1.1603), metadata={op_name="jit(train_step)/jvp()/log" stack_frame_id=0} + %param_0.1432 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(0) + %reduce_max.15 = f32[4,128]{1,0:T(4,128)} convert(%param_0.1432), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} + %add.903 = f32[4,128]{1,0:T(4,128)} add(%log.16, %reduce_max.15), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} + %square.269 = f32[4,128]{1,0:T(4,128)} multiply(%add.903, %add.903), metadata={op_name="jit(train_step)/jvp()/square" stack_frame_id=0} + %constant.1212 = f32[]{:T(128)} constant(0) + %broadcast.741 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1212), dimensions={}, metadata={op_name="broadcast.50"} + %mul.2434 = f32[4,128]{1,0:T(4,128)} multiply(%square.269, %broadcast.741), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %mul.2414 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %mul.2434, %broadcast.741), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %reduce.146 = f32[]{:T(128)} reduce(%mul.2414, %constant.1212), dimensions={0,1}, to_apply=%region_72.77, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %param_2.1340 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %neg.115.clone.1 = f32[4,128]{1,0:T(4,128)} negate(%param_2.1340), metadata={op_name="jit(train_step)/jvp()/neg" stack_frame_id=0} + %add.880.clone.1 = f32[4,128]{1,0:T(4,128)} add(%neg.115.clone.1, %mul.2434), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} + %mul.2415.clone.1 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %add.880.clone.1, %broadcast.741), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %reduce.149.clone.1 = f32[]{:T(128)} reduce(%mul.2415.clone.1, %constant.1212), dimensions={0,1}, to_apply=%region_58.63, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %mul.2432.clone.1 = f32[4,128]{1,0:T(4,128)} multiply(%add.903, %broadcast.741), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %constant.1056.clone.1 = f32[]{:T(128)} constant(1) + %add.898.clone.1 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1056.clone.1), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} + %add.891.clone.1 = f32[4,128]{1,0:T(4,128)S(1)} add(%mul.2432.clone.1, %add.898.clone.1), metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} + ROOT %tuple.158 = (f32[]{:T(128)}, f32[]{:T(128)}, pred[4,128]{1,0:T(4,128)(4,1)S(1)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%reduce.146, %reduce.149.clone.1, %ne.6.clone.1, %add.891.clone.1) +} + +%region_69.74 (reduce_sum.557: f32[], reduce_sum.558: f32[]) -> f32[] { + %reduce_sum.558 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.557 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.562 = f32[]{:T(128)} add(%reduce_sum.557, %reduce_sum.558), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_54.59 (reduce_sum.479: f32[], reduce_sum.480: f32[]) -> f32[] { + %reduce_sum.480 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.479 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.481 = f32[]{:T(128)} add(%reduce_sum.479, %reduce_sum.480), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.422 (param_0.1411: f32[128,4], param_1.1585: f32[], param_2.1327: f32[], param_3.910: f32[], param_4.549: f32[128,4], param_5.483: f32[], param_6.358: f32[4,128], param_7.202: pred[], param_8.120: f32[128,4]) -> (f32[], f32[128,4], f32[128,4], f32[128,4], f32[]) { + %param_0.1411 = f32[128,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.910 = f32[]{:T(128)S(6)} parameter(3) + %mul.2466.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_3.910), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.202 = pred[]{:T(512)S(6)} parameter(7) + %select_n.276.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.202), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.358 = f32[4,128]{1,0:T(4,128)} parameter(6) + %bitcast.449.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_6.358), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.483 = f32[]{:T(128)} parameter(5) + %div.876.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_5.483), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.875.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%bitcast.449.clone.1, %div.876.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.275.clone.1 = f32[128,4]{0,1:T(4,128)} select(%select_n.276.clone.1, %bitcast.449.clone.1, %div.875.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1092.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.766.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1092.clone.1), dimensions={}, metadata={op_name="broadcast.78"} + %mul.2470.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.275.clone.1, %broadcast.766.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.120 = f32[128,4]{0,1:T(4,128)S(1)} parameter(8) + %constant.1096.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.765.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1096.clone.1), dimensions={}, metadata={op_name="broadcast.77"} + %mul.2469.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_8.120, %broadcast.765.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.927.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.2470.clone.1, %mul.2469.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1327 = f32[]{:T(128)S(6)} parameter(2) + %div.872.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_2.1327), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.67.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.275.clone.1, %select_n.275.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1095.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.764.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1095.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.2468.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%integer_pow.67.clone.1, %broadcast.764.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.549 = f32[128,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1094.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.763.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1094.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.2467.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_4.549, %broadcast.763.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.926.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.2468.clone.1, %mul.2467.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1585 = f32[]{:T(128)S(6)} parameter(1) + %div.871.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_1.1585), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.870.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.926.clone.1, %div.871.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.64.clone.1 = f32[128,4]{0,1:T(4,128)} sqrt(%div.870.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1093.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.761.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1093.clone.1), dimensions={}, metadata={op_name="broadcast.62"} + %add.925.clone.1 = f32[128,4]{0,1:T(4,128)} add(%sqrt.64.clone.1, %broadcast.761.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.289.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%div.872.clone.1, %add.925.clone.1), metadata={op_name="multiply.44"} + %div.869.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.927.clone.1, %multiply.289.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2465.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_0.1411, %broadcast.766.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.924.clone.1 = f32[128,4]{0,1:T(4,128)} add(%div.869.clone.1, %mul.2465.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2464.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%mul.2466.clone.1, %add.924.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.923.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%param_0.1411, %mul.2464.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.265 = f32[128,4]{0,1:T(4,128)} multiply(%add.923.clone.1, %add.923.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1190 = f32[]{:T(128)} constant(0) + %reduce.147 = f32[]{:T(128)} reduce(%square.265, %constant.1190), dimensions={0,1}, to_apply=%region_69.74, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.151.clone.1 = f32[]{:T(128)} reduce(%integer_pow.67.clone.1, %constant.1190), dimensions={0,1}, to_apply=%region_54.59, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.160 = (f32[]{:T(128)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.147, %add.923.clone.1, %add.926.clone.1, %add.927.clone.1, %reduce.151.clone.1) +} + +%region_66.71 (reduce_sum.542: f32[], reduce_sum.543: f32[]) -> f32[] { + %reduce_sum.543 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.542 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.544 = f32[]{:T(128)} add(%reduce_sum.542, %reduce_sum.543), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_51.56 (reduce_sum.464: f32[], reduce_sum.465: f32[]) -> f32[] { + %reduce_sum.465 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.464 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.466 = f32[]{:T(128)} add(%reduce_sum.464, %reduce_sum.465), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.423 (param_0.1414: f32[128,4], param_1.1588: f32[], param_2.1330: f32[], param_3.913: f32[], param_4.552: f32[128,4], param_5.486: f32[], param_6.361: f32[4,128], param_7.205: pred[], param_8.123: f32[128,4]) -> (f32[], f32[128,4], f32[128,4], f32[128,4], f32[]) { + %param_0.1414 = f32[128,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.913 = f32[]{:T(128)S(6)} parameter(3) + %mul.2493.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_3.913), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.205 = pred[]{:T(512)S(6)} parameter(7) + %select_n.288.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.205), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.361 = f32[4,128]{1,0:T(4,128)} parameter(6) + %bitcast.455.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_6.361), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.486 = f32[]{:T(128)} parameter(5) + %div.900.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_5.486), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.899.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%bitcast.455.clone.1, %div.900.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.287.clone.1 = f32[128,4]{0,1:T(4,128)} select(%select_n.288.clone.1, %bitcast.455.clone.1, %div.899.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1110.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.776.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1110.clone.1), dimensions={}, metadata={op_name="broadcast.78"} + %mul.2497.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.287.clone.1, %broadcast.776.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.123 = f32[128,4]{0,1:T(4,128)S(1)} parameter(8) + %constant.1114.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.775.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1114.clone.1), dimensions={}, metadata={op_name="broadcast.77"} + %mul.2496.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_8.123, %broadcast.775.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.944.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.2497.clone.1, %mul.2496.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1330 = f32[]{:T(128)S(6)} parameter(2) + %div.896.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_2.1330), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.70.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.287.clone.1, %select_n.287.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1113.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.774.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1113.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.2495.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%integer_pow.70.clone.1, %broadcast.774.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.552 = f32[128,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1112.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.773.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1112.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.2494.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_4.552, %broadcast.773.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.943.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.2495.clone.1, %mul.2494.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1588 = f32[]{:T(128)S(6)} parameter(1) + %div.895.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_1.1588), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.894.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.943.clone.1, %div.895.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.67.clone.1 = f32[128,4]{0,1:T(4,128)} sqrt(%div.894.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1111.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.771.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1111.clone.1), dimensions={}, metadata={op_name="broadcast.62"} + %add.942.clone.1 = f32[128,4]{0,1:T(4,128)} add(%sqrt.67.clone.1, %broadcast.771.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.292.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%div.896.clone.1, %add.942.clone.1), metadata={op_name="multiply.41"} + %div.893.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.944.clone.1, %multiply.292.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2492.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_0.1414, %broadcast.776.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.941.clone.1 = f32[128,4]{0,1:T(4,128)} add(%div.893.clone.1, %mul.2492.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2491.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%mul.2493.clone.1, %add.941.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.940.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%param_0.1414, %mul.2491.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.266 = f32[128,4]{0,1:T(4,128)} multiply(%add.940.clone.1, %add.940.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1193 = f32[]{:T(128)} constant(0) + %reduce.148 = f32[]{:T(128)} reduce(%square.266, %constant.1193), dimensions={0,1}, to_apply=%region_66.71, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.152.clone.1 = f32[]{:T(128)} reduce(%integer_pow.70.clone.1, %constant.1193), dimensions={0,1}, to_apply=%region_51.56, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.161 = (f32[]{:T(128)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.148, %add.940.clone.1, %add.943.clone.1, %add.944.clone.1, %reduce.152.clone.1) +} + +%fused_computation.432 (param_0.1242: f32[4,128], param_1.1350: f32[4,128]) -> f32[4,128] { + %param_0.1242 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %param_1.1350 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %constant.1046 = f32[]{:T(128)} constant(0.00048828125) + %broadcast.749 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1046), dimensions={}, metadata={op_name="broadcast.362"} + %div.767 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1350, %broadcast.749), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.1044 = f32[]{:T(128)} constant(1e-06) + %add.911 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1044), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %add.910 = f32[4,128]{1,0:T(4,128)} add(%div.767, %add.911), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %rsqrt.168 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.910), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} + %div.754 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.168, %add.910), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.1028 = f32[]{:T(128)} constant(-0.5) + %mul.2440 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1028), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2431 = f32[4,128]{1,0:T(4,128)} multiply(%div.754, %mul.2440), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2430 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1242, %mul.2431), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.1027 = f32[]{:T(128)} constant(0.0009765625) + %mul.2439 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1027), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.2429 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.2430, %mul.2439), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%region_7.10 (reduce_sum.234: s32[], reduce_sum.235: s32[]) -> s32[] { + %reduce_sum.235 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.234 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.236 = s32[]{:T(128)} add(%reduce_sum.234, %reduce_sum.235), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["0","2"]}]}} +} + +%fused_computation.435 (param_0.1261: pred[4,128]) -> s32[] { + %param_0.1261 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %convert_element_type.1427 = s32[4,128]{1,0:T(4,128)} convert(%param_0.1261), metadata={op_name="jit(train_step)/jvp()/convert_element_type" stack_frame_id=0} + %constant.1054 = s32[]{:T(128)} constant(0) + ROOT %reduce.150 = s32[]{:T(128)} reduce(%convert_element_type.1427, %constant.1054), dimensions={0,1}, to_apply=%region_7.10, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} +} + +%fused_computation.439 (param_0.1245: f32[4,128]) -> f32[4,128] { + %param_0.1245 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1047 = f32[]{:T(128)} constant(0.00048828125) + %broadcast.745 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1047), dimensions={}, metadata={op_name="broadcast.362"} + %div.759 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1245, %broadcast.745), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.1045 = f32[]{:T(128)} constant(1e-06) + %add.900 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1045), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %add.897 = f32[4,128]{1,0:T(4,128)} add(%div.759, %add.900), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + ROOT %rsqrt.166 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.897), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} +} + +%fused_computation.440 (param_0.1244: pred[4,128], param_1.1602: f32[]) -> f32[4,128] { + %param_0.1244 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %param_1.1602 = f32[]{:T(128)S(6)} parameter(1) + %broadcast_in_dim.309 = f32[4,128]{1,0:T(4,128)} broadcast(%param_1.1602), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/broadcast_in_dim" stack_frame_id=0} + %constant.1211 = f32[]{:T(128)} constant(0) + %broadcast.743 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1211), dimensions={}, metadata={op_name="broadcast.50"} + ROOT %mul.2441 = f32[4,128]{1,0:T(4,128)S(1)} select(%param_0.1244, %broadcast_in_dim.309, %broadcast.743), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} +} + +%fused_computation.442 () -> f32[64] { + %constant.1050 = f32[]{:T(128)} constant(1e+06) + %broadcast.752 = f32[64]{0:T(128)} broadcast(%constant.1050), dimensions={}, metadata={op_name="broadcast.353"} + %iota.46 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/layers/iota" stack_frame_id=0} + %constant.1049 = s32[]{:T(128)} constant(2) + %broadcast.751 = s32[64]{0:T(128)} broadcast(%constant.1049), dimensions={}, metadata={op_name="broadcast.354"} + %mul.2442 = s32[64]{0:T(128)} multiply(%iota.46, %broadcast.751), metadata={op_name="jit(train_step)/layers/mul" stack_frame_id=0} + %convert_element_type.1428 = f32[64]{0:T(128)} convert(%mul.2442), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %constant.1048 = f32[]{:T(128)} constant(0.0078125) + %broadcast.750 = f32[64]{0:T(128)} broadcast(%constant.1048), dimensions={}, metadata={op_name="broadcast.355"} + %div.768 = f32[64]{0:T(128)} multiply(%convert_element_type.1428, %broadcast.750), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + ROOT %pow.36 = f32[64]{0:T(128)S(1)} power(%broadcast.752, %div.768), metadata={op_name="jit(train_step)/layers/pow" stack_frame_id=0} +} + +%fused_computation.443 (param_0.1259: s32[4,128]) -> (f32[4,128,1,1], f32[4,128]) { + %param_0.1259 = s32[4,128]{1,0:T(4,128)} parameter(0) + %convert_element_type.1429 = f32[4,128]{1,0:T(4,128)S(1)} convert(%param_0.1259), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %bitcast.399 = f32[4,128,1,1]{1,0,3,2:T(4,128)} bitcast(%convert_element_type.1429), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.163 = (f32[4,128,1,1]{1,0,3,2:T(4,128)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%bitcast.399, %convert_element_type.1429) +} + +%fused_computation.446 (param_0.1400: f32[2048,4]) -> bf16[4,2048] { + %param_0.1400 = f32[2048,4]{0,1:T(4,128)} parameter(0) + %bitcast.507 = f32[4,2048]{1,0:T(4,128)} bitcast(%param_0.1400), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.79 = bf16[4,2048]{1,0:T(4,128)(2,1)} convert(%bitcast.507) +} + +%fused_computation.447 (param_0.1401: f32[2048,4]) -> bf16[4,2048] { + %param_0.1401 = f32[2048,4]{0,1:T(4,128)} parameter(0) + %bitcast.508 = f32[4,2048]{1,0:T(4,128)} bitcast(%param_0.1401), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.81 = bf16[4,2048]{1,0:T(4,128)(2,1)} convert(%bitcast.508) +} + +%fused_computation.448 (param_0.1402: f32[128,4]) -> bf16[4,128] { + %param_0.1402 = f32[128,4]{0,1:T(4,128)} parameter(0) + %bitcast.509 = f32[4,128]{1,0:T(4,128)} bitcast(%param_0.1402), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.83 = bf16[4,128]{1,0:T(4,128)(2,1)} convert(%bitcast.509) +} + +%fused_computation.449 (param_0.1403: f32[128,4]) -> bf16[4,128] { + %param_0.1403 = f32[128,4]{0,1:T(4,128)} parameter(0) + %bitcast.510 = f32[4,128]{1,0:T(4,128)} bitcast(%param_0.1403), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.85 = bf16[4,128]{1,0:T(4,128)(2,1)} convert(%bitcast.510) +} + +%region_8.11 (reduce_max.6: bf16[], reduce_max.8: bf16[]) -> bf16[] { + %reduce_max.8 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_max"} + %reduce_max.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_max"} + ROOT %reduce_max.9 = bf16[]{:T(256)} maximum(%reduce_max.6, %reduce_max.8), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.298.clone.clone (param_0.1387: bf16[151936,2048]) -> bf16[151936,2048,1] { + %param_0.1387 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + ROOT %bitcast.503 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%param_0.1387), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} +} + +%fused_computation.379.clone.clone (param_0.1388: f32[4,128], param_1.1569: bf16[4,128,2048], param_2.1292: bf16[2048]) -> bf16[4,128,2048] { + %param_1.1569 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1462 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1569), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1388 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2607 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1388), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2606 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1462, %mul.2607), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1461 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2606), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.1292 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2608 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1292), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.2605 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1461, %mul.2608), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%fused_computation.450 (param_0.1404: bf16[151936,2048], param_1.1578: f32[4,128], param_2.1316: bf16[4,128,2048], param_3.903: bf16[2048]) -> (bf16[4,128], bf16[4,128,151936]) { + %param_1.1578 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1316 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.903 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.280.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1578, %param_2.1316, %param_3.903), kind=kLoop, calls=%fused_computation.379.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.1404 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + %fusion.263.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1404), kind=kLoop, calls=%fused_computation.298.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %convolution.85.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convolution(%fusion.280.clone.1, %fusion.263.clone.1), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %constant.1183 = bf16[]{:T(256)} constant(-inf) + %reduce.153 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} reduce(%convolution.85.clone.1, %constant.1183), dimensions={2}, to_apply=%region_8.11, metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} + ROOT %tuple.165 = (bf16[4,128]{1,0:T(4,128)(2,1)S(1)}, bf16[4,128,151936]{2,1,0:T(8,128)(2,1)}) tuple(%reduce.153, %convolution.85.clone.1) +} + +%fused_computation.451 (param_0.1399: f32[2048,4,8,128]) -> bf16[4,2048,8,128] { + %param_0.1399 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) + %bitcast.506 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} bitcast(%param_0.1399), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.87 = bf16[4,2048,8,128]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.506) +} + +%convert_element_type.785.reduce_sub_computation (lhs: bf16[], rhs: bf16[]) -> bf16[] { + %rhs = bf16[] parameter(1) + %lhs = bf16[] parameter(0) + ROOT %add.730 = bf16[] add(%lhs, %rhs), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.167.clone.clone (param_0.1572: bf16[4,2048], param_1.1711: s32[]) -> bf16[2048] { + %param_0.1572 = bf16[4,2048]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1711 = s32[]{:T(128)S(6)} parameter(1) + %constant.1348 = s32[]{:T(128)} constant(0) + %dynamic_slice.394 = bf16[1,2048]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1572, %param_1.1711, %constant.1348), dynamic_slice_sizes={1,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %constant.1349 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %reduce.174 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.394, %constant.1349), dimensions={0}, to_apply=%convert_element_type.785.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%region_14.16 (reduce_sum.278: f32[], reduce_sum.282: f32[]) -> f32[] { + %reduce_sum.282 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.278 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.283 = f32[]{:T(128)} add(%reduce_sum.278, %reduce_sum.282), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.61.clone.clone (param_0.1573: bf16[4,4,128,2048], param_1.1712: s32[]) -> f32[4,128] { + %param_0.1573 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1712 = s32[]{:T(128)S(6)} parameter(1) + %constant.1350 = s32[]{:T(128)} constant(0) + %dynamic_slice.395 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1573, %param_1.1712, %constant.1350, %constant.1350, %constant.1350), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %bitcast.602 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.395), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} + %convert_element_type.1585 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.602), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.280 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1585, %convert_element_type.1585), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1351 = f32[]{:T(128)} constant(0) + ROOT %reduce.175 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.280, %constant.1351), dimensions={2}, to_apply=%region_14.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} +} + +%fused_computation.190.clone.1.clone (param_0.1574: f32[4,128]) -> f32[4,128] { + %param_0.1574 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1353 = f32[]{:T(128)} constant(0.00048828125) + %closed_call.106 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1353), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.999 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1574, %closed_call.106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1352 = f32[]{:T(128)} constant(1e-06) + %closed_call.105 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1352), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.1015 = f32[4,128]{1,0:T(4,128)} add(%div.999, %closed_call.105), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.181 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.1015), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%region_15.17 (reduce_sum.284: f32[], reduce_sum.285: f32[]) -> f32[] { + %reduce_sum.285 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.284 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.289 = f32[]{:T(128)} add(%reduce_sum.284, %reduce_sum.285), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.25.clone.1.clone.clone.clone.clone (param_0.1587: bf16[4,2048,16,128], param_1.1721: s32[]) -> bf16[2048,16,128,1] { + %param_0.1587 = bf16[4,2048,16,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1721 = s32[]{:T(128)S(6)} parameter(1) + %constant.1363 = s32[]{:T(128)} constant(0) + %dynamic_slice.400 = bf16[1,2048,16,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1587, %param_1.1721, %constant.1363, %constant.1363, %constant.1363), dynamic_slice_sizes={1,2048,16,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.610 = bf16[2048,16,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.400), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.89.clone.2.clone.clone.clone.clone (param_0.1588: f32[4,128], param_1.1722: bf16[2048], param_2.1414: bf16[4,4,128,2048], param_3.972: s32[]) -> bf16[4,128,2048,1] { + %param_2.1414 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.972 = s32[]{:T(128)S(6)} parameter(3) + %constant.1364 = s32[]{:T(128)} constant(0) + %dynamic_slice.401 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1414, %param_3.972, %constant.1364, %constant.1364, %constant.1364), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %convert_element_type.1596 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} convert(%dynamic_slice.401), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1588 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2890 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_0.1588), dimensions={1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2889 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} multiply(%convert_element_type.1596, %mul.2890), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1595 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} convert(%mul.2889), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1722 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(1) + %mul.2891 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} broadcast(%param_1.1722), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2888 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1595, %mul.2891), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.611 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2888), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.64.clone.clone (param_0.1589: bf16[4,2048,16,128], param_1.1723: s32[], param_2.1415: f32[4,128], param_3.973: bf16[2048], param_4.596: bf16[4,4,128,2048]) -> (f32[4,128,16], bf16[4,128,16,128]) { + %param_2.1415 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.973 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %param_4.596 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(4) + %param_1.1723 = s32[]{:T(128)S(6)} parameter(1) + %fusion.91.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1415, %param_3.973, %param_4.596, %param_1.1723), kind=kLoop, calls=%fused_computation.89.clone.2.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1589 = bf16[4,2048,16,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.49.clone.3 = bf16[2048,16,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1589, %param_1.1723), kind=kLoop, calls=%fused_computation.25.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.44.clone.3 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.91.clone.3, %fusion.49.clone.3), window={size=1x16 pad=0_0x15_15 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %convert_element_type.1597 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%convolution.44.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.282 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1597, %convert_element_type.1597), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1365 = f32[]{:T(128)} constant(0) + %reduce.177 = f32[4,128,16]{1,2,0:T(8,128)S(1)} reduce(%square.282, %constant.1365), dimensions={3}, to_apply=%region_15.17, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.210 = (f32[4,128,16]{1,2,0:T(8,128)S(1)}, bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%reduce.177, %convolution.44.clone.3) +} + +%fused_computation.162.clone.1.clone (param_0.1590: f32[4,128,16]) -> f32[4,128,16] { + %param_0.1590 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(0) + %constant.1366 = f32[]{:T(128)} constant(0.0078125) + %closed_call.108 = f32[4,128,16]{1,2,0:T(8,128)} broadcast(%constant.1366), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.1001 = f32[4,128,16]{1,2,0:T(8,128)} multiply(%param_0.1590, %closed_call.108), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1367 = f32[]{:T(128)} constant(1e-06) + %add.1020 = f32[4,128,16]{1,2,0:T(8,128)} broadcast(%constant.1367), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %add.1019 = f32[4,128,16]{1,2,0:T(8,128)} add(%div.1001, %add.1020), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.183 = f32[4,128,16]{1,2,0:T(8,128)S(1)} rsqrt(%add.1019), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%fused_computation.193.clone.clone (param_0.1591: bf16[4,128], param_1.1724: s32[]) -> bf16[128] { + %param_0.1591 = bf16[4,128]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1724 = s32[]{:T(128)S(6)} parameter(1) + %constant.1368 = s32[]{:T(128)} constant(0) + %dynamic_slice.402 = bf16[1,128]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1591, %param_1.1724, %constant.1368), dynamic_slice_sizes={1,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.612 = bf16[128]{0:T(256)(128)(2,1)S(1)} bitcast(%dynamic_slice.402), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.118.clone.1.clone (param_0.1592: f32[4,128,16], param_1.1725: bf16[4,128,16,128], param_2.1416: bf16[128]) -> bf16[4,128,16,128] { + %param_1.1725 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1599 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%param_1.1725), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1592 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(0) + %mul.2894 = f32[4,128,16,128]{3,1,2,0:T(8,128)} broadcast(%param_0.1592), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2893 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1599, %mul.2894), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1598 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2893), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1416 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(2) + %mul.2895 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1416), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2892 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} multiply(%convert_element_type.1598, %mul.2895), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.93.clone.clone (param_0.1593: bf16[4,128,16,128]) -> (bf16[4,128,16,64], bf16[4,128,16,64]) { + %param_0.1593 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.160 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1593), slice={[0:4], [0:128], [0:16], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + %neg.129 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.160), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=0} + %split.161 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1593), slice={[0:4], [0:128], [0:16], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + ROOT %tuple.211 = (bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.129, %split.161) +} + +%fused_computation.198.clone.clone () -> f32[64] { + %constant.1343 = f32[]{:T(128)} constant(1e+06) + %closed_call.104 = f32[64]{0:T(128)} broadcast(%constant.1343), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %iota.51 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/iota" stack_frame_id=0} + %constant.1342 = s32[]{:T(128)} constant(2) + %closed_call.103 = s32[64]{0:T(128)} broadcast(%constant.1342), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %mul.2867 = s32[64]{0:T(128)} multiply(%iota.51, %closed_call.103), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1583 = f32[64]{0:T(128)} convert(%mul.2867), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %constant.1344 = f32[]{:T(128)} constant(0.0078125) + %closed_call.102 = f32[64]{0:T(128)} broadcast(%constant.1344), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.995 = f32[64]{0:T(128)} multiply(%convert_element_type.1583, %closed_call.102), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + ROOT %pow.38 = f32[64]{0:T(128)S(1)} power(%closed_call.104, %div.995), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/pow" stack_frame_id=0} +} + +%fused_computation.154.clone.clone (param_0.1569: f32[64], param_1.1709: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.1709 = f32[4,128]{1,0:T(4,128)} parameter(1) + %div.998 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1709), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %param_0.1569 = f32[64]{0:T(128)S(1)} parameter(0) + %div.997 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1569), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %div.996 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.998, %div.997), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %cos.43 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.996), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/cos" stack_frame_id=0} + %convert_element_type.1584 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.43), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %sin.35.clone.3 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.996), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/sin" stack_frame_id=0} + %convert_element_type.1213.clone.3 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.35.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.207 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1584, %convert_element_type.1213.clone.3) +} + +%fused_computation.157.clone.1.clone (param_0.1570: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1570 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1345 = bf16[]{:T(256)} constant(-inf) + %pad.69 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1570, %constant.1345), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %pad.68 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1570, %constant.1345), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %maximum.53 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.69, %pad.68), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + ROOT %bitcast.601 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.53), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.156.clone.1.clone (param_0.1583: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1583 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1361 = bf16[]{:T(256)} constant(-inf) + %pad.71 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1583, %constant.1361), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %pad.70 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1583, %constant.1361), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %maximum.54 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.71, %pad.70), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + ROOT %bitcast.608 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.54), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.102.clone.clone (param_0.1594: bf16[4,128,16,64], param_1.1726: bf16[4,128,16,64], param_2.1417: bf16[4,128,128], param_3.974: bf16[4,128,128], param_4.597: bf16[128], param_5.512: f32[4,128,16], param_6.382: bf16[4,128,16,128]) -> bf16[4,16,128,128] { + %param_6.382 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(6) + %convert_element_type.1601 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%param_6.382), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_5.512 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(5) + %mul.2904 = f32[4,128,16,128]{3,1,2,0:T(8,128)} broadcast(%param_5.512), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2903 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1601, %mul.2904), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1600 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2903), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_4.597 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(4) + %mul.2902 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.597), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2901 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%convert_element_type.1600, %mul.2902), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_3.974 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2900 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.974), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2898 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%mul.2901, %mul.2900), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_1.1726 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1369 = bf16[]{:T(256)} constant(-inf) + %pad.75 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1726, %constant.1369), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_0.1594 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.74 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1594, %constant.1369), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %maximum.56 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.75, %pad.74), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_2.1417 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %mul.2899 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1417), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2897 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.56, %mul.2899), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %add.1021 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2898, %mul.2897), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %constant.1370 = bf16[]{:T(256)} constant(0.08838) + %closed_call.109 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%constant.1370), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %mul.2896 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%add.1021, %closed_call.109), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.613 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%mul.2896), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%region_16.18 (reduce_sum.290: f32[], reduce_sum.291: f32[]) -> f32[] { + %reduce_sum.291 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.290 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.292 = f32[]{:T(128)} add(%reduce_sum.290, %reduce_sum.291), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.70.clone.1.clone.clone.clone.clone (param_0.1579: bf16[4,2048,8,128], param_1.1716: s32[]) -> bf16[2048,8,128,1] { + %param_0.1579 = bf16[4,2048,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1716 = s32[]{:T(128)S(6)} parameter(1) + %constant.1356 = s32[]{:T(128)} constant(0) + %dynamic_slice.398 = bf16[1,2048,8,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1579, %param_1.1716, %constant.1356, %constant.1356, %constant.1356), dynamic_slice_sizes={1,2048,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.606 = bf16[2048,8,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.398), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.89.clone.1.clone.clone.clone.clone (param_0.1580: f32[4,128], param_1.1717: bf16[2048], param_2.1410: bf16[4,4,128,2048], param_3.969: s32[]) -> bf16[4,128,2048,1] { + %param_2.1410 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.969 = s32[]{:T(128)S(6)} parameter(3) + %constant.1357 = s32[]{:T(128)} constant(0) + %dynamic_slice.399 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1410, %param_3.969, %constant.1357, %constant.1357, %constant.1357), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %convert_element_type.1589 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} convert(%dynamic_slice.399), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1580 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2874 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_0.1580), dimensions={1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2873 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} multiply(%convert_element_type.1589, %mul.2874), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1588 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} convert(%mul.2873), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1717 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(1) + %mul.2875 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} broadcast(%param_1.1717), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2872 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1588, %mul.2875), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.607 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2872), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.85.clone.clone (param_0.1581: bf16[4,2048,8,128], param_1.1718: s32[], param_2.1411: f32[4,128], param_3.970: bf16[2048], param_4.594: bf16[4,4,128,2048]) -> (f32[4,128,8], bf16[4,128,8,128]) { + %param_2.1411 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.970 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %param_4.594 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(4) + %param_1.1718 = s32[]{:T(128)S(6)} parameter(1) + %fusion.90.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1411, %param_3.970, %param_4.594, %param_1.1718), kind=kLoop, calls=%fused_computation.89.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1581 = bf16[4,2048,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.85.clone.3 = bf16[2048,8,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1581, %param_1.1718), kind=kLoop, calls=%fused_computation.70.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.56.clone.3 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.90.clone.3, %fusion.85.clone.3), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %convert_element_type.1590 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%convolution.56.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.281 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1590, %convert_element_type.1590), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1358 = f32[]{:T(128)} constant(0) + %reduce.176 = f32[4,128,8]{1,2,0:T(8,128)S(1)} reduce(%square.281, %constant.1358), dimensions={3}, to_apply=%region_16.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.208 = (f32[4,128,8]{1,2,0:T(8,128)S(1)}, bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%reduce.176, %convolution.56.clone.3) +} + +%fused_computation.165.clone.1.clone (param_0.1582: f32[4,128,8]) -> f32[4,128,8] { + %param_0.1582 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(0) + %constant.1359 = f32[]{:T(128)} constant(0.0078125) + %closed_call.107 = f32[4,128,8]{1,2,0:T(8,128)} broadcast(%constant.1359), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.1000 = f32[4,128,8]{1,2,0:T(8,128)} multiply(%param_0.1582, %closed_call.107), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1360 = f32[]{:T(128)} constant(1e-06) + %add.1017 = f32[4,128,8]{1,2,0:T(8,128)} broadcast(%constant.1360), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %add.1016 = f32[4,128,8]{1,2,0:T(8,128)} add(%div.1000, %add.1017), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.182 = f32[4,128,8]{1,2,0:T(8,128)S(1)} rsqrt(%add.1016), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%fused_computation.195.clone.clone (param_0.1568: bf16[4,128], param_1.1708: s32[]) -> bf16[128] { + %param_0.1568 = bf16[4,128]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1708 = s32[]{:T(128)S(6)} parameter(1) + %constant.1341 = s32[]{:T(128)} constant(0) + %dynamic_slice.392 = bf16[1,128]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1568, %param_1.1708, %constant.1341), dynamic_slice_sizes={1,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.600 = bf16[128]{0:T(256)(128)(2,1)S(1)} bitcast(%dynamic_slice.392), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.144.clone.1.clone (param_0.1584: f32[4,128,8], param_1.1719: bf16[4,128,8,128], param_2.1412: bf16[128]) -> bf16[4,128,8,128] { + %param_1.1719 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1592 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%param_1.1719), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1584 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(0) + %mul.2878 = f32[4,128,8,128]{3,1,2,0:T(8,128)} broadcast(%param_0.1584), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2877 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1592, %mul.2878), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1591 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2877), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1412 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(2) + %mul.2879 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1412), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2876 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} multiply(%convert_element_type.1591, %mul.2879), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.125.clone.clone (param_0.1585: bf16[4,128,8,128]) -> (bf16[4,128,8,64], bf16[4,128,8,64]) { + %param_0.1585 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.158 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1585), slice={[0:4], [0:128], [0:8], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + %neg.128 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.158), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=0} + %split.159 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1585), slice={[0:4], [0:128], [0:8], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + ROOT %tuple.209 = (bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.128, %split.159) +} + +%fused_computation.128.clone.clone (param_0.1586: bf16[4,128,8,64], param_1.1720: bf16[4,128,8,64], param_2.1413: bf16[4,128,128], param_3.971: bf16[4,128,128], param_4.595: bf16[128], param_5.511: f32[4,128,8], param_6.381: bf16[4,128,8,128]) -> bf16[4,8,128,128] { + %param_6.381 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(6) + %convert_element_type.1594 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%param_6.381), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_5.511 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(5) + %mul.2887 = f32[4,128,8,128]{3,1,2,0:T(8,128)} broadcast(%param_5.511), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2886 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1594, %mul.2887), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1593 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2886), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_4.595 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(4) + %mul.2885 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.595), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2884 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%convert_element_type.1593, %mul.2885), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_3.971 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2883 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.971), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2881 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%mul.2884, %mul.2883), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_1.1720 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1362 = bf16[]{:T(256)} constant(-inf) + %pad.73 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1720, %constant.1362), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_0.1586 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.72 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1586, %constant.1362), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %maximum.55 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.73, %pad.72), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_2.1413 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %mul.2882 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1413), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2880 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.55, %mul.2882), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %add.1018 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2881, %mul.2880), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %bitcast.609 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.1018), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.181.clone.clone (param_0.1575: bf16[4,2048,8,128], param_1.1713: s32[]) -> bf16[1,2048,8,128] { + %param_0.1575 = bf16[4,2048,8,128]{3,2,0,1:T(8,128)(2,1)} parameter(0) + %param_1.1713 = s32[]{:T(128)S(6)} parameter(1) + %constant.1354 = s32[]{:T(128)} constant(0) + ROOT %dynamic_slice.396 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} dynamic-slice(%param_0.1575, %param_1.1713, %constant.1354, %constant.1354, %constant.1354), dynamic_slice_sizes={1,2048,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +} + +%fused_computation.71.clone.1.clone.clone.clone.clone (param_0.1576: bf16[1,2048,8,128]) -> bf16[2048,8,128,1] { + %param_0.1576 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %copy.200 = bf16[1,2048,8,128]{3,1,2,0:T(8,128)(2,1)} copy(%param_0.1576), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0} + ROOT %bitcast.603 = bf16[2048,8,128,1]{2,0,1,3:T(8,128)(2,1)} bitcast(%copy.200), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.89.clone.clone.clone.clone.clone (param_0.1577: f32[4,128], param_1.1714: bf16[2048], param_2.1408: bf16[4,4,128,2048], param_3.967: s32[]) -> bf16[4,128,2048,1] { + %param_2.1408 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.967 = s32[]{:T(128)S(6)} parameter(3) + %constant.1355 = s32[]{:T(128)} constant(0) + %dynamic_slice.397 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1408, %param_3.967, %constant.1355, %constant.1355, %constant.1355), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %convert_element_type.1587 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} convert(%dynamic_slice.397), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1577 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2870 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_0.1577), dimensions={1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2869 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} multiply(%convert_element_type.1587, %mul.2870), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1586 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} convert(%mul.2869), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1714 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(1) + %mul.2871 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} broadcast(%param_1.1714), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2868 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1586, %mul.2871), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.604 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2868), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.151.clone.clone (param_0.1578: bf16[1,2048,8,128], param_1.1715: f32[4,128], param_2.1409: bf16[2048], param_3.968: bf16[4,4,128,2048], param_4.593: s32[]) -> bf16[4,8,128,128] { + %param_1.1715 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1409 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %param_3.968 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.593 = s32[]{:T(128)S(6)} parameter(4) + %fusion.380 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_1.1715, %param_2.1409, %param_3.968, %param_4.593), kind=kLoop, calls=%fused_computation.89.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1578 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %fusion.379 = bf16[2048,8,128,1]{2,0,1,3:T(8,128)(2,1)} fusion(%param_0.1578), kind=kLoop, calls=%fused_computation.71.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.105 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convolution(%fusion.380, %fusion.379), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + ROOT %bitcast.605 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%convolution.105), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.199.clone.clone (param_0.1618: f32[4,16,128,128]) -> (f32[4,16,128], f32[4,16,128,1]) { + %param_0.1618 = f32[4,16,128,128]{2,1,0,3:T(8,128)S(1)} parameter(0) + %slice.11 = f32[4,16,128,1]{2,1,0,3:T(8,128)S(1)} slice(%param_0.1618), slice={[0:4], [0:16], [0:128], [0:1]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/slice" stack_frame_id=0} + %bitcast.626 = f32[4,16,128]{2,1,0:T(8,128)S(1)} bitcast(%slice.11), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/squeeze" stack_frame_id=0} + ROOT %tuple.216 = (f32[4,16,128]{2,1,0:T(8,128)S(1)}, f32[4,16,128,1]{2,1,0,3:T(8,128)S(1)}) tuple(%bitcast.626, %slice.11) +} + +%region_17.20 (reduce_sum.296: f32[], reduce_sum.297: f32[]) -> f32[] { + %reduce_sum.297 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.296 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.298 = f32[]{:T(128)} add(%reduce_sum.296, %reduce_sum.297), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.26.clone.1.clone.clone.clone.clone.clone.clone (param_0.1595: bf16[4,16,128,2048], param_1.1727: s32[]) -> bf16[16,128,2048,1] { + %param_0.1595 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1727 = s32[]{:T(128)S(6)} parameter(1) + %constant.1371 = s32[]{:T(128)} constant(0) + %dynamic_slice.403 = bf16[1,16,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1595, %param_1.1727, %constant.1371, %constant.1371, %constant.1371), dynamic_slice_sizes={1,16,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.614 = bf16[16,128,2048,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dynamic_slice.403), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.110.clone.clone.clone.clone.clone.clone (param_0.1596: bf16[4,16,128,128]) -> bf16[4,128,16,128] { + %param_0.1596 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.615 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} bitcast(%param_0.1596), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.67.clone.clone (param_0.1597: bf16[4,16,128,2048], param_1.1728: s32[], param_2.1418: bf16[4,16,128,128], param_3.975: bf16[4,4,128,2048]) -> (f32[4,128], bf16[4,128,2048]) { + %param_3.975 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1728 = s32[]{:T(128)S(6)} parameter(1) + %constant.414.clone.1.clone.3 = s32[]{:T(128)} constant(0) + %dynamic_slice.265.clone.3 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_3.975, %param_1.1728, %constant.414.clone.1.clone.3, %constant.414.clone.1.clone.3, %constant.414.clone.1.clone.3), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %bitcast.212.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.265.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} + %param_2.1418 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %fusion.100.clone.3 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} fusion(%param_2.1418), kind=kLoop, calls=%fused_computation.110.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} + %param_0.1597 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %fusion.99.clone.3 = bf16[16,128,2048,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.1597, %param_1.1728), kind=kLoop, calls=%fused_computation.26.clone.1.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.62.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} convolution(%fusion.100.clone.3, %fusion.99.clone.3), window={size=1x16}, dim_labels=0b1f_1io0->0bf1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %bitcast.204.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%convolution.62.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %add.744.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} add(%bitcast.212.clone.3, %bitcast.204.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %convert_element_type.1602 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%add.744.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.283 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1602, %convert_element_type.1602), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1372 = f32[]{:T(128)} constant(0) + %reduce.178 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.283, %constant.1372), dimensions={2}, to_apply=%region_17.20, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.212 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.178, %add.744.clone.3) +} + +%convert_element_type.808.reduce_sub_computation (lhs.1: bf16[], rhs.1: bf16[]) -> bf16[] { + %rhs.1 = bf16[] parameter(1) + %lhs.1 = bf16[] parameter(0) + ROOT %add.731 = bf16[] add(%lhs.1, %rhs.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.166.clone.clone (param_0.1571: bf16[4,2048], param_1.1710: s32[]) -> bf16[2048] { + %param_0.1571 = bf16[4,2048]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1710 = s32[]{:T(128)S(6)} parameter(1) + %constant.1346 = s32[]{:T(128)} constant(0) + %dynamic_slice.393 = bf16[1,2048]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1571, %param_1.1710, %constant.1346), dynamic_slice_sizes={1,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %constant.1347 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %reduce.173 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.393, %constant.1347), dimensions={0}, to_apply=%convert_element_type.808.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.191.clone.1.clone (param_0.1598: f32[4,128]) -> f32[4,128] { + %param_0.1598 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1374 = f32[]{:T(128)} constant(0.00048828125) + %closed_call.111 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1374), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.1002 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1598, %closed_call.111), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1373 = f32[]{:T(128)} constant(1e-06) + %closed_call.110 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1373), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.1022 = f32[4,128]{1,0:T(4,128)} add(%div.1002, %closed_call.110), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.184 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.1022), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%fused_computation.11.clone.1.clone.clone (param_0.1599: bf16[4,2048,6144], param_1.1729: s32[]) -> bf16[2048,6144,1] { + %param_0.1599 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1729 = s32[]{:T(128)S(6)} parameter(1) + %constant.1375 = s32[]{:T(128)} constant(0) + %dynamic_slice.404 = bf16[1,2048,6144]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1599, %param_1.1729, %constant.1375, %constant.1375), dynamic_slice_sizes={1,2048,6144}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.616 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.404), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.116.clone.3.clone.clone (param_0.1600: f32[4,128], param_1.1730: bf16[4,128,2048], param_2.1419: bf16[2048]) -> bf16[4,128,2048] { + %param_1.1730 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1604 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1730), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1600 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2907 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1600), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2906 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1604, %mul.2907), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1603 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2906), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1419 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2908 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1419), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2905 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1603, %mul.2908), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.23.clone.clone (param_0.1601: bf16[4,2048,6144], param_1.1731: s32[], param_2.1420: f32[4,128], param_3.976: bf16[4,128,2048], param_4.598: bf16[2048]) -> bf16[4,128,6144] { + %param_2.1420 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.976 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.598 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.382 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1420, %param_3.976, %param_4.598), kind=kLoop, calls=%fused_computation.116.clone.3.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1601 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1731 = s32[]{:T(128)S(6)} parameter(1) + %fusion.381 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1601, %param_1.1731), kind=kLoop, calls=%fused_computation.11.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %convolution.106 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.382, %fusion.381), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} +} + +%fused_computation.14.clone.clone.clone (param_0.1604: bf16[4,2048,6144], param_1.1734: s32[]) -> bf16[2048,6144,1] { + %param_0.1604 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1734 = s32[]{:T(128)S(6)} parameter(1) + %constant.1377 = s32[]{:T(128)} constant(0) + %dynamic_slice.406 = bf16[1,2048,6144]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1604, %param_1.1734, %constant.1377, %constant.1377), dynamic_slice_sizes={1,2048,6144}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.619 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.406), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.116.clone.2.clone.clone (param_0.1605: f32[4,128], param_1.1735: bf16[4,128,2048], param_2.1422: bf16[2048]) -> bf16[4,128,2048] { + %param_1.1735 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1606 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1735), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1605 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2911 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1605), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2910 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1606, %mul.2911), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1605 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1422 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2912 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1422), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2909 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1605, %mul.2912), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.20.clone.clone (param_0.1606: bf16[4,2048,6144], param_1.1736: s32[], param_2.1423: f32[4,128], param_3.977: bf16[4,128,2048], param_4.599: bf16[2048]) -> bf16[4,128,6144] { + %param_2.1423 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.977 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.599 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.386 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1423, %param_3.977, %param_4.599), kind=kLoop, calls=%fused_computation.116.clone.2.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1606 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1736 = s32[]{:T(128)S(6)} parameter(1) + %fusion.385 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1606, %param_1.1736), kind=kLoop, calls=%fused_computation.14.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %convolution.108 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.386, %fusion.385), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} +} + +%fused_computation.12.clone.clone.clone (param_0.1602: bf16[4,6144,2048], param_1.1732: s32[]) -> bf16[6144,2048,1] { + %param_0.1602 = bf16[4,6144,2048]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1732 = s32[]{:T(128)S(6)} parameter(1) + %constant.1376 = s32[]{:T(128)} constant(0) + %dynamic_slice.405 = bf16[1,6144,2048]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1602, %param_1.1732, %constant.1376, %constant.1376), dynamic_slice_sizes={1,6144,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.618 = bf16[6144,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.405), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%bitcast_fusion.1.clone.clone (bitcast_input.4: bf16[4,128,2048]) -> bf16[4,128,2048] { + %bitcast_input.4 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} parameter(0) + ROOT %bitcast.617 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%bitcast_input.4) diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index 4a500e2fe1..0902717928 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -132,7 +132,15 @@ def maybe_shard_with_pspec( def maybe_shard_with_logical( - inputs, logical_axes, mesh, shard_mode, rules=None, debug_sharding=False, extra_stack_level=0, sharding_desc="" + inputs, + logical_axes, + mesh, + shard_mode, + rules=None, + debug_sharding=False, + extra_stack_level=0, + sharding_desc="", + skip_trivial_specs=False, ): """ A wrapper of maybe_shard_with_name when logical axes are inputs @@ -147,6 +155,9 @@ def maybe_shard_with_logical( named_sharding = create_sharding(mesh, logical_axes, rules=rules) + if skip_trivial_specs and all(ax is None or ax == () for ax in named_sharding.spec): + return inputs + return maybe_shard_with_name( inputs, named_sharding, diff --git a/tests/utils/reference_hlo_deepseek3.txt b/tests/utils/reference_hlo_deepseek3.txt index eb316f1963..ac457cb0fd 100644 --- a/tests/utils/reference_hlo_deepseek3.txt +++ b/tests/utils/reference_hlo_deepseek3.txt @@ -10,21 +10,21 @@ StackFrames %region_46.56 (top_k.25: bf16[], top_k.26: bf16[], top_k.27: s32[], top_k.28: s32[]) -> pred[] { - %constant.1427 = s32[]{:T(128)} constant(0) - %constant.1428 = s32[]{:T(128)} constant(2147483647) + %constant.1376 = s32[]{:T(128)} constant(0) + %constant.1377 = s32[]{:T(128)} constant(2147483647) %top_k.25 = bf16[]{:T(256)} parameter(0), metadata={op_name="top_k"} %top_k.26 = bf16[]{:T(256)} parameter(1), metadata={op_name="top_k"} %top_k.27 = s32[]{:T(128)} parameter(2), metadata={op_name="top_k"} %top_k.28 = s32[]{:T(128)} parameter(3), metadata={op_name="top_k"} - %convert.393 = f32[]{:T(128)S(6)} convert(%top_k.25), metadata={op_name="convert.18"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %bitcast-convert.39 = s32[]{:T(128)S(6)} bitcast-convert(%convert.393), metadata={op_name="bitcast-convert.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %compare.144 = pred[]{:T(512)S(6)} compare(%bitcast-convert.39, %constant.1427), direction=LT, metadata={op_name="compare.38"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %xor.40 = s32[]{:T(128)S(6)} xor(%constant.1428, %bitcast-convert.39), metadata={op_name="xor.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %convert.278 = f32[]{:T(128)S(6)} convert(%top_k.25), metadata={op_name="convert.18"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.39 = s32[]{:T(128)S(6)} bitcast-convert(%convert.278), metadata={op_name="bitcast-convert.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.144 = pred[]{:T(512)S(6)} compare(%bitcast-convert.39, %constant.1376), direction=LT, metadata={op_name="compare.38"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.40 = s32[]{:T(128)S(6)} xor(%constant.1377, %bitcast-convert.39), metadata={op_name="xor.8"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %select.127 = s32[]{:T(128)S(6)} select(%compare.144, %xor.40, %bitcast-convert.39), metadata={op_name="select.16"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} - %convert.394 = f32[]{:T(128)S(6)} convert(%top_k.26), metadata={op_name="convert.19"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %bitcast-convert.40 = s32[]{:T(128)S(6)} bitcast-convert(%convert.394), metadata={op_name="bitcast-convert.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %compare.145 = pred[]{:T(512)S(6)} compare(%bitcast-convert.40, %constant.1427), direction=LT, metadata={op_name="compare.39"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %xor.41 = s32[]{:T(128)S(6)} xor(%constant.1428, %bitcast-convert.40), metadata={op_name="xor.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %convert.279 = f32[]{:T(128)S(6)} convert(%top_k.26), metadata={op_name="convert.19"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.40 = s32[]{:T(128)S(6)} bitcast-convert(%convert.279), metadata={op_name="bitcast-convert.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.145 = pred[]{:T(512)S(6)} compare(%bitcast-convert.40, %constant.1376), direction=LT, metadata={op_name="compare.39"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.41 = s32[]{:T(128)S(6)} xor(%constant.1377, %bitcast-convert.40), metadata={op_name="xor.9"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %select.128 = s32[]{:T(128)S(6)} select(%compare.145, %xor.41, %bitcast-convert.40), metadata={op_name="select.17"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} %compare.146 = pred[]{:T(512)S(6)} compare(%select.127, %select.128), direction=GT, metadata={op_name="compare.0"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %compare.147 = pred[]{:T(512)S(6)} compare(%select.128, %select.127), direction=GT, metadata={op_name="compare.117"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} @@ -69,28 +69,28 @@ StackFrames ROOT %select.134 = pred[]{:T(512)} select(%compare.156, %compare.157, %lt_to.37), metadata={op_name="select.116"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_119.141 (reduce_sum.157: bf16[], reduce_sum.158: bf16[]) -> bf16[] { - %reduce_sum.157 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/reduce_sum"} - %reduce_sum.158 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/reduce_sum"} - ROOT %reduce_sum.159 = bf16[]{:T(256)} add(%reduce_sum.157, %reduce_sum.158), metadata={op_name="checkpoint/moe_layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_119.141 (reduce_sum.151: bf16[], reduce_sum.152: bf16[]) -> bf16[] { + %reduce_sum.151 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/reduce_sum"} + %reduce_sum.152 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/reduce_sum"} + ROOT %reduce_sum.153 = bf16[]{:T(256)} add(%reduce_sum.151, %reduce_sum.152), metadata={op_name="checkpoint/moe_layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %region_107.126 (psum.6: bf16[], psum.9: bf16[]) -> bf16[] { %psum.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} %psum.9 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} - ROOT %add.1445 = bf16[]{:T(256)} add(%psum.6, %psum.9), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.1409 = bf16[]{:T(256)} add(%psum.6, %psum.9), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %region_108.127 (psum.10: bf16[], psum.11: bf16[]) -> bf16[] { %psum.10 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} %psum.11 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} - ROOT %add.1446 = bf16[]{:T(256)} add(%psum.10, %psum.11), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.1410 = bf16[]{:T(256)} add(%psum.10, %psum.11), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %region_109.128 (psum.14: bf16[], psum.15: bf16[]) -> bf16[] { %psum.14 = bf16[]{:T(256)} parameter(0), metadata={op_name="psum"} %psum.15 = bf16[]{:T(256)} parameter(1), metadata={op_name="psum"} - ROOT %add.1447 = bf16[]{:T(256)} add(%psum.14, %psum.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.1411 = bf16[]{:T(256)} add(%psum.14, %psum.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %region_62.73 (reduce-window.111: s32[], reduce-window.112: s32[]) -> s32[] { @@ -212,11 +212,11 @@ StackFrames %param_1.108 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.13 = s32[1024]{0:T(1024)} custom-call(%param_1.108), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} %slice.892 = s32[512]{0:T(512)} slice(%custom-call.13), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %reshape.3306 = s32[4,128]{1,0:T(4,128)} reshape(%slice.892), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.847 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3306), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %gather.183 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} gather(%param_0.17, %transpose.847), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %transpose.846 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%gather.183), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - ROOT %reshape.3305 = bf16[512,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.846), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %reshape.3432 = s32[4,128]{1,0:T(4,128)} reshape(%slice.892), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.607 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3432), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %gather.183 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} gather(%param_0.17, %transpose.607), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %transpose.606 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%gather.183), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %reshape.3431 = bf16[512,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.606), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} } %fused_computation.6 (param_0.20: f32[163840,32], param_1.110: s32[1024]) -> f32[512,32] { @@ -224,11 +224,11 @@ StackFrames %param_1.110 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.15 = s32[1024]{0:T(1024)} custom-call(%param_1.110), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} %slice.894 = s32[512]{0:T(512)} slice(%custom-call.15), slice={[0:512]}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %reshape.3314 = s32[4,128]{1,0:T(4,128)} reshape(%slice.894), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} - %transpose.853 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3314), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} - %gather.185 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.20, %transpose.853), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %transpose.852 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.185), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - ROOT %reshape.3313 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.852), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %reshape.3440 = s32[4,128]{1,0:T(4,128)} reshape(%slice.894), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %transpose.613 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3440), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %gather.185 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.20, %transpose.613), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %transpose.612 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.185), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + ROOT %reshape.3439 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.612), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} } %fused_computation.7 (param_0.23: f32[163840,32], param_1.112: s32[1024]) -> f32[512,32] { @@ -236,11 +236,11 @@ StackFrames %param_1.112 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.17 = s32[1024]{0:T(1024)} custom-call(%param_1.112), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} %slice.896 = s32[512]{0:T(512)} slice(%custom-call.17), slice={[0:512]}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %reshape.3322 = s32[4,128]{1,0:T(4,128)} reshape(%slice.896), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} - %transpose.859 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3322), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} - %gather.187 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.23, %transpose.859), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - %transpose.858 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.187), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} - ROOT %reshape.3321 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.858), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %reshape.3448 = s32[4,128]{1,0:T(4,128)} reshape(%slice.896), metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %transpose.619 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3448), dimensions={0,1}, metadata={op_name="jit(train_step)/dense_layers/broadcast_in_dim" stack_frame_id=0} + %gather.187 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.23, %transpose.619), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + %transpose.618 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.187), dimensions={0,1,2}, metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} + ROOT %reshape.3447 = f32[512,32]{1,0:T(8,128)} reshape(%transpose.618), metadata={op_name="jit(train_step)/dense_layers/gather" stack_frame_id=0} } %fused_computation.8 (param_0.26: f32[163840,32], param_1.120: s32[1024]) -> f32[512,32] { @@ -248,11 +248,11 @@ StackFrames %param_1.120 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.25 = s32[1024]{0:T(1024)} custom-call(%param_1.120), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} %slice.904 = s32[512]{0:T(512)} slice(%custom-call.25), slice={[0:512]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %reshape.3330 = s32[4,128]{1,0:T(4,128)} reshape(%slice.904), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} - %transpose.865 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3330), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} - %gather.189 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.26, %transpose.865), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %transpose.864 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.189), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - ROOT %reshape.3329 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.864), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %reshape.3456 = s32[4,128]{1,0:T(4,128)} reshape(%slice.904), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %transpose.625 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3456), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %gather.189 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.26, %transpose.625), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %transpose.624 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.189), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + ROOT %reshape.3455 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.624), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} } %fused_computation.9 (param_0.29: f32[163840,32], param_1.122: s32[1024]) -> f32[512,32] { @@ -260,11 +260,11 @@ StackFrames %param_1.122 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.27 = s32[1024]{0:T(1024)} custom-call(%param_1.122), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} %slice.906 = s32[512]{0:T(512)} slice(%custom-call.27), slice={[0:512]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %reshape.3338 = s32[4,128]{1,0:T(4,128)} reshape(%slice.906), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} - %transpose.871 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3338), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} - %gather.191 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.29, %transpose.871), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - %transpose.870 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.191), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} - ROOT %reshape.3337 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.870), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %reshape.3464 = s32[4,128]{1,0:T(4,128)} reshape(%slice.906), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %transpose.631 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3464), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/select_n" stack_frame_id=0} + %gather.191 = f32[4,128,32]{2,1,0:T(8,128)} gather(%param_0.29, %transpose.631), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,32}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + %transpose.630 = f32[4,128,32]{2,1,0:T(8,128)} transpose(%gather.191), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} + ROOT %reshape.3463 = f32[512,32]{1,0:T(8,128)S(1)} reshape(%transpose.630), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/gather" stack_frame_id=0} } %fused_computation.10 (param_0.32: bf16[4096,512], param_1.126: s32[4096]) -> bf16[4096,512] { @@ -272,11 +272,11 @@ StackFrames %param_1.126 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.31 = s32[4096]{0:T(1024)} custom-call(%param_1.126), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} %slice.910 = s32[4096]{0:T(1024)} slice(%custom-call.31), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3346 = s32[4096]{0:T(1024)} reshape(%slice.910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.877 = s32[4096]{0:T(1024)} transpose(%reshape.3346), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.193 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.32, %transpose.877), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.876 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.193), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3345 = bf16[4096,512]{1,0:T(8,128)(2,1)} reshape(%transpose.876), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3472 = s32[4096]{0:T(1024)} reshape(%slice.910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.637 = s32[4096]{0:T(1024)} transpose(%reshape.3472), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.193 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.32, %transpose.637), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.636 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.193), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3471 = bf16[4096,512]{1,0:T(8,128)(2,1)} reshape(%transpose.636), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %fused_computation.11 (param_0.35: bf16[4096,512], param_1.128: s32[4096]) -> bf16[4096,512] { @@ -284,11 +284,11 @@ StackFrames %param_1.128 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.33 = s32[4096]{0:T(1024)} custom-call(%param_1.128), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} %slice.912 = s32[4096]{0:T(1024)} slice(%custom-call.33), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3354 = s32[4096]{0:T(1024)} reshape(%slice.912), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.883 = s32[4096]{0:T(1024)} transpose(%reshape.3354), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.195 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.35, %transpose.883), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.882 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.195), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3353 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.882), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3480 = s32[4096]{0:T(1024)} reshape(%slice.912), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.643 = s32[4096]{0:T(1024)} transpose(%reshape.3480), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.195 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.35, %transpose.643), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.642 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.195), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3479 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.642), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %fused_computation.12 (param_0.38: bf16[4096,512], param_1.130: s32[4096]) -> bf16[4096,512] { @@ -296,11 +296,11 @@ StackFrames %param_1.130 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.35 = s32[4096]{0:T(1024)} custom-call(%param_1.130), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} %slice.914 = s32[4096]{0:T(1024)} slice(%custom-call.35), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3362 = s32[4096]{0:T(1024)} reshape(%slice.914), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.889 = s32[4096]{0:T(1024)} transpose(%reshape.3362), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.197 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.38, %transpose.889), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.888 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.197), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3361 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.888), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3488 = s32[4096]{0:T(1024)} reshape(%slice.914), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.649 = s32[4096]{0:T(1024)} transpose(%reshape.3488), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.197 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.38, %transpose.649), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.648 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.197), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3487 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.648), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %fused_computation.13 (param_0.41: bf16[4096,512], param_1.132: s32[4096]) -> bf16[4096,512] { @@ -308,11 +308,11 @@ StackFrames %param_1.132 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.37 = s32[4096]{0:T(1024)} custom-call(%param_1.132), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} %slice.916 = s32[4096]{0:T(1024)} slice(%custom-call.37), slice={[0:4096]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3370 = s32[4096]{0:T(1024)} reshape(%slice.916), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.895 = s32[4096]{0:T(1024)} transpose(%reshape.3370), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.199 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.41, %transpose.895), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.894 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.199), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3369 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.894), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3496 = s32[4096]{0:T(1024)} reshape(%slice.916), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.655 = s32[4096]{0:T(1024)} transpose(%reshape.3496), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.199 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.41, %transpose.655), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.654 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.199), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3495 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.654), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %fused_computation.15 (param_0.47: s32[256], param_1.124: s32[1024]) -> s32[263] { @@ -320,11 +320,11 @@ StackFrames %param_1.124 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.29 = s32[1024]{0:T(1024)} custom-call(%param_1.124), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} %slice.908 = s32[263]{0:T(512)} slice(%custom-call.29), slice={[0:263]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %reshape.3401 = s32[263]{0:T(512)} reshape(%slice.908), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %transpose.911 = s32[263]{0:T(512)} transpose(%reshape.3401), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %gather.204 = s32[263]{0:T(512)} gather(%param_0.47, %transpose.911), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %transpose.910 = s32[263]{0:T(512)} transpose(%gather.204), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - ROOT %reshape.3400 = s32[263]{0:T(512)S(1)} reshape(%transpose.910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %reshape.3527 = s32[263]{0:T(512)} reshape(%slice.908), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %transpose.671 = s32[263]{0:T(512)} transpose(%reshape.3527), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %gather.204 = s32[263]{0:T(512)} gather(%param_0.47, %transpose.671), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %transpose.670 = s32[263]{0:T(512)} transpose(%gather.204), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + ROOT %reshape.3526 = s32[263]{0:T(512)S(1)} reshape(%transpose.670), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} } %fused_computation.16 (param_0.50: s32[256], param_1.134: s32[1024]) -> s32[263] { @@ -332,46 +332,46 @@ StackFrames %param_1.134 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.39 = s32[1024]{0:T(1024)} custom-call(%param_1.134), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} %slice.918 = s32[263]{0:T(512)} slice(%custom-call.39), slice={[0:263]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} - %reshape.3424 = s32[263]{0:T(512)} reshape(%slice.918), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %transpose.921 = s32[263]{0:T(512)} transpose(%reshape.3424), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %gather.207 = s32[263]{0:T(512)} gather(%param_0.50, %transpose.921), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} - %transpose.920 = s32[263]{0:T(512)} transpose(%gather.207), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} - ROOT %reshape.3423 = s32[263]{0:T(512)S(1)} reshape(%transpose.920), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + %reshape.3550 = s32[263]{0:T(512)} reshape(%slice.918), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %transpose.681 = s32[263]{0:T(512)} transpose(%reshape.3550), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %gather.207 = s32[263]{0:T(512)} gather(%param_0.50, %transpose.681), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + %transpose.680 = s32[263]{0:T(512)} transpose(%gather.207), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} + ROOT %reshape.3549 = s32[263]{0:T(512)S(1)} reshape(%transpose.680), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/jit(_take)/gather" stack_frame_id=0} } %region_173.198.clone (scatter-add.94: bf16[], scatter-add.96: bf16[]) -> bf16[] { %scatter-add.94 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} %scatter-add.96 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} - ROOT %add.1918 = bf16[]{:T(256)} add(%scatter-add.94, %scatter-add.96), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.1874 = bf16[]{:T(256)} add(%scatter-add.94, %scatter-add.96), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %fused_computation.21 (param_0.55: bf16[129280,512], param_1.65: s32[512], param_2.24: bf16[512,512]) -> bf16[129280,512] { %param_0.55 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) %param_1.65 = s32[512]{0:T(512)S(1)} parameter(1) - %reshape.3478 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.65), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.954 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3478), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %reshape.3604 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.65), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.714 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.3604), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} %param_2.24 = bf16[512,512]{1,0:T(8,128)(2,1)S(1)} parameter(2) - %reshape.3479 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} reshape(%param_2.24), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} - %transpose.955 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%reshape.3479), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} - ROOT %scatter.73 = bf16[129280,512]{1,0:T(8,128)(2,1)} scatter(%param_0.55, %transpose.954, %transpose.955), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_173.198.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} + %reshape.3605 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} reshape(%param_2.24), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} + %transpose.715 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} transpose(%reshape.3605), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/while" stack_frame_id=0} + ROOT %scatter.73 = bf16[129280,512]{1,0:T(8,128)(2,1)} scatter(%param_0.55, %transpose.714, %transpose.715), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_173.198.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} } -%region_12.18 (top_k.0: bf16[], top_k.6: bf16[], top_k.7: s32[], top_k.8: s32[]) -> pred[] { - %constant.1387 = s32[]{:T(128)} constant(0) - %constant.1388 = s32[]{:T(128)} constant(2147483647) +%region_11.17 (top_k.0: bf16[], top_k.6: bf16[], top_k.7: s32[], top_k.8: s32[]) -> pred[] { + %constant.1336 = s32[]{:T(128)} constant(0) + %constant.1337 = s32[]{:T(128)} constant(2147483647) %top_k.0 = bf16[]{:T(256)} parameter(0), metadata={op_name="top_k"} %top_k.6 = bf16[]{:T(256)} parameter(1), metadata={op_name="top_k"} %top_k.7 = s32[]{:T(128)} parameter(2), metadata={op_name="top_k"} %top_k.8 = s32[]{:T(128)} parameter(3), metadata={op_name="top_k"} - %convert.385 = f32[]{:T(128)S(6)} convert(%top_k.0), metadata={op_name="convert.16"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %bitcast-convert.35 = s32[]{:T(128)S(6)} bitcast-convert(%convert.385), metadata={op_name="bitcast-convert.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %compare.128 = pred[]{:T(512)S(6)} compare(%bitcast-convert.35, %constant.1387), direction=LT, metadata={op_name="compare.35"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %xor.36 = s32[]{:T(128)S(6)} xor(%constant.1388, %bitcast-convert.35), metadata={op_name="xor.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %convert.270 = f32[]{:T(128)S(6)} convert(%top_k.0), metadata={op_name="convert.16"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.35 = s32[]{:T(128)S(6)} bitcast-convert(%convert.270), metadata={op_name="bitcast-convert.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.128 = pred[]{:T(512)S(6)} compare(%bitcast-convert.35, %constant.1336), direction=LT, metadata={op_name="compare.35"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.36 = s32[]{:T(128)S(6)} xor(%constant.1337, %bitcast-convert.35), metadata={op_name="xor.6"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %select.118 = s32[]{:T(128)S(6)} select(%compare.128, %xor.36, %bitcast-convert.35), metadata={op_name="select.14"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} - %convert.386 = f32[]{:T(128)S(6)} convert(%top_k.6), metadata={op_name="convert.17"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %bitcast-convert.36 = s32[]{:T(128)S(6)} bitcast-convert(%convert.386), metadata={op_name="bitcast-convert.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %compare.129 = pred[]{:T(512)S(6)} compare(%bitcast-convert.36, %constant.1387), direction=LT, metadata={op_name="compare.36"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} - %xor.37 = s32[]{:T(128)S(6)} xor(%constant.1388, %bitcast-convert.36), metadata={op_name="xor.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %convert.271 = f32[]{:T(128)S(6)} convert(%top_k.6), metadata={op_name="convert.17"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %bitcast-convert.36 = s32[]{:T(128)S(6)} bitcast-convert(%convert.271), metadata={op_name="bitcast-convert.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %compare.129 = pred[]{:T(512)S(6)} compare(%bitcast-convert.36, %constant.1336), direction=LT, metadata={op_name="compare.36"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %xor.37 = s32[]{:T(128)S(6)} xor(%constant.1337, %bitcast-convert.36), metadata={op_name="xor.7"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %select.119 = s32[]{:T(128)S(6)} select(%compare.129, %xor.37, %bitcast-convert.36), metadata={op_name="select.15"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["1","3"]}]}} %compare.130 = pred[]{:T(512)S(6)} compare(%select.118, %select.119), direction=GT, metadata={op_name="compare.1"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} %compare.131 = pred[]{:T(512)S(6)} compare(%select.119, %select.118), direction=GT, metadata={op_name="compare.108"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} @@ -380,37 +380,37 @@ StackFrames ROOT %select.120 = pred[]{:T(512)} select(%compare.132, %compare.133, %compare.130), metadata={op_name="select.108"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_15.21.clone.1 (reduce-window.326: s32[], reduce-window.327: s32[]) -> s32[] { +%region_14.20.clone.1 (reduce-window.326: s32[], reduce-window.327: s32[]) -> s32[] { %reduce-window.326 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.20"} %reduce-window.327 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.20"} ROOT %reduce_window_sum.282 = s32[]{:T(128)} add(%reduce-window.326, %reduce-window.327), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_16.22.clone.1 (reduce-window.330: s32[], reduce-window.331: s32[]) -> s32[] { +%region_15.21.clone.1 (reduce-window.330: s32[], reduce-window.331: s32[]) -> s32[] { %reduce-window.330 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.56"} %reduce-window.331 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.56"} ROOT %reduce_window_sum.284 = s32[]{:T(128)} add(%reduce-window.330, %reduce-window.331), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_18.24.clone.1 (reduce-window.334: s32[], reduce-window.335: s32[]) -> s32[] { +%region_17.23.clone.1 (reduce-window.334: s32[], reduce-window.335: s32[]) -> s32[] { %reduce-window.334 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.22"} %reduce-window.335 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.22"} ROOT %reduce_window_sum.286 = s32[]{:T(128)} add(%reduce-window.334, %reduce-window.335), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_19.25.clone.1 (reduce-window.338: s32[], reduce-window.339: s32[]) -> s32[] { +%region_18.24.clone.1 (reduce-window.338: s32[], reduce-window.339: s32[]) -> s32[] { %reduce-window.338 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.57"} %reduce-window.339 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.57"} ROOT %reduce_window_sum.288 = s32[]{:T(128)} add(%reduce-window.338, %reduce-window.339), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_21.27.clone.1 (reduce-window.342: s32[], reduce-window.343: s32[]) -> s32[] { +%region_20.26.clone.1 (reduce-window.342: s32[], reduce-window.343: s32[]) -> s32[] { %reduce-window.342 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.24"} %reduce-window.343 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.24"} ROOT %reduce_window_sum.290 = s32[]{:T(128)} add(%reduce-window.342, %reduce-window.343), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_22.28.clone.1 (reduce-window.346: s32[], reduce-window.347: s32[]) -> s32[] { +%region_21.27.clone.1 (reduce-window.346: s32[], reduce-window.347: s32[]) -> s32[] { %reduce-window.346 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.58"} %reduce-window.347 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.58"} ROOT %reduce_window_sum.292 = s32[]{:T(128)} add(%reduce-window.346, %reduce-window.347), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} @@ -421,32 +421,32 @@ StackFrames %param_1.114 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.19 = s32[1024]{0:T(1024)} custom-call(%param_1.114), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} %slice.898 = s32[263]{0:T(512)} slice(%custom-call.19), slice={[0:263]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %reshape.3622 = s32[263]{0:T(512)} reshape(%slice.898), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %transpose.1037 = s32[263]{0:T(512)} transpose(%reshape.3622), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} - %gather.209 = s32[263]{0:T(512)} gather(%param_0.68, %transpose.1037), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - %transpose.1036 = s32[263]{0:T(512)} transpose(%gather.209), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} - ROOT %reshape.3621 = s32[263]{0:T(512)S(1)} reshape(%transpose.1036), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %reshape.3748 = s32[263]{0:T(512)} reshape(%slice.898), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %transpose.797 = s32[263]{0:T(512)} transpose(%reshape.3748), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/broadcast_in_dim" stack_frame_id=0} + %gather.209 = s32[263]{0:T(512)} gather(%param_0.68, %transpose.797), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + %transpose.796 = s32[263]{0:T(512)} transpose(%gather.209), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} + ROOT %reshape.3747 = s32[263]{0:T(512)S(1)} reshape(%transpose.796), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/jit(_take)/gather" stack_frame_id=0} } -%region_27.34.clone.1 (reduce-window.350: s32[], reduce-window.351: s32[]) -> s32[] { +%region_26.33.clone.1 (reduce-window.350: s32[], reduce-window.351: s32[]) -> s32[] { %reduce-window.350 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.26"} %reduce-window.351 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.26"} ROOT %reduce_window_sum.294 = s32[]{:T(128)} add(%reduce-window.350, %reduce-window.351), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_29.36.clone.1 (reduce-window.354: s32[], reduce-window.355: s32[]) -> s32[] { +%region_28.35.clone.1 (reduce-window.354: s32[], reduce-window.355: s32[]) -> s32[] { %reduce-window.354 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.27"} %reduce-window.355 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.27"} ROOT %reduce_window_sum.296 = s32[]{:T(128)} add(%reduce-window.354, %reduce-window.355), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_30.37.clone.1 (reduce-window.358: s32[], reduce-window.359: s32[]) -> s32[] { +%region_29.36.clone.1 (reduce-window.358: s32[], reduce-window.359: s32[]) -> s32[] { %reduce-window.358 = s32[]{:T(128)} parameter(0), metadata={op_name="reduce-window.59"} %reduce-window.359 = s32[]{:T(128)} parameter(1), metadata={op_name="reduce-window.59"} ROOT %reduce_window_sum.298 = s32[]{:T(128)} add(%reduce-window.358, %reduce-window.359), metadata={op_name="reduce_window_sum"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_13.19 (sort.44: s32[], sort.45: s32[], sort.46: s32[], sort.47: s32[], sort.48: s32[], sort.49: s32[]) -> pred[] { +%region_12.18 (sort.44: s32[], sort.45: s32[], sort.46: s32[], sort.47: s32[], sort.48: s32[], sort.49: s32[]) -> pred[] { %sort.46 = s32[]{:T(128)} parameter(2), metadata={op_name="jit(argsort)/sort"} %sort.47 = s32[]{:T(128)} parameter(3), metadata={op_name="jit(argsort)/sort"} %sort.44 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(argsort)/sort"} @@ -465,14 +465,14 @@ StackFrames %param_1.116 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.21 = s32[4096]{0:T(1024)} custom-call(%param_1.116), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} %slice.900 = s32[4096]{0:T(1024)} slice(%custom-call.21), slice={[0:4096]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3645 = s32[4096]{0:T(1024)} reshape(%slice.900), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.1043 = s32[4096]{0:T(1024)} transpose(%reshape.3645), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.210 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.71, %transpose.1043), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.1042 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.210), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3644 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.1042), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3771 = s32[4096]{0:T(1024)} reshape(%slice.900), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.803 = s32[4096]{0:T(1024)} transpose(%reshape.3771), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.210 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.71, %transpose.803), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.802 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.210), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3770 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.802), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } -%region_31.39 (sort.50: s32[], sort.51: s32[], sort.52: s32[], sort.53: s32[], sort.54: s32[], sort.55: s32[]) -> pred[] { +%region_30.38 (sort.50: s32[], sort.51: s32[], sort.52: s32[], sort.53: s32[], sort.54: s32[], sort.55: s32[]) -> pred[] { %sort.52 = s32[]{:T(128)} parameter(2), metadata={op_name="jit(argsort)/sort"} %sort.53 = s32[]{:T(128)} parameter(3), metadata={op_name="jit(argsort)/sort"} %sort.50 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(argsort)/sort"} @@ -491,11 +491,11 @@ StackFrames %param_1.118 = s32[4096]{0:T(1024)S(1)} parameter(1) %custom-call.23 = s32[4096]{0:T(1024)} custom-call(%param_1.118), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[4096]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} %slice.902 = s32[4096]{0:T(1024)} slice(%custom-call.23), slice={[0:4096]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %reshape.3647 = s32[4096]{0:T(1024)} reshape(%slice.902), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %transpose.1045 = s32[4096]{0:T(1024)} transpose(%reshape.3647), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} - %gather.211 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.72, %transpose.1045), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - %transpose.1044 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.211), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} - ROOT %reshape.3646 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.1044), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %reshape.3773 = s32[4096]{0:T(1024)} reshape(%slice.902), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %transpose.805 = s32[4096]{0:T(1024)} transpose(%reshape.3773), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/broadcast_in_dim" stack_frame_id=0} + %gather.211 = bf16[4096,512]{1,0:T(8,128)(2,1)} gather(%param_0.72, %transpose.805), offset_dims={1}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1,512}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + %transpose.804 = bf16[4096,512]{1,0:T(8,128)(2,1)} transpose(%gather.211), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} + ROOT %reshape.3772 = bf16[4096,512]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.804), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/sort_activations/gather" stack_frame_id=0} } %compare (name: s32[], name.1: s32[], name.2: bf16[], name.3: bf16[]) -> pred[] { @@ -538,49 +538,49 @@ StackFrames ROOT %compare.381 = pred[] compare(%name.16, %name.17), direction=LT, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%called_computation.13 (param_0.4519: s32[256]) -> s32[256] { - %param_0.4519 = s32[256]{0:T(256)} parameter(0) - ROOT %copy.2073 = s32[256]{0:T(256)} copy(%param_0.4519), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"1134","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.13 (param_0.4545: s32[256]) -> s32[256] { + %param_0.4545 = s32[256]{0:T(256)} parameter(0) + ROOT %copy.2073 = s32[256]{0:T(256)} copy(%param_0.4545), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"1134","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.13 (param_0.4520: s32[256]) -> s32[256] { - %param_0.4520 = s32[256]{0:T(256)} parameter(0) - ROOT %copy.2074.cloned.1 = s32[256]{0:T(256)} call(%param_0.4520), to_apply=%called_computation.13 +%async_computation.13 (param_0.4546: s32[256]) -> s32[256] { + %param_0.4546 = s32[256]{0:T(256)} parameter(0) + ROOT %copy.2074.cloned.1 = s32[256]{0:T(256)} call(%param_0.4546), to_apply=%called_computation.13 }, execution_thread="sparsecore" %region_49.59 (scatter-add.14: s32[], scatter-add.15: s32[]) -> s32[] { %scatter-add.14 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.15 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1387 = s32[]{:T(128)S(7)} add(%scatter-add.14, %scatter-add.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1351 = s32[]{:T(128)S(7)} add(%scatter-add.14, %scatter-add.15), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.22.clone.clone (param_0.4521: s32[256], param_1.5326: s32[4096], param_2.4491: s32[4096]) -> s32[256] { - %param_0.4521 = s32[256]{0:T(256)} parameter(0) - %param_1.5326 = s32[4096]{0:T(1024)} parameter(1) - %reshape.3911 = s32[4096]{0:T(1024)} reshape(%param_1.5326), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} - %transpose.1100 = s32[4096]{0:T(1024)} transpose(%reshape.3911), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} - %param_2.4491 = s32[4096]{0:T(1024)} parameter(2) - %reshape.3912 = s32[4096]{0:T(1024)} reshape(%param_2.4491), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} - %transpose.1101 = s32[4096]{0:T(1024)} transpose(%reshape.3912), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.231 = s32[256]{0:T(256)} scatter(%param_0.4521, %transpose.1100, %transpose.1101), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_49.59, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} +%fused_computation.22.clone.clone (param_0.4547: s32[256], param_1.5330: s32[4096], param_2.4486: s32[4096]) -> s32[256] { + %param_0.4547 = s32[256]{0:T(256)} parameter(0) + %param_1.5330 = s32[4096]{0:T(1024)} parameter(1) + %reshape.4037 = s32[4096]{0:T(1024)} reshape(%param_1.5330), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} + %transpose.860 = s32[4096]{0:T(1024)} transpose(%reshape.4037), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(clip)/max" stack_frame_id=0} + %param_2.4486 = s32[4096]{0:T(1024)} parameter(2) + %reshape.4038 = s32[4096]{0:T(1024)} reshape(%param_2.4486), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + %transpose.861 = s32[4096]{0:T(1024)} transpose(%reshape.4038), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.231 = s32[256]{0:T(256)} scatter(%param_0.4547, %transpose.860, %transpose.861), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_49.59, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.14 (param_0.4522: s32[256], param_1.5327: s32[4096], param_2.4492: s32[4096]) -> s32[256] { - %param_0.4522 = s32[256]{0:T(256)} parameter(0) - %param_1.5327 = s32[4096]{0:T(1024)} parameter(1) - %param_2.4492 = s32[4096]{0:T(1024)} parameter(2) - ROOT %scatter_offload_custom_fusion.39 = s32[256]{0:T(256)} fusion(%param_0.4522, %param_1.5327, %param_2.4492), kind=kCustom, calls=%fused_computation.22.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["256"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"4160","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.14 (param_0.4548: s32[256], param_1.5331: s32[4096], param_2.4487: s32[4096]) -> s32[256] { + %param_0.4548 = s32[256]{0:T(256)} parameter(0) + %param_1.5331 = s32[4096]{0:T(1024)} parameter(1) + %param_2.4487 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.39 = s32[256]{0:T(256)} fusion(%param_0.4548, %param_1.5331, %param_2.4487), kind=kCustom, calls=%fused_computation.22.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["256"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"4160","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.14 (param_0.4523: s32[256], param_1.5328: s32[4096], param_2.4493: s32[4096]) -> s32[256] { - %param_0.4523 = s32[256]{0:T(256)} parameter(0) - %param_1.5328 = s32[4096]{0:T(1024)} parameter(1) - %param_2.4493 = s32[4096]{0:T(1024)} parameter(2) - ROOT %scatter_offload_custom_fusion.40.cloned.1 = s32[256]{0:T(256)} call(%param_0.4523, %param_1.5328, %param_2.4493), to_apply=%called_computation.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} +%async_computation.14 (param_0.4549: s32[256], param_1.5332: s32[4096], param_2.4488: s32[4096]) -> s32[256] { + %param_0.4549 = s32[256]{0:T(256)} parameter(0) + %param_1.5332 = s32[4096]{0:T(1024)} parameter(1) + %param_2.4488 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.40.cloned.1 = s32[256]{0:T(256)} call(%param_0.4549, %param_1.5332, %param_2.4488), to_apply=%called_computation.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation (param_0.84: s32[256], param_1.136: s32[4096], param_2.80: s32[4096], param_3.3085: token[]) -> s32[256] { - %param_3.3085 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation (param_0.84: s32[256], param_1.136: s32[4096], param_2.80: s32[4096], param_3.3090: token[]) -> s32[256] { + %param_3.3090 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.84 = s32[256]{0:T(256)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.136 = s32[4096]{0:T(1024)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.80 = s32[4096]{0:T(1024)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -590,57 +590,57 @@ StackFrames ROOT %scatter_offload_custom_fusion.40.cloned.1.call-done = s32[256]{0:T(256)} async-done(%scatter_offload_custom_fusion.40.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation (param_0.85: s32[256], param_1.137: s32[4096], param_2.81: s32[4096], param_3.3084: token[]) -> s32[256] { - %param_3.3084 = token[] parameter(3) +%async_computation (param_0.85: s32[256], param_1.137: s32[4096], param_2.81: s32[4096], param_3.3089: token[]) -> s32[256] { + %param_3.3089 = token[] parameter(3) %param_0.85 = s32[256]{0:T(256)} parameter(0) %param_1.137 = s32[4096]{0:T(1024)} parameter(1) %param_2.81 = s32[4096]{0:T(1024)} parameter(2) - ROOT %scatter_offload_custom_fusion.2.cloned.1 = s32[256]{0:T(256)} call(%param_0.85, %param_1.137, %param_2.81, %param_3.3084), to_apply=%called_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.2.cloned.1 = s32[256]{0:T(256)} call(%param_0.85, %param_1.137, %param_2.81, %param_3.3089), to_apply=%called_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.15 (param_0.4524: f32[9]) -> f32[9] { - %param_0.4524 = f32[9]{0:T(128)} parameter(0) - ROOT %copy.2075 = f32[9]{0:T(128)} copy(%param_0.4524), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.15 (param_0.4550: f32[9]) -> f32[9] { + %param_0.4550 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2075 = f32[9]{0:T(128)} copy(%param_0.4550), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.15 (param_0.4525: f32[9]) -> f32[9] { - %param_0.4525 = f32[9]{0:T(128)} parameter(0) - ROOT %copy.2076.cloned.1 = f32[9]{0:T(128)} call(%param_0.4525), to_apply=%called_computation.15 +%async_computation.15 (param_0.4551: f32[9]) -> f32[9] { + %param_0.4551 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2076.cloned.1 = f32[9]{0:T(128)} call(%param_0.4551), to_apply=%called_computation.15 }, execution_thread="sparsecore" %region_61.72 (scatter-add.24: f32[], scatter-add.25: f32[]) -> f32[] { %scatter-add.24 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.25 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1393 = f32[]{:T(128)S(7)} add(%scatter-add.24, %scatter-add.25), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1357 = f32[]{:T(128)S(7)} add(%scatter-add.24, %scatter-add.25), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.24.clone.clone (param_0.4526: f32[9], param_1.5329: s32[256], param_2.4494: f32[256]) -> f32[9] { - %param_0.4526 = f32[9]{0:T(128)} parameter(0) - %param_1.5329 = s32[256]{0:T(256)} parameter(1) - %reshape.3913 = s32[256]{0:T(256)} reshape(%param_1.5329), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1102 = s32[256]{0:T(256)} transpose(%reshape.3913), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %param_2.4494 = f32[256]{0:T(256)} parameter(2) - %reshape.3914 = f32[256]{0:T(256)} reshape(%param_2.4494), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1103 = f32[256]{0:T(256)} transpose(%reshape.3914), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.232 = f32[9]{0:T(128)} scatter(%param_0.4526, %transpose.1102, %transpose.1103), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_61.72, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.24.clone.clone (param_0.4552: f32[9], param_1.5333: s32[256], param_2.4489: f32[256]) -> f32[9] { + %param_0.4552 = f32[9]{0:T(128)} parameter(0) + %param_1.5333 = s32[256]{0:T(256)} parameter(1) + %reshape.4039 = s32[256]{0:T(256)} reshape(%param_1.5333), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.862 = s32[256]{0:T(256)} transpose(%reshape.4039), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %param_2.4489 = f32[256]{0:T(256)} parameter(2) + %reshape.4040 = f32[256]{0:T(256)} reshape(%param_2.4489), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.863 = f32[256]{0:T(256)} transpose(%reshape.4040), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.232 = f32[9]{0:T(128)} scatter(%param_0.4552, %transpose.862, %transpose.863), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_61.72, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.16 (param_0.4527: f32[9], param_1.5330: s32[256], param_2.4495: f32[256]) -> f32[9] { - %param_0.4527 = f32[9]{0:T(128)} parameter(0) - %param_1.5330 = s32[256]{0:T(256)} parameter(1) - %param_2.4495 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.41 = f32[9]{0:T(128)} fusion(%param_0.4527, %param_1.5330, %param_2.4495), kind=kCustom, calls=%fused_computation.24.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.16 (param_0.4553: f32[9], param_1.5334: s32[256], param_2.4490: f32[256]) -> f32[9] { + %param_0.4553 = f32[9]{0:T(128)} parameter(0) + %param_1.5334 = s32[256]{0:T(256)} parameter(1) + %param_2.4490 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.41 = f32[9]{0:T(128)} fusion(%param_0.4553, %param_1.5334, %param_2.4490), kind=kCustom, calls=%fused_computation.24.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.16 (param_0.4528: f32[9], param_1.5331: s32[256], param_2.4496: f32[256]) -> f32[9] { - %param_0.4528 = f32[9]{0:T(128)} parameter(0) - %param_1.5331 = s32[256]{0:T(256)} parameter(1) - %param_2.4496 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.42.cloned.1 = f32[9]{0:T(128)} call(%param_0.4528, %param_1.5331, %param_2.4496), to_apply=%called_computation.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.16 (param_0.4554: f32[9], param_1.5335: s32[256], param_2.4491: f32[256]) -> f32[9] { + %param_0.4554 = f32[9]{0:T(128)} parameter(0) + %param_1.5335 = s32[256]{0:T(256)} parameter(1) + %param_2.4491 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.42.cloned.1 = f32[9]{0:T(128)} call(%param_0.4554, %param_1.5335, %param_2.4491), to_apply=%called_computation.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.1 (param_0.87: f32[9], param_1.139: s32[256], param_2.83: f32[256], param_3.3099: token[]) -> f32[9] { - %param_3.3099 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.1 (param_0.87: f32[9], param_1.139: s32[256], param_2.83: f32[256], param_3.3104: token[]) -> f32[9] { + %param_3.3104 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.87 = f32[9]{0:T(128)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.139 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.83 = f32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -650,57 +650,57 @@ StackFrames ROOT %scatter_offload_custom_fusion.42.cloned.1.call-done = f32[9]{0:T(128)} async-done(%scatter_offload_custom_fusion.42.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.1 (param_0.88: f32[9], param_1.140: s32[256], param_2.84: f32[256], param_3.3098: token[]) -> f32[9] { - %param_3.3098 = token[] parameter(3) +%async_computation.1 (param_0.88: f32[9], param_1.140: s32[256], param_2.84: f32[256], param_3.3103: token[]) -> f32[9] { + %param_3.3103 = token[] parameter(3) %param_0.88 = f32[9]{0:T(128)} parameter(0) %param_1.140 = s32[256]{0:T(256)} parameter(1) %param_2.84 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.5.cloned.1 = f32[9]{0:T(128)} call(%param_0.88, %param_1.140, %param_2.84, %param_3.3098), to_apply=%called_computation.1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.5.cloned.1 = f32[9]{0:T(128)} call(%param_0.88, %param_1.140, %param_2.84, %param_3.3103), to_apply=%called_computation.1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.17 (param_0.4529: s32[263]) -> s32[263] { - %param_0.4529 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2077 = s32[263]{0:T(512)} copy(%param_0.4529), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.17 (param_0.4555: s32[263]) -> s32[263] { + %param_0.4555 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2077 = s32[263]{0:T(512)} copy(%param_0.4555), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.17 (param_0.4530: s32[263]) -> s32[263] { - %param_0.4530 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2078.cloned.1 = s32[263]{0:T(512)} call(%param_0.4530), to_apply=%called_computation.17 +%async_computation.17 (param_0.4556: s32[263]) -> s32[263] { + %param_0.4556 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2078.cloned.1 = s32[263]{0:T(512)} call(%param_0.4556), to_apply=%called_computation.17 }, execution_thread="sparsecore" %region_63.74 (scatter-add.28: s32[], scatter-add.29: s32[]) -> s32[] { %scatter-add.28 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.29 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1394 = s32[]{:T(128)S(7)} add(%scatter-add.28, %scatter-add.29), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1358 = s32[]{:T(128)S(7)} add(%scatter-add.28, %scatter-add.29), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.25.clone.clone (param_0.4531: s32[263], param_1.5332: s32[8], param_2.4497: s32[8]) -> s32[263] { - %param_0.4531 = s32[263]{0:T(512)} parameter(0) - %param_1.5332 = s32[8]{0:T(128)} parameter(1) - %reshape.3915 = s32[8]{0:T(128)} reshape(%param_1.5332), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1104 = s32[8]{0:T(128)} transpose(%reshape.3915), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %param_2.4497 = s32[8]{0:T(128)} parameter(2) - %reshape.3916 = s32[8]{0:T(128)} reshape(%param_2.4497), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - %transpose.1105 = s32[8]{0:T(128)} transpose(%reshape.3916), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - ROOT %scatter-add.233 = s32[263]{0:T(512)} scatter(%param_0.4531, %transpose.1104, %transpose.1105), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_63.74, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.25.clone.clone (param_0.4557: s32[263], param_1.5336: s32[8], param_2.4492: s32[8]) -> s32[263] { + %param_0.4557 = s32[263]{0:T(512)} parameter(0) + %param_1.5336 = s32[8]{0:T(128)} parameter(1) + %reshape.4041 = s32[8]{0:T(128)} reshape(%param_1.5336), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.864 = s32[8]{0:T(128)} transpose(%reshape.4041), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4492 = s32[8]{0:T(128)} parameter(2) + %reshape.4042 = s32[8]{0:T(128)} reshape(%param_2.4492), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + %transpose.865 = s32[8]{0:T(128)} transpose(%reshape.4042), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + ROOT %scatter-add.233 = s32[263]{0:T(512)} scatter(%param_0.4557, %transpose.864, %transpose.865), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_63.74, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.18 (param_0.4532: s32[263], param_1.5333: s32[8], param_2.4498: s32[8]) -> s32[263] { - %param_0.4532 = s32[263]{0:T(512)} parameter(0) - %param_1.5333 = s32[8]{0:T(128)} parameter(1) - %param_2.4498 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.43 = s32[263]{0:T(512)} fusion(%param_0.4532, %param_1.5333, %param_2.4498), kind=kCustom, calls=%fused_computation.25.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.18 (param_0.4558: s32[263], param_1.5337: s32[8], param_2.4493: s32[8]) -> s32[263] { + %param_0.4558 = s32[263]{0:T(512)} parameter(0) + %param_1.5337 = s32[8]{0:T(128)} parameter(1) + %param_2.4493 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.43 = s32[263]{0:T(512)} fusion(%param_0.4558, %param_1.5337, %param_2.4493), kind=kCustom, calls=%fused_computation.25.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.18 (param_0.4533: s32[263], param_1.5334: s32[8], param_2.4499: s32[8]) -> s32[263] { - %param_0.4533 = s32[263]{0:T(512)} parameter(0) - %param_1.5334 = s32[8]{0:T(128)} parameter(1) - %param_2.4499 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.44.cloned.1 = s32[263]{0:T(512)} call(%param_0.4533, %param_1.5334, %param_2.4499), to_apply=%called_computation.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.18 (param_0.4559: s32[263], param_1.5338: s32[8], param_2.4494: s32[8]) -> s32[263] { + %param_0.4559 = s32[263]{0:T(512)} parameter(0) + %param_1.5338 = s32[8]{0:T(128)} parameter(1) + %param_2.4494 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.44.cloned.1 = s32[263]{0:T(512)} call(%param_0.4559, %param_1.5338, %param_2.4494), to_apply=%called_computation.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.2 (param_0.90: s32[263], param_1.142: s32[8], param_2.86: s32[8], param_3.3105: token[]) -> s32[263] { - %param_3.3105 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.2 (param_0.90: s32[263], param_1.142: s32[8], param_2.86: s32[8], param_3.3110: token[]) -> s32[263] { + %param_3.3110 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.90 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.142 = s32[8]{0:T(128)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.86 = s32[8]{0:T(128)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -710,57 +710,57 @@ StackFrames ROOT %scatter_offload_custom_fusion.44.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.44.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.2 (param_0.91: s32[263], param_1.143: s32[8], param_2.87: s32[8], param_3.3104: token[]) -> s32[263] { - %param_3.3104 = token[] parameter(3) +%async_computation.2 (param_0.91: s32[263], param_1.143: s32[8], param_2.87: s32[8], param_3.3109: token[]) -> s32[263] { + %param_3.3109 = token[] parameter(3) %param_0.91 = s32[263]{0:T(512)} parameter(0) %param_1.143 = s32[8]{0:T(128)} parameter(1) %param_2.87 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.8.cloned.1 = s32[263]{0:T(512)} call(%param_0.91, %param_1.143, %param_2.87, %param_3.3104), to_apply=%called_computation.2, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.8.cloned.1 = s32[263]{0:T(512)} call(%param_0.91, %param_1.143, %param_2.87, %param_3.3109), to_apply=%called_computation.2, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/rematted_computation/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.19 (param_0.4534: s32[263]) -> s32[263] { - %param_0.4534 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2079 = s32[263]{0:T(512)} copy(%param_0.4534), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.19 (param_0.4560: s32[263]) -> s32[263] { + %param_0.4560 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2079 = s32[263]{0:T(512)} copy(%param_0.4560), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.19 (param_0.4535: s32[263]) -> s32[263] { - %param_0.4535 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2080.cloned.1 = s32[263]{0:T(512)} call(%param_0.4535), to_apply=%called_computation.19 +%async_computation.19 (param_0.4561: s32[263]) -> s32[263] { + %param_0.4561 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2080.cloned.1 = s32[263]{0:T(512)} call(%param_0.4561), to_apply=%called_computation.19 }, execution_thread="sparsecore" %region_73.86.clone (scatter-add.163: s32[], scatter-add.164: s32[]) -> s32[] { %scatter-add.163 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.164 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2527 = s32[]{:T(128)S(7)} add(%scatter-add.163, %scatter-add.164), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2483 = s32[]{:T(128)S(7)} add(%scatter-add.163, %scatter-add.164), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.26.clone.clone (param_0.4536: s32[263], param_1.5335: s32[256], param_2.4500: s32[256]) -> s32[263] { - %param_0.4536 = s32[263]{0:T(512)} parameter(0) - %param_1.5335 = s32[256]{0:T(256)} parameter(1) - %reshape.3917 = s32[256]{0:T(256)} reshape(%param_1.5335), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1106 = s32[256]{0:T(256)} transpose(%reshape.3917), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %param_2.4500 = s32[256]{0:T(256)} parameter(2) - %reshape.3918 = s32[256]{0:T(256)} reshape(%param_2.4500), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1107 = s32[256]{0:T(256)} transpose(%reshape.3918), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.234 = s32[263]{0:T(512)} scatter(%param_0.4536, %transpose.1106, %transpose.1107), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_73.86.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.26.clone.clone (param_0.4562: s32[263], param_1.5339: s32[256], param_2.4495: s32[256]) -> s32[263] { + %param_0.4562 = s32[263]{0:T(512)} parameter(0) + %param_1.5339 = s32[256]{0:T(256)} parameter(1) + %reshape.4043 = s32[256]{0:T(256)} reshape(%param_1.5339), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.866 = s32[256]{0:T(256)} transpose(%reshape.4043), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4495 = s32[256]{0:T(256)} parameter(2) + %reshape.4044 = s32[256]{0:T(256)} reshape(%param_2.4495), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.867 = s32[256]{0:T(256)} transpose(%reshape.4044), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.234 = s32[263]{0:T(512)} scatter(%param_0.4562, %transpose.866, %transpose.867), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_73.86.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.20 (param_0.4537: s32[263], param_1.5336: s32[256], param_2.4501: s32[256]) -> s32[263] { - %param_0.4537 = s32[263]{0:T(512)} parameter(0) - %param_1.5336 = s32[256]{0:T(256)} parameter(1) - %param_2.4501 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.45 = s32[263]{0:T(512)} fusion(%param_0.4537, %param_1.5336, %param_2.4501), kind=kCustom, calls=%fused_computation.26.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.20 (param_0.4563: s32[263], param_1.5340: s32[256], param_2.4496: s32[256]) -> s32[263] { + %param_0.4563 = s32[263]{0:T(512)} parameter(0) + %param_1.5340 = s32[256]{0:T(256)} parameter(1) + %param_2.4496 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.45 = s32[263]{0:T(512)} fusion(%param_0.4563, %param_1.5340, %param_2.4496), kind=kCustom, calls=%fused_computation.26.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.20 (param_0.4538: s32[263], param_1.5337: s32[256], param_2.4502: s32[256]) -> s32[263] { - %param_0.4538 = s32[263]{0:T(512)} parameter(0) - %param_1.5337 = s32[256]{0:T(256)} parameter(1) - %param_2.4502 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.46.cloned.1 = s32[263]{0:T(512)} call(%param_0.4538, %param_1.5337, %param_2.4502), to_apply=%called_computation.20, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.20 (param_0.4564: s32[263], param_1.5341: s32[256], param_2.4497: s32[256]) -> s32[263] { + %param_0.4564 = s32[263]{0:T(512)} parameter(0) + %param_1.5341 = s32[256]{0:T(256)} parameter(1) + %param_2.4497 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.46.cloned.1 = s32[263]{0:T(512)} call(%param_0.4564, %param_1.5341, %param_2.4497), to_apply=%called_computation.20, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.3 (param_0.93: s32[263], param_1.145: s32[256], param_2.89: s32[256], param_3.3091: token[]) -> s32[263] { - %param_3.3091 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.3 (param_0.93: s32[263], param_1.145: s32[256], param_2.89: s32[256], param_3.3096: token[]) -> s32[263] { + %param_3.3096 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.93 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.145 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.89 = s32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -770,57 +770,57 @@ StackFrames ROOT %scatter_offload_custom_fusion.46.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.46.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.3 (param_0.94: s32[263], param_1.146: s32[256], param_2.90: s32[256], param_3.3090: token[]) -> s32[263] { - %param_3.3090 = token[] parameter(3) +%async_computation.3 (param_0.94: s32[263], param_1.146: s32[256], param_2.90: s32[256], param_3.3095: token[]) -> s32[263] { + %param_3.3095 = token[] parameter(3) %param_0.94 = s32[263]{0:T(512)} parameter(0) %param_1.146 = s32[256]{0:T(256)} parameter(1) %param_2.90 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.11.cloned.1 = s32[263]{0:T(512)} call(%param_0.94, %param_1.146, %param_2.90, %param_3.3090), to_apply=%called_computation.3, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.11.cloned.1 = s32[263]{0:T(512)} call(%param_0.94, %param_1.146, %param_2.90, %param_3.3095), to_apply=%called_computation.3, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.21 (param_0.4539: f32[9]) -> f32[9] { - %param_0.4539 = f32[9]{0:T(128)} parameter(0) - ROOT %copy.2081 = f32[9]{0:T(128)} copy(%param_0.4539), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.21 (param_0.4565: f32[9]) -> f32[9] { + %param_0.4565 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2081 = f32[9]{0:T(128)} copy(%param_0.4565), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.21 (param_0.4540: f32[9]) -> f32[9] { - %param_0.4540 = f32[9]{0:T(128)} parameter(0) - ROOT %copy.2082.cloned.1 = f32[9]{0:T(128)} call(%param_0.4540), to_apply=%called_computation.21 +%async_computation.21 (param_0.4566: f32[9]) -> f32[9] { + %param_0.4566 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2082.cloned.1 = f32[9]{0:T(128)} call(%param_0.4566), to_apply=%called_computation.21 }, execution_thread="sparsecore" %region_79.95.clone (scatter-add.167: f32[], scatter-add.168: f32[]) -> f32[] { %scatter-add.167 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.168 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2529 = f32[]{:T(128)S(7)} add(%scatter-add.167, %scatter-add.168), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2485 = f32[]{:T(128)S(7)} add(%scatter-add.167, %scatter-add.168), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.27.clone.clone (param_0.4541: f32[9], param_1.5338: s32[256], param_2.4503: f32[256]) -> f32[9] { - %param_0.4541 = f32[9]{0:T(128)} parameter(0) - %param_1.5338 = s32[256]{0:T(256)} parameter(1) - %reshape.3919 = s32[256]{0:T(256)} reshape(%param_1.5338), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1108 = s32[256]{0:T(256)} transpose(%reshape.3919), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %param_2.4503 = f32[256]{0:T(256)} parameter(2) - %reshape.3920 = f32[256]{0:T(256)} reshape(%param_2.4503), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1109 = f32[256]{0:T(256)} transpose(%reshape.3920), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.235 = f32[9]{0:T(128)} scatter(%param_0.4541, %transpose.1108, %transpose.1109), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_79.95.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.27.clone.clone (param_0.4567: f32[9], param_1.5342: s32[256], param_2.4498: f32[256]) -> f32[9] { + %param_0.4567 = f32[9]{0:T(128)} parameter(0) + %param_1.5342 = s32[256]{0:T(256)} parameter(1) + %reshape.4045 = s32[256]{0:T(256)} reshape(%param_1.5342), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.868 = s32[256]{0:T(256)} transpose(%reshape.4045), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %param_2.4498 = f32[256]{0:T(256)} parameter(2) + %reshape.4046 = f32[256]{0:T(256)} reshape(%param_2.4498), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.869 = f32[256]{0:T(256)} transpose(%reshape.4046), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.235 = f32[9]{0:T(128)} scatter(%param_0.4567, %transpose.868, %transpose.869), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_79.95.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.22 (param_0.4542: f32[9], param_1.5339: s32[256], param_2.4504: f32[256]) -> f32[9] { - %param_0.4542 = f32[9]{0:T(128)} parameter(0) - %param_1.5339 = s32[256]{0:T(256)} parameter(1) - %param_2.4504 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.47 = f32[9]{0:T(128)} fusion(%param_0.4542, %param_1.5339, %param_2.4504), kind=kCustom, calls=%fused_computation.27.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.22 (param_0.4568: f32[9], param_1.5343: s32[256], param_2.4499: f32[256]) -> f32[9] { + %param_0.4568 = f32[9]{0:T(128)} parameter(0) + %param_1.5343 = s32[256]{0:T(256)} parameter(1) + %param_2.4499 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.47 = f32[9]{0:T(128)} fusion(%param_0.4568, %param_1.5343, %param_2.4499), kind=kCustom, calls=%fused_computation.27.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.22 (param_0.4543: f32[9], param_1.5340: s32[256], param_2.4505: f32[256]) -> f32[9] { - %param_0.4543 = f32[9]{0:T(128)} parameter(0) - %param_1.5340 = s32[256]{0:T(256)} parameter(1) - %param_2.4505 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.48.cloned.1 = f32[9]{0:T(128)} call(%param_0.4543, %param_1.5340, %param_2.4505), to_apply=%called_computation.22, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.22 (param_0.4569: f32[9], param_1.5344: s32[256], param_2.4500: f32[256]) -> f32[9] { + %param_0.4569 = f32[9]{0:T(128)} parameter(0) + %param_1.5344 = s32[256]{0:T(256)} parameter(1) + %param_2.4500 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.48.cloned.1 = f32[9]{0:T(128)} call(%param_0.4569, %param_1.5344, %param_2.4500), to_apply=%called_computation.22, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.4 (param_0.96: f32[9], param_1.148: s32[256], param_2.92: f32[256], param_3.3097: token[]) -> f32[9] { - %param_3.3097 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.4 (param_0.96: f32[9], param_1.148: s32[256], param_2.92: f32[256], param_3.3102: token[]) -> f32[9] { + %param_3.3102 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.96 = f32[9]{0:T(128)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.148 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.92 = f32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -830,57 +830,57 @@ StackFrames ROOT %scatter_offload_custom_fusion.48.cloned.1.call-done = f32[9]{0:T(128)} async-done(%scatter_offload_custom_fusion.48.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.4 (param_0.97: f32[9], param_1.149: s32[256], param_2.93: f32[256], param_3.3096: token[]) -> f32[9] { - %param_3.3096 = token[] parameter(3) +%async_computation.4 (param_0.97: f32[9], param_1.149: s32[256], param_2.93: f32[256], param_3.3101: token[]) -> f32[9] { + %param_3.3101 = token[] parameter(3) %param_0.97 = f32[9]{0:T(128)} parameter(0) %param_1.149 = s32[256]{0:T(256)} parameter(1) %param_2.93 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.14.cloned.1 = f32[9]{0:T(128)} call(%param_0.97, %param_1.149, %param_2.93, %param_3.3096), to_apply=%called_computation.4, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.14.cloned.1 = f32[9]{0:T(128)} call(%param_0.97, %param_1.149, %param_2.93, %param_3.3101), to_apply=%called_computation.4, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.23 (param_0.4544: s32[263]) -> s32[263] { - %param_0.4544 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2083 = s32[263]{0:T(512)} copy(%param_0.4544), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.23 (param_0.4570: s32[263]) -> s32[263] { + %param_0.4570 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2083 = s32[263]{0:T(512)} copy(%param_0.4570), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.23 (param_0.4545: s32[263]) -> s32[263] { - %param_0.4545 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2084.cloned.1 = s32[263]{0:T(512)} call(%param_0.4545), to_apply=%called_computation.23 +%async_computation.23 (param_0.4571: s32[263]) -> s32[263] { + %param_0.4571 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2084.cloned.1 = s32[263]{0:T(512)} call(%param_0.4571), to_apply=%called_computation.23 }, execution_thread="sparsecore" %region_81.97.clone (scatter-add.171: s32[], scatter-add.172: s32[]) -> s32[] { %scatter-add.171 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.172 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2531 = s32[]{:T(128)S(7)} add(%scatter-add.171, %scatter-add.172), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2487 = s32[]{:T(128)S(7)} add(%scatter-add.171, %scatter-add.172), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.28.clone.clone (param_0.4546: s32[263], param_1.5341: s32[8], param_2.4506: s32[8]) -> s32[263] { - %param_0.4546 = s32[263]{0:T(512)} parameter(0) - %param_1.5341 = s32[8]{0:T(128)} parameter(1) - %reshape.3921 = s32[8]{0:T(128)} reshape(%param_1.5341), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1110 = s32[8]{0:T(128)} transpose(%reshape.3921), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %param_2.4506 = s32[8]{0:T(128)} parameter(2) - %reshape.3922 = s32[8]{0:T(128)} reshape(%param_2.4506), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - %transpose.1111 = s32[8]{0:T(128)} transpose(%reshape.3922), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - ROOT %scatter-add.236 = s32[263]{0:T(512)} scatter(%param_0.4546, %transpose.1110, %transpose.1111), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_81.97.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.28.clone.clone (param_0.4572: s32[263], param_1.5345: s32[8], param_2.4501: s32[8]) -> s32[263] { + %param_0.4572 = s32[263]{0:T(512)} parameter(0) + %param_1.5345 = s32[8]{0:T(128)} parameter(1) + %reshape.4047 = s32[8]{0:T(128)} reshape(%param_1.5345), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.870 = s32[8]{0:T(128)} transpose(%reshape.4047), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4501 = s32[8]{0:T(128)} parameter(2) + %reshape.4048 = s32[8]{0:T(128)} reshape(%param_2.4501), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + %transpose.871 = s32[8]{0:T(128)} transpose(%reshape.4048), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + ROOT %scatter-add.236 = s32[263]{0:T(512)} scatter(%param_0.4572, %transpose.870, %transpose.871), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_81.97.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.24 (param_0.4547: s32[263], param_1.5342: s32[8], param_2.4507: s32[8]) -> s32[263] { - %param_0.4547 = s32[263]{0:T(512)} parameter(0) - %param_1.5342 = s32[8]{0:T(128)} parameter(1) - %param_2.4507 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.49 = s32[263]{0:T(512)} fusion(%param_0.4547, %param_1.5342, %param_2.4507), kind=kCustom, calls=%fused_computation.28.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.24 (param_0.4573: s32[263], param_1.5346: s32[8], param_2.4502: s32[8]) -> s32[263] { + %param_0.4573 = s32[263]{0:T(512)} parameter(0) + %param_1.5346 = s32[8]{0:T(128)} parameter(1) + %param_2.4502 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.49 = s32[263]{0:T(512)} fusion(%param_0.4573, %param_1.5346, %param_2.4502), kind=kCustom, calls=%fused_computation.28.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.24 (param_0.4548: s32[263], param_1.5343: s32[8], param_2.4508: s32[8]) -> s32[263] { - %param_0.4548 = s32[263]{0:T(512)} parameter(0) - %param_1.5343 = s32[8]{0:T(128)} parameter(1) - %param_2.4508 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.50.cloned.1 = s32[263]{0:T(512)} call(%param_0.4548, %param_1.5343, %param_2.4508), to_apply=%called_computation.24, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.24 (param_0.4574: s32[263], param_1.5347: s32[8], param_2.4503: s32[8]) -> s32[263] { + %param_0.4574 = s32[263]{0:T(512)} parameter(0) + %param_1.5347 = s32[8]{0:T(128)} parameter(1) + %param_2.4503 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.50.cloned.1 = s32[263]{0:T(512)} call(%param_0.4574, %param_1.5347, %param_2.4503), to_apply=%called_computation.24, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.5 (param_0.99: s32[263], param_1.151: s32[8], param_2.95: s32[8], param_3.3107: token[]) -> s32[263] { - %param_3.3107 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.5 (param_0.99: s32[263], param_1.151: s32[8], param_2.95: s32[8], param_3.3112: token[]) -> s32[263] { + %param_3.3112 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.99 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.151 = s32[8]{0:T(128)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.95 = s32[8]{0:T(128)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -890,57 +890,57 @@ StackFrames ROOT %scatter_offload_custom_fusion.50.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.50.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.5 (param_0.100: s32[263], param_1.152: s32[8], param_2.96: s32[8], param_3.3106: token[]) -> s32[263] { - %param_3.3106 = token[] parameter(3) +%async_computation.5 (param_0.100: s32[263], param_1.152: s32[8], param_2.96: s32[8], param_3.3111: token[]) -> s32[263] { + %param_3.3111 = token[] parameter(3) %param_0.100 = s32[263]{0:T(512)} parameter(0) %param_1.152 = s32[8]{0:T(128)} parameter(1) %param_2.96 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.17.cloned.1 = s32[263]{0:T(512)} call(%param_0.100, %param_1.152, %param_2.96, %param_3.3106), to_apply=%called_computation.5, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.17.cloned.1 = s32[263]{0:T(512)} call(%param_0.100, %param_1.152, %param_2.96, %param_3.3111), to_apply=%called_computation.5, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.25 (param_0.4549: s32[263]) -> s32[263] { - %param_0.4549 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2085 = s32[263]{0:T(512)} copy(%param_0.4549), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.25 (param_0.4575: s32[263]) -> s32[263] { + %param_0.4575 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2085 = s32[263]{0:T(512)} copy(%param_0.4575), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.25 (param_0.4550: s32[263]) -> s32[263] { - %param_0.4550 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2086.cloned.1 = s32[263]{0:T(512)} call(%param_0.4550), to_apply=%called_computation.25 +%async_computation.25 (param_0.4576: s32[263]) -> s32[263] { + %param_0.4576 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2086.cloned.1 = s32[263]{0:T(512)} call(%param_0.4576), to_apply=%called_computation.25 }, execution_thread="sparsecore" %region_96.114 (scatter-add.48: s32[], scatter-add.49: s32[]) -> s32[] { %scatter-add.48 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.49 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1434 = s32[]{:T(128)S(7)} add(%scatter-add.48, %scatter-add.49), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1398 = s32[]{:T(128)S(7)} add(%scatter-add.48, %scatter-add.49), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.29.clone.clone (param_0.4551: s32[263], param_1.5344: s32[256], param_2.4509: s32[256]) -> s32[263] { - %param_0.4551 = s32[263]{0:T(512)} parameter(0) - %param_1.5344 = s32[256]{0:T(256)} parameter(1) - %reshape.3923 = s32[256]{0:T(256)} reshape(%param_1.5344), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} - %transpose.1112 = s32[256]{0:T(256)} transpose(%reshape.3923), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} - %param_2.4509 = s32[256]{0:T(256)} parameter(2) - %reshape.3924 = s32[256]{0:T(256)} reshape(%param_2.4509), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1113 = s32[256]{0:T(256)} transpose(%reshape.3924), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.237 = s32[263]{0:T(512)} scatter(%param_0.4551, %transpose.1112, %transpose.1113), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_96.114, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +%fused_computation.29.clone.clone (param_0.4577: s32[263], param_1.5348: s32[256], param_2.4504: s32[256]) -> s32[263] { + %param_0.4577 = s32[263]{0:T(512)} parameter(0) + %param_1.5348 = s32[256]{0:T(256)} parameter(1) + %reshape.4049 = s32[256]{0:T(256)} reshape(%param_1.5348), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %transpose.872 = s32[256]{0:T(256)} transpose(%reshape.4049), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %param_2.4504 = s32[256]{0:T(256)} parameter(2) + %reshape.4050 = s32[256]{0:T(256)} reshape(%param_2.4504), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.873 = s32[256]{0:T(256)} transpose(%reshape.4050), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.237 = s32[263]{0:T(512)} scatter(%param_0.4577, %transpose.872, %transpose.873), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_96.114, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.26 (param_0.4552: s32[263], param_1.5345: s32[256], param_2.4510: s32[256]) -> s32[263] { - %param_0.4552 = s32[263]{0:T(512)} parameter(0) - %param_1.5345 = s32[256]{0:T(256)} parameter(1) - %param_2.4510 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.51 = s32[263]{0:T(512)} fusion(%param_0.4552, %param_1.5345, %param_2.4510), kind=kCustom, calls=%fused_computation.29.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.26 (param_0.4578: s32[263], param_1.5349: s32[256], param_2.4505: s32[256]) -> s32[263] { + %param_0.4578 = s32[263]{0:T(512)} parameter(0) + %param_1.5349 = s32[256]{0:T(256)} parameter(1) + %param_2.4505 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.51 = s32[263]{0:T(512)} fusion(%param_0.4578, %param_1.5349, %param_2.4505), kind=kCustom, calls=%fused_computation.29.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.26 (param_0.4553: s32[263], param_1.5346: s32[256], param_2.4511: s32[256]) -> s32[263] { - %param_0.4553 = s32[263]{0:T(512)} parameter(0) - %param_1.5346 = s32[256]{0:T(256)} parameter(1) - %param_2.4511 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.52.cloned.1 = s32[263]{0:T(512)} call(%param_0.4553, %param_1.5346, %param_2.4511), to_apply=%called_computation.26, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +%async_computation.26 (param_0.4579: s32[263], param_1.5350: s32[256], param_2.4506: s32[256]) -> s32[263] { + %param_0.4579 = s32[263]{0:T(512)} parameter(0) + %param_1.5350 = s32[256]{0:T(256)} parameter(1) + %param_2.4506 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.52.cloned.1 = s32[263]{0:T(512)} call(%param_0.4579, %param_1.5350, %param_2.4506), to_apply=%called_computation.26, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.6 (param_0.102: s32[263], param_1.154: s32[256], param_2.98: s32[256], param_3.3093: token[]) -> s32[263] { - %param_3.3093 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.6 (param_0.102: s32[263], param_1.154: s32[256], param_2.98: s32[256], param_3.3098: token[]) -> s32[263] { + %param_3.3098 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.102 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.154 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.98 = s32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -950,47 +950,47 @@ StackFrames ROOT %scatter_offload_custom_fusion.52.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.52.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.6 (param_0.103: s32[263], param_1.155: s32[256], param_2.99: s32[256], param_3.3092: token[]) -> s32[263] { - %param_3.3092 = token[] parameter(3) +%async_computation.6 (param_0.103: s32[263], param_1.155: s32[256], param_2.99: s32[256], param_3.3097: token[]) -> s32[263] { + %param_3.3097 = token[] parameter(3) %param_0.103 = s32[263]{0:T(512)} parameter(0) %param_1.155 = s32[256]{0:T(256)} parameter(1) %param_2.99 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.20.cloned.1 = s32[263]{0:T(512)} call(%param_0.103, %param_1.155, %param_2.99, %param_3.3092), to_apply=%called_computation.6, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.20.cloned.1 = s32[263]{0:T(512)} call(%param_0.103, %param_1.155, %param_2.99, %param_3.3097), to_apply=%called_computation.6, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %region_102.120 (scatter-add.52: f32[], scatter-add.53: f32[]) -> f32[] { %scatter-add.52 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.53 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1437 = f32[]{:T(128)S(7)} add(%scatter-add.52, %scatter-add.53), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1401 = f32[]{:T(128)S(7)} add(%scatter-add.52, %scatter-add.53), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.30.clone.clone (param_0.4556: f32[9], param_1.5347: s32[256], param_2.4512: f32[256]) -> f32[9] { - %param_0.4556 = f32[9]{0:T(128)} parameter(0) - %param_1.5347 = s32[256]{0:T(256)} parameter(1) - %reshape.3925 = s32[256]{0:T(256)} reshape(%param_1.5347), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1114 = s32[256]{0:T(256)} transpose(%reshape.3925), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} - %param_2.4512 = f32[256]{0:T(256)} parameter(2) - %reshape.3926 = f32[256]{0:T(256)} reshape(%param_2.4512), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1115 = f32[256]{0:T(256)} transpose(%reshape.3926), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.238 = f32[9]{0:T(128)} scatter(%param_0.4556, %transpose.1114, %transpose.1115), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_102.120, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +%fused_computation.30.clone.clone (param_0.4582: f32[9], param_1.5351: s32[256], param_2.4507: f32[256]) -> f32[9] { + %param_0.4582 = f32[9]{0:T(128)} parameter(0) + %param_1.5351 = s32[256]{0:T(256)} parameter(1) + %reshape.4051 = s32[256]{0:T(256)} reshape(%param_1.5351), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.874 = s32[256]{0:T(256)} transpose(%reshape.4051), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/broadcast_in_dim" stack_frame_id=0} + %param_2.4507 = f32[256]{0:T(256)} parameter(2) + %reshape.4052 = f32[256]{0:T(256)} reshape(%param_2.4507), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.875 = f32[256]{0:T(256)} transpose(%reshape.4052), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.238 = f32[9]{0:T(128)} scatter(%param_0.4582, %transpose.874, %transpose.875), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_102.120, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.28 (param_0.4557: f32[9], param_1.5348: s32[256], param_2.4513: f32[256]) -> f32[9] { - %param_0.4557 = f32[9]{0:T(128)} parameter(0) - %param_1.5348 = s32[256]{0:T(256)} parameter(1) - %param_2.4513 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.53 = f32[9]{0:T(128)} fusion(%param_0.4557, %param_1.5348, %param_2.4513), kind=kCustom, calls=%fused_computation.30.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.28 (param_0.4583: f32[9], param_1.5352: s32[256], param_2.4508: f32[256]) -> f32[9] { + %param_0.4583 = f32[9]{0:T(128)} parameter(0) + %param_1.5352 = s32[256]{0:T(256)} parameter(1) + %param_2.4508 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.53 = f32[9]{0:T(128)} fusion(%param_0.4583, %param_1.5352, %param_2.4508), kind=kCustom, calls=%fused_computation.30.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.28 (param_0.4558: f32[9], param_1.5349: s32[256], param_2.4514: f32[256]) -> f32[9] { - %param_0.4558 = f32[9]{0:T(128)} parameter(0) - %param_1.5349 = s32[256]{0:T(256)} parameter(1) - %param_2.4514 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.54.cloned.1 = f32[9]{0:T(128)} call(%param_0.4558, %param_1.5349, %param_2.4514), to_apply=%called_computation.28, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +%async_computation.28 (param_0.4584: f32[9], param_1.5353: s32[256], param_2.4509: f32[256]) -> f32[9] { + %param_0.4584 = f32[9]{0:T(128)} parameter(0) + %param_1.5353 = s32[256]{0:T(256)} parameter(1) + %param_2.4509 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.54.cloned.1 = f32[9]{0:T(128)} call(%param_0.4584, %param_1.5353, %param_2.4509), to_apply=%called_computation.28, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.7 (param_0.105: f32[9], param_1.157: s32[256], param_2.101: f32[256], param_3.3101: token[]) -> f32[9] { - %param_3.3101 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.7 (param_0.105: f32[9], param_1.157: s32[256], param_2.101: f32[256], param_3.3106: token[]) -> f32[9] { + %param_3.3106 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.105 = f32[9]{0:T(128)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.157 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.101 = f32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -998,47 +998,47 @@ StackFrames ROOT %scatter_offload_custom_fusion.54.cloned.1.call-done = f32[9]{0:T(128)} async-done(%scatter_offload_custom_fusion.54.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.7 (param_0.106: f32[9], param_1.158: s32[256], param_2.102: f32[256], param_3.3100: token[]) -> f32[9] { - %param_3.3100 = token[] parameter(3) +%async_computation.7 (param_0.106: f32[9], param_1.158: s32[256], param_2.102: f32[256], param_3.3105: token[]) -> f32[9] { + %param_3.3105 = token[] parameter(3) %param_0.106 = f32[9]{0:T(128)} parameter(0) %param_1.158 = s32[256]{0:T(256)} parameter(1) %param_2.102 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.23.cloned.1 = f32[9]{0:T(128)} call(%param_0.106, %param_1.158, %param_2.102, %param_3.3100), to_apply=%called_computation.7, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.23.cloned.1 = f32[9]{0:T(128)} call(%param_0.106, %param_1.158, %param_2.102, %param_3.3105), to_apply=%called_computation.7, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" %region_104.122 (scatter-add.83: s32[], scatter-add.84: s32[]) -> s32[] { %scatter-add.83 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.84 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1438 = s32[]{:T(128)S(7)} add(%scatter-add.83, %scatter-add.84), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1402 = s32[]{:T(128)S(7)} add(%scatter-add.83, %scatter-add.84), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.31.clone.clone (param_0.4561: s32[263], param_1.5350: s32[8], param_2.4515: s32[8]) -> s32[263] { - %param_0.4561 = s32[263]{0:T(512)} parameter(0) - %param_1.5350 = s32[8]{0:T(128)} parameter(1) - %reshape.3927 = s32[8]{0:T(128)} reshape(%param_1.5350), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} - %transpose.1116 = s32[8]{0:T(128)} transpose(%reshape.3927), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} - %param_2.4515 = s32[8]{0:T(128)} parameter(2) - %reshape.3928 = s32[8]{0:T(128)} reshape(%param_2.4515), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - %transpose.1117 = s32[8]{0:T(128)} transpose(%reshape.3928), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} - ROOT %scatter-add.239 = s32[263]{0:T(512)} scatter(%param_0.4561, %transpose.1116, %transpose.1117), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_104.122, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +%fused_computation.31.clone.clone (param_0.4587: s32[263], param_1.5354: s32[8], param_2.4510: s32[8]) -> s32[263] { + %param_0.4587 = s32[263]{0:T(512)} parameter(0) + %param_1.5354 = s32[8]{0:T(128)} parameter(1) + %reshape.4053 = s32[8]{0:T(128)} reshape(%param_1.5354), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %transpose.876 = s32[8]{0:T(128)} transpose(%reshape.4053), dimensions={0}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/select_n" stack_frame_id=0} + %param_2.4510 = s32[8]{0:T(128)} parameter(2) + %reshape.4054 = s32[8]{0:T(128)} reshape(%param_2.4510), metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + %transpose.877 = s32[8]{0:T(128)} transpose(%reshape.4054), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/jit(gmm)/broadcast.80" stack_frame_id=0} + ROOT %scatter-add.239 = s32[263]{0:T(512)} scatter(%param_0.4587, %transpose.876, %transpose.877), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_104.122, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.30 (param_0.4562: s32[263], param_1.5351: s32[8], param_2.4516: s32[8]) -> s32[263] { - %param_0.4562 = s32[263]{0:T(512)} parameter(0) - %param_1.5351 = s32[8]{0:T(128)} parameter(1) - %param_2.4516 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.55 = s32[263]{0:T(512)} fusion(%param_0.4562, %param_1.5351, %param_2.4516), kind=kCustom, calls=%fused_computation.31.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.30 (param_0.4588: s32[263], param_1.5355: s32[8], param_2.4511: s32[8]) -> s32[263] { + %param_0.4588 = s32[263]{0:T(512)} parameter(0) + %param_1.5355 = s32[8]{0:T(128)} parameter(1) + %param_2.4511 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.55 = s32[263]{0:T(512)} fusion(%param_0.4588, %param_1.5355, %param_2.4511), kind=kCustom, calls=%fused_computation.31.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.30 (param_0.4563: s32[263], param_1.5352: s32[8], param_2.4517: s32[8]) -> s32[263] { - %param_0.4563 = s32[263]{0:T(512)} parameter(0) - %param_1.5352 = s32[8]{0:T(128)} parameter(1) - %param_2.4517 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.56.cloned.1 = s32[263]{0:T(512)} call(%param_0.4563, %param_1.5352, %param_2.4517), to_apply=%called_computation.30, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} +%async_computation.30 (param_0.4589: s32[263], param_1.5356: s32[8], param_2.4512: s32[8]) -> s32[263] { + %param_0.4589 = s32[263]{0:T(512)} parameter(0) + %param_1.5356 = s32[8]{0:T(128)} parameter(1) + %param_2.4512 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.56.cloned.1 = s32[263]{0:T(512)} call(%param_0.4589, %param_1.5356, %param_2.4512), to_apply=%called_computation.30, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.8 (param_0.108: s32[263], param_1.160: s32[8], param_2.104: s32[8], param_3.3109: token[]) -> s32[263] { - %param_3.3109 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.8 (param_0.108: s32[263], param_1.160: s32[8], param_2.104: s32[8], param_3.3114: token[]) -> s32[263] { + %param_3.3114 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.108 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.160 = s32[8]{0:T(128)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.104 = s32[8]{0:T(128)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -1046,47 +1046,47 @@ StackFrames ROOT %scatter_offload_custom_fusion.56.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.56.cloned.1.call-start), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.8 (param_0.109: s32[263], param_1.161: s32[8], param_2.105: s32[8], param_3.3108: token[]) -> s32[263] { - %param_3.3108 = token[] parameter(3) +%async_computation.8 (param_0.109: s32[263], param_1.161: s32[8], param_2.105: s32[8], param_3.3113: token[]) -> s32[263] { + %param_3.3113 = token[] parameter(3) %param_0.109 = s32[263]{0:T(512)} parameter(0) %param_1.161 = s32[8]{0:T(128)} parameter(1) %param_2.105 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.26.cloned.1 = s32[263]{0:T(512)} call(%param_0.109, %param_1.161, %param_2.105, %param_3.3108), to_apply=%called_computation.8, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.26.cloned.1 = s32[263]{0:T(512)} call(%param_0.109, %param_1.161, %param_2.105, %param_3.3113), to_apply=%called_computation.8, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/shard_map/jit(tgmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%region_14.20 (scatter-add.0: s32[], scatter-add.1: s32[]) -> s32[] { +%region_13.19 (scatter-add.0: s32[], scatter-add.1: s32[]) -> s32[] { %scatter-add.0 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.1 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.1345 = s32[]{:T(128)S(7)} add(%scatter-add.0, %scatter-add.1), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.1310 = s32[]{:T(128)S(7)} add(%scatter-add.0, %scatter-add.1), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.17.clone.clone.clone (param_0.4566: s32[256], param_1.5353: s32[4096], param_2.4518: s32[4096]) -> s32[256] { - %param_0.4566 = s32[256]{0:T(256)} parameter(0) - %param_1.5353 = s32[4096]{0:T(1024)} parameter(1) - %reshape.3929 = s32[4096]{0:T(1024)} reshape(%param_1.5353), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} - %transpose.1118 = s32[4096]{0:T(1024)} transpose(%reshape.3929), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} - %param_2.4518 = s32[4096]{0:T(1024)} parameter(2) - %reshape.3930 = s32[4096]{0:T(1024)} reshape(%param_2.4518), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} - %transpose.1119 = s32[4096]{0:T(1024)} transpose(%reshape.3930), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} - ROOT %scatter-add.240 = s32[256]{0:T(256)} scatter(%param_0.4566, %transpose.1118, %transpose.1119), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_14.20, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} +%fused_computation.17.clone.clone.clone (param_0.4592: s32[256], param_1.5357: s32[4096], param_2.4513: s32[4096]) -> s32[256] { + %param_0.4592 = s32[256]{0:T(256)} parameter(0) + %param_1.5357 = s32[4096]{0:T(1024)} parameter(1) + %reshape.4055 = s32[4096]{0:T(1024)} reshape(%param_1.5357), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} + %transpose.878 = s32[4096]{0:T(1024)} transpose(%reshape.4055), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/select_n" stack_frame_id=0} + %param_2.4513 = s32[4096]{0:T(1024)} parameter(2) + %reshape.4056 = s32[4096]{0:T(1024)} reshape(%param_2.4513), metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + %transpose.879 = s32[4096]{0:T(1024)} transpose(%reshape.4056), dimensions={0}, metadata={op_name="jit(train_step)/moe_layers/shard_map/broadcast_in_dim" stack_frame_id=0} + ROOT %scatter-add.240 = s32[256]{0:T(256)} scatter(%param_0.4592, %transpose.878, %transpose.879), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_13.19, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.32 (param_0.4567: s32[256], param_1.5354: s32[4096], param_2.4519: s32[4096]) -> s32[256] { - %param_0.4567 = s32[256]{0:T(256)} parameter(0) - %param_1.5354 = s32[4096]{0:T(1024)} parameter(1) - %param_2.4519 = s32[4096]{0:T(1024)} parameter(2) - ROOT %scatter_offload_custom_fusion.57 = s32[256]{0:T(256)} fusion(%param_0.4567, %param_1.5354, %param_2.4519), kind=kCustom, calls=%fused_computation.17.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["256"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"4160","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.32 (param_0.4593: s32[256], param_1.5358: s32[4096], param_2.4514: s32[4096]) -> s32[256] { + %param_0.4593 = s32[256]{0:T(256)} parameter(0) + %param_1.5358 = s32[4096]{0:T(1024)} parameter(1) + %param_2.4514 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.57 = s32[256]{0:T(256)} fusion(%param_0.4593, %param_1.5358, %param_2.4514), kind=kCustom, calls=%fused_computation.17.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["256"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"4160","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.32 (param_0.4568: s32[256], param_1.5355: s32[4096], param_2.4520: s32[4096]) -> s32[256] { - %param_0.4568 = s32[256]{0:T(256)} parameter(0) - %param_1.5355 = s32[4096]{0:T(1024)} parameter(1) - %param_2.4520 = s32[4096]{0:T(1024)} parameter(2) - ROOT %scatter_offload_custom_fusion.58.cloned.1 = s32[256]{0:T(256)} call(%param_0.4568, %param_1.5355, %param_2.4520), to_apply=%called_computation.32, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} +%async_computation.32 (param_0.4594: s32[256], param_1.5359: s32[4096], param_2.4515: s32[4096]) -> s32[256] { + %param_0.4594 = s32[256]{0:T(256)} parameter(0) + %param_1.5359 = s32[4096]{0:T(1024)} parameter(1) + %param_2.4515 = s32[4096]{0:T(1024)} parameter(2) + ROOT %scatter_offload_custom_fusion.58.cloned.1 = s32[256]{0:T(256)} call(%param_0.4594, %param_1.5359, %param_2.4515), to_apply=%called_computation.32, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.9 (param_0.111: s32[256], param_1.163: s32[4096], param_2.107: s32[4096], param_3.3087: token[]) -> s32[256] { - %param_3.3087 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.9 (param_0.111: s32[256], param_1.163: s32[4096], param_2.107: s32[4096], param_3.3092: token[]) -> s32[256] { + %param_3.3092 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.111 = s32[256]{0:T(256)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.163 = s32[4096]{0:T(1024)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.107 = s32[4096]{0:T(1024)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -1094,57 +1094,57 @@ StackFrames ROOT %scatter_offload_custom_fusion.58.cloned.1.call-done = s32[256]{0:T(256)} async-done(%scatter_offload_custom_fusion.58.cloned.1.call-start), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.9 (param_0.112: s32[256], param_1.164: s32[4096], param_2.108: s32[4096], param_3.3086: token[]) -> s32[256] { - %param_3.3086 = token[] parameter(3) +%async_computation.9 (param_0.112: s32[256], param_1.164: s32[4096], param_2.108: s32[4096], param_3.3091: token[]) -> s32[256] { + %param_3.3091 = token[] parameter(3) %param_0.112 = s32[256]{0:T(256)} parameter(0) %param_1.164 = s32[4096]{0:T(1024)} parameter(1) %param_2.108 = s32[4096]{0:T(1024)} parameter(2) - ROOT %scatter_offload_custom_fusion.29.cloned.1 = s32[256]{0:T(256)} call(%param_0.112, %param_1.164, %param_2.108, %param_3.3086), to_apply=%called_computation.9, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.29.cloned.1 = s32[256]{0:T(256)} call(%param_0.112, %param_1.164, %param_2.108, %param_3.3091), to_apply=%called_computation.9, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.33 (param_0.4569: s32[263]) -> s32[263] { - %param_0.4569 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2093 = s32[263]{0:T(512)} copy(%param_0.4569), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.33 (param_0.4595: s32[263]) -> s32[263] { + %param_0.4595 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2093 = s32[263]{0:T(512)} copy(%param_0.4595), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.33 (param_0.4570: s32[263]) -> s32[263] { - %param_0.4570 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2094.cloned.1 = s32[263]{0:T(512)} call(%param_0.4570), to_apply=%called_computation.33 +%async_computation.33 (param_0.4596: s32[263]) -> s32[263] { + %param_0.4596 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2094.cloned.1 = s32[263]{0:T(512)} call(%param_0.4596), to_apply=%called_computation.33 }, execution_thread="sparsecore" -%region_20.26.clone.1 (scatter-add.141: s32[], scatter-add.142: s32[]) -> s32[] { +%region_19.25.clone.1 (scatter-add.141: s32[], scatter-add.142: s32[]) -> s32[] { %scatter-add.141 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.142 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2516 = s32[]{:T(128)S(7)} add(%scatter-add.141, %scatter-add.142), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2472 = s32[]{:T(128)S(7)} add(%scatter-add.141, %scatter-add.142), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.18.clone.clone.clone (param_0.4571: s32[263], param_1.5356: s32[256], param_2.4521: s32[256]) -> s32[263] { - %param_0.4571 = s32[263]{0:T(512)} parameter(0) - %param_1.5356 = s32[256]{0:T(256)} parameter(1) - %reshape.3931 = s32[256]{0:T(256)} reshape(%param_1.5356), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1120 = s32[256]{0:T(256)} transpose(%reshape.3931), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %param_2.4521 = s32[256]{0:T(256)} parameter(2) - %reshape.3932 = s32[256]{0:T(256)} reshape(%param_2.4521) - %transpose.1121 = s32[256]{0:T(256)} transpose(%reshape.3932), dimensions={0} - ROOT %scatter-add.241 = s32[263]{0:T(512)} scatter(%param_0.4571, %transpose.1120, %transpose.1121), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_20.26.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.18.clone.clone.clone (param_0.4597: s32[263], param_1.5360: s32[256], param_2.4516: s32[256]) -> s32[263] { + %param_0.4597 = s32[263]{0:T(512)} parameter(0) + %param_1.5360 = s32[256]{0:T(256)} parameter(1) + %reshape.4057 = s32[256]{0:T(256)} reshape(%param_1.5360), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.880 = s32[256]{0:T(256)} transpose(%reshape.4057), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4516 = s32[256]{0:T(256)} parameter(2) + %reshape.4058 = s32[256]{0:T(256)} reshape(%param_2.4516) + %transpose.881 = s32[256]{0:T(256)} transpose(%reshape.4058), dimensions={0} + ROOT %scatter-add.241 = s32[263]{0:T(512)} scatter(%param_0.4597, %transpose.880, %transpose.881), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_19.25.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.34 (param_0.4572: s32[263], param_1.5357: s32[256], param_2.4522: s32[256]) -> s32[263] { - %param_0.4572 = s32[263]{0:T(512)} parameter(0) - %param_1.5357 = s32[256]{0:T(256)} parameter(1) - %param_2.4522 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.59 = s32[263]{0:T(512)} fusion(%param_0.4572, %param_1.5357, %param_2.4522), kind=kCustom, calls=%fused_computation.18.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.34 (param_0.4598: s32[263], param_1.5361: s32[256], param_2.4517: s32[256]) -> s32[263] { + %param_0.4598 = s32[263]{0:T(512)} parameter(0) + %param_1.5361 = s32[256]{0:T(256)} parameter(1) + %param_2.4517 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.59 = s32[263]{0:T(512)} fusion(%param_0.4598, %param_1.5361, %param_2.4517), kind=kCustom, calls=%fused_computation.18.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"384","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.34 (param_0.4573: s32[263], param_1.5358: s32[256], param_2.4523: s32[256]) -> s32[263] { - %param_0.4573 = s32[263]{0:T(512)} parameter(0) - %param_1.5358 = s32[256]{0:T(256)} parameter(1) - %param_2.4523 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.60.cloned.1 = s32[263]{0:T(512)} call(%param_0.4573, %param_1.5358, %param_2.4523), to_apply=%called_computation.34, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.34 (param_0.4599: s32[263], param_1.5362: s32[256], param_2.4518: s32[256]) -> s32[263] { + %param_0.4599 = s32[263]{0:T(512)} parameter(0) + %param_1.5362 = s32[256]{0:T(256)} parameter(1) + %param_2.4518 = s32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.60.cloned.1 = s32[263]{0:T(512)} call(%param_0.4599, %param_1.5362, %param_2.4518), to_apply=%called_computation.34, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.10 (param_0.114: s32[263], param_1.166: s32[256], param_2.110: s32[256], param_3.3089: token[]) -> s32[263] { - %param_3.3089 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.10 (param_0.114: s32[263], param_1.166: s32[256], param_2.110: s32[256], param_3.3094: token[]) -> s32[263] { + %param_3.3094 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.114 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.166 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.110 = s32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -1154,57 +1154,57 @@ StackFrames ROOT %scatter_offload_custom_fusion.60.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.60.cloned.1.call-start), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.10 (param_0.115: s32[263], param_1.167: s32[256], param_2.111: s32[256], param_3.3088: token[]) -> s32[263] { - %param_3.3088 = token[] parameter(3) +%async_computation.10 (param_0.115: s32[263], param_1.167: s32[256], param_2.111: s32[256], param_3.3093: token[]) -> s32[263] { + %param_3.3093 = token[] parameter(3) %param_0.115 = s32[263]{0:T(512)} parameter(0) %param_1.167 = s32[256]{0:T(256)} parameter(1) %param_2.111 = s32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.32.cloned.1 = s32[263]{0:T(512)} call(%param_0.115, %param_1.167, %param_2.111, %param_3.3088), to_apply=%called_computation.10, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.32.cloned.1 = s32[263]{0:T(512)} call(%param_0.115, %param_1.167, %param_2.111, %param_3.3093), to_apply=%called_computation.10, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.35 (param_0.4574: f32[9]) -> f32[9] { - %param_0.4574 = f32[9]{0:T(128)} parameter(0) - ROOT %copy.2095 = f32[9]{0:T(128)} copy(%param_0.4574), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.35 (param_0.4600: f32[9]) -> f32[9] { + %param_0.4600 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2095 = f32[9]{0:T(128)} copy(%param_0.4600), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"1131","iteration_bounds":[],"scratchpad_allocation_size":"128","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.35 (param_0.4575: f32[9]) -> f32[9] { - %param_0.4575 = f32[9]{0:T(128)} parameter(0) - ROOT %copy.2096.cloned.1 = f32[9]{0:T(128)} call(%param_0.4575), to_apply=%called_computation.35 +%async_computation.35 (param_0.4601: f32[9]) -> f32[9] { + %param_0.4601 = f32[9]{0:T(128)} parameter(0) + ROOT %copy.2096.cloned.1 = f32[9]{0:T(128)} call(%param_0.4601), to_apply=%called_computation.35 }, execution_thread="sparsecore" -%region_26.33.clone.1 (scatter-add.145: f32[], scatter-add.146: f32[]) -> f32[] { +%region_25.32.clone.1 (scatter-add.145: f32[], scatter-add.146: f32[]) -> f32[] { %scatter-add.145 = f32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.146 = f32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2518 = f32[]{:T(128)S(7)} add(%scatter-add.145, %scatter-add.146), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2474 = f32[]{:T(128)S(7)} add(%scatter-add.145, %scatter-add.146), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.19.clone.clone.clone (param_0.4576: f32[9], param_1.5359: s32[256], param_2.4524: f32[256]) -> f32[9] { - %param_0.4576 = f32[9]{0:T(128)} parameter(0) - %param_1.5359 = s32[256]{0:T(256)} parameter(1) - %reshape.3933 = s32[256]{0:T(256)} reshape(%param_1.5359), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %transpose.1122 = s32[256]{0:T(256)} transpose(%reshape.3933), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} - %param_2.4524 = f32[256]{0:T(256)} parameter(2) - %reshape.3934 = f32[256]{0:T(256)} reshape(%param_2.4524) - %transpose.1123 = f32[256]{0:T(256)} transpose(%reshape.3934), dimensions={0} - ROOT %scatter-add.242 = f32[9]{0:T(128)} scatter(%param_0.4576, %transpose.1122, %transpose.1123), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_26.33.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.19.clone.clone.clone (param_0.4602: f32[9], param_1.5363: s32[256], param_2.4519: f32[256]) -> f32[9] { + %param_0.4602 = f32[9]{0:T(128)} parameter(0) + %param_1.5363 = s32[256]{0:T(256)} parameter(1) + %reshape.4059 = s32[256]{0:T(256)} reshape(%param_1.5363), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %transpose.882 = s32[256]{0:T(256)} transpose(%reshape.4059), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/broadcast_in_dim" stack_frame_id=0} + %param_2.4519 = f32[256]{0:T(256)} parameter(2) + %reshape.4060 = f32[256]{0:T(256)} reshape(%param_2.4519) + %transpose.883 = f32[256]{0:T(256)} transpose(%reshape.4060), dimensions={0} + ROOT %scatter-add.242 = f32[9]{0:T(128)} scatter(%param_0.4602, %transpose.882, %transpose.883), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, indices_are_sorted=true, to_apply=%region_25.32.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.36 (param_0.4577: f32[9], param_1.5360: s32[256], param_2.4525: f32[256]) -> f32[9] { - %param_0.4577 = f32[9]{0:T(128)} parameter(0) - %param_1.5360 = s32[256]{0:T(256)} parameter(1) - %param_2.4525 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.61 = f32[9]{0:T(128)} fusion(%param_0.4577, %param_1.5360, %param_2.4525), kind=kCustom, calls=%fused_computation.19.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.36 (param_0.4603: f32[9], param_1.5364: s32[256], param_2.4520: f32[256]) -> f32[9] { + %param_0.4603 = f32[9]{0:T(128)} parameter(0) + %param_1.5364 = s32[256]{0:T(256)} parameter(1) + %param_2.4520 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.61 = f32[9]{0:T(128)} fusion(%param_0.4603, %param_1.5364, %param_2.4520), kind=kCustom, calls=%fused_computation.19.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["16"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"1312","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.36 (param_0.4578: f32[9], param_1.5361: s32[256], param_2.4526: f32[256]) -> f32[9] { - %param_0.4578 = f32[9]{0:T(128)} parameter(0) - %param_1.5361 = s32[256]{0:T(256)} parameter(1) - %param_2.4526 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.62.cloned.1 = f32[9]{0:T(128)} call(%param_0.4578, %param_1.5361, %param_2.4526), to_apply=%called_computation.36, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.36 (param_0.4604: f32[9], param_1.5365: s32[256], param_2.4521: f32[256]) -> f32[9] { + %param_0.4604 = f32[9]{0:T(128)} parameter(0) + %param_1.5365 = s32[256]{0:T(256)} parameter(1) + %param_2.4521 = f32[256]{0:T(256)} parameter(2) + ROOT %scatter_offload_custom_fusion.62.cloned.1 = f32[9]{0:T(128)} call(%param_0.4604, %param_1.5365, %param_2.4521), to_apply=%called_computation.36, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.11 (param_0.117: f32[9], param_1.169: s32[256], param_2.113: f32[256], param_3.3095: token[]) -> f32[9] { - %param_3.3095 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.11 (param_0.117: f32[9], param_1.169: s32[256], param_2.113: f32[256], param_3.3100: token[]) -> f32[9] { + %param_3.3100 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.117 = f32[9]{0:T(128)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.169 = s32[256]{0:T(256)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.113 = f32[256]{0:T(256)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -1214,57 +1214,57 @@ StackFrames ROOT %scatter_offload_custom_fusion.62.cloned.1.call-done = f32[9]{0:T(128)} async-done(%scatter_offload_custom_fusion.62.cloned.1.call-start), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.11 (param_0.118: f32[9], param_1.170: s32[256], param_2.114: f32[256], param_3.3094: token[]) -> f32[9] { - %param_3.3094 = token[] parameter(3) +%async_computation.11 (param_0.118: f32[9], param_1.170: s32[256], param_2.114: f32[256], param_3.3099: token[]) -> f32[9] { + %param_3.3099 = token[] parameter(3) %param_0.118 = f32[9]{0:T(128)} parameter(0) %param_1.170 = s32[256]{0:T(256)} parameter(1) %param_2.114 = f32[256]{0:T(256)} parameter(2) - ROOT %scatter_offload_custom_fusion.35.cloned.1 = f32[9]{0:T(128)} call(%param_0.118, %param_1.170, %param_2.114, %param_3.3094), to_apply=%called_computation.11, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.35.cloned.1 = f32[9]{0:T(128)} call(%param_0.118, %param_1.170, %param_2.114, %param_3.3099), to_apply=%called_computation.11, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.37 (param_0.4579: s32[263]) -> s32[263] { - %param_0.4579 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2097 = s32[263]{0:T(512)} copy(%param_0.4579), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.37 (param_0.4605: s32[263]) -> s32[263] { + %param_0.4605 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2097 = s32[263]{0:T(512)} copy(%param_0.4605), backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["32"],"input_window_bounds":[],"estimated_cycles":"1141","iteration_bounds":[],"scratchpad_allocation_size":"512","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"16","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.37 (param_0.4580: s32[263]) -> s32[263] { - %param_0.4580 = s32[263]{0:T(512)} parameter(0) - ROOT %copy.2098.cloned.1 = s32[263]{0:T(512)} call(%param_0.4580), to_apply=%called_computation.37 +%async_computation.37 (param_0.4606: s32[263]) -> s32[263] { + %param_0.4606 = s32[263]{0:T(512)} parameter(0) + ROOT %copy.2098.cloned.1 = s32[263]{0:T(512)} call(%param_0.4606), to_apply=%called_computation.37 }, execution_thread="sparsecore" -%region_28.35.clone.1 (scatter-add.149: s32[], scatter-add.150: s32[]) -> s32[] { +%region_27.34.clone.1 (scatter-add.149: s32[], scatter-add.150: s32[]) -> s32[] { %scatter-add.149 = s32[]{:T(128)S(7)} parameter(0), metadata={op_name="scatter-add"} %scatter-add.150 = s32[]{:T(128)S(7)} parameter(1), metadata={op_name="scatter-add"} - ROOT %add.2520 = s32[]{:T(128)S(7)} add(%scatter-add.149, %scatter-add.150), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} + ROOT %add.2476 = s32[]{:T(128)S(7)} add(%scatter-add.149, %scatter-add.150), metadata={op_name="add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["128"],"input_window_bounds":[],"estimated_cycles":"1165","iteration_bounds":[],"scratchpad_allocation_size":"520","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[{"unroll_dimension":"0","unroll_factor":"4","pipeline_remainder":false,"fully_unroll_if_trip_count_is_at_most":"0"}],"vectorizing_shape":[]},"scoped_memory_configs":[],"used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%fused_computation.20.clone.clone.clone (param_0.4581: s32[263], param_1.5362: s32[8], param_2.4527: s32[8]) -> s32[263] { - %param_0.4581 = s32[263]{0:T(512)} parameter(0) - %param_1.5362 = s32[8]{0:T(128)} parameter(1) - %reshape.3935 = s32[8]{0:T(128)} reshape(%param_1.5362), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %transpose.1124 = s32[8]{0:T(128)} transpose(%reshape.3935), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} - %param_2.4527 = s32[8]{0:T(128)} parameter(2) - %reshape.3936 = s32[8]{0:T(128)} reshape(%param_2.4527) - %transpose.1125 = s32[8]{0:T(128)} transpose(%reshape.3936), dimensions={0} - ROOT %scatter-add.243 = s32[263]{0:T(512)} scatter(%param_0.4581, %transpose.1124, %transpose.1125), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_28.35.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%fused_computation.20.clone.clone.clone (param_0.4607: s32[263], param_1.5366: s32[8], param_2.4522: s32[8]) -> s32[263] { + %param_0.4607 = s32[263]{0:T(512)} parameter(0) + %param_1.5366 = s32[8]{0:T(128)} parameter(1) + %reshape.4061 = s32[8]{0:T(128)} reshape(%param_1.5366), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %transpose.884 = s32[8]{0:T(128)} transpose(%reshape.4061), dimensions={0}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/select_n" stack_frame_id=0} + %param_2.4522 = s32[8]{0:T(128)} parameter(2) + %reshape.4062 = s32[8]{0:T(128)} reshape(%param_2.4522) + %transpose.885 = s32[8]{0:T(128)} transpose(%reshape.4062), dimensions={0} + ROOT %scatter-add.243 = s32[263]{0:T(512)} scatter(%param_0.4607, %transpose.884, %transpose.885), update_window_dims={}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=1, to_apply=%region_27.34.clone.1, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.38 (param_0.4582: s32[263], param_1.5363: s32[8], param_2.4528: s32[8]) -> s32[263] { - %param_0.4582 = s32[263]{0:T(512)} parameter(0) - %param_1.5363 = s32[8]{0:T(128)} parameter(1) - %param_2.4528 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.63 = s32[263]{0:T(512)} fusion(%param_0.4582, %param_1.5363, %param_2.4528), kind=kCustom, calls=%fused_computation.20.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} +%called_computation.38 (param_0.4608: s32[263], param_1.5367: s32[8], param_2.4523: s32[8]) -> s32[263] { + %param_0.4608 = s32[263]{0:T(512)} parameter(0) + %param_1.5367 = s32[8]{0:T(128)} parameter(1) + %param_2.4523 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.63 = s32[263]{0:T(512)} fusion(%param_0.4608, %param_1.5367, %param_2.4523), kind=kCustom, calls=%fused_computation.20.clone.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0}, backend_config={"flag_configs":[],"window_config":{"kernel_window_bounds":[],"output_window_bounds":["8"],"input_window_bounds":[],"estimated_cycles":"9223372036854775807","iteration_bounds":[],"scratchpad_allocation_size":"256","cost_model_type":"COST_MODEL_TYPE_INVALID","ml_estimated_microseconds":0,"is_mask":false,"pad_output_on_minor_dim":"0","pad_input_on_minor_dim":"0","estimated_vmem_bytes":"0","estimated_bundle_count":"0","estimated_scoped_vmem_bytes":"0"},"loop_config":{"loop_order":[],"unrolled_loops":[],"vectorizing_shape":[]},"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_TILE","used_scoped_memory_configs":[]} }, execution_thread="sparsecore" -%async_computation.38 (param_0.4583: s32[263], param_1.5364: s32[8], param_2.4529: s32[8]) -> s32[263] { - %param_0.4583 = s32[263]{0:T(512)} parameter(0) - %param_1.5364 = s32[8]{0:T(128)} parameter(1) - %param_2.4529 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.64.cloned.1 = s32[263]{0:T(512)} call(%param_0.4583, %param_1.5364, %param_2.4529), to_apply=%called_computation.38, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} +%async_computation.38 (param_0.4609: s32[263], param_1.5368: s32[8], param_2.4524: s32[8]) -> s32[263] { + %param_0.4609 = s32[263]{0:T(512)} parameter(0) + %param_1.5368 = s32[8]{0:T(128)} parameter(1) + %param_2.4524 = s32[8]{0:T(128)} parameter(2) + ROOT %scatter_offload_custom_fusion.64.cloned.1 = s32[263]{0:T(512)} call(%param_0.4609, %param_1.5368, %param_2.4524), to_apply=%called_computation.38, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%called_computation.12 (param_0.120: s32[263], param_1.172: s32[8], param_2.116: s32[8], param_3.3103: token[]) -> s32[263] { - %param_3.3103 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} +%called_computation.12 (param_0.120: s32[263], param_1.172: s32[8], param_2.116: s32[8], param_3.3108: token[]) -> s32[263] { + %param_3.3108 = token[] parameter(3), backend_config={"flag_configs":[],"scoped_memory_configs":[],"implicit_sharding":{"type":"REPLICATED","tile_assignment_dimensions":[],"tile_assignment_devices":[],"tuple_shardings":[],"replicate_on_last_tile_dim":false,"metadata":[],"last_tile_dims":[],"iota_reshape_dims":[],"iota_transpose_perm":[],"is_shard_group":false,"shard_group_id":"0","shard_group_type":"AS"},"used_scoped_memory_configs":[]} %param_0.120 = s32[263]{0:T(512)} parameter(0), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_1.172 = s32[8]{0:T(128)} parameter(1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} %param_2.116 = s32[8]{0:T(128)} parameter(2), backend_config={"flag_configs":[],"scoped_memory_configs":[],"compute_type":"COMPUTE_TYPE_SCALAR","used_scoped_memory_configs":[]} @@ -1274,727 +1274,727 @@ StackFrames ROOT %scatter_offload_custom_fusion.64.cloned.1.call-done = s32[263]{0:T(512)} async-done(%scatter_offload_custom_fusion.64.cloned.1.call-start), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%async_computation.12 (param_0.121: s32[263], param_1.173: s32[8], param_2.117: s32[8], param_3.3102: token[]) -> s32[263] { - %param_3.3102 = token[] parameter(3) +%async_computation.12 (param_0.121: s32[263], param_1.173: s32[8], param_2.117: s32[8], param_3.3107: token[]) -> s32[263] { + %param_3.3107 = token[] parameter(3) %param_0.121 = s32[263]{0:T(512)} parameter(0) %param_1.173 = s32[8]{0:T(128)} parameter(1) %param_2.117 = s32[8]{0:T(128)} parameter(2) - ROOT %scatter_offload_custom_fusion.38.cloned.1 = s32[263]{0:T(512)} call(%param_0.121, %param_1.173, %param_2.117, %param_3.3102), to_apply=%called_computation.12, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} + ROOT %scatter_offload_custom_fusion.38.cloned.1 = s32[263]{0:T(512)} call(%param_0.121, %param_1.173, %param_2.117, %param_3.3107), to_apply=%called_computation.12, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/shard_map/jit(gmm)/scatter-add" stack_frame_id=0} }, execution_thread="sparsecore" -%region_154.179 (reduce_sum.431: f32[], reduce_sum.254: f32[]) -> f32[] { - %reduce_sum.431 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.254 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.258 = f32[]{:T(128)} add(%reduce_sum.431, %reduce_sum.254), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_154.179 (reduce_sum.502: f32[], reduce_sum.336: f32[]) -> f32[] { + %reduce_sum.502 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.336 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.337 = f32[]{:T(128)} add(%reduce_sum.502, %reduce_sum.336), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.466 (param_0.4166: f32[3,1536,128,192]) -> f32[] { - %param_0.4166 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(0) - %bitcast.672 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_0.4166), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} - %square.564 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%bitcast.672, %bitcast.672), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5101 = f32[]{:T(128)} constant(0) - ROOT %reduce.669 = f32[]{:T(128)} reduce(%square.564, %constant.5101), dimensions={0,1,2,3}, to_apply=%region_154.179, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.465 (param_0.4192: f32[3,1536,128,192]) -> f32[] { + %param_0.4192 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(0) + %bitcast.654 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_0.4192), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.564 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%bitcast.654, %bitcast.654), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5049 = f32[]{:T(128)} constant(0) + ROOT %reduce.610 = f32[]{:T(128)} reduce(%square.564, %constant.5049), dimensions={0,1,2,3}, to_apply=%region_154.179, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%fused_computation.467 (param_0.1420: f32[1536,3,128,192]) -> bf16[3,1536,128,192] { - %param_0.1420 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(0) - %copy.1550 = bf16[1536,3,128,192]{2,0,3,1:T(8,128)(2,1)} copy(%param_0.1420), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'dense_layers\'][\'self_attention\'][\'wq_b\'][\'kernel\']"} - ROOT %bitcast.673 = bf16[3,1536,128,192]{2,1,3,0:T(8,128)(2,1)} bitcast(%copy.1550), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} +%fused_computation.466 (param_0.1438: f32[1536,3,128,192]) -> bf16[3,1536,128,192] { + %param_0.1438 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(0) + %copy.1550 = bf16[1536,3,128,192]{2,0,3,1:T(8,128)(2,1)} copy(%param_0.1438), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'dense_layers\'][\'self_attention\'][\'wq_b\'][\'kernel\']"} + ROOT %bitcast.655 = bf16[3,1536,128,192]{2,1,3,0:T(8,128)(2,1)} bitcast(%copy.1550), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} } -%region_221.246 (reduce_sum.893: f32[], reduce_sum.603: f32[]) -> f32[] { - %reduce_sum.893 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.603 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.604 = f32[]{:T(128)} add(%reduce_sum.893, %reduce_sum.603), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_221.246 (reduce_sum.964: f32[], reduce_sum.965: f32[]) -> f32[] { + %reduce_sum.964 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.965 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.534 = f32[]{:T(128)} add(%reduce_sum.964, %reduce_sum.965), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_187.212 (reduce_sum.655: f32[], reduce_sum.449: f32[]) -> f32[] { - %reduce_sum.655 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.449 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.450 = f32[]{:T(128)} add(%reduce_sum.655, %reduce_sum.449), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_187.212 (reduce_sum.726: f32[], reduce_sum.727: f32[]) -> f32[] { + %reduce_sum.726 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.727 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.452 = f32[]{:T(128)} add(%reduce_sum.726, %reduce_sum.727), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.468 (param_0.4136: f32[1536,3,128,192], param_1.5026: f32[], param_2.4295: f32[], param_3.2951: f32[], param_4.2203: f32[1536,3,128,192], param_5.2006: f32[], param_6.1443: f32[3,1536,128,192], param_7.1124: pred[], param_8.889: f32[1536,3,128,192]) -> (f32[], f32[1536,3,128,192], f32[1536,3,128,192], f32[1536,3,128,192], f32[]) { - %param_0.4136 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(0) - %param_3.2951 = f32[]{:T(128)S(6)} parameter(3) - %mul.4713.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_3.2951), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.467 (param_0.4162: f32[1536,3,128,192], param_1.5030: f32[], param_2.4290: f32[], param_3.2956: f32[], param_4.2203: f32[1536,3,128,192], param_5.2003: f32[], param_6.1444: f32[3,1536,128,192], param_7.1124: pred[], param_8.889: f32[1536,3,128,192]) -> (f32[], f32[1536,3,128,192], f32[1536,3,128,192], f32[1536,3,128,192], f32[]) { + %param_0.4162 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(0) + %param_3.2956 = f32[]{:T(128)S(6)} parameter(3) + %mul.5439.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_3.2956), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.1124 = pred[]{:T(512)S(6)} parameter(7) %select_n.2121.clone.1 = pred[1536,3,128,192]{2,3,1,0:T(8,128)(4,1)} broadcast(%param_7.1124), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.1443 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(6) - %bitcast.1374.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_6.1443), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} - %param_5.2006 = f32[]{:T(128)} parameter(5) - %div.2562.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_5.2006), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2561.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%bitcast.1374.clone.1, %div.2562.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2120.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} select(%select_n.2121.clone.1, %bitcast.1374.clone.1, %div.2561.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4860.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4272.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4860.clone.1), dimensions={}, metadata={op_name="broadcast.334"} - %mul.4719.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.2120.clone.1, %broadcast.4272.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1444 = f32[3,1536,128,192]{2,3,0,1:T(8,128)} parameter(6) + %bitcast.1356.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} bitcast(%param_6.1444), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} + %param_5.2003 = f32[]{:T(128)} parameter(5) + %div.2562.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_5.2003), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2561.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%bitcast.1356.clone.1, %div.2562.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2120.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} select(%select_n.2121.clone.1, %bitcast.1356.clone.1, %div.2561.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4808.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4133.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4808.clone.1), dimensions={}, metadata={op_name="broadcast.334"} + %mul.5445.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.2120.clone.1, %broadcast.4133.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.889 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(8) - %constant.4864.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.4720.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4864.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4718.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_8.889, %mul.4720.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3488.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.4719.clone.1, %mul.4718.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4295 = f32[]{:T(128)S(6)} parameter(2) - %div.2558.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_2.4295), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4812.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.5446.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4812.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5444.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_8.889, %mul.5446.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3444.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.5445.clone.1, %mul.5444.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4290 = f32[]{:T(128)S(6)} parameter(2) + %div.2558.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_2.4290), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.399.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%select_n.2120.clone.1, %select_n.2120.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4863.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.4717.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4863.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4715.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%integer_pow.399.clone.1, %mul.4717.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.4811.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.5443.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4811.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5441.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%integer_pow.399.clone.1, %mul.5443.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_4.2203 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} parameter(4) - %constant.4862.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.4716.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4862.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4714.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_4.2203, %mul.4716.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3487.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.4715.clone.1, %mul.4714.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.5026 = f32[]{:T(128)S(6)} parameter(1) - %div.2557.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_1.5026), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2556.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3487.clone.1, %div.2557.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4810.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.5442.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4810.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5440.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_4.2203, %mul.5442.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3443.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%mul.5441.clone.1, %mul.5440.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.5030 = f32[]{:T(128)S(6)} parameter(1) + %div.2557.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%param_1.5030), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2556.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3443.clone.1, %div.2557.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.157.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} sqrt(%div.2556.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4861.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.3486.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4861.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.3485.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%sqrt.157.clone.1, %add.3486.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1293.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%div.2558.clone.1, %add.3485.clone.1), metadata={op_name="multiply.290"} - %div.2555.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3488.clone.1, %multiply.1293.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4712.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_0.4136, %broadcast.4272.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3484.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%div.2555.clone.1, %mul.4712.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4711.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%mul.4713.clone.1, %add.3484.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3483.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%param_0.4136, %mul.4711.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.565 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%add.3483.clone.1, %add.3483.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5071 = f32[]{:T(128)} constant(0) - %reduce.670 = f32[]{:T(128)} reduce(%square.565, %constant.5071), dimensions={0,1,2,3}, to_apply=%region_221.246, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.671.clone.1 = f32[]{:T(128)} reduce(%integer_pow.399.clone.1, %constant.5071), dimensions={0,1,2,3}, to_apply=%region_187.212, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.656 = (f32[]{:T(128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.670, %add.3483.clone.1, %add.3487.clone.1, %add.3488.clone.1, %reduce.671.clone.1) -} - -%region_160.185 (reduce_sum.473: f32[], reduce_sum.293: f32[]) -> f32[] { - %reduce_sum.473 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.293 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.300 = f32[]{:T(128)} add(%reduce_sum.473, %reduce_sum.293), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_158.183 (reduce_sum.459: f32[], reduce_sum.460: f32[]) -> f32[] { - %reduce_sum.459 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.460 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.461 = f32[]{:T(128)} add(%reduce_sum.459, %reduce_sum.460), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.494 (param_0.4162: bf16[256,512,512], param_1.5048: bf16[256,512,512]) -> (f32[], f32[]) { - %param_0.4162 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) - %broadcast_in_dim.1245 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4162), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.695 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1245), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} - %square.570 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.695, %bitcast.695), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5097 = f32[]{:T(128)} constant(0) - %reduce.672 = f32[]{:T(128)} reduce(%square.570, %constant.5097), dimensions={0,1,2,3}, to_apply=%region_160.185, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.5048 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(1) - %broadcast_in_dim.1253.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_1.5048), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.703.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1253.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} - %square.576.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.703.clone.1, %bitcast.703.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.674.clone.1 = f32[]{:T(128)} reduce(%square.576.clone.1, %constant.5097), dimensions={0,1,2,3}, to_apply=%region_158.183, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.764 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.672, %reduce.674.clone.1) -} - -%region_159.184 (reduce_sum.466: f32[], reduce_sum.279: f32[]) -> f32[] { - %reduce_sum.466 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.279 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.286 = f32[]{:T(128)} add(%reduce_sum.466, %reduce_sum.279), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.496 (param_0.4161: bf16[256,512,512]) -> f32[] { - %param_0.4161 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) - %broadcast_in_dim.1249 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4161), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.699 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1249), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} - %square.573 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.699, %bitcast.699), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5096 = f32[]{:T(128)} constant(0) - ROOT %reduce.673 = f32[]{:T(128)} reduce(%square.573, %constant.5096), dimensions={0,1,2,3}, to_apply=%region_159.184, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} -} - -%region_227.252 (reduce_sum.935: f32[], reduce_sum.631: f32[]) -> f32[] { - %reduce_sum.935 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.631 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.632 = f32[]{:T(128)} add(%reduce_sum.935, %reduce_sum.631), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_193.218 (reduce_sum.697: f32[], reduce_sum.471: f32[]) -> f32[] { - %reduce_sum.697 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.471 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.472 = f32[]{:T(128)} add(%reduce_sum.697, %reduce_sum.471), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.514 (param_0.4130: f32[], param_1.5020: f32[256,1,512,512], param_2.4289: f32[], param_3.2945: f32[256,1,512,512], param_4.2197: f32[], param_5.2000: bf16[256,512,512], param_6.1437: pred[], param_7.1118: f32[], param_8.883: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { + %constant.4809.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3442.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} broadcast(%constant.4809.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3441.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%sqrt.157.clone.1, %add.3442.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1086.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%div.2558.clone.1, %add.3441.clone.1), metadata={op_name="multiply.263"} + %div.2555.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} divide(%add.3444.clone.1, %multiply.1086.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5438.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%param_0.4162, %broadcast.4133.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3440.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%div.2555.clone.1, %mul.5438.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5437.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%mul.5439.clone.1, %add.3440.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3439.clone.1 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} add(%param_0.4162, %mul.5437.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.565 = f32[1536,3,128,192]{2,3,1,0:T(8,128)} multiply(%add.3439.clone.1, %add.3439.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5019 = f32[]{:T(128)} constant(0) + %reduce.611 = f32[]{:T(128)} reduce(%square.565, %constant.5019), dimensions={0,1,2,3}, to_apply=%region_221.246, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.612.clone.1 = f32[]{:T(128)} reduce(%integer_pow.399.clone.1, %constant.5019), dimensions={0,1,2,3}, to_apply=%region_187.212, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.656 = (f32[]{:T(128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[1536,3,128,192]{2,3,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.611, %add.3439.clone.1, %add.3443.clone.1, %add.3444.clone.1, %reduce.612.clone.1) +} + +%region_160.185 (reduce_sum.544: f32[], reduce_sum.364: f32[]) -> f32[] { + %reduce_sum.544 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.364 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.365 = f32[]{:T(128)} add(%reduce_sum.544, %reduce_sum.364), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_158.183 (reduce_sum.530: f32[], reduce_sum.352: f32[]) -> f32[] { + %reduce_sum.530 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.352 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.357 = f32[]{:T(128)} add(%reduce_sum.530, %reduce_sum.352), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.493 (param_0.4188: bf16[256,512,512], param_1.5052: bf16[256,512,512]) -> (f32[], f32[]) { + %param_0.4188 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) + %broadcast_in_dim.1270 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4188), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.677 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1270), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.570 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.677, %bitcast.677), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5045 = f32[]{:T(128)} constant(0) + %reduce.613 = f32[]{:T(128)} reduce(%square.570, %constant.5045), dimensions={0,1,2,3}, to_apply=%region_160.185, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.5052 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(1) + %broadcast_in_dim.1278.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_1.5052), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.685.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1278.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.576.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.685.clone.1, %bitcast.685.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.615.clone.1 = f32[]{:T(128)} reduce(%square.576.clone.1, %constant.5045), dimensions={0,1,2,3}, to_apply=%region_158.183, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.764 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.613, %reduce.615.clone.1) +} + +%region_159.184 (reduce_sum.537: f32[], reduce_sum.358: f32[]) -> f32[] { + %reduce_sum.537 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.358 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.359 = f32[]{:T(128)} add(%reduce_sum.537, %reduce_sum.358), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.495 (param_0.4187: bf16[256,512,512]) -> f32[] { + %param_0.4187 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(0) + %broadcast_in_dim.1274 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_0.4187), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.681 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1274), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.573 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.681, %bitcast.681), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5044 = f32[]{:T(128)} constant(0) + ROOT %reduce.614 = f32[]{:T(128)} reduce(%square.573, %constant.5044), dimensions={0,1,2,3}, to_apply=%region_159.184, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +} + +%region_227.252 (reduce_sum.1006: f32[], reduce_sum.1007: f32[]) -> f32[] { + %reduce_sum.1006 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.1007 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.548 = f32[]{:T(128)} add(%reduce_sum.1006, %reduce_sum.1007), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_193.218 (reduce_sum.768: f32[], reduce_sum.769: f32[]) -> f32[] { + %reduce_sum.768 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.769 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.466 = f32[]{:T(128)} add(%reduce_sum.768, %reduce_sum.769), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.513 (param_0.4156: f32[], param_1.5024: f32[256,1,512,512], param_2.4284: f32[], param_3.2950: f32[256,1,512,512], param_4.2197: f32[], param_5.1997: bf16[256,512,512], param_6.1438: pred[], param_7.1118: f32[], param_8.883: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { %param_8.883 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) - %bitcast.1359.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.883), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} + %bitcast.1341.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.883), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} %param_7.1118 = f32[]{:T(128)S(6)} parameter(7) - %mul.4662.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1118), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_6.1437 = pred[]{:T(512)S(6)} parameter(6) - %select_n.2103.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1437), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_5.2000 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) - %broadcast_in_dim.1459.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.2000), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.1361.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1459.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %mul.5388.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1118), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1438 = pred[]{:T(512)S(6)} parameter(6) + %select_n.2103.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1438), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_5.1997 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) + %broadcast_in_dim.1484.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.1997), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.1343.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1484.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} %param_4.2197 = f32[]{:T(128)} parameter(4) %div.2520.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2197), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2519.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1361.clone.1, %div.2520.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2102.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2103.clone.1, %bitcast.1361.clone.1, %div.2519.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4830.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4252.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4830.clone.1), dimensions={}, metadata={op_name="broadcast.2362"} - %mul.4664.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2102.clone.1, %broadcast.4252.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_3.2945 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) - %bitcast.1360.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2945), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} - %constant.4829.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.4251.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4829.clone.1), dimensions={}, metadata={op_name="broadcast.329"} - %mul.4663.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1360.clone.1, %broadcast.4251.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3453.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4664.clone.1, %mul.4663.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4289 = f32[]{:T(128)S(6)} parameter(2) - %div.2518.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4289), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2519.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1343.clone.1, %div.2520.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2102.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2103.clone.1, %bitcast.1343.clone.1, %div.2519.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4778.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4113.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4778.clone.1), dimensions={}, metadata={op_name="broadcast.2239"} + %mul.5390.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2102.clone.1, %broadcast.4113.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_3.2950 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) + %bitcast.1342.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2950), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} + %constant.4777.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.4112.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4777.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.5389.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1342.clone.1, %broadcast.4112.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3409.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5390.clone.1, %mul.5389.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4284 = f32[]{:T(128)S(6)} parameter(2) + %div.2518.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4284), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.393.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2102.clone.1, %select_n.2102.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4828.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.4254.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4828.clone.1), dimensions={}, metadata={op_name="broadcast.2365"} - %mul.4666.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.393.clone.1, %broadcast.4254.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_1.5020 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) - %bitcast.1362.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5020), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} - %constant.4827.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.4253.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4827.clone.1), dimensions={}, metadata={op_name="broadcast.312"} - %mul.4665.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1362.clone.1, %broadcast.4253.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3454.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4666.clone.1, %mul.4665.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_0.4130 = f32[]{:T(128)S(6)} parameter(0) - %div.2517.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4130), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2516.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3454.clone.1, %div.2517.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4776.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.4115.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4776.clone.1), dimensions={}, metadata={op_name="broadcast.2242"} + %mul.5392.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.393.clone.1, %broadcast.4115.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_1.5024 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) + %bitcast.1344.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5024), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wo\']"} + %constant.4775.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.4114.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4775.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.5391.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1344.clone.1, %broadcast.4114.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3410.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5392.clone.1, %mul.5391.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_0.4156 = f32[]{:T(128)S(6)} parameter(0) + %div.2517.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4156), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2516.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3410.clone.1, %div.2517.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.151.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2516.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4831.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.4250.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4831.clone.1), dimensions={}, metadata={op_name="broadcast.305"} - %add.3452.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.151.clone.1, %broadcast.4250.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1287.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2518.clone.1, %add.3452.clone.1), metadata={op_name="multiply.296"} - %div.2515.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3453.clone.1, %multiply.1287.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4661.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1359.clone.1, %broadcast.4252.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3451.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2515.clone.1, %mul.4661.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4660.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4662.clone.1, %add.3451.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3450.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1359.clone.1, %mul.4660.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.577 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3450.clone.1, %add.3450.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5065 = f32[]{:T(128)} constant(0) - %reduce.675 = f32[]{:T(128)} reduce(%square.577, %constant.5065), dimensions={0,1,2,3}, to_apply=%region_227.252, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %bitcast.849.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3454.clone.1) - %bitcast.822.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3453.clone.1) - %reduce.684.clone.1 = f32[]{:T(128)} reduce(%integer_pow.393.clone.1, %constant.5065), dimensions={0,1,2,3}, to_apply=%region_193.218, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.666 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.675, %add.3450.clone.1, %bitcast.849.clone.1, %bitcast.822.clone.1, %reduce.684.clone.1) -} - -%region_226.251 (reduce_sum.928: f32[], reduce_sum.625: f32[]) -> f32[] { - %reduce_sum.928 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.625 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.626 = f32[]{:T(128)} add(%reduce_sum.928, %reduce_sum.625), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_192.217 (reduce_sum.690: f32[], reduce_sum.465: f32[]) -> f32[] { - %reduce_sum.690 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.465 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.470 = f32[]{:T(128)} add(%reduce_sum.690, %reduce_sum.465), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.515 (param_0.4131: f32[], param_1.5021: f32[256,1,512,512], param_2.4290: f32[], param_3.2946: f32[256,1,512,512], param_4.2198: f32[], param_5.2001: bf16[256,512,512], param_6.1438: pred[], param_7.1119: f32[], param_8.884: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { + %constant.4779.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.4111.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4779.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3408.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.151.clone.1, %broadcast.4111.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1080.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2518.clone.1, %add.3408.clone.1), metadata={op_name="multiply.269"} + %div.2515.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3409.clone.1, %multiply.1080.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5387.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1341.clone.1, %broadcast.4113.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3407.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2515.clone.1, %mul.5387.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5386.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.5388.clone.1, %add.3407.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3406.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1341.clone.1, %mul.5386.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.577 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3406.clone.1, %add.3406.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5013 = f32[]{:T(128)} constant(0) + %reduce.616 = f32[]{:T(128)} reduce(%square.577, %constant.5013), dimensions={0,1,2,3}, to_apply=%region_227.252, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %bitcast.831.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3410.clone.1) + %bitcast.804.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3409.clone.1) + %reduce.625.clone.1 = f32[]{:T(128)} reduce(%integer_pow.393.clone.1, %constant.5013), dimensions={0,1,2,3}, to_apply=%region_193.218, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.666 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.616, %add.3406.clone.1, %bitcast.831.clone.1, %bitcast.804.clone.1, %reduce.625.clone.1) +} + +%region_226.251 (reduce_sum.999: f32[], reduce_sum.1000: f32[]) -> f32[] { + %reduce_sum.999 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.1000 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.543 = f32[]{:T(128)} add(%reduce_sum.999, %reduce_sum.1000), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_192.217 (reduce_sum.761: f32[], reduce_sum.762: f32[]) -> f32[] { + %reduce_sum.761 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.762 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.465 = f32[]{:T(128)} add(%reduce_sum.761, %reduce_sum.762), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.514 (param_0.4157: f32[], param_1.5025: f32[256,1,512,512], param_2.4285: f32[], param_3.2951: f32[256,1,512,512], param_4.2198: f32[], param_5.1998: bf16[256,512,512], param_6.1439: pred[], param_7.1119: f32[], param_8.884: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { %param_8.884 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) - %bitcast.1363.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.884), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} + %bitcast.1345.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.884), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} %param_7.1119 = f32[]{:T(128)S(6)} parameter(7) - %mul.4669.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1119), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_6.1438 = pred[]{:T(512)S(6)} parameter(6) - %select_n.2105.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1438), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_5.2001 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) - %broadcast_in_dim.1460.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.2001), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.1365.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1460.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %mul.5395.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1119), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1439 = pred[]{:T(512)S(6)} parameter(6) + %select_n.2105.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1439), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_5.1998 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) + %broadcast_in_dim.1485.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.1998), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.1347.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1485.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} %param_4.2198 = f32[]{:T(128)} parameter(4) %div.2526.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2198), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2525.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1365.clone.1, %div.2526.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2104.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2105.clone.1, %bitcast.1365.clone.1, %div.2525.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4835.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4257.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4835.clone.1), dimensions={}, metadata={op_name="broadcast.2362"} - %mul.4671.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2104.clone.1, %broadcast.4257.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_3.2946 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) - %bitcast.1364.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2946), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} - %constant.4834.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.4256.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4834.clone.1), dimensions={}, metadata={op_name="broadcast.329"} - %mul.4670.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1364.clone.1, %broadcast.4256.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3458.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4671.clone.1, %mul.4670.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4290 = f32[]{:T(128)S(6)} parameter(2) - %div.2524.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4290), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2525.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1347.clone.1, %div.2526.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2104.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2105.clone.1, %bitcast.1347.clone.1, %div.2525.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4783.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4118.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4783.clone.1), dimensions={}, metadata={op_name="broadcast.2239"} + %mul.5397.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2104.clone.1, %broadcast.4118.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_3.2951 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) + %bitcast.1346.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2951), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} + %constant.4782.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.4117.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4782.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.5396.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1346.clone.1, %broadcast.4117.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3414.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5397.clone.1, %mul.5396.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4285 = f32[]{:T(128)S(6)} parameter(2) + %div.2524.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4285), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.394.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2104.clone.1, %select_n.2104.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4833.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.4259.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4833.clone.1), dimensions={}, metadata={op_name="broadcast.2365"} - %mul.4673.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.394.clone.1, %broadcast.4259.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_1.5021 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) - %bitcast.1366.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5021), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} - %constant.4832.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.4258.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4832.clone.1), dimensions={}, metadata={op_name="broadcast.312"} - %mul.4672.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1366.clone.1, %broadcast.4258.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3459.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4673.clone.1, %mul.4672.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_0.4131 = f32[]{:T(128)S(6)} parameter(0) - %div.2523.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4131), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2522.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3459.clone.1, %div.2523.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4781.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.4120.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4781.clone.1), dimensions={}, metadata={op_name="broadcast.2242"} + %mul.5399.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.394.clone.1, %broadcast.4120.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_1.5025 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) + %bitcast.1348.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5025), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_1\']"} + %constant.4780.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.4119.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4780.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.5398.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1348.clone.1, %broadcast.4119.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3415.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5399.clone.1, %mul.5398.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_0.4157 = f32[]{:T(128)S(6)} parameter(0) + %div.2523.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4157), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2522.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3415.clone.1, %div.2523.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.152.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2522.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4836.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.4255.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4836.clone.1), dimensions={}, metadata={op_name="broadcast.305"} - %add.3457.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.152.clone.1, %broadcast.4255.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1288.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2524.clone.1, %add.3457.clone.1), metadata={op_name="multiply.295"} - %div.2521.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3458.clone.1, %multiply.1288.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4668.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1363.clone.1, %broadcast.4257.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3456.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2521.clone.1, %mul.4668.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4667.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4669.clone.1, %add.3456.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3455.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1363.clone.1, %mul.4667.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.578 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3455.clone.1, %add.3455.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5066 = f32[]{:T(128)} constant(0) - %reduce.676 = f32[]{:T(128)} reduce(%square.578, %constant.5066), dimensions={0,1,2,3}, to_apply=%region_226.251, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %bitcast.840.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3459.clone.1) - %bitcast.813.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3458.clone.1) - %reduce.685.clone.1 = f32[]{:T(128)} reduce(%integer_pow.394.clone.1, %constant.5066), dimensions={0,1,2,3}, to_apply=%region_192.217, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.665 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.676, %add.3455.clone.1, %bitcast.840.clone.1, %bitcast.813.clone.1, %reduce.685.clone.1) -} - -%region_225.250 (reduce_sum.921: f32[], reduce_sum.619: f32[]) -> f32[] { - %reduce_sum.921 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.619 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.624 = f32[]{:T(128)} add(%reduce_sum.921, %reduce_sum.619), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_191.216 (reduce_sum.683: f32[], reduce_sum.463: f32[]) -> f32[] { - %reduce_sum.683 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.463 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.464 = f32[]{:T(128)} add(%reduce_sum.683, %reduce_sum.463), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.516 (param_0.4132: f32[], param_1.5022: f32[256,1,512,512], param_2.4291: f32[], param_3.2947: f32[256,1,512,512], param_4.2199: f32[], param_5.2002: bf16[256,512,512], param_6.1439: pred[], param_7.1120: f32[], param_8.885: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { + %constant.4784.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.4116.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4784.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3413.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.152.clone.1, %broadcast.4116.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1081.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2524.clone.1, %add.3413.clone.1), metadata={op_name="multiply.268"} + %div.2521.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3414.clone.1, %multiply.1081.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5394.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1345.clone.1, %broadcast.4118.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3412.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2521.clone.1, %mul.5394.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5393.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.5395.clone.1, %add.3412.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3411.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1345.clone.1, %mul.5393.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.578 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3411.clone.1, %add.3411.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5014 = f32[]{:T(128)} constant(0) + %reduce.617 = f32[]{:T(128)} reduce(%square.578, %constant.5014), dimensions={0,1,2,3}, to_apply=%region_226.251, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %bitcast.822.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3415.clone.1) + %bitcast.795.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3414.clone.1) + %reduce.626.clone.1 = f32[]{:T(128)} reduce(%integer_pow.394.clone.1, %constant.5014), dimensions={0,1,2,3}, to_apply=%region_192.217, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.665 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.617, %add.3411.clone.1, %bitcast.822.clone.1, %bitcast.795.clone.1, %reduce.626.clone.1) +} + +%region_225.250 (reduce_sum.992: f32[], reduce_sum.993: f32[]) -> f32[] { + %reduce_sum.992 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.993 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.542 = f32[]{:T(128)} add(%reduce_sum.992, %reduce_sum.993), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_191.216 (reduce_sum.754: f32[], reduce_sum.755: f32[]) -> f32[] { + %reduce_sum.754 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.755 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.464 = f32[]{:T(128)} add(%reduce_sum.754, %reduce_sum.755), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.515 (param_0.4158: f32[], param_1.5026: f32[256,1,512,512], param_2.4286: f32[], param_3.2952: f32[256,1,512,512], param_4.2199: f32[], param_5.1999: bf16[256,512,512], param_6.1440: pred[], param_7.1120: f32[], param_8.885: f32[256,1,512,512]) -> (f32[], f32[256,1,512,512], f32[256,1,512,512], f32[256,1,512,512], f32[]) { %param_8.885 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(8) - %bitcast.1367.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.885), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} + %bitcast.1349.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_8.885), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} %param_7.1120 = f32[]{:T(128)S(6)} parameter(7) - %mul.4676.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1120), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_6.1439 = pred[]{:T(512)S(6)} parameter(6) - %select_n.2107.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1439), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_5.2002 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) - %broadcast_in_dim.1461.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.2002), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.1369.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1461.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %mul.5402.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_7.1120), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1440 = pred[]{:T(512)S(6)} parameter(6) + %select_n.2107.clone.1 = pred[256,1,512,512]{3,2,0,1:T(8,128)(4,1)} broadcast(%param_6.1440), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_5.1999 = bf16[256,512,512]{2,1,0:T(8,128)(2,1)} parameter(5) + %broadcast_in_dim.1486.clone.1 = f32[256,512,512]{2,1,0:T(8,128)} convert(%param_5.1999), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.1351.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%broadcast_in_dim.1486.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} %param_4.2199 = f32[]{:T(128)} parameter(4) %div.2532.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_4.2199), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2531.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1369.clone.1, %div.2532.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2106.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2107.clone.1, %bitcast.1369.clone.1, %div.2531.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4840.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4262.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4840.clone.1), dimensions={}, metadata={op_name="broadcast.2362"} - %mul.4678.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2106.clone.1, %broadcast.4262.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_3.2947 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) - %bitcast.1368.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2947), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} - %constant.4839.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.4261.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4839.clone.1), dimensions={}, metadata={op_name="broadcast.329"} - %mul.4677.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1368.clone.1, %broadcast.4261.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3463.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4678.clone.1, %mul.4677.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4291 = f32[]{:T(128)S(6)} parameter(2) - %div.2530.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4291), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2531.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%bitcast.1351.clone.1, %div.2532.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2106.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} select(%select_n.2107.clone.1, %bitcast.1351.clone.1, %div.2531.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4788.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4123.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4788.clone.1), dimensions={}, metadata={op_name="broadcast.2239"} + %mul.5404.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2106.clone.1, %broadcast.4123.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_3.2952 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(3) + %bitcast.1350.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_3.2952), sharding={replicated}, metadata={op_name="state.opt_state[0].mu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} + %constant.4787.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.4122.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4787.clone.1), dimensions={}, metadata={op_name="broadcast.329"} + %mul.5403.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1350.clone.1, %broadcast.4122.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3419.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5404.clone.1, %mul.5403.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4286 = f32[]{:T(128)S(6)} parameter(2) + %div.2530.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_2.4286), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.395.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%select_n.2106.clone.1, %select_n.2106.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4838.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.4264.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4838.clone.1), dimensions={}, metadata={op_name="broadcast.2365"} - %mul.4680.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.395.clone.1, %broadcast.4264.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_1.5022 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) - %bitcast.1370.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5022), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} - %constant.4837.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.4263.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4837.clone.1), dimensions={}, metadata={op_name="broadcast.312"} - %mul.4679.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1370.clone.1, %broadcast.4263.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3464.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.4680.clone.1, %mul.4679.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_0.4132 = f32[]{:T(128)S(6)} parameter(0) - %div.2529.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4132), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2528.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3464.clone.1, %div.2529.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4786.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.4125.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4786.clone.1), dimensions={}, metadata={op_name="broadcast.2242"} + %mul.5406.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%integer_pow.395.clone.1, %broadcast.4125.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_1.5026 = f32[256,1,512,512]{3,2,1,0:T(8,128)} parameter(1) + %bitcast.1352.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} bitcast(%param_1.5026), sharding={replicated}, metadata={op_name="state.opt_state[0].nu[\'params\'][\'decoder\'][\'moe_layers\'][\'DeepSeekMoeBlock_0\'][\'MoeBlock_0\'][\'wi_0\']"} + %constant.4785.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.4124.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4785.clone.1), dimensions={}, metadata={op_name="broadcast.312"} + %mul.5405.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1352.clone.1, %broadcast.4124.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3420.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%mul.5406.clone.1, %mul.5405.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_0.4158 = f32[]{:T(128)S(6)} parameter(0) + %div.2529.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%param_0.4158), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2528.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3420.clone.1, %div.2529.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.153.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} sqrt(%div.2528.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4841.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.4260.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4841.clone.1), dimensions={}, metadata={op_name="broadcast.305"} - %add.3462.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.153.clone.1, %broadcast.4260.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1289.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2530.clone.1, %add.3462.clone.1), metadata={op_name="multiply.294"} - %div.2527.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3463.clone.1, %multiply.1289.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4675.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1367.clone.1, %broadcast.4262.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3461.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2527.clone.1, %mul.4675.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4674.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.4676.clone.1, %add.3461.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3460.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1367.clone.1, %mul.4674.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.579 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3460.clone.1, %add.3460.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5067 = f32[]{:T(128)} constant(0) - %reduce.677 = f32[]{:T(128)} reduce(%square.579, %constant.5067), dimensions={0,1,2,3}, to_apply=%region_225.250, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %bitcast.831.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3464.clone.1) - %bitcast.804.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3463.clone.1) - %reduce.686.clone.1 = f32[]{:T(128)} reduce(%integer_pow.395.clone.1, %constant.5067), dimensions={0,1,2,3}, to_apply=%region_191.216, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.664 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.677, %add.3460.clone.1, %bitcast.831.clone.1, %bitcast.804.clone.1, %reduce.686.clone.1) -} - -%region_155.180 (reduce_sum.438: f32[], reduce_sum.259: f32[]) -> f32[] { - %reduce_sum.438 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.259 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.260 = f32[]{:T(128)} add(%reduce_sum.438, %reduce_sum.259), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.528.clone.clone.clone (param_0.4075: bf16[4,128,129280], param_1.4954: s32[4,128], param_2.4222: f32[4,128], param_3.2913: f32[4,128], param_4.2170: bf16[4,128], param_5.1978: f32[4,128]) -> bf16[4,128,129280] { - %param_5.1978 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %mul.4889 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_5.1978), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_3.2913 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %mul.4888 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2913), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_0.4075 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.3151 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4075), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %constant.4789.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.4121.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} broadcast(%constant.4789.clone.1), dimensions={}, metadata={op_name="broadcast.305"} + %add.3418.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%sqrt.153.clone.1, %broadcast.4121.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1082.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%div.2530.clone.1, %add.3418.clone.1), metadata={op_name="multiply.267"} + %div.2527.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} divide(%add.3419.clone.1, %multiply.1082.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5401.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%bitcast.1349.clone.1, %broadcast.4123.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3417.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%div.2527.clone.1, %mul.5401.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5400.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%mul.5402.clone.1, %add.3417.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3416.clone.1 = f32[256,1,512,512]{3,2,0,1:T(8,128)} add(%bitcast.1349.clone.1, %mul.5400.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.579 = f32[256,1,512,512]{3,2,0,1:T(8,128)} multiply(%add.3416.clone.1, %add.3416.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5015 = f32[]{:T(128)} constant(0) + %reduce.618 = f32[]{:T(128)} reduce(%square.579, %constant.5015), dimensions={0,1,2,3}, to_apply=%region_225.250, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %bitcast.813.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3420.clone.1) + %bitcast.786.clone.1 = f32[256,1,512,512]{3,2,1,0:T(8,128)} bitcast(%add.3419.clone.1) + %reduce.627.clone.1 = f32[]{:T(128)} reduce(%integer_pow.395.clone.1, %constant.5015), dimensions={0,1,2,3}, to_apply=%region_191.216, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.664 = (f32[]{:T(128)}, f32[256,1,512,512]{3,2,0,1:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[256,1,512,512]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.618, %add.3416.clone.1, %bitcast.813.clone.1, %bitcast.786.clone.1, %reduce.627.clone.1) +} + +%region_155.180 (reduce_sum.509: f32[], reduce_sum.338: f32[]) -> f32[] { + %reduce_sum.509 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.338 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.343 = f32[]{:T(128)} add(%reduce_sum.509, %reduce_sum.338), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.527.clone.clone.clone (param_0.4101: bf16[4,128,129280], param_1.4958: s32[4,128], param_2.4217: f32[4,128], param_3.2918: f32[4,128], param_4.2170: bf16[4,128], param_5.1975: f32[4,128]) -> bf16[4,128,129280] { + %param_5.1975 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.5639 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_5.1975), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.2918 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.5638 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2918), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.4101 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.3163 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4101), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} %param_4.2170 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) %sub.791 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_4.2170), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.790 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.3151, %sub.791), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.790 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.3163, %sub.791), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %exp.534 = f32[4,128,129280]{2,1,0:T(8,128)} exponential(%sub.790), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %mul.4887 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.4888, %exp.534), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_2.4222 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %div.2685 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4222), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %div.2684 = f32[4,128,129280]{2,1,0:T(8,128)} divide(%mul.4887, %div.2685), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %param_1.4954 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %eq.363 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.4954), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %mul.5637 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.5638, %exp.534), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.4217 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.2685 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4217), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.2684 = f32[4,128,129280]{2,1,0:T(8,128)} divide(%mul.5637, %div.2685), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.4958 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.363 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.4958), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.362 = s32[4,128,129280]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.361 = pred[4,128,129280]{2,1,0:T(8,128)(4,1)} compare(%eq.363, %eq.362), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %convert_element_type.3150 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%eq.361), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} - %sub.789 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%div.2684, %convert_element_type.3150), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} - %mul.4886 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.4889, %sub.789), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - ROOT %convert_element_type.3149 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} convert(%mul.4886), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} -} - -%fused_computation.935.clone.clone (param_0.4076: f32[4,128], param_1.4955: bf16[4,128,512], param_2.4224: bf16[512]) -> bf16[4,128,512] { - %param_2.4224 = bf16[512]{0:T(512)(128)(2,1)S(1)} parameter(2) - %dot_general.831 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.4224), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.4955 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.3153 = f32[4,128,512]{2,1,0:T(8,128)} convert(%param_1.4955), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.4076 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.4891 = f32[4,128,512]{2,1,0:T(8,128)} broadcast(%param_0.4076), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.4890 = f32[4,128,512]{2,1,0:T(8,128)} multiply(%convert_element_type.3153, %mul.4891), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.3152 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} convert(%mul.4890), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - ROOT %dot_general.830 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.831, %convert_element_type.3152), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.517 (param_0.4165: bf16[4,128,129280], param_1.5050: s32[4,128], param_2.4316: f32[4,128], param_3.2969: f32[4,128], param_4.2219: bf16[4,128], param_5.2020: f32[4,128], param_6.1457: f32[4,128], param_7.1138: bf16[4,128,512], param_8.902: bf16[512]) -> (f32[], bf16[512,129280,1]) { - %param_6.1457 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %convert_element_type.3162 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%eq.361), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.789 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%div.2684, %convert_element_type.3162), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.5636 = f32[4,128,129280]{2,1,0:T(8,128)} multiply(%mul.5639, %sub.789), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.3161 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} convert(%mul.5636), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +} + +%fused_computation.934.clone.clone (param_0.4102: f32[4,128], param_1.4959: bf16[4,128,512], param_2.4219: bf16[512]) -> bf16[4,128,512] { + %param_1.4959 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.3165 = f32[4,128,512]{2,1,0:T(8,128)} convert(%param_1.4959), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.4102 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.5642 = f32[4,128,512]{2,1,0:T(8,128)} broadcast(%param_0.4102), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.5641 = f32[4,128,512]{2,1,0:T(8,128)} multiply(%convert_element_type.3165, %mul.5642), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.3164 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} convert(%mul.5641), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.4219 = bf16[512]{0:T(512)(128)(2,1)S(1)} parameter(2) + %mul.5643 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.4219), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.5640 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.3164, %mul.5643), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%fused_computation.516 (param_0.4191: bf16[4,128,129280], param_1.5054: s32[4,128], param_2.4311: f32[4,128], param_3.2974: f32[4,128], param_4.2219: bf16[4,128], param_5.2017: f32[4,128], param_6.1458: f32[4,128], param_7.1138: bf16[4,128,512], param_8.902: bf16[512]) -> (f32[], bf16[512,129280,1]) { + %param_6.1458 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) %param_7.1138 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)S(1)} parameter(7) %param_8.902 = bf16[512]{0:T(512)(128)(2,1)S(1)} parameter(8) - %fusion.573.clone.1 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} fusion(%param_6.1457, %param_7.1138, %param_8.902), kind=kLoop, calls=%fused_computation.935.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.4165 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.5050 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.4316 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.2969 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %fusion.572.clone.1 = bf16[4,128,512]{2,1,0:T(8,128)(2,1)} fusion(%param_6.1458, %param_7.1138, %param_8.902), kind=kLoop, calls=%fused_computation.934.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.4191 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.5054 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.4311 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.2974 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) %param_4.2219 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) - %param_5.2020 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %multiply_convert_fusion.1.clone.1 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} fusion(%param_0.4165, %param_1.5050, %param_2.4316, %param_3.2969, %param_4.2219, /*index=5*/%param_5.2020), kind=kLoop, calls=%fused_computation.528.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %convolution.141.clone.1 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} convolution(%fusion.573.clone.1, %multiply_convert_fusion.1.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %bitcast.776 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%convolution.141.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %convert_element_type.2653 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.776), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} - %square.581 = f32[512,129280]{1,0:T(8,128)} multiply(%convert_element_type.2653, %convert_element_type.2653), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5100 = f32[]{:T(128)} constant(0) - %reduce.678 = f32[]{:T(128)} reduce(%square.581, %constant.5100), dimensions={0,1}, to_apply=%region_155.180, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.754 = (f32[]{:T(128)}, bf16[512,129280,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.678, %convolution.141.clone.1) + %param_5.2017 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %multiply_convert_fusion.1.clone.1 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} fusion(%param_0.4191, %param_1.5054, %param_2.4311, %param_3.2974, %param_4.2219, /*index=5*/%param_5.2017), kind=kLoop, calls=%fused_computation.527.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %convolution.141.clone.1 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} convolution(%fusion.572.clone.1, %multiply_convert_fusion.1.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %bitcast.758 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%convolution.141.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %convert_element_type.2665 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.758), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %square.581 = f32[512,129280]{1,0:T(8,128)} multiply(%convert_element_type.2665, %convert_element_type.2665), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5048 = f32[]{:T(128)} constant(0) + %reduce.619 = f32[]{:T(128)} reduce(%square.581, %constant.5048), dimensions={0,1}, to_apply=%region_155.180, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.754 = (f32[]{:T(128)}, bf16[512,129280,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.619, %convolution.141.clone.1) } -%region_174.199 (reduce_sum.564: f32[], reduce_sum.387: f32[]) -> f32[] { - %reduce_sum.564 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.387 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.388 = f32[]{:T(128)} add(%reduce_sum.564, %reduce_sum.387), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_174.199 (reduce_sum.635: f32[], reduce_sum.636: f32[]) -> f32[] { + %reduce_sum.635 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.636 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.424 = f32[]{:T(128)} add(%reduce_sum.635, %reduce_sum.636), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.518 (param_0.4149: bf16[129280,512]) -> f32[] { - %param_0.4149 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.2655 = f32[129280,512]{1,0:T(8,128)} convert(%param_0.4149), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %square.583 = f32[129280,512]{1,0:T(8,128)} multiply(%convert_element_type.2655, %convert_element_type.2655), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5084 = f32[]{:T(128)} constant(0) - ROOT %reduce.679 = f32[]{:T(128)} reduce(%square.583, %constant.5084), dimensions={0,1}, to_apply=%region_174.199, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.517 (param_0.4175: bf16[129280,512]) -> f32[] { + %param_0.4175 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.2667 = f32[129280,512]{1,0:T(8,128)} convert(%param_0.4175), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %square.583 = f32[129280,512]{1,0:T(8,128)} multiply(%convert_element_type.2667, %convert_element_type.2667), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5032 = f32[]{:T(128)} constant(0) + ROOT %reduce.620 = f32[]{:T(128)} reduce(%square.583, %constant.5032), dimensions={0,1}, to_apply=%region_174.199, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_240.265 (reduce_sum.1026: f32[], reduce_sum.689: f32[]) -> f32[] { - %reduce_sum.1026 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.689 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.694 = f32[]{:T(128)} add(%reduce_sum.1026, %reduce_sum.689), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_240.265 (reduce_sum.1097: f32[], reduce_sum.1098: f32[]) -> f32[] { + %reduce_sum.1097 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.1098 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.577 = f32[]{:T(128)} add(%reduce_sum.1097, %reduce_sum.1098), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_206.231 (reduce_sum.788: f32[], reduce_sum.533: f32[]) -> f32[] { - %reduce_sum.788 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.533 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.534 = f32[]{:T(128)} add(%reduce_sum.788, %reduce_sum.533), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_206.231 (reduce_sum.859: f32[], reduce_sum.860: f32[]) -> f32[] { + %reduce_sum.859 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.860 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.499 = f32[]{:T(128)} add(%reduce_sum.859, %reduce_sum.860), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.519 (param_0.4117: f32[129280,512], param_1.5007: f32[], param_2.4276: f32[], param_3.2932: f32[], param_4.2184: f32[129280,512], param_5.1987: f32[], param_6.1424: bf16[129280,512], param_7.1105: pred[], param_8.870: f32[129280,512]) -> (f32[], f32[129280,512], f32[129280,512], f32[129280,512], f32[]) { - %param_0.4117 = f32[129280,512]{1,0:T(8,128)} parameter(0) - %param_3.2932 = f32[]{:T(128)S(6)} parameter(3) - %mul.4550.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_3.2932), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.518 (param_0.4143: f32[129280,512], param_1.5011: f32[], param_2.4271: f32[], param_3.2937: f32[], param_4.2184: f32[129280,512], param_5.1984: f32[], param_6.1425: bf16[129280,512], param_7.1105: pred[], param_8.870: f32[129280,512]) -> (f32[], f32[129280,512], f32[129280,512], f32[129280,512], f32[]) { + %param_0.4143 = f32[129280,512]{1,0:T(8,128)} parameter(0) + %param_3.2937 = f32[]{:T(128)S(6)} parameter(3) + %mul.5276.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_3.2937), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.1105 = pred[]{:T(512)S(6)} parameter(7) %select_n.2061.clone.1 = pred[129280,512]{1,0:T(8,128)(4,1)} broadcast(%param_7.1105), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.1424 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(6) - %convert_element_type.3094.clone.1 = f32[129280,512]{1,0:T(8,128)} convert(%param_6.1424), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %param_5.1987 = f32[]{:T(128)} parameter(5) - %div.2426.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_5.1987), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2425.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%convert_element_type.3094.clone.1, %div.2426.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2060.clone.1 = f32[129280,512]{1,0:T(8,128)} select(%select_n.2061.clone.1, %convert_element_type.3094.clone.1, %div.2425.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4750.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4202.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4750.clone.1), dimensions={}, metadata={op_name="broadcast.318"} - %mul.4556.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.2060.clone.1, %broadcast.4202.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1425 = bf16[129280,512]{1,0:T(8,128)(2,1)} parameter(6) + %convert_element_type.3106.clone.1 = f32[129280,512]{1,0:T(8,128)} convert(%param_6.1425), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %param_5.1984 = f32[]{:T(128)} parameter(5) + %div.2426.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_5.1984), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2425.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%convert_element_type.3106.clone.1, %div.2426.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2060.clone.1 = f32[129280,512]{1,0:T(8,128)} select(%select_n.2061.clone.1, %convert_element_type.3106.clone.1, %div.2425.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4698.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4063.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4698.clone.1), dimensions={}, metadata={op_name="broadcast.318"} + %mul.5282.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.2060.clone.1, %broadcast.4063.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.870 = f32[129280,512]{1,0:T(8,128)} parameter(8) - %constant.4754.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.4557.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4754.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4555.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_8.870, %mul.4557.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3383.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.4556.clone.1, %mul.4555.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4276 = f32[]{:T(128)S(6)} parameter(2) - %div.2422.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_2.4276), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4702.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.5283.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4702.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5281.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_8.870, %mul.5283.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3339.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.5282.clone.1, %mul.5281.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4271 = f32[]{:T(128)S(6)} parameter(2) + %div.2422.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_2.4271), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.380.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%select_n.2060.clone.1, %select_n.2060.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4753.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.4554.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4753.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4552.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%integer_pow.380.clone.1, %mul.4554.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.4701.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.5280.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4701.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5278.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%integer_pow.380.clone.1, %mul.5280.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_4.2184 = f32[129280,512]{1,0:T(8,128)} parameter(4) - %constant.4752.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.4553.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4752.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4551.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_4.2184, %mul.4553.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3382.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.4552.clone.1, %mul.4551.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.5007 = f32[]{:T(128)S(6)} parameter(1) - %div.2421.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_1.5007), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2420.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3382.clone.1, %div.2421.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4700.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.5279.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4700.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5277.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_4.2184, %mul.5279.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3338.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%mul.5278.clone.1, %mul.5277.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.5011 = f32[]{:T(128)S(6)} parameter(1) + %div.2421.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%param_1.5011), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2420.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3338.clone.1, %div.2421.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.138.clone.1 = f32[129280,512]{1,0:T(8,128)} sqrt(%div.2420.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4751.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.3381.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4751.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.3380.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%sqrt.138.clone.1, %add.3381.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1274.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%div.2422.clone.1, %add.3380.clone.1), metadata={op_name="multiply.309"} - %div.2419.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3383.clone.1, %multiply.1274.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4549.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_0.4117, %broadcast.4202.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3379.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%div.2419.clone.1, %mul.4549.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4548.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%mul.4550.clone.1, %add.3379.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3378.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%param_0.4117, %mul.4548.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.584 = f32[129280,512]{1,0:T(8,128)} multiply(%add.3378.clone.1, %add.3378.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5052 = f32[]{:T(128)} constant(0) - %reduce.680 = f32[]{:T(128)} reduce(%square.584, %constant.5052), dimensions={0,1}, to_apply=%region_240.265, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.687.clone.1 = f32[]{:T(128)} reduce(%integer_pow.380.clone.1, %constant.5052), dimensions={0,1}, to_apply=%region_206.231, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.667 = (f32[]{:T(128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.680, %add.3378.clone.1, %add.3382.clone.1, %add.3383.clone.1, %reduce.687.clone.1) -} - -%region_222.247 (reduce_sum.900: f32[], reduce_sum.605: f32[]) -> f32[] { - %reduce_sum.900 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.605 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.610 = f32[]{:T(128)} add(%reduce_sum.900, %reduce_sum.605), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_188.213 (reduce_sum.662: f32[], reduce_sum.451: f32[]) -> f32[] { - %reduce_sum.662 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.451 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.455 = f32[]{:T(128)} add(%reduce_sum.662, %reduce_sum.451), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.520 (param_0.4135: f32[512,129280], param_1.5025: f32[], param_2.4294: f32[], param_3.2950: f32[], param_4.2202: f32[512,129280], param_5.2005: f32[], param_6.1442: bf16[512,129280,1], param_7.1123: pred[], param_8.888: f32[512,129280]) -> (f32[], f32[512,129280], f32[512,129280], f32[512,129280], f32[]) { - %param_0.4135 = f32[512,129280]{1,0:T(8,128)} parameter(0) - %param_3.2950 = f32[]{:T(128)S(6)} parameter(3) - %mul.4703.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_3.2950), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.4699.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3337.clone.1 = f32[129280,512]{1,0:T(8,128)} broadcast(%constant.4699.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3336.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%sqrt.138.clone.1, %add.3337.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1067.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%div.2422.clone.1, %add.3336.clone.1), metadata={op_name="multiply.282"} + %div.2419.clone.1 = f32[129280,512]{1,0:T(8,128)} divide(%add.3339.clone.1, %multiply.1067.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5275.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%param_0.4143, %broadcast.4063.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3335.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%div.2419.clone.1, %mul.5275.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5274.clone.1 = f32[129280,512]{1,0:T(8,128)} multiply(%mul.5276.clone.1, %add.3335.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3334.clone.1 = f32[129280,512]{1,0:T(8,128)} add(%param_0.4143, %mul.5274.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.584 = f32[129280,512]{1,0:T(8,128)} multiply(%add.3334.clone.1, %add.3334.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5000 = f32[]{:T(128)} constant(0) + %reduce.621 = f32[]{:T(128)} reduce(%square.584, %constant.5000), dimensions={0,1}, to_apply=%region_240.265, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.628.clone.1 = f32[]{:T(128)} reduce(%integer_pow.380.clone.1, %constant.5000), dimensions={0,1}, to_apply=%region_206.231, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.667 = (f32[]{:T(128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[129280,512]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.621, %add.3334.clone.1, %add.3338.clone.1, %add.3339.clone.1, %reduce.628.clone.1) +} + +%region_222.247 (reduce_sum.971: f32[], reduce_sum.972: f32[]) -> f32[] { + %reduce_sum.971 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.972 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.535 = f32[]{:T(128)} add(%reduce_sum.971, %reduce_sum.972), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_188.213 (reduce_sum.733: f32[], reduce_sum.734: f32[]) -> f32[] { + %reduce_sum.733 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.734 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.457 = f32[]{:T(128)} add(%reduce_sum.733, %reduce_sum.734), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.519 (param_0.4161: f32[512,129280], param_1.5029: f32[], param_2.4289: f32[], param_3.2955: f32[], param_4.2202: f32[512,129280], param_5.2002: f32[], param_6.1443: bf16[512,129280,1], param_7.1123: pred[], param_8.888: f32[512,129280]) -> (f32[], f32[512,129280], f32[512,129280], f32[512,129280], f32[]) { + %param_0.4161 = f32[512,129280]{1,0:T(8,128)} parameter(0) + %param_3.2955 = f32[]{:T(128)S(6)} parameter(3) + %mul.5429.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_3.2955), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.1123 = pred[]{:T(512)S(6)} parameter(7) %select_n.2117.clone.1 = pred[512,129280]{1,0:T(8,128)(4,1)} broadcast(%param_7.1123), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.1442 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} parameter(6) - %bitcast.1372.clone.1 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%param_6.1442), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %convert_element_type.3096.clone.1 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.1372.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} - %param_5.2005 = f32[]{:T(128)} parameter(5) - %div.2554.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_5.2005), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2553.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%convert_element_type.3096.clone.1, %div.2554.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2116.clone.1 = f32[512,129280]{1,0:T(8,128)} select(%select_n.2117.clone.1, %convert_element_type.3096.clone.1, %div.2553.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4854.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4270.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4854.clone.1), dimensions={}, metadata={op_name="broadcast.333"} - %mul.4709.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.2116.clone.1, %broadcast.4270.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1443 = bf16[512,129280,1]{1,0,2:T(8,128)(2,1)} parameter(6) + %bitcast.1354.clone.1 = bf16[512,129280]{1,0:T(8,128)(2,1)} bitcast(%param_6.1443), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %convert_element_type.3108.clone.1 = f32[512,129280]{1,0:T(8,128)} convert(%bitcast.1354.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %param_5.2002 = f32[]{:T(128)} parameter(5) + %div.2554.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_5.2002), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2553.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%convert_element_type.3108.clone.1, %div.2554.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2116.clone.1 = f32[512,129280]{1,0:T(8,128)} select(%select_n.2117.clone.1, %convert_element_type.3108.clone.1, %div.2553.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4802.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4131.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4802.clone.1), dimensions={}, metadata={op_name="broadcast.333"} + %mul.5435.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.2116.clone.1, %broadcast.4131.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.888 = f32[512,129280]{1,0:T(8,128)} parameter(8) - %constant.4858.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.4710.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4858.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4708.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_8.888, %mul.4710.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3482.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.4709.clone.1, %mul.4708.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4294 = f32[]{:T(128)S(6)} parameter(2) - %div.2550.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_2.4294), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4806.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.5436.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4806.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5434.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_8.888, %mul.5436.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3438.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.5435.clone.1, %mul.5434.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4289 = f32[]{:T(128)S(6)} parameter(2) + %div.2550.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_2.4289), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.398.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%select_n.2116.clone.1, %select_n.2116.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4857.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.4707.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4857.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4705.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%integer_pow.398.clone.1, %mul.4707.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.4805.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.5433.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4805.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5431.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%integer_pow.398.clone.1, %mul.5433.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_4.2202 = f32[512,129280]{1,0:T(8,128)} parameter(4) - %constant.4856.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.4706.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4856.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4704.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_4.2202, %mul.4706.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3481.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.4705.clone.1, %mul.4704.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.5025 = f32[]{:T(128)S(6)} parameter(1) - %div.2549.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_1.5025), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2548.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3481.clone.1, %div.2549.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4804.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.5432.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4804.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5430.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_4.2202, %mul.5432.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3437.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%mul.5431.clone.1, %mul.5430.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.5029 = f32[]{:T(128)S(6)} parameter(1) + %div.2549.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%param_1.5029), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2548.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3437.clone.1, %div.2549.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.156.clone.1 = f32[512,129280]{1,0:T(8,128)} sqrt(%div.2548.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4855.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.3480.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4855.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.3479.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%sqrt.156.clone.1, %add.3480.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1292.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%div.2550.clone.1, %add.3479.clone.1), metadata={op_name="multiply.291"} - %div.2547.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3482.clone.1, %multiply.1292.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4702.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_0.4135, %broadcast.4270.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3478.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%div.2547.clone.1, %mul.4702.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4701.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%mul.4703.clone.1, %add.3478.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3477.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%param_0.4135, %mul.4701.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.585 = f32[512,129280]{1,0:T(8,128)} multiply(%add.3477.clone.1, %add.3477.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5070 = f32[]{:T(128)} constant(0) - %reduce.681 = f32[]{:T(128)} reduce(%square.585, %constant.5070), dimensions={0,1}, to_apply=%region_222.247, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.688.clone.1 = f32[]{:T(128)} reduce(%integer_pow.398.clone.1, %constant.5070), dimensions={0,1}, to_apply=%region_188.213, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.668 = (f32[]{:T(128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.681, %add.3477.clone.1, %add.3481.clone.1, %add.3482.clone.1, %reduce.688.clone.1) -} - -%region_207.232 (reduce_sum.795: f32[], reduce_sum.535: f32[]) -> f32[] { - %reduce_sum.795 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.535 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.540 = f32[]{:T(128)} add(%reduce_sum.795, %reduce_sum.535), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.521 (param_0.4186: bf16[4,128,129280], param_1.5064: f32[4,128], param_2.4326: s32[4,128], param_3.2977: bf16[4,128]) -> f32[4,128] { - %param_2.4326 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %eq.299 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4326), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %constant.4803.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3436.clone.1 = f32[512,129280]{1,0:T(8,128)} broadcast(%constant.4803.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3435.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%sqrt.156.clone.1, %add.3436.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1085.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%div.2550.clone.1, %add.3435.clone.1), metadata={op_name="multiply.264"} + %div.2547.clone.1 = f32[512,129280]{1,0:T(8,128)} divide(%add.3438.clone.1, %multiply.1085.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5428.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%param_0.4161, %broadcast.4131.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3434.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%div.2547.clone.1, %mul.5428.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5427.clone.1 = f32[512,129280]{1,0:T(8,128)} multiply(%mul.5429.clone.1, %add.3434.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3433.clone.1 = f32[512,129280]{1,0:T(8,128)} add(%param_0.4161, %mul.5427.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.585 = f32[512,129280]{1,0:T(8,128)} multiply(%add.3433.clone.1, %add.3433.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5018 = f32[]{:T(128)} constant(0) + %reduce.622 = f32[]{:T(128)} reduce(%square.585, %constant.5018), dimensions={0,1}, to_apply=%region_222.247, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.629.clone.1 = f32[]{:T(128)} reduce(%integer_pow.398.clone.1, %constant.5018), dimensions={0,1}, to_apply=%region_188.213, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.668 = (f32[]{:T(128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[512,129280]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.622, %add.3433.clone.1, %add.3437.clone.1, %add.3438.clone.1, %reduce.629.clone.1) +} + +%region_207.232 (reduce_sum.866: f32[], reduce_sum.867: f32[]) -> f32[] { + %reduce_sum.866 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.867 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.500 = f32[]{:T(128)} add(%reduce_sum.866, %reduce_sum.867), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.520 (param_0.4212: bf16[4,128,129280], param_1.5068: f32[4,128], param_2.4321: s32[4,128], param_3.2982: bf16[4,128]) -> f32[4,128] { + %param_2.4321 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %eq.299 = s32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_2.4321), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.294 = s32[4,128,129280]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.293 = pred[4,128,129280]{2,1,0:T(8,128)(4,1)} compare(%eq.299, %eq.294), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %param_0.4186 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.2660 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4186), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_3.2977 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) - %sub.652 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2977), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.643 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2660, %sub.652), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %param_1.5064 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %sub.650 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5064), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_0.4212 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.2672 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4212), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_3.2982 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) + %sub.652 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_3.2982), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.643 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2672, %sub.652), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_1.5068 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %sub.650 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5068), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %sub.639 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%sub.643, %sub.650), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %constant.5124 = f32[]{:T(128)} constant(0) - %broadcast.3777 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%constant.5124), dimensions={}, metadata={op_name="broadcast.514"} - %mul.3612 = f32[4,128,129280]{2,1,0:T(8,128)} select(%eq.293, %sub.639, %broadcast.3777), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - ROOT %reduce.682 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.3612, %constant.5124), dimensions={2}, to_apply=%region_207.232, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %constant.5072 = f32[]{:T(128)} constant(0) + %broadcast.3638 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%constant.5072), dimensions={}, metadata={op_name="broadcast.496"} + %mul.4227 = f32[4,128,129280]{2,1,0:T(8,128)} select(%eq.293, %sub.639, %broadcast.3638), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + ROOT %reduce.623 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.4227, %constant.5072), dimensions={2}, to_apply=%region_207.232, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } -%region_37.47 (reduce_sum.76: f32[], reduce_sum.80: f32[]) -> f32[] { +%region_37.47 (reduce_sum.76: f32[], reduce_sum.82: f32[]) -> f32[] { %reduce_sum.76 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.80 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.83 = f32[]{:T(128)} add(%reduce_sum.76, %reduce_sum.80), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %reduce_sum.82 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.88 = f32[]{:T(128)} add(%reduce_sum.76, %reduce_sum.82), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.532 (param_0.4187: bf16[4,128,129280], param_1.5065: bf16[4,128]) -> f32[4,128] { - %param_0.4187 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.2666 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4187), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_1.5065 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) - %sub.653 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5065), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.649 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2666, %sub.653), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} +%fused_computation.531 (param_0.4213: bf16[4,128,129280], param_1.5069: bf16[4,128]) -> f32[4,128] { + %param_0.4213 = bf16[4,128,129280]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.2678 = f32[4,128,129280]{2,1,0:T(8,128)} convert(%param_0.4213), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_1.5069 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) + %sub.653 = f32[4,128,129280]{2,1,0:T(8,128)} broadcast(%param_1.5069), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.649 = f32[4,128,129280]{2,1,0:T(8,128)} subtract(%convert_element_type.2678, %sub.653), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %exp.448 = f32[4,128,129280]{2,1,0:T(8,128)} exponential(%sub.649), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %constant.5125 = f32[]{:T(128)} constant(0) - ROOT %reduce.683 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.448, %constant.5125), dimensions={2}, to_apply=%region_37.47, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %constant.5073 = f32[]{:T(128)} constant(0) + ROOT %reduce.624 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.448, %constant.5073), dimensions={2}, to_apply=%region_37.47, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } -%region_152.177 (reduce_sum.417: f32[], reduce_sum.244: f32[]) -> f32[] { - %reduce_sum.417 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.244 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.251 = f32[]{:T(128)} add(%reduce_sum.417, %reduce_sum.244), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_152.177 (reduce_sum.488: f32[], reduce_sum.324: f32[]) -> f32[] { + %reduce_sum.488 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.324 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.329 = f32[]{:T(128)} add(%reduce_sum.488, %reduce_sum.324), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.540 (param_0.4168: f32[3,512,128,256]) -> f32[] { - %param_0.4168 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(0) - %bitcast.752 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_0.4168), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} - %square.588 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%bitcast.752, %bitcast.752), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5103 = f32[]{:T(128)} constant(0) - ROOT %reduce.689 = f32[]{:T(128)} reduce(%square.588, %constant.5103), dimensions={0,1,2,3}, to_apply=%region_152.177, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.539 (param_0.4194: f32[3,512,128,256]) -> f32[] { + %param_0.4194 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.734 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_0.4194), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.588 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%bitcast.734, %bitcast.734), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5051 = f32[]{:T(128)} constant(0) + ROOT %reduce.630 = f32[]{:T(128)} reduce(%square.588, %constant.5051), dimensions={0,1,2,3}, to_apply=%region_152.177, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%fused_computation.541 (param_0.1601: f32[512,3,128,256]) -> bf16[3,512,128,256] { - %param_0.1601 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(0) - %copy.1551 = bf16[512,3,128,256]{3,0,2,1:T(8,128)(2,1)} copy(%param_0.1601), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'dense_layers\'][\'self_attention\'][\'wkv_b\'][\'kernel\']"} - ROOT %bitcast.753 = bf16[3,512,128,256]{3,1,2,0:T(8,128)(2,1)} bitcast(%copy.1551), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} +%fused_computation.540 (param_0.1619: f32[512,3,128,256]) -> bf16[3,512,128,256] { + %param_0.1619 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(0) + %copy.1551 = bf16[512,3,128,256]{3,0,2,1:T(8,128)(2,1)} copy(%param_0.1619), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'dense_layers\'][\'self_attention\'][\'wkv_b\'][\'kernel\']"} + ROOT %bitcast.735 = bf16[3,512,128,256]{3,1,2,0:T(8,128)(2,1)} bitcast(%copy.1551), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} } -%region_219.244 (reduce_sum.879: f32[], reduce_sum.591: f32[]) -> f32[] { - %reduce_sum.879 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.591 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.596 = f32[]{:T(128)} add(%reduce_sum.879, %reduce_sum.591), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_219.244 (reduce_sum.950: f32[], reduce_sum.951: f32[]) -> f32[] { + %reduce_sum.950 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.951 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.528 = f32[]{:T(128)} add(%reduce_sum.950, %reduce_sum.951), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_185.210 (reduce_sum.641: f32[], reduce_sum.437: f32[]) -> f32[] { - %reduce_sum.641 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.437 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.442 = f32[]{:T(128)} add(%reduce_sum.641, %reduce_sum.437), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_185.210 (reduce_sum.712: f32[], reduce_sum.713: f32[]) -> f32[] { + %reduce_sum.712 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.713 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.450 = f32[]{:T(128)} add(%reduce_sum.712, %reduce_sum.713), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.542 (param_0.4138: f32[512,3,128,256], param_1.5028: f32[], param_2.4297: f32[], param_3.2953: f32[], param_4.2205: f32[512,3,128,256], param_5.2008: f32[], param_6.1445: f32[3,512,128,256], param_7.1126: pred[], param_8.891: f32[512,3,128,256]) -> (f32[], f32[512,3,128,256], f32[512,3,128,256], f32[512,3,128,256], f32[]) { - %param_0.4138 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(0) - %param_3.2953 = f32[]{:T(128)S(6)} parameter(3) - %mul.4733.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_3.2953), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.541 (param_0.4164: f32[512,3,128,256], param_1.5032: f32[], param_2.4292: f32[], param_3.2958: f32[], param_4.2205: f32[512,3,128,256], param_5.2005: f32[], param_6.1446: f32[3,512,128,256], param_7.1126: pred[], param_8.891: f32[512,3,128,256]) -> (f32[], f32[512,3,128,256], f32[512,3,128,256], f32[512,3,128,256], f32[]) { + %param_0.4164 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(0) + %param_3.2958 = f32[]{:T(128)S(6)} parameter(3) + %mul.5459.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_3.2958), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.1126 = pred[]{:T(512)S(6)} parameter(7) %select_n.2129.clone.1 = pred[512,3,128,256]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.1126), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.1445 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.1378.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_6.1445), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} - %param_5.2008 = f32[]{:T(128)} parameter(5) - %div.2578.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_5.2008), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2577.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%bitcast.1378.clone.1, %div.2578.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.2128.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} select(%select_n.2129.clone.1, %bitcast.1378.clone.1, %div.2577.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.4872.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.4276.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4872.clone.1), dimensions={}, metadata={op_name="broadcast.336"} - %mul.4739.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.2128.clone.1, %broadcast.4276.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_6.1446 = f32[3,512,128,256]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.1360.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} bitcast(%param_6.1446), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/dense_layers.wrapped_fn/transpose" stack_frame_id=0} + %param_5.2005 = f32[]{:T(128)} parameter(5) + %div.2578.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_5.2005), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2577.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%bitcast.1360.clone.1, %div.2578.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.2128.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} select(%select_n.2129.clone.1, %bitcast.1360.clone.1, %div.2577.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.4820.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.4137.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4820.clone.1), dimensions={}, metadata={op_name="broadcast.336"} + %mul.5465.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.2128.clone.1, %broadcast.4137.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.891 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(8) - %constant.4876.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.4740.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4876.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4738.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_8.891, %mul.4740.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3500.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.4739.clone.1, %mul.4738.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.4297 = f32[]{:T(128)S(6)} parameter(2) - %div.2574.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_2.4297), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4824.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.5466.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4824.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5464.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_8.891, %mul.5466.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3456.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.5465.clone.1, %mul.5464.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.4292 = f32[]{:T(128)S(6)} parameter(2) + %div.2574.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_2.4292), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.401.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%select_n.2128.clone.1, %select_n.2128.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.4875.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.4737.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4875.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4735.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%integer_pow.401.clone.1, %mul.4737.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.4823.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.5463.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4823.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5461.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%integer_pow.401.clone.1, %mul.5463.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_4.2205 = f32[512,3,128,256]{3,2,1,0:T(8,128)} parameter(4) - %constant.4874.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.4736.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4874.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.4734.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_4.2205, %mul.4736.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3499.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.4735.clone.1, %mul.4734.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.5028 = f32[]{:T(128)S(6)} parameter(1) - %div.2573.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_1.5028), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.2572.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3499.clone.1, %div.2573.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.4822.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.5462.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4822.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.5460.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_4.2205, %mul.5462.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3455.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%mul.5461.clone.1, %mul.5460.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.5032 = f32[]{:T(128)S(6)} parameter(1) + %div.2573.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%param_1.5032), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.2572.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3455.clone.1, %div.2573.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.159.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} sqrt(%div.2572.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.4873.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.3498.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4873.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.3497.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%sqrt.159.clone.1, %add.3498.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.1295.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%div.2574.clone.1, %add.3497.clone.1), metadata={op_name="multiply.288"} - %div.2571.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3500.clone.1, %multiply.1295.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.4732.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_0.4138, %broadcast.4276.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3496.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%div.2571.clone.1, %mul.4732.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.4731.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%mul.4733.clone.1, %add.3496.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.3495.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%param_0.4138, %mul.4731.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.589 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%add.3495.clone.1, %add.3495.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5073 = f32[]{:T(128)} constant(0) - %reduce.690 = f32[]{:T(128)} reduce(%square.589, %constant.5073), dimensions={0,1,2,3}, to_apply=%region_219.244, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.691.clone.1 = f32[]{:T(128)} reduce(%integer_pow.401.clone.1, %constant.5073), dimensions={0,1,2,3}, to_apply=%region_185.210, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.663 = (f32[]{:T(128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.690, %add.3495.clone.1, %add.3499.clone.1, %add.3500.clone.1, %reduce.691.clone.1) -} - -%region_172.197 (reduce_sum.557: f32[], reduce_sum.381: f32[]) -> f32[] { - %reduce_sum.557 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.381 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.386 = f32[]{:T(128)} add(%reduce_sum.557, %reduce_sum.381), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.782.clone.clone (param_0.4102: f32[4,128], param_1.4999: bf16[4,128,1536], param_2.4258: bf16[1536]) -> bf16[4,128,1536,1] { - %param_2.4258 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.851 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.4258), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.4999 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.3175 = f32[4,128,1536]{2,1,0:T(8,128)} convert(%param_1.4999), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} - %param_0.4102 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.4937 = f32[4,128,1536]{2,1,0:T(8,128)} broadcast(%param_0.4102), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} - %mul.4936 = f32[4,128,1536]{2,1,0:T(8,128)} multiply(%convert_element_type.3175, %mul.4937), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} - %convert_element_type.3174 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} convert(%mul.4936), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} - %dot_general.850 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.851, %convert_element_type.3174), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.1466 = bf16[4,128,1536,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dot_general.850), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=0} + %constant.4821.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.3454.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} broadcast(%constant.4821.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.3453.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%sqrt.159.clone.1, %add.3454.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.1088.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%div.2574.clone.1, %add.3453.clone.1), metadata={op_name="multiply.261"} + %div.2571.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} divide(%add.3456.clone.1, %multiply.1088.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.5458.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%param_0.4164, %broadcast.4137.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3452.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%div.2571.clone.1, %mul.5458.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.5457.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%mul.5459.clone.1, %add.3452.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.3451.clone.1 = f32[512,3,128,256]{3,2,1,0:T(8,128)} add(%param_0.4164, %mul.5457.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.589 = f32[512,3,128,256]{3,2,1,0:T(8,128)} multiply(%add.3451.clone.1, %add.3451.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5021 = f32[]{:T(128)} constant(0) + %reduce.631 = f32[]{:T(128)} reduce(%square.589, %constant.5021), dimensions={0,1,2,3}, to_apply=%region_219.244, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.632.clone.1 = f32[]{:T(128)} reduce(%integer_pow.401.clone.1, %constant.5021), dimensions={0,1,2,3}, to_apply=%region_185.210, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.663 = (f32[]{:T(128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[512,3,128,256]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.631, %add.3451.clone.1, %add.3455.clone.1, %add.3456.clone.1, %reduce.632.clone.1) +} + +%region_172.197 (reduce_sum.628: f32[], reduce_sum.629: f32[]) -> f32[] { + %reduce_sum.628 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.629 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.423 = f32[]{:T(128)} add(%reduce_sum.628, %reduce_sum.629), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.781.clone.clone (param_0.4128: f32[4,128], param_1.5003: bf16[4,128,1536], param_2.4253: bf16[1536]) -> bf16[4,128,1536,1] { + %param_1.5003 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.3187 = f32[4,128,1536]{2,1,0:T(8,128)} convert(%param_1.5003), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} + %param_0.4128 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.5708 = f32[4,128,1536]{2,1,0:T(8,128)} broadcast(%param_0.4128), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %mul.5707 = f32[4,128,1536]{2,1,0:T(8,128)} multiply(%convert_element_type.3187, %mul.5708), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %convert_element_type.3186 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} convert(%mul.5707), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/convert_element_type" stack_frame_id=0} + %param_2.4253 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.5709 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.4253), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %mul.5706 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.3186, %mul.5709), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + ROOT %bitcast.1448 = bf16[4,128,1536,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%mul.5706), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} } %bitcast_fusion.12 (bitcast_input.12: bf16[4,128,128,192]) -> bf16[4,128,128,192] { %bitcast_input.12 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)S(1)} parameter(0) - ROOT %bitcast.1488 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)} bitcast(%bitcast_input.12) + ROOT %bitcast.1470 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)} bitcast(%bitcast_input.12) } -%fused_computation.551 (param_0.4150: bf16[4,128,128,192], param_1.5039: f32[4,128], param_2.4308: bf16[4,128,1536], param_3.2964: bf16[1536]) -> (f32[], bf16[1536,128,192,1]) { - %param_1.5039 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.4308 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %param_3.2964 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %fusion.459.clone.1 = bf16[4,128,1536,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_1.5039, %param_2.4308, %param_3.2964), kind=kLoop, calls=%fused_computation.782.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.4150 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)S(1)} parameter(0) - %fusion.746 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.4150), kind=kLoop, calls=%bitcast_fusion.12 - %convolution.146.clone.1 = bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)} convolution(%fusion.459.clone.1, %fusion.746), window={size=192x4 pad=191_191x0_0 rhs_reversal=1x0}, dim_labels=1fb0_1io0->bf01, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=0} - %bitcast.861 = bf16[1536,128,192]{1,0,2:T(8,128)(2,1)} bitcast(%convolution.146.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=0} - %broadcast_in_dim.1275 = f32[1536,128,192]{1,0,2:T(8,128)} convert(%bitcast.861), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} - %bitcast.763 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} bitcast(%broadcast_in_dim.1275), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} - %square.592 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%bitcast.763, %bitcast.763), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.5085 = f32[]{:T(128)} constant(0) - %reduce.692 = f32[]{:T(128)} reduce(%square.592, %constant.5085), dimensions={0,1,2,3}, to_apply=%region_172.197, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.763 = (f32[]{:T(128)}, bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)}) tuple(%reduce.692, %convolution.146.clone.1) +%fused_computation.550 (param_0.4176: bf16[4,128,128,192], param_1.5043: f32[4,128], param_2.4303: bf16[4,128,1536], param_3.2969: bf16[1536]) -> (f32[], bf16[1536,128,192,1]) { + %param_1.5043 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.4303 = bf16[4,128,1536]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.2969 = bf16[1536]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.458.clone.1 = bf16[4,128,1536,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_1.5043, %param_2.4303, %param_3.2969), kind=kLoop, calls=%fused_computation.781.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/moe_layers/mul" stack_frame_id=0} + %param_0.4176 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)S(1)} parameter(0) + %fusion.745 = bf16[4,128,128,192]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.4176), kind=kLoop, calls=%bitcast_fusion.12 + %convolution.146.clone.1 = bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)} convolution(%fusion.458.clone.1, %fusion.745), window={size=192x4 pad=191_191x0_0 rhs_reversal=1x0}, dim_labels=1fb0_1io0->bf01, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=0} + %bitcast.843 = bf16[1536,128,192]{1,0,2:T(8,128)(2,1)} bitcast(%convolution.146.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/closed_call/checkpoint/moe_layers/dot_general" stack_frame_id=0} + %broadcast_in_dim.1300 = f32[1536,128,192]{1,0,2:T(8,128)} convert(%bitcast.843), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/while/body/broadcast_in_dim" stack_frame_id=0} + %bitcast.745 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} bitcast(%broadcast_in_dim.1300), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/moe_layers.wrapped_fn/transpose" stack_frame_id=0} + %square.592 = f32[1536,1,128,192]{2,0,3,1:T(8,128)} multiply(%bitcast.745, %bitcast.745), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.5033 = f32[]{:T(128)} constant(0) + %reduce.633 = f32[]{:T(128)} reduce(%square.592, %constant.5033), dimensions={0,1,2,3}, to_apply=%region_172.197, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.763 = (f32[]{:T(128)}, bf16[1536,128,192,1]{1,0,3,2:T(8,128)(2,1)}) tuple(%reduce.633, %convolution.146.clone.1) } -%region_239.264 (reduce_sum.1019: f32[], reduce_sum.687: f32[]) -> f32[] { - %reduce_sum.1019 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.687 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.688 = f32[]{:T(128)} add(%reduce_sum.1019, %reduce_sum.687), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_239.264 (reduce_sum.1090: f32[], reduce_sum.1091: f32[]) -> f32[] { + %reduce_sum.1090 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.1091 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.576 = f32[]{:T(128)} add(%reduce_sum.1090, %reduce_sum.1091), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_205.230 (reduce_sum.781: f32[], reduce_sum.527: f32[]) -> f32[] { - %reduce_sum.781 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.527 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.528 = f32[]{:T(128)} add(%reduce_sum.781, %reduce_sum.527), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_205.230 (reduce_sum.852: f32[], reduce_sum.853: f32[]) -> f32[] { + %reduce_sum.852 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.853 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.494 = f32[]{:T(128)} add(%reduce_sum.852, %reduce_sum.853), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.556 (param_0.4118: f32[], param_1.5008: f32[], param_2.4277: f32[], param_3.2933: f32[1536,1,128,192], param_4.2185: f32[1536,1,128,192], param_5.1988: f32[], param_6.1425: bf16[1536,128,192,1], param_7.1106: pred[], param_8.871: f32[1536,1,128,192]) -> (f32[], f32[1536,1,128,192], f32[1536,1,128,192], f32[1536,1,128,192], f32[]) { +%fused_computation.555 (param_0.4144: f32[], param_1.5012: f32[], param_2.4272: f32[], param_3.2938: f32[1536,1,128,192], param_4.2185: f32[1536,1,128,192], param_5.1985: f32[], param_6.1426: bf16[1536,128,192,1], param_7.1106: pred[], param_8.871: f32[1536,1,128,192]) -> (f32[], f32[1536,1,128,192], f32[1536,1,128,192], f32[1536,1,128,192], f32[]) { diff --git a/tests/utils/reference_hlo_llama3_8b.txt b/tests/utils/reference_hlo_llama3_8b.txt index 27c6529df2..19a83b1a84 100644 --- a/tests/utils/reference_hlo_llama3_8b.txt +++ b/tests/utils/reference_hlo_llama3_8b.txt @@ -14,1355 +14,1355 @@ StackFrames %param_1.7 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.1 = s32[1024]{0:T(1024)} custom-call(%param_1.7), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} %slice.6 = s32[512]{0:T(512)} slice(%custom-call.1), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %reshape.342 = s32[4,128]{1,0:T(4,128)} reshape(%slice.6), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.326 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.342), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %gather.4 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} gather(%param_0.2, %transpose.326), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,4096}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %transpose.325 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} transpose(%gather.4), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - ROOT %reshape.341 = bf16[512,4096]{1,0:T(8,128)(2,1)} reshape(%transpose.325), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %reshape.380 = s32[4,128]{1,0:T(4,128)} reshape(%slice.6), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.241 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.380), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %gather.4 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} gather(%param_0.2, %transpose.241), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,4096}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %transpose.240 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} transpose(%gather.4), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %reshape.379 = bf16[512,4096]{1,0:T(8,128)(2,1)} reshape(%transpose.240), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} } %region_33.38.clone (scatter-add.6: bf16[], scatter-add.7: bf16[]) -> bf16[] { %scatter-add.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} %scatter-add.7 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} - ROOT %add.476 = bf16[]{:T(256)} add(%scatter-add.6, %scatter-add.7), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.462 = bf16[]{:T(256)} add(%scatter-add.6, %scatter-add.7), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %fused_computation.1 (param_0.3: bf16[128256,4096], param_1.5: s32[512], param_2.4: bf16[512,4096]) -> bf16[128256,4096] { %param_0.3 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(0) %param_1.5 = s32[512]{0:T(512)S(1)} parameter(1) - %reshape.349 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.5), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.331 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.349), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %reshape.387 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.5), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.246 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.387), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} %param_2.4 = bf16[512,4096]{1,0:T(8,128)(2,1)S(1)} parameter(2) - %reshape.350 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} reshape(%param_2.4), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} - %transpose.332 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} transpose(%reshape.350), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} - ROOT %scatter.2 = bf16[128256,4096]{1,0:T(8,128)(2,1)} scatter(%param_0.3, %transpose.331, %transpose.332), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_33.38.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} + %reshape.388 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} reshape(%param_2.4), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} + %transpose.247 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} transpose(%reshape.388), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} + ROOT %scatter.2 = bf16[128256,4096]{1,0:T(8,128)(2,1)} scatter(%param_0.3, %transpose.246, %transpose.247), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_33.38.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} } -%region_32.37 (reduce_sum.190: f32[], reduce_sum.191: f32[]) -> f32[] { - %reduce_sum.190 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.191 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.192 = f32[]{:T(128)} add(%reduce_sum.190, %reduce_sum.191), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_32.37 (reduce_sum.244: f32[], reduce_sum.245: f32[]) -> f32[] { + %reduce_sum.244 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.245 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.249 = f32[]{:T(128)} add(%reduce_sum.244, %reduce_sum.245), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.280.clone.clone.clone (param_0.1099: bf16[4,128,128256], param_1.1265: s32[4,128], param_2.1086: f32[4,128], param_3.785: f32[4,128], param_4.487: bf16[4,128], param_5.412: f32[4,128]) -> bf16[4,128,128256] { +%fused_computation.280.clone.clone.clone (param_0.1106: bf16[4,128,128256], param_1.1269: s32[4,128], param_2.1080: f32[4,128], param_3.773: f32[4,128], param_4.481: bf16[4,128], param_5.412: f32[4,128]) -> bf16[4,128,128256] { %param_5.412 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %mul.1613 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_5.412), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_3.785 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %mul.1612 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.785), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_0.1099 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1044 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1099), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_4.487 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) - %sub.94 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_4.487), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.93 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.1044, %sub.94), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %mul.1937 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_5.412), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.773 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.1936 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.773), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.1106 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1060 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1106), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_4.481 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.94 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_4.481), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.93 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.1060, %sub.94), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %exp.62 = f32[4,128,128256]{2,1,0:T(8,128)} exponential(%sub.93), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %mul.1611 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1612, %exp.62), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_2.1086 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %div.823 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1086), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %div.822 = f32[4,128,128256]{2,1,0:T(8,128)} divide(%mul.1611, %div.823), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %param_1.1265 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %eq.49 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1265), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %mul.1935 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1936, %exp.62), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.1080 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.823 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1080), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.822 = f32[4,128,128256]{2,1,0:T(8,128)} divide(%mul.1935, %div.823), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.1269 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.49 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1269), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.48 = s32[4,128,128256]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.47 = pred[4,128,128256]{2,1,0:T(8,128)(4,1)} compare(%eq.49, %eq.48), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %convert_element_type.1043 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%eq.47), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} - %sub.92 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%div.822, %convert_element_type.1043), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} - %mul.1610 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1613, %sub.92), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - ROOT %convert_element_type.1042 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convert(%mul.1610), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} -} - -%fused_computation.316.clone.clone (param_0.1100: f32[4,128], param_1.1266: bf16[4,128,4096], param_2.1088: bf16[4096]) -> bf16[4,128,4096] { - %param_2.1088 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.387 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1088), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1266 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1046 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1266), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.1100 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1615 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1100), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1614 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1046, %mul.1615), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.1045 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1614), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - ROOT %dot_general.386 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.387, %convert_element_type.1045), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.219 (param_0.1119: bf16[4,128,128256], param_1.1281: s32[4,128], param_2.1112: f32[4,128], param_3.801: f32[4,128], param_4.502: bf16[4,128], param_5.427: f32[4,128], param_6.299: f32[4,128], param_7.198: bf16[4,128,4096], param_8.116: bf16[4096]) -> (f32[], bf16[4096,128256,1]) { - %param_6.299 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %convert_element_type.1059 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%eq.47), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.92 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%div.822, %convert_element_type.1059), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.1934 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1937, %sub.92), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.1058 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convert(%mul.1934), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +} + +%fused_computation.316.clone.clone (param_0.1107: f32[4,128], param_1.1270: bf16[4,128,4096], param_2.1082: bf16[4096]) -> bf16[4,128,4096] { + %param_1.1270 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1062 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1270), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1107 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1940 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1107), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1939 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1062, %mul.1940), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1061 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1939), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.1082 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.1941 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1082), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.1938 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1061, %mul.1941), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%fused_computation.219 (param_0.1126: bf16[4,128,128256], param_1.1285: s32[4,128], param_2.1106: f32[4,128], param_3.789: f32[4,128], param_4.496: bf16[4,128], param_5.427: f32[4,128], param_6.300: f32[4,128], param_7.198: bf16[4,128,4096], param_8.116: bf16[4096]) -> (f32[], bf16[4096,128256,1]) { + %param_6.300 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) %param_7.198 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(7) %param_8.116 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(8) - %fusion.239.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_6.299, %param_7.198, %param_8.116), kind=kLoop, calls=%fused_computation.316.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1119 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1281 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.1112 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.801 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %param_4.502 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %fusion.239.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_6.300, %param_7.198, %param_8.116), kind=kLoop, calls=%fused_computation.316.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.1126 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1285 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1106 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.789 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %param_4.496 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) %param_5.427 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %multiply_convert_fusion.1.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1119, %param_1.1281, %param_2.1112, %param_3.801, %param_4.502, /*index=5*/%param_5.427), kind=kLoop, calls=%fused_computation.280.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %multiply_convert_fusion.1.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1126, %param_1.1285, %param_2.1106, %param_3.789, %param_4.496, /*index=5*/%param_5.427), kind=kLoop, calls=%fused_computation.280.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} %convolution.88.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} convolution(%fusion.239.clone.1, %multiply_convert_fusion.1.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} %bitcast.306 = bf16[4096,128256]{1,0:T(8,128)(2,1)} bitcast(%convolution.88.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %convert_element_type.923 = f32[4096,128256]{1,0:T(8,128)} convert(%bitcast.306), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} - %square.157 = f32[4096,128256]{1,0:T(8,128)} multiply(%convert_element_type.923, %convert_element_type.923), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1006 = f32[]{:T(128)} constant(0) - %reduce.118 = f32[]{:T(128)} reduce(%square.157, %constant.1006), dimensions={0,1}, to_apply=%region_32.37, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.154 = (f32[]{:T(128)}, bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.118, %convolution.88.clone.1) + %convert_element_type.939 = f32[4096,128256]{1,0:T(8,128)} convert(%bitcast.306), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %square.157 = f32[4096,128256]{1,0:T(8,128)} multiply(%convert_element_type.939, %convert_element_type.939), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.998 = f32[]{:T(128)} constant(0) + %reduce.79 = f32[]{:T(128)} reduce(%square.157, %constant.998), dimensions={0,1}, to_apply=%region_32.37, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.154 = (f32[]{:T(128)}, bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.79, %convolution.88.clone.1) } -%region_34.39 (reduce_sum.196: f32[], reduce_sum.197: f32[]) -> f32[] { - %reduce_sum.196 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.197 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.198 = f32[]{:T(128)} add(%reduce_sum.196, %reduce_sum.197), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_34.39 (reduce_sum.250: f32[], reduce_sum.251: f32[]) -> f32[] { + %reduce_sum.250 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.251 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.252 = f32[]{:T(128)} add(%reduce_sum.250, %reduce_sum.251), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.220 (param_0.1118: bf16[128256,4096]) -> f32[] { - %param_0.1118 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.925 = f32[128256,4096]{1,0:T(8,128)} convert(%param_0.1118), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %square.159 = f32[128256,4096]{1,0:T(8,128)} multiply(%convert_element_type.925, %convert_element_type.925), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1005 = f32[]{:T(128)} constant(0) - ROOT %reduce.119 = f32[]{:T(128)} reduce(%square.159, %constant.1005), dimensions={0,1}, to_apply=%region_34.39, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.220 (param_0.1125: bf16[128256,4096]) -> f32[] { + %param_0.1125 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.941 = f32[128256,4096]{1,0:T(8,128)} convert(%param_0.1125), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %square.159 = f32[128256,4096]{1,0:T(8,128)} multiply(%convert_element_type.941, %convert_element_type.941), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.997 = f32[]{:T(128)} constant(0) + ROOT %reduce.80 = f32[]{:T(128)} reduce(%square.159, %constant.997), dimensions={0,1}, to_apply=%region_34.39, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_60.65 (reduce_sum.338: f32[], reduce_sum.339: f32[]) -> f32[] { - %reduce_sum.338 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.339 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.329 = f32[]{:T(128)} add(%reduce_sum.338, %reduce_sum.339), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_60.65 (reduce_sum.385: f32[], reduce_sum.389: f32[]) -> f32[] { + %reduce_sum.385 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.389 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.390 = f32[]{:T(128)} add(%reduce_sum.385, %reduce_sum.389), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_46.51 (reduce_sum.259: f32[], reduce_sum.260: f32[]) -> f32[] { - %reduce_sum.259 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.260 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.261 = f32[]{:T(128)} add(%reduce_sum.259, %reduce_sum.260), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_46.51 (reduce_sum.313: f32[], reduce_sum.314: f32[]) -> f32[] { + %reduce_sum.313 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.314 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.315 = f32[]{:T(128)} add(%reduce_sum.313, %reduce_sum.314), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.221 (param_0.1106: f32[128256,4096], param_1.1269: f32[], param_2.1100: f32[], param_3.789: f32[], param_4.490: f32[128256,4096], param_5.415: f32[], param_6.287: bf16[128256,4096], param_7.186: pred[], param_8.104: f32[128256,4096]) -> (f32[], f32[128256,4096], f32[128256,4096], f32[128256,4096], f32[]) { - %param_0.1106 = f32[128256,4096]{1,0:T(8,128)} parameter(0) - %param_3.789 = f32[]{:T(128)S(6)} parameter(3) - %mul.1482.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_3.789), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.221 (param_0.1113: f32[128256,4096], param_1.1273: f32[], param_2.1094: f32[], param_3.777: f32[], param_4.484: f32[128256,4096], param_5.415: f32[], param_6.288: bf16[128256,4096], param_7.186: pred[], param_8.104: f32[128256,4096]) -> (f32[], f32[128256,4096], f32[128256,4096], f32[128256,4096], f32[]) { + %param_0.1113 = f32[128256,4096]{1,0:T(8,128)} parameter(0) + %param_3.777 = f32[]{:T(128)S(6)} parameter(3) + %mul.1800.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_3.777), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.186 = pred[]{:T(512)S(6)} parameter(7) %select_n.242.clone.1 = pred[128256,4096]{1,0:T(8,128)(4,1)} broadcast(%param_7.186), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.287 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(6) - %convert_element_type.1017.clone.1 = f32[128256,4096]{1,0:T(8,128)} convert(%param_6.287), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %param_6.288 = bf16[128256,4096]{1,0:T(8,128)(2,1)} parameter(6) + %convert_element_type.1033.clone.1 = f32[128256,4096]{1,0:T(8,128)} convert(%param_6.288), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} %param_5.415 = f32[]{:T(128)} parameter(5) %div.725.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_5.415), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.724.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%convert_element_type.1017.clone.1, %div.725.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.241.clone.1 = f32[128256,4096]{1,0:T(8,128)} select(%select_n.242.clone.1, %convert_element_type.1017.clone.1, %div.724.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.907.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.554.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.907.clone.1), dimensions={}, metadata={op_name="broadcast.61"} - %mul.1488.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%select_n.241.clone.1, %broadcast.554.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.724.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%convert_element_type.1033.clone.1, %div.725.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.241.clone.1 = f32[128256,4096]{1,0:T(8,128)} select(%select_n.242.clone.1, %convert_element_type.1033.clone.1, %div.724.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.899.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.515.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.899.clone.1), dimensions={}, metadata={op_name="broadcast.61"} + %mul.1806.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%select_n.241.clone.1, %broadcast.515.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.104 = f32[128256,4096]{1,0:T(8,128)} parameter(8) - %constant.911.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1489.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.911.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1487.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_8.104, %mul.1489.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.776.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%mul.1488.clone.1, %mul.1487.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1100 = f32[]{:T(128)S(6)} parameter(2) - %div.721.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_2.1100), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.903.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1807.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.903.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1805.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_8.104, %mul.1807.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.762.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%mul.1806.clone.1, %mul.1805.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1094 = f32[]{:T(128)S(6)} parameter(2) + %div.721.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_2.1094), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.60.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%select_n.241.clone.1, %select_n.241.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.910.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1486.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.910.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1484.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%integer_pow.60.clone.1, %mul.1486.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.490 = f32[128256,4096]{1,0:T(8,128)} parameter(4) - %constant.909.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1485.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.909.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1483.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_4.490, %mul.1485.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.775.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%mul.1484.clone.1, %mul.1483.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1269 = f32[]{:T(128)S(6)} parameter(1) - %div.720.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_1.1269), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.719.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%add.775.clone.1, %div.720.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.902.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1804.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.902.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1802.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%integer_pow.60.clone.1, %mul.1804.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.484 = f32[128256,4096]{1,0:T(8,128)} parameter(4) + %constant.901.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1803.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.901.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1801.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_4.484, %mul.1803.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.761.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%mul.1802.clone.1, %mul.1801.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1273 = f32[]{:T(128)S(6)} parameter(1) + %div.720.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%param_1.1273), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.719.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%add.761.clone.1, %div.720.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.58.clone.1 = f32[128256,4096]{1,0:T(8,128)} sqrt(%div.719.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.908.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.774.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.908.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.773.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%sqrt.58.clone.1, %add.774.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.256.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%div.721.clone.1, %add.773.clone.1), metadata={op_name="multiply.42"} - %div.718.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%add.776.clone.1, %multiply.256.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1481.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_0.1106, %broadcast.554.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.772.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%div.718.clone.1, %mul.1481.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1480.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%mul.1482.clone.1, %add.772.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.771.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%param_0.1106, %mul.1480.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.160 = f32[128256,4096]{1,0:T(8,128)} multiply(%add.771.clone.1, %add.771.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.993 = f32[]{:T(128)} constant(0) - %reduce.120 = f32[]{:T(128)} reduce(%square.160, %constant.993), dimensions={0,1}, to_apply=%region_60.65, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.122.clone.1 = f32[]{:T(128)} reduce(%integer_pow.60.clone.1, %constant.993), dimensions={0,1}, to_apply=%region_46.51, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.135 = (f32[]{:T(128)}, f32[128256,4096]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.120, %add.771.clone.1, %add.775.clone.1, %add.776.clone.1, %reduce.122.clone.1) -} - -%region_59.64 (reduce_sum.331: f32[], reduce_sum.332: f32[]) -> f32[] { - %reduce_sum.331 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.332 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.323 = f32[]{:T(128)} add(%reduce_sum.331, %reduce_sum.332), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_45.50 (reduce_sum.253: f32[], reduce_sum.254: f32[]) -> f32[] { - %reduce_sum.253 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.254 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.255 = f32[]{:T(128)} add(%reduce_sum.253, %reduce_sum.254), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.900.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.760.clone.1 = f32[128256,4096]{1,0:T(8,128)} broadcast(%constant.900.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.759.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%sqrt.58.clone.1, %add.760.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.183.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%div.721.clone.1, %add.759.clone.1), metadata={op_name="multiply.33"} + %div.718.clone.1 = f32[128256,4096]{1,0:T(8,128)} divide(%add.762.clone.1, %multiply.183.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1799.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%param_0.1113, %broadcast.515.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.758.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%div.718.clone.1, %mul.1799.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1798.clone.1 = f32[128256,4096]{1,0:T(8,128)} multiply(%mul.1800.clone.1, %add.758.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.757.clone.1 = f32[128256,4096]{1,0:T(8,128)} add(%param_0.1113, %mul.1798.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.160 = f32[128256,4096]{1,0:T(8,128)} multiply(%add.757.clone.1, %add.757.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.985 = f32[]{:T(128)} constant(0) + %reduce.81 = f32[]{:T(128)} reduce(%square.160, %constant.985), dimensions={0,1}, to_apply=%region_60.65, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.83.clone.1 = f32[]{:T(128)} reduce(%integer_pow.60.clone.1, %constant.985), dimensions={0,1}, to_apply=%region_46.51, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.135 = (f32[]{:T(128)}, f32[128256,4096]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, f32[128256,4096]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.81, %add.757.clone.1, %add.761.clone.1, %add.762.clone.1, %reduce.83.clone.1) +} + +%region_59.64 (reduce_sum.382: f32[], reduce_sum.383: f32[]) -> f32[] { + %reduce_sum.382 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.383 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.384 = f32[]{:T(128)} add(%reduce_sum.382, %reduce_sum.383), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_45.50 (reduce_sum.307: f32[], reduce_sum.308: f32[]) -> f32[] { + %reduce_sum.307 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.308 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.312 = f32[]{:T(128)} add(%reduce_sum.307, %reduce_sum.308), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.222 (param_0.1107: f32[4096,128256], param_1.1270: f32[], param_2.1101: f32[], param_3.790: f32[], param_4.491: f32[4096,128256], param_5.416: f32[], param_6.288: bf16[4096,128256,1], param_7.187: pred[], param_8.105: f32[4096,128256]) -> (f32[], f32[4096,128256], f32[4096,128256], f32[4096,128256], f32[]) { - %param_0.1107 = f32[4096,128256]{1,0:T(8,128)} parameter(0) - %param_3.790 = f32[]{:T(128)S(6)} parameter(3) - %mul.1492.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_3.790), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.222 (param_0.1114: f32[4096,128256], param_1.1274: f32[], param_2.1095: f32[], param_3.778: f32[], param_4.485: f32[4096,128256], param_5.416: f32[], param_6.289: bf16[4096,128256,1], param_7.187: pred[], param_8.105: f32[4096,128256]) -> (f32[], f32[4096,128256], f32[4096,128256], f32[4096,128256], f32[]) { + %param_0.1114 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %param_3.778 = f32[]{:T(128)S(6)} parameter(3) + %mul.1810.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_3.778), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.187 = pred[]{:T(512)S(6)} parameter(7) %select_n.246.clone.1 = pred[4096,128256]{1,0:T(8,128)(4,1)} broadcast(%param_7.187), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.288 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} parameter(6) - %bitcast.409.clone.1 = bf16[4096,128256]{1,0:T(8,128)(2,1)} bitcast(%param_6.288), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %convert_element_type.1019.clone.1 = f32[4096,128256]{1,0:T(8,128)} convert(%bitcast.409.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %param_6.289 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} parameter(6) + %bitcast.409.clone.1 = bf16[4096,128256]{1,0:T(8,128)(2,1)} bitcast(%param_6.289), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} + %convert_element_type.1035.clone.1 = f32[4096,128256]{1,0:T(8,128)} convert(%bitcast.409.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} %param_5.416 = f32[]{:T(128)} parameter(5) %div.733.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_5.416), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.732.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%convert_element_type.1019.clone.1, %div.733.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.245.clone.1 = f32[4096,128256]{1,0:T(8,128)} select(%select_n.246.clone.1, %convert_element_type.1019.clone.1, %div.732.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.913.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.556.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.913.clone.1), dimensions={}, metadata={op_name="broadcast.62"} - %mul.1498.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%select_n.245.clone.1, %broadcast.556.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.732.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%convert_element_type.1035.clone.1, %div.733.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.245.clone.1 = f32[4096,128256]{1,0:T(8,128)} select(%select_n.246.clone.1, %convert_element_type.1035.clone.1, %div.732.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.905.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.517.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.905.clone.1), dimensions={}, metadata={op_name="broadcast.62"} + %mul.1816.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%select_n.245.clone.1, %broadcast.517.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.105 = f32[4096,128256]{1,0:T(8,128)} parameter(8) - %constant.917.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1499.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.917.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1497.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_8.105, %mul.1499.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.782.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%mul.1498.clone.1, %mul.1497.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1101 = f32[]{:T(128)S(6)} parameter(2) - %div.729.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_2.1101), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.909.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1817.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.909.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1815.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_8.105, %mul.1817.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.768.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%mul.1816.clone.1, %mul.1815.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1095 = f32[]{:T(128)S(6)} parameter(2) + %div.729.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_2.1095), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.61.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%select_n.245.clone.1, %select_n.245.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.916.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1496.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.916.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1494.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%integer_pow.61.clone.1, %mul.1496.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.491 = f32[4096,128256]{1,0:T(8,128)} parameter(4) - %constant.915.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1495.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.915.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1493.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_4.491, %mul.1495.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.781.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%mul.1494.clone.1, %mul.1493.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1270 = f32[]{:T(128)S(6)} parameter(1) - %div.728.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_1.1270), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.727.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%add.781.clone.1, %div.728.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.908.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1814.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.908.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1812.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%integer_pow.61.clone.1, %mul.1814.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.485 = f32[4096,128256]{1,0:T(8,128)} parameter(4) + %constant.907.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1813.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.907.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1811.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_4.485, %mul.1813.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.767.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%mul.1812.clone.1, %mul.1811.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1274 = f32[]{:T(128)S(6)} parameter(1) + %div.728.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%param_1.1274), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.727.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%add.767.clone.1, %div.728.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.59.clone.1 = f32[4096,128256]{1,0:T(8,128)} sqrt(%div.727.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.914.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.780.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.914.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.779.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%sqrt.59.clone.1, %add.780.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.257.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%div.729.clone.1, %add.779.clone.1), metadata={op_name="multiply.41"} - %div.726.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%add.782.clone.1, %multiply.257.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1491.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_0.1107, %broadcast.556.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.778.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%div.726.clone.1, %mul.1491.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1490.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%mul.1492.clone.1, %add.778.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.777.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%param_0.1107, %mul.1490.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.161 = f32[4096,128256]{1,0:T(8,128)} multiply(%add.777.clone.1, %add.777.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.994 = f32[]{:T(128)} constant(0) - %reduce.121 = f32[]{:T(128)} reduce(%square.161, %constant.994), dimensions={0,1}, to_apply=%region_59.64, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.123.clone.1 = f32[]{:T(128)} reduce(%integer_pow.61.clone.1, %constant.994), dimensions={0,1}, to_apply=%region_45.50, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.136 = (f32[]{:T(128)}, f32[4096,128256]{1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.121, %add.777.clone.1, %add.781.clone.1, %add.782.clone.1, %reduce.123.clone.1) -} - -%region_25.30 (reduce_sum.154: f32[], reduce_sum.155: f32[]) -> f32[] { - %reduce_sum.154 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.155 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.156 = f32[]{:T(128)} add(%reduce_sum.154, %reduce_sum.155), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.239 (param_0.1124: f32[4,14336,4096]) -> f32[] { - %param_0.1124 = f32[4,14336,4096]{2,0,1:T(4,128)} parameter(0) - %bitcast.314 = f32[14336,4,4096]{2,1,0:T(4,128)} bitcast(%param_0.1124), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %constant.906.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.766.clone.1 = f32[4096,128256]{1,0:T(8,128)} broadcast(%constant.906.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.765.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%sqrt.59.clone.1, %add.766.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.184.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%div.729.clone.1, %add.765.clone.1), metadata={op_name="multiply.32"} + %div.726.clone.1 = f32[4096,128256]{1,0:T(8,128)} divide(%add.768.clone.1, %multiply.184.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1809.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%param_0.1114, %broadcast.517.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.764.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%div.726.clone.1, %mul.1809.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1808.clone.1 = f32[4096,128256]{1,0:T(8,128)} multiply(%mul.1810.clone.1, %add.764.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.763.clone.1 = f32[4096,128256]{1,0:T(8,128)} add(%param_0.1114, %mul.1808.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.161 = f32[4096,128256]{1,0:T(8,128)} multiply(%add.763.clone.1, %add.763.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.986 = f32[]{:T(128)} constant(0) + %reduce.82 = f32[]{:T(128)} reduce(%square.161, %constant.986), dimensions={0,1}, to_apply=%region_59.64, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.84.clone.1 = f32[]{:T(128)} reduce(%integer_pow.61.clone.1, %constant.986), dimensions={0,1}, to_apply=%region_45.50, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.136 = (f32[]{:T(128)}, f32[4096,128256]{1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[4096,128256]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.82, %add.763.clone.1, %add.767.clone.1, %add.768.clone.1, %reduce.84.clone.1) +} + +%region_25.30 (reduce_sum.208: f32[], reduce_sum.209: f32[]) -> f32[] { + %reduce_sum.208 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.209 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.210 = f32[]{:T(128)} add(%reduce_sum.208, %reduce_sum.209), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.239 (param_0.1131: f32[4,14336,4096]) -> f32[] { + %param_0.1131 = f32[4,14336,4096]{2,0,1:T(4,128)} parameter(0) + %bitcast.314 = f32[14336,4,4096]{2,1,0:T(4,128)} bitcast(%param_0.1131), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.164 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%bitcast.314, %bitcast.314), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1011 = f32[]{:T(128)} constant(0) - ROOT %reduce.124 = f32[]{:T(128)} reduce(%square.164, %constant.1011), dimensions={0,1,2}, to_apply=%region_25.30, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %constant.1003 = f32[]{:T(128)} constant(0) + ROOT %reduce.85 = f32[]{:T(128)} reduce(%square.164, %constant.1003), dimensions={0,1,2}, to_apply=%region_25.30, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_24.29 (reduce_sum.148: f32[], reduce_sum.149: f32[]) -> f32[] { - %reduce_sum.148 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.149 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.150 = f32[]{:T(128)} add(%reduce_sum.148, %reduce_sum.149), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_24.29 (reduce_sum.202: f32[], reduce_sum.203: f32[]) -> f32[] { + %reduce_sum.202 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.203 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.207 = f32[]{:T(128)} add(%reduce_sum.202, %reduce_sum.203), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_23.28 (reduce_sum.142: f32[], reduce_sum.143: f32[]) -> f32[] { - %reduce_sum.142 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.143 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.147 = f32[]{:T(128)} add(%reduce_sum.142, %reduce_sum.143), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_23.28 (reduce_sum.196: f32[], reduce_sum.200: f32[]) -> f32[] { + %reduce_sum.196 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.200 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.201 = f32[]{:T(128)} add(%reduce_sum.196, %reduce_sum.200), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.241 (param_0.1125: f32[4,4096,14336], param_1.1284: f32[4,4096,14336]) -> (f32[], f32[]) { - %param_0.1125 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(0) - %bitcast.318 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_0.1125), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.241 (param_0.1132: f32[4,4096,14336], param_1.1288: f32[4,4096,14336]) -> (f32[], f32[]) { + %param_0.1132 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(0) + %bitcast.318 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_0.1132), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.167 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%bitcast.318, %bitcast.318), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1012 = f32[]{:T(128)} constant(0) - %reduce.125 = f32[]{:T(128)} reduce(%square.167, %constant.1012), dimensions={0,1,2}, to_apply=%region_24.29, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1284 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(1) - %bitcast.322.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_1.1284), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %constant.1004 = f32[]{:T(128)} constant(0) + %reduce.86 = f32[]{:T(128)} reduce(%square.167, %constant.1004), dimensions={0,1,2}, to_apply=%region_24.29, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1288 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(1) + %bitcast.322.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_1.1288), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.170.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%bitcast.322.clone.1, %bitcast.322.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.126.clone.1 = f32[]{:T(128)} reduce(%square.170.clone.1, %constant.1012), dimensions={0,1,2}, to_apply=%region_23.28, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.155 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.125, %reduce.126.clone.1) + %reduce.87.clone.1 = f32[]{:T(128)} reduce(%square.170.clone.1, %constant.1004), dimensions={0,1,2}, to_apply=%region_23.28, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.155 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.86, %reduce.87.clone.1) } -%fused_computation.244 (param_0.694: f32[14336,4,4096]) -> bf16[4,14336,4096] { - %param_0.694 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(0) - %copy.234 = bf16[14336,4,4096]{2,0,1:T(8,128)(2,1)} copy(%param_0.694), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wo\'][\'kernel\']"} +%fused_computation.244 (param_0.699: f32[14336,4,4096]) -> bf16[4,14336,4096] { + %param_0.699 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(0) + %copy.234 = bf16[14336,4,4096]{2,0,1:T(8,128)(2,1)} copy(%param_0.699), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wo\'][\'kernel\']"} ROOT %bitcast.323 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} bitcast(%copy.234), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%fused_computation.245 (param_0.696: f32[4096,4,14336]) -> bf16[4,4096,14336] { - %param_0.696 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) - %copy.235 = bf16[4096,4,14336]{2,0,1:T(8,128)(2,1)} copy(%param_0.696), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_1\'][\'kernel\']"} +%fused_computation.245 (param_0.701: f32[4096,4,14336]) -> bf16[4,4096,14336] { + %param_0.701 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %copy.235 = bf16[4096,4,14336]{2,0,1:T(8,128)(2,1)} copy(%param_0.701), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_1\'][\'kernel\']"} ROOT %bitcast.324 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} bitcast(%copy.235), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%fused_computation.246 (param_0.698: f32[4096,4,14336]) -> bf16[4,4096,14336] { - %param_0.698 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) - %copy.236 = bf16[4096,4,14336]{2,0,1:T(8,128)(2,1)} copy(%param_0.698), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_0\'][\'kernel\']"} +%fused_computation.246 (param_0.703: f32[4096,4,14336]) -> bf16[4,4096,14336] { + %param_0.703 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %copy.236 = bf16[4096,4,14336]{2,0,1:T(8,128)(2,1)} copy(%param_0.703), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_0\'][\'kernel\']"} ROOT %bitcast.325 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} bitcast(%copy.236), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%region_52.57 (reduce_sum.289: f32[], reduce_sum.290: f32[]) -> f32[] { - %reduce_sum.289 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.290 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.294 = f32[]{:T(128)} add(%reduce_sum.289, %reduce_sum.290), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_52.57 (reduce_sum.343: f32[], reduce_sum.347: f32[]) -> f32[] { + %reduce_sum.343 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.347 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.348 = f32[]{:T(128)} add(%reduce_sum.343, %reduce_sum.347), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_38.43 (reduce_sum.217: f32[], reduce_sum.218: f32[]) -> f32[] { - %reduce_sum.217 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.218 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.219 = f32[]{:T(128)} add(%reduce_sum.217, %reduce_sum.218), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_38.43 (reduce_sum.271: f32[], reduce_sum.272: f32[]) -> f32[] { + %reduce_sum.271 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.272 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.273 = f32[]{:T(128)} add(%reduce_sum.271, %reduce_sum.272), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.247 (param_0.1114: f32[14336,4,4096], param_1.1277: f32[], param_2.1108: f32[], param_3.797: f32[], param_4.498: f32[14336,4,4096], param_5.423: f32[], param_6.295: f32[4,14336,4096], param_7.194: pred[], param_8.112: f32[14336,4,4096]) -> (f32[], f32[14336,4,4096], f32[14336,4,4096], f32[14336,4,4096], f32[]) { - %param_0.1114 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(0) - %param_3.797 = f32[]{:T(128)S(6)} parameter(3) - %mul.1550.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_3.797), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.247 (param_0.1121: f32[14336,4,4096], param_1.1281: f32[], param_2.1102: f32[], param_3.785: f32[], param_4.492: f32[14336,4,4096], param_5.423: f32[], param_6.296: f32[4,14336,4096], param_7.194: pred[], param_8.112: f32[14336,4,4096]) -> (f32[], f32[14336,4,4096], f32[14336,4,4096], f32[14336,4,4096], f32[]) { + %param_0.1121 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(0) + %param_3.785 = f32[]{:T(128)S(6)} parameter(3) + %mul.1868.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_3.785), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.194 = pred[]{:T(512)S(6)} parameter(7) %select_n.274.clone.1 = pred[14336,4,4096]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.194), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.295 = f32[4,14336,4096]{2,0,1:T(4,128)} parameter(6) - %bitcast.423.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} bitcast(%param_6.295), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.296 = f32[4,14336,4096]{2,0,1:T(4,128)} parameter(6) + %bitcast.423.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} bitcast(%param_6.296), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.423 = f32[]{:T(128)} parameter(5) %div.789.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_5.423), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.788.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%bitcast.423.clone.1, %div.789.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.273.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} select(%select_n.274.clone.1, %bitcast.423.clone.1, %div.788.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.955.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.586.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.955.clone.1), dimensions={}, metadata={op_name="broadcast.69"} - %mul.1556.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%select_n.273.clone.1, %broadcast.586.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.947.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.547.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.947.clone.1), dimensions={}, metadata={op_name="broadcast.69"} + %mul.1874.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%select_n.273.clone.1, %broadcast.547.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.112 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(8) - %constant.959.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1557.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.959.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1555.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_8.112, %mul.1557.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.820.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%mul.1556.clone.1, %mul.1555.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1108 = f32[]{:T(128)S(6)} parameter(2) - %div.785.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_2.1108), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.951.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1875.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.951.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1873.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_8.112, %mul.1875.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.806.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%mul.1874.clone.1, %mul.1873.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1102 = f32[]{:T(128)S(6)} parameter(2) + %div.785.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_2.1102), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.68.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%select_n.273.clone.1, %select_n.273.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.958.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1554.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.958.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1552.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%integer_pow.68.clone.1, %mul.1554.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.498 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(4) - %constant.957.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1553.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.957.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1551.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_4.498, %mul.1553.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.819.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%mul.1552.clone.1, %mul.1551.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1277 = f32[]{:T(128)S(6)} parameter(1) - %div.784.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_1.1277), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.783.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%add.819.clone.1, %div.784.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.950.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1872.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.950.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1870.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%integer_pow.68.clone.1, %mul.1872.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.492 = f32[14336,4,4096]{2,1,0:T(4,128)} parameter(4) + %constant.949.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1871.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.949.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1869.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_4.492, %mul.1871.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.805.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%mul.1870.clone.1, %mul.1869.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1281 = f32[]{:T(128)S(6)} parameter(1) + %div.784.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%param_1.1281), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.783.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%add.805.clone.1, %div.784.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.66.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} sqrt(%div.783.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.956.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.818.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.956.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.817.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%sqrt.66.clone.1, %add.818.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.264.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%div.785.clone.1, %add.817.clone.1), metadata={op_name="multiply.34"} - %div.782.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%add.820.clone.1, %multiply.264.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1549.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_0.1114, %broadcast.586.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.816.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%div.782.clone.1, %mul.1549.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1548.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%mul.1550.clone.1, %add.816.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.815.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%param_0.1114, %mul.1548.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.171 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%add.815.clone.1, %add.815.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1001 = f32[]{:T(128)} constant(0) - %reduce.127 = f32[]{:T(128)} reduce(%square.171, %constant.1001), dimensions={0,1,2}, to_apply=%region_52.57, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.130.clone.1 = f32[]{:T(128)} reduce(%integer_pow.68.clone.1, %constant.1001), dimensions={0,1,2}, to_apply=%region_38.43, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.137 = (f32[]{:T(128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.127, %add.815.clone.1, %add.819.clone.1, %add.820.clone.1, %reduce.130.clone.1) + %constant.948.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.804.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} broadcast(%constant.948.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.803.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%sqrt.66.clone.1, %add.804.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.191.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%div.785.clone.1, %add.803.clone.1), metadata={op_name="multiply.25"} + %div.782.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} divide(%add.806.clone.1, %multiply.191.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1867.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%param_0.1121, %broadcast.547.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.802.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%div.782.clone.1, %mul.1867.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1866.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%mul.1868.clone.1, %add.802.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.801.clone.1 = f32[14336,4,4096]{2,1,0:T(4,128)} add(%param_0.1121, %mul.1866.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.171 = f32[14336,4,4096]{2,1,0:T(4,128)} multiply(%add.801.clone.1, %add.801.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.993 = f32[]{:T(128)} constant(0) + %reduce.88 = f32[]{:T(128)} reduce(%square.171, %constant.993), dimensions={0,1,2}, to_apply=%region_52.57, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.91.clone.1 = f32[]{:T(128)} reduce(%integer_pow.68.clone.1, %constant.993), dimensions={0,1,2}, to_apply=%region_38.43, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.137 = (f32[]{:T(128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[14336,4,4096]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.88, %add.801.clone.1, %add.805.clone.1, %add.806.clone.1, %reduce.91.clone.1) } -%region_51.56 (reduce_sum.283: f32[], reduce_sum.287: f32[]) -> f32[] { - %reduce_sum.283 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.287 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.288 = f32[]{:T(128)} add(%reduce_sum.283, %reduce_sum.287), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_51.56 (reduce_sum.340: f32[], reduce_sum.341: f32[]) -> f32[] { + %reduce_sum.340 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.341 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.342 = f32[]{:T(128)} add(%reduce_sum.340, %reduce_sum.341), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_37.42 (reduce_sum.211: f32[], reduce_sum.212: f32[]) -> f32[] { - %reduce_sum.211 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.212 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.213 = f32[]{:T(128)} add(%reduce_sum.211, %reduce_sum.212), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_37.42 (reduce_sum.265: f32[], reduce_sum.266: f32[]) -> f32[] { + %reduce_sum.265 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.266 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.270 = f32[]{:T(128)} add(%reduce_sum.265, %reduce_sum.266), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.248 (param_0.1115: f32[4096,4,14336], param_1.1278: f32[], param_2.1109: f32[], param_3.798: f32[], param_4.499: f32[4096,4,14336], param_5.424: f32[], param_6.296: f32[4,4096,14336], param_7.195: pred[], param_8.113: f32[4096,4,14336]) -> (f32[], f32[4096,4,14336], f32[4096,4,14336], f32[4096,4,14336], f32[]) { - %param_0.1115 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) - %param_3.798 = f32[]{:T(128)S(6)} parameter(3) - %mul.1560.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_3.798), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.248 (param_0.1122: f32[4096,4,14336], param_1.1282: f32[], param_2.1103: f32[], param_3.786: f32[], param_4.493: f32[4096,4,14336], param_5.424: f32[], param_6.297: f32[4,4096,14336], param_7.195: pred[], param_8.113: f32[4096,4,14336]) -> (f32[], f32[4096,4,14336], f32[4096,4,14336], f32[4096,4,14336], f32[]) { + %param_0.1122 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %param_3.786 = f32[]{:T(128)S(6)} parameter(3) + %mul.1878.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_3.786), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.195 = pred[]{:T(512)S(6)} parameter(7) %select_n.278.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.195), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.296 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(6) - %bitcast.425.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_6.296), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.297 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(6) + %bitcast.425.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_6.297), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.424 = f32[]{:T(128)} parameter(5) %div.797.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_5.424), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.796.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%bitcast.425.clone.1, %div.797.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.277.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%select_n.278.clone.1, %bitcast.425.clone.1, %div.796.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.961.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.592.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.961.clone.1), dimensions={}, metadata={op_name="broadcast.71"} - %mul.1564.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.277.clone.1, %broadcast.592.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.953.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.553.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.953.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.1882.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.277.clone.1, %broadcast.553.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.113 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(8) - %constant.965.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.591.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.965.clone.1), dimensions={}, metadata={op_name="broadcast.70"} - %mul.1563.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_8.113, %broadcast.591.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.825.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1564.clone.1, %mul.1563.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1109 = f32[]{:T(128)S(6)} parameter(2) - %div.793.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_2.1109), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.957.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.552.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.957.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.1881.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_8.113, %broadcast.552.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.811.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1882.clone.1, %mul.1881.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1103 = f32[]{:T(128)S(6)} parameter(2) + %div.793.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_2.1103), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.69.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.277.clone.1, %select_n.277.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.964.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.590.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.964.clone.1), dimensions={}, metadata={op_name="broadcast.60"} - %mul.1562.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%integer_pow.69.clone.1, %broadcast.590.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.499 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(4) - %constant.963.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.589.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.963.clone.1), dimensions={}, metadata={op_name="broadcast.59"} - %mul.1561.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_4.499, %broadcast.589.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.824.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1562.clone.1, %mul.1561.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1278 = f32[]{:T(128)S(6)} parameter(1) - %div.792.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_1.1278), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.791.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.824.clone.1, %div.792.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.956.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.551.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.956.clone.1), dimensions={}, metadata={op_name="broadcast.60"} + %mul.1880.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%integer_pow.69.clone.1, %broadcast.551.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.493 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(4) + %constant.955.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.550.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.955.clone.1), dimensions={}, metadata={op_name="broadcast.59"} + %mul.1879.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_4.493, %broadcast.550.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.810.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1880.clone.1, %mul.1879.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1282 = f32[]{:T(128)S(6)} parameter(1) + %div.792.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_1.1282), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.791.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.810.clone.1, %div.792.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.67.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} sqrt(%div.791.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.962.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.587.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.962.clone.1), dimensions={}, metadata={op_name="broadcast.54"} - %add.823.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%sqrt.67.clone.1, %broadcast.587.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.265.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%div.793.clone.1, %add.823.clone.1), metadata={op_name="multiply.33"} - %div.790.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.825.clone.1, %multiply.265.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1559.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_0.1115, %broadcast.592.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.822.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%div.790.clone.1, %mul.1559.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1558.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%mul.1560.clone.1, %add.822.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.821.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%param_0.1115, %mul.1558.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.172 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%add.821.clone.1, %add.821.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1002 = f32[]{:T(128)} constant(0) - %reduce.128 = f32[]{:T(128)} reduce(%square.172, %constant.1002), dimensions={0,1,2}, to_apply=%region_51.56, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.131.clone.1 = f32[]{:T(128)} reduce(%integer_pow.69.clone.1, %constant.1002), dimensions={0,1,2}, to_apply=%region_37.42, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.138 = (f32[]{:T(128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.128, %add.821.clone.1, %add.824.clone.1, %add.825.clone.1, %reduce.131.clone.1) + %constant.954.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.548.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.954.clone.1), dimensions={}, metadata={op_name="broadcast.54"} + %add.809.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%sqrt.67.clone.1, %broadcast.548.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.192.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%div.793.clone.1, %add.809.clone.1), metadata={op_name="multiply.24"} + %div.790.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.811.clone.1, %multiply.192.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1877.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_0.1122, %broadcast.553.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.808.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%div.790.clone.1, %mul.1877.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1876.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%mul.1878.clone.1, %add.808.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.807.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%param_0.1122, %mul.1876.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.172 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%add.807.clone.1, %add.807.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.994 = f32[]{:T(128)} constant(0) + %reduce.89 = f32[]{:T(128)} reduce(%square.172, %constant.994), dimensions={0,1,2}, to_apply=%region_51.56, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.92.clone.1 = f32[]{:T(128)} reduce(%integer_pow.69.clone.1, %constant.994), dimensions={0,1,2}, to_apply=%region_37.42, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.138 = (f32[]{:T(128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.89, %add.807.clone.1, %add.810.clone.1, %add.811.clone.1, %reduce.92.clone.1) } -%region_50.55 (reduce_sum.280: f32[], reduce_sum.281: f32[]) -> f32[] { - %reduce_sum.280 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.281 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.282 = f32[]{:T(128)} add(%reduce_sum.280, %reduce_sum.281), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_50.55 (reduce_sum.334: f32[], reduce_sum.335: f32[]) -> f32[] { + %reduce_sum.334 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.335 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.336 = f32[]{:T(128)} add(%reduce_sum.334, %reduce_sum.335), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_36.41 (reduce_sum.205: f32[], reduce_sum.206: f32[]) -> f32[] { - %reduce_sum.205 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.206 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.210 = f32[]{:T(128)} add(%reduce_sum.205, %reduce_sum.206), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_36.41 (reduce_sum.259: f32[], reduce_sum.263: f32[]) -> f32[] { + %reduce_sum.259 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.263 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.264 = f32[]{:T(128)} add(%reduce_sum.259, %reduce_sum.263), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.249 (param_0.1116: f32[4096,4,14336], param_1.1279: f32[], param_2.1110: f32[], param_3.799: f32[], param_4.500: f32[4096,4,14336], param_5.425: f32[], param_6.297: f32[4,4096,14336], param_7.196: pred[], param_8.114: f32[4096,4,14336]) -> (f32[], f32[4096,4,14336], f32[4096,4,14336], f32[4096,4,14336], f32[]) { - %param_0.1116 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) - %param_3.799 = f32[]{:T(128)S(6)} parameter(3) - %mul.1567.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_3.799), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.249 (param_0.1123: f32[4096,4,14336], param_1.1283: f32[], param_2.1104: f32[], param_3.787: f32[], param_4.494: f32[4096,4,14336], param_5.425: f32[], param_6.298: f32[4,4096,14336], param_7.196: pred[], param_8.114: f32[4096,4,14336]) -> (f32[], f32[4096,4,14336], f32[4096,4,14336], f32[4096,4,14336], f32[]) { + %param_0.1123 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(0) + %param_3.787 = f32[]{:T(128)S(6)} parameter(3) + %mul.1885.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_3.787), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.196 = pred[]{:T(512)S(6)} parameter(7) %select_n.282.clone.1 = pred[4096,4,14336]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.196), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.297 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(6) - %bitcast.427.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_6.297), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.298 = f32[4,4096,14336]{2,0,1:T(4,128)} parameter(6) + %bitcast.427.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} bitcast(%param_6.298), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.425 = f32[]{:T(128)} parameter(5) %div.805.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_5.425), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.804.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%bitcast.427.clone.1, %div.805.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.281.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} select(%select_n.282.clone.1, %bitcast.427.clone.1, %div.804.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.967.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.598.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.967.clone.1), dimensions={}, metadata={op_name="broadcast.71"} - %mul.1571.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.281.clone.1, %broadcast.598.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.959.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.559.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.959.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.1889.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.281.clone.1, %broadcast.559.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.114 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(8) - %constant.971.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.597.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.971.clone.1), dimensions={}, metadata={op_name="broadcast.70"} - %mul.1570.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_8.114, %broadcast.597.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.830.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1571.clone.1, %mul.1570.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1110 = f32[]{:T(128)S(6)} parameter(2) - %div.801.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_2.1110), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.963.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.558.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.963.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.1888.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_8.114, %broadcast.558.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.816.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1889.clone.1, %mul.1888.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1104 = f32[]{:T(128)S(6)} parameter(2) + %div.801.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_2.1104), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.70.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%select_n.281.clone.1, %select_n.281.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.970.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.596.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.970.clone.1), dimensions={}, metadata={op_name="broadcast.60"} - %mul.1569.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%integer_pow.70.clone.1, %broadcast.596.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.500 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(4) - %constant.969.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.595.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.969.clone.1), dimensions={}, metadata={op_name="broadcast.59"} - %mul.1568.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_4.500, %broadcast.595.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.829.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1569.clone.1, %mul.1568.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1279 = f32[]{:T(128)S(6)} parameter(1) - %div.800.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_1.1279), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.799.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.829.clone.1, %div.800.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.962.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.557.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.962.clone.1), dimensions={}, metadata={op_name="broadcast.60"} + %mul.1887.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%integer_pow.70.clone.1, %broadcast.557.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.494 = f32[4096,4,14336]{2,1,0:T(4,128)} parameter(4) + %constant.961.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.556.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.961.clone.1), dimensions={}, metadata={op_name="broadcast.59"} + %mul.1886.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_4.494, %broadcast.556.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.815.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%mul.1887.clone.1, %mul.1886.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1283 = f32[]{:T(128)S(6)} parameter(1) + %div.800.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%param_1.1283), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.799.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.815.clone.1, %div.800.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.68.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} sqrt(%div.799.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.968.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.593.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.968.clone.1), dimensions={}, metadata={op_name="broadcast.54"} - %add.828.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%sqrt.68.clone.1, %broadcast.593.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.266.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%div.801.clone.1, %add.828.clone.1), metadata={op_name="multiply.32"} - %div.798.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.830.clone.1, %multiply.266.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1566.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_0.1116, %broadcast.598.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.827.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%div.798.clone.1, %mul.1566.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1565.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%mul.1567.clone.1, %add.827.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.826.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%param_0.1116, %mul.1565.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.173 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%add.826.clone.1, %add.826.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1003 = f32[]{:T(128)} constant(0) - %reduce.129 = f32[]{:T(128)} reduce(%square.173, %constant.1003), dimensions={0,1,2}, to_apply=%region_50.55, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.132.clone.1 = f32[]{:T(128)} reduce(%integer_pow.70.clone.1, %constant.1003), dimensions={0,1,2}, to_apply=%region_36.41, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.139 = (f32[]{:T(128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.129, %add.826.clone.1, %add.829.clone.1, %add.830.clone.1, %reduce.132.clone.1) + %constant.960.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.554.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} broadcast(%constant.960.clone.1), dimensions={}, metadata={op_name="broadcast.54"} + %add.814.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%sqrt.68.clone.1, %broadcast.554.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.193.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%div.801.clone.1, %add.814.clone.1), metadata={op_name="multiply.23"} + %div.798.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} divide(%add.816.clone.1, %multiply.193.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1884.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%param_0.1123, %broadcast.559.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.813.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%div.798.clone.1, %mul.1884.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1883.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%mul.1885.clone.1, %add.813.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.812.clone.1 = f32[4096,4,14336]{2,1,0:T(4,128)} add(%param_0.1123, %mul.1883.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.173 = f32[4096,4,14336]{2,1,0:T(4,128)} multiply(%add.812.clone.1, %add.812.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.995 = f32[]{:T(128)} constant(0) + %reduce.90 = f32[]{:T(128)} reduce(%square.173, %constant.995), dimensions={0,1,2}, to_apply=%region_50.55, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.93.clone.1 = f32[]{:T(128)} reduce(%integer_pow.70.clone.1, %constant.995), dimensions={0,1,2}, to_apply=%region_36.41, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.139 = (f32[]{:T(128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[4096,4,14336]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.90, %add.812.clone.1, %add.815.clone.1, %add.816.clone.1, %reduce.93.clone.1) } -%region_30.35 (reduce_sum.178: f32[], reduce_sum.182: f32[]) -> f32[] { - %reduce_sum.178 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.182 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.183 = f32[]{:T(128)} add(%reduce_sum.178, %reduce_sum.182), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_30.35 (reduce_sum.235: f32[], reduce_sum.236: f32[]) -> f32[] { + %reduce_sum.235 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.236 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.237 = f32[]{:T(128)} add(%reduce_sum.235, %reduce_sum.236), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.267 (param_0.1120: f32[4,4096,32,128]) -> f32[] { - %param_0.1120 = f32[4,4096,32,128]{3,2,0,1:T(8,128)} parameter(0) - %bitcast.329 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1120), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.267 (param_0.1127: f32[4,4096,32,128]) -> f32[] { + %param_0.1127 = f32[4,4096,32,128]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.329 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1127), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.176 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%bitcast.329, %bitcast.329), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1007 = f32[]{:T(128)} constant(0) - ROOT %reduce.133 = f32[]{:T(128)} reduce(%square.176, %constant.1007), dimensions={0,1,2,3}, to_apply=%region_30.35, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %constant.999 = f32[]{:T(128)} constant(0) + ROOT %reduce.94 = f32[]{:T(128)} reduce(%square.176, %constant.999), dimensions={0,1,2,3}, to_apply=%region_30.35, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_29.34 (reduce_sum.175: f32[], reduce_sum.176: f32[]) -> f32[] { - %reduce_sum.175 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.176 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.177 = f32[]{:T(128)} add(%reduce_sum.175, %reduce_sum.176), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_29.34 (reduce_sum.229: f32[], reduce_sum.230: f32[]) -> f32[] { + %reduce_sum.229 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.230 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.231 = f32[]{:T(128)} add(%reduce_sum.229, %reduce_sum.230), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.269 (param_0.1121: f32[4,32,128,4096]) -> f32[] { - %param_0.1121 = f32[4,32,128,4096]{3,2,0,1:T(8,128)} parameter(0) - %bitcast.333 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} bitcast(%param_0.1121), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.269 (param_0.1128: f32[4,32,128,4096]) -> f32[] { + %param_0.1128 = f32[4,32,128,4096]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.333 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} bitcast(%param_0.1128), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.179 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%bitcast.333, %bitcast.333), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1008 = f32[]{:T(128)} constant(0) - ROOT %reduce.134 = f32[]{:T(128)} reduce(%square.179, %constant.1008), dimensions={0,1,2,3}, to_apply=%region_29.34, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %constant.1000 = f32[]{:T(128)} constant(0) + ROOT %reduce.95 = f32[]{:T(128)} reduce(%square.179, %constant.1000), dimensions={0,1,2,3}, to_apply=%region_29.34, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%fused_computation.270 (param_0.748: f32[32,4,128,4096]) -> bf16[4,32,128,4096] { - %param_0.748 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(0) - %copy.237 = bf16[32,4,128,4096]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.748), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'out\'][\'kernel\']"} +%fused_computation.270 (param_0.753: f32[32,4,128,4096]) -> bf16[4,32,128,4096] { + %param_0.753 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(0) + %copy.237 = bf16[32,4,128,4096]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.753), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'out\'][\'kernel\']"} ROOT %bitcast.334 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.237), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%region_57.62 (reduce_sum.317: f32[], reduce_sum.318: f32[]) -> f32[] { - %reduce_sum.317 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.318 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.316 = f32[]{:T(128)} add(%reduce_sum.317, %reduce_sum.318), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_57.62 (reduce_sum.370: f32[], reduce_sum.371: f32[]) -> f32[] { + %reduce_sum.370 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.371 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.375 = f32[]{:T(128)} add(%reduce_sum.370, %reduce_sum.371), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_43.48 (reduce_sum.241: f32[], reduce_sum.245: f32[]) -> f32[] { - %reduce_sum.241 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.245 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.246 = f32[]{:T(128)} add(%reduce_sum.241, %reduce_sum.245), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_43.48 (reduce_sum.298: f32[], reduce_sum.299: f32[]) -> f32[] { + %reduce_sum.298 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.299 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.300 = f32[]{:T(128)} add(%reduce_sum.298, %reduce_sum.299), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.271 (param_0.1109: f32[4096,4,32,128], param_1.1272: f32[], param_2.1103: f32[], param_3.792: f32[], param_4.493: f32[4096,4,32,128], param_5.418: f32[], param_6.290: f32[4,4096,32,128], param_7.189: pred[], param_8.107: f32[4096,4,32,128]) -> (f32[], f32[4096,4,32,128], f32[4096,4,32,128], f32[4096,4,32,128], f32[]) { - %param_0.1109 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(0) - %param_3.792 = f32[]{:T(128)S(6)} parameter(3) - %mul.1509.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_3.792), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.271 (param_0.1116: f32[4096,4,32,128], param_1.1276: f32[], param_2.1097: f32[], param_3.780: f32[], param_4.487: f32[4096,4,32,128], param_5.418: f32[], param_6.291: f32[4,4096,32,128], param_7.189: pred[], param_8.107: f32[4096,4,32,128]) -> (f32[], f32[4096,4,32,128], f32[4096,4,32,128], f32[4096,4,32,128], f32[]) { + %param_0.1116 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.780 = f32[]{:T(128)S(6)} parameter(3) + %mul.1827.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_3.780), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.189 = pred[]{:T(512)S(6)} parameter(7) %select_n.254.clone.1 = pred[4096,4,32,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.189), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.290 = f32[4,4096,32,128]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.413.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} bitcast(%param_6.290), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.291 = f32[4,4096,32,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.413.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} bitcast(%param_6.291), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.418 = f32[]{:T(128)} parameter(5) %div.749.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_5.418), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.748.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%bitcast.413.clone.1, %div.749.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.253.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} select(%select_n.254.clone.1, %bitcast.413.clone.1, %div.748.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.925.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.564.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.925.clone.1), dimensions={}, metadata={op_name="broadcast.63"} - %mul.1515.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%select_n.253.clone.1, %broadcast.564.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.917.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.525.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.917.clone.1), dimensions={}, metadata={op_name="broadcast.63"} + %mul.1833.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%select_n.253.clone.1, %broadcast.525.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.107 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(8) - %constant.929.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1516.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.929.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1514.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_8.107, %mul.1516.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.793.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%mul.1515.clone.1, %mul.1514.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1103 = f32[]{:T(128)S(6)} parameter(2) - %div.745.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1103), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.921.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1834.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.921.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1832.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_8.107, %mul.1834.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.779.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%mul.1833.clone.1, %mul.1832.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1097 = f32[]{:T(128)S(6)} parameter(2) + %div.745.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1097), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.63.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%select_n.253.clone.1, %select_n.253.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.928.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1513.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.928.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1511.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.63.clone.1, %mul.1513.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.493 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(4) - %constant.927.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1512.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.927.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1510.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_4.493, %mul.1512.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.792.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%mul.1511.clone.1, %mul.1510.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1272 = f32[]{:T(128)S(6)} parameter(1) - %div.744.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1272), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.743.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%add.792.clone.1, %div.744.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.920.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1831.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.920.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1829.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.63.clone.1, %mul.1831.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.487 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.919.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1830.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.919.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1828.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_4.487, %mul.1830.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.778.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%mul.1829.clone.1, %mul.1828.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1276 = f32[]{:T(128)S(6)} parameter(1) + %div.744.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1276), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.743.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%add.778.clone.1, %div.744.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.61.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} sqrt(%div.743.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.926.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.791.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.926.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.790.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%sqrt.61.clone.1, %add.791.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.259.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%div.745.clone.1, %add.790.clone.1), metadata={op_name="multiply.39"} - %div.742.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%add.793.clone.1, %multiply.259.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1508.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_0.1109, %broadcast.564.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.789.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%div.742.clone.1, %mul.1508.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1507.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%mul.1509.clone.1, %add.789.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.788.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%param_0.1109, %mul.1507.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.180 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%add.788.clone.1, %add.788.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.996 = f32[]{:T(128)} constant(0) - %reduce.135 = f32[]{:T(128)} reduce(%square.180, %constant.996), dimensions={0,1,2,3}, to_apply=%region_57.62, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.139.clone.1 = f32[]{:T(128)} reduce(%integer_pow.63.clone.1, %constant.996), dimensions={0,1,2,3}, to_apply=%region_43.48, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.140 = (f32[]{:T(128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.135, %add.788.clone.1, %add.792.clone.1, %add.793.clone.1, %reduce.139.clone.1) -} - -%region_56.61 (reduce_sum.310: f32[], reduce_sum.311: f32[]) -> f32[] { - %reduce_sum.310 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.311 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.315 = f32[]{:T(128)} add(%reduce_sum.310, %reduce_sum.311), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_42.47 (reduce_sum.238: f32[], reduce_sum.239: f32[]) -> f32[] { - %reduce_sum.238 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.239 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.240 = f32[]{:T(128)} add(%reduce_sum.238, %reduce_sum.239), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.272 (param_0.1110: f32[32,4,128,4096], param_1.1273: f32[], param_2.1104: f32[], param_3.793: f32[], param_4.494: f32[32,4,128,4096], param_5.419: f32[], param_6.291: f32[4,32,128,4096], param_7.190: pred[], param_8.108: f32[32,4,128,4096]) -> (f32[], f32[32,4,128,4096], f32[32,4,128,4096], f32[32,4,128,4096], f32[]) { - %param_0.1110 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(0) - %param_3.793 = f32[]{:T(128)S(6)} parameter(3) - %mul.1519.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_3.793), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.918.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.777.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} broadcast(%constant.918.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.776.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%sqrt.61.clone.1, %add.777.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.186.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%div.745.clone.1, %add.776.clone.1), metadata={op_name="multiply.30"} + %div.742.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} divide(%add.779.clone.1, %multiply.186.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1826.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%param_0.1116, %broadcast.525.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.775.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%div.742.clone.1, %mul.1826.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1825.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%mul.1827.clone.1, %add.775.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.774.clone.1 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} add(%param_0.1116, %mul.1825.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.180 = f32[4096,4,32,128]{3,2,1,0:T(8,128)} multiply(%add.774.clone.1, %add.774.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.988 = f32[]{:T(128)} constant(0) + %reduce.96 = f32[]{:T(128)} reduce(%square.180, %constant.988), dimensions={0,1,2,3}, to_apply=%region_57.62, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.100.clone.1 = f32[]{:T(128)} reduce(%integer_pow.63.clone.1, %constant.988), dimensions={0,1,2,3}, to_apply=%region_43.48, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.140 = (f32[]{:T(128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[4096,4,32,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.96, %add.774.clone.1, %add.778.clone.1, %add.779.clone.1, %reduce.100.clone.1) +} + +%region_56.61 (reduce_sum.364: f32[], reduce_sum.368: f32[]) -> f32[] { + %reduce_sum.364 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.368 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.369 = f32[]{:T(128)} add(%reduce_sum.364, %reduce_sum.368), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_42.47 (reduce_sum.292: f32[], reduce_sum.293: f32[]) -> f32[] { + %reduce_sum.292 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.293 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.294 = f32[]{:T(128)} add(%reduce_sum.292, %reduce_sum.293), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.272 (param_0.1117: f32[32,4,128,4096], param_1.1277: f32[], param_2.1098: f32[], param_3.781: f32[], param_4.488: f32[32,4,128,4096], param_5.419: f32[], param_6.292: f32[4,32,128,4096], param_7.190: pred[], param_8.108: f32[32,4,128,4096]) -> (f32[], f32[32,4,128,4096], f32[32,4,128,4096], f32[32,4,128,4096], f32[]) { + %param_0.1117 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(0) + %param_3.781 = f32[]{:T(128)S(6)} parameter(3) + %mul.1837.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_3.781), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.190 = pred[]{:T(512)S(6)} parameter(7) %select_n.258.clone.1 = pred[32,4,128,4096]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.190), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.291 = f32[4,32,128,4096]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.415.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} bitcast(%param_6.291), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.292 = f32[4,32,128,4096]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.415.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} bitcast(%param_6.292), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.419 = f32[]{:T(128)} parameter(5) %div.757.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_5.419), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.756.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%bitcast.415.clone.1, %div.757.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.257.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} select(%select_n.258.clone.1, %bitcast.415.clone.1, %div.756.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.931.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.566.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.931.clone.1), dimensions={}, metadata={op_name="broadcast.64"} - %mul.1525.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%select_n.257.clone.1, %broadcast.566.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.923.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.527.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.923.clone.1), dimensions={}, metadata={op_name="broadcast.64"} + %mul.1843.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%select_n.257.clone.1, %broadcast.527.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.108 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(8) - %constant.935.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1526.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.935.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1524.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_8.108, %mul.1526.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.799.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%mul.1525.clone.1, %mul.1524.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1104 = f32[]{:T(128)S(6)} parameter(2) - %div.753.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_2.1104), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.927.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1844.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.927.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1842.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_8.108, %mul.1844.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.785.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%mul.1843.clone.1, %mul.1842.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1098 = f32[]{:T(128)S(6)} parameter(2) + %div.753.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_2.1098), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.64.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%select_n.257.clone.1, %select_n.257.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.934.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1523.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.934.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1521.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%integer_pow.64.clone.1, %mul.1523.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.494 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(4) - %constant.933.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1522.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.933.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1520.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_4.494, %mul.1522.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.798.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%mul.1521.clone.1, %mul.1520.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1273 = f32[]{:T(128)S(6)} parameter(1) - %div.752.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_1.1273), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.751.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%add.798.clone.1, %div.752.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.926.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1841.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.926.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1839.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%integer_pow.64.clone.1, %mul.1841.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.488 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} parameter(4) + %constant.925.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1840.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.925.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1838.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_4.488, %mul.1840.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.784.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%mul.1839.clone.1, %mul.1838.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1277 = f32[]{:T(128)S(6)} parameter(1) + %div.752.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%param_1.1277), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.751.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%add.784.clone.1, %div.752.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.62.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} sqrt(%div.751.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.932.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.797.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.932.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.796.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%sqrt.62.clone.1, %add.797.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.260.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%div.753.clone.1, %add.796.clone.1), metadata={op_name="multiply.38"} - %div.750.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%add.799.clone.1, %multiply.260.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1518.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_0.1110, %broadcast.566.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.795.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%div.750.clone.1, %mul.1518.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1517.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%mul.1519.clone.1, %add.795.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.794.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%param_0.1110, %mul.1517.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.181 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%add.794.clone.1, %add.794.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.997 = f32[]{:T(128)} constant(0) - %reduce.136 = f32[]{:T(128)} reduce(%square.181, %constant.997), dimensions={0,1,2,3}, to_apply=%region_56.61, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.140.clone.1 = f32[]{:T(128)} reduce(%integer_pow.64.clone.1, %constant.997), dimensions={0,1,2,3}, to_apply=%region_42.47, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.141 = (f32[]{:T(128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.136, %add.794.clone.1, %add.798.clone.1, %add.799.clone.1, %reduce.140.clone.1) -} - -%region_47.52 (reduce_sum.262: f32[], reduce_sum.266: f32[]) -> f32[] { - %reduce_sum.262 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.266 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.267 = f32[]{:T(128)} add(%reduce_sum.262, %reduce_sum.266), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.279 (param_0.1129: bf16[4,128,128256], param_1.1288: f32[4,128], param_2.1115: s32[4,128], param_3.803: bf16[4,128]) -> f32[4,128] { - %param_2.1115 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %eq.30 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1115), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %constant.924.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.783.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} broadcast(%constant.924.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.782.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%sqrt.62.clone.1, %add.783.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.187.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%div.753.clone.1, %add.782.clone.1), metadata={op_name="multiply.29"} + %div.750.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} divide(%add.785.clone.1, %multiply.187.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1836.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%param_0.1117, %broadcast.527.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.781.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%div.750.clone.1, %mul.1836.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1835.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%mul.1837.clone.1, %add.781.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.780.clone.1 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} add(%param_0.1117, %mul.1835.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.181 = f32[32,4,128,4096]{3,2,1,0:T(8,128)} multiply(%add.780.clone.1, %add.780.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.989 = f32[]{:T(128)} constant(0) + %reduce.97 = f32[]{:T(128)} reduce(%square.181, %constant.989), dimensions={0,1,2,3}, to_apply=%region_56.61, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.101.clone.1 = f32[]{:T(128)} reduce(%integer_pow.64.clone.1, %constant.989), dimensions={0,1,2,3}, to_apply=%region_42.47, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.141 = (f32[]{:T(128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[32,4,128,4096]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.97, %add.780.clone.1, %add.784.clone.1, %add.785.clone.1, %reduce.101.clone.1) +} + +%region_47.52 (reduce_sum.319: f32[], reduce_sum.320: f32[]) -> f32[] { + %reduce_sum.319 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.320 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.321 = f32[]{:T(128)} add(%reduce_sum.319, %reduce_sum.320), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.279 (param_0.1136: bf16[4,128,128256], param_1.1292: f32[4,128], param_2.1109: s32[4,128], param_3.791: bf16[4,128]) -> f32[4,128] { + %param_2.1109 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %eq.30 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1109), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.25 = s32[4,128,128256]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.24 = pred[4,128,128256]{2,1,0:T(8,128)(4,1)} compare(%eq.30, %eq.25), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %param_0.1129 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.950 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1129), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_3.803 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) - %sub.73 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.803), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.64 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.950, %sub.73), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %param_1.1288 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %sub.71 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1288), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_0.1136 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.966 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1136), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_3.791 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) + %sub.73 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.791), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.64 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.966, %sub.73), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_1.1292 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %sub.71 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1292), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %sub.60 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%sub.64, %sub.71), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %constant.1017 = f32[]{:T(128)} constant(0) - %broadcast.511 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%constant.1017), dimensions={}, metadata={op_name="broadcast.83"} - %mul.1373 = f32[4,128,128256]{2,1,0:T(8,128)} select(%eq.24, %sub.60, %broadcast.511), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - ROOT %reduce.137 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1373, %constant.1017), dimensions={2}, to_apply=%region_47.52, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %constant.1009 = f32[]{:T(128)} constant(0) + %broadcast.472 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%constant.1009), dimensions={}, metadata={op_name="broadcast.39"} + %mul.1674 = f32[4,128,128256]{2,1,0:T(8,128)} select(%eq.24, %sub.60, %broadcast.472), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + ROOT %reduce.98 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1674, %constant.1009), dimensions={2}, to_apply=%region_47.52, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } -%region_7.10 (reduce_sum.93: f32[], reduce_sum.94: f32[]) -> f32[] { - %reduce_sum.93 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.94 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.95 = f32[]{:T(128)} add(%reduce_sum.93, %reduce_sum.94), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_7.10 (reduce_sum.123: f32[], reduce_sum.127: f32[]) -> f32[] { + %reduce_sum.123 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.127 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.128 = f32[]{:T(128)} add(%reduce_sum.123, %reduce_sum.127), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.284 (param_0.1130: bf16[4,128,128256], param_1.1289: bf16[4,128]) -> f32[4,128] { - %param_0.1130 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.956 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1130), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_1.1289 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) - %sub.74 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1289), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.70 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.956, %sub.74), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} +%fused_computation.284 (param_0.1137: bf16[4,128,128256], param_1.1293: bf16[4,128]) -> f32[4,128] { + %param_0.1137 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.972 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1137), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_1.1293 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) + %sub.74 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1293), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.70 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.972, %sub.74), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %exp.54 = f32[4,128,128256]{2,1,0:T(8,128)} exponential(%sub.70), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %constant.1018 = f32[]{:T(128)} constant(0) - ROOT %reduce.138 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.54, %constant.1018), dimensions={2}, to_apply=%region_7.10, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %constant.1010 = f32[]{:T(128)} constant(0) + ROOT %reduce.99 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.54, %constant.1010), dimensions={2}, to_apply=%region_7.10, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } -%region_31.36 (reduce_sum.184: f32[], reduce_sum.185: f32[]) -> f32[] { - %reduce_sum.184 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.185 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.189 = f32[]{:T(128)} add(%reduce_sum.184, %reduce_sum.185), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_31.36 (reduce_sum.238: f32[], reduce_sum.242: f32[]) -> f32[] { + %reduce_sum.238 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.242 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.243 = f32[]{:T(128)} add(%reduce_sum.238, %reduce_sum.242), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_28.33 (reduce_sum.169: f32[], reduce_sum.170: f32[]) -> f32[] { - %reduce_sum.169 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.170 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.171 = f32[]{:T(128)} add(%reduce_sum.169, %reduce_sum.170), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_28.33 (reduce_sum.223: f32[], reduce_sum.224: f32[]) -> f32[] { + %reduce_sum.223 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.224 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.228 = f32[]{:T(128)} add(%reduce_sum.223, %reduce_sum.224), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.290 (param_0.1122: f32[4,4096,8,128], param_1.1282: f32[4,4096,8,128]) -> (f32[], f32[]) { - %param_0.1122 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(0) - %bitcast.350 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1122), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.290 (param_0.1129: f32[4,4096,8,128], param_1.1286: f32[4,4096,8,128]) -> (f32[], f32[]) { + %param_0.1129 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.350 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1129), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.184 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.350, %bitcast.350), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1009 = f32[]{:T(128)} constant(0) - %reduce.141 = f32[]{:T(128)} reduce(%square.184, %constant.1009), dimensions={0,1,2,3}, to_apply=%region_31.36, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1282 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(1) - %bitcast.354.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_1.1282), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %constant.1001 = f32[]{:T(128)} constant(0) + %reduce.102 = f32[]{:T(128)} reduce(%square.184, %constant.1001), dimensions={0,1,2,3}, to_apply=%region_31.36, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1286 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(1) + %bitcast.354.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_1.1286), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.187.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.354.clone.1, %bitcast.354.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.142.clone.1 = f32[]{:T(128)} reduce(%square.187.clone.1, %constant.1009), dimensions={0,1,2,3}, to_apply=%region_28.33, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.156 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.141, %reduce.142.clone.1) + %reduce.103.clone.1 = f32[]{:T(128)} reduce(%square.187.clone.1, %constant.1001), dimensions={0,1,2,3}, to_apply=%region_28.33, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.156 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.102, %reduce.103.clone.1) } -%fused_computation.293 (param_0.807: f32[4096,4,8,128]) -> bf16[4,4096,8,128] { - %param_0.807 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) - %copy.238 = bf16[4096,4,8,128]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.807), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'value\'][\'kernel\']"} +%fused_computation.293 (param_0.812: f32[4096,4,8,128]) -> bf16[4,4096,8,128] { + %param_0.812 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %copy.238 = bf16[4096,4,8,128]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.812), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'value\'][\'kernel\']"} ROOT %bitcast.355 = bf16[4,4096,8,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%copy.238), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%region_58.63 (reduce_sum.324: f32[], reduce_sum.325: f32[]) -> f32[] { - %reduce_sum.324 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.325 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.322 = f32[]{:T(128)} add(%reduce_sum.324, %reduce_sum.325), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_58.63 (reduce_sum.376: f32[], reduce_sum.377: f32[]) -> f32[] { + %reduce_sum.376 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.377 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.378 = f32[]{:T(128)} add(%reduce_sum.376, %reduce_sum.377), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_44.49 (reduce_sum.247: f32[], reduce_sum.248: f32[]) -> f32[] { - %reduce_sum.247 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.248 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.252 = f32[]{:T(128)} add(%reduce_sum.247, %reduce_sum.248), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_44.49 (reduce_sum.301: f32[], reduce_sum.305: f32[]) -> f32[] { + %reduce_sum.301 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.305 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.306 = f32[]{:T(128)} add(%reduce_sum.301, %reduce_sum.305), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.294 (param_0.1108: f32[4096,4,8,128], param_1.1271: f32[], param_2.1102: f32[], param_3.791: f32[], param_4.492: f32[4096,4,8,128], param_5.417: f32[], param_6.289: f32[4,4096,8,128], param_7.188: pred[], param_8.106: f32[4096,4,8,128]) -> (f32[], f32[4096,4,8,128], f32[4096,4,8,128], f32[4096,4,8,128], f32[]) { - %param_0.1108 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) - %param_3.791 = f32[]{:T(128)S(6)} parameter(3) - %mul.1502.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.791), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.294 (param_0.1115: f32[4096,4,8,128], param_1.1275: f32[], param_2.1096: f32[], param_3.779: f32[], param_4.486: f32[4096,4,8,128], param_5.417: f32[], param_6.290: f32[4,4096,8,128], param_7.188: pred[], param_8.106: f32[4096,4,8,128]) -> (f32[], f32[4096,4,8,128], f32[4096,4,8,128], f32[4096,4,8,128], f32[]) { + %param_0.1115 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.779 = f32[]{:T(128)S(6)} parameter(3) + %mul.1820.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.779), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.188 = pred[]{:T(512)S(6)} parameter(7) %select_n.250.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.188), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.289 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.411.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.289), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.290 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.411.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.290), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.417 = f32[]{:T(128)} parameter(5) %div.741.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.417), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.740.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.411.clone.1, %div.741.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.249.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.250.clone.1, %bitcast.411.clone.1, %div.740.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.919.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.562.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.919.clone.1), dimensions={}, metadata={op_name="broadcast.66"} - %mul.1506.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.249.clone.1, %broadcast.562.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.911.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.523.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.911.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.1824.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.249.clone.1, %broadcast.523.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.106 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(8) - %constant.923.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.561.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.923.clone.1), dimensions={}, metadata={op_name="broadcast.65"} - %mul.1505.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.106, %broadcast.561.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.787.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1506.clone.1, %mul.1505.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1102 = f32[]{:T(128)S(6)} parameter(2) - %div.737.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1102), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.915.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.522.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.915.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %mul.1823.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.106, %broadcast.522.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.773.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1824.clone.1, %mul.1823.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1096 = f32[]{:T(128)S(6)} parameter(2) + %div.737.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1096), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.62.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.249.clone.1, %select_n.249.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.922.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.560.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.922.clone.1), dimensions={}, metadata={op_name="broadcast.56"} - %mul.1504.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.62.clone.1, %broadcast.560.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.492 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(4) - %constant.921.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.559.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.921.clone.1), dimensions={}, metadata={op_name="broadcast.55"} - %mul.1503.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.492, %broadcast.559.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.786.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1504.clone.1, %mul.1503.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1271 = f32[]{:T(128)S(6)} parameter(1) - %div.736.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1271), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.735.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.786.clone.1, %div.736.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.914.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.521.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.914.clone.1), dimensions={}, metadata={op_name="broadcast.56"} + %mul.1822.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.62.clone.1, %broadcast.521.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.486 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.913.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.520.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.913.clone.1), dimensions={}, metadata={op_name="broadcast.55"} + %mul.1821.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.486, %broadcast.520.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.772.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1822.clone.1, %mul.1821.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1275 = f32[]{:T(128)S(6)} parameter(1) + %div.736.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1275), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.735.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.772.clone.1, %div.736.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.60.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.735.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.920.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.557.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.920.clone.1), dimensions={}, metadata={op_name="broadcast.52"} - %add.785.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.60.clone.1, %broadcast.557.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.258.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.737.clone.1, %add.785.clone.1), metadata={op_name="multiply.40"} - %div.734.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.787.clone.1, %multiply.258.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1501.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1108, %broadcast.562.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.784.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%div.734.clone.1, %mul.1501.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1500.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1502.clone.1, %add.784.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.783.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1108, %mul.1500.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.188 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.783.clone.1, %add.783.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.995 = f32[]{:T(128)} constant(0) - %reduce.143 = f32[]{:T(128)} reduce(%square.188, %constant.995), dimensions={0,1,2,3}, to_apply=%region_58.63, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.145.clone.1 = f32[]{:T(128)} reduce(%integer_pow.62.clone.1, %constant.995), dimensions={0,1,2,3}, to_apply=%region_44.49, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.142 = (f32[]{:T(128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.143, %add.783.clone.1, %add.786.clone.1, %add.787.clone.1, %reduce.145.clone.1) -} - -%region_55.60 (reduce_sum.304: f32[], reduce_sum.308: f32[]) -> f32[] { - %reduce_sum.304 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.308 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.309 = f32[]{:T(128)} add(%reduce_sum.304, %reduce_sum.308), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_41.46 (reduce_sum.232: f32[], reduce_sum.233: f32[]) -> f32[] { - %reduce_sum.232 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.233 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.234 = f32[]{:T(128)} add(%reduce_sum.232, %reduce_sum.233), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.912.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.518.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.912.clone.1), dimensions={}, metadata={op_name="broadcast.52"} + %add.771.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.60.clone.1, %broadcast.518.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.185.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.737.clone.1, %add.771.clone.1), metadata={op_name="multiply.31"} + %div.734.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.773.clone.1, %multiply.185.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1819.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1115, %broadcast.523.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.770.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%div.734.clone.1, %mul.1819.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1818.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1820.clone.1, %add.770.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.769.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1115, %mul.1818.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.188 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.769.clone.1, %add.769.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.987 = f32[]{:T(128)} constant(0) + %reduce.104 = f32[]{:T(128)} reduce(%square.188, %constant.987), dimensions={0,1,2,3}, to_apply=%region_58.63, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.106.clone.1 = f32[]{:T(128)} reduce(%integer_pow.62.clone.1, %constant.987), dimensions={0,1,2,3}, to_apply=%region_44.49, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.142 = (f32[]{:T(128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.104, %add.769.clone.1, %add.772.clone.1, %add.773.clone.1, %reduce.106.clone.1) +} + +%region_55.60 (reduce_sum.361: f32[], reduce_sum.362: f32[]) -> f32[] { + %reduce_sum.361 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.362 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.363 = f32[]{:T(128)} add(%reduce_sum.361, %reduce_sum.362), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_41.46 (reduce_sum.286: f32[], reduce_sum.287: f32[]) -> f32[] { + %reduce_sum.286 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.287 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.291 = f32[]{:T(128)} add(%reduce_sum.286, %reduce_sum.287), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.295 (param_0.1111: f32[4096,4,8,128], param_1.1274: f32[], param_2.1105: f32[], param_3.794: f32[], param_4.495: f32[4096,4,8,128], param_5.420: f32[], param_6.292: f32[4,4096,8,128], param_7.191: pred[], param_8.109: f32[4096,4,8,128]) -> (f32[], f32[4096,4,8,128], f32[4096,4,8,128], f32[4096,4,8,128], f32[]) { - %param_0.1111 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) - %param_3.794 = f32[]{:T(128)S(6)} parameter(3) - %mul.1529.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.794), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.295 (param_0.1118: f32[4096,4,8,128], param_1.1278: f32[], param_2.1099: f32[], param_3.782: f32[], param_4.489: f32[4096,4,8,128], param_5.420: f32[], param_6.293: f32[4,4096,8,128], param_7.191: pred[], param_8.109: f32[4096,4,8,128]) -> (f32[], f32[4096,4,8,128], f32[4096,4,8,128], f32[4096,4,8,128], f32[]) { + %param_0.1118 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.782 = f32[]{:T(128)S(6)} parameter(3) + %mul.1847.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.782), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.191 = pred[]{:T(512)S(6)} parameter(7) %select_n.262.clone.1 = pred[4096,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.191), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.292 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.417.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.292), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.293 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.417.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.293), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.420 = f32[]{:T(128)} parameter(5) %div.765.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.420), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.764.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.417.clone.1, %div.765.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.261.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.262.clone.1, %bitcast.417.clone.1, %div.764.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.937.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.572.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.937.clone.1), dimensions={}, metadata={op_name="broadcast.66"} - %mul.1533.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.261.clone.1, %broadcast.572.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.929.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.533.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.929.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.1851.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.261.clone.1, %broadcast.533.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.109 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(8) - %constant.941.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.571.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.941.clone.1), dimensions={}, metadata={op_name="broadcast.65"} - %mul.1532.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.109, %broadcast.571.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.804.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1533.clone.1, %mul.1532.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1105 = f32[]{:T(128)S(6)} parameter(2) - %div.761.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1105), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.933.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.532.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.933.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %mul.1850.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.109, %broadcast.532.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.790.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1851.clone.1, %mul.1850.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1099 = f32[]{:T(128)S(6)} parameter(2) + %div.761.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1099), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.65.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.261.clone.1, %select_n.261.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.940.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.570.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.940.clone.1), dimensions={}, metadata={op_name="broadcast.56"} - %mul.1531.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.65.clone.1, %broadcast.570.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.495 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(4) - %constant.939.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.569.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.939.clone.1), dimensions={}, metadata={op_name="broadcast.55"} - %mul.1530.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.495, %broadcast.569.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.803.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1531.clone.1, %mul.1530.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1274 = f32[]{:T(128)S(6)} parameter(1) - %div.760.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1274), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.759.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.803.clone.1, %div.760.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.932.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.531.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.932.clone.1), dimensions={}, metadata={op_name="broadcast.56"} + %mul.1849.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.65.clone.1, %broadcast.531.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.489 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.931.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.530.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.931.clone.1), dimensions={}, metadata={op_name="broadcast.55"} + %mul.1848.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.489, %broadcast.530.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.789.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1849.clone.1, %mul.1848.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1278 = f32[]{:T(128)S(6)} parameter(1) + %div.760.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1278), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.759.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.789.clone.1, %div.760.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.63.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.759.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.938.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.567.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.938.clone.1), dimensions={}, metadata={op_name="broadcast.52"} - %add.802.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.63.clone.1, %broadcast.567.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.261.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.761.clone.1, %add.802.clone.1), metadata={op_name="multiply.37"} - %div.758.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.804.clone.1, %multiply.261.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1528.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1111, %broadcast.572.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.801.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%div.758.clone.1, %mul.1528.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1527.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1529.clone.1, %add.801.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.800.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1111, %mul.1527.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.189 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.800.clone.1, %add.800.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.998 = f32[]{:T(128)} constant(0) - %reduce.144 = f32[]{:T(128)} reduce(%square.189, %constant.998), dimensions={0,1,2,3}, to_apply=%region_55.60, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.146.clone.1 = f32[]{:T(128)} reduce(%integer_pow.65.clone.1, %constant.998), dimensions={0,1,2,3}, to_apply=%region_41.46, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.143 = (f32[]{:T(128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.144, %add.800.clone.1, %add.803.clone.1, %add.804.clone.1, %reduce.146.clone.1) -} - -%fused_computation.311 (param_0.872: bf16[4,128,4096], param_1.941: f32[4,128], param_2.726: f32[4,128], param_3.452: bf16[4,128,4096], param_4.271: bf16[4096]) -> bf16[4,128,4096] { - %param_3.452 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %param_4.271 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %dot_general.375 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_4.271), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %dot_general.365 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_3.452, %dot_general.375), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %convert_element_type.973 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%dot_general.365), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_2.726 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %mul.1423 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_2.726), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1415 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.973, %mul.1423), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %param_0.872 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.984 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.872), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_1.941 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %mul.1422 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.941), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1421 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.984, %mul.1422), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %add_any.138 = f32[4,128,4096]{2,1,0:T(8,128)} add(%mul.1415, %mul.1421), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add_any" stack_frame_id=0} - ROOT %convert_element_type.971 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%add_any.138), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} -} - -%region_5.8 (reduce_sum.87: f32[], reduce_sum.88: f32[]) -> f32[] { - %reduce_sum.87 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - %reduce_sum.88 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - ROOT %reduce_sum.92 = f32[]{:T(128)} add(%reduce_sum.87, %reduce_sum.88), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.312 (param_0.1131: bf16[4,128,4096]) -> f32[4,128] { - %param_0.1131 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.975 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1131), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %square.192 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.975, %convert_element_type.975), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/square" stack_frame_id=0} - %constant.1019 = f32[]{:T(128)} constant(0) - ROOT %reduce.147 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.192, %constant.1019), dimensions={2}, to_apply=%region_5.8, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} -} - -%region_10.13 (reduce_sum.102: f32[], reduce_sum.106: f32[]) -> f32[] { - %reduce_sum.102 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - %reduce_sum.106 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - ROOT %reduce_sum.107 = f32[]{:T(128)} add(%reduce_sum.102, %reduce_sum.106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.314 (param_0.1126: bf16[4,128,4096], param_1.1285: bf16[4,128,4096], param_2.1113: bf16[4096]) -> f32[4,128] { - %param_0.1126 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.982 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1126), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_1.1285 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %param_2.1113 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.374 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1113), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %dot_general.364 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1285, %dot_general.374), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %convert_element_type.981 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%dot_general.364), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %mul.1419 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.982, %convert_element_type.981), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %constant.1013 = f32[]{:T(128)} constant(0) - ROOT %reduce.148 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1419, %constant.1013), dimensions={2}, to_apply=%region_10.13, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} -} - -%region_8.11 (dot_general.182: bf16[], dot_general.183: bf16[]) -> bf16[] { - %dot_general.182 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general"} - %dot_general.183 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general"} - ROOT %add.168 = bf16[]{:T(256)} add(%dot_general.182, %dot_general.183), metadata={op_name="add.54"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.235.clone.clone (param_0.1095: f32[4096,128256]) -> bf16[4096,128256,1] { - %param_0.1095 = f32[4096,128256]{1,0:T(8,128)} parameter(0) - %convert_element_type.1033 = bf16[4096,128256]{1,0:T(8,128)(2,1)} convert(%param_0.1095), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} - ROOT %bitcast.449 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} bitcast(%convert_element_type.1033), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} -} - -%fused_computation.280.clone.1.clone.clone (param_0.1096: bf16[4,128,128256], param_1.1261: s32[4,128], param_2.1081: f32[4,128], param_3.782: f32[4,128], param_4.484: bf16[4,128], param_5.409: f32[4,128]) -> bf16[4,128,128256] { + %constant.930.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.528.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.930.clone.1), dimensions={}, metadata={op_name="broadcast.52"} + %add.788.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.63.clone.1, %broadcast.528.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.188.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.761.clone.1, %add.788.clone.1), metadata={op_name="multiply.28"} + %div.758.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} divide(%add.790.clone.1, %multiply.188.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1846.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1118, %broadcast.533.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.787.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%div.758.clone.1, %mul.1846.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1845.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1847.clone.1, %add.787.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.786.clone.1 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1118, %mul.1845.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.189 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.786.clone.1, %add.786.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.990 = f32[]{:T(128)} constant(0) + %reduce.105 = f32[]{:T(128)} reduce(%square.189, %constant.990), dimensions={0,1,2,3}, to_apply=%region_55.60, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.107.clone.1 = f32[]{:T(128)} reduce(%integer_pow.65.clone.1, %constant.990), dimensions={0,1,2,3}, to_apply=%region_41.46, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.143 = (f32[]{:T(128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[4096,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.105, %add.786.clone.1, %add.789.clone.1, %add.790.clone.1, %reduce.107.clone.1) +} + +%fused_computation.311 (param_0.877: bf16[4,128,4096], param_1.943: f32[4,128], param_2.720: f32[4,128], param_3.440: bf16[4,128,4096], param_4.265: bf16[4096]) -> bf16[4,128,4096] { + %param_3.440 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.265 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %mul.1754 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_4.265), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1728 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_3.440, %mul.1754), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.989 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%mul.1728), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.720 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %mul.1725 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_2.720), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1716 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.989, %mul.1725), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.877 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1000 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.877), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_1.943 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.1723 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.943), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1722 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1000, %mul.1723), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %add_any.138 = f32[4,128,4096]{2,1,0:T(8,128)} add(%mul.1716, %mul.1722), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add_any" stack_frame_id=0} + ROOT %convert_element_type.987 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%add_any.138), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} +} + +%region_4.7 (reduce_sum.114: f32[], reduce_sum.115: f32[]) -> f32[] { + %reduce_sum.114 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.115 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.116 = f32[]{:T(128)} add(%reduce_sum.114, %reduce_sum.115), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.312 (param_0.1138: bf16[4,128,4096]) -> f32[4,128] { + %param_0.1138 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.991 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1138), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %square.192 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.991, %convert_element_type.991), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/square" stack_frame_id=0} + %constant.1011 = f32[]{:T(128)} constant(0) + ROOT %reduce.108 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.192, %constant.1011), dimensions={2}, to_apply=%region_4.7, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} +} + +%region_10.13 (reduce_sum.141: f32[], reduce_sum.142: f32[]) -> f32[] { + %reduce_sum.141 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.142 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.143 = f32[]{:T(128)} add(%reduce_sum.141, %reduce_sum.142), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.314 (param_0.1133: bf16[4,128,4096], param_1.1289: bf16[4,128,4096], param_2.1107: bf16[4096]) -> f32[4,128] { + %param_0.1133 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.998 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1133), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_1.1289 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %param_2.1107 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.1753 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1107), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1727 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1289, %mul.1753), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.997 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%mul.1727), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %mul.1720 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.998, %convert_element_type.997), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.1005 = f32[]{:T(128)} constant(0) + ROOT %reduce.109 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1720, %constant.1005), dimensions={2}, to_apply=%region_10.13, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} +} + +%region_8.11 (reduce_sum.129: bf16[], reduce_sum.130: bf16[]) -> bf16[] { + %reduce_sum.129 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.130 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.134 = bf16[]{:T(256)} add(%reduce_sum.129, %reduce_sum.130), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.235.clone.clone (param_0.1102: f32[4096,128256]) -> bf16[4096,128256,1] { + %param_0.1102 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %convert_element_type.1049 = bf16[4096,128256]{1,0:T(8,128)(2,1)} convert(%param_0.1102), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + ROOT %bitcast.449 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} bitcast(%convert_element_type.1049), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} +} + +%fused_computation.280.clone.1.clone.clone (param_0.1103: bf16[4,128,128256], param_1.1265: s32[4,128], param_2.1075: f32[4,128], param_3.770: f32[4,128], param_4.478: bf16[4,128], param_5.409: f32[4,128]) -> bf16[4,128,128256] { %param_5.409 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %mul.1603 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_5.409), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_3.782 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %mul.1602 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.782), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_0.1096 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1036 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1096), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_4.484 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) - %sub.88 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_4.484), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.87 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.1036, %sub.88), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %mul.1925 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_5.409), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.770 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.1924 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_3.770), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.1103 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1052 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%param_0.1103), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_4.478 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.88 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_4.478), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.87 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%convert_element_type.1052, %sub.88), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %exp.60 = f32[4,128,128256]{2,1,0:T(8,128)} exponential(%sub.87), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %mul.1601 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1602, %exp.60), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_2.1081 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %div.819 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1081), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %div.818 = f32[4,128,128256]{2,1,0:T(8,128)} divide(%mul.1601, %div.819), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %param_1.1261 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %eq.43 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1261), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %mul.1923 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1924, %exp.60), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.1075 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.819 = f32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_2.1075), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.818 = f32[4,128,128256]{2,1,0:T(8,128)} divide(%mul.1923, %div.819), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.1265 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.43 = s32[4,128,128256]{2,1,0:T(8,128)} broadcast(%param_1.1265), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.42 = s32[4,128,128256]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.41 = pred[4,128,128256]{2,1,0:T(8,128)(4,1)} compare(%eq.43, %eq.42), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %convert_element_type.1035 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%eq.41), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} - %sub.86 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%div.818, %convert_element_type.1035), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} - %mul.1600 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1603, %sub.86), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - ROOT %convert_element_type.1034 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convert(%mul.1600), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} -} - -%fused_computation.315 (param_0.1094: f32[4,128], param_1.1260: bf16[4,128,4096], param_2.1082: f32[4096,128256], param_3.783: bf16[4,128,128256], param_4.485: s32[4,128], param_5.410: f32[4,128], param_6.284: f32[4,128], param_7.183: bf16[4,128], param_8.102: f32[4,128]) -> (bf16[4096], bf16[4,128,4096]) { - %param_3.783 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(3) - %param_4.485 = s32[4,128]{1,0:T(4,128)S(1)} parameter(4) + %convert_element_type.1051 = f32[4,128,128256]{2,1,0:T(8,128)} convert(%eq.41), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.86 = f32[4,128,128256]{2,1,0:T(8,128)} subtract(%div.818, %convert_element_type.1051), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.1922 = f32[4,128,128256]{2,1,0:T(8,128)} multiply(%mul.1925, %sub.86), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.1050 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convert(%mul.1922), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +} + +%fused_computation.315 (param_0.1101: f32[4,128], param_1.1264: bf16[4,128,4096], param_2.1076: f32[4096,128256], param_3.771: bf16[4,128,128256], param_4.479: s32[4,128], param_5.410: f32[4,128], param_6.285: f32[4,128], param_7.183: bf16[4,128], param_8.102: f32[4,128]) -> (bf16[4096], bf16[4,128,4096]) { + %param_1.1264 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1010 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1264), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1101 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1742 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1101), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1741 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1010, %mul.1742), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1009 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1741), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_3.771 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.479 = s32[4,128]{1,0:T(4,128)S(1)} parameter(4) %param_5.410 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %param_6.284 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_6.285 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) %param_7.183 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(7) %param_8.102 = f32[4,128]{1,0:T(4,128)S(1)} parameter(8) - %multiply_convert_fusion.2.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} fusion(%param_3.783, %param_4.485, %param_5.410, %param_6.284, %param_7.183, /*index=5*/%param_8.102), kind=kLoop, calls=%fused_computation.280.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_2.1082 = f32[4096,128256]{1,0:T(8,128)} parameter(2) - %fusion.219.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} fusion(%param_2.1082), kind=kLoop, calls=%fused_computation.235.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + %multiply_convert_fusion.2.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} fusion(%param_3.771, %param_4.479, %param_5.410, %param_6.285, %param_7.183, /*index=5*/%param_8.102), kind=kLoop, calls=%fused_computation.280.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_2.1076 = f32[4096,128256]{1,0:T(8,128)} parameter(2) + %fusion.219.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} fusion(%param_2.1076), kind=kLoop, calls=%fused_computation.235.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} %convolution.86.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} convolution(%multiply_convert_fusion.2.clone.1, %fusion.219.clone.1), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %param_1.1260 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.994 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1260), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.1094 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1434 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1094), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1433 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.994, %mul.1434), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.993 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1433), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %multiply.252 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convolution.86.clone.1, %convert_element_type.993), metadata={op_name="multiply.206"} - %constant.874 = bf16[]{:T(256)} constant(0) - %reduce.149 = bf16[4096]{0:T(1024)(128)(2,1)} reduce(%multiply.252, %constant.874), dimensions={0,1}, to_apply=%region_8.11, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %tuple.153 = (bf16[4096]{0:T(1024)(128)(2,1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.149, %convolution.86.clone.1) -} - -%fused_computation.323 (param_0.904: f32[64], param_1.974: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { - %param_1.974 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %div.621 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.974), dimensions={0,1}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} - %param_0.904 = f32[64]{0:T(128)S(1)} parameter(0) - %div.619 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.904), dimensions={3}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %mul.1724 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1009, %convolution.86.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.878 = bf16[]{:T(256)} constant(0) + %reduce.110 = bf16[4096]{0:T(1024)(128)(2,1)} reduce(%mul.1724, %constant.878), dimensions={0,1}, to_apply=%region_8.11, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} + ROOT %tuple.153 = (bf16[4096]{0:T(1024)(128)(2,1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.110, %convolution.86.clone.1) +} + +%fused_computation.323 (param_0.911: f32[64], param_1.978: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.978 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %div.621 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.978), dimensions={0,1}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %param_0.911 = f32[64]{0:T(128)S(1)} parameter(0) + %div.619 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.911), dimensions={3}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} %div.618 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.621, %div.619), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} %sin.38 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.618), metadata={op_name="jit(train_step)/layers/sin" stack_frame_id=0} - %convert_element_type.1002 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.38), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %convert_element_type.1018 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.38), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} %cos.41.clone.1 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.618), metadata={op_name="jit(train_step)/layers/cos" stack_frame_id=0} - %convert_element_type.1001.clone.1 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.41.clone.1), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} - ROOT %tuple.150 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1002, %convert_element_type.1001.clone.1) + %convert_element_type.1017.clone.1 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.41.clone.1), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.150 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1018, %convert_element_type.1017.clone.1) } -%fused_computation.324 (param_0.901: bf16[4,128,1,64]) -> bf16[4,128,1,128] { - %param_0.901 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.866 = bf16[]{:T(256)} constant(-inf) - %pad.38 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.901, %constant.866), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} - %pad.37 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.901, %constant.866), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} +%fused_computation.324 (param_0.908: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.908 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.858 = bf16[]{:T(256)} constant(-inf) + %pad.38 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.908, %constant.858), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + %pad.37 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.908, %constant.858), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} ROOT %maximum.34 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.38, %pad.37), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} } -%fused_computation.325 (param_0.903: bf16[4,128,1,64]) -> bf16[4,128,1,128] { - %param_0.903 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.865 = bf16[]{:T(256)} constant(-inf) - %pad.40 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.903, %constant.865), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} - %pad.39 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.903, %constant.865), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} +%fused_computation.325 (param_0.910: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.910 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.857 = bf16[]{:T(256)} constant(-inf) + %pad.40 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.910, %constant.857), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + %pad.39 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.910, %constant.857), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} ROOT %maximum.35 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.40, %pad.39), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} } -%region_27.32 (reduce_sum.163: f32[], reduce_sum.164: f32[]) -> f32[] { - %reduce_sum.163 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.164 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.168 = f32[]{:T(128)} add(%reduce_sum.163, %reduce_sum.164), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_27.32 (reduce_sum.217: f32[], reduce_sum.221: f32[]) -> f32[] { + %reduce_sum.217 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.221 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.222 = f32[]{:T(128)} add(%reduce_sum.217, %reduce_sum.221), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_26.31 (reduce_sum.157: f32[], reduce_sum.161: f32[]) -> f32[] { - %reduce_sum.157 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.161 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.162 = f32[]{:T(128)} add(%reduce_sum.157, %reduce_sum.161), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_26.31 (reduce_sum.214: f32[], reduce_sum.215: f32[]) -> f32[] { + %reduce_sum.214 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.215 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.216 = f32[]{:T(128)} add(%reduce_sum.214, %reduce_sum.215), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.329 (param_0.1123: f32[4,4096], param_1.1283: f32[4,4096]) -> (f32[], f32[]) { - %param_0.1123 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(0) - %bitcast.371 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_0.1123), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.329 (param_0.1130: f32[4,4096], param_1.1287: f32[4,4096]) -> (f32[], f32[]) { + %param_0.1130 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(0) + %bitcast.371 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_0.1130), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.195 = f32[4096,4]{0,1:T(4,128)} multiply(%bitcast.371, %bitcast.371), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1010 = f32[]{:T(128)} constant(0) - %reduce.150 = f32[]{:T(128)} reduce(%square.195, %constant.1010), dimensions={0,1}, to_apply=%region_27.32, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1283 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(1) - %bitcast.375.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_1.1283), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %constant.1002 = f32[]{:T(128)} constant(0) + %reduce.111 = f32[]{:T(128)} reduce(%square.195, %constant.1002), dimensions={0,1}, to_apply=%region_27.32, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1287 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(1) + %bitcast.375.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_1.1287), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %square.198.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%bitcast.375.clone.1, %bitcast.375.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.151.clone.1 = f32[]{:T(128)} reduce(%square.198.clone.1, %constant.1010), dimensions={0,1}, to_apply=%region_26.31, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.157 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.150, %reduce.151.clone.1) + %reduce.112.clone.1 = f32[]{:T(128)} reduce(%square.198.clone.1, %constant.1002), dimensions={0,1}, to_apply=%region_26.31, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.157 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.111, %reduce.112.clone.1) } -%region_54.59 (reduce_sum.301: f32[], reduce_sum.302: f32[]) -> f32[] { - %reduce_sum.301 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.302 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.303 = f32[]{:T(128)} add(%reduce_sum.301, %reduce_sum.302), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_54.59 (reduce_sum.355: f32[], reduce_sum.356: f32[]) -> f32[] { + %reduce_sum.355 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.356 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.357 = f32[]{:T(128)} add(%reduce_sum.355, %reduce_sum.356), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_40.45 (reduce_sum.226: f32[], reduce_sum.227: f32[]) -> f32[] { - %reduce_sum.226 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.227 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.231 = f32[]{:T(128)} add(%reduce_sum.226, %reduce_sum.227), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_40.45 (reduce_sum.280: f32[], reduce_sum.284: f32[]) -> f32[] { + %reduce_sum.280 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.284 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.285 = f32[]{:T(128)} add(%reduce_sum.280, %reduce_sum.284), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.332 (param_0.1112: f32[4096,4], param_1.1275: f32[], param_2.1106: f32[], param_3.795: f32[], param_4.496: f32[4096,4], param_5.421: f32[], param_6.293: f32[4,4096], param_7.192: pred[], param_8.110: f32[4096,4]) -> (f32[], f32[4096,4], f32[4096,4], f32[4096,4], f32[]) { - %param_0.1112 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) - %param_3.795 = f32[]{:T(128)S(6)} parameter(3) - %mul.1536.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_3.795), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.332 (param_0.1119: f32[4096,4], param_1.1279: f32[], param_2.1100: f32[], param_3.783: f32[], param_4.490: f32[4096,4], param_5.421: f32[], param_6.294: f32[4,4096], param_7.192: pred[], param_8.110: f32[4096,4]) -> (f32[], f32[4096,4], f32[4096,4], f32[4096,4], f32[]) { + %param_0.1119 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.783 = f32[]{:T(128)S(6)} parameter(3) + %mul.1854.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_3.783), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.192 = pred[]{:T(512)S(6)} parameter(7) %select_n.266.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.192), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.293 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(6) - %bitcast.419.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_6.293), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.294 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(6) + %bitcast.419.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_6.294), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.421 = f32[]{:T(128)} parameter(5) %div.773.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_5.421), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.772.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%bitcast.419.clone.1, %div.773.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.265.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%select_n.266.clone.1, %bitcast.419.clone.1, %div.772.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.943.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.578.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.943.clone.1), dimensions={}, metadata={op_name="broadcast.68"} - %mul.1540.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.265.clone.1, %broadcast.578.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.935.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.539.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.935.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.1858.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.265.clone.1, %broadcast.539.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.110 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(8) - %constant.947.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.577.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.947.clone.1), dimensions={}, metadata={op_name="broadcast.67"} - %mul.1539.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_8.110, %broadcast.577.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.809.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1540.clone.1, %mul.1539.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1106 = f32[]{:T(128)S(6)} parameter(2) - %div.769.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_2.1106), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.939.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.538.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.939.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.1857.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_8.110, %broadcast.538.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.795.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1858.clone.1, %mul.1857.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1100 = f32[]{:T(128)S(6)} parameter(2) + %div.769.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_2.1100), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.66.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.265.clone.1, %select_n.265.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.946.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.576.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.946.clone.1), dimensions={}, metadata={op_name="broadcast.58"} - %mul.1538.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%integer_pow.66.clone.1, %broadcast.576.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.496 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(4) - %constant.945.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.575.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.945.clone.1), dimensions={}, metadata={op_name="broadcast.57"} - %mul.1537.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_4.496, %broadcast.575.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.808.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1538.clone.1, %mul.1537.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1275 = f32[]{:T(128)S(6)} parameter(1) - %div.768.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_1.1275), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.767.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.808.clone.1, %div.768.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.938.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.537.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.938.clone.1), dimensions={}, metadata={op_name="broadcast.58"} + %mul.1856.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%integer_pow.66.clone.1, %broadcast.537.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.490 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.937.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.536.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.937.clone.1), dimensions={}, metadata={op_name="broadcast.57"} + %mul.1855.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_4.490, %broadcast.536.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.794.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1856.clone.1, %mul.1855.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1279 = f32[]{:T(128)S(6)} parameter(1) + %div.768.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_1.1279), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.767.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.794.clone.1, %div.768.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.64.clone.1 = f32[4096,4]{0,1:T(4,128)} sqrt(%div.767.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.944.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.573.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.944.clone.1), dimensions={}, metadata={op_name="broadcast.53"} - %add.807.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%sqrt.64.clone.1, %broadcast.573.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.262.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%div.769.clone.1, %add.807.clone.1), metadata={op_name="multiply.36"} - %div.766.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.809.clone.1, %multiply.262.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1535.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_0.1112, %broadcast.578.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.806.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%div.766.clone.1, %mul.1535.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1534.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%mul.1536.clone.1, %add.806.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.805.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%param_0.1112, %mul.1534.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.199 = f32[4096,4]{0,1:T(4,128)} multiply(%add.805.clone.1, %add.805.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.999 = f32[]{:T(128)} constant(0) - %reduce.152 = f32[]{:T(128)} reduce(%square.199, %constant.999), dimensions={0,1}, to_apply=%region_54.59, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.154.clone.1 = f32[]{:T(128)} reduce(%integer_pow.66.clone.1, %constant.999), dimensions={0,1}, to_apply=%region_40.45, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.144 = (f32[]{:T(128)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.152, %add.805.clone.1, %add.808.clone.1, %add.809.clone.1, %reduce.154.clone.1) -} - -%region_53.58 (reduce_sum.295: f32[], reduce_sum.296: f32[]) -> f32[] { - %reduce_sum.295 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.296 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.297 = f32[]{:T(128)} add(%reduce_sum.295, %reduce_sum.296), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_39.44 (reduce_sum.220: f32[], reduce_sum.224: f32[]) -> f32[] { - %reduce_sum.220 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.224 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.225 = f32[]{:T(128)} add(%reduce_sum.220, %reduce_sum.224), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.333 (param_0.1113: f32[4096,4], param_1.1276: f32[], param_2.1107: f32[], param_3.796: f32[], param_4.497: f32[4096,4], param_5.422: f32[], param_6.294: f32[4,4096], param_7.193: pred[], param_8.111: f32[4096,4]) -> (f32[], f32[4096,4], f32[4096,4], f32[4096,4], f32[]) { - %param_0.1113 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) - %param_3.796 = f32[]{:T(128)S(6)} parameter(3) - %mul.1543.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_3.796), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.936.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.534.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.936.clone.1), dimensions={}, metadata={op_name="broadcast.53"} + %add.793.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%sqrt.64.clone.1, %broadcast.534.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.189.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%div.769.clone.1, %add.793.clone.1), metadata={op_name="multiply.27"} + %div.766.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.795.clone.1, %multiply.189.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1853.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_0.1119, %broadcast.539.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.792.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%div.766.clone.1, %mul.1853.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1852.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%mul.1854.clone.1, %add.792.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.791.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%param_0.1119, %mul.1852.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.199 = f32[4096,4]{0,1:T(4,128)} multiply(%add.791.clone.1, %add.791.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.991 = f32[]{:T(128)} constant(0) + %reduce.113 = f32[]{:T(128)} reduce(%square.199, %constant.991), dimensions={0,1}, to_apply=%region_54.59, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.115.clone.1 = f32[]{:T(128)} reduce(%integer_pow.66.clone.1, %constant.991), dimensions={0,1}, to_apply=%region_40.45, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.144 = (f32[]{:T(128)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.113, %add.791.clone.1, %add.794.clone.1, %add.795.clone.1, %reduce.115.clone.1) +} + +%region_53.58 (reduce_sum.349: f32[], reduce_sum.350: f32[]) -> f32[] { + %reduce_sum.349 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.350 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.354 = f32[]{:T(128)} add(%reduce_sum.349, %reduce_sum.350), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_39.44 (reduce_sum.277: f32[], reduce_sum.278: f32[]) -> f32[] { + %reduce_sum.277 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.278 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.279 = f32[]{:T(128)} add(%reduce_sum.277, %reduce_sum.278), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.333 (param_0.1120: f32[4096,4], param_1.1280: f32[], param_2.1101: f32[], param_3.784: f32[], param_4.491: f32[4096,4], param_5.422: f32[], param_6.295: f32[4,4096], param_7.193: pred[], param_8.111: f32[4096,4]) -> (f32[], f32[4096,4], f32[4096,4], f32[4096,4], f32[]) { + %param_0.1120 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.784 = f32[]{:T(128)S(6)} parameter(3) + %mul.1861.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_3.784), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.193 = pred[]{:T(512)S(6)} parameter(7) %select_n.270.clone.1 = pred[4096,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.193), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.294 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(6) - %bitcast.421.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_6.294), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_6.295 = f32[4,4096]{1,0:T(4,128)S(1)} parameter(6) + %bitcast.421.clone.1 = f32[4096,4]{0,1:T(4,128)} bitcast(%param_6.295), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} %param_5.422 = f32[]{:T(128)} parameter(5) %div.781.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_5.422), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %div.780.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%bitcast.421.clone.1, %div.781.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %select_n.269.clone.1 = f32[4096,4]{0,1:T(4,128)} select(%select_n.270.clone.1, %bitcast.421.clone.1, %div.780.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.949.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.584.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.949.clone.1), dimensions={}, metadata={op_name="broadcast.68"} - %mul.1547.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.269.clone.1, %broadcast.584.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %constant.941.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.545.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.941.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.1865.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.269.clone.1, %broadcast.545.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.111 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(8) - %constant.953.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.583.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.953.clone.1), dimensions={}, metadata={op_name="broadcast.67"} - %mul.1546.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_8.111, %broadcast.583.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.814.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1547.clone.1, %mul.1546.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1107 = f32[]{:T(128)S(6)} parameter(2) - %div.777.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_2.1107), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.945.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.544.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.945.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.1864.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_8.111, %broadcast.544.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.800.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1865.clone.1, %mul.1864.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1101 = f32[]{:T(128)S(6)} parameter(2) + %div.777.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_2.1101), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.67.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%select_n.269.clone.1, %select_n.269.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.952.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.582.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.952.clone.1), dimensions={}, metadata={op_name="broadcast.58"} - %mul.1545.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%integer_pow.67.clone.1, %broadcast.582.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.497 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(4) - %constant.951.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.581.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.951.clone.1), dimensions={}, metadata={op_name="broadcast.57"} - %mul.1544.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_4.497, %broadcast.581.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.813.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1545.clone.1, %mul.1544.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1276 = f32[]{:T(128)S(6)} parameter(1) - %div.776.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_1.1276), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.775.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.813.clone.1, %div.776.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.944.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.543.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.944.clone.1), dimensions={}, metadata={op_name="broadcast.58"} + %mul.1863.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%integer_pow.67.clone.1, %broadcast.543.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.491 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.943.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.542.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.943.clone.1), dimensions={}, metadata={op_name="broadcast.57"} + %mul.1862.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_4.491, %broadcast.542.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.799.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%mul.1863.clone.1, %mul.1862.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1280 = f32[]{:T(128)S(6)} parameter(1) + %div.776.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%param_1.1280), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.775.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.799.clone.1, %div.776.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.65.clone.1 = f32[4096,4]{0,1:T(4,128)} sqrt(%div.775.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.950.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.579.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.950.clone.1), dimensions={}, metadata={op_name="broadcast.53"} - %add.812.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%sqrt.65.clone.1, %broadcast.579.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.263.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%div.777.clone.1, %add.812.clone.1), metadata={op_name="multiply.35"} - %div.774.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.814.clone.1, %multiply.263.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1542.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_0.1113, %broadcast.584.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.811.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%div.774.clone.1, %mul.1542.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1541.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%mul.1543.clone.1, %add.811.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.810.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%param_0.1113, %mul.1541.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.200 = f32[4096,4]{0,1:T(4,128)} multiply(%add.810.clone.1, %add.810.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1000 = f32[]{:T(128)} constant(0) - %reduce.153 = f32[]{:T(128)} reduce(%square.200, %constant.1000), dimensions={0,1}, to_apply=%region_53.58, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.155.clone.1 = f32[]{:T(128)} reduce(%integer_pow.67.clone.1, %constant.1000), dimensions={0,1}, to_apply=%region_39.44, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.145 = (f32[]{:T(128)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.153, %add.810.clone.1, %add.813.clone.1, %add.814.clone.1, %reduce.155.clone.1) -} - -%region_9.12 (reduce_sum.99: f32[], reduce_sum.100: f32[]) -> f32[] { - %reduce_sum.100 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.99 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.101 = f32[]{:T(128)} add(%reduce_sum.99, %reduce_sum.100), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.344 (param_0.1127: bf16[4096]) -> f32[] { - %param_0.1127 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) - %convert_element_type.1006 = f32[4096]{0:T(1024)} convert(%param_0.1127), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %square.203 = f32[4096]{0:T(1024)} multiply(%convert_element_type.1006, %convert_element_type.1006), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1014 = f32[]{:T(128)} constant(0) - ROOT %reduce.156 = f32[]{:T(128)} reduce(%square.203, %constant.1014), dimensions={0}, to_apply=%region_9.12, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %constant.942.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.540.clone.1 = f32[4096,4]{0,1:T(4,128)} broadcast(%constant.942.clone.1), dimensions={}, metadata={op_name="broadcast.53"} + %add.798.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%sqrt.65.clone.1, %broadcast.540.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.190.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%div.777.clone.1, %add.798.clone.1), metadata={op_name="multiply.26"} + %div.774.clone.1 = f32[4096,4]{0,1:T(4,128)} divide(%add.800.clone.1, %multiply.190.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1860.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%param_0.1120, %broadcast.545.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.797.clone.1 = f32[4096,4]{0,1:T(4,128)} add(%div.774.clone.1, %mul.1860.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1859.clone.1 = f32[4096,4]{0,1:T(4,128)} multiply(%mul.1861.clone.1, %add.797.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.796.clone.1 = f32[4096,4]{0,1:T(4,128)S(1)} add(%param_0.1120, %mul.1859.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.200 = f32[4096,4]{0,1:T(4,128)} multiply(%add.796.clone.1, %add.796.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.992 = f32[]{:T(128)} constant(0) + %reduce.114 = f32[]{:T(128)} reduce(%square.200, %constant.992), dimensions={0,1}, to_apply=%region_53.58, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.116.clone.1 = f32[]{:T(128)} reduce(%integer_pow.67.clone.1, %constant.992), dimensions={0,1}, to_apply=%region_39.44, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.145 = (f32[]{:T(128)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[4096,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.114, %add.796.clone.1, %add.799.clone.1, %add.800.clone.1, %reduce.116.clone.1) +} + +%region_9.12 (reduce_sum.135: f32[], reduce_sum.136: f32[]) -> f32[] { + %reduce_sum.135 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.136 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.137 = f32[]{:T(128)} add(%reduce_sum.135, %reduce_sum.136), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.344 (param_0.1134: bf16[4096]) -> f32[] { + %param_0.1134 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %convert_element_type.1022 = f32[4096]{0:T(1024)} convert(%param_0.1134), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %square.203 = f32[4096]{0:T(1024)} multiply(%convert_element_type.1022, %convert_element_type.1022), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1006 = f32[]{:T(128)} constant(0) + ROOT %reduce.117 = f32[]{:T(128)} reduce(%square.203, %constant.1006), dimensions={0}, to_apply=%region_9.12, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_49.54 (reduce_sum.274: f32[], reduce_sum.275: f32[]) -> f32[] { - %reduce_sum.274 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.275 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.276 = f32[]{:T(128)} add(%reduce_sum.274, %reduce_sum.275), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_49.54 (reduce_sum.328: f32[], reduce_sum.329: f32[]) -> f32[] { + %reduce_sum.328 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.329 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.333 = f32[]{:T(128)} add(%reduce_sum.328, %reduce_sum.329), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_35.40 (reduce_sum.199: f32[], reduce_sum.203: f32[]) -> f32[] { - %reduce_sum.199 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.203 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.204 = f32[]{:T(128)} add(%reduce_sum.199, %reduce_sum.203), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_35.40 (reduce_sum.256: f32[], reduce_sum.257: f32[]) -> f32[] { + %reduce_sum.256 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.257 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.258 = f32[]{:T(128)} add(%reduce_sum.256, %reduce_sum.257), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.345 (param_0.1117: f32[4096], param_1.1280: f32[], param_2.1111: f32[], param_3.800: f32[], param_4.501: f32[4096], param_5.426: f32[], param_6.298: bf16[4096], param_7.197: pred[], param_8.115: f32[4096]) -> (f32[], f32[4096], f32[4096], f32[4096], f32[]) { - %param_0.1117 = f32[4096]{0:T(1024)S(1)} parameter(0) - %param_3.800 = f32[]{:T(128)S(6)} parameter(3) - %mul.1574.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_3.800), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.345 (param_0.1124: f32[4096], param_1.1284: f32[], param_2.1105: f32[], param_3.788: f32[], param_4.495: f32[4096], param_5.426: f32[], param_6.299: bf16[4096], param_7.197: pred[], param_8.115: f32[4096]) -> (f32[], f32[4096], f32[4096], f32[4096], f32[]) { + %param_0.1124 = f32[4096]{0:T(1024)S(1)} parameter(0) + %param_3.788 = f32[]{:T(128)S(6)} parameter(3) + %mul.1892.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_3.788), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_7.197 = pred[]{:T(512)S(6)} parameter(7) %select_n.286.clone.1 = pred[4096]{0:T(1024)(128)(4,1)} broadcast(%param_7.197), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.298 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(6) - %convert_element_type.1021.clone.1 = f32[4096]{0:T(1024)} convert(%param_6.298), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_6.299 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(6) + %convert_element_type.1037.clone.1 = f32[4096]{0:T(1024)} convert(%param_6.299), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} %param_5.426 = f32[]{:T(128)} parameter(5) %div.813.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_5.426), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.812.clone.1 = f32[4096]{0:T(1024)} divide(%convert_element_type.1021.clone.1, %div.813.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.285.clone.1 = f32[4096]{0:T(1024)} select(%select_n.286.clone.1, %convert_element_type.1021.clone.1, %div.812.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.973.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.600.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.973.clone.1), dimensions={}, metadata={op_name="broadcast.72"} - %mul.1580.clone.1 = f32[4096]{0:T(1024)} multiply(%select_n.285.clone.1, %broadcast.600.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %div.812.clone.1 = f32[4096]{0:T(1024)} divide(%convert_element_type.1037.clone.1, %div.813.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.285.clone.1 = f32[4096]{0:T(1024)} select(%select_n.286.clone.1, %convert_element_type.1037.clone.1, %div.812.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.965.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.561.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.965.clone.1), dimensions={}, metadata={op_name="broadcast.72"} + %mul.1898.clone.1 = f32[4096]{0:T(1024)} multiply(%select_n.285.clone.1, %broadcast.561.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.115 = f32[4096]{0:T(1024)S(1)} parameter(8) - %constant.977.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1581.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.977.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1579.clone.1 = f32[4096]{0:T(1024)} multiply(%param_8.115, %mul.1581.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.836.clone.1 = f32[4096]{0:T(1024)S(1)} add(%mul.1580.clone.1, %mul.1579.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1111 = f32[]{:T(128)S(6)} parameter(2) - %div.809.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_2.1111), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.969.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.1899.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.969.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1897.clone.1 = f32[4096]{0:T(1024)} multiply(%param_8.115, %mul.1899.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.822.clone.1 = f32[4096]{0:T(1024)S(1)} add(%mul.1898.clone.1, %mul.1897.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1105 = f32[]{:T(128)S(6)} parameter(2) + %div.809.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_2.1105), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.71.clone.1 = f32[4096]{0:T(1024)} multiply(%select_n.285.clone.1, %select_n.285.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.976.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1578.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.976.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1576.clone.1 = f32[4096]{0:T(1024)} multiply(%integer_pow.71.clone.1, %mul.1578.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.501 = f32[4096]{0:T(1024)S(1)} parameter(4) - %constant.975.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1577.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.975.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1575.clone.1 = f32[4096]{0:T(1024)} multiply(%param_4.501, %mul.1577.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.835.clone.1 = f32[4096]{0:T(1024)S(1)} add(%mul.1576.clone.1, %mul.1575.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1280 = f32[]{:T(128)S(6)} parameter(1) - %div.808.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_1.1280), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.807.clone.1 = f32[4096]{0:T(1024)} divide(%add.835.clone.1, %div.808.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.968.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.1896.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.968.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1894.clone.1 = f32[4096]{0:T(1024)} multiply(%integer_pow.71.clone.1, %mul.1896.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.495 = f32[4096]{0:T(1024)S(1)} parameter(4) + %constant.967.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.1895.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.967.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.1893.clone.1 = f32[4096]{0:T(1024)} multiply(%param_4.495, %mul.1895.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.821.clone.1 = f32[4096]{0:T(1024)S(1)} add(%mul.1894.clone.1, %mul.1893.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1284 = f32[]{:T(128)S(6)} parameter(1) + %div.808.clone.1 = f32[4096]{0:T(1024)} broadcast(%param_1.1284), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.807.clone.1 = f32[4096]{0:T(1024)} divide(%add.821.clone.1, %div.808.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.69.clone.1 = f32[4096]{0:T(1024)} sqrt(%div.807.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.974.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.834.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.974.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.833.clone.1 = f32[4096]{0:T(1024)} add(%sqrt.69.clone.1, %add.834.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.267.clone.1 = f32[4096]{0:T(1024)} multiply(%div.809.clone.1, %add.833.clone.1), metadata={op_name="multiply.31"} - %div.806.clone.1 = f32[4096]{0:T(1024)} divide(%add.836.clone.1, %multiply.267.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1573.clone.1 = f32[4096]{0:T(1024)} multiply(%param_0.1117, %broadcast.600.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.832.clone.1 = f32[4096]{0:T(1024)} add(%div.806.clone.1, %mul.1573.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1572.clone.1 = f32[4096]{0:T(1024)} multiply(%mul.1574.clone.1, %add.832.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.831.clone.1 = f32[4096]{0:T(1024)S(1)} add(%param_0.1117, %mul.1572.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.204 = f32[4096]{0:T(1024)} multiply(%add.831.clone.1, %add.831.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1004 = f32[]{:T(128)} constant(0) - %reduce.157 = f32[]{:T(128)} reduce(%square.204, %constant.1004), dimensions={0}, to_apply=%region_49.54, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.158.clone.1 = f32[]{:T(128)} reduce(%integer_pow.71.clone.1, %constant.1004), dimensions={0}, to_apply=%region_35.40, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.148 = (f32[]{:T(128)}, f32[4096]{0:T(1024)S(1)}, f32[4096]{0:T(1024)S(1)}, f32[4096]{0:T(1024)S(1)}, f32[]{:T(128)}) tuple(%reduce.157, %add.831.clone.1, %add.835.clone.1, %add.836.clone.1, %reduce.158.clone.1) -} - -%fused_computation.351 (param_0.964: s32[512]) -> s32[1024] { - %constant.801 = s32[] constant(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %broadcast.539 = s32[1024]{0:T(1024)} broadcast(%constant.801), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %param_0.964 = s32[512]{0:T(512)S(1)} parameter(0) - %constant.802 = s32[] constant(2147483647), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %pad.41 = s32[1024]{0:T(1024)} pad(%param_0.964, %constant.802), padding=0_512, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %constant.800 = s32[] constant(128255), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %broadcast.538 = s32[1024]{0:T(1024)} broadcast(%constant.800), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - ROOT %clamp.1 = s32[1024]{0:T(1024)} clamp(%broadcast.539, %pad.41, %broadcast.538), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} -} - -%fused_computation.352 (param_0.963: s32[4,128]) -> s32[512] { - %param_0.963 = s32[4,128]{1,0:T(4,128)} parameter(0) - %constant.888 = s32[]{:T(128)} constant(0) - %broadcast.546 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.888), dimensions={}, metadata={op_name="broadcast.81"} - %lt.32 = pred[4,128]{1,0:T(4,128)(4,1)} compare(%param_0.963, %broadcast.546), direction=LT, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/lt" stack_frame_id=0} - %constant.875 = s32[]{:T(128)} constant(128256) - %add.760 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.875), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} - %add.748 = s32[4,128]{1,0:T(4,128)} add(%param_0.963, %add.760), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} - %select_n.178 = s32[4,128]{1,0:T(4,128)} select(%lt.32, %add.748, %param_0.963), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/select_n" stack_frame_id=0} + %constant.966.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.820.clone.1 = f32[4096]{0:T(1024)} broadcast(%constant.966.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.819.clone.1 = f32[4096]{0:T(1024)} add(%sqrt.69.clone.1, %add.820.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.194.clone.1 = f32[4096]{0:T(1024)} multiply(%div.809.clone.1, %add.819.clone.1), metadata={op_name="multiply.22"} + %div.806.clone.1 = f32[4096]{0:T(1024)} divide(%add.822.clone.1, %multiply.194.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.1891.clone.1 = f32[4096]{0:T(1024)} multiply(%param_0.1124, %broadcast.561.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.818.clone.1 = f32[4096]{0:T(1024)} add(%div.806.clone.1, %mul.1891.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.1890.clone.1 = f32[4096]{0:T(1024)} multiply(%mul.1892.clone.1, %add.818.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.817.clone.1 = f32[4096]{0:T(1024)S(1)} add(%param_0.1124, %mul.1890.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.204 = f32[4096]{0:T(1024)} multiply(%add.817.clone.1, %add.817.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.996 = f32[]{:T(128)} constant(0) + %reduce.118 = f32[]{:T(128)} reduce(%square.204, %constant.996), dimensions={0}, to_apply=%region_49.54, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.119.clone.1 = f32[]{:T(128)} reduce(%integer_pow.71.clone.1, %constant.996), dimensions={0}, to_apply=%region_35.40, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.148 = (f32[]{:T(128)}, f32[4096]{0:T(1024)S(1)}, f32[4096]{0:T(1024)S(1)}, f32[4096]{0:T(1024)S(1)}, f32[]{:T(128)}) tuple(%reduce.118, %add.817.clone.1, %add.821.clone.1, %add.822.clone.1, %reduce.119.clone.1) +} + +%fused_computation.351 (param_0.971: s32[512]) -> s32[1024] { + %constant.793 = s32[] constant(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %broadcast.500 = s32[1024]{0:T(1024)} broadcast(%constant.793), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %param_0.971 = s32[512]{0:T(512)S(1)} parameter(0) + %constant.794 = s32[] constant(2147483647), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %pad.41 = s32[1024]{0:T(1024)} pad(%param_0.971, %constant.794), padding=0_512, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %constant.792 = s32[] constant(128255), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %broadcast.499 = s32[1024]{0:T(1024)} broadcast(%constant.792), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %clamp.1 = s32[1024]{0:T(1024)} clamp(%broadcast.500, %pad.41, %broadcast.499), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} +} + +%fused_computation.352 (param_0.970: s32[4,128]) -> s32[512] { + %param_0.970 = s32[4,128]{1,0:T(4,128)} parameter(0) + %constant.882 = s32[]{:T(128)} constant(0) + %broadcast.508 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.882), dimensions={}, metadata={op_name="broadcast.81"} + %lt.32 = pred[4,128]{1,0:T(4,128)(4,1)} compare(%param_0.970, %broadcast.508), direction=LT, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/lt" stack_frame_id=0} + %constant.879 = s32[]{:T(128)} constant(128256) + %add.746 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.879), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} + %add.734 = s32[4,128]{1,0:T(4,128)} add(%param_0.970, %add.746), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} + %select_n.178 = s32[4,128]{1,0:T(4,128)} select(%lt.32, %add.734, %param_0.970), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/select_n" stack_frame_id=0} ROOT %bitcast.376 = s32[512]{0:T(512)S(1)} bitcast(%select_n.178), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} } -%region_61.66 (reduce_sum.345: f32[], reduce_sum.346: f32[]) -> f32[] { - %reduce_sum.345 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.346 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.330 = f32[]{:T(128)} add(%reduce_sum.345, %reduce_sum.346), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_48.53 (reduce_sum.268: f32[], reduce_sum.269: f32[]) -> f32[] { - %reduce_sum.268 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.269 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.273 = f32[]{:T(128)} add(%reduce_sum.268, %reduce_sum.269), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.353 (param_0.1128: bf16[4,128], param_1.1287: f32[4,128], param_2.1114: f32[4,128], param_3.802: s32[4,128]) -> (f32[], f32[], pred[4,128], f32[4,128]) { - %param_3.802 = s32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %constant.979.clone.1 = s32[]{:T(128)} constant(0) - %broadcast.601.clone.1 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.979.clone.1), dimensions={}, metadata={op_name="broadcast.81"} - %ne.6.clone.1 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} compare(%param_3.802, %broadcast.601.clone.1), direction=NE, metadata={op_name="jit(train_step)/jvp()/ne" stack_frame_id=0} - %param_1.1287 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %log.16 = f32[4,128]{1,0:T(4,128)} log(%param_1.1287), metadata={op_name="jit(train_step)/jvp()/log" stack_frame_id=0} - %param_0.1128 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(0) - %reduce_max.15 = f32[4,128]{1,0:T(4,128)} convert(%param_0.1128), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} - %add.762 = f32[4,128]{1,0:T(4,128)} add(%log.16, %reduce_max.15), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} - %square.207 = f32[4,128]{1,0:T(4,128)} multiply(%add.762, %add.762), metadata={op_name="jit(train_step)/jvp()/square" stack_frame_id=0} - %constant.1016 = f32[]{:T(128)} constant(0) - %broadcast.543 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1016), dimensions={}, metadata={op_name="broadcast.32"} - %mul.1473 = f32[4,128]{1,0:T(4,128)} multiply(%square.207, %broadcast.543), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - %mul.1465 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %mul.1473, %broadcast.543), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - %reduce.159 = f32[]{:T(128)} reduce(%mul.1465, %constant.1016), dimensions={0,1}, to_apply=%region_61.66, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} - %param_2.1114 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %neg.115.clone.1 = f32[4,128]{1,0:T(4,128)} negate(%param_2.1114), metadata={op_name="jit(train_step)/jvp()/neg" stack_frame_id=0} - %add.749.clone.1 = f32[4,128]{1,0:T(4,128)} add(%neg.115.clone.1, %mul.1473), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} - %mul.1466.clone.1 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %add.749.clone.1, %broadcast.543), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - %reduce.160.clone.1 = f32[]{:T(128)} reduce(%mul.1466.clone.1, %constant.1016), dimensions={0,1}, to_apply=%region_48.53, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} - %mul.1471.clone.1 = f32[4,128]{1,0:T(4,128)} multiply(%add.762, %broadcast.543), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %constant.891.clone.1 = f32[]{:T(128)} constant(1) - %add.757.clone.1 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.891.clone.1), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} - %add.750.clone.1 = f32[4,128]{1,0:T(4,128)S(1)} add(%mul.1471.clone.1, %add.757.clone.1), metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} - ROOT %tuple.149 = (f32[]{:T(128)}, f32[]{:T(128)}, pred[4,128]{1,0:T(4,128)(4,1)S(1)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%reduce.159, %reduce.160.clone.1, %ne.6.clone.1, %add.750.clone.1) -} - -%fused_computation.356 (param_0.987: f32[4,128], param_1.1101: f32[4,128]) -> f32[4,128] { - %param_0.987 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %param_1.1101 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %constant.869 = f32[]{:T(128)} constant(0.000244140625) - %broadcast.549 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.869), dimensions={}, metadata={op_name="broadcast.264"} - %div.656 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1101, %broadcast.549), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} - %constant.867 = f32[]{:T(128)} constant(1e-05) - %add.770 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.867), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - %add.769 = f32[4,128]{1,0:T(4,128)} add(%div.656, %add.770), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - %rsqrt.90 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.769), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} - %div.649 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.90, %add.769), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} - %constant.864 = f32[]{:T(128)} constant(-0.5) - %mul.1477 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.864), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1470 = f32[4,128]{1,0:T(4,128)} multiply(%div.649, %mul.1477), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1469 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.987, %mul.1470), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %constant.863 = f32[]{:T(128)} constant(0.00048828125) - %mul.1476 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.863), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - ROOT %mul.1468 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.1469, %mul.1476), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} -} - -%region_0.1 (reduce_sum.67: s32[], reduce_sum.71: s32[]) -> s32[] { - %reduce_sum.67 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.71 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.72 = s32[]{:T(128)} add(%reduce_sum.67, %reduce_sum.71), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["0","2"]}]}} -} - -%fused_computation.360 (param_0.1004: pred[4,128]) -> s32[] { - %param_0.1004 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) - %convert_element_type.1013 = s32[4,128]{1,0:T(4,128)} convert(%param_0.1004), metadata={op_name="jit(train_step)/jvp()/convert_element_type" stack_frame_id=0} - %constant.889 = s32[]{:T(128)} constant(0) - ROOT %reduce.161 = s32[]{:T(128)} reduce(%convert_element_type.1013, %constant.889), dimensions={0,1}, to_apply=%region_0.1, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} -} - -%fused_computation.361 (param_0.989: f32[4,128]) -> f32[4,128] { - %param_0.989 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %constant.870 = f32[]{:T(128)} constant(0.000244140625) - %broadcast.541 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.870), dimensions={}, metadata={op_name="broadcast.264"} - %div.654 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.989, %broadcast.541), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} - %constant.868 = f32[]{:T(128)} constant(1e-05) - %add.759 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.868), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - %add.756 = f32[4,128]{1,0:T(4,128)} add(%div.654, %add.759), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - ROOT %rsqrt.88 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.756), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} -} - -%fused_computation.362 (param_0.990: pred[4,128], param_1.1286: f32[]) -> f32[4,128] { - %param_0.990 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) - %param_1.1286 = f32[]{:T(128)S(6)} parameter(1) - %broadcast_in_dim.272 = f32[4,128]{1,0:T(4,128)} broadcast(%param_1.1286), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/broadcast_in_dim" stack_frame_id=0} - %constant.1015 = f32[]{:T(128)} constant(0) - %broadcast.545 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1015), dimensions={}, metadata={op_name="broadcast.32"} - ROOT %mul.1478 = f32[4,128]{1,0:T(4,128)S(1)} select(%param_0.990, %broadcast_in_dim.272, %broadcast.545), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} +%region_61.66 (reduce_sum.391: f32[], reduce_sum.392: f32[]) -> f32[] { + %reduce_sum.391 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.392 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.396 = f32[]{:T(128)} add(%reduce_sum.391, %reduce_sum.392), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_48.53 (reduce_sum.322: f32[], reduce_sum.326: f32[]) -> f32[] { + %reduce_sum.322 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.326 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.327 = f32[]{:T(128)} add(%reduce_sum.322, %reduce_sum.326), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.353 (param_0.1135: bf16[4,128], param_1.1291: f32[4,128], param_2.1108: f32[4,128], param_3.790: s32[4,128]) -> (f32[], f32[], pred[4,128], f32[4,128]) { + %param_3.790 = s32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %constant.971.clone.1 = s32[]{:T(128)} constant(0) + %broadcast.562.clone.1 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.971.clone.1), dimensions={}, metadata={op_name="broadcast.81"} + %ne.6.clone.1 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} compare(%param_3.790, %broadcast.562.clone.1), direction=NE, metadata={op_name="jit(train_step)/jvp()/ne" stack_frame_id=0} + %param_1.1291 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %log.16 = f32[4,128]{1,0:T(4,128)} log(%param_1.1291), metadata={op_name="jit(train_step)/jvp()/log" stack_frame_id=0} + %param_0.1135 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(0) + %reduce_max.15 = f32[4,128]{1,0:T(4,128)} convert(%param_0.1135), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} + %add.748 = f32[4,128]{1,0:T(4,128)} add(%log.16, %reduce_max.15), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} + %square.207 = f32[4,128]{1,0:T(4,128)} multiply(%add.748, %add.748), metadata={op_name="jit(train_step)/jvp()/square" stack_frame_id=0} + %constant.1008 = f32[]{:T(128)} constant(0) + %broadcast.502 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1008), dimensions={}, metadata={op_name="broadcast.32"} + %mul.1791 = f32[4,128]{1,0:T(4,128)} multiply(%square.207, %broadcast.502), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %mul.1783 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %mul.1791, %broadcast.502), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %reduce.120 = f32[]{:T(128)} reduce(%mul.1783, %constant.1008), dimensions={0,1}, to_apply=%region_61.66, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %param_2.1108 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %neg.115.clone.1 = f32[4,128]{1,0:T(4,128)} negate(%param_2.1108), metadata={op_name="jit(train_step)/jvp()/neg" stack_frame_id=0} + %add.735.clone.1 = f32[4,128]{1,0:T(4,128)} add(%neg.115.clone.1, %mul.1791), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} + %mul.1784.clone.1 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %add.735.clone.1, %broadcast.502), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %reduce.121.clone.1 = f32[]{:T(128)} reduce(%mul.1784.clone.1, %constant.1008), dimensions={0,1}, to_apply=%region_48.53, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %mul.1789.clone.1 = f32[4,128]{1,0:T(4,128)} multiply(%add.748, %broadcast.502), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %constant.883.clone.1 = f32[]{:T(128)} constant(1) + %add.743.clone.1 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.883.clone.1), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} + %add.736.clone.1 = f32[4,128]{1,0:T(4,128)S(1)} add(%mul.1789.clone.1, %add.743.clone.1), metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} + ROOT %tuple.149 = (f32[]{:T(128)}, f32[]{:T(128)}, pred[4,128]{1,0:T(4,128)(4,1)S(1)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%reduce.120, %reduce.121.clone.1, %ne.6.clone.1, %add.736.clone.1) +} + +%fused_computation.356 (param_0.994: f32[4,128], param_1.1105: f32[4,128]) -> f32[4,128] { + %param_0.994 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %param_1.1105 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %constant.873 = f32[]{:T(128)} constant(0.000244140625) + %broadcast.510 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.873), dimensions={}, metadata={op_name="broadcast.245"} + %div.656 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1105, %broadcast.510), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.871 = f32[]{:T(128)} constant(1e-05) + %add.756 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.871), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %add.755 = f32[4,128]{1,0:T(4,128)} add(%div.656, %add.756), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %rsqrt.90 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.755), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} + %div.649 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.90, %add.755), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.856 = f32[]{:T(128)} constant(-0.5) + %mul.1795 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.856), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1788 = f32[4,128]{1,0:T(4,128)} multiply(%div.649, %mul.1795), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1787 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.994, %mul.1788), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.855 = f32[]{:T(128)} constant(0.00048828125) + %mul.1794 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.855), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.1786 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.1787, %mul.1794), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%region_5.8 (reduce_sum.120: s32[], reduce_sum.121: s32[]) -> s32[] { + %reduce_sum.120 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.121 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.122 = s32[]{:T(128)} add(%reduce_sum.120, %reduce_sum.121), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["0","2"]}]}} +} + +%fused_computation.359 (param_0.1011: pred[4,128]) -> s32[] { + %param_0.1011 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %convert_element_type.1029 = s32[4,128]{1,0:T(4,128)} convert(%param_0.1011), metadata={op_name="jit(train_step)/jvp()/convert_element_type" stack_frame_id=0} + %constant.881 = s32[]{:T(128)} constant(0) + ROOT %reduce.122 = s32[]{:T(128)} reduce(%convert_element_type.1029, %constant.881), dimensions={0,1}, to_apply=%region_5.8, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} +} + +%fused_computation.361 (param_0.997: f32[4,128]) -> f32[4,128] { + %param_0.997 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.874 = f32[]{:T(128)} constant(0.000244140625) + %broadcast.506 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.874), dimensions={}, metadata={op_name="broadcast.245"} + %div.654 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.997, %broadcast.506), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.872 = f32[]{:T(128)} constant(1e-05) + %add.745 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.872), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %add.742 = f32[4,128]{1,0:T(4,128)} add(%div.654, %add.745), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + ROOT %rsqrt.88 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.742), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} +} + +%fused_computation.362 (param_0.996: pred[4,128], param_1.1290: f32[]) -> f32[4,128] { + %param_0.996 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %param_1.1290 = f32[]{:T(128)S(6)} parameter(1) + %broadcast_in_dim.283 = f32[4,128]{1,0:T(4,128)} broadcast(%param_1.1290), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/broadcast_in_dim" stack_frame_id=0} + %constant.1007 = f32[]{:T(128)} constant(0) + %broadcast.504 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1007), dimensions={}, metadata={op_name="broadcast.32"} + ROOT %mul.1796 = f32[4,128]{1,0:T(4,128)S(1)} select(%param_0.996, %broadcast_in_dim.283, %broadcast.504), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} } %fused_computation.364 () -> f32[64] { - %constant.873 = f32[]{:T(128)} constant(500000) - %broadcast.552 = f32[64]{0:T(128)} broadcast(%constant.873), dimensions={}, metadata={op_name="broadcast.255"} + %constant.877 = f32[]{:T(128)} constant(500000) + %broadcast.513 = f32[64]{0:T(128)} broadcast(%constant.877), dimensions={}, metadata={op_name="broadcast.236"} %iota.46 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/layers/iota" stack_frame_id=0} - %constant.872 = s32[]{:T(128)} constant(2) - %broadcast.551 = s32[64]{0:T(128)} broadcast(%constant.872), dimensions={}, metadata={op_name="broadcast.256"} - %mul.1479 = s32[64]{0:T(128)} multiply(%iota.46, %broadcast.551), metadata={op_name="jit(train_step)/layers/mul" stack_frame_id=0} - %convert_element_type.1014 = f32[64]{0:T(128)} convert(%mul.1479), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} - %constant.871 = f32[]{:T(128)} constant(0.0078125) - %broadcast.550 = f32[64]{0:T(128)} broadcast(%constant.871), dimensions={}, metadata={op_name="broadcast.257"} - %div.657 = f32[64]{0:T(128)} multiply(%convert_element_type.1014, %broadcast.550), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} - ROOT %pow.36 = f32[64]{0:T(128)S(1)} power(%broadcast.552, %div.657), metadata={op_name="jit(train_step)/layers/pow" stack_frame_id=0} + %constant.876 = s32[]{:T(128)} constant(2) + %broadcast.512 = s32[64]{0:T(128)} broadcast(%constant.876), dimensions={}, metadata={op_name="broadcast.237"} + %mul.1797 = s32[64]{0:T(128)} multiply(%iota.46, %broadcast.512), metadata={op_name="jit(train_step)/layers/mul" stack_frame_id=0} + %convert_element_type.1030 = f32[64]{0:T(128)} convert(%mul.1797), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %constant.875 = f32[]{:T(128)} constant(0.0078125) + %broadcast.511 = f32[64]{0:T(128)} broadcast(%constant.875), dimensions={}, metadata={op_name="broadcast.238"} + %div.657 = f32[64]{0:T(128)} multiply(%convert_element_type.1030, %broadcast.511), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + ROOT %pow.36 = f32[64]{0:T(128)S(1)} power(%broadcast.513, %div.657), metadata={op_name="jit(train_step)/layers/pow" stack_frame_id=0} } -%fused_computation.365 (param_0.1002: s32[4,128]) -> (f32[4,128,1,1], f32[4,128]) { - %param_0.1002 = s32[4,128]{1,0:T(4,128)} parameter(0) - %convert_element_type.1015 = f32[4,128]{1,0:T(4,128)S(1)} convert(%param_0.1002), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} - %bitcast.377 = f32[4,128,1,1]{1,0,3,2:T(4,128)} bitcast(%convert_element_type.1015), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %tuple.151 = (f32[4,128,1,1]{1,0,3,2:T(4,128)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%bitcast.377, %convert_element_type.1015) +%fused_computation.365 (param_0.1009: s32[4,128]) -> (f32[4,128,1,1], f32[4,128]) { + %param_0.1009 = s32[4,128]{1,0:T(4,128)} parameter(0) + %convert_element_type.1031 = f32[4,128]{1,0:T(4,128)S(1)} convert(%param_0.1009), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %bitcast.377 = f32[4,128,1,1]{1,0,3,2:T(4,128)} bitcast(%convert_element_type.1031), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.151 = (f32[4,128,1,1]{1,0,3,2:T(4,128)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%bitcast.377, %convert_element_type.1031) } -%fused_computation.369 (param_0.1103: f32[4096,4]) -> bf16[4,4096] { - %param_0.1103 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) - %bitcast.451 = f32[4,4096]{1,0:T(4,128)} bitcast(%param_0.1103), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.106 = bf16[4,4096]{1,0:T(4,128)(2,1)} convert(%bitcast.451) +%fused_computation.369 (param_0.1110: f32[4096,4]) -> bf16[4,4096] { + %param_0.1110 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) + %bitcast.451 = f32[4,4096]{1,0:T(4,128)} bitcast(%param_0.1110), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.69 = bf16[4,4096]{1,0:T(4,128)(2,1)S(1)} convert(%bitcast.451) } -%fused_computation.370 (param_0.1104: f32[4096,4]) -> bf16[4,4096] { - %param_0.1104 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) - %bitcast.452 = f32[4,4096]{1,0:T(4,128)} bitcast(%param_0.1104), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.108 = bf16[4,4096]{1,0:T(4,128)(2,1)S(1)} convert(%bitcast.452) +%fused_computation.370 (param_0.1111: f32[4096,4]) -> bf16[4,4096] { + %param_0.1111 = f32[4096,4]{0,1:T(4,128)S(1)} parameter(0) + %bitcast.452 = f32[4,4096]{1,0:T(4,128)} bitcast(%param_0.1111), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.71 = bf16[4,4096]{1,0:T(4,128)(2,1)} convert(%bitcast.452) } %region_6.9 (reduce_max.6: bf16[], reduce_max.8: bf16[]) -> bf16[] { @@ -1371,364 +1371,364 @@ StackFrames ROOT %reduce_max.9 = bf16[]{:T(256)} maximum(%reduce_max.6, %reduce_max.8), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.237.clone.clone (param_0.1090: f32[4096,128256]) -> bf16[4096,128256,1] { - %param_0.1090 = f32[4096,128256]{1,0:T(8,128)} parameter(0) - %convert_element_type.1026 = bf16[4096,128256]{1,0:T(8,128)(2,1)} convert(%param_0.1090), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} - ROOT %bitcast.447 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} bitcast(%convert_element_type.1026), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} +%fused_computation.237.clone.clone (param_0.1097: f32[4096,128256]) -> bf16[4096,128256,1] { + %param_0.1097 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %convert_element_type.1042 = bf16[4096,128256]{1,0:T(8,128)(2,1)} convert(%param_0.1097), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} + ROOT %bitcast.447 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} bitcast(%convert_element_type.1042), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} } -%fused_computation.317.clone.clone (param_0.1091: f32[4,128], param_1.1257: bf16[4,128,4096], param_2.1077: bf16[4096]) -> bf16[4,128,4096] { - %param_2.1077 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.383 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1077), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1257 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1028 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1257), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.1091 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1595 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1091), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1594 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1028, %mul.1595), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.1027 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1594), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - ROOT %dot_general.382 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.383, %convert_element_type.1027), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} +%fused_computation.317.clone.clone (param_0.1098: f32[4,128], param_1.1261: bf16[4,128,4096], param_2.1071: bf16[4096]) -> bf16[4,128,4096] { + %param_1.1261 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1044 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1261), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1098 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.1916 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1098), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.1915 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1044, %mul.1916), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1043 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1915), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.1071 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.1917 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1071), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.1914 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1043, %mul.1917), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} } -%fused_computation.371 (param_0.1105: f32[4096,128256], param_1.1268: f32[4,128], param_2.1099: bf16[4,128,4096], param_3.788: bf16[4096]) -> (bf16[4,128], bf16[4,128,128256]) { - %param_1.1268 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.1099 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %param_3.788 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %fusion.240.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1268, %param_2.1099, %param_3.788), kind=kLoop, calls=%fused_computation.317.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1105 = f32[4096,128256]{1,0:T(8,128)} parameter(0) - %fusion.221.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1105), kind=kLoop, calls=%fused_computation.237.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} +%fused_computation.371 (param_0.1112: f32[4096,128256], param_1.1272: f32[4,128], param_2.1093: bf16[4,128,4096], param_3.776: bf16[4096]) -> (bf16[4,128], bf16[4,128,128256]) { + %param_1.1272 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1093 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.776 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.240.clone.1 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1272, %param_2.1093, %param_3.776), kind=kLoop, calls=%fused_computation.317.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.1112 = f32[4096,128256]{1,0:T(8,128)} parameter(0) + %fusion.221.clone.1 = bf16[4096,128256,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1112), kind=kLoop, calls=%fused_computation.237.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/convert_element_type" stack_frame_id=0} %convolution.87.clone.1 = bf16[4,128,128256]{2,1,0:T(8,128)(2,1)} convolution(%fusion.240.clone.1, %fusion.221.clone.1), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/logits_dense/dot_general" stack_frame_id=0} - %constant.992 = bf16[]{:T(256)} constant(-inf) - %reduce.162 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} reduce(%convolution.87.clone.1, %constant.992), dimensions={2}, to_apply=%region_6.9, metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} - ROOT %tuple.152 = (bf16[4,128]{1,0:T(4,128)(2,1)S(1)}, bf16[4,128,128256]{2,1,0:T(8,128)(2,1)}) tuple(%reduce.162, %convolution.87.clone.1) + %constant.984 = bf16[]{:T(256)} constant(-inf) + %reduce.123 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} reduce(%convolution.87.clone.1, %constant.984), dimensions={2}, to_apply=%region_6.9, metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} + ROOT %tuple.152 = (bf16[4,128]{1,0:T(4,128)(2,1)S(1)}, bf16[4,128,128256]{2,1,0:T(8,128)(2,1)}) tuple(%reduce.123, %convolution.87.clone.1) } -%fused_computation.372 (param_0.1102: f32[4096,4,8,128]) -> bf16[4,4096,8,128] { - %param_0.1102 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) - %bitcast.450 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} bitcast(%param_0.1102), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.110 = bf16[4,4096,8,128]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.450) +%fused_computation.372 (param_0.1109: f32[4096,4,8,128]) -> bf16[4,4096,8,128] { + %param_0.1109 = f32[4096,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %bitcast.450 = f32[4,4096,8,128]{3,2,0,1:T(8,128)} bitcast(%param_0.1109), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.73 = bf16[4,4096,8,128]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.450) } -%convert_element_type.525.reduce_sub_computation (lhs.1: bf16[], rhs.1: bf16[]) -> bf16[] { - %lhs.1 = bf16[] parameter(0) - %rhs.1 = bf16[] parameter(1) - ROOT %add.624 = bf16[] add(%lhs.1, %rhs.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%convert_element_type.541.reduce_sub_computation (lhs: bf16[], rhs: bf16[]) -> bf16[] { + %lhs = bf16[] parameter(0) + %rhs = bf16[] parameter(1) + ROOT %add.609 = bf16[] add(%lhs, %rhs), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.121.clone.clone (param_0.1242: bf16[4,4096], param_1.1376: s32[]) -> bf16[4096] { - %param_0.1242 = bf16[4,4096]{1,0:T(4,128)(2,1)} parameter(0) - %param_1.1376 = s32[]{:T(128)S(6)} parameter(1) - %constant.1116 = s32[]{:T(128)} constant(0) - %dynamic_slice.316 = bf16[1,4096]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1242, %param_1.1376, %constant.1116), dynamic_slice_sizes={1,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %constant.1117 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %reduce.174 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.316, %constant.1117), dimensions={0}, to_apply=%convert_element_type.525.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.122.clone.clone (param_0.1249: bf16[4,4096], param_1.1380: s32[]) -> bf16[4096] { + %param_0.1249 = bf16[4,4096]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1380 = s32[]{:T(128)S(6)} parameter(1) + %constant.1108 = s32[]{:T(128)} constant(0) + %dynamic_slice.316 = bf16[1,4096]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1249, %param_1.1380, %constant.1108), dynamic_slice_sizes={1,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %constant.1109 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %reduce.135 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.316, %constant.1109), dimensions={0}, to_apply=%convert_element_type.541.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%region_12.14 (reduce_sum.108: f32[], reduce_sum.109: f32[]) -> f32[] { - %reduce_sum.108 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - %reduce_sum.109 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - ROOT %reduce_sum.113 = f32[]{:T(128)} add(%reduce_sum.108, %reduce_sum.109), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_12.14 (reduce_sum.144: f32[], reduce_sum.148: f32[]) -> f32[] { + %reduce_sum.144 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.148 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.149 = f32[]{:T(128)} add(%reduce_sum.144, %reduce_sum.148), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.58.clone.clone (param_0.1243: bf16[4,4,128,4096], param_1.1377: s32[]) -> f32[4,128] { - %param_0.1243 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1377 = s32[]{:T(128)S(6)} parameter(1) - %constant.1118 = s32[]{:T(128)} constant(0) - %dynamic_slice.317 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1243, %param_1.1377, %constant.1118, %constant.1118, %constant.1118), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.58.clone.clone (param_0.1250: bf16[4,4,128,4096], param_1.1381: s32[]) -> f32[4,128] { + %param_0.1250 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1381 = s32[]{:T(128)S(6)} parameter(1) + %constant.1110 = s32[]{:T(128)} constant(0) + %dynamic_slice.317 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1250, %param_1.1381, %constant.1110, %constant.1110, %constant.1110), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} %bitcast.548 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.317), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1093 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.548), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %square.214 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1093, %convert_element_type.1093), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} - %constant.1119 = f32[]{:T(128)} constant(0) - ROOT %reduce.175 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.214, %constant.1119), dimensions={2}, to_apply=%region_12.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + %convert_element_type.1109 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.548), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.214 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1109, %convert_element_type.1109), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1111 = f32[]{:T(128)} constant(0) + ROOT %reduce.136 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.214, %constant.1111), dimensions={2}, to_apply=%region_12.14, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} } -%fused_computation.143.clone.1.clone (param_0.1244: f32[4,128]) -> f32[4,128] { - %param_0.1244 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %constant.1121 = f32[]{:T(128)} constant(0.000244140625) - %closed_call.81 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1121), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.842 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1244, %closed_call.81), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1120 = f32[]{:T(128)} constant(1e-05) - %closed_call.80 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1120), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %add.858 = f32[4,128]{1,0:T(4,128)} add(%div.842, %closed_call.80), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %rsqrt.97 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.858), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +%fused_computation.143.clone.1.clone (param_0.1251: f32[4,128]) -> f32[4,128] { + %param_0.1251 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1113 = f32[]{:T(128)} constant(0.000244140625) + %closed_call.81 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1113), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.842 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1251, %closed_call.81), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1112 = f32[]{:T(128)} constant(1e-05) + %closed_call.80 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1112), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.844 = f32[4,128]{1,0:T(4,128)} add(%div.842, %closed_call.80), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.97 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.844), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} } -%fused_computation.24.clone.1.clone.clone (param_0.1258: bf16[4,4096,32,128], param_1.1387: s32[]) -> bf16[4096,32,128,1] { - %param_0.1258 = bf16[4,4096,32,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %param_1.1387 = s32[]{:T(128)S(6)} parameter(1) - %constant.1134 = s32[]{:T(128)} constant(0) - %dynamic_slice.323 = bf16[1,4096,32,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1258, %param_1.1387, %constant.1134, %constant.1134, %constant.1134), dynamic_slice_sizes={1,4096,32,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.24.clone.1.clone.clone (param_0.1265: bf16[4,4096,32,128], param_1.1391: s32[]) -> bf16[4096,32,128,1] { + %param_0.1265 = bf16[4,4096,32,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1391 = s32[]{:T(128)S(6)} parameter(1) + %constant.1126 = s32[]{:T(128)} constant(0) + %dynamic_slice.323 = bf16[1,4096,32,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1265, %param_1.1391, %constant.1126, %constant.1126, %constant.1126), dynamic_slice_sizes={1,4096,32,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.559 = bf16[4096,32,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.323), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.91.clone.clone (param_0.1259: f32[4,128], param_1.1388: bf16[4,4,128,4096], param_2.1176: s32[], param_3.847: bf16[4096]) -> bf16[4,128,4096,1] { - %param_3.847 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %dot_general.428 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.847), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1388 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1176 = s32[]{:T(128)S(6)} parameter(2) - %constant.1135 = s32[]{:T(128)} constant(0) - %dynamic_slice.324 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1388, %param_2.1176, %constant.1135, %constant.1135, %constant.1135), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.91.clone.clone (param_0.1266: bf16[4096], param_1.1392: f32[4,128], param_2.1170: bf16[4,4,128,4096], param_3.835: s32[]) -> bf16[4,128,4096,1] { + %param_2.1170 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.835 = s32[]{:T(128)S(6)} parameter(3) + %constant.1127 = s32[]{:T(128)} constant(0) + %dynamic_slice.324 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1170, %param_3.835, %constant.1127, %constant.1127, %constant.1127), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} %bitcast.561 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.324), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1101 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.561), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1259 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1709 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1259), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1708 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1101, %mul.1709), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1100 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1708), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.427 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.428, %convert_element_type.1100), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.560 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.427), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.36.clone.clone (param_0.1260: bf16[4,4096,32,128], param_1.1389: s32[], param_2.1177: f32[4,128], param_3.848: bf16[4,4,128,4096], param_4.530: bf16[4096]) -> bf16[4,128,32,128] { - %param_2.1177 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.848 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) - %param_1.1389 = s32[]{:T(128)S(6)} parameter(1) - %param_4.530 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.343 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1177, %param_3.848, %param_1.1389, %param_4.530), kind=kLoop, calls=%fused_computation.91.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1260 = bf16[4,4096,32,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %fusion.342 = bf16[4096,32,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1260, %param_1.1389), kind=kLoop, calls=%fused_computation.24.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convert_element_type.1117 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.561), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1392 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2081 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1392), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2080 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1117, %mul.2081), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1116 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2080), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1266 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %mul.2079 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_0.1266), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2078 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1116, %mul.2079), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.560 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2078), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.36.clone.clone (param_0.1267: bf16[4,4096,32,128], param_1.1393: s32[], param_2.1171: f32[4,128], param_3.836: bf16[4,4,128,4096], param_4.524: bf16[4096]) -> bf16[4,128,32,128] { + %param_4.524 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %param_2.1171 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.836 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1393 = s32[]{:T(128)S(6)} parameter(1) + %fusion.343 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_4.524, %param_2.1171, %param_3.836, %param_1.1393), kind=kLoop, calls=%fused_computation.91.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1267 = bf16[4,4096,32,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.342 = bf16[4096,32,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1267, %param_1.1393), kind=kLoop, calls=%fused_computation.24.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} ROOT %convolution.113 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.343, %fusion.342), window={size=1x32 pad=0_0x31_31 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} } -%fused_computation.70.clone.clone (param_0.1261: bf16[4,128,32,128]) -> (bf16[4,128,32,64], bf16[4,128,32,64]) { - %param_0.1261 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %split.160 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1261), slice={[0:4], [0:128], [0:32], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} +%fused_computation.70.clone.clone (param_0.1268: bf16[4,128,32,128]) -> (bf16[4,128,32,64], bf16[4,128,32,64]) { + %param_0.1268 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.160 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1268), slice={[0:4], [0:128], [0:32], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} %neg.129 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.160), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=0} - %split.161 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1261), slice={[0:4], [0:128], [0:32], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + %split.161 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1268), slice={[0:4], [0:128], [0:32], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} ROOT %tuple.187 = (bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.129, %split.161) } %fused_computation.145.clone.clone () -> f32[64] { - %constant.1124 = f32[]{:T(128)} constant(500000) - %closed_call.84 = f32[64]{0:T(128)} broadcast(%constant.1124), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %constant.1116 = f32[]{:T(128)} constant(500000) + %closed_call.84 = f32[64]{0:T(128)} broadcast(%constant.1116), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} %iota.51 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/iota" stack_frame_id=0} - %constant.1123 = s32[]{:T(128)} constant(2) - %closed_call.83 = s32[64]{0:T(128)} broadcast(%constant.1123), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %mul.1699 = s32[64]{0:T(128)} multiply(%iota.51, %closed_call.83), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1094 = f32[64]{0:T(128)} convert(%mul.1699), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %constant.1122 = f32[]{:T(128)} constant(0.0078125) - %closed_call.82 = f32[64]{0:T(128)} broadcast(%constant.1122), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.843 = f32[64]{0:T(128)} multiply(%convert_element_type.1094, %closed_call.82), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1115 = s32[]{:T(128)} constant(2) + %closed_call.83 = s32[64]{0:T(128)} broadcast(%constant.1115), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %mul.2065 = s32[64]{0:T(128)} multiply(%iota.51, %closed_call.83), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1110 = f32[64]{0:T(128)} convert(%mul.2065), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %constant.1114 = f32[]{:T(128)} constant(0.0078125) + %closed_call.82 = f32[64]{0:T(128)} broadcast(%constant.1114), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.843 = f32[64]{0:T(128)} multiply(%convert_element_type.1110, %closed_call.82), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} ROOT %pow.38 = f32[64]{0:T(128)S(1)} power(%closed_call.84, %div.843), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/pow" stack_frame_id=0} } -%fused_computation.117.clone.clone (param_0.1245: f32[64], param_1.1378: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { - %param_1.1378 = f32[4,128]{1,0:T(4,128)} parameter(1) - %div.846 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1378), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %param_0.1245 = f32[64]{0:T(128)S(1)} parameter(0) - %div.845 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1245), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} +%fused_computation.117.clone.clone (param_0.1252: f32[64], param_1.1382: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.1382 = f32[4,128]{1,0:T(4,128)} parameter(1) + %div.846 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1382), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %param_0.1252 = f32[64]{0:T(128)S(1)} parameter(0) + %div.845 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1252), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} %div.844 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.846, %div.845), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} %cos.43 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.844), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/cos" stack_frame_id=0} - %convert_element_type.1095 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.43), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convert_element_type.1111 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.43), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} %sin.35.clone.3 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.844), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/sin" stack_frame_id=0} - %convert_element_type.829.clone.3 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.35.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %tuple.185 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1095, %convert_element_type.829.clone.3) + %convert_element_type.845.clone.3 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.35.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.185 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1111, %convert_element_type.845.clone.3) } -%fused_computation.120.clone.clone (param_0.1252: bf16[4,128,1,64]) -> bf16[4,128,128] { - %param_0.1252 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.1130 = bf16[]{:T(256)} constant(-inf) - %pad.61 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1252, %constant.1130), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %pad.60 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1252, %constant.1130), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} +%fused_computation.120.clone.clone (param_0.1259: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1259 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1122 = bf16[]{:T(256)} constant(-inf) + %pad.61 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1259, %constant.1122), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %pad.60 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1259, %constant.1122), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.45 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.61, %pad.60), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} ROOT %bitcast.554 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.45), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} } -%fused_computation.119.clone.clone (param_0.1246: bf16[4,128,1,64]) -> bf16[4,128,128] { - %param_0.1246 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.1125 = bf16[]{:T(256)} constant(-inf) - %pad.59 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1246, %constant.1125), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %pad.58 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1246, %constant.1125), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} +%fused_computation.119.clone.clone (param_0.1253: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1253 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1117 = bf16[]{:T(256)} constant(-inf) + %pad.59 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1253, %constant.1117), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %pad.58 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1253, %constant.1117), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.44 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.59, %pad.58), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} ROOT %bitcast.549 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.44), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} } -%fused_computation.73.clone.clone (param_0.1262: bf16[4,128,32,64], param_1.1390: bf16[4,128,32,64], param_2.1178: bf16[4,128,32,128], param_3.849: bf16[4,128,128], param_4.531: bf16[4,128,128]) -> bf16[4,32,128,128] { - %param_2.1178 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(2) - %param_4.531 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(4) - %mul.1713 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.531), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1711 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%param_2.1178, %mul.1713), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %param_1.1390 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) - %constant.1136 = bf16[]{:T(256)} constant(-inf) - %pad.65 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1390, %constant.1136), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_0.1262 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %pad.64 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1262, %constant.1136), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} +%fused_computation.73.clone.clone (param_0.1269: bf16[4,128,32,64], param_1.1394: bf16[4,128,32,64], param_2.1172: bf16[4,128,32,128], param_3.837: bf16[4,128,128], param_4.525: bf16[4,128,128]) -> bf16[4,32,128,128] { + %param_2.1172 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(2) + %param_4.525 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(4) + %mul.2085 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.525), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2083 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%param_2.1172, %mul.2085), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_1.1394 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1128 = bf16[]{:T(256)} constant(-inf) + %pad.65 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1394, %constant.1128), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_0.1269 = bf16[4,128,32,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.64 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1269, %constant.1128), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.47 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.65, %pad.64), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_3.849 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %mul.1712 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.849), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1710 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.47, %mul.1712), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %add.860 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.1711, %mul.1710), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %bitcast.562 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.860), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} -} - -%fused_computation.90.clone.clone (param_0.1254: f32[4,128], param_1.1384: bf16[4,4,128,4096], param_2.1173: s32[], param_3.844: bf16[4096]) -> bf16[4,128,4096,1] { - %param_3.844 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %dot_general.426 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.844), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1384 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1173 = s32[]{:T(128)S(6)} parameter(2) - %constant.1132 = s32[]{:T(128)} constant(0) - %dynamic_slice.322 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1384, %param_2.1173, %constant.1132, %constant.1132, %constant.1132), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %param_3.837 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2084 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.837), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2082 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.47, %mul.2084), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %add.846 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2083, %mul.2082), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %bitcast.562 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.846), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.90.clone.clone (param_0.1261: bf16[4096], param_1.1388: f32[4,128], param_2.1167: bf16[4,4,128,4096], param_3.832: s32[]) -> bf16[4,128,4096,1] { + %param_2.1167 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.832 = s32[]{:T(128)S(6)} parameter(3) + %constant.1124 = s32[]{:T(128)} constant(0) + %dynamic_slice.322 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1167, %param_3.832, %constant.1124, %constant.1124, %constant.1124), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} %bitcast.557 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.322), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1099 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.557), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1254 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1703 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1254), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1702 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1099, %mul.1703), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1098 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1702), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.425 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.426, %convert_element_type.1098), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.556 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.425), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.64.clone.1.clone.clone (param_0.1253: bf16[4,4096,8,128], param_1.1383: s32[]) -> bf16[4096,8,128,1] { - %param_0.1253 = bf16[4,4096,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %param_1.1383 = s32[]{:T(128)S(6)} parameter(1) - %constant.1131 = s32[]{:T(128)} constant(0) - %dynamic_slice.321 = bf16[1,4096,8,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1253, %param_1.1383, %constant.1131, %constant.1131, %constant.1131), dynamic_slice_sizes={1,4096,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %convert_element_type.1115 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.557), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1388 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2073 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1388), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2072 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1115, %mul.2073), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1114 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2072), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1261 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %mul.2071 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_0.1261), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2070 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1114, %mul.2071), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.556 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2070), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.64.clone.1.clone.clone (param_0.1260: bf16[4,4096,8,128], param_1.1387: s32[]) -> bf16[4096,8,128,1] { + %param_0.1260 = bf16[4,4096,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1387 = s32[]{:T(128)S(6)} parameter(1) + %constant.1123 = s32[]{:T(128)} constant(0) + %dynamic_slice.321 = bf16[1,4096,8,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1260, %param_1.1387, %constant.1123, %constant.1123, %constant.1123), dynamic_slice_sizes={1,4096,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.555 = bf16[4096,8,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.321), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.89.clone.clone (param_0.1255: bf16[4,4096,8,128], param_1.1385: s32[], param_2.1174: f32[4,128], param_3.845: bf16[4,4,128,4096], param_4.528: bf16[4096]) -> bf16[4,128,8,128] { - %param_2.1174 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.845 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) - %param_1.1385 = s32[]{:T(128)S(6)} parameter(1) - %param_4.528 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.340 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1174, %param_3.845, %param_1.1385, %param_4.528), kind=kLoop, calls=%fused_computation.90.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1255 = bf16[4,4096,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %fusion.341 = bf16[4096,8,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1255, %param_1.1385), kind=kLoop, calls=%fused_computation.64.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.89.clone.clone (param_0.1262: bf16[4,4096,8,128], param_1.1389: s32[], param_2.1168: f32[4,128], param_3.833: bf16[4,4,128,4096], param_4.522: bf16[4096]) -> bf16[4,128,8,128] { + %param_4.522 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %param_2.1168 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.833 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1389 = s32[]{:T(128)S(6)} parameter(1) + %fusion.340 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_4.522, %param_2.1168, %param_3.833, %param_1.1389), kind=kLoop, calls=%fused_computation.90.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1262 = bf16[4,4096,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.341 = bf16[4096,8,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1262, %param_1.1389), kind=kLoop, calls=%fused_computation.64.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} ROOT %convolution.112 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.340, %fusion.341), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} } -%fused_computation.106.clone.clone (param_0.1256: bf16[4,128,8,128]) -> (bf16[4,128,8,64], bf16[4,128,8,64]) { - %param_0.1256 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %split.158 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1256), slice={[0:4], [0:128], [0:8], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} +%fused_computation.106.clone.clone (param_0.1263: bf16[4,128,8,128]) -> (bf16[4,128,8,64], bf16[4,128,8,64]) { + %param_0.1263 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.158 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1263), slice={[0:4], [0:128], [0:8], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} %neg.128 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.158), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=0} - %split.159 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1256), slice={[0:4], [0:128], [0:8], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + %split.159 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1263), slice={[0:4], [0:128], [0:8], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} ROOT %tuple.186 = (bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.128, %split.159) } -%fused_computation.109.clone.clone (param_0.1257: bf16[4,128,8,64], param_1.1386: bf16[4,128,8,64], param_2.1175: bf16[4,128,8,128], param_3.846: bf16[4,128,128], param_4.529: bf16[4,128,128]) -> bf16[4,8,128,128] { - %param_2.1175 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(2) - %param_4.529 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(4) - %mul.1707 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.529), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1705 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%param_2.1175, %mul.1707), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %param_1.1386 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) - %constant.1133 = bf16[]{:T(256)} constant(-inf) - %pad.63 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1386, %constant.1133), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_0.1257 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %pad.62 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1257, %constant.1133), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} +%fused_computation.109.clone.clone (param_0.1264: bf16[4,128,8,64], param_1.1390: bf16[4,128,8,64], param_2.1169: bf16[4,128,8,128], param_3.834: bf16[4,128,128], param_4.523: bf16[4,128,128]) -> bf16[4,8,128,128] { + %param_2.1169 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(2) + %param_4.523 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(4) + %mul.2077 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.523), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2075 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%param_2.1169, %mul.2077), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_1.1390 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1125 = bf16[]{:T(256)} constant(-inf) + %pad.63 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1390, %constant.1125), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_0.1264 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.62 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1264, %constant.1125), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.46 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.63, %pad.62), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_3.846 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %mul.1706 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.846), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1704 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.46, %mul.1706), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %add.859 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.1705, %mul.1704), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %bitcast.558 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.859), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} + %param_3.834 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2076 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.834), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2074 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.46, %mul.2076), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %add.845 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2075, %mul.2074), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %bitcast.558 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.845), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} } -%fused_computation.135.clone.clone (param_0.1248: bf16[4,4096,8,128], param_1.1380: s32[]) -> bf16[1,4096,8,128] { - %param_0.1248 = bf16[4,4096,8,128]{3,2,0,1:T(8,128)(2,1)} parameter(0) - %param_1.1380 = s32[]{:T(128)S(6)} parameter(1) - %constant.1128 = s32[]{:T(128)} constant(0) - ROOT %dynamic_slice.319 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} dynamic-slice(%param_0.1248, %param_1.1380, %constant.1128, %constant.1128, %constant.1128), dynamic_slice_sizes={1,4096,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.135.clone.clone (param_0.1255: bf16[4,4096,8,128], param_1.1384: s32[]) -> bf16[1,4096,8,128] { + %param_0.1255 = bf16[4,4096,8,128]{3,2,0,1:T(8,128)(2,1)} parameter(0) + %param_1.1384 = s32[]{:T(128)S(6)} parameter(1) + %constant.1120 = s32[]{:T(128)} constant(0) + ROOT %dynamic_slice.319 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} dynamic-slice(%param_0.1255, %param_1.1384, %constant.1120, %constant.1120, %constant.1120), dynamic_slice_sizes={1,4096,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} } -%fused_computation.65.clone.1.clone.clone.clone.clone (param_0.1249: bf16[1,4096,8,128]) -> bf16[4096,8,128,1] { - %param_0.1249 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) - %copy.248 = bf16[1,4096,8,128]{3,1,2,0:T(8,128)(2,1)} copy(%param_0.1249), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0} +%fused_computation.65.clone.1.clone.clone.clone.clone (param_0.1256: bf16[1,4096,8,128]) -> bf16[4096,8,128,1] { + %param_0.1256 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %copy.248 = bf16[1,4096,8,128]{3,1,2,0:T(8,128)(2,1)} copy(%param_0.1256), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0} ROOT %bitcast.550 = bf16[4096,8,128,1]{2,0,1,3:T(8,128)(2,1)} bitcast(%copy.248), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.88.clone.clone.clone.clone (param_0.1250: f32[4,128], param_1.1381: bf16[4,4,128,4096], param_2.1171: s32[], param_3.842: bf16[4096]) -> bf16[4,128,4096,1] { - %param_3.842 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %dot_general.424 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.842), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1381 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1171 = s32[]{:T(128)S(6)} parameter(2) - %constant.1129 = s32[]{:T(128)} constant(0) - %dynamic_slice.320 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1381, %param_2.1171, %constant.1129, %constant.1129, %constant.1129), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.88.clone.clone.clone.clone (param_0.1257: bf16[4096], param_1.1385: f32[4,128], param_2.1165: bf16[4,4,128,4096], param_3.830: s32[]) -> bf16[4,128,4096,1] { + %param_2.1165 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.830 = s32[]{:T(128)S(6)} parameter(3) + %constant.1121 = s32[]{:T(128)} constant(0) + %dynamic_slice.320 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1165, %param_3.830, %constant.1121, %constant.1121, %constant.1121), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} %bitcast.552 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.320), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1097 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.552), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1250 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1701 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1250), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1700 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1097, %mul.1701), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1096 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1700), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.423 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.424, %convert_element_type.1096), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.551 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.423), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.114.clone.clone (param_0.1251: bf16[1,4096,8,128], param_1.1382: f32[4,128], param_2.1172: bf16[4,4,128,4096], param_3.843: s32[], param_4.527: bf16[4096]) -> bf16[4,8,128,128] { - %param_1.1382 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.1172 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) - %param_3.843 = s32[]{:T(128)S(6)} parameter(3) - %param_4.527 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.339 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_1.1382, %param_2.1172, %param_3.843, %param_4.527), kind=kLoop, calls=%fused_computation.88.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1251 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) - %fusion.338 = bf16[4096,8,128,1]{2,0,1,3:T(8,128)(2,1)} fusion(%param_0.1251), kind=kLoop, calls=%fused_computation.65.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convert_element_type.1113 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%bitcast.552), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1385 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2069 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1385), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2068 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1113, %mul.2069), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1112 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2068), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1257 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %mul.2067 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_0.1257), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2066 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1112, %mul.2067), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.551 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2066), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.114.clone.clone (param_0.1258: bf16[1,4096,8,128], param_1.1386: f32[4,128], param_2.1166: bf16[4,4,128,4096], param_3.831: s32[], param_4.521: bf16[4096]) -> bf16[4,8,128,128] { + %param_4.521 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %param_1.1386 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1166 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.831 = s32[]{:T(128)S(6)} parameter(3) + %fusion.339 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_4.521, %param_1.1386, %param_2.1166, %param_3.831), kind=kLoop, calls=%fused_computation.88.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1258 = bf16[1,4096,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %fusion.338 = bf16[4096,8,128,1]{2,0,1,3:T(8,128)(2,1)} fusion(%param_0.1258), kind=kLoop, calls=%fused_computation.65.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} %convolution.111 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convolution(%fusion.339, %fusion.338), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} ROOT %bitcast.553 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%convolution.111), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} } -%fused_computation.366.clone.clone (param_0.1286: f32[4,32,128,128]) -> (f32[4,32,128,1], f32[4,32,128]) { - %param_0.1286 = f32[4,32,128,128]{2,1,0,3:T(8,128)S(1)} parameter(0) - %slice.11 = f32[4,32,128,1]{2,1,0,3:T(8,128)S(1)} slice(%param_0.1286), slice={[0:4], [0:32], [0:128], [0:1]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/slice" stack_frame_id=0} +%fused_computation.366.clone.clone (param_0.1293: f32[4,32,128,128]) -> (f32[4,32,128,1], f32[4,32,128]) { + %param_0.1293 = f32[4,32,128,128]{2,1,0,3:T(8,128)S(1)} parameter(0) + %slice.11 = f32[4,32,128,1]{2,1,0,3:T(8,128)S(1)} slice(%param_0.1293), slice={[0:4], [0:32], [0:128], [0:1]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/slice" stack_frame_id=0} %bitcast.262.clone.3 = f32[4,32,128]{2,1,0:T(8,128)S(1)} bitcast(%slice.11), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/squeeze" stack_frame_id=0} ROOT %tuple.192 = (f32[4,32,128,1]{2,1,0,3:T(8,128)S(1)}, f32[4,32,128]{2,1,0:T(8,128)S(1)}) tuple(%slice.11, %bitcast.262.clone.3) } -%region_13.16 (reduce_sum.120: f32[], reduce_sum.121: f32[]) -> f32[] { - %reduce_sum.120 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - %reduce_sum.121 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - ROOT %reduce_sum.122 = f32[]{:T(128)} add(%reduce_sum.120, %reduce_sum.121), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_13.16 (reduce_sum.150: f32[], reduce_sum.151: f32[]) -> f32[] { + %reduce_sum.150 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.151 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.155 = f32[]{:T(128)} add(%reduce_sum.150, %reduce_sum.151), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.25.clone.1.clone.clone.clone.clone.clone.clone (param_0.1263: bf16[4,32,128,4096], param_1.1391: s32[]) -> bf16[32,128,4096,1] { - %param_0.1263 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1391 = s32[]{:T(128)S(6)} parameter(1) - %constant.1137 = s32[]{:T(128)} constant(0) - %dynamic_slice.325 = bf16[1,32,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1263, %param_1.1391, %constant.1137, %constant.1137, %constant.1137), dynamic_slice_sizes={1,32,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.25.clone.1.clone.clone.clone.clone.clone.clone (param_0.1270: bf16[4,32,128,4096], param_1.1395: s32[]) -> bf16[32,128,4096,1] { + %param_0.1270 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1395 = s32[]{:T(128)S(6)} parameter(1) + %constant.1129 = s32[]{:T(128)} constant(0) + %dynamic_slice.325 = bf16[1,32,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1270, %param_1.1395, %constant.1129, %constant.1129, %constant.1129), dynamic_slice_sizes={1,32,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.563 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dynamic_slice.325), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.80.clone.clone.clone.clone.clone.clone (param_0.1264: bf16[4,32,128,128]) -> bf16[4,128,32,128] { - %param_0.1264 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) - ROOT %bitcast.564 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} bitcast(%param_0.1264), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +%fused_computation.81.clone.clone.clone.clone.clone.clone (param_0.1271: bf16[4,32,128,128]) -> bf16[4,128,32,128] { + %param_0.1271 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.564 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} bitcast(%param_0.1271), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} } -%fused_computation.61.clone.clone (param_0.1265: bf16[4,32,128,4096], param_1.1392: s32[], param_2.1179: bf16[4,32,128,128], param_3.850: bf16[4,4,128,4096]) -> (f32[4,128], bf16[4,128,4096]) { - %param_3.850 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) - %param_1.1392 = s32[]{:T(128)S(6)} parameter(1) - %constant.365.clone.1.clone.3 = s32[]{:T(128)} constant(0) - %dynamic_slice.208.clone.3 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_3.850, %param_1.1392, %constant.365.clone.1.clone.3, %constant.365.clone.1.clone.3, %constant.365.clone.1.clone.3), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.61.clone.clone (param_0.1272: bf16[4,32,128,4096], param_1.1396: s32[], param_2.1173: bf16[4,32,128,128], param_3.838: bf16[4,4,128,4096]) -> (f32[4,128], bf16[4,128,4096]) { + %param_3.838 = bf16[4,4,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1396 = s32[]{:T(128)S(6)} parameter(1) + %constant.357.clone.1.clone.3 = s32[]{:T(128)} constant(0) + %dynamic_slice.208.clone.3 = bf16[1,4,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_3.838, %param_1.1396, %constant.357.clone.1.clone.3, %constant.357.clone.1.clone.3, %constant.357.clone.1.clone.3), dynamic_slice_sizes={1,4,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} %bitcast.207.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.208.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %param_2.1179 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %fusion.83.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} fusion(%param_2.1179), kind=kLoop, calls=%fused_computation.80.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} - %param_0.1265 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %fusion.82.clone.3 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.1265, %param_1.1392), kind=kLoop, calls=%fused_computation.25.clone.1.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1173 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %fusion.83.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} fusion(%param_2.1173), kind=kLoop, calls=%fused_computation.81.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} + %param_0.1272 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %fusion.82.clone.3 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.1272, %param_1.1396), kind=kLoop, calls=%fused_computation.25.clone.1.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} %convolution.62.clone.3 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} convolution(%fusion.83.clone.3, %fusion.82.clone.3), window={size=1x32}, dim_labels=0b1f_1io0->0bf1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} %bitcast.182.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%convolution.62.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} - %add.635.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%bitcast.207.clone.3, %bitcast.182.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - %convert_element_type.1102 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%add.635.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %square.215 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1102, %convert_element_type.1102), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} - %constant.1138 = f32[]{:T(128)} constant(0) - %reduce.177 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.215, %constant.1138), dimensions={2}, to_apply=%region_13.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} - ROOT %tuple.188 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.177, %add.635.clone.3) + %add.621.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%bitcast.207.clone.3, %bitcast.182.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %convert_element_type.1118 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%add.621.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.215 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1118, %convert_element_type.1118), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1130 = f32[]{:T(128)} constant(0) + %reduce.138 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.215, %constant.1130), dimensions={2}, to_apply=%region_13.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.188 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.138, %add.621.clone.3) } -%convert_element_type.523.reduce_sub_computation (lhs: bf16[], rhs: bf16[]) -> bf16[] { - %lhs = bf16[] parameter(0) - %rhs = bf16[] parameter(1) - ROOT %add.623 = bf16[] add(%lhs, %rhs), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%convert_element_type.556.reduce_sub_computation (lhs.1: bf16[], rhs.1: bf16[]) -> bf16[] { + %lhs.1 = bf16[] parameter(0) + %rhs.1 = bf16[] parameter(1) + ROOT %add.610 = bf16[] add(%lhs.1, %rhs.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.122.clone.clone (param_0.1247: bf16[4,4096], param_1.1379: s32[]) -> bf16[4096] { - %param_0.1247 = bf16[4,4096]{1,0:T(4,128)(2,1)} parameter(0) - %param_1.1379 = s32[]{:T(128)S(6)} parameter(1) - %constant.1126 = s32[]{:T(128)} constant(0) - %dynamic_slice.318 = bf16[1,4096]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1247, %param_1.1379, %constant.1126), dynamic_slice_sizes={1,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %constant.1127 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %reduce.176 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.318, %constant.1127), dimensions={0}, to_apply=%convert_element_type.523.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.121.clone.clone (param_0.1254: bf16[4,4096], param_1.1383: s32[]) -> bf16[4096] { + %param_0.1254 = bf16[4,4096]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1383 = s32[]{:T(128)S(6)} parameter(1) + %constant.1118 = s32[]{:T(128)} constant(0) + %dynamic_slice.318 = bf16[1,4096]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1254, %param_1.1383, %constant.1118), dynamic_slice_sizes={1,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %constant.1119 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %reduce.137 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.318, %constant.1119), dimensions={0}, to_apply=%convert_element_type.556.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.12.clone.clone.clone (param_0.1266: bf16[4,14336,4096], param_1.1393: s32[]) -> bf16[14336,4096,1] { - %param_0.1266 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1393 = s32[]{:T(128)S(6)} parameter(1) - %constant.1139 = s32[]{:T(128)} constant(0) - %dynamic_slice.326 = bf16[1,14336,4096]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1266, %param_1.1393, %constant.1139, %constant.1139), dynamic_slice_sizes={1,14336,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.12.clone.clone.clone (param_0.1273: bf16[4,14336,4096], param_1.1397: s32[]) -> bf16[14336,4096,1] { + %param_0.1273 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1397 = s32[]{:T(128)S(6)} parameter(1) + %constant.1131 = s32[]{:T(128)} constant(0) + %dynamic_slice.326 = bf16[1,14336,4096]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1273, %param_1.1397, %constant.1131, %constant.1131), dynamic_slice_sizes={1,14336,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.566 = bf16[14336,4096,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.326), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } @@ -1737,264 +1737,264 @@ StackFrames ROOT %bitcast.565 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} bitcast(%bitcast_input.12) } -%fused_computation.13.clone.clone (param_0.1267: bf16[4,128,4096], param_1.1394: bf16[4,14336,4096], param_2.1180: s32[]) -> bf16[14336,4,128] { - %param_1.1394 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1180 = s32[]{:T(128)S(6)} parameter(2) - %fusion.344 = bf16[14336,4096,1]{1,0,2:T(8,128)(2,1)} fusion(%param_1.1394, %param_2.1180), kind=kLoop, calls=%fused_computation.12.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1267 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %fusion.345 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1267), kind=kLoop, calls=%bitcast_fusion.3.clone.clone +%fused_computation.13.clone.clone (param_0.1274: bf16[4,128,4096], param_1.1398: bf16[4,14336,4096], param_2.1174: s32[]) -> bf16[14336,4,128] { + %param_1.1398 = bf16[4,14336,4096]{2,1,0:T(8,128)(2,1)} parameter(1) + %param_2.1174 = s32[]{:T(128)S(6)} parameter(2) + %fusion.344 = bf16[14336,4096,1]{1,0,2:T(8,128)(2,1)} fusion(%param_1.1398, %param_2.1174), kind=kLoop, calls=%fused_computation.12.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1274 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %fusion.345 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1274), kind=kLoop, calls=%bitcast_fusion.3.clone.clone ROOT %convolution.114 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} convolution(%fusion.344, %fusion.345), window={size=4 pad=3_3 rhs_reversal=1}, dim_labels=bf0_0oi->b0f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} } -%fused_computation.144.clone.1.clone (param_0.1268: f32[4,128]) -> f32[4,128] { - %param_0.1268 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %constant.1141 = f32[]{:T(128)} constant(0.000244140625) - %closed_call.86 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1141), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.847 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1268, %closed_call.86), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1140 = f32[]{:T(128)} constant(1e-05) - %closed_call.85 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1140), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %add.861 = f32[4,128]{1,0:T(4,128)} add(%div.847, %closed_call.85), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %rsqrt.98 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.861), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +%fused_computation.144.clone.1.clone (param_0.1275: f32[4,128]) -> f32[4,128] { + %param_0.1275 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1133 = f32[]{:T(128)} constant(0.000244140625) + %closed_call.86 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1133), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.847 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1275, %closed_call.86), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1132 = f32[]{:T(128)} constant(1e-05) + %closed_call.85 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1132), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.847 = f32[4,128]{1,0:T(4,128)} add(%div.847, %closed_call.85), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.98 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.847), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} } -%fused_computation.11.clone.1.clone.clone (param_0.1272: bf16[4,4096,14336], param_1.1398: s32[]) -> bf16[4096,14336,1] { - %param_0.1272 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1398 = s32[]{:T(128)S(6)} parameter(1) - %constant.1143 = s32[]{:T(128)} constant(0) - %dynamic_slice.328 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1272, %param_1.1398, %constant.1143, %constant.1143), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.11.clone.1.clone.clone (param_0.1279: bf16[4,4096,14336], param_1.1402: s32[]) -> bf16[4096,14336,1] { + %param_0.1279 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1402 = s32[]{:T(128)S(6)} parameter(1) + %constant.1135 = s32[]{:T(128)} constant(0) + %dynamic_slice.328 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1279, %param_1.1402, %constant.1135, %constant.1135), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.568 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.328), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.96.clone.2.clone.clone (param_0.1273: f32[4,128], param_1.1399: bf16[4,128,4096], param_2.1183: bf16[4096]) -> bf16[4,128,4096] { - %param_2.1183 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.432 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1183), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1399 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1106 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1399), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1273 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1717 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1273), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1716 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1106, %mul.1717), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1105 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1716), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %dot_general.431 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.432, %convert_element_type.1105), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.23.clone.clone (param_0.1274: bf16[4,4096,14336], param_1.1400: s32[], param_2.1184: f32[4,128], param_3.852: bf16[4,128,4096], param_4.533: bf16[4096]) -> bf16[4,128,14336] { - %param_2.1184 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.852 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %param_4.533 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.349 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1184, %param_3.852, %param_4.533), kind=kLoop, calls=%fused_computation.96.clone.2.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1274 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1400 = s32[]{:T(128)S(6)} parameter(1) - %fusion.348 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1274, %param_1.1400), kind=kLoop, calls=%fused_computation.11.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.96.clone.2.clone.clone (param_0.1280: f32[4,128], param_1.1403: bf16[4,128,4096], param_2.1177: bf16[4096]) -> bf16[4,128,4096] { + %param_1.1403 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1122 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1403), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1280 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2092 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1280), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2091 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1122, %mul.2092), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1121 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2091), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1177 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2093 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1177), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2090 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1121, %mul.2093), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.23.clone.clone (param_0.1281: bf16[4,4096,14336], param_1.1404: s32[], param_2.1178: f32[4,128], param_3.840: bf16[4,128,4096], param_4.527: bf16[4096]) -> bf16[4,128,14336] { + %param_2.1178 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.840 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.527 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.349 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1178, %param_3.840, %param_4.527), kind=kLoop, calls=%fused_computation.96.clone.2.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1281 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1404 = s32[]{:T(128)S(6)} parameter(1) + %fusion.348 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1281, %param_1.1404), kind=kLoop, calls=%fused_computation.11.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} ROOT %convolution.116 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.349, %fusion.348), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} } -%fused_computation.14.clone.1.clone.clone (param_0.1275: bf16[4,4096,14336], param_1.1401: s32[]) -> bf16[4096,14336,1] { - %param_0.1275 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1401 = s32[]{:T(128)S(6)} parameter(1) - %constant.1144 = s32[]{:T(128)} constant(0) - %dynamic_slice.329 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1275, %param_1.1401, %constant.1144, %constant.1144), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.14.clone.1.clone.clone (param_0.1282: bf16[4,4096,14336], param_1.1405: s32[]) -> bf16[4096,14336,1] { + %param_0.1282 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1405 = s32[]{:T(128)S(6)} parameter(1) + %constant.1136 = s32[]{:T(128)} constant(0) + %dynamic_slice.329 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1282, %param_1.1405, %constant.1136, %constant.1136), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.569 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.329), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.39.clone.1.clone.clone (param_0.1276: bf16[14336,4,128], param_1.1402: bf16[4,128,14336]) -> bf16[4,128,14336] { - %param_1.1402 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %constant.1145 = bf16[]{:T(256)} constant(1) - %jit_silu_.44 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} broadcast(%constant.1145), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)" stack_frame_id=0} - %neg.130 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} negate(%param_1.1402), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/neg" stack_frame_id=0} +%fused_computation.39.clone.1.clone.clone (param_0.1283: bf16[14336,4,128], param_1.1406: bf16[4,128,14336]) -> bf16[4,128,14336] { + %param_1.1406 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1137 = bf16[]{:T(256)} constant(1) + %jit_silu_.44 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} broadcast(%constant.1137), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)" stack_frame_id=0} + %neg.130 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} negate(%param_1.1406), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/neg" stack_frame_id=0} %exp.69 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} exponential(%neg.130), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/exp" stack_frame_id=0} - %add.862 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%exp.69, %jit_silu_.44), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/add" stack_frame_id=0} - %div.848 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} divide(%jit_silu_.44, %add.862), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/div" stack_frame_id=0} - %mul.1719 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1402, %div.848), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/mul" stack_frame_id=0} - %param_0.1276 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(0) - %bitcast.570 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} bitcast(%param_0.1276), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} - ROOT %mul.1718 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.1719, %bitcast.570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} -} - -%fused_computation.21.clone.clone (param_0.1277: bf16[4,4096,14336], param_1.1403: s32[], param_2.1185: bf16[14336,4,128], param_3.853: bf16[4,128,14336]) -> bf16[4,128,4096] { - %param_2.1185 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(2) - %param_3.853 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %bitcast_multiply_fusion.15 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1185, %param_3.853), kind=kLoop, calls=%fused_computation.39.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %param_0.1277 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1403 = s32[]{:T(128)S(6)} parameter(1) - %fusion.350 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1277, %param_1.1403), kind=kLoop, calls=%fused_computation.14.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %add.848 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%exp.69, %jit_silu_.44), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/add" stack_frame_id=0} + %div.848 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} divide(%jit_silu_.44, %add.848), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/div" stack_frame_id=0} + %mul.2095 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1406, %div.848), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/mul" stack_frame_id=0} + %param_0.1283 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(0) + %bitcast.570 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} bitcast(%param_0.1283), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} + ROOT %mul.2094 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.2095, %bitcast.570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} +} + +%fused_computation.21.clone.clone (param_0.1284: bf16[4,4096,14336], param_1.1407: s32[], param_2.1179: bf16[14336,4,128], param_3.841: bf16[4,128,14336]) -> bf16[4,128,4096] { + %param_2.1179 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(2) + %param_3.841 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %bitcast_multiply_fusion.15 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1179, %param_3.841), kind=kLoop, calls=%fused_computation.39.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %param_0.1284 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1407 = s32[]{:T(128)S(6)} parameter(1) + %fusion.350 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1284, %param_1.1407), kind=kLoop, calls=%fused_computation.14.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} ROOT %convolution.117 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} convolution(%bitcast_multiply_fusion.15, %fusion.350), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} } -%fused_computation.14.clone.clone.clone (param_0.1269: bf16[4,4096,14336], param_1.1395: s32[]) -> bf16[4096,14336,1] { - %param_0.1269 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1395 = s32[]{:T(128)S(6)} parameter(1) - %constant.1142 = s32[]{:T(128)} constant(0) - %dynamic_slice.327 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1269, %param_1.1395, %constant.1142, %constant.1142), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.14.clone.clone.clone (param_0.1276: bf16[4,4096,14336], param_1.1399: s32[]) -> bf16[4096,14336,1] { + %param_0.1276 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1399 = s32[]{:T(128)S(6)} parameter(1) + %constant.1134 = s32[]{:T(128)} constant(0) + %dynamic_slice.327 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1276, %param_1.1399, %constant.1134, %constant.1134), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.567 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.327), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.96.clone.1.clone.clone (param_0.1270: f32[4,128], param_1.1396: bf16[4,128,4096], param_2.1181: bf16[4096]) -> bf16[4,128,4096] { - %param_2.1181 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.430 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1181), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1396 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1104 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1396), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1270 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1715 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1270), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1714 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1104, %mul.1715), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1103 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1714), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %dot_general.429 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.430, %convert_element_type.1103), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.20.clone.clone (param_0.1271: bf16[4,4096,14336], param_1.1397: s32[], param_2.1182: f32[4,128], param_3.851: bf16[4,128,4096], param_4.532: bf16[4096]) -> bf16[4,128,14336] { - %param_2.1182 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.851 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %param_4.532 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.347 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1182, %param_3.851, %param_4.532), kind=kLoop, calls=%fused_computation.96.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1271 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1397 = s32[]{:T(128)S(6)} parameter(1) - %fusion.346 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1271, %param_1.1397), kind=kLoop, calls=%fused_computation.14.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.96.clone.1.clone.clone (param_0.1277: f32[4,128], param_1.1400: bf16[4,128,4096], param_2.1175: bf16[4096]) -> bf16[4,128,4096] { + %param_1.1400 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1120 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_1.1400), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1277 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2088 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_0.1277), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2087 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1120, %mul.2088), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1119 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2087), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1175 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2089 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1175), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2086 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1119, %mul.2089), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.20.clone.clone (param_0.1278: bf16[4,4096,14336], param_1.1401: s32[], param_2.1176: f32[4,128], param_3.839: bf16[4,128,4096], param_4.526: bf16[4096]) -> bf16[4,128,14336] { + %param_2.1176 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.839 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.526 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.347 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1176, %param_3.839, %param_4.526), kind=kLoop, calls=%fused_computation.96.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1278 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1401 = s32[]{:T(128)S(6)} parameter(1) + %fusion.346 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1278, %param_1.1401), kind=kLoop, calls=%fused_computation.14.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} ROOT %convolution.115 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.347, %fusion.346), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} } -%region_14.17 (reduce_sum.126: f32[], reduce_sum.127: f32[]) -> f32[] { - %reduce_sum.126 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} - %reduce_sum.127 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} - ROOT %reduce_sum.128 = f32[]{:T(128)} add(%reduce_sum.126, %reduce_sum.127), metadata={op_name="checkpoint/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_14.17 (reduce_sum.166: f32[], reduce_sum.167: f32[]) -> f32[] { + %reduce_sum.166 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} + %reduce_sum.167 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} + ROOT %reduce_sum.168 = f32[]{:T(128)} add(%reduce_sum.166, %reduce_sum.167), metadata={op_name="checkpoint/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.11.clone.clone.clone.clone.clone.clone.clone (param_0.1278: bf16[4,4096,14336], param_1.1404: s32[]) -> bf16[4096,14336,1] { - %param_0.1278 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1404 = s32[]{:T(128)S(6)} parameter(1) - %constant.1146 = s32[]{:T(128)} constant(0) - %dynamic_slice.330 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1278, %param_1.1404, %constant.1146, %constant.1146), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +%fused_computation.11.clone.clone.clone.clone.clone.clone.clone (param_0.1285: bf16[4,4096,14336], param_1.1408: s32[]) -> bf16[4096,14336,1] { + %param_0.1285 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1408 = s32[]{:T(128)S(6)} parameter(1) + %constant.1138 = s32[]{:T(128)} constant(0) + %dynamic_slice.330 = bf16[1,4096,14336]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1285, %param_1.1408, %constant.1138, %constant.1138), dynamic_slice_sizes={1,4096,14336}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.571 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.330), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.38.clone.1.clone.clone.clone.clone (param_0.1279: bf16[4,128,14336], param_1.1405: bf16[4,128,14336], param_2.1186: bf16[14336,4,128]) -> bf16[4,128,14336] { - %param_2.1186 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(2) - %bitcast.572 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} bitcast(%param_2.1186), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} - %param_1.1405 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %mul.1724 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%bitcast.572, %param_1.1405), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %constant.1147 = bf16[]{:T(256)} constant(1) - %jit_silu_.45 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} broadcast(%constant.1147), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)" stack_frame_id=0} - %param_0.1279 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %neg.131 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} negate(%param_0.1279), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/neg" stack_frame_id=0} +%fused_computation.38.clone.1.clone.clone.clone.clone (param_0.1286: bf16[4,128,14336], param_1.1409: bf16[4,128,14336], param_2.1180: bf16[14336,4,128]) -> bf16[4,128,14336] { + %param_2.1180 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(2) + %bitcast.572 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} bitcast(%param_2.1180), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} + %param_1.1409 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %mul.2100 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%bitcast.572, %param_1.1409), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %constant.1139 = bf16[]{:T(256)} constant(1) + %jit_silu_.45 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} broadcast(%constant.1139), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)" stack_frame_id=0} + %param_0.1286 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %neg.131 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} negate(%param_0.1286), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/neg" stack_frame_id=0} %exp.70 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} exponential(%neg.131), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/exp" stack_frame_id=0} - %add.863 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%exp.70, %jit_silu_.45), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/add" stack_frame_id=0} - %div.849 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} divide(%jit_silu_.45, %add.863), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/div" stack_frame_id=0} - %mul.1723 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.1724, %div.849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} - %mul.1722 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%param_0.1279, %mul.1724), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} + %add.849 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%exp.70, %jit_silu_.45), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/add" stack_frame_id=0} + %div.849 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} divide(%jit_silu_.45, %add.849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/div" stack_frame_id=0} + %mul.2099 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.2100, %div.849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} + %mul.2098 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%param_0.1286, %mul.2100), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} %sub.98 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} subtract(%jit_silu_.45, %div.849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/sub" stack_frame_id=0} - %mul.1721 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%div.849, %sub.98), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/mul" stack_frame_id=0} - %mul.1720 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.1722, %mul.1721), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} - ROOT %add_any.145 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%mul.1723, %mul.1720), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/add_any" stack_frame_id=0} + %mul.2097 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%div.849, %sub.98), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/mul" stack_frame_id=0} + %mul.2096 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} multiply(%mul.2098, %mul.2097), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/mul" stack_frame_id=0} + ROOT %add_any.145 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} add(%mul.2099, %mul.2096), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/add_any" stack_frame_id=0} } -%fused_computation.63.clone.clone (param_0.1280: bf16[4,128,4096], param_1.1406: bf16[4096], param_2.1187: bf16[4,128,4096], param_3.854: bf16[4,4096,14336], param_4.534: s32[], param_5.435: bf16[4,128,14336], param_6.304: bf16[4,128,14336], param_7.200: bf16[14336,4,128]) -> (f32[4,128], bf16[4,128,4096]) { - %param_0.1280 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.1108 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1280), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_2.1187 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) +%fused_computation.63.clone.clone (param_0.1287: bf16[4,128,4096], param_1.1410: bf16[4096], param_2.1181: bf16[4,128,4096], param_3.842: bf16[4,4096,14336], param_4.528: s32[], param_5.435: bf16[4,128,14336], param_6.305: bf16[4,128,14336], param_7.200: bf16[14336,4,128]) -> (f32[4,128], bf16[4,128,4096]) { + %param_0.1287 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1124 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_0.1287), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1181 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) %param_5.435 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(5) - %param_6.304 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(6) + %param_6.305 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)S(1)} parameter(6) %param_7.200 = bf16[14336,4,128]{0,2,1:T(8,128)(2,1)S(1)} parameter(7) - %fusion.134.clone.3 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} fusion(%param_5.435, %param_6.304, %param_7.200), kind=kLoop, calls=%fused_computation.38.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/add_any" stack_frame_id=0} - %param_3.854 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(3) - %param_4.534 = s32[]{:T(128)S(6)} parameter(4) - %fusion.79.clone.3 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_3.854, %param_4.534), kind=kLoop, calls=%fused_computation.11.clone.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %fusion.134.clone.3 = bf16[4,128,14336]{2,1,0:T(8,128)(2,1)} fusion(%param_5.435, %param_6.305, %param_7.200), kind=kLoop, calls=%fused_computation.38.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/jit(silu)/add_any" stack_frame_id=0} + %param_3.842 = bf16[4,4096,14336]{2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.528 = s32[]{:T(128)S(6)} parameter(4) + %fusion.79.clone.3 = bf16[4096,14336,1]{1,0,2:T(8,128)(2,1)} fusion(%param_3.842, %param_4.528), kind=kLoop, calls=%fused_computation.11.clone.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} %convolution.60.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convolution(%fusion.134.clone.3, %fusion.79.clone.3), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} - %add_any.132.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%param_2.1187, %convolution.60.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} - %param_1.1406 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(1) - %dot_general.434 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_1.1406), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %dot_general.433 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%add_any.132.clone.3, %dot_general.434), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/...k,k->...k/dot_general" stack_frame_id=0} - %convert_element_type.1107 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%dot_general.433), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} - %mul.1725 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1108, %convert_element_type.1107), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %constant.1148 = f32[]{:T(128)} constant(0) - %reduce.178 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1725, %constant.1148), dimensions={2}, to_apply=%region_14.17, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum" stack_frame_id=0} - ROOT %tuple.189 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.178, %add_any.132.clone.3) -} - -%fused_computation.140.clone.clone (param_0.1281: f32[4,128], param_1.1407: f32[4,128]) -> f32[4,128] { - %param_0.1281 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %param_1.1407 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %constant.1152 = f32[]{:T(128)} constant(0.000244140625) - %closed_call.89 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1152), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.851 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1407, %closed_call.89), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1151 = f32[]{:T(128)} constant(1e-05) - %closed_call.88 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1151), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %add.864 = f32[4,128]{1,0:T(4,128)} add(%div.851, %closed_call.88), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - %rsqrt.99 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.864), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} - %div.850 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.99, %add.864), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1150 = f32[]{:T(128)} constant(-0.5) - %closed_call.87 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1150), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %mul.1728 = f32[4,128]{1,0:T(4,128)} multiply(%div.850, %closed_call.87), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1727 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1281, %mul.1728), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %constant.1149 = f32[]{:T(128)} constant(0.00048828125) - %mul.1729 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1149), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - ROOT %mul.1726 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.1727, %mul.1729), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} -} - -%region_20.24 (dot_general.187: bf16[], dot_general.188: bf16[]) -> bf16[] { - %dot_general.187 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/...k,k->...k/dot_general"} - %dot_general.188 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/...k,k->...k/dot_general"} - ROOT %add.173 = bf16[]{:T(256)} add(%dot_general.187, %dot_general.188), metadata={op_name="add.39"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.94.clone.clone (param_0.1282: bf16[4,128,4096], param_1.1408: f32[4,128], param_2.1188: bf16[4,128,4096], param_3.855: bf16[4,128,4096], param_4.535: f32[4,128], param_5.436: bf16[4096]) -> (bf16[4096], bf16[4,128,4096]) { - %param_0.1282 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %param_2.1188 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %convert_element_type.1110 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_2.1188), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_1.1408 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %mul.1731 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1408), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.1730 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1110, %mul.1731), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1109 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.1730), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %multiply.271 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_0.1282, %convert_element_type.1109), metadata={op_name="multiply.204"} - %constant.1153 = bf16[]{:T(256)} constant(0) - %reduce.179 = bf16[4096]{0:T(1024)(128)(2,1)} reduce(%multiply.271, %constant.1153), dimensions={0,1}, to_apply=%region_20.24, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_3.855 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %add_any.132.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%param_2.1181, %convolution.60.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} + %param_1.1410 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(1) + %mul.2103 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_1.1410), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2102 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%add_any.132.clone.3, %mul.2103), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %convert_element_type.1123 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%mul.2102), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} + %mul.2101 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1124, %convert_element_type.1123), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %constant.1140 = f32[]{:T(128)} constant(0) + %reduce.139 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.2101, %constant.1140), dimensions={2}, to_apply=%region_14.17, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.189 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.139, %add_any.132.clone.3) +} + +%fused_computation.140.clone.clone (param_0.1288: f32[4,128], param_1.1411: f32[4,128]) -> f32[4,128] { + %param_0.1288 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %param_1.1411 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %constant.1144 = f32[]{:T(128)} constant(0.000244140625) + %closed_call.89 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1144), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.851 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1411, %closed_call.89), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1143 = f32[]{:T(128)} constant(1e-05) + %closed_call.88 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1143), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.850 = f32[4,128]{1,0:T(4,128)} add(%div.851, %closed_call.88), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %rsqrt.99 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.850), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} + %div.850 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.99, %add.850), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1142 = f32[]{:T(128)} constant(-0.5) + %closed_call.87 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1142), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %mul.2106 = f32[4,128]{1,0:T(4,128)} multiply(%div.850, %closed_call.87), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2105 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1288, %mul.2106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %constant.1141 = f32[]{:T(128)} constant(0.00048828125) + %mul.2107 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1141), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + ROOT %mul.2104 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.2105, %mul.2107), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} +} + +%region_20.24 (reduce_sum.175: bf16[], reduce_sum.179: bf16[]) -> bf16[] { + %reduce_sum.175 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} + %reduce_sum.179 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum"} + ROOT %reduce_sum.180 = bf16[]{:T(256)} add(%reduce_sum.175, %reduce_sum.179), metadata={op_name="checkpoint/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.93.clone.clone (param_0.1289: bf16[4,128,4096], param_1.1412: f32[4,128], param_2.1182: bf16[4,128,4096], param_3.843: bf16[4,128,4096], param_4.529: f32[4,128], param_5.436: bf16[4096]) -> (bf16[4096], bf16[4,128,4096]) { + %param_2.1182 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %convert_element_type.1126 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%param_2.1182), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1412 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2110 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_1.1412), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2109 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1126, %mul.2110), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1125 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%mul.2109), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1289 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %mul.2108 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1125, %param_0.1289), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %constant.1145 = bf16[]{:T(256)} constant(0) + %reduce.140 = bf16[4096]{0:T(1024)(128)(2,1)} reduce(%mul.2108, %constant.1145), dimensions={0,1}, to_apply=%region_20.24, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/reduce_sum" stack_frame_id=0} + %param_3.843 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) %param_5.436 = bf16[4096]{0:T(1024)(128)(2,1)S(1)} parameter(5) - %dot_general.286.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_5.436), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %dot_general.263.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_0.1282, %dot_general.286.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/...k,k->...k/dot_general" stack_frame_id=0} - %convert_element_type.753.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%dot_general.263.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} - %mul.1142.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.753.clone.3, %mul.1731), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %param_4.535 = f32[4,128]{1,0:T(4,128)S(1)} parameter(4) - %mul.1151.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_4.535), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %mul.1141.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1110, %mul.1151.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} - %add_any.126.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} add(%mul.1142.clone.3, %mul.1141.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} - %convert_element_type.751.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%add_any.126.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} - %add_any.124.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%param_3.855, %convert_element_type.751.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} - ROOT %tuple.190 = (bf16[4096]{0:T(1024)(128)(2,1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.179, %add_any.124.clone.3) -} - -%region_15.18 (dot_general.184: f32[], dot_general.185: f32[]) -> f32[] { - %dot_general.184 = f32[]{:T(128)} parameter(0), metadata={op_name="vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general"} - %dot_general.185 = f32[]{:T(128)} parameter(1), metadata={op_name="vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general"} - ROOT %add.169 = f32[]{:T(128)} add(%dot_general.184, %dot_general.185), metadata={op_name="add.31"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.25.clone.clone.clone.clone.clone.clone.clone (param_0.1283: bf16[4,32,128,4096], param_1.1409: s32[]) -> bf16[32,128,4096,1] { - %param_0.1283 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1409 = s32[]{:T(128)S(6)} parameter(1) - %constant.1154 = s32[]{:T(128)} constant(0) - %dynamic_slice.331 = bf16[1,32,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1283, %param_1.1409, %constant.1154, %constant.1154, %constant.1154), dynamic_slice_sizes={1,32,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %mul.1399.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} broadcast(%param_5.436), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.1353.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} multiply(%param_0.1289, %mul.1399.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %convert_element_type.769.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} convert(%mul.1353.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} + %mul.1333.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.769.clone.3, %mul.2110), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %param_4.529 = f32[4,128]{1,0:T(4,128)S(1)} parameter(4) + %mul.1344.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} broadcast(%param_4.529), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %mul.1332.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} multiply(%convert_element_type.1126, %mul.1344.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/mul" stack_frame_id=0} + %add_any.126.clone.3 = f32[4,128,4096]{2,1,0:T(8,128)} add(%mul.1333.clone.3, %mul.1332.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} + %convert_element_type.767.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)} convert(%add_any.126.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/convert_element_type" stack_frame_id=0} + %add_any.124.clone.3 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} add(%param_3.843, %convert_element_type.767.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} + ROOT %tuple.190 = (bf16[4096]{0:T(1024)(128)(2,1)}, bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.140, %add_any.124.clone.3) +} + +%region_15.18 (dot_general.157: f32[], dot_general.158: f32[]) -> f32[] { + %dot_general.157 = f32[]{:T(128)} parameter(0), metadata={op_name="vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general"} + %dot_general.158 = f32[]{:T(128)} parameter(1), metadata={op_name="vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general"} + ROOT %add.157 = f32[]{:T(128)} add(%dot_general.157, %dot_general.158), metadata={op_name="add.31"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.25.clone.clone.clone.clone.clone.clone.clone (param_0.1290: bf16[4,32,128,4096], param_1.1413: s32[]) -> bf16[32,128,4096,1] { + %param_0.1290 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1413 = s32[]{:T(128)S(6)} parameter(1) + %constant.1146 = s32[]{:T(128)} constant(0) + %dynamic_slice.331 = bf16[1,32,128,4096]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1290, %param_1.1413, %constant.1146, %constant.1146, %constant.1146), dynamic_slice_sizes={1,32,128,4096}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} ROOT %bitcast.573 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dynamic_slice.331), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.76.clone.clone.clone.clone.clone.clone (param_0.1284: bf16[4,128,4096]) -> bf16[4,128,4096,1] { - %param_0.1284 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - ROOT %bitcast.574 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%param_0.1284), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} +%fused_computation.76.clone.clone.clone.clone.clone.clone (param_0.1291: bf16[4,128,4096]) -> bf16[4,128,4096,1] { + %param_0.1291 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.574 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%param_0.1291), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} } -%fused_computation.66.clone.clone (param_0.1285: bf16[4,32,128,128], param_1.1410: bf16[4,32,128,4096], param_2.1189: s32[], param_3.856: bf16[4,128,4096]) -> (f32[4,32,128], bf16[4,32,128,128]) { - %param_0.1285 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert.124 = f32[4,32,128,128]{3,2,1,0:T(8,128)} convert(%param_0.1285), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/convert" stack_frame_id=0} - %param_3.856 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %fusion.95.clone.3 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_3.856), kind=kLoop, calls=%fused_computation.76.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} - %param_1.1410 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1189 = s32[]{:T(128)S(6)} parameter(2) - %fusion.94.clone.3 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_1.1410, %param_2.1189), kind=kLoop, calls=%fused_computation.25.clone.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.66.clone.clone (param_0.1292: bf16[4,32,128,128], param_1.1414: bf16[4,32,128,4096], param_2.1183: s32[], param_3.844: bf16[4,128,4096]) -> (f32[4,32,128], bf16[4,32,128,128]) { + %param_0.1292 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert.87 = f32[4,32,128,128]{3,2,1,0:T(8,128)} convert(%param_0.1292), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/convert" stack_frame_id=0} + %param_3.844 = bf16[4,128,4096]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %fusion.95.clone.3 = bf16[4,128,4096,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_3.844), kind=kLoop, calls=%fused_computation.76.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/add_any" stack_frame_id=0} + %param_1.1414 = bf16[4,32,128,4096]{3,2,1,0:T(8,128)(2,1)} parameter(1) + %param_2.1183 = s32[]{:T(128)S(6)} parameter(2) + %fusion.94.clone.3 = bf16[32,128,4096,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_1.1414, %param_2.1183), kind=kLoop, calls=%fused_computation.25.clone.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} %convolution.64.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} convolution(%fusion.95.clone.3, %fusion.94.clone.3), window={size=1x32 pad=0_0x31_31 rhs_reversal=0x1}, dim_labels=0bf1_1oi0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} - %constant.619.clone.3 = bf16[]{:T(256)} constant(0.25) - %div.442.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%constant.619.clone.3), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/div" stack_frame_id=0} + %constant.611.clone.3 = bf16[]{:T(256)} constant(0.25) + %div.442.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%constant.611.clone.3), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/div" stack_frame_id=0} %div.441.clone.3 = bf16[4,128,32,128]{3,1,2,0:T(8,128)(2,1)} multiply(%convolution.64.clone.3, %div.442.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/div" stack_frame_id=0} %bitcast.209.clone.3 = bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%div.441.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/div" stack_frame_id=0} - %convert.123 = f32[4,32,128,128]{3,2,1,0:T(8,128)} convert(%bitcast.209.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/convert.1" stack_frame_id=0} - %multiply.272 = f32[4,32,128,128]{3,2,1,0:T(8,128)} multiply(%convert.124, %convert.123), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/multiply" stack_frame_id=0} - %constant.1155 = f32[]{:T(128)} constant(0) - %dot_general.435 = f32[4,32,128]{2,1,0:T(8,128)S(1)} reduce(%multiply.272, %constant.1155), dimensions={3}, to_apply=%region_15.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general" stack_frame_id=0} - ROOT %tuple.191 = (f32[4,32,128]{2,1,0:T(8,128)S(1)}, bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)}) tuple(%dot_general.435, %bitcast.209.clone.3) + %convert.86 = f32[4,32,128,128]{3,2,1,0:T(8,128)} convert(%bitcast.209.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/convert.1" stack_frame_id=0} + %multiply.196 = f32[4,32,128,128]{3,2,1,0:T(8,128)} multiply(%convert.87, %convert.86), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/multiply" stack_frame_id=0} + %constant.1147 = f32[]{:T(128)} constant(0) + %dot_general.189 = f32[4,32,128]{2,1,0:T(8,128)S(1)} reduce(%multiply.196, %constant.1147), dimensions={3}, to_apply=%region_15.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/shard_map/vmap(jit(_splash_attention))/hsd,hsd->hs/dot_general" stack_frame_id=0} + ROOT %tuple.191 = (f32[4,32,128]{2,1,0:T(8,128)S(1)}, bf16[4,32,128,128]{3,2,1,0:T(8,128)(2,1)S(1)}) tuple(%dot_general.189, %bitcast.209.clone.3) } diff --git a/tests/utils/reference_hlo_qwen3_1.7b.txt b/tests/utils/reference_hlo_qwen3_1.7b.txt index f1ede66966..6bdc2b6141 100644 --- a/tests/utils/reference_hlo_qwen3_1.7b.txt +++ b/tests/utils/reference_hlo_qwen3_1.7b.txt @@ -14,1446 +14,1446 @@ StackFrames %param_1.7 = s32[1024]{0:T(1024)S(1)} parameter(1) %custom-call.1 = s32[1024]{0:T(1024)} custom-call(%param_1.7), custom_call_target="AssumeGatherIndicesInBound", operand_layout_constraints={s32[1024]{0:T(1024)}}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} %slice.6 = s32[512]{0:T(512)} slice(%custom-call.1), slice={[0:512]}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %reshape.444 = s32[4,128]{1,0:T(4,128)} reshape(%slice.6), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.461 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.444), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %gather.4 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} gather(%param_0.2, %transpose.461), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,2048}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %transpose.460 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} transpose(%gather.4), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - ROOT %reshape.443 = bf16[512,2048]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.460), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %reshape.554 = s32[4,128]{1,0:T(4,128)} reshape(%slice.6), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.261 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.554), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %gather.4 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} gather(%param_0.2, %transpose.261), offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2, slice_sizes={1,2048}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %transpose.260 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} transpose(%gather.4), dimensions={0,1,2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %reshape.553 = bf16[512,2048]{1,0:T(8,128)(2,1)S(1)} reshape(%transpose.260), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} } %region_42.47.clone (scatter-add.6: bf16[], scatter-add.7: bf16[]) -> bf16[] { %scatter-add.7 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} %scatter-add.6 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add"} - ROOT %add.584 = bf16[]{:T(256)} add(%scatter-add.6, %scatter-add.7), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + ROOT %add.560 = bf16[]{:T(256)} add(%scatter-add.6, %scatter-add.7), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } %fused_computation.1 (param_0.3: bf16[151936,2048], param_1.5: s32[512], param_2.4: bf16[512,2048]) -> bf16[151936,2048] { %param_0.3 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) %param_1.5 = s32[512]{0:T(512)S(1)} parameter(1) - %reshape.451 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.5), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %transpose.466 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.451), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} - %param_2.4 = bf16[512,2048]{1,0:T(8,128)(2,1)S(1)} parameter(2) - %reshape.452 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} reshape(%param_2.4), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} - %transpose.467 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} transpose(%reshape.452), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} - ROOT %scatter.2 = bf16[151936,2048]{1,0:T(8,128)(2,1)} scatter(%param_0.3, %transpose.466, %transpose.467), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_42.47.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} + %reshape.561 = s32[4,128]{1,0:T(4,128)} reshape(%param_1.5), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %transpose.266 = s32[4,128]{1,0:T(4,128)} transpose(%reshape.561), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} + %param_2.4 = bf16[512,2048]{1,0:T(8,128)(2,1)} parameter(2) + %reshape.562 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} reshape(%param_2.4), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} + %transpose.267 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} transpose(%reshape.562), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while" stack_frame_id=0} + ROOT %scatter.2 = bf16[151936,2048]{1,0:T(8,128)(2,1)} scatter(%param_0.3, %transpose.266, %transpose.267), update_window_dims={2}, inserted_window_dims={0}, scatter_dims_to_operand_dims={0}, index_vector_dim=2, to_apply=%region_42.47.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/scatter-add" stack_frame_id=0} } -%region_71.76 (reduce_sum.464: f32[], reduce_sum.465: f32[]) -> f32[] { - %reduce_sum.465 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.464 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.466 = f32[]{:T(128)} add(%reduce_sum.464, %reduce_sum.465), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_71.76 (reduce_sum.569: f32[], reduce_sum.570: f32[]) -> f32[] { + %reduce_sum.570 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.569 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.571 = f32[]{:T(128)} add(%reduce_sum.569, %reduce_sum.570), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_56.61 (reduce_sum.386: f32[], reduce_sum.387: f32[]) -> f32[] { - %reduce_sum.387 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.386 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.388 = f32[]{:T(128)} add(%reduce_sum.386, %reduce_sum.387), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_56.61 (reduce_sum.488: f32[], reduce_sum.492: f32[]) -> f32[] { + %reduce_sum.492 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.488 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.493 = f32[]{:T(128)} add(%reduce_sum.488, %reduce_sum.492), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.277 (param_0.1368: f32[151936,2048], param_1.1556: f32[], param_2.1314: f32[], param_3.918: f32[], param_4.556: f32[151936,2048], param_5.468: f32[], param_6.358: bf16[151936,2048], param_7.201: bf16[151936,2048,1], param_8.118: pred[], param_9.97: f32[151936,2048]) -> (f32[], f32[151936,2048], f32[151936,2048], f32[151936,2048], f32[]) { - %param_0.1368 = f32[151936,2048]{1,0:T(8,128)} parameter(0) - %param_3.918 = f32[]{:T(128)S(6)} parameter(3) - %mul.1926.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_3.918), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.288 (param_0.1409: f32[151936,2048], param_1.1583: f32[], param_2.1325: f32[], param_3.908: f32[], param_4.547: f32[151936,2048], param_5.481: f32[], param_6.356: bf16[151936,2048], param_7.200: bf16[151936,2048,1], param_8.118: pred[], param_9.97: f32[151936,2048]) -> (f32[], f32[151936,2048], f32[151936,2048], f32[151936,2048], f32[]) { + %param_0.1409 = f32[151936,2048]{1,0:T(8,128)} parameter(0) + %param_3.908 = f32[]{:T(128)S(6)} parameter(3) + %mul.2449.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_3.908), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.118 = pred[]{:T(512)S(6)} parameter(8) %select_n.268.clone.1 = pred[151936,2048]{1,0:T(8,128)(4,1)} broadcast(%param_8.118), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_7.201 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} parameter(7) - %bitcast.464.clone.1 = bf16[151936,2048]{1,0:T(8,128)(2,1)} bitcast(%param_7.201), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} - %convert_element_type.1409.clone.1 = f32[151936,2048]{1,0:T(8,128)} convert(%bitcast.464.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_6.358 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(6) - %convert_element_type.1408.clone.1 = f32[151936,2048]{1,0:T(8,128)} convert(%param_6.358), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %add_any.197.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%convert_element_type.1409.clone.1, %convert_element_type.1408.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add_any" stack_frame_id=0} - %param_5.468 = f32[]{:T(128)} parameter(5) - %div.860.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_5.468), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.859.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add_any.197.clone.1, %div.860.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.267.clone.1 = f32[151936,2048]{1,0:T(8,128)} select(%select_n.268.clone.1, %add_any.197.clone.1, %div.859.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1092.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.844.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1092.clone.1), dimensions={}, metadata={op_name="broadcast.74"} - %mul.1932.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%select_n.267.clone.1, %broadcast.844.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.200 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} parameter(7) + %bitcast.445.clone.1 = bf16[151936,2048]{1,0:T(8,128)(2,1)} bitcast(%param_7.200), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %convert_element_type.1433.clone.1 = f32[151936,2048]{1,0:T(8,128)} convert(%bitcast.445.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_6.356 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(6) + %convert_element_type.1432.clone.1 = f32[151936,2048]{1,0:T(8,128)} convert(%param_6.356), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %add_any.188.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%convert_element_type.1433.clone.1, %convert_element_type.1432.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add_any" stack_frame_id=0} + %param_5.481 = f32[]{:T(128)} parameter(5) + %div.860.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_5.481), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.859.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add_any.188.clone.1, %div.860.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.267.clone.1 = f32[151936,2048]{1,0:T(8,128)} select(%select_n.268.clone.1, %add_any.188.clone.1, %div.859.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1080.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.754.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1080.clone.1), dimensions={}, metadata={op_name="broadcast.74"} + %mul.2455.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%select_n.267.clone.1, %broadcast.754.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_9.97 = f32[151936,2048]{1,0:T(8,128)} parameter(9) - %constant.1096.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1933.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1096.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1931.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_9.97, %mul.1933.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.941.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%mul.1932.clone.1, %mul.1931.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1314 = f32[]{:T(128)S(6)} parameter(2) - %div.856.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_2.1314), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1084.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2456.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1084.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2454.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_9.97, %mul.2456.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.917.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%mul.2455.clone.1, %mul.2454.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1325 = f32[]{:T(128)S(6)} parameter(2) + %div.856.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_2.1325), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.65.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%select_n.267.clone.1, %select_n.267.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1095.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1930.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1095.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1928.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%integer_pow.65.clone.1, %mul.1930.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.556 = f32[151936,2048]{1,0:T(8,128)} parameter(4) - %constant.1094.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1929.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1094.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1927.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_4.556, %mul.1929.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.940.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%mul.1928.clone.1, %mul.1927.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1556 = f32[]{:T(128)S(6)} parameter(1) - %div.855.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_1.1556), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.854.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add.940.clone.1, %div.855.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1083.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2453.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1083.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2451.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%integer_pow.65.clone.1, %mul.2453.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.547 = f32[151936,2048]{1,0:T(8,128)} parameter(4) + %constant.1082.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2452.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1082.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2450.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_4.547, %mul.2452.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.916.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%mul.2451.clone.1, %mul.2450.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1583 = f32[]{:T(128)S(6)} parameter(1) + %div.855.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%param_1.1583), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.854.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add.916.clone.1, %div.855.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.62.clone.1 = f32[151936,2048]{1,0:T(8,128)} sqrt(%div.854.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1093.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.939.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1093.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.938.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%sqrt.62.clone.1, %add.939.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.426.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%div.856.clone.1, %add.938.clone.1), metadata={op_name="multiply.61"} - %div.853.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add.941.clone.1, %multiply.426.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1925.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_0.1368, %broadcast.844.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.937.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%div.853.clone.1, %mul.1925.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1924.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%mul.1926.clone.1, %add.937.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.936.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%param_0.1368, %mul.1924.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.214 = f32[151936,2048]{1,0:T(8,128)} multiply(%add.936.clone.1, %add.936.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1200 = f32[]{:T(128)} constant(0) - %reduce.176 = f32[]{:T(128)} reduce(%square.214, %constant.1200), dimensions={0,1}, to_apply=%region_71.76, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.178.clone.1 = f32[]{:T(128)} reduce(%integer_pow.65.clone.1, %constant.1200), dimensions={0,1}, to_apply=%region_56.61, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.144 = (f32[]{:T(128)}, f32[151936,2048]{1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.176, %add.936.clone.1, %add.940.clone.1, %add.941.clone.1, %reduce.178.clone.1) -} - -%region_43.48 (reduce_sum.317: f32[], reduce_sum.318: f32[]) -> f32[] { - %reduce_sum.318 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.317 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.319 = f32[]{:T(128)} add(%reduce_sum.317, %reduce_sum.318), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.367.clone.clone (param_0.1355: f32[4,128], param_1.1549: bf16[4,128,2048], param_2.1290: bf16[2048]) -> bf16[4,128,2048] { - %param_2.1290 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.480 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1290), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1549 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1451 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1549), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.1355 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.2083 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1355), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.2082 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1451, %mul.2083), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.1450 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2082), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - ROOT %dot_general.479 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.480, %convert_element_type.1450), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.289.clone.clone.clone (param_0.1356: bf16[4,128,151936], param_1.1550: s32[4,128], param_2.1291: f32[4,128], param_3.911: f32[4,128], param_4.546: bf16[4,128], param_5.446: f32[4,128]) -> bf16[4,128,151936] { - %param_5.446 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %mul.2087 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_5.446), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_3.911 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %mul.2086 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.911), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_0.1356 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1454 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1356), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_4.546 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) - %sub.94 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_4.546), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.93 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1454, %sub.94), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %exp.62 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.93), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %mul.2085 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2086, %exp.62), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_2.1291 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %div.966 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1291), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %div.965 = f32[4,128,151936]{2,1,0:T(8,128)} divide(%mul.2085, %div.966), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %param_1.1550 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %eq.49 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1550), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} + %constant.1081.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.915.clone.1 = f32[151936,2048]{1,0:T(8,128)} broadcast(%constant.1081.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.914.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%sqrt.62.clone.1, %add.915.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.287.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%div.856.clone.1, %add.914.clone.1), metadata={op_name="multiply.46"} + %div.853.clone.1 = f32[151936,2048]{1,0:T(8,128)} divide(%add.917.clone.1, %multiply.287.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2448.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%param_0.1409, %broadcast.754.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.913.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%div.853.clone.1, %mul.2448.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2447.clone.1 = f32[151936,2048]{1,0:T(8,128)} multiply(%mul.2449.clone.1, %add.913.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.912.clone.1 = f32[151936,2048]{1,0:T(8,128)} add(%param_0.1409, %mul.2447.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.214 = f32[151936,2048]{1,0:T(8,128)} multiply(%add.912.clone.1, %add.912.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1188 = f32[]{:T(128)} constant(0) + %reduce.106 = f32[]{:T(128)} reduce(%square.214, %constant.1188), dimensions={0,1}, to_apply=%region_71.76, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.108.clone.1 = f32[]{:T(128)} reduce(%integer_pow.65.clone.1, %constant.1188), dimensions={0,1}, to_apply=%region_56.61, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.145 = (f32[]{:T(128)}, f32[151936,2048]{1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[151936,2048]{1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.106, %add.912.clone.1, %add.916.clone.1, %add.917.clone.1, %reduce.108.clone.1) +} + +%region_43.48 (reduce_sum.422: f32[], reduce_sum.423: f32[]) -> f32[] { + %reduce_sum.423 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.422 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.424 = f32[]{:T(128)} add(%reduce_sum.422, %reduce_sum.423), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.378.clone.clone (param_0.1396: f32[4,128], param_1.1576: bf16[4,128,2048], param_2.1301: bf16[2048]) -> bf16[4,128,2048] { + %param_1.1576 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1475 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1576), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1396 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2627 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1396), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2626 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1475, %mul.2627), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1474 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2626), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.1301 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2628 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1301), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.2625 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1474, %mul.2628), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%fused_computation.300.clone.clone.clone (param_0.1397: bf16[4,128,151936], param_1.1577: s32[4,128], param_2.1302: f32[4,128], param_3.901: f32[4,128], param_4.537: bf16[4,128], param_5.459: f32[4,128]) -> bf16[4,128,151936] { + %param_5.459 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.2632 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_5.459), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.901 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.2631 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.901), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.1397 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1478 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1397), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_4.537 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.92 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_4.537), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.91 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1478, %sub.92), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.62 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.91), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %mul.2630 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2631, %exp.62), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.1302 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.966 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1302), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.965 = f32[4,128,151936]{2,1,0:T(8,128)} divide(%mul.2630, %div.966), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.1577 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.49 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1577), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.48 = s32[4,128,151936]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.47 = pred[4,128,151936]{2,1,0:T(8,128)(4,1)} compare(%eq.49, %eq.48), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %convert_element_type.1453 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%eq.47), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} - %sub.92 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%div.965, %convert_element_type.1453), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} - %mul.2084 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2087, %sub.92), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - ROOT %convert_element_type.1452 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convert(%mul.2084), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %convert_element_type.1477 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%eq.47), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.90 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%div.965, %convert_element_type.1477), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.2629 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2632, %sub.90), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.1476 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convert(%mul.2629), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} } -%fused_computation.281 (param_0.1381: bf16[151936,2048], param_1.1569: f32[4,128], param_2.1327: bf16[4,128,2048], param_3.931: bf16[2048], param_4.569: bf16[4,128,151936], param_5.481: s32[4,128], param_6.371: f32[4,128], param_7.214: f32[4,128], param_8.131: bf16[4,128], param_9.98: f32[4,128]) -> (f32[], bf16[151936,2048,1]) { - %param_4.569 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(4) - %param_5.481 = s32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %param_6.371 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) - %param_7.214 = f32[4,128]{1,0:T(4,128)S(1)} parameter(7) +%fused_computation.292 (param_0.1422: bf16[151936,2048], param_1.1596: f32[4,128], param_2.1338: bf16[4,128,2048], param_3.921: bf16[2048], param_4.560: bf16[4,128,151936], param_5.494: s32[4,128], param_6.369: f32[4,128], param_7.213: f32[4,128], param_8.131: bf16[4,128], param_9.98: f32[4,128]) -> (f32[], bf16[151936,2048,1]) { + %param_4.560 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(4) + %param_5.494 = s32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %param_6.369 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.213 = f32[4,128]{1,0:T(4,128)S(1)} parameter(7) %param_8.131 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(8) %param_9.98 = f32[4,128]{1,0:T(4,128)S(1)} parameter(9) - %multiply_convert_fusion.1.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} fusion(%param_4.569, %param_5.481, %param_6.371, %param_7.214, %param_8.131, /*index=5*/%param_9.98), kind=kLoop, calls=%fused_computation.289.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_1.1569 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.1327 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %param_3.931 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %fusion.269.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1569, %param_2.1327, %param_3.931), kind=kLoop, calls=%fused_computation.367.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %convolution.86.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} convolution(%multiply_convert_fusion.1.clone.1, %fusion.269.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} - %bitcast.333 = bf16[151936,2048]{1,0:T(8,128)(2,1)} bitcast(%convolution.86.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} - %convert_element_type.1323 = f32[151936,2048]{1,0:T(8,128)} convert(%bitcast.333), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_0.1381 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1322 = f32[151936,2048]{1,0:T(8,128)} convert(%param_0.1381), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %add_any.184 = f32[151936,2048]{1,0:T(8,128)} add(%convert_element_type.1323, %convert_element_type.1322), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add_any" stack_frame_id=0} - %square.215 = f32[151936,2048]{1,0:T(8,128)} multiply(%add_any.184, %add_any.184), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1213 = f32[]{:T(128)} constant(0) - %reduce.177 = f32[]{:T(128)} reduce(%square.215, %constant.1213), dimensions={0,1}, to_apply=%region_43.48, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.166 = (f32[]{:T(128)}, bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.177, %convolution.86.clone.1) + %multiply_convert_fusion.2.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} fusion(%param_4.560, %param_5.494, %param_6.369, %param_7.213, %param_8.131, /*index=5*/%param_9.98), kind=kLoop, calls=%fused_computation.300.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_1.1596 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1338 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.921 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.279.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1596, %param_2.1338, %param_3.921), kind=kLoop, calls=%fused_computation.378.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convolution.86.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} convolution(%multiply_convert_fusion.2.clone.1, %fusion.279.clone.1), window={size=4}, dim_labels=0fb_0io->bf0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %bitcast.314 = bf16[151936,2048]{1,0:T(8,128)(2,1)} bitcast(%convolution.86.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %convert_element_type.1347 = f32[151936,2048]{1,0:T(8,128)} convert(%bitcast.314), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_0.1422 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1346 = f32[151936,2048]{1,0:T(8,128)} convert(%param_0.1422), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %add_any.175 = f32[151936,2048]{1,0:T(8,128)} add(%convert_element_type.1347, %convert_element_type.1346), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add_any" stack_frame_id=0} + %square.215 = f32[151936,2048]{1,0:T(8,128)} multiply(%add_any.175, %add_any.175), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1201 = f32[]{:T(128)} constant(0) + %reduce.107 = f32[]{:T(128)} reduce(%square.215, %constant.1201), dimensions={0,1}, to_apply=%region_43.48, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.167 = (f32[]{:T(128)}, bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)}) tuple(%reduce.107, %convolution.86.clone.1) } -%region_57.62 (reduce_sum.389: f32[], reduce_sum.393: f32[]) -> f32[] { - %reduce_sum.393 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.389 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.394 = f32[]{:T(128)} add(%reduce_sum.389, %reduce_sum.393), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_57.62 (reduce_sum.494: f32[], reduce_sum.495: f32[]) -> f32[] { + %reduce_sum.495 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.494 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.499 = f32[]{:T(128)} add(%reduce_sum.494, %reduce_sum.495), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.288 (param_0.1392: bf16[4,128,151936], param_1.1577: f32[4,128], param_2.1330: s32[4,128], param_3.933: bf16[4,128]) -> f32[4,128] { - %param_2.1330 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %eq.30 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1330), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} +%fused_computation.299 (param_0.1433: bf16[4,128,151936], param_1.1604: f32[4,128], param_2.1341: s32[4,128], param_3.923: bf16[4,128]) -> f32[4,128] { + %param_2.1341 = s32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %eq.30 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1341), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.25 = s32[4,128,151936]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.24 = pred[4,128,151936]{2,1,0:T(8,128)(4,1)} compare(%eq.30, %eq.25), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %param_0.1392 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1340 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1392), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_3.933 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) - %sub.73 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.933), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.64 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1340, %sub.73), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %param_1.1577 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %sub.71 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1577), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_0.1433 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1364 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1433), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_3.923 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(3) + %sub.73 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.923), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.64 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1364, %sub.73), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %param_1.1604 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %sub.71 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1604), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %sub.60 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%sub.64, %sub.71), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %constant.1225 = f32[]{:T(128)} constant(0) - %broadcast.769 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%constant.1225), dimensions={}, metadata={op_name="broadcast.109"} - %mul.1765 = f32[4,128,151936]{2,1,0:T(8,128)} select(%eq.24, %sub.60, %broadcast.769), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - ROOT %reduce.179 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1765, %constant.1225), dimensions={2}, to_apply=%region_57.62, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %constant.1213 = f32[]{:T(128)} constant(0) + %broadcast.681 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%constant.1213), dimensions={}, metadata={op_name="broadcast.99"} + %mul.2269 = f32[4,128,151936]{2,1,0:T(8,128)} select(%eq.24, %sub.60, %broadcast.681), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + ROOT %reduce.109 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.2269, %constant.1213), dimensions={2}, to_apply=%region_57.62, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } -%region_9.12 (reduce_sum.186: f32[], reduce_sum.190: f32[]) -> f32[] { - %reduce_sum.190 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.186 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.191 = f32[]{:T(128)} add(%reduce_sum.186, %reduce_sum.190), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_9.12 (reduce_sum.237: f32[], reduce_sum.241: f32[]) -> f32[] { + %reduce_sum.241 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.237 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.262 = f32[]{:T(128)} add(%reduce_sum.237, %reduce_sum.241), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.293 (param_0.1393: bf16[4,128,151936], param_1.1578: bf16[4,128]) -> f32[4,128] { - %param_0.1393 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1346 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1393), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_1.1578 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) - %sub.74 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1578), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.70 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1346, %sub.74), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} +%fused_computation.304 (param_0.1434: bf16[4,128,151936], param_1.1605: bf16[4,128]) -> f32[4,128] { + %param_0.1434 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1370 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1434), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_1.1605 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(1) + %sub.74 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1605), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.70 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1370, %sub.74), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} %exp.54 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.70), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %constant.1226 = f32[]{:T(128)} constant(0) - ROOT %reduce.180 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.54, %constant.1226), dimensions={2}, to_apply=%region_9.12, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %constant.1214 = f32[]{:T(128)} constant(0) + ROOT %reduce.110 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%exp.54, %constant.1214), dimensions={2}, to_apply=%region_9.12, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} } -%region_33.38 (reduce_sum.269: f32[], reduce_sum.270: f32[]) -> f32[] { - %reduce_sum.270 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.269 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.274 = f32[]{:T(128)} add(%reduce_sum.269, %reduce_sum.270), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_33.38 (reduce_sum.374: f32[], reduce_sum.375: f32[]) -> f32[] { + %reduce_sum.375 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.374 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.376 = f32[]{:T(128)} add(%reduce_sum.374, %reduce_sum.375), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.298 (param_0.1387: f32[4,6144,2048]) -> f32[] { - %param_0.1387 = f32[4,6144,2048]{2,0,1:T(4,128)} parameter(0) - %bitcast.347 = f32[6144,4,2048]{2,1,0:T(4,128)} bitcast(%param_0.1387), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.218 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%bitcast.347, %bitcast.347), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1219 = f32[]{:T(128)} constant(0) - ROOT %reduce.181 = f32[]{:T(128)} reduce(%square.218, %constant.1219), dimensions={0,1,2}, to_apply=%region_33.38, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.309 (param_0.1428: f32[4,6144,2048]) -> f32[] { + %param_0.1428 = f32[4,6144,2048]{2,0,1:T(4,128)} parameter(0) + %bitcast.328 = f32[6144,4,2048]{2,1,0:T(4,128)} bitcast(%param_0.1428), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.218 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%bitcast.328, %bitcast.328), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1207 = f32[]{:T(128)} constant(0) + ROOT %reduce.111 = f32[]{:T(128)} reduce(%square.218, %constant.1207), dimensions={0,1,2}, to_apply=%region_33.38, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_32.37 (reduce_sum.263: f32[], reduce_sum.267: f32[]) -> f32[] { - %reduce_sum.267 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.263 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.268 = f32[]{:T(128)} add(%reduce_sum.263, %reduce_sum.267), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_32.37 (reduce_sum.368: f32[], reduce_sum.369: f32[]) -> f32[] { + %reduce_sum.369 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.368 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.373 = f32[]{:T(128)} add(%reduce_sum.368, %reduce_sum.369), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_31.36 (reduce_sum.260: f32[], reduce_sum.261: f32[]) -> f32[] { - %reduce_sum.261 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.260 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.262 = f32[]{:T(128)} add(%reduce_sum.260, %reduce_sum.261), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_31.36 (reduce_sum.362: f32[], reduce_sum.366: f32[]) -> f32[] { + %reduce_sum.366 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.362 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.367 = f32[]{:T(128)} add(%reduce_sum.362, %reduce_sum.366), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.300 (param_0.1388: f32[4,2048,6144], param_1.1573: f32[4,2048,6144]) -> (f32[], f32[]) { - %param_0.1388 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(0) - %bitcast.351 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_0.1388), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.221 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%bitcast.351, %bitcast.351), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1220 = f32[]{:T(128)} constant(0) - %reduce.182 = f32[]{:T(128)} reduce(%square.221, %constant.1220), dimensions={0,1,2}, to_apply=%region_32.37, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1573 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(1) - %bitcast.355.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_1.1573), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.224.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%bitcast.355.clone.1, %bitcast.355.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.183.clone.1 = f32[]{:T(128)} reduce(%square.224.clone.1, %constant.1220), dimensions={0,1,2}, to_apply=%region_31.36, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.167 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.182, %reduce.183.clone.1) +%fused_computation.311 (param_0.1429: f32[4,2048,6144], param_1.1600: f32[4,2048,6144]) -> (f32[], f32[]) { + %param_0.1429 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(0) + %bitcast.332 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_0.1429), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.221 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%bitcast.332, %bitcast.332), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1208 = f32[]{:T(128)} constant(0) + %reduce.112 = f32[]{:T(128)} reduce(%square.221, %constant.1208), dimensions={0,1,2}, to_apply=%region_32.37, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1600 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(1) + %bitcast.336.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_1.1600), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.224.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%bitcast.336.clone.1, %bitcast.336.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.113.clone.1 = f32[]{:T(128)} reduce(%square.224.clone.1, %constant.1208), dimensions={0,1,2}, to_apply=%region_31.36, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.168 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.112, %reduce.113.clone.1) } -%fused_computation.303 (param_0.901: f32[6144,4,2048]) -> bf16[4,6144,2048] { - %param_0.901 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(0) - %copy.190 = bf16[6144,4,2048]{2,0,1:T(8,128)(2,1)} copy(%param_0.901), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wo\'][\'kernel\']"} - ROOT %bitcast.356 = bf16[4,6144,2048]{2,1,0:T(8,128)(2,1)} bitcast(%copy.190), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.314 (param_0.940: f32[6144,4,2048]) -> bf16[4,6144,2048] { + %param_0.940 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(0) + %copy.186 = bf16[6144,4,2048]{2,0,1:T(8,128)(2,1)} copy(%param_0.940), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wo\'][\'kernel\']"} + ROOT %bitcast.337 = bf16[4,6144,2048]{2,1,0:T(8,128)(2,1)} bitcast(%copy.186), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%fused_computation.304 (param_0.903: f32[2048,4,6144]) -> bf16[4,2048,6144] { - %param_0.903 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) - %copy.191 = bf16[2048,4,6144]{2,0,1:T(8,128)(2,1)} copy(%param_0.903), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_1\'][\'kernel\']"} - ROOT %bitcast.357 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} bitcast(%copy.191), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.315 (param_0.942: f32[2048,4,6144]) -> bf16[4,2048,6144] { + %param_0.942 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %copy.187 = bf16[2048,4,6144]{2,0,1:T(8,128)(2,1)} copy(%param_0.942), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_1\'][\'kernel\']"} + ROOT %bitcast.338 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} bitcast(%copy.187), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%fused_computation.305 (param_0.905: f32[2048,4,6144]) -> bf16[4,2048,6144] { - %param_0.905 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) - %copy.192 = bf16[2048,4,6144]{2,0,1:T(8,128)(2,1)} copy(%param_0.905), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_0\'][\'kernel\']"} - ROOT %bitcast.358 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} bitcast(%copy.192), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.316 (param_0.944: f32[2048,4,6144]) -> bf16[4,2048,6144] { + %param_0.944 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %copy.188 = bf16[2048,4,6144]{2,0,1:T(8,128)(2,1)} copy(%param_0.944), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'mlp\'][\'wi_0\'][\'kernel\']"} + ROOT %bitcast.339 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} bitcast(%copy.188), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%region_62.67 (reduce_sum.416: f32[], reduce_sum.417: f32[]) -> f32[] { - %reduce_sum.417 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.416 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.421 = f32[]{:T(128)} add(%reduce_sum.416, %reduce_sum.417), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_62.67 (reduce_sum.521: f32[], reduce_sum.522: f32[]) -> f32[] { + %reduce_sum.522 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.521 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.523 = f32[]{:T(128)} add(%reduce_sum.521, %reduce_sum.522), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_47.52 (reduce_sum.338: f32[], reduce_sum.339: f32[]) -> f32[] { - %reduce_sum.339 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.338 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.340 = f32[]{:T(128)} add(%reduce_sum.338, %reduce_sum.339), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_47.52 (reduce_sum.443: f32[], reduce_sum.444: f32[]) -> f32[] { + %reduce_sum.444 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.443 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.445 = f32[]{:T(128)} add(%reduce_sum.443, %reduce_sum.444), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.306 (param_0.1377: f32[6144,4,2048], param_1.1565: f32[], param_2.1323: f32[], param_3.927: f32[], param_4.565: f32[6144,4,2048], param_5.477: f32[], param_6.367: f32[4,6144,2048], param_7.210: pred[], param_8.127: f32[6144,4,2048]) -> (f32[], f32[6144,4,2048], f32[6144,4,2048], f32[6144,4,2048], f32[]) { - %param_0.1377 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(0) - %param_3.927 = f32[]{:T(128)S(6)} parameter(3) - %mul.1998.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_3.927), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.210 = pred[]{:T(512)S(6)} parameter(7) - %select_n.304.clone.1 = pred[6144,4,2048]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.210), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.367 = f32[4,6144,2048]{2,0,1:T(4,128)} parameter(6) - %bitcast.482.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} bitcast(%param_6.367), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.477 = f32[]{:T(128)} parameter(5) - %div.932.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_5.477), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.931.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%bitcast.482.clone.1, %div.932.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.303.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} select(%select_n.304.clone.1, %bitcast.482.clone.1, %div.931.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1146.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.886.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1146.clone.1), dimensions={}, metadata={op_name="broadcast.83"} - %mul.2004.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%select_n.303.clone.1, %broadcast.886.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.317 (param_0.1418: f32[6144,4,2048], param_1.1592: f32[], param_2.1334: f32[], param_3.917: f32[], param_4.556: f32[6144,4,2048], param_5.490: f32[], param_6.365: f32[4,6144,2048], param_7.209: pred[], param_8.127: f32[6144,4,2048]) -> (f32[], f32[6144,4,2048], f32[6144,4,2048], f32[6144,4,2048], f32[]) { + %param_0.1418 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(0) + %param_3.917 = f32[]{:T(128)S(6)} parameter(3) + %mul.2521.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_3.917), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.209 = pred[]{:T(512)S(6)} parameter(7) + %select_n.304.clone.1 = pred[6144,4,2048]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.209), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.365 = f32[4,6144,2048]{2,0,1:T(4,128)} parameter(6) + %bitcast.463.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} bitcast(%param_6.365), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.490 = f32[]{:T(128)} parameter(5) + %div.932.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_5.490), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.931.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%bitcast.463.clone.1, %div.932.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.303.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} select(%select_n.304.clone.1, %bitcast.463.clone.1, %div.931.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1134.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.796.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1134.clone.1), dimensions={}, metadata={op_name="broadcast.83"} + %mul.2527.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%select_n.303.clone.1, %broadcast.796.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.127 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(8) - %constant.1150.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.2005.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1150.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.2003.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_8.127, %mul.2005.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.989.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%mul.2004.clone.1, %mul.2003.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1323 = f32[]{:T(128)S(6)} parameter(2) - %div.928.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_2.1323), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1138.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2528.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1138.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2526.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_8.127, %mul.2528.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.965.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%mul.2527.clone.1, %mul.2526.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1334 = f32[]{:T(128)S(6)} parameter(2) + %div.928.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_2.1334), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.74.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%select_n.303.clone.1, %select_n.303.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1149.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.2002.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1149.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.2000.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%integer_pow.74.clone.1, %mul.2002.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.565 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(4) - %constant.1148.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.2001.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1148.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1999.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_4.565, %mul.2001.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.988.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%mul.2000.clone.1, %mul.1999.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1565 = f32[]{:T(128)S(6)} parameter(1) - %div.927.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_1.1565), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.926.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%add.988.clone.1, %div.927.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1137.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2525.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1137.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2523.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%integer_pow.74.clone.1, %mul.2525.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.556 = f32[6144,4,2048]{2,1,0:T(4,128)} parameter(4) + %constant.1136.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2524.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1136.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2522.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_4.556, %mul.2524.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.964.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%mul.2523.clone.1, %mul.2522.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1592 = f32[]{:T(128)S(6)} parameter(1) + %div.927.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%param_1.1592), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.926.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%add.964.clone.1, %div.927.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.71.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} sqrt(%div.926.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1147.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.987.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1147.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.986.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%sqrt.71.clone.1, %add.987.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.435.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%div.928.clone.1, %add.986.clone.1), metadata={op_name="multiply.52"} - %div.925.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%add.989.clone.1, %multiply.435.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1997.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_0.1377, %broadcast.886.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.985.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%div.925.clone.1, %mul.1997.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1996.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%mul.1998.clone.1, %add.985.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.984.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%param_0.1377, %mul.1996.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.225 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%add.984.clone.1, %add.984.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1209 = f32[]{:T(128)} constant(0) - %reduce.184 = f32[]{:T(128)} reduce(%square.225, %constant.1209), dimensions={0,1,2}, to_apply=%region_62.67, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.187.clone.1 = f32[]{:T(128)} reduce(%integer_pow.74.clone.1, %constant.1209), dimensions={0,1,2}, to_apply=%region_47.52, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.145 = (f32[]{:T(128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.184, %add.984.clone.1, %add.988.clone.1, %add.989.clone.1, %reduce.187.clone.1) -} - -%region_61.66 (reduce_sum.410: f32[], reduce_sum.414: f32[]) -> f32[] { - %reduce_sum.414 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.410 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.415 = f32[]{:T(128)} add(%reduce_sum.410, %reduce_sum.414), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_46.51 (reduce_sum.332: f32[], reduce_sum.333: f32[]) -> f32[] { - %reduce_sum.333 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.332 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.337 = f32[]{:T(128)} add(%reduce_sum.332, %reduce_sum.333), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.1135.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.963.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} broadcast(%constant.1135.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.962.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%sqrt.71.clone.1, %add.963.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.296.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%div.928.clone.1, %add.962.clone.1), metadata={op_name="multiply.37"} + %div.925.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} divide(%add.965.clone.1, %multiply.296.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2520.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%param_0.1418, %broadcast.796.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.961.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%div.925.clone.1, %mul.2520.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2519.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%mul.2521.clone.1, %add.961.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.960.clone.1 = f32[6144,4,2048]{2,1,0:T(4,128)} add(%param_0.1418, %mul.2519.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.225 = f32[6144,4,2048]{2,1,0:T(4,128)} multiply(%add.960.clone.1, %add.960.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1197 = f32[]{:T(128)} constant(0) + %reduce.114 = f32[]{:T(128)} reduce(%square.225, %constant.1197), dimensions={0,1,2}, to_apply=%region_62.67, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.117.clone.1 = f32[]{:T(128)} reduce(%integer_pow.74.clone.1, %constant.1197), dimensions={0,1,2}, to_apply=%region_47.52, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.146 = (f32[]{:T(128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[6144,4,2048]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.114, %add.960.clone.1, %add.964.clone.1, %add.965.clone.1, %reduce.117.clone.1) +} + +%region_61.66 (reduce_sum.515: f32[], reduce_sum.516: f32[]) -> f32[] { + %reduce_sum.516 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.515 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.520 = f32[]{:T(128)} add(%reduce_sum.515, %reduce_sum.516), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_46.51 (reduce_sum.437: f32[], reduce_sum.438: f32[]) -> f32[] { + %reduce_sum.438 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.437 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.439 = f32[]{:T(128)} add(%reduce_sum.437, %reduce_sum.438), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.307 (param_0.1378: f32[2048,4,6144], param_1.1566: f32[], param_2.1324: f32[], param_3.928: f32[], param_4.566: f32[2048,4,6144], param_5.478: f32[], param_6.368: f32[4,2048,6144], param_7.211: pred[], param_8.128: f32[2048,4,6144]) -> (f32[], f32[2048,4,6144], f32[2048,4,6144], f32[2048,4,6144], f32[]) { - %param_0.1378 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) - %param_3.928 = f32[]{:T(128)S(6)} parameter(3) - %mul.2008.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_3.928), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.211 = pred[]{:T(512)S(6)} parameter(7) - %select_n.308.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.211), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.368 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(6) - %bitcast.484.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_6.368), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.478 = f32[]{:T(128)} parameter(5) - %div.940.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_5.478), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.939.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%bitcast.484.clone.1, %div.940.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.307.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%select_n.308.clone.1, %bitcast.484.clone.1, %div.939.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1152.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.892.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1152.clone.1), dimensions={}, metadata={op_name="broadcast.85"} - %mul.2012.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.307.clone.1, %broadcast.892.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.318 (param_0.1419: f32[2048,4,6144], param_1.1593: f32[], param_2.1335: f32[], param_3.918: f32[], param_4.557: f32[2048,4,6144], param_5.491: f32[], param_6.366: f32[4,2048,6144], param_7.210: pred[], param_8.128: f32[2048,4,6144]) -> (f32[], f32[2048,4,6144], f32[2048,4,6144], f32[2048,4,6144], f32[]) { + %param_0.1419 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %param_3.918 = f32[]{:T(128)S(6)} parameter(3) + %mul.2531.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_3.918), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.210 = pred[]{:T(512)S(6)} parameter(7) + %select_n.308.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.210), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.366 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(6) + %bitcast.465.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_6.366), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.491 = f32[]{:T(128)} parameter(5) + %div.940.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_5.491), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.939.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%bitcast.465.clone.1, %div.940.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.307.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%select_n.308.clone.1, %bitcast.465.clone.1, %div.939.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1140.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.802.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1140.clone.1), dimensions={}, metadata={op_name="broadcast.85"} + %mul.2535.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.307.clone.1, %broadcast.802.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.128 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(8) - %constant.1156.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.891.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1156.clone.1), dimensions={}, metadata={op_name="broadcast.84"} - %mul.2011.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_8.128, %broadcast.891.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.994.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2012.clone.1, %mul.2011.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1324 = f32[]{:T(128)S(6)} parameter(2) - %div.936.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_2.1324), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1144.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.801.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1144.clone.1), dimensions={}, metadata={op_name="broadcast.84"} + %mul.2534.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_8.128, %broadcast.801.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.970.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2535.clone.1, %mul.2534.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1335 = f32[]{:T(128)S(6)} parameter(2) + %div.936.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_2.1335), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.75.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.307.clone.1, %select_n.307.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1155.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.890.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1155.clone.1), dimensions={}, metadata={op_name="broadcast.73"} - %mul.2010.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%integer_pow.75.clone.1, %broadcast.890.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.566 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(4) - %constant.1154.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.889.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1154.clone.1), dimensions={}, metadata={op_name="broadcast.72"} - %mul.2009.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_4.566, %broadcast.889.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.993.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2010.clone.1, %mul.2009.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1566 = f32[]{:T(128)S(6)} parameter(1) - %div.935.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_1.1566), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.934.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.993.clone.1, %div.935.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1143.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.800.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1143.clone.1), dimensions={}, metadata={op_name="broadcast.73"} + %mul.2533.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%integer_pow.75.clone.1, %broadcast.800.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.557 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(4) + %constant.1142.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.799.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1142.clone.1), dimensions={}, metadata={op_name="broadcast.72"} + %mul.2532.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_4.557, %broadcast.799.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.969.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2533.clone.1, %mul.2532.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1593 = f32[]{:T(128)S(6)} parameter(1) + %div.935.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_1.1593), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.934.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.969.clone.1, %div.935.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.72.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} sqrt(%div.934.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1153.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.887.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1153.clone.1), dimensions={}, metadata={op_name="broadcast.65"} - %add.992.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%sqrt.72.clone.1, %broadcast.887.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.436.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%div.936.clone.1, %add.992.clone.1), metadata={op_name="multiply.51"} - %div.933.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.994.clone.1, %multiply.436.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.2007.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_0.1378, %broadcast.892.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.991.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%div.933.clone.1, %mul.2007.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.2006.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%mul.2008.clone.1, %add.991.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.990.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%param_0.1378, %mul.2006.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.226 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%add.990.clone.1, %add.990.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1210 = f32[]{:T(128)} constant(0) - %reduce.185 = f32[]{:T(128)} reduce(%square.226, %constant.1210), dimensions={0,1,2}, to_apply=%region_61.66, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.188.clone.1 = f32[]{:T(128)} reduce(%integer_pow.75.clone.1, %constant.1210), dimensions={0,1,2}, to_apply=%region_46.51, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.146 = (f32[]{:T(128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.185, %add.990.clone.1, %add.993.clone.1, %add.994.clone.1, %reduce.188.clone.1) -} - -%region_60.65 (reduce_sum.407: f32[], reduce_sum.408: f32[]) -> f32[] { - %reduce_sum.408 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.407 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.409 = f32[]{:T(128)} add(%reduce_sum.407, %reduce_sum.408), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_45.50 (reduce_sum.326: f32[], reduce_sum.330: f32[]) -> f32[] { - %reduce_sum.330 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.326 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.331 = f32[]{:T(128)} add(%reduce_sum.326, %reduce_sum.330), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.1141.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.797.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1141.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %add.968.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%sqrt.72.clone.1, %broadcast.797.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.297.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%div.936.clone.1, %add.968.clone.1), metadata={op_name="multiply.36"} + %div.933.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.970.clone.1, %multiply.297.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2530.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_0.1419, %broadcast.802.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.967.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%div.933.clone.1, %mul.2530.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2529.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%mul.2531.clone.1, %add.967.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.966.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%param_0.1419, %mul.2529.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.226 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%add.966.clone.1, %add.966.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1198 = f32[]{:T(128)} constant(0) + %reduce.115 = f32[]{:T(128)} reduce(%square.226, %constant.1198), dimensions={0,1,2}, to_apply=%region_61.66, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.118.clone.1 = f32[]{:T(128)} reduce(%integer_pow.75.clone.1, %constant.1198), dimensions={0,1,2}, to_apply=%region_46.51, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.147 = (f32[]{:T(128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.115, %add.966.clone.1, %add.969.clone.1, %add.970.clone.1, %reduce.118.clone.1) +} + +%region_60.65 (reduce_sum.509: f32[], reduce_sum.513: f32[]) -> f32[] { + %reduce_sum.513 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.509 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.514 = f32[]{:T(128)} add(%reduce_sum.509, %reduce_sum.513), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_45.50 (reduce_sum.431: f32[], reduce_sum.432: f32[]) -> f32[] { + %reduce_sum.432 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.431 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.436 = f32[]{:T(128)} add(%reduce_sum.431, %reduce_sum.432), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.308 (param_0.1379: f32[2048,4,6144], param_1.1567: f32[], param_2.1325: f32[], param_3.929: f32[], param_4.567: f32[2048,4,6144], param_5.479: f32[], param_6.369: f32[4,2048,6144], param_7.212: pred[], param_8.129: f32[2048,4,6144]) -> (f32[], f32[2048,4,6144], f32[2048,4,6144], f32[2048,4,6144], f32[]) { - %param_0.1379 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) - %param_3.929 = f32[]{:T(128)S(6)} parameter(3) - %mul.2015.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_3.929), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.212 = pred[]{:T(512)S(6)} parameter(7) - %select_n.312.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.212), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.369 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(6) - %bitcast.486.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_6.369), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.479 = f32[]{:T(128)} parameter(5) - %div.948.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_5.479), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.947.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%bitcast.486.clone.1, %div.948.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.311.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%select_n.312.clone.1, %bitcast.486.clone.1, %div.947.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1158.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.898.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1158.clone.1), dimensions={}, metadata={op_name="broadcast.85"} - %mul.2019.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.311.clone.1, %broadcast.898.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.319 (param_0.1420: f32[2048,4,6144], param_1.1594: f32[], param_2.1336: f32[], param_3.919: f32[], param_4.558: f32[2048,4,6144], param_5.492: f32[], param_6.367: f32[4,2048,6144], param_7.211: pred[], param_8.129: f32[2048,4,6144]) -> (f32[], f32[2048,4,6144], f32[2048,4,6144], f32[2048,4,6144], f32[]) { + %param_0.1420 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(0) + %param_3.919 = f32[]{:T(128)S(6)} parameter(3) + %mul.2538.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_3.919), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.211 = pred[]{:T(512)S(6)} parameter(7) + %select_n.312.clone.1 = pred[2048,4,6144]{2,1,0:T(4,128)(4,1)} broadcast(%param_7.211), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.367 = f32[4,2048,6144]{2,0,1:T(4,128)} parameter(6) + %bitcast.467.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} bitcast(%param_6.367), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.492 = f32[]{:T(128)} parameter(5) + %div.948.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_5.492), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.947.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%bitcast.467.clone.1, %div.948.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.311.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} select(%select_n.312.clone.1, %bitcast.467.clone.1, %div.947.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1146.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.808.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1146.clone.1), dimensions={}, metadata={op_name="broadcast.85"} + %mul.2542.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.311.clone.1, %broadcast.808.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.129 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(8) - %constant.1162.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.897.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1162.clone.1), dimensions={}, metadata={op_name="broadcast.84"} - %mul.2018.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_8.129, %broadcast.897.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.999.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2019.clone.1, %mul.2018.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1325 = f32[]{:T(128)S(6)} parameter(2) - %div.944.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_2.1325), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1150.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.807.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1150.clone.1), dimensions={}, metadata={op_name="broadcast.84"} + %mul.2541.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_8.129, %broadcast.807.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.975.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2542.clone.1, %mul.2541.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1336 = f32[]{:T(128)S(6)} parameter(2) + %div.944.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_2.1336), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.76.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%select_n.311.clone.1, %select_n.311.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1161.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.896.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1161.clone.1), dimensions={}, metadata={op_name="broadcast.73"} - %mul.2017.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%integer_pow.76.clone.1, %broadcast.896.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.567 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(4) - %constant.1160.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.895.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1160.clone.1), dimensions={}, metadata={op_name="broadcast.72"} - %mul.2016.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_4.567, %broadcast.895.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.998.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2017.clone.1, %mul.2016.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1567 = f32[]{:T(128)S(6)} parameter(1) - %div.943.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_1.1567), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.942.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.998.clone.1, %div.943.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1149.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.806.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1149.clone.1), dimensions={}, metadata={op_name="broadcast.73"} + %mul.2540.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%integer_pow.76.clone.1, %broadcast.806.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.558 = f32[2048,4,6144]{2,1,0:T(4,128)} parameter(4) + %constant.1148.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.805.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1148.clone.1), dimensions={}, metadata={op_name="broadcast.72"} + %mul.2539.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_4.558, %broadcast.805.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.974.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%mul.2540.clone.1, %mul.2539.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1594 = f32[]{:T(128)S(6)} parameter(1) + %div.943.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%param_1.1594), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.942.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.974.clone.1, %div.943.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.73.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} sqrt(%div.942.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1159.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.893.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1159.clone.1), dimensions={}, metadata={op_name="broadcast.65"} - %add.997.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%sqrt.73.clone.1, %broadcast.893.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.437.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%div.944.clone.1, %add.997.clone.1), metadata={op_name="multiply.50"} - %div.941.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.999.clone.1, %multiply.437.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.2014.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_0.1379, %broadcast.898.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.996.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%div.941.clone.1, %mul.2014.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.2013.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%mul.2015.clone.1, %add.996.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.995.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%param_0.1379, %mul.2013.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.227 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%add.995.clone.1, %add.995.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1211 = f32[]{:T(128)} constant(0) - %reduce.186 = f32[]{:T(128)} reduce(%square.227, %constant.1211), dimensions={0,1,2}, to_apply=%region_60.65, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.189.clone.1 = f32[]{:T(128)} reduce(%integer_pow.76.clone.1, %constant.1211), dimensions={0,1,2}, to_apply=%region_45.50, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.147 = (f32[]{:T(128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.186, %add.995.clone.1, %add.998.clone.1, %add.999.clone.1, %reduce.189.clone.1) -} - -%region_39.44 (reduce_sum.302: f32[], reduce_sum.303: f32[]) -> f32[] { - %reduce_sum.303 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.302 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.304 = f32[]{:T(128)} add(%reduce_sum.302, %reduce_sum.303), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.1147.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.803.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} broadcast(%constant.1147.clone.1), dimensions={}, metadata={op_name="broadcast.65"} + %add.973.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%sqrt.73.clone.1, %broadcast.803.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.298.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%div.944.clone.1, %add.973.clone.1), metadata={op_name="multiply.35"} + %div.941.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} divide(%add.975.clone.1, %multiply.298.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2537.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%param_0.1420, %broadcast.808.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.972.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%div.941.clone.1, %mul.2537.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2536.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%mul.2538.clone.1, %add.972.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.971.clone.1 = f32[2048,4,6144]{2,1,0:T(4,128)} add(%param_0.1420, %mul.2536.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.227 = f32[2048,4,6144]{2,1,0:T(4,128)} multiply(%add.971.clone.1, %add.971.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1199 = f32[]{:T(128)} constant(0) + %reduce.116 = f32[]{:T(128)} reduce(%square.227, %constant.1199), dimensions={0,1,2}, to_apply=%region_60.65, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.119.clone.1 = f32[]{:T(128)} reduce(%integer_pow.76.clone.1, %constant.1199), dimensions={0,1,2}, to_apply=%region_45.50, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.148 = (f32[]{:T(128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[2048,4,6144]{2,1,0:T(4,128)}, f32[]{:T(128)}) tuple(%reduce.116, %add.971.clone.1, %add.974.clone.1, %add.975.clone.1, %reduce.119.clone.1) +} + +%region_39.44 (reduce_sum.404: f32[], reduce_sum.408: f32[]) -> f32[] { + %reduce_sum.408 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.404 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.409 = f32[]{:T(128)} add(%reduce_sum.404, %reduce_sum.408), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.324 (param_0.1382: f32[4,2048,16,128]) -> f32[] { - %param_0.1382 = f32[4,2048,16,128]{3,2,0,1:T(8,128)} parameter(0) - %bitcast.362 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1382), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.230 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%bitcast.362, %bitcast.362), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1214 = f32[]{:T(128)} constant(0) - ROOT %reduce.190 = f32[]{:T(128)} reduce(%square.230, %constant.1214), dimensions={0,1,2,3}, to_apply=%region_39.44, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.335 (param_0.1423: f32[4,2048,16,128]) -> f32[] { + %param_0.1423 = f32[4,2048,16,128]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.343 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1423), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.230 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%bitcast.343, %bitcast.343), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1202 = f32[]{:T(128)} constant(0) + ROOT %reduce.120 = f32[]{:T(128)} reduce(%square.230, %constant.1202), dimensions={0,1,2,3}, to_apply=%region_39.44, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%region_38.43 (reduce_sum.296: f32[], reduce_sum.297: f32[]) -> f32[] { - %reduce_sum.297 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.296 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.298 = f32[]{:T(128)} add(%reduce_sum.296, %reduce_sum.297), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_38.43 (reduce_sum.401: f32[], reduce_sum.402: f32[]) -> f32[] { + %reduce_sum.402 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.401 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.403 = f32[]{:T(128)} add(%reduce_sum.401, %reduce_sum.402), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.326 (param_0.1383: f32[4,16,128,2048]) -> f32[] { - %param_0.1383 = f32[4,16,128,2048]{3,2,0,1:T(8,128)} parameter(0) - %bitcast.366 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} bitcast(%param_0.1383), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.233 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%bitcast.366, %bitcast.366), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1215 = f32[]{:T(128)} constant(0) - ROOT %reduce.191 = f32[]{:T(128)} reduce(%square.233, %constant.1215), dimensions={0,1,2,3}, to_apply=%region_38.43, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%fused_computation.337 (param_0.1424: f32[4,16,128,2048]) -> f32[] { + %param_0.1424 = f32[4,16,128,2048]{3,2,0,1:T(8,128)} parameter(0) + %bitcast.347 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} bitcast(%param_0.1424), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.233 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%bitcast.347, %bitcast.347), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1203 = f32[]{:T(128)} constant(0) + ROOT %reduce.121 = f32[]{:T(128)} reduce(%square.233, %constant.1203), dimensions={0,1,2,3}, to_apply=%region_38.43, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%fused_computation.327 (param_0.950: f32[16,4,128,2048]) -> bf16[4,16,128,2048] { - %param_0.950 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(0) - %copy.193 = bf16[16,4,128,2048]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.950), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'out\'][\'kernel\']"} - ROOT %bitcast.367 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.193), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +%fused_computation.338 (param_0.989: f32[16,4,128,2048]) -> bf16[4,16,128,2048] { + %param_0.989 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(0) + %copy.189 = bf16[16,4,128,2048]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.989), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'out\'][\'kernel\']"} + ROOT %bitcast.348 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.189), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} } -%region_68.73 (reduce_sum.449: f32[], reduce_sum.450: f32[]) -> f32[] { - %reduce_sum.450 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.449 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.451 = f32[]{:T(128)} add(%reduce_sum.449, %reduce_sum.450), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_68.73 (reduce_sum.551: f32[], reduce_sum.555: f32[]) -> f32[] { + %reduce_sum.555 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.551 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.556 = f32[]{:T(128)} add(%reduce_sum.551, %reduce_sum.555), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_53.58 (reduce_sum.368: f32[], reduce_sum.372: f32[]) -> f32[] { - %reduce_sum.372 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.368 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.373 = f32[]{:T(128)} add(%reduce_sum.368, %reduce_sum.372), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_53.58 (reduce_sum.473: f32[], reduce_sum.474: f32[]) -> f32[] { + %reduce_sum.474 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.473 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.478 = f32[]{:T(128)} add(%reduce_sum.473, %reduce_sum.474), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.328 (param_0.1371: f32[2048,4,16,128], param_1.1559: f32[], param_2.1317: f32[], param_3.921: f32[], param_4.559: f32[2048,4,16,128], param_5.471: f32[], param_6.361: f32[4,2048,16,128], param_7.204: pred[], param_8.121: f32[2048,4,16,128]) -> (f32[], f32[2048,4,16,128], f32[2048,4,16,128], f32[2048,4,16,128], f32[]) { - %param_0.1371 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(0) - %param_3.921 = f32[]{:T(128)S(6)} parameter(3) - %mul.1950.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_3.921), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.204 = pred[]{:T(512)S(6)} parameter(7) - %select_n.280.clone.1 = pred[2048,4,16,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.204), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.361 = f32[4,2048,16,128]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.470.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} bitcast(%param_6.361), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.471 = f32[]{:T(128)} parameter(5) - %div.884.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_5.471), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.883.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%bitcast.470.clone.1, %div.884.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.279.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} select(%select_n.280.clone.1, %bitcast.470.clone.1, %div.883.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1110.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.858.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1110.clone.1), dimensions={}, metadata={op_name="broadcast.75"} - %mul.1956.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%select_n.279.clone.1, %broadcast.858.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.339 (param_0.1412: f32[2048,4,16,128], param_1.1586: f32[], param_2.1328: f32[], param_3.911: f32[], param_4.550: f32[2048,4,16,128], param_5.484: f32[], param_6.359: f32[4,2048,16,128], param_7.203: pred[], param_8.121: f32[2048,4,16,128]) -> (f32[], f32[2048,4,16,128], f32[2048,4,16,128], f32[2048,4,16,128], f32[]) { + %param_0.1412 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.911 = f32[]{:T(128)S(6)} parameter(3) + %mul.2473.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_3.911), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.203 = pred[]{:T(512)S(6)} parameter(7) + %select_n.280.clone.1 = pred[2048,4,16,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.203), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.359 = f32[4,2048,16,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.451.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} bitcast(%param_6.359), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.484 = f32[]{:T(128)} parameter(5) + %div.884.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_5.484), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.883.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%bitcast.451.clone.1, %div.884.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.279.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} select(%select_n.280.clone.1, %bitcast.451.clone.1, %div.883.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1098.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.768.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1098.clone.1), dimensions={}, metadata={op_name="broadcast.75"} + %mul.2479.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%select_n.279.clone.1, %broadcast.768.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.121 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(8) - %constant.1114.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1957.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1114.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1955.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_8.121, %mul.1957.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.957.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%mul.1956.clone.1, %mul.1955.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1317 = f32[]{:T(128)S(6)} parameter(2) - %div.880.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1317), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1102.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2480.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1102.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2478.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_8.121, %mul.2480.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.933.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%mul.2479.clone.1, %mul.2478.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1328 = f32[]{:T(128)S(6)} parameter(2) + %div.880.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1328), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.68.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%select_n.279.clone.1, %select_n.279.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1113.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1954.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1113.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1952.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.68.clone.1, %mul.1954.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.559 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(4) - %constant.1112.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1953.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1112.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1951.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_4.559, %mul.1953.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.956.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%mul.1952.clone.1, %mul.1951.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1559 = f32[]{:T(128)S(6)} parameter(1) - %div.879.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1559), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.878.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%add.956.clone.1, %div.879.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1101.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2477.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1101.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2475.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.68.clone.1, %mul.2477.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.550 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.1100.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2476.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1100.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2474.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_4.550, %mul.2476.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.932.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%mul.2475.clone.1, %mul.2474.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1586 = f32[]{:T(128)S(6)} parameter(1) + %div.879.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1586), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.878.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%add.932.clone.1, %div.879.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.65.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} sqrt(%div.878.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1111.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.955.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1111.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.954.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%sqrt.65.clone.1, %add.955.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.429.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%div.880.clone.1, %add.954.clone.1), metadata={op_name="multiply.58"} - %div.877.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%add.957.clone.1, %multiply.429.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1949.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_0.1371, %broadcast.858.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.953.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%div.877.clone.1, %mul.1949.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1948.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%mul.1950.clone.1, %add.953.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.952.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%param_0.1371, %mul.1948.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.234 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%add.952.clone.1, %add.952.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1203 = f32[]{:T(128)} constant(0) - %reduce.192 = f32[]{:T(128)} reduce(%square.234, %constant.1203), dimensions={0,1,2,3}, to_apply=%region_68.73, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.194.clone.1 = f32[]{:T(128)} reduce(%integer_pow.68.clone.1, %constant.1203), dimensions={0,1,2,3}, to_apply=%region_53.58, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.148 = (f32[]{:T(128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.192, %add.952.clone.1, %add.956.clone.1, %add.957.clone.1, %reduce.194.clone.1) + %constant.1099.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.931.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} broadcast(%constant.1099.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.930.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%sqrt.65.clone.1, %add.931.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.290.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%div.880.clone.1, %add.930.clone.1), metadata={op_name="multiply.43"} + %div.877.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} divide(%add.933.clone.1, %multiply.290.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2472.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%param_0.1412, %broadcast.768.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.929.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%div.877.clone.1, %mul.2472.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2471.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%mul.2473.clone.1, %add.929.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.928.clone.1 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} add(%param_0.1412, %mul.2471.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.234 = f32[2048,4,16,128]{3,2,1,0:T(8,128)} multiply(%add.928.clone.1, %add.928.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1191 = f32[]{:T(128)} constant(0) + %reduce.122 = f32[]{:T(128)} reduce(%square.234, %constant.1191), dimensions={0,1,2,3}, to_apply=%region_68.73, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.124.clone.1 = f32[]{:T(128)} reduce(%integer_pow.68.clone.1, %constant.1191), dimensions={0,1,2,3}, to_apply=%region_53.58, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.149 = (f32[]{:T(128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[2048,4,16,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.122, %add.928.clone.1, %add.932.clone.1, %add.933.clone.1, %reduce.124.clone.1) +} + +%region_67.72 (reduce_sum.548: f32[], reduce_sum.549: f32[]) -> f32[] { + %reduce_sum.549 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.548 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.550 = f32[]{:T(128)} add(%reduce_sum.548, %reduce_sum.549), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_52.57 (reduce_sum.467: f32[], reduce_sum.471: f32[]) -> f32[] { + %reduce_sum.471 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.467 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.472 = f32[]{:T(128)} add(%reduce_sum.467, %reduce_sum.471), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.340 (param_0.1413: f32[16,4,128,2048], param_1.1587: f32[], param_2.1329: f32[], param_3.912: f32[], param_4.551: f32[16,4,128,2048], param_5.485: f32[], param_6.360: f32[4,16,128,2048], param_7.204: pred[], param_8.122: f32[16,4,128,2048]) -> (f32[], f32[16,4,128,2048], f32[16,4,128,2048], f32[16,4,128,2048], f32[]) { + %param_0.1413 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(0) + %param_3.912 = f32[]{:T(128)S(6)} parameter(3) + %mul.2483.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_3.912), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.204 = pred[]{:T(512)S(6)} parameter(7) + %select_n.284.clone.1 = pred[16,4,128,2048]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.204), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.360 = f32[4,16,128,2048]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.453.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} bitcast(%param_6.360), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.485 = f32[]{:T(128)} parameter(5) + %div.892.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_5.485), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.891.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%bitcast.453.clone.1, %div.892.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.283.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} select(%select_n.284.clone.1, %bitcast.453.clone.1, %div.891.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1104.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.770.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1104.clone.1), dimensions={}, metadata={op_name="broadcast.76"} + %mul.2489.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%select_n.283.clone.1, %broadcast.770.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.122 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(8) + %constant.1108.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2490.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1108.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2488.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_8.122, %mul.2490.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.939.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%mul.2489.clone.1, %mul.2488.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1329 = f32[]{:T(128)S(6)} parameter(2) + %div.888.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_2.1329), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.69.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%select_n.283.clone.1, %select_n.283.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1107.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2487.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1107.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2485.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%integer_pow.69.clone.1, %mul.2487.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.551 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(4) + %constant.1106.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2486.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1106.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2484.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_4.551, %mul.2486.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.938.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%mul.2485.clone.1, %mul.2484.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1587 = f32[]{:T(128)S(6)} parameter(1) + %div.887.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_1.1587), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.886.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%add.938.clone.1, %div.887.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.66.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} sqrt(%div.886.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1105.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.937.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1105.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.936.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%sqrt.66.clone.1, %add.937.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.291.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%div.888.clone.1, %add.936.clone.1), metadata={op_name="multiply.42"} + %div.885.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%add.939.clone.1, %multiply.291.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2482.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_0.1413, %broadcast.770.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.935.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%div.885.clone.1, %mul.2482.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2481.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%mul.2483.clone.1, %add.935.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.934.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%param_0.1413, %mul.2481.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.235 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%add.934.clone.1, %add.934.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1192 = f32[]{:T(128)} constant(0) + %reduce.123 = f32[]{:T(128)} reduce(%square.235, %constant.1192), dimensions={0,1,2,3}, to_apply=%region_67.72, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.125.clone.1 = f32[]{:T(128)} reduce(%integer_pow.69.clone.1, %constant.1192), dimensions={0,1,2,3}, to_apply=%region_52.57, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.150 = (f32[]{:T(128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.123, %add.934.clone.1, %add.938.clone.1, %add.939.clone.1, %reduce.125.clone.1) +} + +%region_41.46 (reduce_sum.416: f32[], reduce_sum.417: f32[]) -> f32[] { + %reduce_sum.417 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.416 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.418 = f32[]{:T(128)} add(%reduce_sum.416, %reduce_sum.417), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_67.72 (reduce_sum.443: f32[], reduce_sum.444: f32[]) -> f32[] { - %reduce_sum.444 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.443 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.445 = f32[]{:T(128)} add(%reduce_sum.443, %reduce_sum.444), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_36.41 (reduce_sum.389: f32[], reduce_sum.390: f32[]) -> f32[] { + %reduce_sum.390 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.389 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.394 = f32[]{:T(128)} add(%reduce_sum.389, %reduce_sum.390), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_52.57 (reduce_sum.365: f32[], reduce_sum.366: f32[]) -> f32[] { - %reduce_sum.366 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.365 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.367 = f32[]{:T(128)} add(%reduce_sum.365, %reduce_sum.366), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%fused_computation.352 (param_0.1426: f32[4,2048,8,128], param_1.1598: f32[4,2048,8,128]) -> (f32[], f32[]) { + %param_0.1426 = f32[4,2048,8,128]{3,2,0,1:T(8,128)S(1)} parameter(0) + %bitcast.352 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1426), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.238 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.352, %bitcast.352), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1205 = f32[]{:T(128)} constant(0) + %reduce.126 = f32[]{:T(128)} reduce(%square.238, %constant.1205), dimensions={0,1,2,3}, to_apply=%region_41.46, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1598 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} parameter(1) + %bitcast.356.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_1.1598), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.241.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.356.clone.1, %bitcast.356.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.127.clone.1 = f32[]{:T(128)} reduce(%square.241.clone.1, %constant.1205), dimensions={0,1,2,3}, to_apply=%region_36.41, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.169 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.126, %reduce.127.clone.1) +} + +%fused_computation.355 (param_0.1021: f32[2048,4,8,128]) -> bf16[4,2048,8,128] { + %param_0.1021 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) + %copy.190 = bf16[2048,4,8,128]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.1021), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'value\'][\'kernel\']"} + ROOT %bitcast.357 = bf16[4,2048,8,128]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.190), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} +} + +%region_70.75 (reduce_sum.563: f32[], reduce_sum.564: f32[]) -> f32[] { + %reduce_sum.564 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.563 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.565 = f32[]{:T(128)} add(%reduce_sum.563, %reduce_sum.564), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_55.60 (reduce_sum.485: f32[], reduce_sum.486: f32[]) -> f32[] { + %reduce_sum.486 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.485 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.487 = f32[]{:T(128)} add(%reduce_sum.485, %reduce_sum.486), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.356 (param_0.1410: f32[2048,4,8,128], param_1.1584: f32[], param_2.1326: f32[], param_3.909: f32[], param_4.548: f32[2048,4,8,128], param_5.482: f32[], param_6.357: f32[4,2048,8,128], param_7.201: pred[], param_8.119: f32[2048,4,8,128]) -> (f32[], f32[2048,4,8,128], f32[2048,4,8,128], f32[2048,4,8,128], f32[]) { + %param_0.1410 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.909 = f32[]{:T(128)S(6)} parameter(3) + %mul.2459.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.909), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.201 = pred[]{:T(512)S(6)} parameter(7) + %select_n.272.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.201), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.357 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} parameter(6) + %bitcast.447.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.357), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.482 = f32[]{:T(128)} parameter(5) + %div.868.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.482), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.867.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.447.clone.1, %div.868.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.271.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.272.clone.1, %bitcast.447.clone.1, %div.867.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1086.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.760.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1086.clone.1), dimensions={}, metadata={op_name="broadcast.80"} + %mul.2463.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.271.clone.1, %broadcast.760.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.119 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(8) + %constant.1090.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.759.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1090.clone.1), dimensions={}, metadata={op_name="broadcast.79"} + %mul.2462.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.119, %broadcast.759.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.922.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.2463.clone.1, %mul.2462.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1326 = f32[]{:T(128)S(6)} parameter(2) + %div.864.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1326), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.66.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.271.clone.1, %select_n.271.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %constant.1089.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.758.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1089.clone.1), dimensions={}, metadata={op_name="broadcast.69"} + %mul.2461.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.66.clone.1, %broadcast.758.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.548 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(4) + %constant.1088.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.757.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1088.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.2460.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.548, %broadcast.757.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.921.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.2461.clone.1, %mul.2460.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1584 = f32[]{:T(128)S(6)} parameter(1) + %div.863.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1584), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.862.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.921.clone.1, %div.863.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.63.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.862.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %constant.1087.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.755.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1087.clone.1), dimensions={}, metadata={op_name="broadcast.63"} + %add.920.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.63.clone.1, %broadcast.755.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.288.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.864.clone.1, %add.920.clone.1), metadata={op_name="multiply.45"} + %div.861.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.922.clone.1, %multiply.288.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2458.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1410, %broadcast.760.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.919.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%div.861.clone.1, %mul.2458.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2457.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.2459.clone.1, %add.919.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.918.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1410, %mul.2457.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.242 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.918.clone.1, %add.918.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1189 = f32[]{:T(128)} constant(0) + %reduce.128 = f32[]{:T(128)} reduce(%square.242, %constant.1189), dimensions={0,1,2,3}, to_apply=%region_70.75, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.130.clone.1 = f32[]{:T(128)} reduce(%integer_pow.66.clone.1, %constant.1189), dimensions={0,1,2,3}, to_apply=%region_55.60, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.151 = (f32[]{:T(128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.128, %add.918.clone.1, %add.921.clone.1, %add.922.clone.1, %reduce.130.clone.1) +} + +%region_65.70 (reduce_sum.536: f32[], reduce_sum.537: f32[]) -> f32[] { + %reduce_sum.537 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.536 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.541 = f32[]{:T(128)} add(%reduce_sum.536, %reduce_sum.537), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_50.55 (reduce_sum.458: f32[], reduce_sum.459: f32[]) -> f32[] { + %reduce_sum.459 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.458 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.460 = f32[]{:T(128)} add(%reduce_sum.458, %reduce_sum.459), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.329 (param_0.1372: f32[16,4,128,2048], param_1.1560: f32[], param_2.1318: f32[], param_3.922: f32[], param_4.560: f32[16,4,128,2048], param_5.472: f32[], param_6.362: f32[4,16,128,2048], param_7.205: pred[], param_8.122: f32[16,4,128,2048]) -> (f32[], f32[16,4,128,2048], f32[16,4,128,2048], f32[16,4,128,2048], f32[]) { - %param_0.1372 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(0) - %param_3.922 = f32[]{:T(128)S(6)} parameter(3) - %mul.1960.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_3.922), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.205 = pred[]{:T(512)S(6)} parameter(7) - %select_n.284.clone.1 = pred[16,4,128,2048]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.205), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.362 = f32[4,16,128,2048]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.472.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} bitcast(%param_6.362), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.472 = f32[]{:T(128)} parameter(5) - %div.892.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_5.472), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.891.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%bitcast.472.clone.1, %div.892.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.283.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} select(%select_n.284.clone.1, %bitcast.472.clone.1, %div.891.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} +%fused_computation.357 (param_0.1415: f32[2048,4,8,128], param_1.1589: f32[], param_2.1331: f32[], param_3.914: f32[], param_4.553: f32[2048,4,8,128], param_5.487: f32[], param_6.362: f32[4,2048,8,128], param_7.206: pred[], param_8.124: f32[2048,4,8,128]) -> (f32[], f32[2048,4,8,128], f32[2048,4,8,128], f32[2048,4,8,128], f32[]) { + %param_0.1415 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(0) + %param_3.914 = f32[]{:T(128)S(6)} parameter(3) + %mul.2500.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.914), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.206 = pred[]{:T(512)S(6)} parameter(7) + %select_n.292.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.206), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.362 = f32[4,2048,8,128]{3,2,0,1:T(8,128)S(1)} parameter(6) + %bitcast.457.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.362), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.487 = f32[]{:T(128)} parameter(5) + %div.908.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.487), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.907.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.457.clone.1, %div.908.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.291.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.292.clone.1, %bitcast.457.clone.1, %div.907.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} %constant.1116.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.860.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1116.clone.1), dimensions={}, metadata={op_name="broadcast.76"} - %mul.1966.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%select_n.283.clone.1, %broadcast.860.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_8.122 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(8) + %broadcast.782.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1116.clone.1), dimensions={}, metadata={op_name="broadcast.80"} + %mul.2504.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.291.clone.1, %broadcast.782.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_8.124 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(8) %constant.1120.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.1967.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1120.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1965.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_8.122, %mul.1967.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.963.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%mul.1966.clone.1, %mul.1965.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1318 = f32[]{:T(128)S(6)} parameter(2) - %div.888.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_2.1318), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %integer_pow.69.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%select_n.283.clone.1, %select_n.283.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} + %broadcast.781.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1120.clone.1), dimensions={}, metadata={op_name="broadcast.79"} + %mul.2503.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.124, %broadcast.781.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.949.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.2504.clone.1, %mul.2503.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1331 = f32[]{:T(128)S(6)} parameter(2) + %div.904.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1331), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %integer_pow.71.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.291.clone.1, %select_n.291.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} %constant.1119.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.1964.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1119.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1962.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%integer_pow.69.clone.1, %mul.1964.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.560 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} parameter(4) + %broadcast.780.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1119.clone.1), dimensions={}, metadata={op_name="broadcast.69"} + %mul.2502.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.71.clone.1, %broadcast.780.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.553 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(4) %constant.1118.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.1963.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1118.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.1961.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_4.560, %mul.1963.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.962.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%mul.1962.clone.1, %mul.1961.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1560 = f32[]{:T(128)S(6)} parameter(1) - %div.887.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_1.1560), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.886.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%add.962.clone.1, %div.887.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.66.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} sqrt(%div.886.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} + %broadcast.779.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1118.clone.1), dimensions={}, metadata={op_name="broadcast.68"} + %mul.2501.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.553, %broadcast.779.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.948.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.2502.clone.1, %mul.2501.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1589 = f32[]{:T(128)S(6)} parameter(1) + %div.903.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1589), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.902.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.948.clone.1, %div.903.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %sqrt.68.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.902.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} %constant.1117.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.961.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%constant.1117.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.960.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%sqrt.66.clone.1, %add.961.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.430.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%div.888.clone.1, %add.960.clone.1), metadata={op_name="multiply.57"} - %div.885.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} divide(%add.963.clone.1, %multiply.430.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1959.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%param_0.1372, %broadcast.860.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.959.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%div.885.clone.1, %mul.1959.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1958.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%mul.1960.clone.1, %add.959.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.958.clone.1 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} add(%param_0.1372, %mul.1958.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.235 = f32[16,4,128,2048]{3,2,1,0:T(8,128)} multiply(%add.958.clone.1, %add.958.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1204 = f32[]{:T(128)} constant(0) - %reduce.193 = f32[]{:T(128)} reduce(%square.235, %constant.1204), dimensions={0,1,2,3}, to_apply=%region_67.72, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.195.clone.1 = f32[]{:T(128)} reduce(%integer_pow.69.clone.1, %constant.1204), dimensions={0,1,2,3}, to_apply=%region_52.57, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.149 = (f32[]{:T(128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[16,4,128,2048]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.193, %add.958.clone.1, %add.962.clone.1, %add.963.clone.1, %reduce.195.clone.1) -} - -%region_41.46 (reduce_sum.311: f32[], reduce_sum.312: f32[]) -> f32[] { - %reduce_sum.312 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.311 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.316 = f32[]{:T(128)} add(%reduce_sum.311, %reduce_sum.312), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_36.41 (reduce_sum.284: f32[], reduce_sum.288: f32[]) -> f32[] { - %reduce_sum.288 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.284 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.289 = f32[]{:T(128)} add(%reduce_sum.284, %reduce_sum.288), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.341 (param_0.1385: f32[4,2048,8,128], param_1.1571: f32[4,2048,8,128]) -> (f32[], f32[]) { - %param_0.1385 = f32[4,2048,8,128]{3,2,0,1:T(8,128)S(1)} parameter(0) - %bitcast.371 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_0.1385), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.238 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.371, %bitcast.371), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1217 = f32[]{:T(128)} constant(0) - %reduce.196 = f32[]{:T(128)} reduce(%square.238, %constant.1217), dimensions={0,1,2,3}, to_apply=%region_41.46, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1571 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} parameter(1) - %bitcast.375.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_1.1571), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.241.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%bitcast.375.clone.1, %bitcast.375.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.197.clone.1 = f32[]{:T(128)} reduce(%square.241.clone.1, %constant.1217), dimensions={0,1,2,3}, to_apply=%region_36.41, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.168 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.196, %reduce.197.clone.1) -} - -%fused_computation.344 (param_0.982: f32[2048,4,8,128]) -> bf16[4,2048,8,128] { - %param_0.982 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) - %copy.194 = bf16[2048,4,8,128]{3,2,0,1:T(8,128)(2,1)} copy(%param_0.982), sharding={replicated}, metadata={op_name="state.params[\'params\'][\'decoder\'][\'layers\'][\'self_attention\'][\'value\'][\'kernel\']"} - ROOT %bitcast.376 = bf16[4,2048,8,128]{3,2,1,0:T(8,128)(2,1)} bitcast(%copy.194), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} -} - -%region_70.75 (reduce_sum.458: f32[], reduce_sum.459: f32[]) -> f32[] { - %reduce_sum.459 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.458 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.463 = f32[]{:T(128)} add(%reduce_sum.458, %reduce_sum.459), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %broadcast.777.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1117.clone.1), dimensions={}, metadata={op_name="broadcast.63"} + %add.947.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.68.clone.1, %broadcast.777.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.293.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.904.clone.1, %add.947.clone.1), metadata={op_name="multiply.40"} + %div.901.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.949.clone.1, %multiply.293.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2499.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1415, %broadcast.782.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.946.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%div.901.clone.1, %mul.2499.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2498.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.2500.clone.1, %add.946.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.945.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1415, %mul.2498.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.243 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.945.clone.1, %add.945.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1194 = f32[]{:T(128)} constant(0) + %reduce.129 = f32[]{:T(128)} reduce(%square.243, %constant.1194), dimensions={0,1,2,3}, to_apply=%region_65.70, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.131.clone.1 = f32[]{:T(128)} reduce(%integer_pow.71.clone.1, %constant.1194), dimensions={0,1,2,3}, to_apply=%region_50.55, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.152 = (f32[]{:T(128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.129, %add.945.clone.1, %add.948.clone.1, %add.949.clone.1, %reduce.131.clone.1) +} + +%fused_computation.373 (param_0.1095: bf16[4,128,2048], param_1.1142: f32[4,128], param_2.842: f32[4,128], param_3.484: bf16[4,128,2048], param_4.283: bf16[2048]) -> bf16[4,128,2048] { + %param_3.484 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.283 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %mul.2385 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_4.283), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2359 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%param_3.484, %mul.2385), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1387 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%mul.2359), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.842 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %mul.2356 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_2.842), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2347 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1387, %mul.2356), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.1095 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1398 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1095), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_1.1142 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %mul.2354 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_1.1142), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2353 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1398, %mul.2354), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %add_any.184 = f32[4,128,2048]{2,1,0:T(8,128)} add(%mul.2347, %mul.2353), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add_any" stack_frame_id=0} + ROOT %convert_element_type.1385 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%add_any.184), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} +} + +%region_6.9 (reduce_sum.228: f32[], reduce_sum.229: f32[]) -> f32[] { + %reduce_sum.229 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.228 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.230 = f32[]{:T(128)} add(%reduce_sum.228, %reduce_sum.229), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.374 (param_0.1435: bf16[4,128,2048]) -> f32[4,128] { + %param_0.1435 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1389 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1435), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %square.246 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1389, %convert_element_type.1389), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/square" stack_frame_id=0} + %constant.1215 = f32[]{:T(128)} constant(0) + ROOT %reduce.132 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.246, %constant.1215), dimensions={2}, to_apply=%region_6.9, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} } -%region_55.60 (reduce_sum.380: f32[], reduce_sum.381: f32[]) -> f32[] { - %reduce_sum.381 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.380 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.382 = f32[]{:T(128)} add(%reduce_sum.380, %reduce_sum.381), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_12.15 (reduce_sum.275: f32[], reduce_sum.276: f32[]) -> f32[] { + %reduce_sum.276 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.275 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.277 = f32[]{:T(128)} add(%reduce_sum.275, %reduce_sum.276), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.345 (param_0.1369: f32[2048,4,8,128], param_1.1557: f32[], param_2.1315: f32[], param_3.919: f32[], param_4.557: f32[2048,4,8,128], param_5.469: f32[], param_6.359: f32[4,2048,8,128], param_7.202: pred[], param_8.119: f32[2048,4,8,128]) -> (f32[], f32[2048,4,8,128], f32[2048,4,8,128], f32[2048,4,8,128], f32[]) { - %param_0.1369 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(0) - %param_3.919 = f32[]{:T(128)S(6)} parameter(3) - %mul.1936.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.919), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.202 = pred[]{:T(512)S(6)} parameter(7) - %select_n.272.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.202), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.359 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} parameter(6) - %bitcast.466.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.359), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.469 = f32[]{:T(128)} parameter(5) - %div.868.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.469), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.867.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.466.clone.1, %div.868.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.271.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.272.clone.1, %bitcast.466.clone.1, %div.867.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1098.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.850.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1098.clone.1), dimensions={}, metadata={op_name="broadcast.80"} - %mul.1940.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.271.clone.1, %broadcast.850.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_8.119 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(8) - %constant.1102.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.849.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1102.clone.1), dimensions={}, metadata={op_name="broadcast.79"} - %mul.1939.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.119, %broadcast.849.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.946.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1940.clone.1, %mul.1939.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1315 = f32[]{:T(128)S(6)} parameter(2) - %div.864.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1315), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %integer_pow.66.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.271.clone.1, %select_n.271.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1101.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.848.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1101.clone.1), dimensions={}, metadata={op_name="broadcast.69"} - %mul.1938.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.66.clone.1, %broadcast.848.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.557 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(4) - %constant.1100.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.847.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1100.clone.1), dimensions={}, metadata={op_name="broadcast.68"} - %mul.1937.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.557, %broadcast.847.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.945.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1938.clone.1, %mul.1937.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1557 = f32[]{:T(128)S(6)} parameter(1) - %div.863.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1557), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.862.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.945.clone.1, %div.863.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.63.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.862.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1099.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.845.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1099.clone.1), dimensions={}, metadata={op_name="broadcast.63"} - %add.944.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.63.clone.1, %broadcast.845.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.427.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.864.clone.1, %add.944.clone.1), metadata={op_name="multiply.60"} - %div.861.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.946.clone.1, %multiply.427.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1935.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1369, %broadcast.850.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.943.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%div.861.clone.1, %mul.1935.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1934.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1936.clone.1, %add.943.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.942.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1369, %mul.1934.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.242 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.942.clone.1, %add.942.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1201 = f32[]{:T(128)} constant(0) - %reduce.198 = f32[]{:T(128)} reduce(%square.242, %constant.1201), dimensions={0,1,2,3}, to_apply=%region_70.75, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.200.clone.1 = f32[]{:T(128)} reduce(%integer_pow.66.clone.1, %constant.1201), dimensions={0,1,2,3}, to_apply=%region_55.60, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.150 = (f32[]{:T(128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.198, %add.942.clone.1, %add.945.clone.1, %add.946.clone.1, %reduce.200.clone.1) +%fused_computation.376 (param_0.1430: bf16[4,128,2048], param_1.1601: bf16[4,128,2048], param_2.1339: bf16[2048]) -> f32[4,128] { + %param_0.1430 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) + %convert_element_type.1396 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1430), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_1.1601 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %param_2.1339 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2384 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1339), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2358 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1601, %mul.2384), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1395 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%mul.2358), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %mul.2351 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1396, %convert_element_type.1395), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.1209 = f32[]{:T(128)} constant(0) + ROOT %reduce.133 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.2351, %constant.1209), dimensions={2}, to_apply=%region_12.15, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} } -%region_65.70 (reduce_sum.431: f32[], reduce_sum.435: f32[]) -> f32[] { - %reduce_sum.435 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.431 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.436 = f32[]{:T(128)} add(%reduce_sum.431, %reduce_sum.435), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_10.13 (reduce_sum.263: bf16[], reduce_sum.264: bf16[]) -> bf16[] { + %reduce_sum.264 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + %reduce_sum.263 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} + ROOT %reduce_sum.268 = bf16[]{:T(256)} add(%reduce_sum.263, %reduce_sum.264), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_50.55 (reduce_sum.353: f32[], reduce_sum.354: f32[]) -> f32[] { - %reduce_sum.354 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.353 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.358 = f32[]{:T(128)} add(%reduce_sum.353, %reduce_sum.354), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%fused_computation.296.clone.clone (param_0.1392: bf16[151936,2048]) -> bf16[151936,2048,1] { + %param_0.1392 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + ROOT %bitcast.505 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%param_0.1392), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} } -%fused_computation.346 (param_0.1374: f32[2048,4,8,128], param_1.1562: f32[], param_2.1320: f32[], param_3.924: f32[], param_4.562: f32[2048,4,8,128], param_5.474: f32[], param_6.364: f32[4,2048,8,128], param_7.207: pred[], param_8.124: f32[2048,4,8,128]) -> (f32[], f32[2048,4,8,128], f32[2048,4,8,128], f32[2048,4,8,128], f32[]) { - %param_0.1374 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(0) - %param_3.924 = f32[]{:T(128)S(6)} parameter(3) - %mul.1977.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_3.924), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.207 = pred[]{:T(512)S(6)} parameter(7) - %select_n.292.clone.1 = pred[2048,4,8,128]{3,2,1,0:T(8,128)(4,1)} broadcast(%param_7.207), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.364 = f32[4,2048,8,128]{3,2,0,1:T(8,128)S(1)} parameter(6) - %bitcast.476.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} bitcast(%param_6.364), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.474 = f32[]{:T(128)} parameter(5) - %div.908.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_5.474), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.907.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%bitcast.476.clone.1, %div.908.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.291.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} select(%select_n.292.clone.1, %bitcast.476.clone.1, %div.907.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1128.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.872.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1128.clone.1), dimensions={}, metadata={op_name="broadcast.80"} - %mul.1981.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.291.clone.1, %broadcast.872.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_8.124 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(8) - %constant.1132.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.871.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1132.clone.1), dimensions={}, metadata={op_name="broadcast.79"} - %mul.1980.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_8.124, %broadcast.871.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.973.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1981.clone.1, %mul.1980.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1320 = f32[]{:T(128)S(6)} parameter(2) - %div.904.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_2.1320), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %integer_pow.71.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%select_n.291.clone.1, %select_n.291.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1131.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.870.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1131.clone.1), dimensions={}, metadata={op_name="broadcast.69"} - %mul.1979.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%integer_pow.71.clone.1, %broadcast.870.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.562 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} parameter(4) - %constant.1130.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.869.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1130.clone.1), dimensions={}, metadata={op_name="broadcast.68"} - %mul.1978.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_4.562, %broadcast.869.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.972.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%mul.1979.clone.1, %mul.1978.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1562 = f32[]{:T(128)S(6)} parameter(1) - %div.903.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%param_1.1562), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.902.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.972.clone.1, %div.903.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %sqrt.68.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} sqrt(%div.902.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1129.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.867.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} broadcast(%constant.1129.clone.1), dimensions={}, metadata={op_name="broadcast.63"} - %add.971.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%sqrt.68.clone.1, %broadcast.867.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.432.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%div.904.clone.1, %add.971.clone.1), metadata={op_name="multiply.55"} - %div.901.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} divide(%add.973.clone.1, %multiply.432.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1976.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%param_0.1374, %broadcast.872.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.970.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%div.901.clone.1, %mul.1976.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1975.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%mul.1977.clone.1, %add.970.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.969.clone.1 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} add(%param_0.1374, %mul.1975.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.243 = f32[2048,4,8,128]{3,2,1,0:T(8,128)} multiply(%add.969.clone.1, %add.969.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1206 = f32[]{:T(128)} constant(0) - %reduce.199 = f32[]{:T(128)} reduce(%square.243, %constant.1206), dimensions={0,1,2,3}, to_apply=%region_65.70, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.201.clone.1 = f32[]{:T(128)} reduce(%integer_pow.71.clone.1, %constant.1206), dimensions={0,1,2,3}, to_apply=%region_50.55, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.151 = (f32[]{:T(128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[2048,4,8,128]{3,2,1,0:T(8,128)}, f32[]{:T(128)}) tuple(%reduce.199, %add.969.clone.1, %add.972.clone.1, %add.973.clone.1, %reduce.201.clone.1) -} - -%fused_computation.362 (param_0.1056: bf16[4,128,2048], param_1.1117: f32[4,128], param_2.830: f32[4,128], param_3.495: bf16[4,128,2048], param_4.296: bf16[2048]) -> bf16[4,128,2048] { - %param_3.495 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %param_4.296 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %dot_general.448 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_4.296), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %dot_general.438 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%param_3.495, %dot_general.448), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %convert_element_type.1363 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%dot_general.438), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_2.830 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %mul.1851 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_2.830), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1843 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1363, %mul.1851), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %param_0.1056 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.1374 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1056), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_1.1117 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %mul.1850 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_1.1117), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1849 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1374, %mul.1850), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %add_any.193 = f32[4,128,2048]{2,1,0:T(8,128)} add(%mul.1843, %mul.1849), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add_any" stack_frame_id=0} - ROOT %convert_element_type.1361 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%add_any.193), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} -} - -%region_7.10 (reduce_sum.171: f32[], reduce_sum.184: f32[]) -> f32[] { - %reduce_sum.184 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - %reduce_sum.171 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - ROOT %reduce_sum.185 = f32[]{:T(128)} add(%reduce_sum.171, %reduce_sum.184), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.363 (param_0.1394: bf16[4,128,2048]) -> f32[4,128] { - %param_0.1394 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.1365 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1394), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %square.246 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1365, %convert_element_type.1365), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/square" stack_frame_id=0} - %constant.1227 = f32[]{:T(128)} constant(0) - ROOT %reduce.202 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.246, %constant.1227), dimensions={2}, to_apply=%region_7.10, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} -} - -%region_12.15 (reduce_sum.198: f32[], reduce_sum.199: f32[]) -> f32[] { - %reduce_sum.199 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - %reduce_sum.198 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum"} - ROOT %reduce_sum.200 = f32[]{:T(128)} add(%reduce_sum.198, %reduce_sum.199), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.365 (param_0.1389: bf16[4,128,2048], param_1.1574: bf16[4,128,2048], param_2.1328: bf16[2048]) -> f32[4,128] { - %param_0.1389 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %convert_element_type.1372 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_0.1389), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_1.1574 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %param_2.1328 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.447 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1328), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %dot_general.437 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1574, %dot_general.447), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %convert_element_type.1371 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%dot_general.437), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %mul.1847 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1372, %convert_element_type.1371), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %constant.1221 = f32[]{:T(128)} constant(0) - ROOT %reduce.203 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%mul.1847, %constant.1221), dimensions={2}, to_apply=%region_12.15, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} -} - -%region_10.13 (dot_general.190: bf16[], dot_general.191: bf16[]) -> bf16[] { - %dot_general.191 = bf16[]{:T(256)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general"} - %dot_general.190 = bf16[]{:T(256)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general"} - ROOT %add.419 = bf16[]{:T(256)} add(%dot_general.190, %dot_general.191), metadata={op_name="add.82"}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.285.clone.clone (param_0.1351: bf16[151936,2048]) -> bf16[151936,2048,1] { - %param_0.1351 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) - ROOT %bitcast.528 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%param_0.1351), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} -} - -%fused_computation.289.clone.1.clone.clone (param_0.1352: bf16[4,128,151936], param_1.1546: s32[4,128], param_2.1285: f32[4,128], param_3.906: f32[4,128], param_4.542: bf16[4,128], param_5.442: f32[4,128]) -> bf16[4,128,151936] { - %param_5.442 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %mul.2075 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_5.442), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_3.906 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %mul.2074 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.906), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_0.1352 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) - %convert_element_type.1444 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1352), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_4.542 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) - %sub.88 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_4.542), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %sub.87 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1444, %sub.88), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} - %exp.60 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.87), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} - %mul.2073 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2074, %exp.60), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %param_2.1285 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %div.962 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1285), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %div.961 = f32[4,128,151936]{2,1,0:T(8,128)} divide(%mul.2073, %div.962), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} - %param_1.1546 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %eq.43 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1546), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} +%fused_computation.300.clone.1.clone.clone (param_0.1393: bf16[4,128,151936], param_1.1573: s32[4,128], param_2.1296: f32[4,128], param_3.896: f32[4,128], param_4.533: bf16[4,128], param_5.455: f32[4,128]) -> bf16[4,128,151936] { + %param_5.455 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %mul.2616 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_5.455), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_3.896 = f32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %mul.2615 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_3.896), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_0.1393 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(0) + %convert_element_type.1468 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%param_0.1393), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_4.533 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(4) + %sub.86 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_4.533), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %sub.85 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%convert_element_type.1468, %sub.86), metadata={op_name="jit(train_step)/jvp()/sub" stack_frame_id=0} + %exp.60 = f32[4,128,151936]{2,1,0:T(8,128)} exponential(%sub.85), metadata={op_name="jit(train_step)/jvp()/exp" stack_frame_id=0} + %mul.2614 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2615, %exp.60), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %param_2.1296 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %div.962 = f32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_2.1296), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %div.961 = f32[4,128,151936]{2,1,0:T(8,128)} divide(%mul.2614, %div.962), metadata={op_name="jit(train_step)/transpose(jvp())/div" stack_frame_id=0} + %param_1.1573 = s32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %eq.43 = s32[4,128,151936]{2,1,0:T(8,128)} broadcast(%param_1.1573), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.42 = s32[4,128,151936]{2,1,0:T(8,128)} iota(), iota_dimension=2, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} %eq.41 = pred[4,128,151936]{2,1,0:T(8,128)(4,1)} compare(%eq.43, %eq.42), direction=EQ, metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/eq" stack_frame_id=0} - %convert_element_type.1443 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%eq.41), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} - %sub.86 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%div.961, %convert_element_type.1443), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} - %mul.2072 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2075, %sub.86), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - ROOT %convert_element_type.1442 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convert(%mul.2072), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} -} - -%fused_computation.366 (param_0.1350: f32[4,128], param_1.1545: bf16[4,128,2048], param_2.1286: bf16[151936,2048], param_3.907: bf16[4,128,151936], param_4.543: s32[4,128], param_5.443: f32[4,128], param_6.340: f32[4,128], param_7.199: bf16[4,128], param_8.116: f32[4,128]) -> (bf16[2048], bf16[4,128,2048]) { - %param_3.907 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(3) - %param_4.543 = s32[4,128]{1,0:T(4,128)S(1)} parameter(4) - %param_5.443 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) - %param_6.340 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) - %param_7.199 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(7) + %convert_element_type.1467 = f32[4,128,151936]{2,1,0:T(8,128)} convert(%eq.41), metadata={op_name="jit(train_step)/jvp(jit(_one_hot))/convert_element_type" stack_frame_id=0} + %sub.84 = f32[4,128,151936]{2,1,0:T(8,128)} subtract(%div.961, %convert_element_type.1467), metadata={op_name="jit(train_step)/transpose(jvp())/sub" stack_frame_id=0} + %mul.2613 = f32[4,128,151936]{2,1,0:T(8,128)} multiply(%mul.2616, %sub.84), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + ROOT %convert_element_type.1466 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convert(%mul.2613), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} +} + +%fused_computation.377 (param_0.1391: f32[4,128], param_1.1572: bf16[4,128,2048], param_2.1297: bf16[151936,2048], param_3.897: bf16[4,128,151936], param_4.534: s32[4,128], param_5.456: f32[4,128], param_6.342: f32[4,128], param_7.198: bf16[4,128], param_8.116: f32[4,128]) -> (bf16[2048], bf16[4,128,2048]) { + %param_1.1572 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1408 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1572), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1391 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2373 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1391), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2372 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1408, %mul.2373), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1407 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2372), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_3.897 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.534 = s32[4,128]{1,0:T(4,128)S(1)} parameter(4) + %param_5.456 = f32[4,128]{1,0:T(4,128)S(1)} parameter(5) + %param_6.342 = f32[4,128]{1,0:T(4,128)S(1)} parameter(6) + %param_7.198 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(7) %param_8.116 = f32[4,128]{1,0:T(4,128)S(1)} parameter(8) - %multiply_convert_fusion.2.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} fusion(%param_3.907, %param_4.543, %param_5.443, %param_6.340, %param_7.199, /*index=5*/%param_8.116), kind=kLoop, calls=%fused_computation.289.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} - %param_2.1286 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(2) - %fusion.251.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_2.1286), kind=kLoop, calls=%fused_computation.285.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %convolution.84.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} convolution(%multiply_convert_fusion.2.clone.1, %fusion.251.clone.1), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} - %param_1.1545 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1384 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1545), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.1350 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.1862 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1350), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1861 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1384, %mul.1862), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.1383 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.1861), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %multiply.420 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convolution.84.clone.1, %convert_element_type.1383), metadata={op_name="multiply.362"} - %constant.1050 = bf16[]{:T(256)} constant(0) - %reduce.204 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%multiply.420, %constant.1050), dimensions={0,1}, to_apply=%region_10.13, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %tuple.165 = (bf16[2048]{0:T(1024)(128)(2,1)S(1)}, bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.204, %convolution.84.clone.1) -} - -%fused_computation.374 (param_0.1088: f32[64], param_1.1150: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { - %param_1.1150 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %div.720 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1150), dimensions={0,1}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} - %param_0.1088 = f32[64]{0:T(128)S(1)} parameter(0) - %div.718 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1088), dimensions={3}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %multiply_convert_fusion.3.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} fusion(%param_3.897, %param_4.534, %param_5.456, %param_6.342, %param_7.198, /*index=5*/%param_8.116), kind=kLoop, calls=%fused_computation.300.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/convert_element_type" stack_frame_id=0} + %param_2.1297 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(2) + %fusion.261.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_2.1297), kind=kLoop, calls=%fused_computation.296.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %convolution.84.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} convolution(%multiply_convert_fusion.3.clone.1, %fusion.261.clone.1), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %mul.2355 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1407, %convolution.84.clone.1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.1051 = bf16[]{:T(256)} constant(0) + %reduce.134 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%mul.2355, %constant.1051), dimensions={0,1}, to_apply=%region_10.13, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/reduce_sum" stack_frame_id=0} + ROOT %tuple.166 = (bf16[2048]{0:T(1024)(128)(2,1)S(1)}, bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.134, %convolution.84.clone.1) +} + +%fused_computation.385 (param_0.1129: f32[64], param_1.1177: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.1177 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %div.720 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1177), dimensions={0,1}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + %param_0.1129 = f32[64]{0:T(128)S(1)} parameter(0) + %div.718 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1129), dimensions={3}, metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} %div.717 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.720, %div.718), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} %sin.38 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.717), metadata={op_name="jit(train_step)/layers/sin" stack_frame_id=0} - %convert_element_type.1392 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.38), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %convert_element_type.1416 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.38), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} %cos.41.clone.1 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.717), metadata={op_name="jit(train_step)/layers/cos" stack_frame_id=0} - %convert_element_type.1391.clone.1 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.41.clone.1), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} - ROOT %tuple.158 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1392, %convert_element_type.1391.clone.1) + %convert_element_type.1415.clone.1 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.41.clone.1), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.159 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1416, %convert_element_type.1415.clone.1) } -%fused_computation.375 (param_0.1085: bf16[4,128,1,64]) -> bf16[4,128,1,128] { - %param_0.1085 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.1042 = bf16[]{:T(256)} constant(-inf) - %pad.46 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1085, %constant.1042), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} - %pad.45 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1085, %constant.1042), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} +%fused_computation.386 (param_0.1126: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.1126 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1030 = bf16[]{:T(256)} constant(-inf) + %pad.46 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1126, %constant.1030), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + %pad.45 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1126, %constant.1030), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} ROOT %maximum.42 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.46, %pad.45), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} } -%fused_computation.376 (param_0.1087: bf16[4,128,1,64]) -> bf16[4,128,1,128] { - %param_0.1087 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.1041 = bf16[]{:T(256)} constant(-inf) - %pad.48 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1087, %constant.1041), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} - %pad.47 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1087, %constant.1041), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} +%fused_computation.387 (param_0.1128: bf16[4,128,1,64]) -> bf16[4,128,1,128] { + %param_0.1128 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1029 = bf16[]{:T(256)} constant(-inf) + %pad.48 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1128, %constant.1029), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} + %pad.47 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1128, %constant.1029), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} ROOT %maximum.43 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.48, %pad.47), metadata={op_name="jit(train_step)/layers/concatenate" stack_frame_id=0} } -%region_35.40 (reduce_sum.281: f32[], reduce_sum.282: f32[]) -> f32[] { - %reduce_sum.282 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.281 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.283 = f32[]{:T(128)} add(%reduce_sum.281, %reduce_sum.282), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_35.40 (reduce_sum.383: f32[], reduce_sum.387: f32[]) -> f32[] { + %reduce_sum.387 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.383 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.388 = f32[]{:T(128)} add(%reduce_sum.383, %reduce_sum.387), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_34.39 (reduce_sum.275: f32[], reduce_sum.276: f32[]) -> f32[] { - %reduce_sum.276 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.275 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.277 = f32[]{:T(128)} add(%reduce_sum.275, %reduce_sum.276), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_34.39 (reduce_sum.380: f32[], reduce_sum.381: f32[]) -> f32[] { + %reduce_sum.381 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.380 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.382 = f32[]{:T(128)} add(%reduce_sum.380, %reduce_sum.381), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.380 (param_0.1386: f32[4,2048], param_1.1572: f32[4,2048]) -> (f32[], f32[]) { - %param_0.1386 = f32[4,2048]{1,0:T(4,128)S(1)} parameter(0) - %bitcast.404 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_0.1386), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.249 = f32[2048,4]{0,1:T(4,128)} multiply(%bitcast.404, %bitcast.404), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1218 = f32[]{:T(128)} constant(0) - %reduce.205 = f32[]{:T(128)} reduce(%square.249, %constant.1218), dimensions={0,1}, to_apply=%region_35.40, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1572 = f32[4,2048]{1,0:T(4,128)} parameter(1) - %bitcast.408.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_1.1572), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.252.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%bitcast.408.clone.1, %bitcast.408.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.206.clone.1 = f32[]{:T(128)} reduce(%square.252.clone.1, %constant.1218), dimensions={0,1}, to_apply=%region_34.39, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.169 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.205, %reduce.206.clone.1) +%fused_computation.391 (param_0.1427: f32[4,2048], param_1.1599: f32[4,2048]) -> (f32[], f32[]) { + %param_0.1427 = f32[4,2048]{1,0:T(4,128)S(1)} parameter(0) + %bitcast.385 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_0.1427), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.249 = f32[2048,4]{0,1:T(4,128)} multiply(%bitcast.385, %bitcast.385), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1206 = f32[]{:T(128)} constant(0) + %reduce.135 = f32[]{:T(128)} reduce(%square.249, %constant.1206), dimensions={0,1}, to_apply=%region_35.40, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1599 = f32[4,2048]{1,0:T(4,128)} parameter(1) + %bitcast.389.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_1.1599), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.252.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%bitcast.389.clone.1, %bitcast.389.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.136.clone.1 = f32[]{:T(128)} reduce(%square.252.clone.1, %constant.1206), dimensions={0,1}, to_apply=%region_34.39, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.170 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.135, %reduce.136.clone.1) } -%region_64.69 (reduce_sum.428: f32[], reduce_sum.429: f32[]) -> f32[] { - %reduce_sum.429 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.428 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.430 = f32[]{:T(128)} add(%reduce_sum.428, %reduce_sum.429), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_64.69 (reduce_sum.530: f32[], reduce_sum.534: f32[]) -> f32[] { + %reduce_sum.534 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.530 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.535 = f32[]{:T(128)} add(%reduce_sum.530, %reduce_sum.534), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_49.54 (reduce_sum.347: f32[], reduce_sum.351: f32[]) -> f32[] { - %reduce_sum.351 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.347 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.352 = f32[]{:T(128)} add(%reduce_sum.347, %reduce_sum.351), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_49.54 (reduce_sum.452: f32[], reduce_sum.453: f32[]) -> f32[] { + %reduce_sum.453 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.452 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.457 = f32[]{:T(128)} add(%reduce_sum.452, %reduce_sum.453), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.383 (param_0.1375: f32[2048,4], param_1.1563: f32[], param_2.1321: f32[], param_3.925: f32[], param_4.563: f32[2048,4], param_5.475: f32[], param_6.365: f32[4,2048], param_7.208: pred[], param_8.125: f32[2048,4]) -> (f32[], f32[2048,4], f32[2048,4], f32[2048,4], f32[]) { - %param_0.1375 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(0) - %param_3.925 = f32[]{:T(128)S(6)} parameter(3) - %mul.1984.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_3.925), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.208 = pred[]{:T(512)S(6)} parameter(7) - %select_n.296.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.208), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.365 = f32[4,2048]{1,0:T(4,128)S(1)} parameter(6) - %bitcast.478.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_6.365), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.475 = f32[]{:T(128)} parameter(5) - %div.916.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_5.475), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.915.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%bitcast.478.clone.1, %div.916.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.295.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%select_n.296.clone.1, %bitcast.478.clone.1, %div.915.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1134.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.878.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1134.clone.1), dimensions={}, metadata={op_name="broadcast.82"} - %mul.1988.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.295.clone.1, %broadcast.878.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.394 (param_0.1416: f32[2048,4], param_1.1590: f32[], param_2.1332: f32[], param_3.915: f32[], param_4.554: f32[2048,4], param_5.488: f32[], param_6.363: f32[4,2048], param_7.207: pred[], param_8.125: f32[2048,4]) -> (f32[], f32[2048,4], f32[2048,4], f32[2048,4], f32[]) { + %param_0.1416 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.915 = f32[]{:T(128)S(6)} parameter(3) + %mul.2507.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_3.915), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.207 = pred[]{:T(512)S(6)} parameter(7) + %select_n.296.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.207), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.363 = f32[4,2048]{1,0:T(4,128)S(1)} parameter(6) + %bitcast.459.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_6.363), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.488 = f32[]{:T(128)} parameter(5) + %div.916.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_5.488), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.915.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%bitcast.459.clone.1, %div.916.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.295.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%select_n.296.clone.1, %bitcast.459.clone.1, %div.915.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1122.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.788.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1122.clone.1), dimensions={}, metadata={op_name="broadcast.82"} + %mul.2511.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.295.clone.1, %broadcast.788.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.125 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(8) - %constant.1138.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.877.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1138.clone.1), dimensions={}, metadata={op_name="broadcast.81"} - %mul.1987.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_8.125, %broadcast.877.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.978.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.1988.clone.1, %mul.1987.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1321 = f32[]{:T(128)S(6)} parameter(2) - %div.912.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_2.1321), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1126.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.787.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1126.clone.1), dimensions={}, metadata={op_name="broadcast.81"} + %mul.2510.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_8.125, %broadcast.787.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.954.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.2511.clone.1, %mul.2510.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1332 = f32[]{:T(128)S(6)} parameter(2) + %div.912.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_2.1332), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.72.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.295.clone.1, %select_n.295.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1137.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.876.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1137.clone.1), dimensions={}, metadata={op_name="broadcast.71"} - %mul.1986.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%integer_pow.72.clone.1, %broadcast.876.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.563 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(4) - %constant.1136.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.875.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1136.clone.1), dimensions={}, metadata={op_name="broadcast.70"} - %mul.1985.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_4.563, %broadcast.875.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.977.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.1986.clone.1, %mul.1985.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1563 = f32[]{:T(128)S(6)} parameter(1) - %div.911.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_1.1563), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.910.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.977.clone.1, %div.911.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1125.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.786.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1125.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.2509.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%integer_pow.72.clone.1, %broadcast.786.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.554 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1124.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.785.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1124.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.2508.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_4.554, %broadcast.785.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.953.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.2509.clone.1, %mul.2508.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1590 = f32[]{:T(128)S(6)} parameter(1) + %div.911.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_1.1590), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.910.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.953.clone.1, %div.911.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.69.clone.1 = f32[2048,4]{0,1:T(4,128)} sqrt(%div.910.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1135.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.873.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1135.clone.1), dimensions={}, metadata={op_name="broadcast.64"} - %add.976.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%sqrt.69.clone.1, %broadcast.873.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.433.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%div.912.clone.1, %add.976.clone.1), metadata={op_name="multiply.54"} - %div.909.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.978.clone.1, %multiply.433.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1983.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_0.1375, %broadcast.878.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.975.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%div.909.clone.1, %mul.1983.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1982.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%mul.1984.clone.1, %add.975.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.974.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%param_0.1375, %mul.1982.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.253 = f32[2048,4]{0,1:T(4,128)} multiply(%add.974.clone.1, %add.974.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1207 = f32[]{:T(128)} constant(0) - %reduce.207 = f32[]{:T(128)} reduce(%square.253, %constant.1207), dimensions={0,1}, to_apply=%region_64.69, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.209.clone.1 = f32[]{:T(128)} reduce(%integer_pow.72.clone.1, %constant.1207), dimensions={0,1}, to_apply=%region_49.54, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.152 = (f32[]{:T(128)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.207, %add.974.clone.1, %add.977.clone.1, %add.978.clone.1, %reduce.209.clone.1) -} - -%region_63.68 (reduce_sum.422: f32[], reduce_sum.423: f32[]) -> f32[] { - %reduce_sum.423 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.422 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.424 = f32[]{:T(128)} add(%reduce_sum.422, %reduce_sum.423), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_48.53 (reduce_sum.344: f32[], reduce_sum.345: f32[]) -> f32[] { - %reduce_sum.345 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.344 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.346 = f32[]{:T(128)} add(%reduce_sum.344, %reduce_sum.345), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.1123.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.783.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1123.clone.1), dimensions={}, metadata={op_name="broadcast.64"} + %add.952.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%sqrt.69.clone.1, %broadcast.783.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.294.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%div.912.clone.1, %add.952.clone.1), metadata={op_name="multiply.39"} + %div.909.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.954.clone.1, %multiply.294.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2506.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_0.1416, %broadcast.788.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.951.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%div.909.clone.1, %mul.2506.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2505.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%mul.2507.clone.1, %add.951.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.950.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%param_0.1416, %mul.2505.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.253 = f32[2048,4]{0,1:T(4,128)} multiply(%add.950.clone.1, %add.950.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1195 = f32[]{:T(128)} constant(0) + %reduce.137 = f32[]{:T(128)} reduce(%square.253, %constant.1195), dimensions={0,1}, to_apply=%region_64.69, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.139.clone.1 = f32[]{:T(128)} reduce(%integer_pow.72.clone.1, %constant.1195), dimensions={0,1}, to_apply=%region_49.54, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.153 = (f32[]{:T(128)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.137, %add.950.clone.1, %add.953.clone.1, %add.954.clone.1, %reduce.139.clone.1) +} + +%region_63.68 (reduce_sum.527: f32[], reduce_sum.528: f32[]) -> f32[] { + %reduce_sum.528 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.527 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.529 = f32[]{:T(128)} add(%reduce_sum.527, %reduce_sum.528), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_48.53 (reduce_sum.446: f32[], reduce_sum.450: f32[]) -> f32[] { + %reduce_sum.450 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.446 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.451 = f32[]{:T(128)} add(%reduce_sum.446, %reduce_sum.450), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.384 (param_0.1376: f32[2048,4], param_1.1564: f32[], param_2.1322: f32[], param_3.926: f32[], param_4.564: f32[2048,4], param_5.476: f32[], param_6.366: f32[4,2048], param_7.209: pred[], param_8.126: f32[2048,4]) -> (f32[], f32[2048,4], f32[2048,4], f32[2048,4], f32[]) { - %param_0.1376 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(0) - %param_3.926 = f32[]{:T(128)S(6)} parameter(3) - %mul.1991.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_3.926), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.209 = pred[]{:T(512)S(6)} parameter(7) - %select_n.300.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.209), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.366 = f32[4,2048]{1,0:T(4,128)} parameter(6) - %bitcast.480.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_6.366), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.476 = f32[]{:T(128)} parameter(5) - %div.924.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_5.476), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.923.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%bitcast.480.clone.1, %div.924.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.299.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%select_n.300.clone.1, %bitcast.480.clone.1, %div.923.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1140.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.884.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1140.clone.1), dimensions={}, metadata={op_name="broadcast.82"} - %mul.1995.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.299.clone.1, %broadcast.884.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.395 (param_0.1417: f32[2048,4], param_1.1591: f32[], param_2.1333: f32[], param_3.916: f32[], param_4.555: f32[2048,4], param_5.489: f32[], param_6.364: f32[4,2048], param_7.208: pred[], param_8.126: f32[2048,4]) -> (f32[], f32[2048,4], f32[2048,4], f32[2048,4], f32[]) { + %param_0.1417 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.916 = f32[]{:T(128)S(6)} parameter(3) + %mul.2514.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_3.916), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.208 = pred[]{:T(512)S(6)} parameter(7) + %select_n.300.clone.1 = pred[2048,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.208), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.364 = f32[4,2048]{1,0:T(4,128)} parameter(6) + %bitcast.461.clone.1 = f32[2048,4]{0,1:T(4,128)} bitcast(%param_6.364), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.489 = f32[]{:T(128)} parameter(5) + %div.924.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_5.489), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.923.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%bitcast.461.clone.1, %div.924.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.299.clone.1 = f32[2048,4]{0,1:T(4,128)} select(%select_n.300.clone.1, %bitcast.461.clone.1, %div.923.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1128.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.794.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1128.clone.1), dimensions={}, metadata={op_name="broadcast.82"} + %mul.2518.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.299.clone.1, %broadcast.794.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.126 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(8) - %constant.1144.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.883.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1144.clone.1), dimensions={}, metadata={op_name="broadcast.81"} - %mul.1994.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_8.126, %broadcast.883.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.983.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.1995.clone.1, %mul.1994.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1322 = f32[]{:T(128)S(6)} parameter(2) - %div.920.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_2.1322), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1132.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.793.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1132.clone.1), dimensions={}, metadata={op_name="broadcast.81"} + %mul.2517.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_8.126, %broadcast.793.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.959.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.2518.clone.1, %mul.2517.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1333 = f32[]{:T(128)S(6)} parameter(2) + %div.920.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_2.1333), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.73.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%select_n.299.clone.1, %select_n.299.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1143.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.882.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1143.clone.1), dimensions={}, metadata={op_name="broadcast.71"} - %mul.1993.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%integer_pow.73.clone.1, %broadcast.882.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.564 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(4) - %constant.1142.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.881.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1142.clone.1), dimensions={}, metadata={op_name="broadcast.70"} - %mul.1992.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_4.564, %broadcast.881.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.982.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.1993.clone.1, %mul.1992.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1564 = f32[]{:T(128)S(6)} parameter(1) - %div.919.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_1.1564), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.918.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.982.clone.1, %div.919.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1131.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.792.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1131.clone.1), dimensions={}, metadata={op_name="broadcast.71"} + %mul.2516.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%integer_pow.73.clone.1, %broadcast.792.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.555 = f32[2048,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1130.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.791.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1130.clone.1), dimensions={}, metadata={op_name="broadcast.70"} + %mul.2515.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_4.555, %broadcast.791.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.958.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%mul.2516.clone.1, %mul.2515.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1591 = f32[]{:T(128)S(6)} parameter(1) + %div.919.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%param_1.1591), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.918.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.958.clone.1, %div.919.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.70.clone.1 = f32[2048,4]{0,1:T(4,128)} sqrt(%div.918.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1141.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.879.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1141.clone.1), dimensions={}, metadata={op_name="broadcast.64"} - %add.981.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%sqrt.70.clone.1, %broadcast.879.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.434.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%div.920.clone.1, %add.981.clone.1), metadata={op_name="multiply.53"} - %div.917.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.983.clone.1, %multiply.434.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1990.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_0.1376, %broadcast.884.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.980.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%div.917.clone.1, %mul.1990.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1989.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%mul.1991.clone.1, %add.980.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.979.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%param_0.1376, %mul.1989.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.254 = f32[2048,4]{0,1:T(4,128)} multiply(%add.979.clone.1, %add.979.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1208 = f32[]{:T(128)} constant(0) - %reduce.208 = f32[]{:T(128)} reduce(%square.254, %constant.1208), dimensions={0,1}, to_apply=%region_63.68, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.210.clone.1 = f32[]{:T(128)} reduce(%integer_pow.73.clone.1, %constant.1208), dimensions={0,1}, to_apply=%region_48.53, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.153 = (f32[]{:T(128)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.208, %add.979.clone.1, %add.982.clone.1, %add.983.clone.1, %reduce.210.clone.1) + %constant.1129.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.789.clone.1 = f32[2048,4]{0,1:T(4,128)} broadcast(%constant.1129.clone.1), dimensions={}, metadata={op_name="broadcast.64"} + %add.957.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%sqrt.70.clone.1, %broadcast.789.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.295.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%div.920.clone.1, %add.957.clone.1), metadata={op_name="multiply.38"} + %div.917.clone.1 = f32[2048,4]{0,1:T(4,128)} divide(%add.959.clone.1, %multiply.295.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2513.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%param_0.1417, %broadcast.794.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.956.clone.1 = f32[2048,4]{0,1:T(4,128)} add(%div.917.clone.1, %mul.2513.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2512.clone.1 = f32[2048,4]{0,1:T(4,128)} multiply(%mul.2514.clone.1, %add.956.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.955.clone.1 = f32[2048,4]{0,1:T(4,128)S(1)} add(%param_0.1417, %mul.2512.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.254 = f32[2048,4]{0,1:T(4,128)} multiply(%add.955.clone.1, %add.955.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1196 = f32[]{:T(128)} constant(0) + %reduce.138 = f32[]{:T(128)} reduce(%square.254, %constant.1196), dimensions={0,1}, to_apply=%region_63.68, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.140.clone.1 = f32[]{:T(128)} reduce(%integer_pow.73.clone.1, %constant.1196), dimensions={0,1}, to_apply=%region_48.53, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.154 = (f32[]{:T(128)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[2048,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.138, %add.955.clone.1, %add.958.clone.1, %add.959.clone.1, %reduce.140.clone.1) +} + +%region_11.14 (reduce_sum.269: f32[], reduce_sum.270: f32[]) -> f32[] { + %reduce_sum.270 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.269 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.271 = f32[]{:T(128)} add(%reduce_sum.269, %reduce_sum.270), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_11.14 (reduce_sum.192: f32[], reduce_sum.193: f32[]) -> f32[] { - %reduce_sum.193 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.192 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.197 = f32[]{:T(128)} add(%reduce_sum.192, %reduce_sum.193), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%fused_computation.406 (param_0.1431: bf16[2048]) -> f32[] { + %param_0.1431 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(0) + %convert_element_type.1420 = f32[2048]{0:T(1024)} convert(%param_0.1431), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %square.257 = f32[2048]{0:T(1024)} multiply(%convert_element_type.1420, %convert_element_type.1420), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1210 = f32[]{:T(128)} constant(0) + ROOT %reduce.141 = f32[]{:T(128)} reduce(%square.257, %constant.1210), dimensions={0}, to_apply=%region_11.14, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} } -%fused_computation.395 (param_0.1390: bf16[2048]) -> f32[] { - %param_0.1390 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(0) - %convert_element_type.1396 = f32[2048]{0:T(1024)} convert(%param_0.1390), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %square.257 = f32[2048]{0:T(1024)} multiply(%convert_element_type.1396, %convert_element_type.1396), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1222 = f32[]{:T(128)} constant(0) - ROOT %reduce.211 = f32[]{:T(128)} reduce(%square.257, %constant.1222), dimensions={0}, to_apply=%region_11.14, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} +%region_59.64 (reduce_sum.506: f32[], reduce_sum.507: f32[]) -> f32[] { + %reduce_sum.507 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.506 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.508 = f32[]{:T(128)} add(%reduce_sum.506, %reduce_sum.507), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_59.64 (reduce_sum.401: f32[], reduce_sum.402: f32[]) -> f32[] { - %reduce_sum.402 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.401 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.403 = f32[]{:T(128)} add(%reduce_sum.401, %reduce_sum.402), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_44.49 (reduce_sum.425: f32[], reduce_sum.429: f32[]) -> f32[] { + %reduce_sum.429 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.425 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.430 = f32[]{:T(128)} add(%reduce_sum.425, %reduce_sum.429), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_44.49 (reduce_sum.323: f32[], reduce_sum.324: f32[]) -> f32[] { - %reduce_sum.324 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.323 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.325 = f32[]{:T(128)} add(%reduce_sum.323, %reduce_sum.324), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.396 (param_0.1380: f32[2048], param_1.1568: f32[], param_2.1326: f32[], param_3.930: f32[], param_4.568: f32[2048], param_5.480: f32[], param_6.370: bf16[2048], param_7.213: pred[], param_8.130: f32[2048]) -> (f32[], f32[2048], f32[2048], f32[2048], f32[]) { - %param_0.1380 = f32[2048]{0:T(1024)S(1)} parameter(0) - %param_3.930 = f32[]{:T(128)S(6)} parameter(3) - %mul.2022.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_3.930), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.213 = pred[]{:T(512)S(6)} parameter(7) - %select_n.316.clone.1 = pred[2048]{0:T(1024)(128)(4,1)} broadcast(%param_7.213), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.370 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(6) - %convert_element_type.1411.clone.1 = f32[2048]{0:T(1024)} convert(%param_6.370), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_5.480 = f32[]{:T(128)} parameter(5) - %div.956.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_5.480), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.955.clone.1 = f32[2048]{0:T(1024)} divide(%convert_element_type.1411.clone.1, %div.956.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.315.clone.1 = f32[2048]{0:T(1024)} select(%select_n.316.clone.1, %convert_element_type.1411.clone.1, %div.955.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1164.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.900.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1164.clone.1), dimensions={}, metadata={op_name="broadcast.86"} - %mul.2028.clone.1 = f32[2048]{0:T(1024)} multiply(%select_n.315.clone.1, %broadcast.900.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.407 (param_0.1421: f32[2048], param_1.1595: f32[], param_2.1337: f32[], param_3.920: f32[], param_4.559: f32[2048], param_5.493: f32[], param_6.368: bf16[2048], param_7.212: pred[], param_8.130: f32[2048]) -> (f32[], f32[2048], f32[2048], f32[2048], f32[]) { + %param_0.1421 = f32[2048]{0:T(1024)S(1)} parameter(0) + %param_3.920 = f32[]{:T(128)S(6)} parameter(3) + %mul.2545.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_3.920), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.212 = pred[]{:T(512)S(6)} parameter(7) + %select_n.316.clone.1 = pred[2048]{0:T(1024)(128)(4,1)} broadcast(%param_7.212), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.368 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(6) + %convert_element_type.1435.clone.1 = f32[2048]{0:T(1024)} convert(%param_6.368), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_5.493 = f32[]{:T(128)} parameter(5) + %div.956.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_5.493), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.955.clone.1 = f32[2048]{0:T(1024)} divide(%convert_element_type.1435.clone.1, %div.956.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.315.clone.1 = f32[2048]{0:T(1024)} select(%select_n.316.clone.1, %convert_element_type.1435.clone.1, %div.955.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1152.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.810.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1152.clone.1), dimensions={}, metadata={op_name="broadcast.86"} + %mul.2551.clone.1 = f32[2048]{0:T(1024)} multiply(%select_n.315.clone.1, %broadcast.810.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.130 = f32[2048]{0:T(1024)S(1)} parameter(8) - %constant.1168.clone.1 = f32[]{:T(128)} constant(0.9) - %mul.2029.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1168.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.2027.clone.1 = f32[2048]{0:T(1024)} multiply(%param_8.130, %mul.2029.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.1005.clone.1 = f32[2048]{0:T(1024)S(1)} add(%mul.2028.clone.1, %mul.2027.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1326 = f32[]{:T(128)S(6)} parameter(2) - %div.952.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_2.1326), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1156.clone.1 = f32[]{:T(128)} constant(0.9) + %mul.2552.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1156.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2550.clone.1 = f32[2048]{0:T(1024)} multiply(%param_8.130, %mul.2552.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.981.clone.1 = f32[2048]{0:T(1024)S(1)} add(%mul.2551.clone.1, %mul.2550.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1337 = f32[]{:T(128)S(6)} parameter(2) + %div.952.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_2.1337), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.77.clone.1 = f32[2048]{0:T(1024)} multiply(%select_n.315.clone.1, %select_n.315.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1167.clone.1 = f32[]{:T(128)} constant(0.05) - %mul.2026.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1167.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.2024.clone.1 = f32[2048]{0:T(1024)} multiply(%integer_pow.77.clone.1, %mul.2026.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.568 = f32[2048]{0:T(1024)S(1)} parameter(4) - %constant.1166.clone.1 = f32[]{:T(128)} constant(0.95) - %mul.2025.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1166.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %mul.2023.clone.1 = f32[2048]{0:T(1024)} multiply(%param_4.568, %mul.2025.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.1004.clone.1 = f32[2048]{0:T(1024)S(1)} add(%mul.2024.clone.1, %mul.2023.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1568 = f32[]{:T(128)S(6)} parameter(1) - %div.951.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_1.1568), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.950.clone.1 = f32[2048]{0:T(1024)} divide(%add.1004.clone.1, %div.951.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1155.clone.1 = f32[]{:T(128)} constant(0.05) + %mul.2549.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1155.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2547.clone.1 = f32[2048]{0:T(1024)} multiply(%integer_pow.77.clone.1, %mul.2549.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.559 = f32[2048]{0:T(1024)S(1)} parameter(4) + %constant.1154.clone.1 = f32[]{:T(128)} constant(0.95) + %mul.2548.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1154.clone.1), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %mul.2546.clone.1 = f32[2048]{0:T(1024)} multiply(%param_4.559, %mul.2548.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.980.clone.1 = f32[2048]{0:T(1024)S(1)} add(%mul.2547.clone.1, %mul.2546.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1595 = f32[]{:T(128)S(6)} parameter(1) + %div.951.clone.1 = f32[2048]{0:T(1024)} broadcast(%param_1.1595), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.950.clone.1 = f32[2048]{0:T(1024)} divide(%add.980.clone.1, %div.951.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.74.clone.1 = f32[2048]{0:T(1024)} sqrt(%div.950.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1165.clone.1 = f32[]{:T(128)} constant(1e-08) - %add.1003.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1165.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %add.1002.clone.1 = f32[2048]{0:T(1024)} add(%sqrt.74.clone.1, %add.1003.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.438.clone.1 = f32[2048]{0:T(1024)} multiply(%div.952.clone.1, %add.1002.clone.1), metadata={op_name="multiply.49"} - %div.949.clone.1 = f32[2048]{0:T(1024)} divide(%add.1005.clone.1, %multiply.438.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.2021.clone.1 = f32[2048]{0:T(1024)} multiply(%param_0.1380, %broadcast.900.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.1001.clone.1 = f32[2048]{0:T(1024)} add(%div.949.clone.1, %mul.2021.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.2020.clone.1 = f32[2048]{0:T(1024)} multiply(%mul.2022.clone.1, %add.1001.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.1000.clone.1 = f32[2048]{0:T(1024)S(1)} add(%param_0.1380, %mul.2020.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.258 = f32[2048]{0:T(1024)} multiply(%add.1000.clone.1, %add.1000.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1212 = f32[]{:T(128)} constant(0) - %reduce.212 = f32[]{:T(128)} reduce(%square.258, %constant.1212), dimensions={0}, to_apply=%region_59.64, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.213.clone.1 = f32[]{:T(128)} reduce(%integer_pow.77.clone.1, %constant.1212), dimensions={0}, to_apply=%region_44.49, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.156 = (f32[]{:T(128)}, f32[2048]{0:T(1024)S(1)}, f32[2048]{0:T(1024)S(1)}, f32[2048]{0:T(1024)S(1)}, f32[]{:T(128)}) tuple(%reduce.212, %add.1000.clone.1, %add.1004.clone.1, %add.1005.clone.1, %reduce.213.clone.1) -} - -%fused_computation.402 (param_0.1150: s32[512]) -> s32[1024] { - %constant.972 = s32[] constant(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %broadcast.815 = s32[1024]{0:T(1024)} broadcast(%constant.972), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %param_0.1150 = s32[512]{0:T(512)S(1)} parameter(0) - %constant.973 = s32[] constant(2147483647), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %pad.49 = s32[1024]{0:T(1024)} pad(%param_0.1150, %constant.973), padding=0_512, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %constant.971 = s32[] constant(151935), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - %broadcast.814 = s32[1024]{0:T(1024)} broadcast(%constant.971), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} - ROOT %clamp.1 = s32[1024]{0:T(1024)} clamp(%broadcast.815, %pad.49, %broadcast.814), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} -} - -%fused_computation.405 (param_0.1149: s32[4,128]) -> s32[512] { - %param_0.1149 = s32[4,128]{1,0:T(4,128)} parameter(0) - %constant.1065 = s32[]{:T(128)} constant(0) - %broadcast.834 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1065), dimensions={}, metadata={op_name="broadcast.95"} - %lt.32 = pred[4,128]{1,0:T(4,128)(4,1)} compare(%param_0.1149, %broadcast.834), direction=LT, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/lt" stack_frame_id=0} - %constant.1051 = s32[]{:T(128)} constant(151936) - %add.925 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1051), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} - %add.903 = s32[4,128]{1,0:T(4,128)} add(%param_0.1149, %add.925), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} - %select_n.178 = s32[4,128]{1,0:T(4,128)} select(%lt.32, %add.903, %param_0.1149), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/select_n" stack_frame_id=0} - ROOT %bitcast.409 = s32[512]{0:T(512)S(1)} bitcast(%select_n.178), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} -} - -%region_40.45 (reduce_sum.305: f32[], reduce_sum.309: f32[]) -> f32[] { - %reduce_sum.309 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.305 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.310 = f32[]{:T(128)} add(%reduce_sum.305, %reduce_sum.309), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_37.42 (reduce_sum.290: f32[], reduce_sum.291: f32[]) -> f32[] { - %reduce_sum.291 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.290 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.295 = f32[]{:T(128)} add(%reduce_sum.290, %reduce_sum.291), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.407 (param_0.1384: f32[4,128], param_1.1570: f32[4,128]) -> (f32[], f32[]) { - %param_0.1384 = f32[4,128]{1,0:T(4,128)} parameter(0) - %bitcast.413 = f32[128,4]{0,1:T(4,128)} bitcast(%param_0.1384), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.261 = f32[128,4]{0,1:T(4,128)} multiply(%bitcast.413, %bitcast.413), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1216 = f32[]{:T(128)} constant(0) - %reduce.214 = f32[]{:T(128)} reduce(%square.261, %constant.1216), dimensions={0,1}, to_apply=%region_40.45, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %param_1.1570 = f32[4,128]{1,0:T(4,128)} parameter(1) - %bitcast.417.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_1.1570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %square.264.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%bitcast.417.clone.1, %bitcast.417.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %reduce.215.clone.1 = f32[]{:T(128)} reduce(%square.264.clone.1, %constant.1216), dimensions={0,1}, to_apply=%region_37.42, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.170 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.214, %reduce.215.clone.1) -} - -%region_72.77 (reduce_sum.470: f32[], reduce_sum.471: f32[]) -> f32[] { - %reduce_sum.471 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.470 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.472 = f32[]{:T(128)} add(%reduce_sum.470, %reduce_sum.471), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_58.63 (reduce_sum.395: f32[], reduce_sum.396: f32[]) -> f32[] { - %reduce_sum.396 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.395 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.400 = f32[]{:T(128)} add(%reduce_sum.395, %reduce_sum.396), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.410 (param_0.1391: bf16[4,128], param_1.1576: f32[4,128], param_2.1329: f32[4,128], param_3.932: s32[4,128]) -> (f32[], f32[], pred[4,128], f32[4,128]) { - %param_3.932 = s32[4,128]{1,0:T(4,128)S(1)} parameter(3) - %constant.1170.clone.1 = s32[]{:T(128)} constant(0) - %broadcast.901.clone.1 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1170.clone.1), dimensions={}, metadata={op_name="broadcast.95"} - %ne.6.clone.1 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} compare(%param_3.932, %broadcast.901.clone.1), direction=NE, metadata={op_name="jit(train_step)/jvp()/ne" stack_frame_id=0} - %param_1.1576 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %log.16 = f32[4,128]{1,0:T(4,128)} log(%param_1.1576), metadata={op_name="jit(train_step)/jvp()/log" stack_frame_id=0} - %param_0.1391 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(0) - %reduce_max.15 = f32[4,128]{1,0:T(4,128)} convert(%param_0.1391), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} - %add.927 = f32[4,128]{1,0:T(4,128)} add(%log.16, %reduce_max.15), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} - %square.269 = f32[4,128]{1,0:T(4,128)} multiply(%add.927, %add.927), metadata={op_name="jit(train_step)/jvp()/square" stack_frame_id=0} - %constant.1224 = f32[]{:T(128)} constant(0) - %broadcast.831 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1224), dimensions={}, metadata={op_name="broadcast.99"} - %mul.1913 = f32[4,128]{1,0:T(4,128)} multiply(%square.269, %broadcast.831), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - %mul.1893 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %mul.1913, %broadcast.831), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - %reduce.216 = f32[]{:T(128)} reduce(%mul.1893, %constant.1224), dimensions={0,1}, to_apply=%region_72.77, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} - %param_2.1329 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %neg.115.clone.1 = f32[4,128]{1,0:T(4,128)} negate(%param_2.1329), metadata={op_name="jit(train_step)/jvp()/neg" stack_frame_id=0} - %add.904.clone.1 = f32[4,128]{1,0:T(4,128)} add(%neg.115.clone.1, %mul.1913), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} - %mul.1894.clone.1 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %add.904.clone.1, %broadcast.831), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} - %reduce.219.clone.1 = f32[]{:T(128)} reduce(%mul.1894.clone.1, %constant.1224), dimensions={0,1}, to_apply=%region_58.63, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} - %mul.1911.clone.1 = f32[4,128]{1,0:T(4,128)} multiply(%add.927, %broadcast.831), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} - %constant.1068.clone.1 = f32[]{:T(128)} constant(1) - %add.922.clone.1 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1068.clone.1), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} - %add.915.clone.1 = f32[4,128]{1,0:T(4,128)S(1)} add(%mul.1911.clone.1, %add.922.clone.1), metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} - ROOT %tuple.157 = (f32[]{:T(128)}, f32[]{:T(128)}, pred[4,128]{1,0:T(4,128)(4,1)S(1)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%reduce.216, %reduce.219.clone.1, %ne.6.clone.1, %add.915.clone.1) -} - -%region_69.74 (reduce_sum.452: f32[], reduce_sum.456: f32[]) -> f32[] { - %reduce_sum.456 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.452 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.457 = f32[]{:T(128)} add(%reduce_sum.452, %reduce_sum.456), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.1153.clone.1 = f32[]{:T(128)} constant(1e-08) + %add.979.clone.1 = f32[2048]{0:T(1024)} broadcast(%constant.1153.clone.1), dimensions={}, metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %add.978.clone.1 = f32[2048]{0:T(1024)} add(%sqrt.74.clone.1, %add.979.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.299.clone.1 = f32[2048]{0:T(1024)} multiply(%div.952.clone.1, %add.978.clone.1), metadata={op_name="multiply.34"} + %div.949.clone.1 = f32[2048]{0:T(1024)} divide(%add.981.clone.1, %multiply.299.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2544.clone.1 = f32[2048]{0:T(1024)} multiply(%param_0.1421, %broadcast.810.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.977.clone.1 = f32[2048]{0:T(1024)} add(%div.949.clone.1, %mul.2544.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2543.clone.1 = f32[2048]{0:T(1024)} multiply(%mul.2545.clone.1, %add.977.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.976.clone.1 = f32[2048]{0:T(1024)S(1)} add(%param_0.1421, %mul.2543.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.258 = f32[2048]{0:T(1024)} multiply(%add.976.clone.1, %add.976.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1200 = f32[]{:T(128)} constant(0) + %reduce.142 = f32[]{:T(128)} reduce(%square.258, %constant.1200), dimensions={0}, to_apply=%region_59.64, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.143.clone.1 = f32[]{:T(128)} reduce(%integer_pow.77.clone.1, %constant.1200), dimensions={0}, to_apply=%region_44.49, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.157 = (f32[]{:T(128)}, f32[2048]{0:T(1024)S(1)}, f32[2048]{0:T(1024)S(1)}, f32[2048]{0:T(1024)S(1)}, f32[]{:T(128)}) tuple(%reduce.142, %add.976.clone.1, %add.980.clone.1, %add.981.clone.1, %reduce.143.clone.1) +} + +%fused_computation.413 (param_0.1191: s32[512]) -> s32[1024] { + %constant.960 = s32[] constant(0), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %broadcast.727 = s32[1024]{0:T(1024)} broadcast(%constant.960), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %param_0.1191 = s32[512]{0:T(512)S(1)} parameter(0) + %constant.961 = s32[] constant(2147483647), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %pad.49 = s32[1024]{0:T(1024)} pad(%param_0.1191, %constant.961), padding=0_512, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %constant.959 = s32[] constant(151935), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + %broadcast.726 = s32[1024]{0:T(1024)} broadcast(%constant.959), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} + ROOT %clamp.1 = s32[1024]{0:T(1024)} clamp(%broadcast.727, %pad.49, %broadcast.726), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/gather" stack_frame_id=0} +} + +%fused_computation.416 (param_0.1190: s32[4,128]) -> s32[512] { + %param_0.1190 = s32[4,128]{1,0:T(4,128)} parameter(0) + %constant.1055 = s32[]{:T(128)} constant(0) + %broadcast.747 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1055), dimensions={}, metadata={op_name="broadcast.95"} + %lt.32 = pred[4,128]{1,0:T(4,128)(4,1)} compare(%param_0.1190, %broadcast.747), direction=LT, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/lt" stack_frame_id=0} + %constant.1052 = s32[]{:T(128)} constant(151936) + %add.901 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1052), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} + %add.879 = s32[4,128]{1,0:T(4,128)} add(%param_0.1190, %add.901), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/add" stack_frame_id=0} + %select_n.178 = s32[4,128]{1,0:T(4,128)} select(%lt.32, %add.879, %param_0.1190), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/select_n" stack_frame_id=0} + ROOT %bitcast.390 = s32[512]{0:T(512)S(1)} bitcast(%select_n.178), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/broadcast_in_dim" stack_frame_id=0} +} + +%region_40.45 (reduce_sum.410: f32[], reduce_sum.411: f32[]) -> f32[] { + %reduce_sum.411 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.410 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.415 = f32[]{:T(128)} add(%reduce_sum.410, %reduce_sum.411), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%region_54.59 (reduce_sum.374: f32[], reduce_sum.375: f32[]) -> f32[] { - %reduce_sum.375 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.374 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.379 = f32[]{:T(128)} add(%reduce_sum.374, %reduce_sum.375), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_37.42 (reduce_sum.395: f32[], reduce_sum.396: f32[]) -> f32[] { + %reduce_sum.396 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.395 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.397 = f32[]{:T(128)} add(%reduce_sum.395, %reduce_sum.396), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.411 (param_0.1370: f32[128,4], param_1.1558: f32[], param_2.1316: f32[], param_3.920: f32[], param_4.558: f32[128,4], param_5.470: f32[], param_6.360: f32[4,128], param_7.203: pred[], param_8.120: f32[128,4]) -> (f32[], f32[128,4], f32[128,4], f32[128,4], f32[]) { - %param_0.1370 = f32[128,4]{0,1:T(4,128)S(1)} parameter(0) - %param_3.920 = f32[]{:T(128)S(6)} parameter(3) - %mul.1943.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_3.920), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.203 = pred[]{:T(512)S(6)} parameter(7) - %select_n.276.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.203), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.360 = f32[4,128]{1,0:T(4,128)} parameter(6) - %bitcast.468.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_6.360), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.470 = f32[]{:T(128)} parameter(5) - %div.876.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_5.470), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.875.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%bitcast.468.clone.1, %div.876.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.275.clone.1 = f32[128,4]{0,1:T(4,128)} select(%select_n.276.clone.1, %bitcast.468.clone.1, %div.875.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1104.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.856.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1104.clone.1), dimensions={}, metadata={op_name="broadcast.78"} - %mul.1947.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.275.clone.1, %broadcast.856.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.418 (param_0.1425: f32[4,128], param_1.1597: f32[4,128]) -> (f32[], f32[]) { + %param_0.1425 = f32[4,128]{1,0:T(4,128)} parameter(0) + %bitcast.394 = f32[128,4]{0,1:T(4,128)} bitcast(%param_0.1425), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.261 = f32[128,4]{0,1:T(4,128)} multiply(%bitcast.394, %bitcast.394), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1204 = f32[]{:T(128)} constant(0) + %reduce.144 = f32[]{:T(128)} reduce(%square.261, %constant.1204), dimensions={0,1}, to_apply=%region_40.45, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %param_1.1597 = f32[4,128]{1,0:T(4,128)} parameter(1) + %bitcast.398.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_1.1597), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %square.264.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%bitcast.398.clone.1, %bitcast.398.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %reduce.145.clone.1 = f32[]{:T(128)} reduce(%square.264.clone.1, %constant.1204), dimensions={0,1}, to_apply=%region_37.42, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.171 = (f32[]{:T(128)}, f32[]{:T(128)}) tuple(%reduce.144, %reduce.145.clone.1) +} + +%region_72.77 (reduce_sum.572: f32[], reduce_sum.576: f32[]) -> f32[] { + %reduce_sum.576 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.572 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.577 = f32[]{:T(128)} add(%reduce_sum.572, %reduce_sum.576), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_58.63 (reduce_sum.500: f32[], reduce_sum.501: f32[]) -> f32[] { + %reduce_sum.501 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.500 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.502 = f32[]{:T(128)} add(%reduce_sum.500, %reduce_sum.501), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.421 (param_0.1432: bf16[4,128], param_1.1603: f32[4,128], param_2.1340: f32[4,128], param_3.922: s32[4,128]) -> (f32[], f32[], pred[4,128], f32[4,128]) { + %param_3.922 = s32[4,128]{1,0:T(4,128)S(1)} parameter(3) + %constant.1158.clone.1 = s32[]{:T(128)} constant(0) + %broadcast.811.clone.1 = s32[4,128]{1,0:T(4,128)} broadcast(%constant.1158.clone.1), dimensions={}, metadata={op_name="broadcast.95"} + %ne.6.clone.1 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} compare(%param_3.922, %broadcast.811.clone.1), direction=NE, metadata={op_name="jit(train_step)/jvp()/ne" stack_frame_id=0} + %param_1.1603 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %log.16 = f32[4,128]{1,0:T(4,128)} log(%param_1.1603), metadata={op_name="jit(train_step)/jvp()/log" stack_frame_id=0} + %param_0.1432 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} parameter(0) + %reduce_max.15 = f32[4,128]{1,0:T(4,128)} convert(%param_0.1432), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} + %add.903 = f32[4,128]{1,0:T(4,128)} add(%log.16, %reduce_max.15), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} + %square.269 = f32[4,128]{1,0:T(4,128)} multiply(%add.903, %add.903), metadata={op_name="jit(train_step)/jvp()/square" stack_frame_id=0} + %constant.1212 = f32[]{:T(128)} constant(0) + %broadcast.741 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1212), dimensions={}, metadata={op_name="broadcast.50"} + %mul.2434 = f32[4,128]{1,0:T(4,128)} multiply(%square.269, %broadcast.741), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %mul.2414 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %mul.2434, %broadcast.741), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %reduce.146 = f32[]{:T(128)} reduce(%mul.2414, %constant.1212), dimensions={0,1}, to_apply=%region_72.77, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %param_2.1340 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %neg.115.clone.1 = f32[4,128]{1,0:T(4,128)} negate(%param_2.1340), metadata={op_name="jit(train_step)/jvp()/neg" stack_frame_id=0} + %add.880.clone.1 = f32[4,128]{1,0:T(4,128)} add(%neg.115.clone.1, %mul.2434), metadata={op_name="jit(train_step)/jvp()/add" stack_frame_id=0} + %mul.2415.clone.1 = f32[4,128]{1,0:T(4,128)} select(%ne.6.clone.1, %add.880.clone.1, %broadcast.741), metadata={op_name="jit(train_step)/jvp()/mul" stack_frame_id=0} + %reduce.149.clone.1 = f32[]{:T(128)} reduce(%mul.2415.clone.1, %constant.1212), dimensions={0,1}, to_apply=%region_58.63, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} + %mul.2432.clone.1 = f32[4,128]{1,0:T(4,128)} multiply(%add.903, %broadcast.741), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %constant.1056.clone.1 = f32[]{:T(128)} constant(1) + %add.898.clone.1 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1056.clone.1), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} + %add.891.clone.1 = f32[4,128]{1,0:T(4,128)S(1)} add(%mul.2432.clone.1, %add.898.clone.1), metadata={op_name="jit(train_step)/transpose(jvp())/add" stack_frame_id=0} + ROOT %tuple.158 = (f32[]{:T(128)}, f32[]{:T(128)}, pred[4,128]{1,0:T(4,128)(4,1)S(1)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%reduce.146, %reduce.149.clone.1, %ne.6.clone.1, %add.891.clone.1) +} + +%region_69.74 (reduce_sum.557: f32[], reduce_sum.558: f32[]) -> f32[] { + %reduce_sum.558 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.557 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.562 = f32[]{:T(128)} add(%reduce_sum.557, %reduce_sum.558), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_54.59 (reduce_sum.479: f32[], reduce_sum.480: f32[]) -> f32[] { + %reduce_sum.480 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.479 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.481 = f32[]{:T(128)} add(%reduce_sum.479, %reduce_sum.480), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.422 (param_0.1411: f32[128,4], param_1.1585: f32[], param_2.1327: f32[], param_3.910: f32[], param_4.549: f32[128,4], param_5.483: f32[], param_6.358: f32[4,128], param_7.202: pred[], param_8.120: f32[128,4]) -> (f32[], f32[128,4], f32[128,4], f32[128,4], f32[]) { + %param_0.1411 = f32[128,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.910 = f32[]{:T(128)S(6)} parameter(3) + %mul.2466.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_3.910), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.202 = pred[]{:T(512)S(6)} parameter(7) + %select_n.276.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.202), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.358 = f32[4,128]{1,0:T(4,128)} parameter(6) + %bitcast.449.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_6.358), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.483 = f32[]{:T(128)} parameter(5) + %div.876.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_5.483), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.875.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%bitcast.449.clone.1, %div.876.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.275.clone.1 = f32[128,4]{0,1:T(4,128)} select(%select_n.276.clone.1, %bitcast.449.clone.1, %div.875.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1092.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.766.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1092.clone.1), dimensions={}, metadata={op_name="broadcast.78"} + %mul.2470.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.275.clone.1, %broadcast.766.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.120 = f32[128,4]{0,1:T(4,128)S(1)} parameter(8) - %constant.1108.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.855.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1108.clone.1), dimensions={}, metadata={op_name="broadcast.77"} - %mul.1946.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_8.120, %broadcast.855.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.951.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.1947.clone.1, %mul.1946.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1316 = f32[]{:T(128)S(6)} parameter(2) - %div.872.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_2.1316), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1096.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.765.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1096.clone.1), dimensions={}, metadata={op_name="broadcast.77"} + %mul.2469.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_8.120, %broadcast.765.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.927.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.2470.clone.1, %mul.2469.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1327 = f32[]{:T(128)S(6)} parameter(2) + %div.872.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_2.1327), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.67.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.275.clone.1, %select_n.275.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1107.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.854.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1107.clone.1), dimensions={}, metadata={op_name="broadcast.67"} - %mul.1945.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%integer_pow.67.clone.1, %broadcast.854.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.558 = f32[128,4]{0,1:T(4,128)S(1)} parameter(4) - %constant.1106.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.853.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1106.clone.1), dimensions={}, metadata={op_name="broadcast.66"} - %mul.1944.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_4.558, %broadcast.853.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.950.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.1945.clone.1, %mul.1944.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1558 = f32[]{:T(128)S(6)} parameter(1) - %div.871.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_1.1558), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.870.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.950.clone.1, %div.871.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1095.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.764.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1095.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.2468.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%integer_pow.67.clone.1, %broadcast.764.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.549 = f32[128,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1094.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.763.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1094.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.2467.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_4.549, %broadcast.763.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.926.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.2468.clone.1, %mul.2467.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1585 = f32[]{:T(128)S(6)} parameter(1) + %div.871.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_1.1585), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.870.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.926.clone.1, %div.871.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.64.clone.1 = f32[128,4]{0,1:T(4,128)} sqrt(%div.870.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1105.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.851.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1105.clone.1), dimensions={}, metadata={op_name="broadcast.62"} - %add.949.clone.1 = f32[128,4]{0,1:T(4,128)} add(%sqrt.64.clone.1, %broadcast.851.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.428.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%div.872.clone.1, %add.949.clone.1), metadata={op_name="multiply.59"} - %div.869.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.951.clone.1, %multiply.428.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1942.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_0.1370, %broadcast.856.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.948.clone.1 = f32[128,4]{0,1:T(4,128)} add(%div.869.clone.1, %mul.1942.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1941.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%mul.1943.clone.1, %add.948.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.947.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%param_0.1370, %mul.1941.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.265 = f32[128,4]{0,1:T(4,128)} multiply(%add.947.clone.1, %add.947.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1202 = f32[]{:T(128)} constant(0) - %reduce.217 = f32[]{:T(128)} reduce(%square.265, %constant.1202), dimensions={0,1}, to_apply=%region_69.74, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.221.clone.1 = f32[]{:T(128)} reduce(%integer_pow.67.clone.1, %constant.1202), dimensions={0,1}, to_apply=%region_54.59, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.159 = (f32[]{:T(128)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.217, %add.947.clone.1, %add.950.clone.1, %add.951.clone.1, %reduce.221.clone.1) -} - -%region_66.71 (reduce_sum.437: f32[], reduce_sum.438: f32[]) -> f32[] { - %reduce_sum.438 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.437 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.442 = f32[]{:T(128)} add(%reduce_sum.437, %reduce_sum.438), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%region_51.56 (reduce_sum.359: f32[], reduce_sum.360: f32[]) -> f32[] { - %reduce_sum.360 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} - %reduce_sum.359 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} - ROOT %reduce_sum.361 = f32[]{:T(128)} add(%reduce_sum.359, %reduce_sum.360), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %constant.1093.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.761.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1093.clone.1), dimensions={}, metadata={op_name="broadcast.62"} + %add.925.clone.1 = f32[128,4]{0,1:T(4,128)} add(%sqrt.64.clone.1, %broadcast.761.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.289.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%div.872.clone.1, %add.925.clone.1), metadata={op_name="multiply.44"} + %div.869.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.927.clone.1, %multiply.289.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2465.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_0.1411, %broadcast.766.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.924.clone.1 = f32[128,4]{0,1:T(4,128)} add(%div.869.clone.1, %mul.2465.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2464.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%mul.2466.clone.1, %add.924.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.923.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%param_0.1411, %mul.2464.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.265 = f32[128,4]{0,1:T(4,128)} multiply(%add.923.clone.1, %add.923.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1190 = f32[]{:T(128)} constant(0) + %reduce.147 = f32[]{:T(128)} reduce(%square.265, %constant.1190), dimensions={0,1}, to_apply=%region_69.74, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.151.clone.1 = f32[]{:T(128)} reduce(%integer_pow.67.clone.1, %constant.1190), dimensions={0,1}, to_apply=%region_54.59, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.160 = (f32[]{:T(128)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.147, %add.923.clone.1, %add.926.clone.1, %add.927.clone.1, %reduce.151.clone.1) +} + +%region_66.71 (reduce_sum.542: f32[], reduce_sum.543: f32[]) -> f32[] { + %reduce_sum.543 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.542 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.544 = f32[]{:T(128)} add(%reduce_sum.542, %reduce_sum.543), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%region_51.56 (reduce_sum.464: f32[], reduce_sum.465: f32[]) -> f32[] { + %reduce_sum.465 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/reduce_sum"} + %reduce_sum.464 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/reduce_sum"} + ROOT %reduce_sum.466 = f32[]{:T(128)} add(%reduce_sum.464, %reduce_sum.465), metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.412 (param_0.1373: f32[128,4], param_1.1561: f32[], param_2.1319: f32[], param_3.923: f32[], param_4.561: f32[128,4], param_5.473: f32[], param_6.363: f32[4,128], param_7.206: pred[], param_8.123: f32[128,4]) -> (f32[], f32[128,4], f32[128,4], f32[128,4], f32[]) { - %param_0.1373 = f32[128,4]{0,1:T(4,128)S(1)} parameter(0) - %param_3.923 = f32[]{:T(128)S(6)} parameter(3) - %mul.1970.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_3.923), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_7.206 = pred[]{:T(512)S(6)} parameter(7) - %select_n.288.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.206), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %param_6.363 = f32[4,128]{1,0:T(4,128)} parameter(6) - %bitcast.474.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_6.363), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - %param_5.473 = f32[]{:T(128)} parameter(5) - %div.900.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_5.473), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.899.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%bitcast.474.clone.1, %div.900.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %select_n.287.clone.1 = f32[128,4]{0,1:T(4,128)} select(%select_n.288.clone.1, %bitcast.474.clone.1, %div.899.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} - %constant.1122.clone.1 = f32[]{:T(128)} constant(0.1) - %broadcast.866.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1122.clone.1), dimensions={}, metadata={op_name="broadcast.78"} - %mul.1974.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.287.clone.1, %broadcast.866.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} +%fused_computation.423 (param_0.1414: f32[128,4], param_1.1588: f32[], param_2.1330: f32[], param_3.913: f32[], param_4.552: f32[128,4], param_5.486: f32[], param_6.361: f32[4,128], param_7.205: pred[], param_8.123: f32[128,4]) -> (f32[], f32[128,4], f32[128,4], f32[128,4], f32[]) { + %param_0.1414 = f32[128,4]{0,1:T(4,128)S(1)} parameter(0) + %param_3.913 = f32[]{:T(128)S(6)} parameter(3) + %mul.2493.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_3.913), dimensions={}, metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_7.205 = pred[]{:T(512)S(6)} parameter(7) + %select_n.288.clone.1 = pred[128,4]{0,1:T(4,128)(4,1)} broadcast(%param_7.205), dimensions={}, metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %param_6.361 = f32[4,128]{1,0:T(4,128)} parameter(6) + %bitcast.455.clone.1 = f32[128,4]{0,1:T(4,128)} bitcast(%param_6.361), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + %param_5.486 = f32[]{:T(128)} parameter(5) + %div.900.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_5.486), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.899.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%bitcast.455.clone.1, %div.900.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %select_n.287.clone.1 = f32[128,4]{0,1:T(4,128)} select(%select_n.288.clone.1, %bitcast.455.clone.1, %div.899.clone.1), metadata={op_name="jit(train_step)/select_n" stack_frame_id=0} + %constant.1110.clone.1 = f32[]{:T(128)} constant(0.1) + %broadcast.776.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1110.clone.1), dimensions={}, metadata={op_name="broadcast.78"} + %mul.2497.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.287.clone.1, %broadcast.776.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} %param_8.123 = f32[128,4]{0,1:T(4,128)S(1)} parameter(8) - %constant.1126.clone.1 = f32[]{:T(128)} constant(0.9) - %broadcast.865.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1126.clone.1), dimensions={}, metadata={op_name="broadcast.77"} - %mul.1973.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_8.123, %broadcast.865.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.968.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.1974.clone.1, %mul.1973.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_2.1319 = f32[]{:T(128)S(6)} parameter(2) - %div.896.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_2.1319), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1114.clone.1 = f32[]{:T(128)} constant(0.9) + %broadcast.775.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1114.clone.1), dimensions={}, metadata={op_name="broadcast.77"} + %mul.2496.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_8.123, %broadcast.775.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.944.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.2497.clone.1, %mul.2496.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_2.1330 = f32[]{:T(128)S(6)} parameter(2) + %div.896.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_2.1330), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} %integer_pow.70.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%select_n.287.clone.1, %select_n.287.clone.1), metadata={op_name="jit(train_step)/integer_pow" stack_frame_id=0} - %constant.1125.clone.1 = f32[]{:T(128)} constant(0.05) - %broadcast.864.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1125.clone.1), dimensions={}, metadata={op_name="broadcast.67"} - %mul.1972.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%integer_pow.70.clone.1, %broadcast.864.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %param_4.561 = f32[128,4]{0,1:T(4,128)S(1)} parameter(4) - %constant.1124.clone.1 = f32[]{:T(128)} constant(0.95) - %broadcast.863.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1124.clone.1), dimensions={}, metadata={op_name="broadcast.66"} - %mul.1971.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_4.561, %broadcast.863.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.967.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.1972.clone.1, %mul.1971.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %param_1.1561 = f32[]{:T(128)S(6)} parameter(1) - %div.895.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_1.1561), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %div.894.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.967.clone.1, %div.895.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %constant.1113.clone.1 = f32[]{:T(128)} constant(0.05) + %broadcast.774.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1113.clone.1), dimensions={}, metadata={op_name="broadcast.67"} + %mul.2495.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%integer_pow.70.clone.1, %broadcast.774.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %param_4.552 = f32[128,4]{0,1:T(4,128)S(1)} parameter(4) + %constant.1112.clone.1 = f32[]{:T(128)} constant(0.95) + %broadcast.773.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1112.clone.1), dimensions={}, metadata={op_name="broadcast.66"} + %mul.2494.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_4.552, %broadcast.773.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.943.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%mul.2495.clone.1, %mul.2494.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %param_1.1588 = f32[]{:T(128)S(6)} parameter(1) + %div.895.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%param_1.1588), dimensions={}, metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %div.894.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.943.clone.1, %div.895.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} %sqrt.67.clone.1 = f32[128,4]{0,1:T(4,128)} sqrt(%div.894.clone.1), metadata={op_name="jit(train_step)/sqrt" stack_frame_id=0} - %constant.1123.clone.1 = f32[]{:T(128)} constant(1e-08) - %broadcast.861.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1123.clone.1), dimensions={}, metadata={op_name="broadcast.62"} - %add.966.clone.1 = f32[128,4]{0,1:T(4,128)} add(%sqrt.67.clone.1, %broadcast.861.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %multiply.431.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%div.896.clone.1, %add.966.clone.1), metadata={op_name="multiply.56"} - %div.893.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.968.clone.1, %multiply.431.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} - %mul.1969.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_0.1373, %broadcast.866.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.965.clone.1 = f32[128,4]{0,1:T(4,128)} add(%div.893.clone.1, %mul.1969.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %mul.1968.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%mul.1970.clone.1, %add.965.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} - %add.964.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%param_0.1373, %mul.1968.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} - %square.266 = f32[128,4]{0,1:T(4,128)} multiply(%add.964.clone.1, %add.964.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} - %constant.1205 = f32[]{:T(128)} constant(0) - %reduce.218 = f32[]{:T(128)} reduce(%square.266, %constant.1205), dimensions={0,1}, to_apply=%region_66.71, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - %reduce.222.clone.1 = f32[]{:T(128)} reduce(%integer_pow.70.clone.1, %constant.1205), dimensions={0,1}, to_apply=%region_51.56, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} - ROOT %tuple.160 = (f32[]{:T(128)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.218, %add.964.clone.1, %add.967.clone.1, %add.968.clone.1, %reduce.222.clone.1) -} - -%fused_computation.421 (param_0.1201: f32[4,128], param_1.1323: f32[4,128]) -> f32[4,128] { - %param_0.1201 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %param_1.1323 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %constant.1045 = f32[]{:T(128)} constant(0.00048828125) - %broadcast.837 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1045), dimensions={}, metadata={op_name="broadcast.399"} - %div.767 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1323, %broadcast.837), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} - %constant.1043 = f32[]{:T(128)} constant(1e-06) - %add.935 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1043), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - %add.934 = f32[4,128]{1,0:T(4,128)} add(%div.767, %add.935), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - %rsqrt.168 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.934), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} - %div.754 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.168, %add.934), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} - %constant.1040 = f32[]{:T(128)} constant(-0.5) - %mul.1919 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1040), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1910 = f32[4,128]{1,0:T(4,128)} multiply(%div.754, %mul.1919), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.1909 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1201, %mul.1910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %constant.1039 = f32[]{:T(128)} constant(0.0009765625) - %mul.1918 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1039), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - ROOT %mul.1908 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.1909, %mul.1918), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} -} - -%region_0.1 (reduce_sum.137: s32[], reduce_sum.138: s32[]) -> s32[] { - %reduce_sum.138 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - %reduce_sum.137 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} - ROOT %reduce_sum.139 = s32[]{:T(128)} add(%reduce_sum.137, %reduce_sum.138), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["0","2"]}]}} -} - -%fused_computation.425 (param_0.1220: pred[4,128]) -> s32[] { - %param_0.1220 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) - %convert_element_type.1403 = s32[4,128]{1,0:T(4,128)} convert(%param_0.1220), metadata={op_name="jit(train_step)/jvp()/convert_element_type" stack_frame_id=0} - %constant.1066 = s32[]{:T(128)} constant(0) - ROOT %reduce.220 = s32[]{:T(128)} reduce(%convert_element_type.1403, %constant.1066), dimensions={0,1}, to_apply=%region_0.1, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} -} - -%fused_computation.428 (param_0.1203: f32[4,128]) -> f32[4,128] { - %param_0.1203 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1111.clone.1 = f32[]{:T(128)} constant(1e-08) + %broadcast.771.clone.1 = f32[128,4]{0,1:T(4,128)} broadcast(%constant.1111.clone.1), dimensions={}, metadata={op_name="broadcast.62"} + %add.942.clone.1 = f32[128,4]{0,1:T(4,128)} add(%sqrt.67.clone.1, %broadcast.771.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %multiply.292.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%div.896.clone.1, %add.942.clone.1), metadata={op_name="multiply.41"} + %div.893.clone.1 = f32[128,4]{0,1:T(4,128)} divide(%add.944.clone.1, %multiply.292.clone.1), metadata={op_name="jit(train_step)/div" stack_frame_id=0} + %mul.2492.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%param_0.1414, %broadcast.776.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.941.clone.1 = f32[128,4]{0,1:T(4,128)} add(%div.893.clone.1, %mul.2492.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %mul.2491.clone.1 = f32[128,4]{0,1:T(4,128)} multiply(%mul.2493.clone.1, %add.941.clone.1), metadata={op_name="jit(train_step)/mul" stack_frame_id=0} + %add.940.clone.1 = f32[128,4]{0,1:T(4,128)S(1)} add(%param_0.1414, %mul.2491.clone.1), metadata={op_name="jit(train_step)/add" stack_frame_id=0} + %square.266 = f32[128,4]{0,1:T(4,128)} multiply(%add.940.clone.1, %add.940.clone.1), metadata={op_name="jit(train_step)/square" stack_frame_id=0} + %constant.1193 = f32[]{:T(128)} constant(0) + %reduce.148 = f32[]{:T(128)} reduce(%square.266, %constant.1193), dimensions={0,1}, to_apply=%region_66.71, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + %reduce.152.clone.1 = f32[]{:T(128)} reduce(%integer_pow.70.clone.1, %constant.1193), dimensions={0,1}, to_apply=%region_51.56, metadata={op_name="jit(train_step)/reduce_sum" stack_frame_id=0} + ROOT %tuple.161 = (f32[]{:T(128)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[128,4]{0,1:T(4,128)S(1)}, f32[]{:T(128)}) tuple(%reduce.148, %add.940.clone.1, %add.943.clone.1, %add.944.clone.1, %reduce.152.clone.1) +} + +%fused_computation.432 (param_0.1242: f32[4,128], param_1.1350: f32[4,128]) -> f32[4,128] { + %param_0.1242 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %param_1.1350 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) %constant.1046 = f32[]{:T(128)} constant(0.00048828125) - %broadcast.829 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1046), dimensions={}, metadata={op_name="broadcast.399"} - %div.759 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1203, %broadcast.829), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %broadcast.749 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1046), dimensions={}, metadata={op_name="broadcast.362"} + %div.767 = f32[4,128]{1,0:T(4,128)} multiply(%param_1.1350, %broadcast.749), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} %constant.1044 = f32[]{:T(128)} constant(1e-06) - %add.924 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1044), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - %add.921 = f32[4,128]{1,0:T(4,128)} add(%div.759, %add.924), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} - ROOT %rsqrt.166 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.921), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} -} - -%fused_computation.429 (param_0.1204: pred[4,128], param_1.1575: f32[]) -> f32[4,128] { - %param_0.1204 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) - %param_1.1575 = f32[]{:T(128)S(6)} parameter(1) - %broadcast_in_dim.288 = f32[4,128]{1,0:T(4,128)} broadcast(%param_1.1575), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/broadcast_in_dim" stack_frame_id=0} - %constant.1223 = f32[]{:T(128)} constant(0) - %broadcast.833 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1223), dimensions={}, metadata={op_name="broadcast.99"} - ROOT %mul.1920 = f32[4,128]{1,0:T(4,128)S(1)} select(%param_0.1204, %broadcast_in_dim.288, %broadcast.833), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} + %add.911 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1044), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %add.910 = f32[4,128]{1,0:T(4,128)} add(%div.767, %add.911), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %rsqrt.168 = f32[4,128]{1,0:T(4,128)} rsqrt(%add.910), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} + %div.754 = f32[4,128]{1,0:T(4,128)} divide(%rsqrt.168, %add.910), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.1028 = f32[]{:T(128)} constant(-0.5) + %mul.2440 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1028), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2431 = f32[4,128]{1,0:T(4,128)} multiply(%div.754, %mul.2440), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2430 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1242, %mul.2431), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %constant.1027 = f32[]{:T(128)} constant(0.0009765625) + %mul.2439 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1027), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.2429 = f32[4,128]{1,0:T(4,128)S(1)} multiply(%mul.2430, %mul.2439), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} +} + +%region_7.10 (reduce_sum.234: s32[], reduce_sum.235: s32[]) -> s32[] { + %reduce_sum.235 = s32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + %reduce_sum.234 = s32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/jvp()/reduce_sum"} + ROOT %reduce_sum.236 = s32[]{:T(128)} add(%reduce_sum.234, %reduce_sum.235), metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[{"indices":["0","2"]}]}} +} + +%fused_computation.435 (param_0.1261: pred[4,128]) -> s32[] { + %param_0.1261 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %convert_element_type.1427 = s32[4,128]{1,0:T(4,128)} convert(%param_0.1261), metadata={op_name="jit(train_step)/jvp()/convert_element_type" stack_frame_id=0} + %constant.1054 = s32[]{:T(128)} constant(0) + ROOT %reduce.150 = s32[]{:T(128)} reduce(%convert_element_type.1427, %constant.1054), dimensions={0,1}, to_apply=%region_7.10, metadata={op_name="jit(train_step)/jvp()/reduce_sum" stack_frame_id=0} +} + +%fused_computation.439 (param_0.1245: f32[4,128]) -> f32[4,128] { + %param_0.1245 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1047 = f32[]{:T(128)} constant(0.00048828125) + %broadcast.745 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1047), dimensions={}, metadata={op_name="broadcast.362"} + %div.759 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1245, %broadcast.745), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/div" stack_frame_id=0} + %constant.1045 = f32[]{:T(128)} constant(1e-06) + %add.900 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1045), dimensions={}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + %add.897 = f32[4,128]{1,0:T(4,128)} add(%div.759, %add.900), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/add" stack_frame_id=0} + ROOT %rsqrt.166 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.897), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/rsqrt" stack_frame_id=0} +} + +%fused_computation.440 (param_0.1244: pred[4,128], param_1.1602: f32[]) -> f32[4,128] { + %param_0.1244 = pred[4,128]{1,0:T(4,128)(4,1)S(1)} parameter(0) + %param_1.1602 = f32[]{:T(128)S(6)} parameter(1) + %broadcast_in_dim.309 = f32[4,128]{1,0:T(4,128)} broadcast(%param_1.1602), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp())/broadcast_in_dim" stack_frame_id=0} + %constant.1211 = f32[]{:T(128)} constant(0) + %broadcast.743 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1211), dimensions={}, metadata={op_name="broadcast.50"} + ROOT %mul.2441 = f32[4,128]{1,0:T(4,128)S(1)} select(%param_0.1244, %broadcast_in_dim.309, %broadcast.743), metadata={op_name="jit(train_step)/transpose(jvp())/mul" stack_frame_id=0} } -%fused_computation.431 () -> f32[64] { - %constant.1049 = f32[]{:T(128)} constant(1e+06) - %broadcast.840 = f32[64]{0:T(128)} broadcast(%constant.1049), dimensions={}, metadata={op_name="broadcast.390"} +%fused_computation.442 () -> f32[64] { + %constant.1050 = f32[]{:T(128)} constant(1e+06) + %broadcast.752 = f32[64]{0:T(128)} broadcast(%constant.1050), dimensions={}, metadata={op_name="broadcast.353"} %iota.46 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/layers/iota" stack_frame_id=0} - %constant.1048 = s32[]{:T(128)} constant(2) - %broadcast.839 = s32[64]{0:T(128)} broadcast(%constant.1048), dimensions={}, metadata={op_name="broadcast.391"} - %mul.1921 = s32[64]{0:T(128)} multiply(%iota.46, %broadcast.839), metadata={op_name="jit(train_step)/layers/mul" stack_frame_id=0} - %convert_element_type.1404 = f32[64]{0:T(128)} convert(%mul.1921), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} - %constant.1047 = f32[]{:T(128)} constant(0.0078125) - %broadcast.838 = f32[64]{0:T(128)} broadcast(%constant.1047), dimensions={}, metadata={op_name="broadcast.392"} - %div.768 = f32[64]{0:T(128)} multiply(%convert_element_type.1404, %broadcast.838), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} - ROOT %pow.36 = f32[64]{0:T(128)S(1)} power(%broadcast.840, %div.768), metadata={op_name="jit(train_step)/layers/pow" stack_frame_id=0} + %constant.1049 = s32[]{:T(128)} constant(2) + %broadcast.751 = s32[64]{0:T(128)} broadcast(%constant.1049), dimensions={}, metadata={op_name="broadcast.354"} + %mul.2442 = s32[64]{0:T(128)} multiply(%iota.46, %broadcast.751), metadata={op_name="jit(train_step)/layers/mul" stack_frame_id=0} + %convert_element_type.1428 = f32[64]{0:T(128)} convert(%mul.2442), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %constant.1048 = f32[]{:T(128)} constant(0.0078125) + %broadcast.750 = f32[64]{0:T(128)} broadcast(%constant.1048), dimensions={}, metadata={op_name="broadcast.355"} + %div.768 = f32[64]{0:T(128)} multiply(%convert_element_type.1428, %broadcast.750), metadata={op_name="jit(train_step)/layers/div" stack_frame_id=0} + ROOT %pow.36 = f32[64]{0:T(128)S(1)} power(%broadcast.752, %div.768), metadata={op_name="jit(train_step)/layers/pow" stack_frame_id=0} } -%fused_computation.432 (param_0.1218: s32[4,128]) -> (f32[4,128,1,1], f32[4,128]) { - %param_0.1218 = s32[4,128]{1,0:T(4,128)} parameter(0) - %convert_element_type.1405 = f32[4,128]{1,0:T(4,128)S(1)} convert(%param_0.1218), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} - %bitcast.418 = f32[4,128,1,1]{1,0,3,2:T(4,128)} bitcast(%convert_element_type.1405), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %tuple.162 = (f32[4,128,1,1]{1,0,3,2:T(4,128)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%bitcast.418, %convert_element_type.1405) +%fused_computation.443 (param_0.1259: s32[4,128]) -> (f32[4,128,1,1], f32[4,128]) { + %param_0.1259 = s32[4,128]{1,0:T(4,128)} parameter(0) + %convert_element_type.1429 = f32[4,128]{1,0:T(4,128)S(1)} convert(%param_0.1259), metadata={op_name="jit(train_step)/layers/convert_element_type" stack_frame_id=0} + %bitcast.399 = f32[4,128,1,1]{1,0,3,2:T(4,128)} bitcast(%convert_element_type.1429), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.163 = (f32[4,128,1,1]{1,0,3,2:T(4,128)}, f32[4,128]{1,0:T(4,128)S(1)}) tuple(%bitcast.399, %convert_element_type.1429) } -%fused_computation.435 (param_0.1360: f32[2048,4]) -> bf16[4,2048] { - %param_0.1360 = f32[2048,4]{0,1:T(4,128)} parameter(0) - %bitcast.531 = f32[4,2048]{1,0:T(4,128)} bitcast(%param_0.1360), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.145 = bf16[4,2048]{1,0:T(4,128)(2,1)} convert(%bitcast.531) +%fused_computation.446 (param_0.1400: f32[2048,4]) -> bf16[4,2048] { + %param_0.1400 = f32[2048,4]{0,1:T(4,128)} parameter(0) + %bitcast.507 = f32[4,2048]{1,0:T(4,128)} bitcast(%param_0.1400), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.79 = bf16[4,2048]{1,0:T(4,128)(2,1)} convert(%bitcast.507) } -%fused_computation.436 (param_0.1359: f32[2048,4]) -> bf16[4,2048] { - %param_0.1359 = f32[2048,4]{0,1:T(4,128)} parameter(0) - %bitcast.530 = f32[4,2048]{1,0:T(4,128)} bitcast(%param_0.1359), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.147 = bf16[4,2048]{1,0:T(4,128)(2,1)} convert(%bitcast.530) +%fused_computation.447 (param_0.1401: f32[2048,4]) -> bf16[4,2048] { + %param_0.1401 = f32[2048,4]{0,1:T(4,128)} parameter(0) + %bitcast.508 = f32[4,2048]{1,0:T(4,128)} bitcast(%param_0.1401), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.81 = bf16[4,2048]{1,0:T(4,128)(2,1)} convert(%bitcast.508) } -%fused_computation.437 (param_0.1361: f32[128,4]) -> bf16[4,128] { - %param_0.1361 = f32[128,4]{0,1:T(4,128)} parameter(0) - %bitcast.532 = f32[4,128]{1,0:T(4,128)} bitcast(%param_0.1361), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.149 = bf16[4,128]{1,0:T(4,128)(2,1)} convert(%bitcast.532) +%fused_computation.448 (param_0.1402: f32[128,4]) -> bf16[4,128] { + %param_0.1402 = f32[128,4]{0,1:T(4,128)} parameter(0) + %bitcast.509 = f32[4,128]{1,0:T(4,128)} bitcast(%param_0.1402), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.83 = bf16[4,128]{1,0:T(4,128)(2,1)} convert(%bitcast.509) } -%fused_computation.438 (param_0.1362: f32[128,4]) -> bf16[4,128] { - %param_0.1362 = f32[128,4]{0,1:T(4,128)} parameter(0) - %bitcast.533 = f32[4,128]{1,0:T(4,128)} bitcast(%param_0.1362), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.151 = bf16[4,128]{1,0:T(4,128)(2,1)} convert(%bitcast.533) +%fused_computation.449 (param_0.1403: f32[128,4]) -> bf16[4,128] { + %param_0.1403 = f32[128,4]{0,1:T(4,128)} parameter(0) + %bitcast.510 = f32[4,128]{1,0:T(4,128)} bitcast(%param_0.1403), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.85 = bf16[4,128]{1,0:T(4,128)(2,1)} convert(%bitcast.510) } %region_8.11 (reduce_max.6: bf16[], reduce_max.8: bf16[]) -> bf16[] { @@ -1462,539 +1462,539 @@ StackFrames ROOT %reduce_max.9 = bf16[]{:T(256)} maximum(%reduce_max.6, %reduce_max.8), metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.287.clone.clone (param_0.1346: bf16[151936,2048]) -> bf16[151936,2048,1] { - %param_0.1346 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) - ROOT %bitcast.526 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%param_0.1346), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} +%fused_computation.298.clone.clone (param_0.1387: bf16[151936,2048]) -> bf16[151936,2048,1] { + %param_0.1387 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + ROOT %bitcast.503 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%param_0.1387), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} } -%fused_computation.368.clone.clone (param_0.1347: f32[4,128], param_1.1542: bf16[4,128,2048], param_2.1281: bf16[2048]) -> bf16[4,128,2048] { - %param_2.1281 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.476 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1281), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1542 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1438 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1542), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - %param_0.1347 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.2067 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1347), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %mul.2066 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1438, %mul.2067), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} - %convert_element_type.1437 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2066), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} - ROOT %dot_general.475 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.476, %convert_element_type.1437), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} +%fused_computation.379.clone.clone (param_0.1388: f32[4,128], param_1.1569: bf16[4,128,2048], param_2.1292: bf16[2048]) -> bf16[4,128,2048] { + %param_1.1569 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1462 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1569), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_0.1388 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2607 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1388), dimensions={0,1}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %mul.2606 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1462, %mul.2607), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %convert_element_type.1461 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2606), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/convert_element_type" stack_frame_id=0} + %param_2.1292 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2608 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1292), dimensions={2}, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + ROOT %mul.2605 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1461, %mul.2608), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} } -%fused_computation.439 (param_0.1363: bf16[151936,2048], param_1.1551: f32[4,128], param_2.1305: bf16[4,128,2048], param_3.913: bf16[2048]) -> (bf16[4,128], bf16[4,128,151936]) { - %param_1.1551 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.1305 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %param_3.913 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %fusion.270.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1551, %param_2.1305, %param_3.913), kind=kLoop, calls=%fused_computation.368.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1363 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) - %fusion.253.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1363), kind=kLoop, calls=%fused_computation.287.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} - %convolution.85.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convolution(%fusion.270.clone.1, %fusion.253.clone.1), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} - %constant.1195 = bf16[]{:T(256)} constant(-inf) - %reduce.223 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} reduce(%convolution.85.clone.1, %constant.1195), dimensions={2}, to_apply=%region_8.11, metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} - ROOT %tuple.164 = (bf16[4,128]{1,0:T(4,128)(2,1)S(1)}, bf16[4,128,151936]{2,1,0:T(8,128)(2,1)}) tuple(%reduce.223, %convolution.85.clone.1) +%fused_computation.450 (param_0.1404: bf16[151936,2048], param_1.1578: f32[4,128], param_2.1316: bf16[4,128,2048], param_3.903: bf16[2048]) -> (bf16[4,128], bf16[4,128,151936]) { + %param_1.1578 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1316 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %param_3.903 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %fusion.280.clone.1 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_1.1578, %param_2.1316, %param_3.903), kind=kLoop, calls=%fused_computation.379.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/decoder_norm/mul" stack_frame_id=0} + %param_0.1404 = bf16[151936,2048]{1,0:T(8,128)(2,1)} parameter(0) + %fusion.263.clone.1 = bf16[151936,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1404), kind=kLoop, calls=%fused_computation.298.clone.clone, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder._apply_embedding/token_embedder/convert_element_type" stack_frame_id=0} + %convolution.85.clone.1 = bf16[4,128,151936]{2,1,0:T(8,128)(2,1)} convolution(%fusion.280.clone.1, %fusion.263.clone.1), window={size=1}, dim_labels=0bf_oi0->0bf, metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/decoder.apply_output_head/dot_general" stack_frame_id=0} + %constant.1183 = bf16[]{:T(256)} constant(-inf) + %reduce.153 = bf16[4,128]{1,0:T(4,128)(2,1)S(1)} reduce(%convolution.85.clone.1, %constant.1183), dimensions={2}, to_apply=%region_8.11, metadata={op_name="jit(train_step)/jvp()/reduce_max" stack_frame_id=0} + ROOT %tuple.165 = (bf16[4,128]{1,0:T(4,128)(2,1)S(1)}, bf16[4,128,151936]{2,1,0:T(8,128)(2,1)}) tuple(%reduce.153, %convolution.85.clone.1) } -%fused_computation.440 (param_0.1358: f32[2048,4,8,128]) -> bf16[4,2048,8,128] { - %param_0.1358 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) - %bitcast.529 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} bitcast(%param_0.1358), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} - ROOT %convert.153 = bf16[4,2048,8,128]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.529) +%fused_computation.451 (param_0.1399: f32[2048,4,8,128]) -> bf16[4,2048,8,128] { + %param_0.1399 = f32[2048,4,8,128]{3,2,1,0:T(8,128)S(1)} parameter(0) + %bitcast.506 = f32[4,2048,8,128]{3,2,0,1:T(8,128)} bitcast(%param_0.1399), metadata={op_name="jit(train_step)/jvp(TransformerLinenPure.apply)/TransformerLinenPure/decoder/transpose" stack_frame_id=0} + ROOT %convert.87 = bf16[4,2048,8,128]{3,2,0,1:T(8,128)(2,1)} convert(%bitcast.506) } -%convert_element_type.767.reduce_sub_computation (lhs.1: bf16[], rhs.1: bf16[]) -> bf16[] { - %rhs.1 = bf16[] parameter(1) - %lhs.1 = bf16[] parameter(0) - ROOT %add.755 = bf16[] add(%lhs.1, %rhs.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%convert_element_type.785.reduce_sub_computation (lhs: bf16[], rhs: bf16[]) -> bf16[] { + %rhs = bf16[] parameter(1) + %lhs = bf16[] parameter(0) + ROOT %add.730 = bf16[] add(%lhs, %rhs), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.155.clone.clone (param_0.1534: bf16[4,2048], param_1.1687: s32[]) -> bf16[2048] { - %param_0.1534 = bf16[4,2048]{1,0:T(4,128)(2,1)} parameter(0) - %param_1.1687 = s32[]{:T(128)S(6)} parameter(1) - %constant.1361 = s32[]{:T(128)} constant(0) - %dynamic_slice.388 = bf16[1,2048]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1534, %param_1.1687, %constant.1361), dynamic_slice_sizes={1,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %constant.1362 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %reduce.244 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.388, %constant.1362), dimensions={0}, to_apply=%convert_element_type.767.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.167.clone.clone (param_0.1572: bf16[4,2048], param_1.1711: s32[]) -> bf16[2048] { + %param_0.1572 = bf16[4,2048]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1711 = s32[]{:T(128)S(6)} parameter(1) + %constant.1348 = s32[]{:T(128)} constant(0) + %dynamic_slice.394 = bf16[1,2048]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1572, %param_1.1711, %constant.1348), dynamic_slice_sizes={1,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %constant.1349 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %reduce.174 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.394, %constant.1349), dimensions={0}, to_apply=%convert_element_type.785.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%region_14.16 (reduce_sum.204: f32[], reduce_sum.205: f32[]) -> f32[] { - %reduce_sum.205 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - %reduce_sum.204 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - ROOT %reduce_sum.206 = f32[]{:T(128)} add(%reduce_sum.204, %reduce_sum.205), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +%region_14.16 (reduce_sum.278: f32[], reduce_sum.282: f32[]) -> f32[] { + %reduce_sum.282 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.278 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.283 = f32[]{:T(128)} add(%reduce_sum.278, %reduce_sum.282), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.58.clone.clone (param_0.1535: bf16[4,4,128,2048], param_1.1688: s32[]) -> f32[4,128] { - %param_0.1535 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1688 = s32[]{:T(128)S(6)} parameter(1) +%fused_computation.61.clone.clone (param_0.1573: bf16[4,4,128,2048], param_1.1712: s32[]) -> f32[4,128] { + %param_0.1573 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1712 = s32[]{:T(128)S(6)} parameter(1) + %constant.1350 = s32[]{:T(128)} constant(0) + %dynamic_slice.395 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1573, %param_1.1712, %constant.1350, %constant.1350, %constant.1350), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %bitcast.602 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.395), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} + %convert_element_type.1585 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.602), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.280 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1585, %convert_element_type.1585), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1351 = f32[]{:T(128)} constant(0) + ROOT %reduce.175 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.280, %constant.1351), dimensions={2}, to_apply=%region_14.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} +} + +%fused_computation.190.clone.1.clone (param_0.1574: f32[4,128]) -> f32[4,128] { + %param_0.1574 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1353 = f32[]{:T(128)} constant(0.00048828125) + %closed_call.106 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1353), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.999 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1574, %closed_call.106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1352 = f32[]{:T(128)} constant(1e-06) + %closed_call.105 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1352), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.1015 = f32[4,128]{1,0:T(4,128)} add(%div.999, %closed_call.105), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.181 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.1015), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%region_15.17 (reduce_sum.284: f32[], reduce_sum.285: f32[]) -> f32[] { + %reduce_sum.285 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.284 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.289 = f32[]{:T(128)} add(%reduce_sum.284, %reduce_sum.285), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.25.clone.1.clone.clone.clone.clone (param_0.1587: bf16[4,2048,16,128], param_1.1721: s32[]) -> bf16[2048,16,128,1] { + %param_0.1587 = bf16[4,2048,16,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1721 = s32[]{:T(128)S(6)} parameter(1) %constant.1363 = s32[]{:T(128)} constant(0) - %dynamic_slice.389 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1535, %param_1.1688, %constant.1363, %constant.1363, %constant.1363), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %bitcast.633 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.389), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1564 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.633), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %square.280 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1564, %convert_element_type.1564), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} - %constant.1364 = f32[]{:T(128)} constant(0) - ROOT %reduce.245 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.280, %constant.1364), dimensions={2}, to_apply=%region_14.16, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} -} - -%fused_computation.179.clone.1.clone (param_0.1536: f32[4,128]) -> f32[4,128] { - %param_0.1536 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %constant.1366 = f32[]{:T(128)} constant(0.00048828125) - %closed_call.106 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1366), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.999 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1536, %closed_call.106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1365 = f32[]{:T(128)} constant(1e-06) - %closed_call.105 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1365), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %add.1039 = f32[4,128]{1,0:T(4,128)} add(%div.999, %closed_call.105), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %rsqrt.181 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.1039), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} -} - -%region_15.17 (reduce_sum.207: f32[], reduce_sum.211: f32[]) -> f32[] { - %reduce_sum.211 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - %reduce_sum.207 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - ROOT %reduce_sum.212 = f32[]{:T(128)} add(%reduce_sum.207, %reduce_sum.211), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.25.clone.1.clone.clone.clone.clone (param_0.1550: bf16[4,2048,16,128], param_1.1698: s32[]) -> bf16[2048,16,128,1] { - %param_0.1550 = bf16[4,2048,16,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %param_1.1698 = s32[]{:T(128)S(6)} parameter(1) - %constant.1377 = s32[]{:T(128)} constant(0) - %dynamic_slice.395 = bf16[1,2048,16,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1550, %param_1.1698, %constant.1377, %constant.1377, %constant.1377), dynamic_slice_sizes={1,2048,16,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.644 = bf16[2048,16,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.395), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.114.clone.clone.clone.clone (param_0.1551: f32[4,128], param_1.1699: bf16[4,4,128,2048], param_2.1405: s32[], param_3.982: bf16[2048]) -> bf16[4,128,2048,1] { - %param_3.982 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %dot_general.571 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.982), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1699 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1405 = s32[]{:T(128)S(6)} parameter(2) - %constant.1378 = s32[]{:T(128)} constant(0) - %dynamic_slice.396 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1699, %param_2.1405, %constant.1378, %constant.1378, %constant.1378), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %bitcast.646 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.396), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1575 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.646), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1551 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.2256 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1551), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2255 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1575, %mul.2256), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1574 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2255), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.570 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.571, %convert_element_type.1574), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.645 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.61.clone.clone (param_0.1552: bf16[4,2048,16,128], param_1.1700: s32[], param_2.1406: f32[4,128], param_3.983: bf16[4,4,128,2048], param_4.604: bf16[2048]) -> (f32[4,128,16], bf16[4,128,16,128]) { - %param_2.1406 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.983 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) - %param_1.1700 = s32[]{:T(128)S(6)} parameter(1) - %param_4.604 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.74.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1406, %param_3.983, %param_1.1700, %param_4.604), kind=kLoop, calls=%fused_computation.114.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1552 = bf16[4,2048,16,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %fusion.49.clone.3 = bf16[2048,16,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1552, %param_1.1700), kind=kLoop, calls=%fused_computation.25.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %convolution.44.clone.3 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.74.clone.3, %fusion.49.clone.3), window={size=1x16 pad=0_0x15_15 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} - %convert_element_type.1576 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%convolution.44.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %square.282 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1576, %convert_element_type.1576), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} - %constant.1379 = f32[]{:T(128)} constant(0) - %reduce.247 = f32[4,128,16]{1,2,0:T(8,128)S(1)} reduce(%square.282, %constant.1379), dimensions={3}, to_apply=%region_15.17, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} - ROOT %tuple.208 = (f32[4,128,16]{1,2,0:T(8,128)S(1)}, bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%reduce.247, %convolution.44.clone.3) -} - -%fused_computation.151.clone.1.clone (param_0.1553: f32[4,128,16]) -> f32[4,128,16] { - %param_0.1553 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(0) - %constant.1380 = f32[]{:T(128)} constant(0.0078125) - %closed_call.108 = f32[4,128,16]{1,2,0:T(8,128)} broadcast(%constant.1380), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.1001 = f32[4,128,16]{1,2,0:T(8,128)} multiply(%param_0.1553, %closed_call.108), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1381 = f32[]{:T(128)} constant(1e-06) - %add.1044 = f32[4,128,16]{1,2,0:T(8,128)} broadcast(%constant.1381), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - %add.1043 = f32[4,128,16]{1,2,0:T(8,128)} add(%div.1001, %add.1044), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %rsqrt.183 = f32[4,128,16]{1,2,0:T(8,128)S(1)} rsqrt(%add.1043), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} -} - -%fused_computation.182.clone.clone (param_0.1549: bf16[4,128], param_1.1697: s32[]) -> bf16[128] { - %param_0.1549 = bf16[4,128]{1,0:T(4,128)(2,1)} parameter(0) - %param_1.1697 = s32[]{:T(128)S(6)} parameter(1) - %constant.1376 = s32[]{:T(128)} constant(0) - %dynamic_slice.394 = bf16[1,128]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1549, %param_1.1697, %constant.1376), dynamic_slice_sizes={1,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.643 = bf16[128]{0:T(256)(128)(2,1)S(1)} bitcast(%dynamic_slice.394), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %dynamic_slice.400 = bf16[1,2048,16,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1587, %param_1.1721, %constant.1363, %constant.1363, %constant.1363), dynamic_slice_sizes={1,2048,16,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.610 = bf16[2048,16,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.400), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.89.clone.2.clone.clone.clone.clone (param_0.1588: f32[4,128], param_1.1722: bf16[2048], param_2.1414: bf16[4,4,128,2048], param_3.972: s32[]) -> bf16[4,128,2048,1] { + %param_2.1414 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.972 = s32[]{:T(128)S(6)} parameter(3) + %constant.1364 = s32[]{:T(128)} constant(0) + %dynamic_slice.401 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1414, %param_3.972, %constant.1364, %constant.1364, %constant.1364), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %convert_element_type.1596 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} convert(%dynamic_slice.401), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1588 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2890 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_0.1588), dimensions={1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2889 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} multiply(%convert_element_type.1596, %mul.2890), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1595 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} convert(%mul.2889), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1722 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(1) + %mul.2891 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} broadcast(%param_1.1722), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2888 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1595, %mul.2891), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.611 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2888), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.64.clone.clone (param_0.1589: bf16[4,2048,16,128], param_1.1723: s32[], param_2.1415: f32[4,128], param_3.973: bf16[2048], param_4.596: bf16[4,4,128,2048]) -> (f32[4,128,16], bf16[4,128,16,128]) { + %param_2.1415 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.973 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %param_4.596 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(4) + %param_1.1723 = s32[]{:T(128)S(6)} parameter(1) + %fusion.91.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1415, %param_3.973, %param_4.596, %param_1.1723), kind=kLoop, calls=%fused_computation.89.clone.2.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1589 = bf16[4,2048,16,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.49.clone.3 = bf16[2048,16,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1589, %param_1.1723), kind=kLoop, calls=%fused_computation.25.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.44.clone.3 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.91.clone.3, %fusion.49.clone.3), window={size=1x16 pad=0_0x15_15 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %convert_element_type.1597 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%convolution.44.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.282 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1597, %convert_element_type.1597), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1365 = f32[]{:T(128)} constant(0) + %reduce.177 = f32[4,128,16]{1,2,0:T(8,128)S(1)} reduce(%square.282, %constant.1365), dimensions={3}, to_apply=%region_15.17, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.210 = (f32[4,128,16]{1,2,0:T(8,128)S(1)}, bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%reduce.177, %convolution.44.clone.3) +} + +%fused_computation.162.clone.1.clone (param_0.1590: f32[4,128,16]) -> f32[4,128,16] { + %param_0.1590 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(0) + %constant.1366 = f32[]{:T(128)} constant(0.0078125) + %closed_call.108 = f32[4,128,16]{1,2,0:T(8,128)} broadcast(%constant.1366), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.1001 = f32[4,128,16]{1,2,0:T(8,128)} multiply(%param_0.1590, %closed_call.108), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1367 = f32[]{:T(128)} constant(1e-06) + %add.1020 = f32[4,128,16]{1,2,0:T(8,128)} broadcast(%constant.1367), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %add.1019 = f32[4,128,16]{1,2,0:T(8,128)} add(%div.1001, %add.1020), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.183 = f32[4,128,16]{1,2,0:T(8,128)S(1)} rsqrt(%add.1019), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%fused_computation.193.clone.clone (param_0.1591: bf16[4,128], param_1.1724: s32[]) -> bf16[128] { + %param_0.1591 = bf16[4,128]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1724 = s32[]{:T(128)S(6)} parameter(1) + %constant.1368 = s32[]{:T(128)} constant(0) + %dynamic_slice.402 = bf16[1,128]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1591, %param_1.1724, %constant.1368), dynamic_slice_sizes={1,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.612 = bf16[128]{0:T(256)(128)(2,1)S(1)} bitcast(%dynamic_slice.402), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.121.clone.1.clone (param_0.1554: f32[4,128,16], param_1.1701: bf16[4,128,16,128], param_2.1407: bf16[128]) -> bf16[4,128,16,128] { - %param_2.1407 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(2) - %dot_general.573 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1407), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1701 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1578 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%param_1.1701), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1554 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(0) - %mul.2258 = f32[4,128,16,128]{3,1,2,0:T(8,128)} broadcast(%param_0.1554), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2257 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1578, %mul.2258), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1577 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2257), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %dot_general.572 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} multiply(%dot_general.573, %convert_element_type.1577), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} +%fused_computation.118.clone.1.clone (param_0.1592: f32[4,128,16], param_1.1725: bf16[4,128,16,128], param_2.1416: bf16[128]) -> bf16[4,128,16,128] { + %param_1.1725 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1599 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%param_1.1725), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1592 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(0) + %mul.2894 = f32[4,128,16,128]{3,1,2,0:T(8,128)} broadcast(%param_0.1592), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2893 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1599, %mul.2894), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1598 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2893), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1416 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(2) + %mul.2895 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1416), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2892 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} multiply(%convert_element_type.1598, %mul.2895), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} } -%fused_computation.90.clone.clone (param_0.1555: bf16[4,128,16,128]) -> (bf16[4,128,16,64], bf16[4,128,16,64]) { - %param_0.1555 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %split.160 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1555), slice={[0:4], [0:128], [0:16], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} +%fused_computation.93.clone.clone (param_0.1593: bf16[4,128,16,128]) -> (bf16[4,128,16,64], bf16[4,128,16,64]) { + %param_0.1593 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.160 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1593), slice={[0:4], [0:128], [0:16], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} %neg.129 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.160), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=0} - %split.161 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1555), slice={[0:4], [0:128], [0:16], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} - ROOT %tuple.209 = (bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.129, %split.161) + %split.161 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1593), slice={[0:4], [0:128], [0:16], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + ROOT %tuple.211 = (bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.129, %split.161) } -%fused_computation.187.clone.clone () -> f32[64] { - %constant.1355 = f32[]{:T(128)} constant(1e+06) - %closed_call.104 = f32[64]{0:T(128)} broadcast(%constant.1355), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} +%fused_computation.198.clone.clone () -> f32[64] { + %constant.1343 = f32[]{:T(128)} constant(1e+06) + %closed_call.104 = f32[64]{0:T(128)} broadcast(%constant.1343), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} %iota.51 = s32[64]{0:T(128)} iota(), iota_dimension=0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/iota" stack_frame_id=0} - %constant.1354 = s32[]{:T(128)} constant(2) - %closed_call.103 = s32[64]{0:T(128)} broadcast(%constant.1354), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %mul.2242 = s32[64]{0:T(128)} multiply(%iota.51, %closed_call.103), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1562 = f32[64]{0:T(128)} convert(%mul.2242), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %constant.1356 = f32[]{:T(128)} constant(0.0078125) - %closed_call.102 = f32[64]{0:T(128)} broadcast(%constant.1356), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.995 = f32[64]{0:T(128)} multiply(%convert_element_type.1562, %closed_call.102), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1342 = s32[]{:T(128)} constant(2) + %closed_call.103 = s32[64]{0:T(128)} broadcast(%constant.1342), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %mul.2867 = s32[64]{0:T(128)} multiply(%iota.51, %closed_call.103), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1583 = f32[64]{0:T(128)} convert(%mul.2867), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %constant.1344 = f32[]{:T(128)} constant(0.0078125) + %closed_call.102 = f32[64]{0:T(128)} broadcast(%constant.1344), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.995 = f32[64]{0:T(128)} multiply(%convert_element_type.1583, %closed_call.102), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} ROOT %pow.38 = f32[64]{0:T(128)S(1)} power(%closed_call.104, %div.995), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/pow" stack_frame_id=0} } -%fused_computation.143.clone.clone (param_0.1529: f32[64], param_1.1683: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { - %param_1.1683 = f32[4,128]{1,0:T(4,128)} parameter(1) - %div.998 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1683), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %param_0.1529 = f32[64]{0:T(128)S(1)} parameter(0) - %div.997 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1529), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} +%fused_computation.154.clone.clone (param_0.1569: f32[64], param_1.1709: f32[4,128]) -> (bf16[4,128,1,64], bf16[4,128,1,64]) { + %param_1.1709 = f32[4,128]{1,0:T(4,128)} parameter(1) + %div.998 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_1.1709), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %param_0.1569 = f32[64]{0:T(128)S(1)} parameter(0) + %div.997 = f32[4,128,1,64]{3,1,0,2:T(8,128)} broadcast(%param_0.1569), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} %div.996 = f32[4,128,1,64]{3,1,0,2:T(8,128)} divide(%div.998, %div.997), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} %cos.43 = f32[4,128,1,64]{3,1,0,2:T(8,128)} cosine(%div.996), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/cos" stack_frame_id=0} - %convert_element_type.1563 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.43), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convert_element_type.1584 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%cos.43), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} %sin.35.clone.3 = f32[4,128,1,64]{3,1,0,2:T(8,128)} sine(%div.996), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/sin" stack_frame_id=0} - %convert_element_type.1189.clone.3 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.35.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %tuple.205 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1563, %convert_element_type.1189.clone.3) + %convert_element_type.1213.clone.3 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} convert(%sin.35.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %tuple.207 = (bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}, bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)}) tuple(%convert_element_type.1584, %convert_element_type.1213.clone.3) } -%fused_computation.146.clone.1.clone (param_0.1530: bf16[4,128,1,64]) -> bf16[4,128,128] { - %param_0.1530 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.1357 = bf16[]{:T(256)} constant(-inf) - %pad.69 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1530, %constant.1357), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %pad.68 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1530, %constant.1357), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} +%fused_computation.157.clone.1.clone (param_0.1570: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1570 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1345 = bf16[]{:T(256)} constant(-inf) + %pad.69 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1570, %constant.1345), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %pad.68 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1570, %constant.1345), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.53 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.69, %pad.68), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - ROOT %bitcast.630 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.53), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.601 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.53), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} } -%fused_computation.145.clone.1.clone (param_0.1545: bf16[4,128,1,64]) -> bf16[4,128,128] { - %param_0.1545 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) - %constant.1374 = bf16[]{:T(256)} constant(-inf) - %pad.71 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1545, %constant.1374), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %pad.70 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1545, %constant.1374), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} +%fused_computation.156.clone.1.clone (param_0.1583: bf16[4,128,1,64]) -> bf16[4,128,128] { + %param_0.1583 = bf16[4,128,1,64]{3,1,0,2:T(8,128)(2,1)S(1)} parameter(0) + %constant.1361 = bf16[]{:T(256)} constant(-inf) + %pad.71 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1583, %constant.1361), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %pad.70 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} pad(%param_0.1583, %constant.1361), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.54 = bf16[4,128,1,128]{3,1,0,2:T(8,128)(2,1)} maximum(%pad.71, %pad.70), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - ROOT %bitcast.641 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.54), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} -} - -%fused_computation.94.clone.clone (param_0.1556: bf16[4,128,16,64], param_1.1702: bf16[4,128,16,64], param_2.1408: bf16[4,128,128], param_3.984: bf16[4,128,128], param_4.605: f32[4,128,16], param_5.499: bf16[4,128,16,128], param_6.384: bf16[128]) -> bf16[4,16,128,128] { - %param_6.384 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(6) - %dot_general.575 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_6.384), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_5.499 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(5) - %convert_element_type.1580 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%param_5.499), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_4.605 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(4) - %mul.2265 = f32[4,128,16,128]{3,1,2,0:T(8,128)} broadcast(%param_4.605), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2264 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1580, %mul.2265), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1579 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2264), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.574 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%dot_general.575, %convert_element_type.1579), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_3.984 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %mul.2263 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.984), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2261 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%dot_general.574, %mul.2263), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %param_1.1702 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) - %constant.1382 = bf16[]{:T(256)} constant(-inf) - %pad.75 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1702, %constant.1382), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_0.1556 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %pad.74 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1556, %constant.1382), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + ROOT %bitcast.608 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} bitcast(%maximum.54), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.102.clone.clone (param_0.1594: bf16[4,128,16,64], param_1.1726: bf16[4,128,16,64], param_2.1417: bf16[4,128,128], param_3.974: bf16[4,128,128], param_4.597: bf16[128], param_5.512: f32[4,128,16], param_6.382: bf16[4,128,16,128]) -> bf16[4,16,128,128] { + %param_6.382 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(6) + %convert_element_type.1601 = f32[4,128,16,128]{3,1,2,0:T(8,128)} convert(%param_6.382), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_5.512 = f32[4,128,16]{1,2,0:T(8,128)S(1)} parameter(5) + %mul.2904 = f32[4,128,16,128]{3,1,2,0:T(8,128)} broadcast(%param_5.512), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2903 = f32[4,128,16,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1601, %mul.2904), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1600 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2903), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_4.597 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(4) + %mul.2902 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.597), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2901 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%convert_element_type.1600, %mul.2902), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_3.974 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2900 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.974), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2898 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%mul.2901, %mul.2900), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_1.1726 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1369 = bf16[]{:T(256)} constant(-inf) + %pad.75 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1726, %constant.1369), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_0.1594 = bf16[4,128,16,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.74 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1594, %constant.1369), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.56 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.75, %pad.74), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_2.1408 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %mul.2262 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1408), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2260 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.56, %mul.2262), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %add.1045 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2261, %mul.2260), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - %constant.1383 = bf16[]{:T(256)} constant(0.08838) - %closed_call.109 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%constant.1383), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %mul.2259 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%add.1045, %closed_call.109), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - ROOT %bitcast.647 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%mul.2259), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} -} - -%region_16.18 (reduce_sum.213: f32[], reduce_sum.214: f32[]) -> f32[] { - %reduce_sum.214 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - %reduce_sum.213 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - ROOT %reduce_sum.218 = f32[]{:T(128)} add(%reduce_sum.213, %reduce_sum.214), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.69.clone.1.clone.clone.clone.clone (param_0.1541: bf16[4,2048,8,128], param_1.1692: s32[]) -> bf16[2048,8,128,1] { - %param_0.1541 = bf16[4,2048,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %param_1.1692 = s32[]{:T(128)S(6)} parameter(1) - %constant.1369 = s32[]{:T(128)} constant(0) - %dynamic_slice.392 = bf16[1,2048,8,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1541, %param_1.1692, %constant.1369, %constant.1369, %constant.1369), dynamic_slice_sizes={1,2048,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.638 = bf16[2048,8,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.392), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.113.clone.clone.clone.clone (param_0.1542: f32[4,128], param_1.1693: bf16[4,4,128,2048], param_2.1401: s32[], param_3.979: bf16[2048]) -> bf16[4,128,2048,1] { - %param_3.979 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %dot_general.565 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.979), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1693 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1401 = s32[]{:T(128)S(6)} parameter(2) - %constant.1370 = s32[]{:T(128)} constant(0) - %dynamic_slice.393 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1693, %param_2.1401, %constant.1370, %constant.1370, %constant.1370), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %bitcast.640 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.393), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1568 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.640), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1542 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.2246 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1542), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2245 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1568, %mul.2246), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1567 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2245), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.564 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.565, %convert_element_type.1567), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.639 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.564), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.84.clone.clone (param_0.1543: bf16[4,2048,8,128], param_1.1694: s32[], param_2.1402: f32[4,128], param_3.980: bf16[4,4,128,2048], param_4.602: bf16[2048]) -> (f32[4,128,8], bf16[4,128,8,128]) { - %param_2.1402 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.980 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) - %param_1.1694 = s32[]{:T(128)S(6)} parameter(1) - %param_4.602 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.73.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1402, %param_3.980, %param_1.1694, %param_4.602), kind=kLoop, calls=%fused_computation.113.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1543 = bf16[4,2048,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) - %fusion.87.clone.3 = bf16[2048,8,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1543, %param_1.1694), kind=kLoop, calls=%fused_computation.69.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %convolution.50.clone.3 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.73.clone.3, %fusion.87.clone.3), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} - %convert_element_type.1569 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%convolution.50.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %square.281 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1569, %convert_element_type.1569), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} - %constant.1371 = f32[]{:T(128)} constant(0) - %reduce.246 = f32[4,128,8]{1,2,0:T(8,128)S(1)} reduce(%square.281, %constant.1371), dimensions={3}, to_apply=%region_16.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} - ROOT %tuple.206 = (f32[4,128,8]{1,2,0:T(8,128)S(1)}, bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%reduce.246, %convolution.50.clone.3) -} - -%fused_computation.154.clone.1.clone (param_0.1544: f32[4,128,8]) -> f32[4,128,8] { - %param_0.1544 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(0) - %constant.1372 = f32[]{:T(128)} constant(0.0078125) - %closed_call.107 = f32[4,128,8]{1,2,0:T(8,128)} broadcast(%constant.1372), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.1000 = f32[4,128,8]{1,2,0:T(8,128)} multiply(%param_0.1544, %closed_call.107), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1373 = f32[]{:T(128)} constant(1e-06) - %add.1041 = f32[4,128,8]{1,2,0:T(8,128)} broadcast(%constant.1373), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - %add.1040 = f32[4,128,8]{1,2,0:T(8,128)} add(%div.1000, %add.1041), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %rsqrt.182 = f32[4,128,8]{1,2,0:T(8,128)S(1)} rsqrt(%add.1040), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} -} - -%fused_computation.184.clone.clone (param_0.1528: bf16[4,128], param_1.1682: s32[]) -> bf16[128] { - %param_0.1528 = bf16[4,128]{1,0:T(4,128)(2,1)} parameter(0) - %param_1.1682 = s32[]{:T(128)S(6)} parameter(1) - %constant.1353 = s32[]{:T(128)} constant(0) - %dynamic_slice.385 = bf16[1,128]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1528, %param_1.1682, %constant.1353), dynamic_slice_sizes={1,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.629 = bf16[128]{0:T(256)(128)(2,1)S(1)} bitcast(%dynamic_slice.385), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.139.clone.1.clone (param_0.1546: f32[4,128,8], param_1.1695: bf16[4,128,8,128], param_2.1403: bf16[128]) -> bf16[4,128,8,128] { - %param_2.1403 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(2) - %dot_general.567 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1403), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1695 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1571 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%param_1.1695), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1546 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(0) - %mul.2248 = f32[4,128,8,128]{3,1,2,0:T(8,128)} broadcast(%param_0.1546), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2247 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1571, %mul.2248), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1570 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2247), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %dot_general.566 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} multiply(%dot_general.567, %convert_element_type.1570), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.126.clone.clone (param_0.1547: bf16[4,128,8,128]) -> (bf16[4,128,8,64], bf16[4,128,8,64]) { - %param_0.1547 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %split.158 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1547), slice={[0:4], [0:128], [0:8], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + %param_2.1417 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %mul.2899 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1417), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2897 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.56, %mul.2899), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %add.1021 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2898, %mul.2897), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %constant.1370 = bf16[]{:T(256)} constant(0.08838) + %closed_call.109 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%constant.1370), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %mul.2896 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} multiply(%add.1021, %closed_call.109), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.613 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%mul.2896), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%region_16.18 (reduce_sum.290: f32[], reduce_sum.291: f32[]) -> f32[] { + %reduce_sum.291 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.290 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.292 = f32[]{:T(128)} add(%reduce_sum.290, %reduce_sum.291), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.70.clone.1.clone.clone.clone.clone (param_0.1579: bf16[4,2048,8,128], param_1.1716: s32[]) -> bf16[2048,8,128,1] { + %param_0.1579 = bf16[4,2048,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %param_1.1716 = s32[]{:T(128)S(6)} parameter(1) + %constant.1356 = s32[]{:T(128)} constant(0) + %dynamic_slice.398 = bf16[1,2048,8,128]{1,3,2,0:T(8,128)(2,1)} dynamic-slice(%param_0.1579, %param_1.1716, %constant.1356, %constant.1356, %constant.1356), dynamic_slice_sizes={1,2048,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.606 = bf16[2048,8,128,1]{0,2,1,3:T(8,128)(2,1)} bitcast(%dynamic_slice.398), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.89.clone.1.clone.clone.clone.clone (param_0.1580: f32[4,128], param_1.1717: bf16[2048], param_2.1410: bf16[4,4,128,2048], param_3.969: s32[]) -> bf16[4,128,2048,1] { + %param_2.1410 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.969 = s32[]{:T(128)S(6)} parameter(3) + %constant.1357 = s32[]{:T(128)} constant(0) + %dynamic_slice.399 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1410, %param_3.969, %constant.1357, %constant.1357, %constant.1357), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %convert_element_type.1589 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} convert(%dynamic_slice.399), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1580 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2874 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_0.1580), dimensions={1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2873 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} multiply(%convert_element_type.1589, %mul.2874), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1588 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} convert(%mul.2873), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1717 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(1) + %mul.2875 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} broadcast(%param_1.1717), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2872 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1588, %mul.2875), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.607 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2872), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.85.clone.clone (param_0.1581: bf16[4,2048,8,128], param_1.1718: s32[], param_2.1411: f32[4,128], param_3.970: bf16[2048], param_4.594: bf16[4,4,128,2048]) -> (f32[4,128,8], bf16[4,128,8,128]) { + %param_2.1411 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.970 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) + %param_4.594 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(4) + %param_1.1718 = s32[]{:T(128)S(6)} parameter(1) + %fusion.90.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_2.1411, %param_3.970, %param_4.594, %param_1.1718), kind=kLoop, calls=%fused_computation.89.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1581 = bf16[4,2048,8,128]{1,3,2,0:T(8,128)(2,1)} parameter(0) + %fusion.85.clone.3 = bf16[2048,8,128,1]{0,2,1,3:T(8,128)(2,1)} fusion(%param_0.1581, %param_1.1718), kind=kLoop, calls=%fused_computation.70.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.56.clone.3 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} convolution(%fusion.90.clone.3, %fusion.85.clone.3), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %convert_element_type.1590 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%convolution.56.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.281 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1590, %convert_element_type.1590), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1358 = f32[]{:T(128)} constant(0) + %reduce.176 = f32[4,128,8]{1,2,0:T(8,128)S(1)} reduce(%square.281, %constant.1358), dimensions={3}, to_apply=%region_16.18, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.208 = (f32[4,128,8]{1,2,0:T(8,128)S(1)}, bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%reduce.176, %convolution.56.clone.3) +} + +%fused_computation.165.clone.1.clone (param_0.1582: f32[4,128,8]) -> f32[4,128,8] { + %param_0.1582 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(0) + %constant.1359 = f32[]{:T(128)} constant(0.0078125) + %closed_call.107 = f32[4,128,8]{1,2,0:T(8,128)} broadcast(%constant.1359), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.1000 = f32[4,128,8]{1,2,0:T(8,128)} multiply(%param_0.1582, %closed_call.107), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1360 = f32[]{:T(128)} constant(1e-06) + %add.1017 = f32[4,128,8]{1,2,0:T(8,128)} broadcast(%constant.1360), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %add.1016 = f32[4,128,8]{1,2,0:T(8,128)} add(%div.1000, %add.1017), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.182 = f32[4,128,8]{1,2,0:T(8,128)S(1)} rsqrt(%add.1016), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%fused_computation.195.clone.clone (param_0.1568: bf16[4,128], param_1.1708: s32[]) -> bf16[128] { + %param_0.1568 = bf16[4,128]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1708 = s32[]{:T(128)S(6)} parameter(1) + %constant.1341 = s32[]{:T(128)} constant(0) + %dynamic_slice.392 = bf16[1,128]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1568, %param_1.1708, %constant.1341), dynamic_slice_sizes={1,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.600 = bf16[128]{0:T(256)(128)(2,1)S(1)} bitcast(%dynamic_slice.392), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.144.clone.1.clone (param_0.1584: f32[4,128,8], param_1.1719: bf16[4,128,8,128], param_2.1412: bf16[128]) -> bf16[4,128,8,128] { + %param_1.1719 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1592 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%param_1.1719), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1584 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(0) + %mul.2878 = f32[4,128,8,128]{3,1,2,0:T(8,128)} broadcast(%param_0.1584), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2877 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1592, %mul.2878), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1591 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2877), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1412 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(2) + %mul.2879 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1412), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2876 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} multiply(%convert_element_type.1591, %mul.2879), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.125.clone.clone (param_0.1585: bf16[4,128,8,128]) -> (bf16[4,128,8,64], bf16[4,128,8,64]) { + %param_0.1585 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %split.158 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)} slice(%param_0.1585), slice={[0:4], [0:128], [0:8], [64:128]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} %neg.128 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} negate(%split.158), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/neg" stack_frame_id=0} - %split.159 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1547), slice={[0:4], [0:128], [0:8], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} - ROOT %tuple.207 = (bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.128, %split.159) -} - -%fused_computation.129.clone.clone (param_0.1548: bf16[4,128,8,64], param_1.1696: bf16[4,128,8,64], param_2.1404: bf16[4,128,128], param_3.981: bf16[4,128,128], param_4.603: f32[4,128,8], param_5.498: bf16[4,128,8,128], param_6.383: bf16[128]) -> bf16[4,8,128,128] { - %param_6.383 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(6) - %dot_general.569 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_6.383), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_5.498 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(5) - %convert_element_type.1573 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%param_5.498), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_4.603 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(4) - %mul.2254 = f32[4,128,8,128]{3,1,2,0:T(8,128)} broadcast(%param_4.603), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2253 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1573, %mul.2254), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1572 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2253), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.568 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%dot_general.569, %convert_element_type.1572), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_3.981 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %mul.2252 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.981), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2250 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%dot_general.568, %mul.2252), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %param_1.1696 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) - %constant.1375 = bf16[]{:T(256)} constant(-inf) - %pad.73 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1696, %constant.1375), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_0.1548 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) - %pad.72 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1548, %constant.1375), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %split.159 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} slice(%param_0.1585), slice={[0:4], [0:128], [0:8], [0:64]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/split" stack_frame_id=0} + ROOT %tuple.209 = (bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}, bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)}) tuple(%neg.128, %split.159) +} + +%fused_computation.128.clone.clone (param_0.1586: bf16[4,128,8,64], param_1.1720: bf16[4,128,8,64], param_2.1413: bf16[4,128,128], param_3.971: bf16[4,128,128], param_4.595: bf16[128], param_5.511: f32[4,128,8], param_6.381: bf16[4,128,8,128]) -> bf16[4,8,128,128] { + %param_6.381 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(6) + %convert_element_type.1594 = f32[4,128,8,128]{3,1,2,0:T(8,128)} convert(%param_6.381), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_5.511 = f32[4,128,8]{1,2,0:T(8,128)S(1)} parameter(5) + %mul.2887 = f32[4,128,8,128]{3,1,2,0:T(8,128)} broadcast(%param_5.511), dimensions={0,1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2886 = f32[4,128,8,128]{3,1,2,0:T(8,128)} multiply(%convert_element_type.1594, %mul.2887), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1593 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convert(%mul.2886), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_4.595 = bf16[128]{0:T(256)(128)(2,1)S(1)} parameter(4) + %mul.2885 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_4.595), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2884 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%convert_element_type.1593, %mul.2885), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_3.971 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %mul.2883 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_3.971), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2881 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%mul.2884, %mul.2883), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_1.1720 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(1) + %constant.1362 = bf16[]{:T(256)} constant(-inf) + %pad.73 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_1.1720, %constant.1362), padding=0_0x0_0x0_0x0_64, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} + %param_0.1586 = bf16[4,128,8,64]{3,1,2,0:T(8,128)(2,1)S(1)} parameter(0) + %pad.72 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} pad(%param_0.1586, %constant.1362), padding=0_0x0_0x0_0x64_0, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} %maximum.55 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} maximum(%pad.73, %pad.72), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/concatenate" stack_frame_id=0} - %param_2.1404 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %mul.2251 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1404), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2249 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.55, %mul.2251), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %add.1042 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2250, %mul.2249), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %bitcast.642 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.1042), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} -} - -%fused_computation.169.clone.clone (param_0.1537: bf16[4,2048,8,128], param_1.1689: s32[]) -> bf16[1,2048,8,128] { - %param_0.1537 = bf16[4,2048,8,128]{3,2,0,1:T(8,128)(2,1)} parameter(0) - %param_1.1689 = s32[]{:T(128)S(6)} parameter(1) - %constant.1367 = s32[]{:T(128)} constant(0) - ROOT %dynamic_slice.390 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} dynamic-slice(%param_0.1537, %param_1.1689, %constant.1367, %constant.1367, %constant.1367), dynamic_slice_sizes={1,2048,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} -} - -%fused_computation.70.clone.1.clone.clone.clone.clone (param_0.1538: bf16[1,2048,8,128]) -> bf16[2048,8,128,1] { - %param_0.1538 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) - %copy.204 = bf16[1,2048,8,128]{3,1,2,0:T(8,128)(2,1)} copy(%param_0.1538), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0} - ROOT %bitcast.634 = bf16[2048,8,128,1]{2,0,1,3:T(8,128)(2,1)} bitcast(%copy.204), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.111.clone.clone.clone.clone (param_0.1539: f32[4,128], param_1.1690: bf16[4,4,128,2048], param_2.1399: s32[], param_3.977: bf16[2048]) -> bf16[4,128,2048,1] { - %param_3.977 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(3) - %dot_general.563 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_3.977), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1690 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1399 = s32[]{:T(128)S(6)} parameter(2) - %constant.1368 = s32[]{:T(128)} constant(0) - %dynamic_slice.391 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_1.1690, %param_2.1399, %constant.1368, %constant.1368, %constant.1368), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %bitcast.636 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.391), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %convert_element_type.1566 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%bitcast.636), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1539 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.2244 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1539), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2243 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1566, %mul.2244), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1565 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2243), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %dot_general.562 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.563, %convert_element_type.1565), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - ROOT %bitcast.635 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%dot_general.562), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.140.clone.clone (param_0.1540: bf16[1,2048,8,128], param_1.1691: f32[4,128], param_2.1400: bf16[4,4,128,2048], param_3.978: s32[], param_4.601: bf16[2048]) -> bf16[4,8,128,128] { - %param_1.1691 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) - %param_2.1400 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(2) - %param_3.978 = s32[]{:T(128)S(6)} parameter(3) - %param_4.601 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.373 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_1.1691, %param_2.1400, %param_3.978, %param_4.601), kind=kLoop, calls=%fused_computation.111.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1540 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) - %fusion.372 = bf16[2048,8,128,1]{2,0,1,3:T(8,128)(2,1)} fusion(%param_0.1540), kind=kLoop, calls=%fused_computation.70.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %convolution.106 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convolution(%fusion.373, %fusion.372), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} - ROOT %bitcast.637 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%convolution.106), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} -} - -%fused_computation.188.clone.clone (param_0.1578: f32[4,16,128,128]) -> (f32[4,16,128], f32[4,16,128,1]) { - %param_0.1578 = f32[4,16,128,128]{2,1,0,3:T(8,128)S(1)} parameter(0) - %slice.11 = f32[4,16,128,1]{2,1,0,3:T(8,128)S(1)} slice(%param_0.1578), slice={[0:4], [0:16], [0:128], [0:1]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/slice" stack_frame_id=0} - %bitcast.660 = f32[4,16,128]{2,1,0:T(8,128)S(1)} bitcast(%slice.11), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/squeeze" stack_frame_id=0} - ROOT %tuple.213 = (f32[4,16,128]{2,1,0:T(8,128)S(1)}, f32[4,16,128,1]{2,1,0,3:T(8,128)S(1)}) tuple(%bitcast.660, %slice.11) -} - -%region_17.20 (reduce_sum.219: f32[], reduce_sum.220: f32[]) -> f32[] { - %reduce_sum.220 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - %reduce_sum.219 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} - ROOT %reduce_sum.221 = f32[]{:T(128)} add(%reduce_sum.219, %reduce_sum.220), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} -} - -%fused_computation.26.clone.1.clone.clone.clone.clone.clone.clone (param_0.1557: bf16[4,16,128,2048], param_1.1703: s32[]) -> bf16[16,128,2048,1] { - %param_0.1557 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1703 = s32[]{:T(128)S(6)} parameter(1) - %constant.1384 = s32[]{:T(128)} constant(0) - %dynamic_slice.397 = bf16[1,16,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1557, %param_1.1703, %constant.1384, %constant.1384, %constant.1384), dynamic_slice_sizes={1,16,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.648 = bf16[16,128,2048,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dynamic_slice.397), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.103.clone.clone.clone.clone.clone.clone (param_0.1558: bf16[4,16,128,128]) -> bf16[4,128,16,128] { - %param_0.1558 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) - ROOT %bitcast.649 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} bitcast(%param_0.1558), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} -} - -%fused_computation.64.clone.clone (param_0.1559: bf16[4,16,128,2048], param_1.1704: s32[], param_2.1409: bf16[4,16,128,128], param_3.985: bf16[4,4,128,2048]) -> (f32[4,128], bf16[4,128,2048]) { - %param_3.985 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) - %param_1.1704 = s32[]{:T(128)S(6)} parameter(1) - %constant.436.clone.1.clone.3 = s32[]{:T(128)} constant(0) - %dynamic_slice.242.clone.3 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_3.985, %param_1.1704, %constant.436.clone.1.clone.3, %constant.436.clone.1.clone.3, %constant.436.clone.1.clone.3), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %bitcast.227.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.242.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} - %param_2.1409 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(2) - %fusion.96.clone.3 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} fusion(%param_2.1409), kind=kLoop, calls=%fused_computation.103.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} - %param_0.1559 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) - %fusion.95.clone.3 = bf16[16,128,2048,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.1559, %param_1.1704), kind=kLoop, calls=%fused_computation.26.clone.1.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %convolution.62.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} convolution(%fusion.96.clone.3, %fusion.95.clone.3), window={size=1x16}, dim_labels=0b1f_1io0->0bf1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} - %bitcast.203.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%convolution.62.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} - %add.768.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} add(%bitcast.227.clone.3, %bitcast.203.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - %convert_element_type.1581 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%add.768.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %square.283 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1581, %convert_element_type.1581), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} - %constant.1385 = f32[]{:T(128)} constant(0) - %reduce.248 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.283, %constant.1385), dimensions={2}, to_apply=%region_17.20, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} - ROOT %tuple.210 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.248, %add.768.clone.3) -} - -%convert_element_type.763.reduce_sub_computation (lhs: bf16[], rhs: bf16[]) -> bf16[] { - %rhs = bf16[] parameter(1) - %lhs = bf16[] parameter(0) - ROOT %add.754 = bf16[] add(%lhs, %rhs), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} + %param_2.1413 = bf16[4,128,128]{2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %mul.2882 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} broadcast(%param_2.1413), dimensions={0,1,3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2880 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} multiply(%maximum.55, %mul.2882), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %add.1018 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} add(%mul.2881, %mul.2880), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %bitcast.609 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%add.1018), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.181.clone.clone (param_0.1575: bf16[4,2048,8,128], param_1.1713: s32[]) -> bf16[1,2048,8,128] { + %param_0.1575 = bf16[4,2048,8,128]{3,2,0,1:T(8,128)(2,1)} parameter(0) + %param_1.1713 = s32[]{:T(128)S(6)} parameter(1) + %constant.1354 = s32[]{:T(128)} constant(0) + ROOT %dynamic_slice.396 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} dynamic-slice(%param_0.1575, %param_1.1713, %constant.1354, %constant.1354, %constant.1354), dynamic_slice_sizes={1,2048,8,128}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} +} + +%fused_computation.71.clone.1.clone.clone.clone.clone (param_0.1576: bf16[1,2048,8,128]) -> bf16[2048,8,128,1] { + %param_0.1576 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %copy.200 = bf16[1,2048,8,128]{3,1,2,0:T(8,128)(2,1)} copy(%param_0.1576), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0} + ROOT %bitcast.603 = bf16[2048,8,128,1]{2,0,1,3:T(8,128)(2,1)} bitcast(%copy.200), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.89.clone.clone.clone.clone.clone (param_0.1577: f32[4,128], param_1.1714: bf16[2048], param_2.1408: bf16[4,4,128,2048], param_3.967: s32[]) -> bf16[4,128,2048,1] { + %param_2.1408 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(2) + %param_3.967 = s32[]{:T(128)S(6)} parameter(3) + %constant.1355 = s32[]{:T(128)} constant(0) + %dynamic_slice.397 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_2.1408, %param_3.967, %constant.1355, %constant.1355, %constant.1355), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %convert_element_type.1587 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} convert(%dynamic_slice.397), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1577 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2870 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} broadcast(%param_0.1577), dimensions={1,2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2869 = f32[1,4,128,2048]{3,2,1,0:T(8,128)} multiply(%convert_element_type.1587, %mul.2870), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1586 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} convert(%mul.2869), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_1.1714 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(1) + %mul.2871 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} broadcast(%param_1.1714), dimensions={3}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2868 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1586, %mul.2871), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %bitcast.604 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} bitcast(%mul.2868), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.151.clone.clone (param_0.1578: bf16[1,2048,8,128], param_1.1715: f32[4,128], param_2.1409: bf16[2048], param_3.968: bf16[4,4,128,2048], param_4.593: s32[]) -> bf16[4,8,128,128] { + %param_1.1715 = f32[4,128]{1,0:T(4,128)S(1)} parameter(1) + %param_2.1409 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %param_3.968 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_4.593 = s32[]{:T(128)S(6)} parameter(4) + %fusion.380 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} fusion(%param_1.1715, %param_2.1409, %param_3.968, %param_4.593), kind=kLoop, calls=%fused_computation.89.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1578 = bf16[1,2048,8,128]{3,2,0,1:T(8,128)(2,1)S(1)} parameter(0) + %fusion.379 = bf16[2048,8,128,1]{2,0,1,3:T(8,128)(2,1)} fusion(%param_0.1578), kind=kLoop, calls=%fused_computation.71.clone.1.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.105 = bf16[4,128,8,128]{3,1,2,0:T(8,128)(2,1)} convolution(%fusion.380, %fusion.379), window={size=1x8 pad=0_0x7_7 rhs_reversal=0x1}, dim_labels=0bf1_i1o0->0b1f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + ROOT %bitcast.605 = bf16[4,8,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} bitcast(%convolution.105), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.199.clone.clone (param_0.1618: f32[4,16,128,128]) -> (f32[4,16,128], f32[4,16,128,1]) { + %param_0.1618 = f32[4,16,128,128]{2,1,0,3:T(8,128)S(1)} parameter(0) + %slice.11 = f32[4,16,128,1]{2,1,0,3:T(8,128)S(1)} slice(%param_0.1618), slice={[0:4], [0:16], [0:128], [0:1]}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/slice" stack_frame_id=0} + %bitcast.626 = f32[4,16,128]{2,1,0:T(8,128)S(1)} bitcast(%slice.11), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/shard_map/vmap(jit(_splash_attention))/squeeze" stack_frame_id=0} + ROOT %tuple.216 = (f32[4,16,128]{2,1,0:T(8,128)S(1)}, f32[4,16,128,1]{2,1,0,3:T(8,128)S(1)}) tuple(%bitcast.626, %slice.11) +} + +%region_17.20 (reduce_sum.296: f32[], reduce_sum.297: f32[]) -> f32[] { + %reduce_sum.297 = f32[]{:T(128)} parameter(1), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + %reduce_sum.296 = f32[]{:T(128)} parameter(0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum"} + ROOT %reduce_sum.298 = f32[]{:T(128)} add(%reduce_sum.296, %reduce_sum.297), metadata={op_name="checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} +} + +%fused_computation.26.clone.1.clone.clone.clone.clone.clone.clone (param_0.1595: bf16[4,16,128,2048], param_1.1727: s32[]) -> bf16[16,128,2048,1] { + %param_0.1595 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1727 = s32[]{:T(128)S(6)} parameter(1) + %constant.1371 = s32[]{:T(128)} constant(0) + %dynamic_slice.403 = bf16[1,16,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1595, %param_1.1727, %constant.1371, %constant.1371, %constant.1371), dynamic_slice_sizes={1,16,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.614 = bf16[16,128,2048,1]{2,1,0,3:T(8,128)(2,1)} bitcast(%dynamic_slice.403), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.110.clone.clone.clone.clone.clone.clone (param_0.1596: bf16[4,16,128,128]) -> bf16[4,128,16,128] { + %param_0.1596 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(0) + ROOT %bitcast.615 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} bitcast(%param_0.1596), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} +} + +%fused_computation.67.clone.clone (param_0.1597: bf16[4,16,128,2048], param_1.1728: s32[], param_2.1418: bf16[4,16,128,128], param_3.975: bf16[4,4,128,2048]) -> (f32[4,128], bf16[4,128,2048]) { + %param_3.975 = bf16[4,4,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(3) + %param_1.1728 = s32[]{:T(128)S(6)} parameter(1) + %constant.414.clone.1.clone.3 = s32[]{:T(128)} constant(0) + %dynamic_slice.265.clone.3 = bf16[1,4,128,2048]{3,2,1,0:T(8,128)(2,1)} dynamic-slice(%param_3.975, %param_1.1728, %constant.414.clone.1.clone.3, %constant.414.clone.1.clone.3, %constant.414.clone.1.clone.3), dynamic_slice_sizes={1,4,128,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %bitcast.212.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%dynamic_slice.265.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/squeeze" stack_frame_id=0} + %param_2.1418 = bf16[4,16,128,128]{3,2,1,0:T(8,128)(2,1)S(1)} parameter(2) + %fusion.100.clone.3 = bf16[4,128,16,128]{3,1,2,0:T(8,128)(2,1)} fusion(%param_2.1418), kind=kLoop, calls=%fused_computation.110.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/transpose" stack_frame_id=0} + %param_0.1597 = bf16[4,16,128,2048]{3,2,1,0:T(8,128)(2,1)} parameter(0) + %fusion.99.clone.3 = bf16[16,128,2048,1]{2,1,0,3:T(8,128)(2,1)} fusion(%param_0.1597, %param_1.1728), kind=kLoop, calls=%fused_computation.26.clone.1.clone.clone.clone.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %convolution.62.clone.3 = bf16[4,128,2048,1]{2,1,3,0:T(8,128)(2,1)} convolution(%fusion.100.clone.3, %fusion.99.clone.3), window={size=1x16}, dim_labels=0b1f_1io0->0bf1, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %bitcast.204.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%convolution.62.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} + %add.744.clone.3 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} add(%bitcast.212.clone.3, %bitcast.204.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + %convert_element_type.1602 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%add.744.clone.3), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %square.283 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1602, %convert_element_type.1602), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/square" stack_frame_id=0} + %constant.1372 = f32[]{:T(128)} constant(0) + %reduce.178 = f32[4,128]{1,0:T(4,128)S(1)} reduce(%square.283, %constant.1372), dimensions={2}, to_apply=%region_17.20, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/reduce_sum" stack_frame_id=0} + ROOT %tuple.212 = (f32[4,128]{1,0:T(4,128)S(1)}, bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)}) tuple(%reduce.178, %add.744.clone.3) +} + +%convert_element_type.808.reduce_sub_computation (lhs.1: bf16[], rhs.1: bf16[]) -> bf16[] { + %rhs.1 = bf16[] parameter(1) + %lhs.1 = bf16[] parameter(0) + ROOT %add.731 = bf16[] add(%lhs.1, %rhs.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"used_scoped_memory_configs":[],"aliasing_operands":{"lists":[]}} } -%fused_computation.156.clone.clone (param_0.1531: bf16[4,2048], param_1.1684: s32[]) -> bf16[2048] { - %param_0.1531 = bf16[4,2048]{1,0:T(4,128)(2,1)} parameter(0) - %param_1.1684 = s32[]{:T(128)S(6)} parameter(1) - %constant.1358 = s32[]{:T(128)} constant(0) - %dynamic_slice.386 = bf16[1,2048]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1531, %param_1.1684, %constant.1358), dynamic_slice_sizes={1,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - %constant.1359 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %reduce.243 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.386, %constant.1359), dimensions={0}, to_apply=%convert_element_type.763.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.166.clone.clone (param_0.1571: bf16[4,2048], param_1.1710: s32[]) -> bf16[2048] { + %param_0.1571 = bf16[4,2048]{1,0:T(4,128)(2,1)} parameter(0) + %param_1.1710 = s32[]{:T(128)S(6)} parameter(1) + %constant.1346 = s32[]{:T(128)} constant(0) + %dynamic_slice.393 = bf16[1,2048]{1,0:T(2,128)(2,1)} dynamic-slice(%param_0.1571, %param_1.1710, %constant.1346), dynamic_slice_sizes={1,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + %constant.1347 = bf16[]{:T(256)} constant(-0), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %reduce.173 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} reduce(%dynamic_slice.393, %constant.1347), dimensions={0}, to_apply=%convert_element_type.808.reduce_sub_computation, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } -%fused_computation.13.clone.clone.clone (param_0.1532: bf16[4,6144,2048], param_1.1685: s32[]) -> bf16[6144,2048,1] { - %param_0.1532 = bf16[4,6144,2048]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1685 = s32[]{:T(128)S(6)} parameter(1) - %constant.1360 = s32[]{:T(128)} constant(0) - %dynamic_slice.387 = bf16[1,6144,2048]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1532, %param_1.1685, %constant.1360, %constant.1360), dynamic_slice_sizes={1,6144,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.632 = bf16[6144,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.387), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +%fused_computation.191.clone.1.clone (param_0.1598: f32[4,128]) -> f32[4,128] { + %param_0.1598 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %constant.1374 = f32[]{:T(128)} constant(0.00048828125) + %closed_call.111 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1374), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %div.1002 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1598, %closed_call.111), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} + %constant.1373 = f32[]{:T(128)} constant(1e-06) + %closed_call.110 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1373), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} + %add.1022 = f32[4,128]{1,0:T(4,128)} add(%div.1002, %closed_call.110), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} + ROOT %rsqrt.184 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.1022), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} +} + +%fused_computation.11.clone.1.clone.clone (param_0.1599: bf16[4,2048,6144], param_1.1729: s32[]) -> bf16[2048,6144,1] { + %param_0.1599 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1729 = s32[]{:T(128)S(6)} parameter(1) + %constant.1375 = s32[]{:T(128)} constant(0) + %dynamic_slice.404 = bf16[1,2048,6144]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1599, %param_1.1729, %constant.1375, %constant.1375), dynamic_slice_sizes={1,2048,6144}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.616 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.404), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.116.clone.3.clone.clone (param_0.1600: f32[4,128], param_1.1730: bf16[4,128,2048], param_2.1419: bf16[2048]) -> bf16[4,128,2048] { + %param_1.1730 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1604 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1730), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1600 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2907 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1600), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2906 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1604, %mul.2907), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1603 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2906), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1419 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2908 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1419), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2905 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1603, %mul.2908), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.23.clone.clone (param_0.1601: bf16[4,2048,6144], param_1.1731: s32[], param_2.1420: f32[4,128], param_3.976: bf16[4,128,2048], param_4.598: bf16[2048]) -> bf16[4,128,6144] { + %param_2.1420 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.976 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.598 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.382 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1420, %param_3.976, %param_4.598), kind=kLoop, calls=%fused_computation.116.clone.3.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1601 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1731 = s32[]{:T(128)S(6)} parameter(1) + %fusion.381 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1601, %param_1.1731), kind=kLoop, calls=%fused_computation.11.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %convolution.106 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.382, %fusion.381), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} +} + +%fused_computation.14.clone.clone.clone (param_0.1604: bf16[4,2048,6144], param_1.1734: s32[]) -> bf16[2048,6144,1] { + %param_0.1604 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1734 = s32[]{:T(128)S(6)} parameter(1) + %constant.1377 = s32[]{:T(128)} constant(0) + %dynamic_slice.406 = bf16[1,2048,6144]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1604, %param_1.1734, %constant.1377, %constant.1377), dynamic_slice_sizes={1,2048,6144}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.619 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.406), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} +} + +%fused_computation.116.clone.2.clone.clone (param_0.1605: f32[4,128], param_1.1735: bf16[4,128,2048], param_2.1422: bf16[2048]) -> bf16[4,128,2048] { + %param_1.1735 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) + %convert_element_type.1606 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1735), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_0.1605 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) + %mul.2911 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1605), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %mul.2910 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1606, %mul.2911), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %convert_element_type.1605 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2910), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + %param_2.1422 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) + %mul.2912 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1422), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + ROOT %mul.2909 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%convert_element_type.1605, %mul.2912), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} +} + +%fused_computation.20.clone.clone (param_0.1606: bf16[4,2048,6144], param_1.1736: s32[], param_2.1423: f32[4,128], param_3.977: bf16[4,128,2048], param_4.599: bf16[2048]) -> bf16[4,128,6144] { + %param_2.1423 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) + %param_3.977 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) + %param_4.599 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) + %fusion.386 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1423, %param_3.977, %param_4.599), kind=kLoop, calls=%fused_computation.116.clone.2.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} + %param_0.1606 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1736 = s32[]{:T(128)S(6)} parameter(1) + %fusion.385 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1606, %param_1.1736), kind=kLoop, calls=%fused_computation.14.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} + ROOT %convolution.108 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.386, %fusion.385), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} +} + +%fused_computation.12.clone.clone.clone (param_0.1602: bf16[4,6144,2048], param_1.1732: s32[]) -> bf16[6144,2048,1] { + %param_0.1602 = bf16[4,6144,2048]{2,1,0:T(8,128)(2,1)} parameter(0) + %param_1.1732 = s32[]{:T(128)S(6)} parameter(1) + %constant.1376 = s32[]{:T(128)} constant(0) + %dynamic_slice.405 = bf16[1,6144,2048]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1602, %param_1.1732, %constant.1376, %constant.1376), dynamic_slice_sizes={1,6144,2048}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} + ROOT %bitcast.618 = bf16[6144,2048,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.405), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} } %bitcast_fusion.1.clone.clone (bitcast_input.4: bf16[4,128,2048]) -> bf16[4,128,2048] { - %bitcast_input.4 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - ROOT %bitcast.631 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%bitcast_input.4) -} - -%fused_computation.14.clone.clone (param_0.1533: bf16[4,128,2048], param_1.1686: bf16[4,6144,2048], param_2.1398: s32[]) -> bf16[6144,4,128] { - %param_1.1686 = bf16[4,6144,2048]{2,1,0:T(8,128)(2,1)} parameter(1) - %param_2.1398 = s32[]{:T(128)S(6)} parameter(2) - %fusion.370 = bf16[6144,2048,1]{1,0,2:T(8,128)(2,1)} fusion(%param_1.1686, %param_2.1398), kind=kLoop, calls=%fused_computation.13.clone.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1533 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(0) - %fusion.371 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_0.1533), kind=kLoop, calls=%bitcast_fusion.1.clone.clone - ROOT %convolution.105 = bf16[6144,4,128]{0,2,1:T(8,128)(2,1)S(1)} convolution(%fusion.370, %fusion.371), window={size=4 pad=3_3 rhs_reversal=1}, dim_labels=bf0_0oi->b0f, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/layers/dot_general" stack_frame_id=0} -} - -%fused_computation.180.clone.1.clone (param_0.1560: f32[4,128]) -> f32[4,128] { - %param_0.1560 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %constant.1387 = f32[]{:T(128)} constant(0.00048828125) - %closed_call.111 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1387), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %div.1002 = f32[4,128]{1,0:T(4,128)} multiply(%param_0.1560, %closed_call.111), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/div" stack_frame_id=0} - %constant.1386 = f32[]{:T(128)} constant(1e-06) - %closed_call.110 = f32[4,128]{1,0:T(4,128)} broadcast(%constant.1386), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call" stack_frame_id=0} - %add.1046 = f32[4,128]{1,0:T(4,128)} add(%div.1002, %closed_call.110), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/add" stack_frame_id=0} - ROOT %rsqrt.184 = f32[4,128]{1,0:T(4,128)S(1)} rsqrt(%add.1046), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/rsqrt" stack_frame_id=0} -} - -%fused_computation.12.clone.1.clone.clone (param_0.1564: bf16[4,2048,6144], param_1.1708: s32[]) -> bf16[2048,6144,1] { - %param_0.1564 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1708 = s32[]{:T(128)S(6)} parameter(1) - %constant.1389 = s32[]{:T(128)} constant(0) - %dynamic_slice.399 = bf16[1,2048,6144]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1564, %param_1.1708, %constant.1389, %constant.1389), dynamic_slice_sizes={1,2048,6144}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.651 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.399), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.119.clone.3.clone.clone (param_0.1565: f32[4,128], param_1.1709: bf16[4,128,2048], param_2.1412: bf16[2048]) -> bf16[4,128,2048] { - %param_2.1412 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(2) - %dot_general.579 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} broadcast(%param_2.1412), dimensions={2}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_1.1709 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %convert_element_type.1585 = f32[4,128,2048]{2,1,0:T(8,128)} convert(%param_1.1709), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - %param_0.1565 = f32[4,128]{1,0:T(4,128)S(1)} parameter(0) - %mul.2269 = f32[4,128,2048]{2,1,0:T(8,128)} broadcast(%param_0.1565), dimensions={0,1}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %mul.2268 = f32[4,128,2048]{2,1,0:T(8,128)} multiply(%convert_element_type.1585, %mul.2269), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/mul" stack_frame_id=0} - %convert_element_type.1584 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} convert(%mul.2268), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %dot_general.578 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} multiply(%dot_general.579, %convert_element_type.1584), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} -} - -%fused_computation.21.clone.clone (param_0.1566: bf16[4,2048,6144], param_1.1710: s32[], param_2.1413: f32[4,128], param_3.987: bf16[4,128,2048], param_4.607: bf16[2048]) -> bf16[4,128,6144] { - %param_2.1413 = f32[4,128]{1,0:T(4,128)S(1)} parameter(2) - %param_3.987 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)S(1)} parameter(3) - %param_4.607 = bf16[2048]{0:T(1024)(128)(2,1)S(1)} parameter(4) - %fusion.377 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} fusion(%param_2.1413, %param_3.987, %param_4.607), kind=kLoop, calls=%fused_computation.119.clone.3.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/...k,k->...k/dot_general" stack_frame_id=0} - %param_0.1566 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1710 = s32[]{:T(128)S(6)} parameter(1) - %fusion.376 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} fusion(%param_0.1566, %param_1.1710), kind=kLoop, calls=%fused_computation.12.clone.1.clone.clone, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} - ROOT %convolution.108 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)S(1)} convolution(%fusion.377, %fusion.376), window={size=1}, dim_labels=0bf_io0->0bf, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/dot_general" stack_frame_id=0} -} - -%fused_computation.11.clone.1.clone.clone (param_0.1568: bf16[4,2048,6144], param_1.1712: s32[]) -> bf16[2048,6144,1] { - %param_0.1568 = bf16[4,2048,6144]{2,1,0:T(8,128)(2,1)} parameter(0) - %param_1.1712 = s32[]{:T(128)S(6)} parameter(1) - %constant.1391 = s32[]{:T(128)} constant(0) - %dynamic_slice.400 = bf16[1,2048,6144]{2,1,0:T(8,128)(2,1)} dynamic-slice(%param_0.1568, %param_1.1712, %constant.1391, %constant.1391), dynamic_slice_sizes={1,2048,6144}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/dynamic_slice" stack_frame_id=0}, backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"4294967292","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"},{"zeroes":"4294967295","ones":"0","bitwidth":"32"}],"is_index_aligned":[]},"used_scoped_memory_configs":[]} - ROOT %bitcast.653 = bf16[2048,6144,1]{1,0,2:T(8,128)(2,1)} bitcast(%dynamic_slice.400), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/convert_element_type" stack_frame_id=0} -} - -%fused_computation.47.clone.1.clone.clone (param_0.1567: bf16[6144,4,128], param_1.1711: bf16[4,128,6144]) -> bf16[4,128,6144] { - %param_1.1711 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)S(1)} parameter(1) - %constant.1390 = bf16[]{:T(256)} constant(1) - %jit_silu_.44 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)} broadcast(%constant.1390), dimensions={}, metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)" stack_frame_id=0} - %neg.130 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)} negate(%param_1.1711), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/neg" stack_frame_id=0} - %exp.69 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)} exponential(%neg.130), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/exp" stack_frame_id=0} - %add.1047 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)} add(%exp.69, %jit_silu_.44), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/add" stack_frame_id=0} - %div.1003 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)} divide(%jit_silu_.44, %add.1047), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/div" stack_frame_id=0} - %mul.2271 = bf16[4,128,6144]{2,1,0:T(8,128)(2,1)} multiply(%param_1.1711, %div.1003), metadata={op_name="jit(train_step)/transpose(jvp(TransformerLinenPure.apply))/TransformerLinenPure/decoder/while/body/closed_call/checkpoint/rematted_computation/layers/jit(silu)/mul" stack_frame_id=0} + %bitcast_input.4 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} parameter(0) + ROOT %bitcast.617 = bf16[4,128,2048]{2,1,0:T(8,128)(2,1)} bitcast(%bitcast_input.4)