From 37d1e8541bd7c8e8a185848572a7bf1ef4a25548 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Thu, 4 Jun 2026 08:53:06 -0700 Subject: [PATCH] Handle missing validation_ds in RNNT BPE transcribe Signed-off-by: Bas Nijholt --- .../collections/asr/models/rnnt_bpe_models.py | 3 ++- .../asr/test_asr_rnnt_encoder_model_bpe.py | 27 ++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/nemo/collections/asr/models/rnnt_bpe_models.py b/nemo/collections/asr/models/rnnt_bpe_models.py index 779f03e3719d..eba07e8599ef 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models.py +++ b/nemo/collections/asr/models/rnnt_bpe_models.py @@ -528,6 +528,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json') batch_size = min(config['batch_size'], len(config['paths2audio_files'])) + validation_ds = self.cfg.get('validation_ds') or {} dl_config = { 'use_lhotse': config.get('use_lhotse', True), 'manifest_filepath': manifest_filepath, @@ -537,7 +538,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo 'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)), 'pin_memory': True, 'channel_selector': config.get('channel_selector', None), - 'use_start_end_token': self.cfg.validation_ds.get('use_start_end_token', False), + 'use_start_end_token': validation_ds.get('use_start_end_token', False), } if config.get("augmentor"): diff --git a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py index 344697b28d2a..301c1ce01d8b 100644 --- a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py @@ -20,7 +20,7 @@ import torch from lhotse import CutSet, MonoCut from lhotse.testing.dummies import DummyManifest -from omegaconf import DictConfig +from omegaconf import DictConfig, open_dict from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import ASRModel @@ -194,6 +194,31 @@ def test_predict_step(self, asr_model): assert isinstance(outputs[0][0], MonoCut) assert isinstance(outputs[0][1], Hypothesis) + @pytest.mark.unit + def test_transcribe_dataloader_allows_missing_validation_ds(self, asr_model, monkeypatch): + """Some pretrained RNNT BPE checkpoints are shipped without validation_ds.""" + with open_dict(asr_model.cfg): + asr_model.cfg.validation_ds = None + captured = {} + + def capture_dataloader_config(config): + captured["config"] = config + return object() + + monkeypatch.setattr(asr_model, "_setup_dataloader_from_config", capture_dataloader_config) + + with tempfile.TemporaryDirectory() as tmpdir: + dataloader = asr_model._setup_transcribe_dataloader( + { + "paths2audio_files": ["audio.wav"], + "batch_size": 1, + "temp_dir": tmpdir, + } + ) + + assert dataloader is not None + assert captured["config"]["use_start_end_token"] is False + @pytest.mark.with_downloads() @pytest.mark.skipif( not NUMBA_RNNT_LOSS_AVAILABLE,