From 78ce0f117d3212bbf7371c0443b4b399a162d189 Mon Sep 17 00:00:00 2001 From: jkobject Date: Wed, 20 May 2026 07:49:26 +0000 Subject: [PATCH] fix(embedder): keep_all_cls_pred with heterogeneous head sizes (#16) Three related bugs reported in jkobject/scPRINT#16: 1. scPrint._predict: with keep_all_cls_pred=True, the code stacked the per-class logits with torch.stack, but classification heads have different n_classes (e.g. 424 for cell_type vs 62 for disease), so stacking raises 'stack expects each tensor to be equal size'. Store per-class logits in a dict {clsname: tensor} instead, and concatenate per head across batches. 2. on_validation_epoch_end: handle the dict shape when all_gather'ing self.pred. 3. Embedder.__call__ (cell_emb.py): - move logits to CPU/numpy before wrapping in pd.DataFrame (CUDA tensors cannot be converted directly). - fix pd.concat(adata.obs, allclspred) -> pd.concat([adata.obs, allclspred], axis=1) (positional misuse, second arg was being interpreted as 'axis'). The keep_all_cls_pred=False (argmax) path is unchanged. Co-authored-by: Prachi-Priyam --- pyproject.toml | 2 +- scprint/model/model.py | 67 ++++++++++++++++++++++----------------- scprint/tasks/cell_emb.py | 21 ++++++++---- 3 files changed, 53 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 356953e8..c0e8cead 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scprint" -version = "1.1.3" +version = "1.1.4" license = "MIT" description = "scPRINT is a Large Cell Model for Gene Network Inference, Denoising and more from scRNAseq data" authors = ["jeremie kalfon"] diff --git a/scprint/model/model.py b/scprint/model/model.py index bd65ad65..26bca6fc 100644 --- a/scprint/model/model.py +++ b/scprint/model/model.py @@ -1091,11 +1091,15 @@ def on_validation_epoch_end(self): """@see pl.LightningModule""" self.embs = self.all_gather(self.embs).view(-1, self.embs.shape[-1]) self.info = self.all_gather(self.info).view(-1, self.info.shape[-1]) - self.pred = ( - self.all_gather(self.pred).view(-1, self.pred.shape[-1]) - if self.pred is not None - else None - ) + if self.pred is None: + pass + elif isinstance(self.pred, dict): + self.pred = { + k: self.all_gather(v).view(-1, v.shape[-1]) + for k, v in self.pred.items() + } + else: + self.pred = self.all_gather(self.pred).view(-1, self.pred.shape[-1]) self.pos = self.all_gather(self.pos).view(-1, self.pos.shape[-1]) if not self.trainer.is_global_zero: # print("you are not on the main node. cancelling logging step") @@ -1258,20 +1262,23 @@ def _predict( if self.embs is None: self.embs = torch.mean(cell_embs[:, ind, :], dim=1) # self.embs = output["cls_output_" + "cell_type_ontology_term_id"] - self.pred = ( - torch.stack( + if len(self.classes) == 0: + self.pred = None + elif self.keep_all_cls_pred: + # Heads have different n_classes (e.g. 424 vs 62), so a + # stacked tensor is not well-defined. Store per-class logits + # in a dict so downstream code can concat per head. + self.pred = { + clsname: output["cls_output_" + clsname] + for clsname in self.classes + } + else: + self.pred = torch.stack( [ - ( - torch.argmax(output["cls_output_" + clsname], dim=1) - if not self.keep_all_cls_pred - else output["cls_output_" + clsname] - ) + torch.argmax(output["cls_output_" + clsname], dim=1) for clsname in self.classes ] ).transpose(0, 1) - if len(self.classes) > 0 - else None - ) self.pos = gene_pos self.expr_pred = ( [output["mean"], output["disp"], output["zero_logits"]] @@ -1283,25 +1290,27 @@ def _predict( # [self.embs, output["cls_output_" + "cell_type_ontology_term_id"]] [self.embs, torch.mean(cell_embs[:, ind, :], dim=1)] ) - self.pred = torch.cat( - [ - self.pred, - ( + if len(self.classes) == 0: + pass # keep self.pred = None + elif self.keep_all_cls_pred: + for clsname in self.classes: + self.pred[clsname] = torch.cat( + [self.pred[clsname], output["cls_output_" + clsname]] + ) + else: + self.pred = torch.cat( + [ + self.pred, torch.stack( [ - ( - torch.argmax(output["cls_output_" + clsname], dim=1) - if not self.keep_all_cls_pred - else output["cls_output_" + clsname] + torch.argmax( + output["cls_output_" + clsname], dim=1 ) for clsname in self.classes ] - ).transpose(0, 1) - if len(self.classes) > 0 - else None - ), - ], - ) + ).transpose(0, 1), + ], + ) self.pos = torch.cat([self.pos, gene_pos]) self.expr_pred = ( [ diff --git a/scprint/tasks/cell_emb.py b/scprint/tasks/cell_emb.py index 777adaff..d6e8da79 100644 --- a/scprint/tasks/cell_emb.py +++ b/scprint/tasks/cell_emb.py @@ -235,15 +235,22 @@ def __call__(self, model: torch.nn.Module, adata: AnnData, cache=False): pred_adata.obs.index = adata.obs.index adata.obs = pd.concat([adata.obs, pred_adata.obs], axis=1) if self.keep_all_cls_pred: - allclspred = model.pred - columns = [] + # model.pred is a dict[clsname -> tensor[n_cells, n_classes_cl]] + # (heads have different n_classes), so concatenate per head. + dfs = [] for cl in model.classes: n = model.label_counts[cl] - columns += [model.label_decoders[cl][i] for i in range(n)] - allclspred = pd.DataFrame( - allclspred, columns=columns, index=adata.obs.index - ) - adata.obs = pd.concat(adata.obs, allclspred) + columns = [model.label_decoders[cl][i] for i in range(n)] + tensor = model.pred[cl] + if hasattr(tensor, "detach"): + tensor = tensor.detach().cpu().numpy() + dfs.append( + pd.DataFrame( + tensor, columns=columns, index=adata.obs.index + ) + ) + allclspred = pd.concat(dfs, axis=1) + adata.obs = pd.concat([adata.obs, allclspred], axis=1) metrics = {} if self.doclass and not self.keep_all_cls_pred: