Skip to content

feat: Implement batch processing for the MDXC separator#262

Closed
pedroalmeida415 wants to merge 6 commits intonomadkaraoke:mainfrom
pedroalmeida415:main
Closed

feat: Implement batch processing for the MDXC separator#262
pedroalmeida415 wants to merge 6 commits intonomadkaraoke:mainfrom
pedroalmeida415:main

Conversation

@pedroalmeida415
Copy link

@pedroalmeida415 pedroalmeida415 commented Mar 12, 2026

Motivation

This PR properly implements batch processing for the MDXC Arch models, addressing a massive processing bottleneck. Previously, setting batch_size for MDXC separator did not actually influence the execution loop for Roformer models, causing the GPU to be severely underutilized on high-overlap segment sliding window operations.

By restructuring the iteration to leverage torch.utils.data.DataLoader over a custom RoformerDataset, this implementation lazily calculates coordinate slice bounds instead of redundantly slicing massive sections of audio into RAM.

Why this approach?

  1. Massive Throughput Boost: Processing one slice at a time on modern GPUs leaves thousands of CUDA cores idling. Using batches allows for parallel computation over the STFT representations simultaneously.
  2. Fixed Memory Footprint: The previous methodology for overlap-adding could spike RAM or VRAM based on the audio length or the overlap amount. A lazy-loading Dataset handles indices behind the scenes, ensuring deterministic memory bounds regardless of track size.
  3. PCIe Offloading: Passing grouped tensors asynchronously limits the slow host-to-device memory transfer overhead that plagued the frame-by-frame loop.
  4. Weighted Overlap-Add Reassembly: Similarly to industry standards like MONAI's SlidingWindowInferer or Asteroid's LambdaOverlapAdd, overlapping regions are gracefully interpolated through a hamming window weighting buffer to completely nullify any clipping or "seams" at segment boundaries.

Benchmarks

File Details: ~22 minutes of PCM_16 audio
Model Details: mel_band_roformer_kim_ft_unwa.ckpt

Using standard inference parameters

Metric Previous Approach New DataLoader Approach
Execution Time ~5-6 minutes ~2 minutes 1 second
Memory Usage N/A 6.8 GB VRAM batch_size=4

Speedup is ~2.75x

Inference use_autocast=True

Metric Previous Approach New DataLoader Approach
Execution Time ~2 minutes 34 seconds ~1 minute 9 seconds
Memory Usage N/A 5.5 GB VRAM batch_size=4

Speedup is ~2.2x

System Environment

  • Kernel: Linux 6.19.6-2-cachyos
  • CPU: Intel(R) Core(TM) i7-14700HX (28) @ 5.50 GHz
  • GPU 1: NVIDIA GeForce RTX 4070 Max-Q / Mobile [Discrete]
  • Memory: 32 GiB

Summary by CodeRabbit

  • New Features

    • Added --mdxc_num_workers CLI option to configure background worker threads for MDXC processing (default: 0).
  • Refactor

    • MDXC processing moved to batched, worker-backed flow for better throughput and memory use; logs now report configured worker count.
    • Improved handling of audio chunking and accumulation for more reliable results.
  • Tests

    • Added unit tests for MDXC chunking behavior and the new CLI worker parameter.
  • Documentation

    • CLI and README updated to document the new worker option.

@coderabbitai
Copy link

coderabbitai bot commented Mar 12, 2026

Walkthrough

Replaces per-step Roformer demux with a DataLoader-based batching path: adds RoformerDataset for chunked audio, MDXC separator records/logs num_workers, CLI exposes --mdxc_num_workers, tests and README updated to cover the new parameter and dataset behavior.

Changes

Cohort / File(s) Summary
Roformer batching & dataset
audio_separator/separator/architectures/mdxc_separator.py
Add RoformerDataset(Dataset) with __init__, __len__, __getitem__. Replace manual per-step loop with DataLoader batching (uses batch_size, num_workers, pin_memory), propagate start indices/lengths, and perform overlap-add accumulation on CPU.
Separator defaults & logging
audio_separator/separator/separator.py
Include "num_workers": 0 in default mdxc_params; MDXCSeparator records/logs num_workers. Minor log message formatting changes.
CLI parameter
audio_separator/utils/cli.py
Add --mdxc_num_workers CLI argument (default 0) and pass mdxc_num_workers into mdxc_params for Separator.
Unit tests: CLI
tests/unit/test_cli.py
Add test(s) asserting --mdxc_num_workers is forwarded to Separator; update common_expected_args to include "num_workers": 0; remove an unused import.
Unit tests: Roformer chunking
tests/unit/test_mdxc_roformer_chunking.py
Add TestRoformerDataset tests (index deduplication, tail remapping, short-audio handling, exact overlap); import and exercise RoformerDataset; tighten some existing assertions.
Docs / README
README.md
Document new --mdxc_num_workers option and show default mdxc_params including "num_workers": 0.

Sequence Diagram

sequenceDiagram
    participant Sep as Separator
    participant DS as RoformerDataset
    participant DL as DataLoader
    participant Model as RoformerModel
    participant OA as OverlapAdd

    Sep->>DS: Create(mix, chunk_size, step)
    Sep->>DL: Create(DS, batch_size, num_workers, pin_memory)
    loop per-batch
      DL->>DS: __getitem__ (batch fetch)
      DS-->>DL: (chunks, start_indices, lengths)
      DL-->>Sep: Batch of chunks
      Sep->>Model: Move batch to device and forward
      Model-->>Sep: Predictions (batch)
      Sep->>OA: Overlap-add using start_indices and lengths (accumulate on CPU)
    end
    OA-->>Sep: Final demixed outputs
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

🐇 I nibble timestamps, hop by hop,

Chunks in baskets, never stop,
Workers hum and batches play,
Overlap-add folds night to day,
A rabbit cheers the new delay.

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 77.78% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title accurately describes the main implementation: batch processing for MDXC separator using DataLoader and RoformerDataset, which is the core change across multiple files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
audio_separator/separator/architectures/mdxc_separator.py (1)

347-364: Good implementation of batched inference with DataLoader.

The DataLoader integration with pin_memory for CUDA and the batch-wise overlap-add logic is well structured.

Minor suggestion: Consider using xs.shape[0] instead of len(xs) at line 354 for clearer tensor semantics.

♻️ Optional: Use tensor shape accessor
-                for b in range(len(xs)):
+                for b in range(xs.shape[0]):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/architectures/mdxc_separator.py` around lines 347 -
364, Replace the Python built-in len(xs) with the tensor shape accessor
xs.shape[0] in the batch loop to use tensor semantics; in mdxc_separator.py
within the batched inference loop that calls self.model_run(parts) and then
iterates over outputs, change "for b in range(len(xs)):" to "for b in
range(xs.shape[0]):" so the loop uses the tensor's first-dimension size reliably
when xs is a tensor returned by self.model_run.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@audio_separator/separator/architectures/mdxc_separator.py`:
- Around line 17-40: The RoformerDataset.__getitem__ last-chunk handling can
produce a negative start_idx and an incorrect length when mix.shape[1] <
chunk_size; update the block to clamp start_idx with start_idx = max(0,
self.mix.shape[1] - self.chunk_size) and set length from the actual slice
(length = part.shape[-1]) after computing part, so that for very short audio the
returned part, start_idx, and length are consistent and non-negative (refer to
RoformerDataset, __getitem__, self.mix, chunk_size, and length).

---

Nitpick comments:
In `@audio_separator/separator/architectures/mdxc_separator.py`:
- Around line 347-364: Replace the Python built-in len(xs) with the tensor shape
accessor xs.shape[0] in the batch loop to use tensor semantics; in
mdxc_separator.py within the batched inference loop that calls
self.model_run(parts) and then iterates over outputs, change "for b in
range(len(xs)):" to "for b in range(xs.shape[0]):" so the loop uses the tensor's
first-dimension size reliably when xs is a tensor returned by self.model_run.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: fc0d3223-134e-403d-8fbc-c4caa174747a

📥 Commits

Reviewing files that changed from the base of the PR and between 12f8fc6 and ec71756.

📒 Files selected for processing (3)
  • audio_separator/separator/architectures/mdxc_separator.py
  • audio_separator/separator/separator.py
  • audio_separator/utils/cli.py

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
tests/unit/test_cli.py (1)

31-31: Add one non-default --mdxc_num_workers test.

Line 31 only covers the default 0. If the CLI accepted --mdxc_num_workers but still always forwarded 0, this suite would still pass. A small targeted test with a non-default value would close that gap.

def test_cli_mdxc_num_workers_argument(common_expected_args):
    test_args = ["cli.py", "test_audio.mp3", "--mdxc_num_workers=2"]
    with patch("sys.argv", test_args):
        with patch("audio_separator.separator.Separator") as mock_separator:
            mock_separator.return_value.separate.return_value = ["output_file.mp3"]
            main()

            expected_args = common_expected_args.copy()
            expected_args["mdxc_params"] = {
                **common_expected_args["mdxc_params"],
                "num_workers": 2,
            }
            mock_separator.assert_called_once_with(**expected_args)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/test_cli.py` at line 31, Add a new unit test that verifies the CLI
forwards a non-default --mdxc_num_workers value into the mdxc_params passed to
Separator: call main() with argv containing "--mdxc_num_workers=2", patch
audio_separator.separator.Separator, assert Separator was called once and that
its mdxc_params dict has "num_workers": 2 (keeping other mdxc_params from
common_expected_args); reference main(), Separator, and the mdxc_params key when
implementing the test.
audio_separator/separator/architectures/mdxc_separator.py (2)

35-35: Avoid scheduling the same tail window twice.

Line 35 schedules starts all the way to the raw tail, and Lines 60-65 remap any short tail back to mix.shape[1] - chunk_size. When that last full-window start is already on a step boundary, the same end chunk gets inferred twice.

♻️ Proposed fix
-        self.indices = list(range(0, mix.shape[1], step))
+        if mix.shape[1] <= chunk_size:
+            self.indices = [0]
+        else:
+            last_start = mix.shape[1] - chunk_size
+            self.indices = list(range(0, last_start + 1, step))
+            if self.indices[-1] != last_start:
+                self.indices.append(last_start)
-        # We need to handle the last chunk where part is smaller than chunk_size
-        if length < self.chunk_size and self.mix.shape[1] >= self.chunk_size:
-            # Take the last chunk_size from the end
-            part = self.mix[:, -self.chunk_size :]
-            length = self.chunk_size
-            start_idx = self.mix.shape[1] - self.chunk_size
-        # If mix is shorter than chunk_size, keep original part and length
-
         return part, start_idx, length

Also applies to: 60-65

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/architectures/mdxc_separator.py` at line 35, The
current scheduling can append a tail start that duplicates an existing start
when mix.shape[1]-chunk_size falls on a step boundary; update the logic that
builds self.indices so that the final list never contains duplicate starts:
after creating indices = list(range(0, mix.shape[1], step)) and before applying
the tail remap (the code that references chunk_size and mix.shape[1]), compute
last_start = mix.shape[1] - chunk_size and ensure you only append or remap to
last_start if it is not already present (or alternatively deduplicate
self.indices preserving order), so the tail window is not scheduled twice.

377-385: Keep the host/device transfers batched.

Line 377 pins the batch, but Line 380 still uses a blocking .to(device). Line 384 then copies each item back separately with .cpu(), which reintroduces per-sample D2H transfers immediately after batching.

⚡ Proposed fix
-                for parts, start_idxs, lengths in tqdm(dataloader):
-                    parts = parts.to(device)
-                    xs = self.model_run(parts)
+                for parts, start_idxs, lengths in tqdm(dataloader):
+                    parts = parts.to(device, non_blocking=(device.type == "cuda"))
+                    xs = self.model_run(parts).cpu()
 
                     for b in range(xs.shape[0]):
-                        x = xs[b].cpu()
+                        x = xs[b]
                         start_idx = start_idxs[b].item()
                         length = lengths[b].item()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/architectures/mdxc_separator.py` around lines 377 -
385, The per-sample device-to-host transfers happen because you call .to(device)
on the whole batch (parts) but then call .cpu() inside the loop for each sample
(xs[b].cpu()); fix by doing the host copy once per batch: after xs =
self.model_run(parts) perform a single batch-level transfer (e.g., xs =
xs.detach().cpu()) and, if needed, ensure start_idxs is on CPU (e.g., start_idxs
= start_idxs.cpu()) before the inner loop, then iterate over xs[b] and
start_idxs[b].item() without per-sample .cpu() calls. This keeps transfers
batched and avoids repeated D2H copies.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@audio_separator/separator/architectures/mdxc_separator.py`:
- Line 35: The current scheduling can append a tail start that duplicates an
existing start when mix.shape[1]-chunk_size falls on a step boundary; update the
logic that builds self.indices so that the final list never contains duplicate
starts: after creating indices = list(range(0, mix.shape[1], step)) and before
applying the tail remap (the code that references chunk_size and mix.shape[1]),
compute last_start = mix.shape[1] - chunk_size and ensure you only append or
remap to last_start if it is not already present (or alternatively deduplicate
self.indices preserving order), so the tail window is not scheduled twice.
- Around line 377-385: The per-sample device-to-host transfers happen because
you call .to(device) on the whole batch (parts) but then call .cpu() inside the
loop for each sample (xs[b].cpu()); fix by doing the host copy once per batch:
after xs = self.model_run(parts) perform a single batch-level transfer (e.g., xs
= xs.detach().cpu()) and, if needed, ensure start_idxs is on CPU (e.g.,
start_idxs = start_idxs.cpu()) before the inner loop, then iterate over xs[b]
and start_idxs[b].item() without per-sample .cpu() calls. This keeps transfers
batched and avoids repeated D2H copies.

In `@tests/unit/test_cli.py`:
- Line 31: Add a new unit test that verifies the CLI forwards a non-default
--mdxc_num_workers value into the mdxc_params passed to Separator: call main()
with argv containing "--mdxc_num_workers=2", patch
audio_separator.separator.Separator, assert Separator was called once and that
its mdxc_params dict has "num_workers": 2 (keeping other mdxc_params from
common_expected_args); reference main(), Separator, and the mdxc_params key when
implementing the test.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: 1a40c725-b4a4-4590-89c1-8271d945ff54

📥 Commits

Reviewing files that changed from the base of the PR and between ec71756 and 0bb3d1b.

📒 Files selected for processing (2)
  • audio_separator/separator/architectures/mdxc_separator.py
  • tests/unit/test_cli.py

- Refactor RoformerDataset to calculate tail-window remapping during initialization.
- Fix an issue where the final audio chunk was duplicated when the tail start fell exactly on a step boundary.
- Optimize MDXCSeparator inference by moving device-to-host transfers (.cpu()) outside the per-sample loop.
- Batch the D2H copies to reduce synchronization overhead and improve processing speed.
- Simplify RoformerDataset.__getitem__ to remove redundant re-calculation logic.
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
audio_separator/separator/separator.py (1)

100-104: Consider using None as default for mutable dict parameters (pre-existing pattern).

The static analysis tool flags the use of mutable default arguments (dicts) for mdx_params, vr_params, demucs_params, and mdxc_params. While this is a pre-existing pattern and not introduced by this PR, mutable defaults can cause subtle bugs if the dict is mutated during runtime.

The standard Python idiom is to use None and initialize within the function:

♻️ Proposed refactor (optional, for future consideration)
     def __init__(
         self,
         ...
-        mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False},
-        vr_params={"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False},
-        demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True},
-        mdxc_params={"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0, "num_workers": 0},
+        mdx_params=None,
+        vr_params=None,
+        demucs_params=None,
+        mdxc_params=None,
         info_only=False,
     ):
         """Initialize the separator."""
+        if mdx_params is None:
+            mdx_params = {"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}
+        if vr_params is None:
+            vr_params = {"batch_size": 1, "window_size": 512, "aggression": 5, "enable_tta": False, "enable_post_process": False, "post_process_threshold": 0.2, "high_end_process": False}
+        if demucs_params is None:
+            demucs_params = {"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}
+        if mdxc_params is None:
+            mdxc_params = {"segment_size": 256, "override_model_segment_size": False, "batch_size": 1, "overlap": 8, "pitch_shift": 0, "num_workers": 0}

This is flagged by Ruff B006. Since this pattern is pre-existing and the current PR only adds a key, this can be addressed in a separate cleanup PR if desired.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@audio_separator/separator/separator.py` around lines 100 - 104, The function
currently uses mutable dict defaults for parameters mdx_params, vr_params,
demucs_params, and mdxc_params which triggers Ruff B006; change each default to
None and inside the function (where these params are used/merged) initialize
them with the existing literal dicts only when the argument is None (e.g., if
mdx_params is None: mdx_params = {...}); reference the parameter names
mdx_params, vr_params, demucs_params, and mdxc_params and ensure any downstream
usage expects those newly-initialized dicts so behavior remains unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@audio_separator/separator/separator.py`:
- Around line 100-104: The function currently uses mutable dict defaults for
parameters mdx_params, vr_params, demucs_params, and mdxc_params which triggers
Ruff B006; change each default to None and inside the function (where these
params are used/merged) initialize them with the existing literal dicts only
when the argument is None (e.g., if mdx_params is None: mdx_params = {...});
reference the parameter names mdx_params, vr_params, demucs_params, and
mdxc_params and ensure any downstream usage expects those newly-initialized
dicts so behavior remains unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: dcc3eebe-8d78-4f60-a0bc-83e4e62237ab

📥 Commits

Reviewing files that changed from the base of the PR and between d8b7538 and 8926f70.

📒 Files selected for processing (2)
  • README.md
  • audio_separator/separator/separator.py

@pedroalmeida415
Copy link
Author

pedroalmeida415 commented Mar 12, 2026

@beveradb let me know what you think.

@pedroalmeida415
Copy link
Author

After careful analysis and exhaustive benchmarks, I noticed that this approach was actually causing the process to slowdown overall, with marginal gains on big files (+-1 second).

I'll be closing the PR, please disregard this changes.

Benchmark results (original implementation using the published PyPi package is just faster out of the box):

index variant batch_size num_workers autocast separation_s_mean separation_s_std peak_rss_gb_mean peak_vram_alloc_gb_mean speedup_vs_b1_mean
1 pypi 1 0 true 35.651866896000115 NaN 2.833057403564453 1.94974946975708 1.0
3 repo 1 0 true 38.20467528500012 NaN 2.8497314453125 1.94974946975708 1.0
9 repo 2 0 true 40.25816148700005 NaN 2.9076385498046875 2.561758041381836 0.9489920521416997
7 repo 1 2 true 45.25300838299995 NaN 2.9271240234375 1.94974946975708 1.0
15 repo 4 0 true 46.665141149999954 NaN 2.915103912353516 3.790305137634277 0.8186983762075294
17 repo 4 1 true 46.72898002400007 NaN 2.930347442626953 3.790305137634277 1.0012827074541177
5 repo 1 1 true 46.78891963500018 NaN 2.864307403564453 1.94974946975708 1.0
11 repo 2 1 true 48.1048220990001 NaN 2.85477066040039 2.561758041381836 0.9726451027863321
13 repo 2 2 true 51.008291850000205 NaN 2.8958511352539062 2.561758041381836 0.8871696491244054
19 repo 4 2 true 51.31647834499972 NaN 2.9986495971679688 3.790305137634277 0.881841658711746
0 pypi 1 0 false 89.62990637100006 NaN 2.8253707885742188 1.667192459106445 1.0
2 repo 1 0 false 92.41930510499992 NaN 2.755535125732422 1.667192459106445 1.0
8 repo 2 0 false 94.93719235499998 NaN 2.8451805114746094 2.4691567420959477 0.9734783893694171
12 repo 2 2 false 95.77983637299984 NaN 2.8089256286621094 2.4691567420959477 1.0734324552571795
10 repo 2 1 false 98.15373112700011 NaN 2.762451171875 2.4691567420959477 1.0287402396792213
16 repo 4 1 false 98.49047762000008 NaN 2.824531555175781 4.066957950592041 1.0252228979392766
14 repo 4 0 false 99.64287339099997 NaN 2.887645721435547 4.066957950592041 0.9275054197036785
4 repo 1 1 false 100.97469288499995 NaN 2.8532028198242188 1.667192459106445 1.0
6 repo 1 2 false 102.81318492200012 NaN 2.852020263671875 1.667192459106445 1.0
18 repo 4 2 false 105.3692488859997 NaN 2.878890991210937 4.066957950592041 0.9757418412770028

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant