Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion nemo/collections/asr/models/rnnt_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
manifest_filepath = os.path.join(config['temp_dir'], 'manifest.json')
batch_size = min(config['batch_size'], len(config['paths2audio_files']))

validation_ds = self.cfg.get('validation_ds') or {}
dl_config = {
'use_lhotse': config.get('use_lhotse', True),
'manifest_filepath': manifest_filepath,
Expand All @@ -537,7 +538,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
'num_workers': config.get('num_workers', min(batch_size, os.cpu_count() - 1)),
'pin_memory': True,
'channel_selector': config.get('channel_selector', None),
'use_start_end_token': self.cfg.validation_ds.get('use_start_end_token', False),
'use_start_end_token': validation_ds.get('use_start_end_token', False),
}

if config.get("augmentor"):
Expand Down
27 changes: 26 additions & 1 deletion tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
from lhotse import CutSet, MonoCut
from lhotse.testing.dummies import DummyManifest
from omegaconf import DictConfig
from omegaconf import DictConfig, open_dict

from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset
from nemo.collections.asr.models import ASRModel
Expand Down Expand Up @@ -194,6 +194,31 @@ def test_predict_step(self, asr_model):
assert isinstance(outputs[0][0], MonoCut)
assert isinstance(outputs[0][1], Hypothesis)

@pytest.mark.unit
def test_transcribe_dataloader_allows_missing_validation_ds(self, asr_model, monkeypatch):
"""Some pretrained RNNT BPE checkpoints are shipped without validation_ds."""
with open_dict(asr_model.cfg):
asr_model.cfg.validation_ds = None
captured = {}

def capture_dataloader_config(config):
captured["config"] = config
return object()

monkeypatch.setattr(asr_model, "_setup_dataloader_from_config", capture_dataloader_config)

with tempfile.TemporaryDirectory() as tmpdir:
dataloader = asr_model._setup_transcribe_dataloader(
{
"paths2audio_files": ["audio.wav"],
"batch_size": 1,
"temp_dir": tmpdir,
}
)

assert dataloader is not None
assert captured["config"]["use_start_end_token"] is False

@pytest.mark.with_downloads()
@pytest.mark.skipif(
not NUMBA_RNNT_LOSS_AVAILABLE,
Expand Down
Loading