Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,13 @@ Status CheckInputs(const T* query,
"Input 'past_key' and 'past_value' shall be both present or both absent.");
}

const auto& seqlens_k_dim = seqlens_k->Shape().GetDims();
if (seqlens_k_dim.size() != 1 || seqlens_k_dim[0] != batch_size) {
// Accept any shape whose total element count equals batch_size (e.g. {batch_size},
// {batch_size, 1}, or scalar when batch_size == 1). The per-element bounds check in
// group_query_attention.cc still enforces 0 <= seqlens_k[b] < present_kv_seqlen,
// preserving the OOB protection added in PR #28031.
if (seqlens_k->Shape().Size() != static_cast<int64_t>(batch_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k must be shape (batch_size).");
"seqlens_k must contain batch_size elements.");
}

if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) {
Expand Down
13 changes: 8 additions & 5 deletions onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,10 @@ TEST(GroupQueryAttentionTest, TotalSeqLenNegative) {
"total_sequence_length must be positive");
}

// Shape validation: seqlens_k with wrong rank (2D instead of 1D) must be rejected.
TEST(GroupQueryAttentionTest, SeqlensKWrongRank) {
// Shape compatibility: seqlens_k with rank 2 ({batch_size, 1}) must be accepted because
// the element count matches batch_size. Per-element bounds are still enforced separately.
// Regression for GQA models exported with a trailing singleton dim (e.g. Phi-3 / Phi-3.5).
TEST(GroupQueryAttentionTest, SeqlensKRank2SingletonAccepted) {
constexpr int num_heads = 1;
constexpr int kv_num_heads = 1;
constexpr int head_size = 8;
Expand All @@ -248,7 +250,7 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongRank) {
tester.AddInput<float>("value", {1, 1, kv_hidden_size}, std::vector<float>(kv_hidden_size, 1.0f));
tester.AddOptionalInputEdge<float>(); // past_key
tester.AddOptionalInputEdge<float>(); // past_value
// 2D shape {1, 1} instead of {1}
// 2D shape {1, 1} — element count == batch_size, must be accepted.
tester.AddInput<int32_t>("seqlens_k", {1, 1}, {0});
tester.AddInput<int32_t>("total_sequence_length", {1}, {1});
tester.AddOptionalInputEdge<float>(); // cos_cache
Expand All @@ -265,7 +267,8 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongRank) {

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k must be shape (batch_size)",
// Values are don't-care — just verify the shape check no longer rejects rank-2 seqlens_k.
tester.Run(OpTester::ExpectResult::kExpectSuccess, "",
{}, nullptr, &execution_providers);
}

Expand Down Expand Up @@ -303,7 +306,7 @@ TEST(GroupQueryAttentionTest, SeqlensKWrongLength) {

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k must be shape (batch_size)",
tester.Run(OpTester::ExpectResult::kExpectFailure, "seqlens_k must contain batch_size elements",
{}, nullptr, &execution_providers);
}

Expand Down
Loading