Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/batch-calibration-memory-fix.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added `batch_size` parameter to `Calibration` and `reweight()` for gradient accumulation over record batches. When set, the chi-squared loss is accumulated under `no_grad` in a first pass and the backward pass is split into per-batch virtual-loss calls with pre-computed per-target coefficients. Peak autograd activation memory drops from O(n_records × n_targets) to O(batch_size × n_targets). The full-batch path is unchanged when `batch_size` is `None` (default) or greater than or equal to `n_records`. Not supported in combination with `regularize_with_l0=True` (raises `ValueError`).
1 change: 1 addition & 0 deletions changelog.d/batch-calibration-memory-fix.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
`Calibration` now converts the user-provided `estimate_matrix` DataFrame to a cached `float32` torch tensor on `estimate_matrix_tensor` during `__init__` and releases the pandas DataFrame reference by setting `original_estimate_matrix` to `None`. Downstream code (`hyperparameter_tuning`, `evaluation`, `assess_analytical_solution`) reads the cached tensor rather than re-materializing from `DataFrame.values`. This substantially reduces peak RSS during `calibrate()` at large record counts. External readers of `Calibration.original_estimate_matrix` will now see `None` after construction; the tensor equivalent is available on `Calibration.estimate_matrix_tensor`.
77 changes: 47 additions & 30 deletions src/microcalibrate/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
sparse_learning_rate: Optional[float] = 0.2,
regularize_with_l0: Optional[bool] = False,
seed: Optional[int] = 42,
batch_size: Optional[int] = None,
):
"""Initialize the Calibration class.

Expand All @@ -64,6 +65,7 @@ def __init__(
temperature (float): Temperature parameter for L0 regularization, controlling the sparsity of the model. Defaults to 0.5.
sparse_learning_rate (float): Learning rate for the regularizing optimizer. Defaults to 0.2.
regularize_with_l0 (Optional[bool]): Whether to apply L0 regularization. Defaults to False.
batch_size (Optional[int]): If set, the per-epoch gradient is accumulated over disjoint record batches of this size (two-pass: accumulate the chi-squared estimate under no_grad, then per-batch backward with pre-computed target-coefficient). This keeps peak activation memory O(batch_size * n_targets) instead of O(n_records * n_targets), at the cost of modest fp32 rounding during accumulation. None (default) = full-batch, matching prior behavior exactly.
"""
# Resolve the torch device exactly once. The fallback chain
# (cuda -> mps -> cpu) runs when ``device`` is None so callers
Expand Down Expand Up @@ -99,6 +101,13 @@ def __init__(
self.sparse_learning_rate = sparse_learning_rate
self.regularize_with_l0 = regularize_with_l0
self.seed = seed
self.batch_size = batch_size

# Authoritative float32 copy of the estimate matrix; the
# pandas DataFrame is released after __init__ so its storage is
# garbage-collectable and peak RSS during calibrate() is cut
# substantially at v7 scale (>1e6 rows).
self.estimate_matrix_tensor: Optional[torch.Tensor] = None

# Seed torch on every path, and CUDA as well when we actually
# resolved to a CUDA device, so stochastic CUDA kernels are
Expand All @@ -121,17 +130,26 @@ def __init__(
self.original_estimate_matrix.columns.to_numpy()
)

# Build a single float32 torch copy of the full estimate matrix
# and release the caller's pandas DataFrame. The tensor is the
# authoritative matrix from here on; downstream code (including
# exclude_targets and hyperparameter tuning) reads it instead of
# re-materializing from .values.
if self.original_estimate_matrix is not None:
self.estimate_matrix_tensor = torch.tensor(
self.original_estimate_matrix.values,
dtype=torch.float32,
device=self.device,
)
self.original_estimate_matrix = None

if self.excluded_targets is not None:
self.exclude_targets()
else:
self.targets = self.original_targets
self.target_names = self.original_target_names
if self.original_estimate_matrix is not None:
self.estimate_matrix = torch.tensor(
self.original_estimate_matrix.values,
dtype=torch.float32,
device=self.device,
)
if self.estimate_matrix_tensor is not None:
self.estimate_matrix = self.estimate_matrix_tensor
else:
self.estimate_matrix = None

Expand Down Expand Up @@ -182,6 +200,8 @@ def calibrate(self) -> None:
regularize_with_l0=self.regularize_with_l0,
logger=self.logger,
seed=self.seed,
batch_size=self.batch_size,
estimate_matrix=self.estimate_matrix,
)

self.weights = new_weights
Expand Down Expand Up @@ -242,29 +262,25 @@ def exclude_targets(
.cpu()
.numpy()
)
elif self.original_estimate_matrix is not None:
# Get initial estimates using the original full matrix
original_estimate_matrix_tensor = torch.tensor(
self.original_estimate_matrix.values,
dtype=torch.float32,
device=self.device,
)
elif self.estimate_matrix_tensor is not None:
# Get initial estimates using the full matrix tensor
initial_estimates_all = (
(initial_weights_tensor @ original_estimate_matrix_tensor)
(initial_weights_tensor @ self.estimate_matrix_tensor)
.detach()
.cpu()
.numpy()
)

# Filter estimate matrix for calibration
filtered_estimate_matrix = self.original_estimate_matrix.iloc[
:, calibration_mask
]
self.estimate_matrix = torch.tensor(
filtered_estimate_matrix.values,
dtype=torch.float32,
# Filter estimate matrix for calibration via torch column
# indexing — no pandas round-trip, no extra materialized copy.
keep_idx = torch.as_tensor(
np.flatnonzero(calibration_mask),
dtype=torch.long,
device=self.device,
)
self.estimate_matrix = (
self.estimate_matrix_tensor.index_select(1, keep_idx)
)

self.estimate_function = (
lambda weights: weights @ self.estimate_matrix
Expand All @@ -284,12 +300,8 @@ def exclude_targets(
)

else:
if self.original_estimate_matrix is not None:
self.estimate_matrix = torch.tensor(
self.original_estimate_matrix.values,
dtype=torch.float32,
device=self.device,
)
if self.estimate_matrix_tensor is not None:
self.estimate_matrix = self.estimate_matrix_tensor
if self.original_estimate_function is None:
self.estimate_function = (
lambda weights: weights @ self.estimate_matrix
Expand Down Expand Up @@ -451,14 +463,19 @@ def _get_linear_loss(metrics_matrix, target_vector, sparse=False):

return np.mean(((y - y_hat) ** 2) * normalization_factor)

X = self.original_estimate_matrix.values
if self.estimate_matrix_tensor is None:
raise ValueError(
"analytical_solution requires a dense estimate matrix; "
"Calibration was constructed without one."
)
X = self.estimate_matrix_tensor.cpu().numpy()
y = self.targets

results = []
slices = []
idx_dict = {
self.original_estimate_matrix.columns.to_list()[i]: i
for i in range(len(self.original_estimate_matrix.columns))
self.original_target_names[i]: i
for i in range(len(self.original_target_names))
}

self.logger.info(
Expand Down
16 changes: 8 additions & 8 deletions src/microcalibrate/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,15 @@ def _evaluate_single_holdout_robustness(
final_weights, dtype=torch.float32, device=calibration.device
)

# Get estimates for all targets using original estimate function/matrix
if calibration.original_estimate_matrix is not None:
original_matrix_tensor = torch.tensor(
calibration.original_estimate_matrix.values,
dtype=torch.float32,
device=calibration.device,
)
# Get estimates for all targets using the cached full matrix
# tensor (built once in Calibration.__init__). Falls back to
# the user-supplied estimate_function for callers that passed
# an opaque function rather than a dense matrix.
if calibration.estimate_matrix_tensor is not None:
all_estimates = (
(weights_tensor @ original_matrix_tensor).cpu().numpy()
(weights_tensor @ calibration.estimate_matrix_tensor)
.cpu()
.numpy()
)
else:
all_estimates = (
Expand Down
11 changes: 4 additions & 7 deletions src/microcalibrate/hyperparameter_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,11 @@ def _evaluate_single_holdout(
sparse_weights, dtype=torch.float32, device=calibration.device
)

if calibration.original_estimate_matrix is not None:
original_matrix_tensor = torch.tensor(
calibration.original_estimate_matrix.values,
dtype=torch.float32,
device=calibration.device,
)
if calibration.estimate_matrix_tensor is not None:
all_estimates = (
(weights_tensor @ original_matrix_tensor).cpu().numpy()
(weights_tensor @ calibration.estimate_matrix_tensor)
.cpu()
.numpy()
)
else:
all_estimates = (
Expand Down
100 changes: 94 additions & 6 deletions src/microcalibrate/reweight.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm import tqdm

from .utils.log_performance import log_performance_over_epochs
from .utils.metrics import loss, pct_close
from .utils.metrics import _safe_denominator, loss, pct_close


def dropout_weights(weights: torch.Tensor, p: float) -> torch.Tensor:
Expand Down Expand Up @@ -70,6 +70,8 @@ def reweight(
device: Optional[str] = None,
logger: Optional[logging.Logger] = None,
seed: Optional[int] = None,
batch_size: Optional[int] = None,
estimate_matrix: Optional[torch.Tensor] = None,
) -> tuple[np.ndarray, Union[np.ndarray, None], pd.DataFrame]:
"""Reweight the original weights based on the loss matrix and targets.

Expand Down Expand Up @@ -97,6 +99,20 @@ def reweight(
draws the initial weight noise and torch's generator. When
None, a non-deterministic draw is used (preserving the
historical behaviour).
batch_size (Optional[int]): If set, the per-epoch gradient is
accumulated over disjoint record batches of this size. This
keeps the autograd activation O(batch_size * n_targets)
instead of O(n_records * n_targets) — critical at v7 scale
(n_records > 1e6). Requires ``estimate_matrix`` to be
provided; not supported for arbitrary ``estimate_function``.
None (default) preserves the existing full-batch path bit-
for-bit. batch_size >= n_records degenerates to full-batch.
estimate_matrix (Optional[torch.Tensor]): The float32 estimate
matrix of shape (n_records, n_targets) backing the
``estimate_function``. Required when ``batch_size`` is set;
ignored otherwise. Callers passing a custom
``estimate_function`` that does not correspond to a dense
matrix must use full-batch mode.

Returns:
np.ndarray: Reweighted weights.
Expand Down Expand Up @@ -149,6 +165,23 @@ def reweight(

optimizer = torch.optim.Adam([weights], lr=learning_rate)

n_records = original_weights.shape[0]
use_batched = batch_size is not None and batch_size < n_records
if use_batched and estimate_matrix is None:
raise ValueError(
"batch_size requires `estimate_matrix` to be provided so the "
"reweight loop can index per-batch rows. Pass the torch "
"estimate tensor explicitly, or leave batch_size=None to use "
"the full-batch path with an arbitrary estimate_function."
)
if use_batched and regularize_with_l0:
raise ValueError(
"batch_size is not yet supported with regularize_with_l0=True. "
"The L0 sparse-reweighting loop uses a different objective and "
"is not yet batched. Choose one: disable L0 for the dense "
"calibration, or leave batch_size=None."
)

iterator = tqdm(range(epochs), desc="Reweighting progress", unit="epoch")
tracking_n = max(1, epochs // 10) if epochs > 10 else 1
progress_update_interval = 10
Expand All @@ -161,9 +194,61 @@ def reweight(
for i in iterator:
optimizer.zero_grad()
weights_ = dropout_weights(weights, dropout_rate)
estimate = estimate_function(torch.exp(weights_))
l = loss(estimate, targets, normalization_factor)
close = pct_close(estimate, targets)

if use_batched:
# Two-pass batched gradient accumulation.
#
# The chi-squared loss is separable across record batches
# given the per-target coefficient c_j = d(loss)/d(S_j),
# because S_j (the weighted sum of estimate_matrix column j)
# is itself a sum over records. Phase 1 accumulates S under
# no_grad; Phase 2 computes, per batch,
# virtual_loss_batch = c · (exp(w_log[batch]) @ A[batch])
# and calls .backward() to accumulate gradients into weights.
# The sum of virtual_loss_batch over batches has exactly the
# same gradient as the full-batch loss; peak autograd
# activation is O(batch_size * n_targets).
n_targets = targets.shape[0]
with torch.no_grad():
exp_w_ = torch.exp(weights_)
S = torch.zeros(n_targets, dtype=torch.float32, device=device)
for start in range(0, n_records, batch_size):
end = min(start + batch_size, n_records)
S += exp_w_[start:end] @ estimate_matrix[start:end]
# Coefficient c_j = d(loss)/d(S_j). Using the same
# clamped denominator as the reference loss so batched
# and full-batch paths agree on targets near -1.
# loss = mean(((S-t)+1) / _safe_denominator(t))^2 * normalization_factor)
# => d(loss)/d(S_j) = 2 * ((S_j - t_j + 1) / denom_j^2) / n_targets * normalization_factor_j
denominator = _safe_denominator(targets)
rel_error_unrooted = ((S - targets) + 1) / denominator
coef = 2.0 * rel_error_unrooted / denominator / n_targets
if normalization_factor is not None:
coef = coef * normalization_factor

# Phase 2: per-batch backward with retain_graph until the
# final batch, so weights_ → weights graph persists across
# the multiple .backward() calls within this epoch.
batch_starts = list(range(0, n_records, batch_size))
for batch_idx, start in enumerate(batch_starts):
end = min(start + batch_size, n_records)
batch_estimate = (
torch.exp(weights_[start:end]) @ estimate_matrix[start:end]
)
virtual_loss = (coef * batch_estimate).sum()
retain = batch_idx < len(batch_starts) - 1
virtual_loss.backward(retain_graph=retain)

# For logging only: full-batch-equivalent loss value,
# computed from S (no additional activation memory).
with torch.no_grad():
estimate = S
l = loss(estimate, targets, normalization_factor)
close = pct_close(estimate, targets)
else:
estimate = estimate_function(torch.exp(weights_))
l = loss(estimate, targets, normalization_factor)
close = pct_close(estimate, targets)

if i % progress_update_interval == 0:
iterator.set_postfix(
Expand Down Expand Up @@ -197,8 +282,11 @@ def reweight(

# Step every epoch. The returned final_weights reflect the state
# after the last step; the final logged row above reflects the
# pre-step state of the same (last) epoch.
l.backward()
# pre-step state of the same (last) epoch. In the batched path
# gradients were already accumulated above, so we only call
# l.backward() on the full-batch path.
if not use_batched:
l.backward()
optimizer.step()

tracker_dict = {
Expand Down
Loading
Loading