From f93e1deb74814803aa0c3001077945f4d6f8ad3f Mon Sep 17 00:00:00 2001 From: ankitm3k Date: Wed, 29 Apr 2026 20:48:02 +0530 Subject: [PATCH] fix: GQA kernel relax seqlens_k shape check to element count (follow-up #28031) --- .../cpu/bert/group_query_attention_helper.h | 9 ++++++--- .../contrib_ops/group_query_attention_op_test.cc | 13 ++++++++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index f5399e307fbca..0a2e817fe8ac1 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -261,10 +261,13 @@ Status CheckInputs(const T* query, "Input 'past_key' and 'past_value' shall be both present or both absent."); } - const auto& seqlens_k_dim = seqlens_k->Shape().GetDims(); - if (seqlens_k_dim.size() != 1 || seqlens_k_dim[0] != batch_size) { + // Accept any shape whose total element count equals batch_size (e.g. {batch_size}, + // {batch_size, 1}, or scalar when batch_size == 1). The per-element bounds check in + // group_query_attention.cc still enforces 0 <= seqlens_k[b] < present_kv_seqlen, + // preserving the OOB protection added in PR #28031. + if (seqlens_k->Shape().Size() != static_cast(batch_size)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "seqlens_k must be shape (batch_size)."); + "seqlens_k must contain batch_size elements."); } if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index 0690094031bb8..cc966449bd147 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -231,8 +231,10 @@ TEST(GroupQueryAttentionTest, TotalSeqLenNegative) { "total_sequence_length must be positive"); } -// Shape validation: seqlens_k with wrong rank (2D instead of 1D) must be rejected. -TEST(GroupQueryAttentionTest, SeqlensKWrongRank) { +// Shape compatibility: seqlens_k with rank 2 ({batch_size, 1}) must be accepted because +// the element count matches batch_size. Per-element bounds are still enforced separately. +// Regression for GQA models exported with a trailing singleton dim (e.g. Phi-3 / Phi-3.5). +TEST(GroupQueryAttentionTest, SeqlensKRank2SingletonAccepted) { constexpr int num_heads = 1; constexpr int kv_num_heads = 1; constexpr int head_size = 8; @@ -248,7 +250,7 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongRank) { tester.AddInput("value", {1, 1, kv_hidden_size}, std::vector(kv_hidden_size, 1.0f)); tester.AddOptionalInputEdge(); // past_key tester.AddOptionalInputEdge(); // past_value - // 2D shape {1, 1} instead of {1} + // 2D shape {1, 1} — element count == batch_size, must be accepted. tester.AddInput("seqlens_k", {1, 1}, {0}); tester.AddInput("total_sequence_length", {1}, {1}); tester.AddOptionalInputEdge(); // cos_cache @@ -265,7 +267,8 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongRank) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k must be shape (batch_size)", + // Values are don't-care — just verify the shape check no longer rejects rank-2 seqlens_k. + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -303,7 +306,7 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongLength) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k must be shape (batch_size)", + tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k must contain batch_size elements", {}, nullptr, &execution_providers); }