diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index dee93f0ef8..68cddbfd4f 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -1367,10 +1367,14 @@ def cudnn_flash_attention( qkv_layout = "THD_THD_THD" # Packed format: 'T3HD', 'THD_T2HD' or 'THD_THD_THD' if decoder_segment_ids is None: decoder_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) - attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=decoder_segment_ids, segment_pos=None) + attn_mask = SequenceDescriptor.from_segment_ids_and_pos( + segment_ids=decoder_segment_ids, segment_pos=None, is_thd=True, is_segment_ids_reordered=False + ) # Create dummy SequenceDescriptor for lazy_init dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) - dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=dummy_segment_ids, segment_pos=None) + dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos( + segment_ids=dummy_segment_ids, segment_pos=None, is_thd=True, is_segment_ids_reordered=False + ) max_segments_per_seq = self.config.max_segments_per_seq elif using_context_parallelism or self.config.dataset_type == "synthetic": if self.attention_type == AttentionType.LOCAL_SLIDING: