From 41916d078a9e072362bd4c8f4be546102e0d5348 Mon Sep 17 00:00:00 2001 From: Andy Ye Date: Tue, 7 Apr 2026 09:59:48 -0400 Subject: [PATCH] Fix SequenceDescriptor.from_segment_ids_and_pos() for TE >= 2.12 TransformerEngine v2.12 (NVIDIA/TransformerEngine#2523) made `is_thd` and `is_segment_ids_reordered` required keyword arguments on `SequenceDescriptor.from_segment_ids_and_pos()` to fix incorrect segment position calculation for THD layouts. Since the packing branch in `cudnn_flash_attention` uses `qkv_layout="THD_THD_THD"` with standard (non-reordered) segment IDs, the correct values are `is_thd=True, is_segment_ids_reordered=False`. Without this fix, any configuration using `attention="cudnn_flash_te"` with `packing=True` and real data (`dataset_type != "synthetic"`) fails with: TypeError: SequenceDescriptor.from_segment_ids_and_pos() missing 2 required keyword-only arguments: 'is_thd' and 'is_segment_ids_reordered' --- src/MaxText/layers/attention_op.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/MaxText/layers/attention_op.py b/src/MaxText/layers/attention_op.py index dee93f0ef8..68cddbfd4f 100644 --- a/src/MaxText/layers/attention_op.py +++ b/src/MaxText/layers/attention_op.py @@ -1367,10 +1367,14 @@ def cudnn_flash_attention( qkv_layout = "THD_THD_THD" # Packed format: 'T3HD', 'THD_T2HD' or 'THD_THD_THD' if decoder_segment_ids is None: decoder_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) - attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=decoder_segment_ids, segment_pos=None) + attn_mask = SequenceDescriptor.from_segment_ids_and_pos( + segment_ids=decoder_segment_ids, segment_pos=None, is_thd=True, is_segment_ids_reordered=False + ) # Create dummy SequenceDescriptor for lazy_init dummy_segment_ids = jnp.ones(shape=query.shape[:2], dtype=jnp.int32) - dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos(segment_ids=dummy_segment_ids, segment_pos=None) + dummy_attn_mask = SequenceDescriptor.from_segment_ids_and_pos( + segment_ids=dummy_segment_ids, segment_pos=None, is_thd=True, is_segment_ids_reordered=False + ) max_segments_per_seq = self.config.max_segments_per_seq elif using_context_parallelism or self.config.dataset_type == "synthetic": if self.attention_type == AttentionType.LOCAL_SLIDING: