fix(glm_asr): honor sampling params in vLLM generate()#2997
fix(glm_asr): honor sampling params in vLLM generate()#2997SuperMarioYL wants to merge 1 commit into
Conversation
GLMASRVLLMEngine.generate() accepted **kwargs but hardcoded temperature=0 and never forwarded any sampling parameter, so callers could not control decoding. Its sibling Fun-ASR-Nano engine already exposes temperature/top_p/top_k/repetition_penalty. Mirror that surface on GLM-ASR: forward temperature, top_p and top_k into SamplingParams (preserving the previous greedy default), and guard repetition_penalty so a non-neutral value is coerced back to 1.0 (with a one-time warning). This engine runs vLLM in prompt-embeds mode, where a non-neutral repetition_penalty triggers the issue modelscope#2948 CUDA crash. Add CPU-only regression tests covering parameter forwarding, the greedy defaults, top_k normalization, and the repetition_penalty guard.
There was a problem hiding this comment.
Code Review
This pull request adds sampling parameters (temperature, top_p, top_k, and repetition_penalty) to the GLM-ASR vLLM inference engine, along with corresponding unit tests. It also introduces a safety check to prevent a CUDA crash when using repetition penalty in prompt-embeds mode. The review feedback points out a potential TypeError if top_k is passed as None and suggests adding a safe check and an associated unit test.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| temperature=0, | ||
| temperature=temperature, | ||
| top_p=top_p, | ||
| top_k=top_k if top_k > 0 else -1, |
There was a problem hiding this comment.
If top_k is passed as None (which is common when passing optional configuration dictionaries), the comparison top_k > 0 will raise a TypeError: '>' not supported between instances of 'NoneType' and 'int'. We should explicitly check if top_k is not None before comparing it.
| top_k=top_k if top_k > 0 else -1, | |
| top_k=top_k if (top_k is not None and top_k > 0) else -1, |
| def test_non_positive_top_k_is_normalized_to_disabled(self): | ||
| self.engine.generate("a.wav", top_k=0) | ||
| self.assertEqual(self.captured["top_k"], -1) |
There was a problem hiding this comment.
Let's also test the case where top_k is None to ensure it is correctly normalized to -1 without raising a TypeError.
| def test_non_positive_top_k_is_normalized_to_disabled(self): | |
| self.engine.generate("a.wav", top_k=0) | |
| self.assertEqual(self.captured["top_k"], -1) | |
| def test_non_positive_top_k_is_normalized_to_disabled(self): | |
| self.engine.generate("a.wav", top_k=0) | |
| self.assertEqual(self.captured["top_k"], -1) | |
| self.engine.generate("a.wav", top_k=None) | |
| self.assertEqual(self.captured["top_k"], -1) |
Summary
GLMASRVLLMEngine.generate()accepts**kwargsbut hardcodestemperature=0and never forwards any sampling parameter intoSamplingParams, so callers cannot control decoding — anything they pass (temperature, top_p, top_k, …) is silently dropped.The sibling Fun-ASR-Nano vLLM engine already exposes these parameters (
temperature/top_p/top_k/repetition_penalty). This PR brings GLM-ASR to parity.Changes
temperature,top_p,top_k,repetition_penaltyarguments togenerate()and forward them intoSamplingParams.temperature=0.0,top_p=1.0,top_k=-1,repetition_penalty=1.0) reproduce the previous greedy-decoding behavior exactly, so existing callers are unaffected.top_kto-1(disabled), mirroring Fun-ASR-Nano.repetition_penalty: this engine runs vLLM in prompt-embeds mode (enable_prompt_embeds=True), where a non-neutral repetition penalty has no prompt token IDs to scatter over and crashes the engine with a CUDAscatter gather index out of boundsassertion (the same failure mode fixed for Fun-ASR-Nano in vllm离线服务推理,长音频报错 #2948). A non-neutral value is coerced back to1.0with a one-time warning.Testing
Added
tests/test_glm_asr_vllm_sampling.py(CPU-only, vLLM stubbed) covering:SamplingParamstop_kis normalized to disabledrepetition_penaltyguard coerces non-neutral values and warns onceVerified red on
main(params dropped) and green on this branch.