From 5ce4c9a4012776ee9e42d04b8b1ea2eaac358d5e Mon Sep 17 00:00:00 2001 From: supermario_leo Date: Fri, 19 Jun 2026 20:00:25 +0800 Subject: [PATCH] fix(glm_asr): warn when vLLM dtype=fp16 (degraded output) The GLM-ASR vLLM engine accepts dtype='fp16' but, like Fun-ASR-Nano, fp16 can produce degraded or garbage transcription due to numerical overflow in the audio embedding path. Fun-ASR-Nano already warns about this; GLM-ASR did not, so users hitting it got silently poor output with no hint to switch to bf16/fp32. Add a small dependency-free helper that emits a one-time warning when fp16 is requested (the value is still honoured for GPUs that only support fp16), wire it into GLMASRVLLMEngine.__init__, and cover it with unit tests that run without a GPU or vLLM. --- funasr/models/glm_asr/inference_vllm.py | 4 +- funasr/models/glm_asr/vllm_utils.py | 43 ++++++++++++++++++ tests/test_glm_asr_vllm_dtype.py | 60 +++++++++++++++++++++++++ 3 files changed, 106 insertions(+), 1 deletion(-) create mode 100644 funasr/models/glm_asr/vllm_utils.py create mode 100644 tests/test_glm_asr_vllm_dtype.py diff --git a/funasr/models/glm_asr/inference_vllm.py b/funasr/models/glm_asr/inference_vllm.py index 73687a335..d0a7250de 100644 --- a/funasr/models/glm_asr/inference_vllm.py +++ b/funasr/models/glm_asr/inference_vllm.py @@ -101,8 +101,10 @@ def __init__(self, model_dir, device="cuda:0", dtype="bf16", from vllm import LLM from transformers import AutoProcessor, AutoConfig, AutoModel as HFAutoModel + from funasr.models.glm_asr.vllm_utils import warn_if_degraded_dtype + self.device = device - self.torch_dtype = dtype_map.get(dtype, torch.bfloat16) + self.torch_dtype = dtype_map.get(warn_if_degraded_dtype(dtype), torch.bfloat16) self.model_dir = model_dir logger.info(f"Loading GLM-ASR audio components from {model_dir}") diff --git a/funasr/models/glm_asr/vllm_utils.py b/funasr/models/glm_asr/vllm_utils.py new file mode 100644 index 000000000..2828e9be5 --- /dev/null +++ b/funasr/models/glm_asr/vllm_utils.py @@ -0,0 +1,43 @@ +"""Helpers for the GLM-ASR vLLM serving path. + +Kept dependency-free (standard library only) so the dtype guard can be unit +tested without a CUDA device, a torch build, or a vLLM installation. +""" + +import logging + +logger = logging.getLogger("funasr.glm_asr.vllm") + +# Compute dtype that is known to degrade GLM-ASR transcription quality. +DEGRADED_DTYPE = "fp16" + +# Warn only once per process so batch loops do not spam the log. +_warned_fp16 = False + + +def warn_if_degraded_dtype(dtype): + """Warn once when a compute dtype is known to degrade GLM-ASR output. + + ``fp16`` can produce degraded or garbage transcription for GLM-ASR + (numerical overflow in the audio embedding path), matching the documented + Fun-ASR-Nano behaviour. The value is still honoured -- some GPUs only + support fp16 -- but the caller is warned once about why output may be poor. + + Args: + dtype: Requested compute dtype string ("bf16", "fp16", "fp32"). + + Returns: + ``dtype`` unchanged, so callers can wrap the value inline. + """ + global _warned_fp16 + + if dtype == DEGRADED_DTYPE and not _warned_fp16: + logger.warning( + "dtype='fp16' can produce degraded or garbage transcription for " + "GLM-ASR (numerical overflow in the audio embedding path). " + "Use dtype='bf16' (recommended) or dtype='fp32'. On GPUs without " + "bfloat16 support (e.g. NVIDIA V100), use 'fp32'." + ) + _warned_fp16 = True + + return dtype diff --git a/tests/test_glm_asr_vllm_dtype.py b/tests/test_glm_asr_vllm_dtype.py new file mode 100644 index 000000000..d426b4663 --- /dev/null +++ b/tests/test_glm_asr_vllm_dtype.py @@ -0,0 +1,60 @@ +"""Unit tests for the GLM-ASR vLLM fp16 degraded-output guard. + +Regression guard: ``fp16`` can produce degraded or garbage transcription for +GLM-ASR (numerical overflow in the audio embedding path), mirroring the +documented Fun-ASR-Nano behaviour. Requesting it must warn once so users are +not silently handed poor output, while leaving the requested value untouched. + +The helper is dependency-free, so these tests run without a GPU, torch, or +vLLM. +""" + +import logging +import unittest + +from funasr.models.glm_asr import vllm_utils +from funasr.models.glm_asr.vllm_utils import warn_if_degraded_dtype + + +class TestWarnIfDegradedDtype(unittest.TestCase): + def setUp(self): + # Reset the once-per-process warning flag so each test is independent. + vllm_utils._warned_fp16 = False + + def test_returns_value_unchanged(self): + for dtype in ("bf16", "fp16", "fp32", "something-else"): + self.assertEqual(warn_if_degraded_dtype(dtype), dtype) + + def test_fp16_warns(self): + with self.assertLogs(vllm_utils.logger, level=logging.WARNING) as cm: + warn_if_degraded_dtype("fp16") + self.assertEqual(len(cm.records), 1) + self.assertIn("fp16", cm.output[0]) + + def test_fp16_warns_only_once(self): + with self.assertLogs(vllm_utils.logger, level=logging.WARNING) as cm: + warn_if_degraded_dtype("fp16") + # Subsequent calls must not emit additional warnings. + warn_if_degraded_dtype("fp16") + self.assertEqual(len(cm.records), 1) + + def test_safe_values_do_not_warn(self): + # Capture records directly (assertNoLogs is only available on 3.10+). + records = [] + + class _Collect(logging.Handler): + def emit(self, record): + records.append(record) + + handler = _Collect(level=logging.WARNING) + vllm_utils.logger.addHandler(handler) + try: + warn_if_degraded_dtype("bf16") + warn_if_degraded_dtype("fp32") + finally: + vllm_utils.logger.removeHandler(handler) + self.assertEqual(records, []) + + +if __name__ == "__main__": + unittest.main()