Skip to content

fix: GQA kernel relax seqlens_k shape check to element count #1067

Open
ankitm3k wants to merge 2 commits intoovep-developfrom
ankit/gqa-seqlens-k-shape-relax
Open

fix: GQA kernel relax seqlens_k shape check to element count #1067
ankitm3k wants to merge 2 commits intoovep-developfrom
ankit/gqa-seqlens-k-shape-relax

Conversation

@ankitm3k
Copy link
Copy Markdown

Summary

PR microsoft#28031 tightened the seqlens_k shape check in GroupQueryAttention to require rank 1 AND dim[0] == batch_size. While the intent was sound, the rank-1 requirement is stricter than necessary for memory safety and collaterally rejects valid {batch_size, 1} tensors produced by common GQA exports — including Microsoft's own Phi-3 / Phi-3.5 ONNX models from HuggingFace ( microsoft/Phi-3.5-mini-instruct-onnx ), onnxruntime-genai model-builder output, and optimum-onnx-exported LLMs.

This PR replaces the rank check with an element-count check using TensorShape::Size(). The per-element bounds loop in group_query_attention.cc (the actual CVE fix for MSRC-108962) is unchanged, so all security guarantees from microsoft#28031 are preserved.

Reproducer before this fix

onnxruntime_perf_test.exe -I -e cpu -r 1 -o 0 -f "batch_size:1" \
  "Phi-3.5-mini-instruct-cpu-int4-awq-block-128-acc-level-4.onnx"

or any Python app running the same model via CPUExecutionProvider fails with:

[E] Non-zero status code returned while running GroupQueryAttention node.
    Status Message: seqlens_k must be shape (batch_size).

The seqlens_k tensor flowing into the op is {1, 1} (a legitimate export shape — ReduceSum(attention_mask, axis=-1) of a [batch, seq] mask keeps the trailing singleton). Pre-microsoft#28031, the buggy && predicate silently accepted it. Post-microsoft#28031, the || predicate rejects it even though it is memory-safe.

Change

// before
const auto& seqlens_k_dim = seqlens_k->Shape().GetDims();
if (seqlens_k_dim.size() != 1 || seqlens_k_dim[0] != batch_size) {
  return ORT_MAKE_STATUS(..., "seqlens_k must be shape (batch_size).");
}

// after
if (seqlens_k->Shape().Size() != static_cast<int64_t>(batch_size)) {
  return ORT_MAKE_STATUS(..., "seqlens_k must contain batch_size elements.");
}

This mirrors the existing IsScalarOr1ElementVector pattern used in the same helper for total_seqlen (L273) — acknowledging that ONNX exporters legitimately emit singleton tensors with varying rank.

Security equivalence to microsoft#28031

microsoft#28031 defends against MSRC-108962, where crafted models supplied negative or oversized per-element seqlens_k values that, cast to size_t, produced OOB reads against the K/V cache. Its four parts:

# Defense Status in this PR
1 Per-element lower/upper bounds at group_query_attention.cc:88-100 Unchanged
2 Non-first-prompt underflow check Unchanged
3 SafeInt casts in gqa_attention_base.h / attention_helper.h Unchanged
4 Rank/shape check in group_query_attention_helper.h Replaced with element-count check

The CPU kernel only ever reads seqlens_k as a contiguous int32_t* in a for (b = 0; b < batch_size; ...) loop (group_query_attention.cc:86-101, gqa_attention_base.h:112-135). Shape is never read. The necessary-and-sufficient memory-safety invariant is therefore "the buffer holds exactly batch_size int32 elements", which is precisely what Shape().Size() == batch_size enforces. CUDA and WebGPU kernels follow the same pattern — they use parameters.batch_size (derived from query.Shape()[0]) to size device-side reads, never seqlens_k's own shape — so this relaxation is safe across all EPs sharing this helper.

Edge-case audit

seqlens_k shape batch_size Old (post-microsoft#28031) New Safe?
{B} B accept accept same
{B, 1}, {1, B} B reject accept safe — same B contiguous int32s
{} (scalar) 1 reject accept safe — scalar int32 is 1 element
{} (scalar) 2 reject reject (1≠2) rejected correctly
{2} 1 reject reject (2≠1) length-mismatch rejection preserved
{B, 2} B reject reject (2B≠B) rejected correctly
overflowing product of dims any passed to bounds check later SafeInt in SizeHelper throws synchronously strictly safer

The SafeInt overflow trap inside TensorShape::Size() is a minor upgrade over the old code, which never called Size() for this validation path.

Test changes

  • SeqlensKWrongRank → renamed to SeqlensKRank2SingletonAccepted and flipped from expect-failure to expect-success. Serves as a regression guard for Phi-3/3.5-style exports with {batch_size, 1} seqlens_k.
  • SeqlensKWrongLength kept; expected error-string updated to match the new message ("seqlens_k must contain batch_size elements").
  • Existing CVE regression tests (NegativeSeqlensK_OOB, OversizedSeqlensK_OOB, MultiBatchOneBadSeqlensK_OOB, Int32MaxSeqlensK_OOB, NonPromptSeqlensKUnderflow_OOB, etc.) are unaffected — all still expect failure against the per-element bounds loop.

Verified with

  • End-to-end Python inference app (phi_muffin_app) on PSU_MF (Phi-3-mini derived, 28-layer, int4-AWQ QDQ) via OpenVINO EP with device=CPU — GQA node falls through to the stock CPU kernel, runs correctly, generated output is coherent.

Related

@MayureshV1
Copy link
Copy Markdown

MayureshV1 commented Apr 30, 2026

Might not be needed if microsoft#28259 is merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants