Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions funasr/models/glm_asr/inference_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,36 @@
logger = logging.getLogger(__name__)
dtype_map = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}

# Warn only once per process so batch loops do not spam the log.
_warned_rep_penalty = False


def _safe_repetition_penalty(repetition_penalty):
"""Force ``repetition_penalty`` to the neutral value for prompt-embeds mode.

GLM-ASR feeds vLLM precomputed embeddings (``enable_prompt_embeds=True``), so
a request carries no prompt token IDs. vLLM applies ``repetition_penalty`` by
scattering over those IDs, so any value other than 1.0 indexes an empty
token-id tensor and aborts the engine with a CUDA
``scatter gather index out of bounds`` assertion (issue #2948). We therefore
warn once and fall back to the neutral value of 1.0.
"""
global _warned_rep_penalty

if repetition_penalty is None or repetition_penalty == 1.0:
return 1.0

if not _warned_rep_penalty:
logger.warning(
"repetition_penalty=%s is not supported in vLLM prompt-embeds mode "
"(no prompt token IDs to penalize) and would trigger a CUDA scatter "
"index-out-of-bounds crash; using repetition_penalty=1.0 instead. "
"See https://github.com/modelscope/FunASR/issues/2948.",
repetition_penalty,
)
_warned_rep_penalty = True
return 1.0


def prepare_glmasr_vllm_dir(model_dir: str) -> str:
"""Extract language_model weights into vLLM-compatible Llama format."""
Expand Down Expand Up @@ -180,13 +210,22 @@ def _build_prompt_embeds(self, audio_embeds, prompt="转录以下音频内容"):
audio_emb = audio_embeds[0] if audio_embeds.dim() == 3 else audio_embeds
return torch.cat([prefix_emb, audio_emb, suffix_emb], dim=0)

def generate(self, inputs, prompt="转录以下音频内容", max_new_tokens=500, **kwargs):
def generate(self, inputs, prompt="转录以下音频内容", max_new_tokens=500,
temperature=0.0, top_p=1.0, top_k=-1, repetition_penalty=1.0,
**kwargs):
"""Run batch ASR inference.

Args:
inputs: Audio file path(s), numpy arrays, or tensors.
prompt: Instruction prompt for ASR.
max_new_tokens: Maximum tokens to generate per sample.
temperature: Sampling temperature (0 = greedy decoding).
top_p: Nucleus sampling parameter.
top_k: Top-k sampling (-1 = disabled).
repetition_penalty: Repetition penalty factor. Non-neutral values are
forced back to 1.0 here because this engine feeds vLLM precomputed
embeddings (``enable_prompt_embeds=True``); see
``resolve_repetition_penalty`` and issue #2948.

Returns:
List of {"key": str, "text": str}
Expand All @@ -202,7 +241,10 @@ def generate(self, inputs, prompt="转录以下音频内容", max_new_tokens=500

sampling_params = SamplingParams(
max_tokens=max_new_tokens,
temperature=0,
temperature=temperature,
top_p=top_p,
top_k=top_k if top_k > 0 else -1,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If top_k is passed as None (which is common when passing optional configuration dictionaries), the comparison top_k > 0 will raise a TypeError: '>' not supported between instances of 'NoneType' and 'int'. We should explicitly check if top_k is not None before comparing it.

Suggested change
top_k=top_k if top_k > 0 else -1,
top_k=top_k if (top_k is not None and top_k > 0) else -1,

repetition_penalty=_safe_repetition_penalty(repetition_penalty),
skip_special_tokens=True,
)

Expand Down
119 changes: 119 additions & 0 deletions tests/test_glm_asr_vllm_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Unit tests for GLM-ASR vLLM sampling-parameter handling.

These tests exercise ``GLMASRVLLMEngine.generate`` without a GPU or a real
vLLM installation: the vLLM entry points are stubbed in ``sys.modules`` and the
audio/encoder/engine collaborators are mocked, so only the sampling-parameter
wiring is under test.
"""

import re
import sys
import types
import unittest
from unittest import mock


def _install_vllm_stub():
"""Install a minimal ``vllm`` stub whose SamplingParams records kwargs."""

captured = {}

class _RecordingSamplingParams:
def __init__(self, **kwargs):
captured.clear()
captured.update(kwargs)

class _EmbedsPrompt:
def __init__(self, **kwargs):
self.kwargs = kwargs

vllm_mod = types.ModuleType("vllm")
vllm_mod.SamplingParams = _RecordingSamplingParams
vllm_mod.LLM = object
inputs_mod = types.ModuleType("vllm.inputs")
inputs_mod.EmbedsPrompt = _EmbedsPrompt
data_mod = types.ModuleType("vllm.inputs.data")
data_mod.EmbedsPrompt = _EmbedsPrompt

sys.modules["vllm"] = vllm_mod
sys.modules["vllm.inputs"] = inputs_mod
sys.modules["vllm.inputs.data"] = data_mod
return captured


class GLMASRSamplingParamsTest(unittest.TestCase):
def setUp(self):
self.captured = _install_vllm_stub()
from funasr.models.glm_asr.inference_vllm import GLMASRVLLMEngine

# Build an engine without running __init__ (no model load / GPU needed).
engine = GLMASRVLLMEngine.__new__(GLMASRVLLMEngine)
engine.device = "cpu"
engine._encode_audio = mock.Mock(return_value="audio_embeds")
engine._build_prompt_embeds = mock.Mock(
return_value=mock.Mock(float=lambda: "embeds")
)

token_out = types.SimpleNamespace(token_ids=[1, 2, 3])
vllm_output = types.SimpleNamespace(outputs=[token_out])
engine.vllm_engine = mock.Mock()
engine.vllm_engine.generate = mock.Mock(return_value=[vllm_output])
engine.tokenizer = mock.Mock()
engine.tokenizer.decode = mock.Mock(return_value="hello world")
self.engine = engine

def test_defaults_preserve_greedy_behavior(self):
results = self.engine.generate("a.wav")
self.assertEqual(results, [{"key": "a", "text": "hello world"}])
self.assertEqual(self.captured["temperature"], 0.0)
self.assertEqual(self.captured["top_p"], 1.0)
self.assertEqual(self.captured["top_k"], -1)
self.assertEqual(self.captured["repetition_penalty"], 1.0)
self.assertEqual(self.captured["max_tokens"], 500)

def test_caller_sampling_params_are_forwarded(self):
self.engine.generate(
"a.wav", max_new_tokens=128, temperature=0.7, top_p=0.9, top_k=20
)
self.assertEqual(self.captured["max_tokens"], 128)
self.assertEqual(self.captured["temperature"], 0.7)
self.assertEqual(self.captured["top_p"], 0.9)
self.assertEqual(self.captured["top_k"], 20)

def test_non_positive_top_k_is_normalized_to_disabled(self):
self.engine.generate("a.wav", top_k=0)
self.assertEqual(self.captured["top_k"], -1)
Comment on lines +83 to +85

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Let's also test the case where top_k is None to ensure it is correctly normalized to -1 without raising a TypeError.

Suggested change
def test_non_positive_top_k_is_normalized_to_disabled(self):
self.engine.generate("a.wav", top_k=0)
self.assertEqual(self.captured["top_k"], -1)
def test_non_positive_top_k_is_normalized_to_disabled(self):
self.engine.generate("a.wav", top_k=0)
self.assertEqual(self.captured["top_k"], -1)
self.engine.generate("a.wav", top_k=None)
self.assertEqual(self.captured["top_k"], -1)


def test_repetition_penalty_is_forced_neutral_in_prompt_embeds_mode(self):
# A non-neutral repetition_penalty would crash vLLM prompt-embeds mode
# (issue #2948), so it must be coerced back to 1.0 rather than forwarded.
self.engine.generate("a.wav", repetition_penalty=1.3)
self.assertEqual(self.captured["repetition_penalty"], 1.0)

def test_neutral_repetition_penalty_passes_through(self):
self.engine.generate("a.wav", repetition_penalty=1.0)
self.assertEqual(self.captured["repetition_penalty"], 1.0)


class SafeRepetitionPenaltyTest(unittest.TestCase):
def setUp(self):
_install_vllm_stub()
import funasr.models.glm_asr.inference_vllm as mod

self.mod = mod
# Reset the process-wide warn-once flag between tests.
mod._warned_rep_penalty = False

def test_neutral_and_none_map_to_one(self):
self.assertEqual(self.mod._safe_repetition_penalty(1.0), 1.0)
self.assertEqual(self.mod._safe_repetition_penalty(None), 1.0)

def test_non_neutral_is_coerced_and_warns_once(self):
with self.assertLogs(self.mod.logger, level="WARNING") as cm:
self.assertEqual(self.mod._safe_repetition_penalty(1.5), 1.0)
self.assertTrue(any("2948" in line for line in cm.output))
self.assertTrue(self.mod._warned_rep_penalty)


if __name__ == "__main__":
unittest.main()