Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions antipode/antipode_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(self, adata, discov_pair, batch_pair, layer, seccov_key='seccov_dum
# Initialize plates to be used during sampling
self.var_plate = pyro.plate('var_plate',self.num_var,dim=-1)
self.discov_plate = pyro.plate('discov_plate',self.num_discov,dim=-3)
self.discov_da_plate = pyro.plate('discov_da_plate',self.num_discov,dim=-2)
self.seccov_plate = pyro.plate('seccov_plate',self.num_seccov,dim=-3)
self.batch_plate = pyro.plate('batch_plate',self.num_batch,dim=-3)
self.latent_plate = pyro.plate('latent_plate',self.num_latent,dim=-1)
Expand All @@ -147,7 +148,7 @@ def __init__(self, adata, discov_pair, batch_pair, layer, seccov_key='seccov_dum
self.dc=MAPLaplaceModule(self,'discov_dc',[self.num_discov,self.num_latent,self.num_var],
[self.discov_plate,self.latent_plate2,self.var_plate],scale_multiplier=self.anc_prior_scalar)
self.discov_da=MAPLaplaceModule(self,'discov_da',[self.num_discov,self.num_var],
[self.discov_plate,self.var_plate],init_val=self.discov_da_prior_loc,
[self.discov_da_plate,self.var_plate],init_val=self.discov_da_prior_loc,
prior_loc=self.discov_da_prior_loc,scale_multiplier=self.anc_prior_scalar)
self.zdw=MAPLaplaceModule(self,'z_decoder_weight',[self.num_latent,self.num_var],
[self.latent_plate2,self.var_plate],
Expand Down Expand Up @@ -266,7 +267,6 @@ def model(self, s,discov_ind=torch.zeros(1),batch_ind=torch.zeros(1),seccov=torc
self.intercept_init.to(s.device),
)
discov_da = self.discov_da.model_sample(s, scale=fest([discov], -1))
discov_da = discov_da.squeeze(-2)
level_edges=self.tree_edges.model_sample(s,approx=self.approx)
quality_genes=self.qg.model_sample(s) if self.use_q_score else 0.

Expand Down
143 changes: 143 additions & 0 deletions tests/test_obs_categorical_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""Smoke-tests confirming that ANTIPODE runs without shape errors in both
obs-categorical and obsm-design-matrix discovery modes.

Regression guard for the RuntimeError at antipode_model.py caused by
discov_da sampling as [num_discov, num_discov, num_var] instead of
[num_discov, num_var] when the two plates were non-contiguous.
"""
import numpy as np
import pandas as pd
import anndata
import pyro
import torch
from pyro import poutine

from antipode.antipode_model import ANTIPODE
from antipode.model_modules import SafeSVI


def _make_adata(n_obs=64, n_genes=50, n_discov=4, n_batch=3):
rng = np.random.default_rng(0)
x = rng.poisson(5, size=(n_obs, n_genes)).astype(np.float32)
adata = anndata.AnnData(x)
adata.layers["counts"] = x.copy()
discov_labels = [f"d{i}" for i in range(n_discov)]
batch_labels = [f"b{i}" for i in range(n_batch)]
adata.obs["discov"] = pd.Categorical(
[discov_labels[i % n_discov] for i in range(n_obs)]
)
adata.obs["batch"] = pd.Categorical(
[batch_labels[i % n_batch] for i in range(n_obs)]
)
# obsm one-hot views for design-matrix mode
adata.obsm["discov_onehot"] = np.eye(n_discov, dtype=np.float32)[
adata.obs["discov"].cat.codes.to_numpy()
]
adata.obsm["batch_onehot"] = np.eye(n_batch, dtype=np.float32)[
adata.obs["batch"].cat.codes.to_numpy()
]
return adata


def _discov_real_means(adata, n_discov, n_genes):
"""Simple per-category mean expression used as the prior for discov_da."""
return torch.ones(n_discov, n_genes)


def _run_steps(model, s, discov_ind, batch_ind, n_steps=3):
"""Run n_steps of SVI and return the last loss."""
seccov = torch.zeros((s.shape[0], 1))
optim = pyro.optim.ClippedAdam({"lr": 1e-3})
elbo = pyro.infer.Trace_ELBO(num_particles=1)
model_blocked = poutine.block(model.model, hide=["s"])
svi = SafeSVI(model_blocked, model.guide, optim, elbo)
loss = None
for _ in range(n_steps):
loss = svi.step(
s,
discov_ind=discov_ind,
batch_ind=batch_ind,
seccov=seccov,
step=torch.ones(1),
)
return loss


def test_obs_categorical_mode_no_shape_error():
"""obs categorical mode: discov_da must sample as [num_discov, num_var]."""
pyro.clear_param_store()
n_discov, n_genes = 4, 50
adata = _make_adata(n_discov=n_discov, n_genes=n_genes)

model = ANTIPODE(
adata,
discov_pair=("obs", "discov"),
batch_pair=("obs", "batch"),
layer="counts",
level_sizes=[1, 3],
num_latent=4,
num_batch_embed=2,
classifier_hidden=[4],
encoder_hidden=[4],
batch_embedder_hidden=[4],
use_q_score=False,
use_psi=False,
discov_real_means=_discov_real_means(adata, n_discov, n_genes),
)
model.freeze_encoder = False

# Confirm discov_da plate is at dim=-2 (the fix)
assert model.discov_da_plate.dim == -2, (
f"discov_da_plate.dim should be -2, got {model.discov_da_plate.dim}"
)

s = torch.tensor(adata.layers["counts"])
# The scvi CategoricalObsField loader returns shape [N, 1]; reproduce that here
discov_ind = torch.tensor(
adata.obs["discov"].cat.codes.to_numpy(), dtype=torch.long
).unsqueeze(-1)
batch_ind = torch.tensor(
adata.obs["batch"].cat.codes.to_numpy(), dtype=torch.long
).unsqueeze(-1)

loss = _run_steps(model, s, discov_ind, batch_ind)
assert np.isfinite(loss), f"Expected finite loss in obs mode, got {loss}"


def test_obsm_design_matrix_mode_no_shape_error():
"""obsm design-matrix mode: discov_da must also sample as [num_discov, num_var]."""
pyro.clear_param_store()
n_discov, n_genes = 4, 50
adata = _make_adata(n_discov=n_discov, n_genes=n_genes)

model = ANTIPODE(
adata,
discov_pair=("obsm", "discov_onehot"),
batch_pair=("obsm", "batch_onehot"),
layer="counts",
level_sizes=[1, 3],
num_latent=4,
num_batch_embed=2,
classifier_hidden=[4],
encoder_hidden=[4],
batch_embedder_hidden=[4],
use_q_score=False,
use_psi=False,
discov_real_means=_discov_real_means(adata, n_discov, n_genes),
)
model.freeze_encoder = False

s = torch.tensor(adata.layers["counts"])
discov_ind = torch.tensor(adata.obsm["discov_onehot"])
batch_ind = torch.tensor(adata.obsm["batch_onehot"])

loss = _run_steps(model, s, discov_ind, batch_ind)
assert np.isfinite(loss), f"Expected finite loss in obsm mode, got {loss}"


if __name__ == "__main__":
test_obs_categorical_mode_no_shape_error()
print("obs categorical mode: OK")
test_obsm_design_matrix_mode_no_shape_error()
print("obsm design-matrix mode: OK")
print("All checks passed.")
5 changes: 4 additions & 1 deletion tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def _make_minimal_adata(n_obs=4, n_genes=3):
def test_train_phase_runs_on_minimal_data():
pyro.clear_param_store()
adata = _make_minimal_adata()
adata.obsm["discov_onehot"] = np.eye(2, dtype=np.float32)[
n_discov = 2
n_genes = adata.n_vars
adata.obsm["discov_onehot"] = np.eye(n_discov, dtype=np.float32)[
adata.obs["discov"].cat.codes.to_numpy()
]
adata.obsm["batch_onehot"] = np.eye(2, dtype=np.float32)[
Expand All @@ -40,6 +42,7 @@ def test_train_phase_runs_on_minimal_data():
batch_embedder_hidden=[2],
use_q_score=False,
use_psi=False,
discov_real_means=torch.ones(n_discov, n_genes),
)
model.freeze_encoder = False

Expand Down