Description
The BatchLimiter class in src/saev/utils/scheduling.py incorrectly counts the number of samples seen during iteration, leading to premature termination or overcounting when the actual batch size is smaller than the expected batch size.
Location
src/saev/utils/scheduling.py:114
Root Cause
In the __iter__ method, the code always increments self.n_seen by self.batch_size:
self.n_seen += self.batch_size
if self.n_seen > self.n_samples:
return
However, the actual batch yielded might have fewer samples than self.batch_size, particularly:
- For the last batch when
drop_last=False
- For dataloaders with uneven dataset sizes
This causes the limiter to overcount samples, terminating the iterator at the wrong time.
Expected Behavior
The BatchLimiter should count the actual number of samples in each batch, not assume all batches have size self.batch_size.
Actual Behavior
The limiter terminates based on incorrect counts, yielding either too many or too few samples.
Example
If we have:
- A dataloader with 105 samples
batch_size = 32
drop_last = False
n_samples = 100 (what we want from BatchLimiter)
The batches would be: [32, 32, 32, 9]
But the counter would be: [32, 64, 96, 128]
When the counter hits 128 > 100, it returns after yielding all 105 samples (not 100).
Reproduction
See the unit tests in tests/test_batch_limiter.py which demonstrate this bug:
uv run --no-dev python -m pytest tests/test_batch_limiter.py -v
Test results:
test_batch_limiter_with_uneven_batches: Expected ≤100 samples, got 105
test_batch_limiter_early_termination: Expected 100 samples, got 160
Proposed Fix
Change line 114 to count the actual batch size instead of always using self.batch_size.
See PR for the implementation.
Description
The
BatchLimiterclass insrc/saev/utils/scheduling.pyincorrectly counts the number of samples seen during iteration, leading to premature termination or overcounting when the actual batch size is smaller than the expected batch size.Location
src/saev/utils/scheduling.py:114Root Cause
In the
__iter__method, the code always incrementsself.n_seenbyself.batch_size:However, the actual batch yielded might have fewer samples than
self.batch_size, particularly:drop_last=FalseThis causes the limiter to overcount samples, terminating the iterator at the wrong time.
Expected Behavior
The
BatchLimitershould count the actual number of samples in each batch, not assume all batches have sizeself.batch_size.Actual Behavior
The limiter terminates based on incorrect counts, yielding either too many or too few samples.
Example
If we have:
batch_size = 32drop_last = Falsen_samples = 100(what we want from BatchLimiter)The batches would be: [32, 32, 32, 9]
But the counter would be: [32, 64, 96, 128]
When the counter hits 128 > 100, it returns after yielding all 105 samples (not 100).
Reproduction
See the unit tests in
tests/test_batch_limiter.pywhich demonstrate this bug:Test results:
test_batch_limiter_with_uneven_batches: Expected ≤100 samples, got 105test_batch_limiter_early_termination: Expected 100 samples, got 160Proposed Fix
Change line 114 to count the actual batch size instead of always using
self.batch_size.See PR for the implementation.