diff --git a/backends/webgpu/test/test_webgpu_native.cpp b/backends/webgpu/test/test_webgpu_native.cpp index 338ecb39913..fea5b591a7e 100644 --- a/backends/webgpu/test/test_webgpu_native.cpp +++ b/backends/webgpu/test/test_webgpu_native.cpp @@ -662,32 +662,46 @@ static bool test_sdpa_config( } const auto& outputs = result.get(); - // The mutating op returns [k_cache, v_cache, attn_output]; select the - // attention output (numel == S*Hq*D), not a mutated cache (numel Cmax*Hkv*D). - // Count matches and fail if ambiguous: a cache could share the same numel. - int attn_idx = -1; - int attn_matches = 0; - for (size_t i = 0; i < outputs.size(); i++) { - if (outputs[i].isTensor() && outputs[i].toTensor().numel() == on) { - attn_idx = static_cast(i); - attn_matches++; - } - } - if (attn_idx < 0) { + // sdpa_with_kv_cache mutates k_cache/v_cache in place and returns the + // attention output. ExecuTorch emits program outputs as + // [*mutated_inputs, *user_outputs], so forward() returns exactly + // [k_cache, v_cache, attn_output]: three tensors, attention output last. + // Selecting by element count is unsafe: when S*Hq*D == Cmax*Hkv*D + // (e.g. llama1b_prefill, all 262144) the attention output and both caches + // share numel. Select by the documented position instead, and assert the + // output count and per-slot numels so a future change in output structure + // still fails loudly. + if (outputs.size() != 3) { printf( - "FAIL: no attention output (numel %d) among %zu outputs\n", - on, + "FAIL: expected 3 outputs [k_cache, v_cache, attn_output], got %zu\n", outputs.size()); return false; } - if (attn_matches > 1) { + for (size_t i = 0; i < outputs.size(); i++) { + if (!outputs[i].isTensor()) { + printf("FAIL: output %zu is not a tensor\n", i); + return false; + } + } + // Outputs 0 and 1 are the mutated k/v caches (numel cn); output 2 is attn. + for (int i = 0; i < 2; i++) { + if (outputs[i].toTensor().numel() != cn) { + printf( + "FAIL: output %d (expected k/v cache) numel %zu != Cmax*Hkv*D %d\n", + i, + (size_t)outputs[i].toTensor().numel(), + cn); + return false; + } + } + const auto& out_tensor = outputs[2].toTensor(); + if (out_tensor.numel() != on) { printf( - "FAIL: ambiguous attention output: %d tensors match numel %d\n", - attn_matches, + "FAIL: attention output numel %zu != S*Hq*D %d\n", + (size_t)out_tensor.numel(), on); return false; } - const auto& out_tensor = outputs[attn_idx].toTensor(); const float* out_data = out_tensor.const_data_ptr(); std::vector golden = load_golden(golden_path, on); @@ -803,25 +817,47 @@ static bool test_sdpa_replay(const SdpaSequence& seq, const std::string& dir) { } const auto& outs = result.get(); - // The op returns [k_cache, v_cache, attn_output]: attn has a unique numel; - // the two caches share numel cn, so identify them by content at step 0. - int attn_idx = -1; - std::vector cache_idxs; - for (size_t i = 0; i < outs.size(); i++) { - if (!outs[i].isTensor()) { - continue; - } - const int ne = static_cast(outs[i].toTensor().numel()); - if (ne == qn) { - attn_idx = static_cast(i); - } else if (ne == cn) { - cache_idxs.push_back(static_cast(i)); - } + // Outputs are [k_cache, v_cache, attn_output]: ExecuTorch emits + // [*mutated_inputs, *user_outputs], so the two mutated caches come first + // (signature order k, v) and the attention output last. Select by position; + // numel is ambiguous when the attn output and caches share numel. The k/v + // caches are still disambiguated by content at step 0 below. + if (outs.size() != 3 || !outs[0].isTensor() || !outs[1].isTensor() || + !outs[2].isTensor()) { + printf( + "FAIL: %s step%zu: expected 3 tensor outputs " + "[k_cache, v_cache, attn_output], got %zu\n", + seq.name, + t, + outs.size()); + return false; } - if (attn_idx < 0 || cache_idxs.size() != 2) { - printf("FAIL: %s step%zu: expected 1 attn + 2 caches\n", seq.name, t); + const int attn_idx = 2; + const std::vector cache_idxs = {0, 1}; + if (static_cast(outs[attn_idx].toTensor().numel()) != qn) { + printf( + "FAIL: %s step%zu: attn output numel %zu != expected %d\n", + seq.name, + t, + (size_t)outs[attn_idx].toTensor().numel(), + qn); return false; } + // Caches must be full-size (numel cn): step-0 identification and cross-step + // threading read cn/kvn elements from them, so a short tensor would be an + // out-of-bounds read rather than a clean failure. + for (int ci : cache_idxs) { + if (static_cast(outs[ci].toTensor().numel()) != cn) { + printf( + "FAIL: %s step%zu: cache output %d numel %zu != Cmax*Hkv*D %d\n", + seq.name, + t, + ci, + (size_t)outs[ci].toTensor().numel(), + cn); + return false; + } + } if (t == 0) { const float* c0 = outs[cache_idxs[0]].toTensor().const_data_ptr(); @@ -974,23 +1010,47 @@ static bool test_sdpa_dynamic_decode( } const auto& outs = result.get(); - int attn_idx = -1; - std::vector cache_idxs; - for (size_t i = 0; i < outs.size(); i++) { - if (!outs[i].isTensor()) { - continue; - } - const int ne = static_cast(outs[i].toTensor().numel()); - if (ne == qn) { - attn_idx = static_cast(i); - } else if (ne == cn) { - cache_idxs.push_back(static_cast(i)); - } + // Outputs are [k_cache, v_cache, attn_output]: ExecuTorch emits + // [*mutated_inputs, *user_outputs], so the two mutated caches come first + // (signature order k, v) and the attention output last. Select by position; + // numel is ambiguous when the attn output and caches share numel. The k/v + // caches are still disambiguated by content at step 0 below. + if (outs.size() != 3 || !outs[0].isTensor() || !outs[1].isTensor() || + !outs[2].isTensor()) { + printf( + "FAIL: %s step%d: expected 3 tensor outputs " + "[k_cache, v_cache, attn_output], got %zu\n", + seq.name, + t, + outs.size()); + return false; } - if (attn_idx < 0 || cache_idxs.size() != 2) { - printf("FAIL: %s step%d: expected 1 attn + 2 caches\n", seq.name, t); + const int attn_idx = 2; + const std::vector cache_idxs = {0, 1}; + if (static_cast(outs[attn_idx].toTensor().numel()) != qn) { + printf( + "FAIL: %s step%d: attn output numel %zu != expected %d\n", + seq.name, + t, + (size_t)outs[attn_idx].toTensor().numel(), + qn); return false; } + // Caches must be full-size (numel cn): step-0 identification and cross-step + // threading read cn/kvn elements from them, so a short tensor would be an + // out-of-bounds read rather than a clean failure. + for (int ci : cache_idxs) { + if (static_cast(outs[ci].toTensor().numel()) != cn) { + printf( + "FAIL: %s step%d: cache output %d numel %zu != Cmax*Hkv*D %d\n", + seq.name, + t, + ci, + (size_t)outs[ci].toTensor().numel(), + cn); + return false; + } + } if (t == 0) { const float* c0 = outs[cache_idxs[0]].toTensor().const_data_ptr(); const float* c1 = outs[cache_idxs[1]].toTensor().const_data_ptr();