From e57bd9dae5ed94160520e07d0a8dcb1086ae4d4d Mon Sep 17 00:00:00 2001 From: supermario_leo Date: Sat, 20 Jun 2026 13:07:59 +0800 Subject: [PATCH] fix(glm_asr): honor sampling params in vLLM generate() GLMASRVLLMEngine.generate() accepted **kwargs but hardcoded temperature=0 and never forwarded any sampling parameter, so callers could not control decoding. Its sibling Fun-ASR-Nano engine already exposes temperature/top_p/top_k/repetition_penalty. Mirror that surface on GLM-ASR: forward temperature, top_p and top_k into SamplingParams (preserving the previous greedy default), and guard repetition_penalty so a non-neutral value is coerced back to 1.0 (with a one-time warning). This engine runs vLLM in prompt-embeds mode, where a non-neutral repetition_penalty triggers the issue #2948 CUDA crash. Add CPU-only regression tests covering parameter forwarding, the greedy defaults, top_k normalization, and the repetition_penalty guard. --- funasr/models/glm_asr/inference_vllm.py | 46 ++++++++- tests/test_glm_asr_vllm_sampling.py | 119 ++++++++++++++++++++++++ 2 files changed, 163 insertions(+), 2 deletions(-) create mode 100644 tests/test_glm_asr_vllm_sampling.py diff --git a/funasr/models/glm_asr/inference_vllm.py b/funasr/models/glm_asr/inference_vllm.py index 73687a335..7360a73ab 100644 --- a/funasr/models/glm_asr/inference_vllm.py +++ b/funasr/models/glm_asr/inference_vllm.py @@ -26,6 +26,36 @@ logger = logging.getLogger(__name__) dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} +# Warn only once per process so batch loops do not spam the log. +_warned_rep_penalty = False + + +def _safe_repetition_penalty(repetition_penalty): + """Force ``repetition_penalty`` to the neutral value for prompt-embeds mode. + + GLM-ASR feeds vLLM precomputed embeddings (``enable_prompt_embeds=True``), so + a request carries no prompt token IDs. vLLM applies ``repetition_penalty`` by + scattering over those IDs, so any value other than 1.0 indexes an empty + token-id tensor and aborts the engine with a CUDA + ``scatter gather index out of bounds`` assertion (issue #2948). We therefore + warn once and fall back to the neutral value of 1.0. + """ + global _warned_rep_penalty + + if repetition_penalty is None or repetition_penalty == 1.0: + return 1.0 + + if not _warned_rep_penalty: + logger.warning( + "repetition_penalty=%s is not supported in vLLM prompt-embeds mode " + "(no prompt token IDs to penalize) and would trigger a CUDA scatter " + "index-out-of-bounds crash; using repetition_penalty=1.0 instead. " + "See https://github.com/modelscope/FunASR/issues/2948.", + repetition_penalty, + ) + _warned_rep_penalty = True + return 1.0 + def prepare_glmasr_vllm_dir(model_dir: str) -> str: """Extract language_model weights into vLLM-compatible Llama format.""" @@ -180,13 +210,22 @@ def _build_prompt_embeds(self, audio_embeds, prompt="转录以下音频内容"): audio_emb = audio_embeds[0] if audio_embeds.dim() == 3 else audio_embeds return torch.cat([prefix_emb, audio_emb, suffix_emb], dim=0) - def generate(self, inputs, prompt="转录以下音频内容", max_new_tokens=500, **kwargs): + def generate(self, inputs, prompt="转录以下音频内容", max_new_tokens=500, + temperature=0.0, top_p=1.0, top_k=-1, repetition_penalty=1.0, + **kwargs): """Run batch ASR inference. Args: inputs: Audio file path(s), numpy arrays, or tensors. prompt: Instruction prompt for ASR. max_new_tokens: Maximum tokens to generate per sample. + temperature: Sampling temperature (0 = greedy decoding). + top_p: Nucleus sampling parameter. + top_k: Top-k sampling (-1 = disabled). + repetition_penalty: Repetition penalty factor. Non-neutral values are + forced back to 1.0 here because this engine feeds vLLM precomputed + embeddings (``enable_prompt_embeds=True``); see + ``resolve_repetition_penalty`` and issue #2948. Returns: List of {"key": str, "text": str} @@ -202,7 +241,10 @@ def generate(self, inputs, prompt="转录以下音频内容", max_new_tokens=500 sampling_params = SamplingParams( max_tokens=max_new_tokens, - temperature=0, + temperature=temperature, + top_p=top_p, + top_k=top_k if top_k > 0 else -1, + repetition_penalty=_safe_repetition_penalty(repetition_penalty), skip_special_tokens=True, ) diff --git a/tests/test_glm_asr_vllm_sampling.py b/tests/test_glm_asr_vllm_sampling.py new file mode 100644 index 000000000..bd92c1c07 --- /dev/null +++ b/tests/test_glm_asr_vllm_sampling.py @@ -0,0 +1,119 @@ +"""Unit tests for GLM-ASR vLLM sampling-parameter handling. + +These tests exercise ``GLMASRVLLMEngine.generate`` without a GPU or a real +vLLM installation: the vLLM entry points are stubbed in ``sys.modules`` and the +audio/encoder/engine collaborators are mocked, so only the sampling-parameter +wiring is under test. +""" + +import re +import sys +import types +import unittest +from unittest import mock + + +def _install_vllm_stub(): + """Install a minimal ``vllm`` stub whose SamplingParams records kwargs.""" + + captured = {} + + class _RecordingSamplingParams: + def __init__(self, **kwargs): + captured.clear() + captured.update(kwargs) + + class _EmbedsPrompt: + def __init__(self, **kwargs): + self.kwargs = kwargs + + vllm_mod = types.ModuleType("vllm") + vllm_mod.SamplingParams = _RecordingSamplingParams + vllm_mod.LLM = object + inputs_mod = types.ModuleType("vllm.inputs") + inputs_mod.EmbedsPrompt = _EmbedsPrompt + data_mod = types.ModuleType("vllm.inputs.data") + data_mod.EmbedsPrompt = _EmbedsPrompt + + sys.modules["vllm"] = vllm_mod + sys.modules["vllm.inputs"] = inputs_mod + sys.modules["vllm.inputs.data"] = data_mod + return captured + + +class GLMASRSamplingParamsTest(unittest.TestCase): + def setUp(self): + self.captured = _install_vllm_stub() + from funasr.models.glm_asr.inference_vllm import GLMASRVLLMEngine + + # Build an engine without running __init__ (no model load / GPU needed). + engine = GLMASRVLLMEngine.__new__(GLMASRVLLMEngine) + engine.device = "cpu" + engine._encode_audio = mock.Mock(return_value="audio_embeds") + engine._build_prompt_embeds = mock.Mock( + return_value=mock.Mock(float=lambda: "embeds") + ) + + token_out = types.SimpleNamespace(token_ids=[1, 2, 3]) + vllm_output = types.SimpleNamespace(outputs=[token_out]) + engine.vllm_engine = mock.Mock() + engine.vllm_engine.generate = mock.Mock(return_value=[vllm_output]) + engine.tokenizer = mock.Mock() + engine.tokenizer.decode = mock.Mock(return_value="hello world") + self.engine = engine + + def test_defaults_preserve_greedy_behavior(self): + results = self.engine.generate("a.wav") + self.assertEqual(results, [{"key": "a", "text": "hello world"}]) + self.assertEqual(self.captured["temperature"], 0.0) + self.assertEqual(self.captured["top_p"], 1.0) + self.assertEqual(self.captured["top_k"], -1) + self.assertEqual(self.captured["repetition_penalty"], 1.0) + self.assertEqual(self.captured["max_tokens"], 500) + + def test_caller_sampling_params_are_forwarded(self): + self.engine.generate( + "a.wav", max_new_tokens=128, temperature=0.7, top_p=0.9, top_k=20 + ) + self.assertEqual(self.captured["max_tokens"], 128) + self.assertEqual(self.captured["temperature"], 0.7) + self.assertEqual(self.captured["top_p"], 0.9) + self.assertEqual(self.captured["top_k"], 20) + + def test_non_positive_top_k_is_normalized_to_disabled(self): + self.engine.generate("a.wav", top_k=0) + self.assertEqual(self.captured["top_k"], -1) + + def test_repetition_penalty_is_forced_neutral_in_prompt_embeds_mode(self): + # A non-neutral repetition_penalty would crash vLLM prompt-embeds mode + # (issue #2948), so it must be coerced back to 1.0 rather than forwarded. + self.engine.generate("a.wav", repetition_penalty=1.3) + self.assertEqual(self.captured["repetition_penalty"], 1.0) + + def test_neutral_repetition_penalty_passes_through(self): + self.engine.generate("a.wav", repetition_penalty=1.0) + self.assertEqual(self.captured["repetition_penalty"], 1.0) + + +class SafeRepetitionPenaltyTest(unittest.TestCase): + def setUp(self): + _install_vllm_stub() + import funasr.models.glm_asr.inference_vllm as mod + + self.mod = mod + # Reset the process-wide warn-once flag between tests. + mod._warned_rep_penalty = False + + def test_neutral_and_none_map_to_one(self): + self.assertEqual(self.mod._safe_repetition_penalty(1.0), 1.0) + self.assertEqual(self.mod._safe_repetition_penalty(None), 1.0) + + def test_non_neutral_is_coerced_and_warns_once(self): + with self.assertLogs(self.mod.logger, level="WARNING") as cm: + self.assertEqual(self.mod._safe_repetition_penalty(1.5), 1.0) + self.assertTrue(any("2948" in line for line in cm.output)) + self.assertTrue(self.mod._warned_rep_penalty) + + +if __name__ == "__main__": + unittest.main()