From 16e72ea38ce2e976ab56bc4778e9c8cccbe844d8 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Apr 2026 07:50:56 +0000 Subject: [PATCH 1/2] Initial plan From 162f14878b66e2d1526be54d3d7d513c3b4d0c1b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 10 Apr 2026 08:15:31 +0000 Subject: [PATCH 2/2] Fix discov_da plate structure: use contiguous dim=-2 plate to prevent [N,N,G] sampling shape mismatch Agent-Logs-Url: https://github.com/mtvector/scANTIPODE/sessions/953f3c1a-ee3c-4a27-97b7-a64cb390f060 Co-authored-by: mtvector <28308459+mtvector@users.noreply.github.com> --- antipode/antipode_model.py | 4 +- tests/test_obs_categorical_mode.py | 143 +++++++++++++++++++++++++++++ tests/test_training.py | 5 +- 3 files changed, 149 insertions(+), 3 deletions(-) create mode 100644 tests/test_obs_categorical_mode.py diff --git a/antipode/antipode_model.py b/antipode/antipode_model.py index 2859e4f..688feaa 100644 --- a/antipode/antipode_model.py +++ b/antipode/antipode_model.py @@ -124,6 +124,7 @@ def __init__(self, adata, discov_pair, batch_pair, layer, seccov_key='seccov_dum # Initialize plates to be used during sampling self.var_plate = pyro.plate('var_plate',self.num_var,dim=-1) self.discov_plate = pyro.plate('discov_plate',self.num_discov,dim=-3) + self.discov_da_plate = pyro.plate('discov_da_plate',self.num_discov,dim=-2) self.seccov_plate = pyro.plate('seccov_plate',self.num_seccov,dim=-3) self.batch_plate = pyro.plate('batch_plate',self.num_batch,dim=-3) self.latent_plate = pyro.plate('latent_plate',self.num_latent,dim=-1) @@ -147,7 +148,7 @@ def __init__(self, adata, discov_pair, batch_pair, layer, seccov_key='seccov_dum self.dc=MAPLaplaceModule(self,'discov_dc',[self.num_discov,self.num_latent,self.num_var], [self.discov_plate,self.latent_plate2,self.var_plate],scale_multiplier=self.anc_prior_scalar) self.discov_da=MAPLaplaceModule(self,'discov_da',[self.num_discov,self.num_var], - [self.discov_plate,self.var_plate],init_val=self.discov_da_prior_loc, + [self.discov_da_plate,self.var_plate],init_val=self.discov_da_prior_loc, prior_loc=self.discov_da_prior_loc,scale_multiplier=self.anc_prior_scalar) self.zdw=MAPLaplaceModule(self,'z_decoder_weight',[self.num_latent,self.num_var], [self.latent_plate2,self.var_plate], @@ -266,7 +267,6 @@ def model(self, s,discov_ind=torch.zeros(1),batch_ind=torch.zeros(1),seccov=torc self.intercept_init.to(s.device), ) discov_da = self.discov_da.model_sample(s, scale=fest([discov], -1)) - discov_da = discov_da.squeeze(-2) level_edges=self.tree_edges.model_sample(s,approx=self.approx) quality_genes=self.qg.model_sample(s) if self.use_q_score else 0. diff --git a/tests/test_obs_categorical_mode.py b/tests/test_obs_categorical_mode.py new file mode 100644 index 0000000..baabd9e --- /dev/null +++ b/tests/test_obs_categorical_mode.py @@ -0,0 +1,143 @@ +"""Smoke-tests confirming that ANTIPODE runs without shape errors in both +obs-categorical and obsm-design-matrix discovery modes. + +Regression guard for the RuntimeError at antipode_model.py caused by +discov_da sampling as [num_discov, num_discov, num_var] instead of +[num_discov, num_var] when the two plates were non-contiguous. +""" +import numpy as np +import pandas as pd +import anndata +import pyro +import torch +from pyro import poutine + +from antipode.antipode_model import ANTIPODE +from antipode.model_modules import SafeSVI + + +def _make_adata(n_obs=64, n_genes=50, n_discov=4, n_batch=3): + rng = np.random.default_rng(0) + x = rng.poisson(5, size=(n_obs, n_genes)).astype(np.float32) + adata = anndata.AnnData(x) + adata.layers["counts"] = x.copy() + discov_labels = [f"d{i}" for i in range(n_discov)] + batch_labels = [f"b{i}" for i in range(n_batch)] + adata.obs["discov"] = pd.Categorical( + [discov_labels[i % n_discov] for i in range(n_obs)] + ) + adata.obs["batch"] = pd.Categorical( + [batch_labels[i % n_batch] for i in range(n_obs)] + ) + # obsm one-hot views for design-matrix mode + adata.obsm["discov_onehot"] = np.eye(n_discov, dtype=np.float32)[ + adata.obs["discov"].cat.codes.to_numpy() + ] + adata.obsm["batch_onehot"] = np.eye(n_batch, dtype=np.float32)[ + adata.obs["batch"].cat.codes.to_numpy() + ] + return adata + + +def _discov_real_means(adata, n_discov, n_genes): + """Simple per-category mean expression used as the prior for discov_da.""" + return torch.ones(n_discov, n_genes) + + +def _run_steps(model, s, discov_ind, batch_ind, n_steps=3): + """Run n_steps of SVI and return the last loss.""" + seccov = torch.zeros((s.shape[0], 1)) + optim = pyro.optim.ClippedAdam({"lr": 1e-3}) + elbo = pyro.infer.Trace_ELBO(num_particles=1) + model_blocked = poutine.block(model.model, hide=["s"]) + svi = SafeSVI(model_blocked, model.guide, optim, elbo) + loss = None + for _ in range(n_steps): + loss = svi.step( + s, + discov_ind=discov_ind, + batch_ind=batch_ind, + seccov=seccov, + step=torch.ones(1), + ) + return loss + + +def test_obs_categorical_mode_no_shape_error(): + """obs categorical mode: discov_da must sample as [num_discov, num_var].""" + pyro.clear_param_store() + n_discov, n_genes = 4, 50 + adata = _make_adata(n_discov=n_discov, n_genes=n_genes) + + model = ANTIPODE( + adata, + discov_pair=("obs", "discov"), + batch_pair=("obs", "batch"), + layer="counts", + level_sizes=[1, 3], + num_latent=4, + num_batch_embed=2, + classifier_hidden=[4], + encoder_hidden=[4], + batch_embedder_hidden=[4], + use_q_score=False, + use_psi=False, + discov_real_means=_discov_real_means(adata, n_discov, n_genes), + ) + model.freeze_encoder = False + + # Confirm discov_da plate is at dim=-2 (the fix) + assert model.discov_da_plate.dim == -2, ( + f"discov_da_plate.dim should be -2, got {model.discov_da_plate.dim}" + ) + + s = torch.tensor(adata.layers["counts"]) + # The scvi CategoricalObsField loader returns shape [N, 1]; reproduce that here + discov_ind = torch.tensor( + adata.obs["discov"].cat.codes.to_numpy(), dtype=torch.long + ).unsqueeze(-1) + batch_ind = torch.tensor( + adata.obs["batch"].cat.codes.to_numpy(), dtype=torch.long + ).unsqueeze(-1) + + loss = _run_steps(model, s, discov_ind, batch_ind) + assert np.isfinite(loss), f"Expected finite loss in obs mode, got {loss}" + + +def test_obsm_design_matrix_mode_no_shape_error(): + """obsm design-matrix mode: discov_da must also sample as [num_discov, num_var].""" + pyro.clear_param_store() + n_discov, n_genes = 4, 50 + adata = _make_adata(n_discov=n_discov, n_genes=n_genes) + + model = ANTIPODE( + adata, + discov_pair=("obsm", "discov_onehot"), + batch_pair=("obsm", "batch_onehot"), + layer="counts", + level_sizes=[1, 3], + num_latent=4, + num_batch_embed=2, + classifier_hidden=[4], + encoder_hidden=[4], + batch_embedder_hidden=[4], + use_q_score=False, + use_psi=False, + discov_real_means=_discov_real_means(adata, n_discov, n_genes), + ) + model.freeze_encoder = False + + s = torch.tensor(adata.layers["counts"]) + discov_ind = torch.tensor(adata.obsm["discov_onehot"]) + batch_ind = torch.tensor(adata.obsm["batch_onehot"]) + + loss = _run_steps(model, s, discov_ind, batch_ind) + assert np.isfinite(loss), f"Expected finite loss in obsm mode, got {loss}" + + +if __name__ == "__main__": + test_obs_categorical_mode_no_shape_error() + print("obs categorical mode: OK") + test_obsm_design_matrix_mode_no_shape_error() + print("obsm design-matrix mode: OK") + print("All checks passed.") diff --git a/tests/test_training.py b/tests/test_training.py index 12f9289..8366797 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -21,7 +21,9 @@ def _make_minimal_adata(n_obs=4, n_genes=3): def test_train_phase_runs_on_minimal_data(): pyro.clear_param_store() adata = _make_minimal_adata() - adata.obsm["discov_onehot"] = np.eye(2, dtype=np.float32)[ + n_discov = 2 + n_genes = adata.n_vars + adata.obsm["discov_onehot"] = np.eye(n_discov, dtype=np.float32)[ adata.obs["discov"].cat.codes.to_numpy() ] adata.obsm["batch_onehot"] = np.eye(2, dtype=np.float32)[ @@ -40,6 +42,7 @@ def test_train_phase_runs_on_minimal_data(): batch_embedder_hidden=[2], use_q_score=False, use_psi=False, + discov_real_means=torch.ones(n_discov, n_genes), ) model.freeze_encoder = False