From d4c9df074693d8952bbcd00636a1ec260f59cc63 Mon Sep 17 00:00:00 2001 From: Topapec Date: Wed, 22 Apr 2026 18:28:10 +0300 Subject: [PATCH 1/7] feat: add fast_transformers module with FlatSASRec and UniSRec models Standalone sequential recommender package, mimics ModelBase interface without touching existing rectools code. FlatSASRec - plain ID-embedding SASRec encoder. UniSRec - pretrained text embeddings + PCA/BN adaptor, 3-phase training (ID emb -> adaptor only -> full finetune). Uses lightweight rank_topk instead of TorchRanker, reuses SASRecDataPreparator for the data pipeline. 30 tests, smoke scripts for both models. Fix: NaN*0=NaN in IEEE 754 breaks attention padding masking via multiplication, switched to masked_fill. --- rectools/fast_transformers/__init__.py | 23 + rectools/fast_transformers/lightning_wrap.py | 74 +++ rectools/fast_transformers/model.py | 325 +++++++++++++ rectools/fast_transformers/net.py | 187 +++++++ rectools/fast_transformers/ranking.py | 80 +++ .../fast_transformers/unisrec_lightning.py | 97 ++++ rectools/fast_transformers/unisrec_model.py | 458 ++++++++++++++++++ rectools/fast_transformers/unisrec_net.py | 296 +++++++++++ scripts/train_fast_sasrec.py | 77 +++ scripts/train_unisrec.py | 96 ++++ tests/fast_transformers/__init__.py | 0 tests/fast_transformers/conftest.py | 31 ++ tests/fast_transformers/test_model.py | 89 ++++ tests/fast_transformers/test_net.py | 49 ++ tests/fast_transformers/test_unisrec_model.py | 138 ++++++ tests/fast_transformers/test_unisrec_net.py | 115 +++++ 16 files changed, 2135 insertions(+) create mode 100644 rectools/fast_transformers/__init__.py create mode 100644 rectools/fast_transformers/lightning_wrap.py create mode 100644 rectools/fast_transformers/model.py create mode 100644 rectools/fast_transformers/net.py create mode 100644 rectools/fast_transformers/ranking.py create mode 100644 rectools/fast_transformers/unisrec_lightning.py create mode 100644 rectools/fast_transformers/unisrec_model.py create mode 100644 rectools/fast_transformers/unisrec_net.py create mode 100644 scripts/train_fast_sasrec.py create mode 100644 scripts/train_unisrec.py create mode 100644 tests/fast_transformers/__init__.py create mode 100644 tests/fast_transformers/conftest.py create mode 100644 tests/fast_transformers/test_model.py create mode 100644 tests/fast_transformers/test_net.py create mode 100644 tests/fast_transformers/test_unisrec_model.py create mode 100644 tests/fast_transformers/test_unisrec_net.py diff --git a/rectools/fast_transformers/__init__.py b/rectools/fast_transformers/__init__.py new file mode 100644 index 00000000..2a10affd --- /dev/null +++ b/rectools/fast_transformers/__init__.py @@ -0,0 +1,23 @@ +"""Fast Transformers: flat sequential recommenders without ItemNet hierarchy.""" + +from .lightning_wrap import FlatSASRecLightning +from .model import FlatSASRecConfig, FlatSASRecModel +from .net import FlatSASRec, SASRecBlock +from .ranking import rank_topk +from .unisrec_net import UniSRec, FeedForward +from .unisrec_lightning import UniSRecLightning +from .unisrec_model import UniSRecConfig, UniSRecModel + +__all__ = [ + "FlatSASRec", + "SASRecBlock", + "FlatSASRecLightning", + "FlatSASRecModel", + "FlatSASRecConfig", + "rank_topk", + "UniSRec", + "FeedForward", + "UniSRecLightning", + "UniSRecConfig", + "UniSRecModel", +] diff --git a/rectools/fast_transformers/lightning_wrap.py b/rectools/fast_transformers/lightning_wrap.py new file mode 100644 index 00000000..698afa10 --- /dev/null +++ b/rectools/fast_transformers/lightning_wrap.py @@ -0,0 +1,74 @@ +"""PyTorch Lightning wrapper for FlatSASRec.""" + +import typing as tp + +import torch +import pytorch_lightning as pl +from torch import nn + +from .net import FlatSASRec + + +class FlatSASRecLightning(pl.LightningModule): + """Lightning module wrapping FlatSASRec with softmax / BCE losses.""" + + SUPPORTED_LOSSES = ("softmax", "BCE") + + def __init__( + self, + net: FlatSASRec, + lr: float = 1e-3, + loss: str = "softmax", + n_negatives: int = 1, + ) -> None: + super().__init__() + self.net = net + self.lr = lr + self.loss_name = loss + self.n_negatives = n_negatives + + if loss == "softmax": + self.loss_fn = nn.CrossEntropyLoss(ignore_index=0) + elif loss == "BCE": + self.loss_fn = nn.BCEWithLogitsLoss(reduction="none") + else: + raise ValueError(f"Unsupported loss: {loss}. Use one of {self.SUPPORTED_LOSSES}") + + def on_train_start(self) -> None: + for p in self.net.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: + logits = self.net(batch) + y = batch["y"] # (B, L) + mask = y != FlatSASRec.PADDING_IDX # ignore padding positions + + if self.loss_name == "softmax": + # logits: (B, L, n_items) — full catalog + # targets need to be 0-indexed item ids (subtract 1 since item ids start from 1) + targets = y - 1 # shift to 0-based for CrossEntropyLoss; padding (0) becomes -1 -> ignore_index=0 won't work + # Actually, we set ignore_index=0 but padding maps to -1. + # Let's use a different approach: set padding targets to 0 and use ignore_index=0 + targets = y.clone() + targets[~mask] = 0 + # For CE loss: targets should index into logits dim=-1 which is [0..n_items-1] + # Our item ids in y are 1..n_items, so subtract 1 + targets = targets - 1 + targets[~mask] = -100 # PyTorch ignore index + loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100) + else: + # BCE: logits shape (B, L, 1+N) + B, L, C = logits.shape + labels = torch.zeros(B, L, C, device=logits.device) + labels[:, :, 0] = 1.0 # first column is positive + loss_per_elem = self.loss_fn(logits, labels) # (B, L, C) + # Mask out padding positions + loss_per_elem = loss_per_elem * mask.unsqueeze(-1).float() + loss = loss_per_elem.sum() / mask.sum().clamp(min=1) / C + + self.log("train_loss", loss, prog_bar=True) + return loss + + def configure_optimizers(self) -> torch.optim.Optimizer: + return torch.optim.Adam(self.parameters(), lr=self.lr, betas=(0.9, 0.98)) diff --git a/rectools/fast_transformers/model.py b/rectools/fast_transformers/model.py new file mode 100644 index 00000000..e62f9943 --- /dev/null +++ b/rectools/fast_transformers/model.py @@ -0,0 +1,325 @@ +"""FlatSASRecModel: standalone flat sequential recommender built on ModelBase.""" + +import typing as tp + +import numpy as np +import pandas as pd +import torch +import pytorch_lightning as pl +from scipy import sparse + +from rectools import Columns +from rectools.dataset import Dataset +from rectools.dataset.identifiers import IdMap +from rectools.models.base import InternalRecoTriplet, ModelBase, ModelConfig +from rectools.models.nn.transformers.sasrec import SASRecDataPreparator +from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler +from rectools.types import InternalIdsArray +from rectools.utils.config import BaseConfig + +from .lightning_wrap import FlatSASRecLightning +from .net import FlatSASRec +from .ranking import rank_topk + + +class FlatSASRecConfig(BaseConfig): + """Configuration for FlatSASRecModel.""" + + n_factors: int = 64 + n_blocks: int = 2 + n_heads: int = 2 + session_max_len: int = 32 + dropout: float = 0.1 + loss: str = "softmax" + n_negatives: int = 1 + epochs: int = 5 + batch_size: int = 128 + lr: float = 1e-3 + recommend_batch_size: int = 256 + dataloader_num_workers: int = 0 + train_min_user_interactions: int = 2 + + +class FlatSASRecModelConfig(ModelConfig): + """Full model config including cls.""" + + model: FlatSASRecConfig = FlatSASRecConfig() + + +class FlatSASRecModel(ModelBase[FlatSASRecModelConfig]): + """ + Flat SASRec model: sequential recommender without the ItemNet hierarchy. + + Uses SASRecDataPreparator for data processing and a standalone FlatSASRec + network for encoding. + """ + + config_class = FlatSASRecModelConfig + recommends_for_warm = False + recommends_for_cold = False + + def __init__( + self, + n_factors: int = 64, + n_blocks: int = 2, + n_heads: int = 2, + session_max_len: int = 32, + dropout: float = 0.1, + loss: str = "softmax", + n_negatives: int = 1, + epochs: int = 5, + batch_size: int = 128, + lr: float = 1e-3, + recommend_batch_size: int = 256, + dataloader_num_workers: int = 0, + train_min_user_interactions: int = 2, + verbose: int = 0, + ) -> None: + super().__init__(verbose=verbose) + + if loss not in FlatSASRecLightning.SUPPORTED_LOSSES: + raise ValueError(f"Unsupported loss '{loss}'. Choose from {FlatSASRecLightning.SUPPORTED_LOSSES}") + + self.n_factors = n_factors + self.n_blocks = n_blocks + self.n_heads = n_heads + self.session_max_len = session_max_len + self.dropout = dropout + self.loss = loss + self.n_negatives = n_negatives + self.epochs = epochs + self.batch_size = batch_size + self.lr = lr + self.recommend_batch_size = recommend_batch_size + self.dataloader_num_workers = dataloader_num_workers + self.train_min_user_interactions = train_min_user_interactions + + self._net: tp.Optional[FlatSASRec] = None + self._lightning: tp.Optional[FlatSASRecLightning] = None + self._data_preparator: tp.Optional[SASRecDataPreparator] = None + + def _get_config(self) -> FlatSASRecModelConfig: + return FlatSASRecModelConfig( + cls=self.__class__, + verbose=self.verbose, + model=FlatSASRecConfig( + n_factors=self.n_factors, + n_blocks=self.n_blocks, + n_heads=self.n_heads, + session_max_len=self.session_max_len, + dropout=self.dropout, + loss=self.loss, + n_negatives=self.n_negatives, + epochs=self.epochs, + batch_size=self.batch_size, + lr=self.lr, + recommend_batch_size=self.recommend_batch_size, + dataloader_num_workers=self.dataloader_num_workers, + train_min_user_interactions=self.train_min_user_interactions, + ), + ) + + @classmethod + def _from_config(cls, config: FlatSASRecModelConfig) -> "FlatSASRecModel": + m = config.model + return cls( + n_factors=m.n_factors, + n_blocks=m.n_blocks, + n_heads=m.n_heads, + session_max_len=m.session_max_len, + dropout=m.dropout, + loss=m.loss, + n_negatives=m.n_negatives, + epochs=m.epochs, + batch_size=m.batch_size, + lr=m.lr, + recommend_batch_size=m.recommend_batch_size, + dataloader_num_workers=m.dataloader_num_workers, + train_min_user_interactions=m.train_min_user_interactions, + verbose=config.verbose, + ) + + def _fit(self, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> None: + negative_sampler = None + n_negatives_dp: tp.Optional[int] = None + if self.loss == "BCE": + negative_sampler = CatalogUniformSampler(n_negatives=self.n_negatives) + n_negatives_dp = self.n_negatives + + dp = SASRecDataPreparator( + session_max_len=self.session_max_len, + batch_size=self.batch_size, + dataloader_num_workers=self.dataloader_num_workers, + train_min_user_interactions=self.train_min_user_interactions, + n_negatives=n_negatives_dp, + negative_sampler=negative_sampler, + ) + dp.process_dataset_train(dataset) + self._data_preparator = dp + + n_items = dp.item_id_map.size # includes extra tokens (padding) + # item ids in the preparator go from 0 (padding) to n_items-1 + # FlatSASRec expects n_items = max real item count (embedding table = n_items+1 with padding at 0) + # The preparator's item_id_map.size includes the padding token, so real items = size - 1 + n_real_items = dp.item_id_map.size - dp.n_item_extra_tokens + + net = FlatSASRec( + n_items=n_real_items, + n_factors=self.n_factors, + n_blocks=self.n_blocks, + n_heads=self.n_heads, + session_max_len=self.session_max_len, + dropout=self.dropout, + ) + + lightning_model = FlatSASRecLightning( + net=net, + lr=self.lr, + loss=self.loss, + n_negatives=self.n_negatives, + ) + + train_dl = dp.get_dataloader_train() + val_dl = dp.get_dataloader_val() + + trainer = pl.Trainer( + max_epochs=self.epochs, + enable_checkpointing=False, + enable_model_summary=False, + logger=self.verbose > 0, + enable_progress_bar=self.verbose > 0, + ) + trainer.fit(lightning_model, train_dataloaders=train_dl, val_dataloaders=val_dl) + + self._net = net + self._lightning = lightning_model + + def _custom_transform_dataset_u2i( + self, + dataset: Dataset, + users: tp.Any, + on_unsupported_targets: tp.Any, + context: tp.Optional[pd.DataFrame] = None, + ) -> Dataset: + assert self._data_preparator is not None + return self._data_preparator.transform_dataset_u2i(dataset, users) + + def _custom_transform_dataset_i2i( + self, dataset: Dataset, target_items: tp.Any, on_unsupported_targets: tp.Any + ) -> Dataset: + assert self._data_preparator is not None + return self._data_preparator.transform_dataset_i2i(dataset) + + @torch.no_grad() + def _get_user_embeddings(self, dataset: Dataset) -> torch.Tensor: + """Compute user embeddings from their interaction sequences.""" + assert self._data_preparator is not None and self._net is not None + self._net.eval() + + recommend_dl = self._data_preparator.get_dataloader_recommend(dataset, self.recommend_batch_size) + device = next(self._net.parameters()).device + + all_embs = [] + for batch in recommend_dl: + x = batch["x"].to(device) + embs = self._net.encode_last(x) # (batch, D) + all_embs.append(embs) + return torch.cat(all_embs, dim=0) + + @torch.no_grad() + def _get_item_embeddings(self) -> torch.Tensor: + """Get all item embeddings from the network.""" + assert self._net is not None + self._net.eval() + return self._net.all_item_embeddings() + + def _recommend_u2i( + self, + user_ids: InternalIdsArray, + dataset: Dataset, + k: int, + filter_viewed: bool, + sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], + ) -> InternalRecoTriplet: + assert self._data_preparator is not None + device = next(self._net.parameters()).device # type: ignore + + user_embs = self._get_user_embeddings(dataset) # (n_users, D) + item_embs = self._get_item_embeddings() # (n_items, D) + + # Build filter matrix + filter_csr = None + if filter_viewed: + ui_mat = dataset.get_user_item_matrix(include_weights=False) + n_users_mat = ui_mat.shape[0] + n_items_emb = item_embs.shape[0] + n_extra = self._data_preparator.n_item_extra_tokens + # item_embs[i] corresponds to preparator internal item id (i + n_extra). + # ui_mat columns are dataset internal item ids which share the preparator's id_map. + # Slice out the extra-token columns and pad/trim to exactly n_items_emb cols. + if ui_mat.shape[1] > n_extra: + sliced = ui_mat[:, n_extra:] + else: + sliced = sparse.csr_matrix((n_users_mat, 0)) + n_cols = sliced.shape[1] + if n_cols < n_items_emb: + pad = sparse.csr_matrix((n_users_mat, n_items_emb - n_cols)) + filter_csr = sparse.hstack([sliced, pad], format="csr") + elif n_cols > n_items_emb: + filter_csr = sliced[:, :n_items_emb] + else: + filter_csr = sliced + + # Map whitelist to item_embs indices (0-based, without extra tokens) + whitelist = None + if sorted_item_ids_to_recommend is not None: + n_extra = self._data_preparator.n_item_extra_tokens + wl = sorted_item_ids_to_recommend - n_extra + whitelist = wl[(wl >= 0) & (wl < item_embs.shape[0])] + + u_ids, i_ids, scores = rank_topk( + user_embs, item_embs, k, + filter_csr=filter_csr, + whitelist=whitelist, + batch_size=self.recommend_batch_size, + ) + + # Convert item indices back to preparator's internal ids + n_extra = self._data_preparator.n_item_extra_tokens + i_ids = i_ids + n_extra + + return u_ids, i_ids, scores + + def _recommend_i2i( + self, + target_ids: InternalIdsArray, + dataset: Dataset, + k: int, + sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], + ) -> InternalRecoTriplet: + assert self._data_preparator is not None and self._net is not None + device = next(self._net.parameters()).device + + item_embs = self._get_item_embeddings() # (n_items, D) + n_extra = self._data_preparator.n_item_extra_tokens + + # Target embeddings: target_ids are preparator internal ids + target_emb_idx = target_ids - n_extra + target_embs = item_embs[target_emb_idx] # (n_targets, D) + + whitelist = None + if sorted_item_ids_to_recommend is not None: + wl = sorted_item_ids_to_recommend - n_extra + whitelist = wl[(wl >= 0) & (wl < item_embs.shape[0])] + + t_ids, i_ids, scores = rank_topk( + target_embs, item_embs, k, + whitelist=whitelist, + batch_size=self.recommend_batch_size, + ) + + # Map back + result_target_ids = target_ids[t_ids] + result_item_ids = i_ids + n_extra + + return result_target_ids, result_item_ids, scores diff --git a/rectools/fast_transformers/net.py b/rectools/fast_transformers/net.py new file mode 100644 index 00000000..81d4dd7d --- /dev/null +++ b/rectools/fast_transformers/net.py @@ -0,0 +1,187 @@ +"""Flat SASRec network: pre-norm transformer encoder with plain id embeddings.""" + +import typing as tp + +import torch +from torch import nn + + +class SASRecBlock(nn.Module): + """Pre-norm transformer block: LayerNorm -> MHA -> residual -> LayerNorm -> FFN -> residual.""" + + def __init__(self, n_factors: int, n_heads: int, dropout: float = 0.1) -> None: + super().__init__() + self.ln1 = nn.LayerNorm(n_factors) + self.mha = nn.MultiheadAttention(n_factors, n_heads, dropout=dropout, batch_first=True) + self.ln2 = nn.LayerNorm(n_factors) + self.ffn = nn.Sequential( + nn.Linear(n_factors, n_factors * 4), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(n_factors * 4, n_factors), + nn.Dropout(dropout), + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: tp.Optional[torch.Tensor] = None, + key_padding_mask: tp.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + h = self.ln1(x) + h, _ = self.mha(h, h, h, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) + x = x + h + h = self.ln2(x) + x = x + self.ffn(h) + return x + + +class FlatSASRec(nn.Module): + """ + Flat SASRec: sequential recommender with plain id-embedding table + (no ItemNet hierarchy). + + Parameters + ---------- + n_items : int + Total number of items (excluding padding token 0). + n_factors : int + Embedding / hidden dimension. + n_blocks : int + Number of transformer blocks. + n_heads : int + Number of attention heads. + session_max_len : int + Maximum sequence length. + dropout : float + Dropout rate. + """ + + PADDING_IDX = 0 + + def __init__( + self, + n_items: int, + n_factors: int, + n_blocks: int, + n_heads: int, + session_max_len: int, + dropout: float = 0.1, + ) -> None: + super().__init__() + self.n_items = n_items + self.n_factors = n_factors + self.session_max_len = session_max_len + + # +1 for padding at index 0 + self.item_emb = nn.Embedding(n_items + 1, n_factors, padding_idx=self.PADDING_IDX) + self.pos_emb = nn.Embedding(session_max_len, n_factors) + self.emb_dropout = nn.Dropout(dropout) + + self.blocks = nn.ModuleList([SASRecBlock(n_factors, n_heads, dropout) for _ in range(n_blocks)]) + self.final_ln = nn.LayerNorm(n_factors) + + def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: + return torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode full sequence. + + Parameters + ---------- + x : LongTensor (B, L) + Item id sequences (0 = padding). + + Returns + ------- + Tensor (B, L, D) + """ + B, L = x.shape + positions = torch.arange(L, device=x.device).unsqueeze(0) + h = self.item_emb(x) + self.pos_emb(positions) + h = self.emb_dropout(h) + + # timeline_mask: zero out padding positions to prevent NaN from attention + timeline_mask = (x != self.PADDING_IDX).unsqueeze(-1).float() # (B, L, 1) + attn_mask = self._causal_mask(L, x.device) + key_padding_mask = x == self.PADDING_IDX + + for block in self.blocks: + h = h * timeline_mask + h = block(h, attn_mask=attn_mask, key_padding_mask=key_padding_mask) + h = h * timeline_mask + h = self.final_ln(h) + return h + + def encode_last(self, x: torch.Tensor) -> torch.Tensor: + """ + Encode and return only the last non-padding position representation. + + Parameters + ---------- + x : LongTensor (B, L) + + Returns + ------- + Tensor (B, D) + """ + h = self.encode(x) # (B, L, D) + # Find last non-padding position per row + non_pad = (x != self.PADDING_IDX) # (B, L) + # lengths: number of non-pad tokens + lengths = non_pad.sum(dim=1) # (B,) + # Clamp to at least 1 to avoid index -1 for fully-padded rows + last_idx = (lengths - 1).clamp(min=0) + # We use left-padding, so last non-pad is at position (L - 1) if any token exists + # Actually with left padding, non-pad tokens are at the end, so the last position is L-1 + # But let's compute correctly: the last non-pad index + # With left-padding: first non-pad is at L - length, last non-pad is at L - 1 + B = x.shape[0] + last_pos = x.shape[1] - 1 # last position is always the last for left-padded sequences + return h[:, last_pos, :] # (B, D) + + def all_item_embeddings(self) -> torch.Tensor: + """ + Return embeddings for all items (1..n_items), excluding padding. + + Returns + ------- + Tensor (n_items, D) + """ + ids = torch.arange(1, self.n_items + 1, device=self.item_emb.weight.device) + return self.item_emb(ids) + + def forward(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor: + """ + Training forward pass. + + Parameters + ---------- + batch : dict + Must contain 'x' (B, L) and 'y' (B, L). + Optionally 'negatives' (B, L, N) for candidate-logits branch. + + Returns + ------- + logits : Tensor + If negatives present: (B, L, 1 + N) — positive + negative logits. + Otherwise: (B, L, n_items) — full catalog logits. + """ + x = batch["x"] # (B, L) + y = batch["y"] # (B, L) + + h = self.encode(x) # (B, L, D) + + if "negatives" in batch: + negatives = batch["negatives"] # (B, L, N) + pos_emb = self.item_emb(y).unsqueeze(3) # (B, L, D, 1) + neg_emb = self.item_emb(negatives) # (B, L, N, D) + neg_emb = neg_emb.transpose(2, 3) # (B, L, D, N) + all_emb = torch.cat([pos_emb, neg_emb], dim=3) # (B, L, D, 1+N) + logits = (h.unsqueeze(2) @ all_emb).squeeze(2) # (B, L, 1+N) + # -> shape is (B, L, 1+N) where first column is positive logit + else: + item_embs = self.all_item_embeddings() # (n_items, D) + logits = h @ item_embs.T # (B, L, n_items) + return logits diff --git a/rectools/fast_transformers/ranking.py b/rectools/fast_transformers/ranking.py new file mode 100644 index 00000000..9825d763 --- /dev/null +++ b/rectools/fast_transformers/ranking.py @@ -0,0 +1,80 @@ +"""Batch top-k ranking with optional viewed-item filtering.""" + +import typing as tp + +import numpy as np +import torch +from scipy import sparse + + +def rank_topk( + user_embs: torch.Tensor, + item_embs: torch.Tensor, + k: int, + filter_csr: tp.Optional[sparse.csr_matrix] = None, + whitelist: tp.Optional[np.ndarray] = None, + batch_size: int = 256, +) -> tp.Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Batch-wise top-k ranking: user_embs @ item_embs.T with optional filtering. + + Parameters + ---------- + user_embs : Tensor (N, D) + User embeddings. + item_embs : Tensor (M, D) + Item embeddings. + k : int + Number of items to recommend per user. + filter_csr : csr_matrix (N, M), optional + Binary matrix of viewed items to mask out. + whitelist : ndarray, optional + Sorted array of item indices to consider. + batch_size : int + Batch size for processing users. + + Returns + ------- + all_user_ids, all_item_ids, all_scores : ndarray, ndarray, ndarray + Flattened arrays of recommendations. + """ + device = user_embs.device + n_users = user_embs.shape[0] + + if whitelist is not None: + item_embs = item_embs[whitelist] + + all_user_ids = [] + all_item_ids = [] + all_scores = [] + + for start in range(0, n_users, batch_size): + end = min(start + batch_size, n_users) + scores = user_embs[start:end] @ item_embs.T # (batch, M) + + if filter_csr is not None: + batch_csr = filter_csr[start:end] + if whitelist is not None: + batch_csr = batch_csr[:, whitelist] + viewed_mask = torch.tensor(batch_csr.toarray(), dtype=torch.bool, device=device) + scores[viewed_mask] = -float("inf") + + actual_k = min(k, scores.shape[1]) + topk_scores, topk_idx = torch.topk(scores, actual_k, dim=1) # (batch, k) + + if whitelist is not None: + topk_idx_np = topk_idx.cpu().numpy() + topk_idx_mapped = whitelist[topk_idx_np] + else: + topk_idx_mapped = topk_idx.cpu().numpy() + + batch_users = np.arange(start, end) + user_ids = np.repeat(batch_users, actual_k) + item_ids = topk_idx_mapped.ravel() + s = topk_scores.cpu().numpy().ravel() + + all_user_ids.append(user_ids) + all_item_ids.append(item_ids) + all_scores.append(s) + + return np.concatenate(all_user_ids), np.concatenate(all_item_ids), np.concatenate(all_scores) diff --git a/rectools/fast_transformers/unisrec_lightning.py b/rectools/fast_transformers/unisrec_lightning.py new file mode 100644 index 00000000..c0c440f3 --- /dev/null +++ b/rectools/fast_transformers/unisrec_lightning.py @@ -0,0 +1,97 @@ +"""Lightning wrapper for UniSRec: supports full-softmax and sampled CE loss.""" + +import typing as tp + +import torch +import torch.nn.functional as F +import pytorch_lightning as pl + +from .unisrec_net import UniSRec + + +class UniSRecLightning(pl.LightningModule): + """ + Thin Lightning wrapper reused across all training phases. + + Each phase creates a fresh ``UniSRecLightning`` with appropriate + ``param_groups`` and ``use_id`` flag, sharing the same ``net`` instance. + """ + + def __init__( + self, + net: UniSRec, + param_groups: tp.List[tp.Dict[str, tp.Any]], + use_id: bool = False, + ) -> None: + super().__init__() + self.net = net + self._param_groups = param_groups + self.use_id = use_id + + # ── helpers ── + + def _get_item_embs(self, item_ids: torch.Tensor) -> torch.Tensor: + if self.use_id: + return self.net.item_emb(item_ids) + return self.net._adapt_score(self.net._sample_frozen(item_ids)) + + # ── training step ── + + def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: + input_ids = batch["x"] + labels = batch["y"] + hidden = self.net(input_ids, use_id=self.use_id) # (B, L, D) + + if "negatives" in batch: + loss = self._sampled_ce_loss(hidden, labels, batch["negatives"]) + else: + loss = self._full_softmax_loss(hidden, labels) + + self.log("train_loss", loss, prog_bar=True) + return loss + + def _full_softmax_loss(self, hidden: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + if self.use_id: + all_emb = self.net.item_emb.weight # (n_items+1, D) + else: + all_emb = self.net.project_all() # (n_items+1, D) + + logits = hidden @ all_emb.T # (B, L, n_items+1) + logits[:, :, 0] = float("-inf") # never predict padding + + targets = labels.clone() + targets[targets == 0] = -100 # padding → ignore + return F.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + ignore_index=-100, + ) + + def _sampled_ce_loss( + self, + hidden: torch.Tensor, + labels: torch.Tensor, + negatives: torch.Tensor, + ) -> torch.Tensor: + emb_pos = self._get_item_embs(labels) # (B, L, D) + logits_pos = (hidden * emb_pos).sum(dim=-1) # (B, L) + + emb_neg = self._get_item_embs(negatives) # (B, L, N, D) + logits_neg = torch.matmul( # (B, L, N) + hidden.unsqueeze(2), emb_neg.transpose(2, 3), + ).squeeze(2) + + logits = torch.cat([logits_pos.unsqueeze(-1), logits_neg], dim=-1) # (B, L, 1+N) + + targets = torch.zeros_like(labels) # positive class = index 0 + targets[labels == 0] = -100 # padding → ignore + return F.cross_entropy( + logits.view(-1, logits.size(-1)), + targets.view(-1), + ignore_index=-100, + ) + + # ── optimizer ── + + def configure_optimizers(self) -> torch.optim.Optimizer: + return torch.optim.AdamW(self._param_groups) diff --git a/rectools/fast_transformers/unisrec_model.py b/rectools/fast_transformers/unisrec_model.py new file mode 100644 index 00000000..a1990884 --- /dev/null +++ b/rectools/fast_transformers/unisrec_model.py @@ -0,0 +1,458 @@ +"""UniSRecModel: ModelBase wrapper with three-phase training.""" + +import typing as tp + +import numpy as np +import torch +import pytorch_lightning as pl +from scipy import sparse + +from rectools.dataset import Dataset +from rectools.models.base import InternalRecoTriplet, ModelBase, ModelConfig +from rectools.models.nn.transformers.sasrec import SASRecDataPreparator +from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler +from rectools.types import InternalIdsArray +from rectools.utils.config import BaseConfig + +from .unisrec_net import UniSRec +from .unisrec_lightning import UniSRecLightning +from .ranking import rank_topk + + +class UniSRecConfig(BaseConfig): + """Hyperparameters for UniSRecModel (without pretrained embeddings).""" + + n_factors: int = 256 + projection_hidden: int = 512 + n_blocks: int = 2 + n_heads: int = 1 + session_max_len: int = 200 + dropout: float = 0.1 + adaptor_dropout: float = 0.2 + adaptor_type: str = "pca" + use_adaptor_ffn: bool = True + + phase1_epochs: int = 10 + phase2_epochs: int = 10 + phase3_epochs: int = 10 + phase1_lr: float = 1e-3 + phase2_lr: float = 3e-4 + phase3_lr: float = 1e-4 + lr_head: float = 0.3 + lr_wp: float = 0.1 + lr_transformer: float = 3.0 + + grad_clip: float = 1.0 + weight_decay: float = 0.01 + batch_size: int = 128 + recommend_batch_size: int = 256 + dataloader_num_workers: int = 0 + train_min_user_interactions: int = 2 + n_negatives: tp.Optional[int] = None + + +class UniSRecModelConfig(ModelConfig): + """Full model config (cls + verbose + hyper-params).""" + + model: UniSRecConfig = UniSRecConfig() + + +class UniSRecModel(ModelBase[UniSRecModelConfig]): + """ + UniSRec integrated into RecTools via ``ModelBase``. + + Three training phases + --------------------- + 1. **Phase 1** — SASRec on ID embeddings (``item_emb`` + transformer). + 2. **Phase 2** — Adaptor only (transformer frozen, pretrained embeddings). + 3. **Phase 3** — Full fine-tune (adaptor + transformer, pretrained embeddings). + + Parameters + ---------- + pretrained_item_embeddings : Tensor + Shape ``(max_external_item_id + 1, D_text)`` or + ``(max_external_item_id + 1, n_variants, D_text)``. + Index *i* holds the text embedding for the item whose **external** ID + equals *i*. Index 0 is padding (zeros). + During ``fit`` the tensor is reindexed to match the internal ID map + produced by ``SASRecDataPreparator``. + """ + + config_class = UniSRecModelConfig + recommends_for_warm = False + recommends_for_cold = False + + def __init__( + self, + pretrained_item_embeddings: torch.Tensor, + n_factors: int = 256, + projection_hidden: int = 512, + n_blocks: int = 2, + n_heads: int = 1, + session_max_len: int = 200, + dropout: float = 0.1, + adaptor_dropout: float = 0.2, + adaptor_type: str = "pca", + use_adaptor_ffn: bool = True, + phase1_epochs: int = 10, + phase2_epochs: int = 10, + phase3_epochs: int = 10, + phase1_lr: float = 1e-3, + phase2_lr: float = 3e-4, + phase3_lr: float = 1e-4, + lr_head: float = 0.3, + lr_wp: float = 0.1, + lr_transformer: float = 3.0, + grad_clip: float = 1.0, + weight_decay: float = 0.01, + batch_size: int = 128, + recommend_batch_size: int = 256, + dataloader_num_workers: int = 0, + train_min_user_interactions: int = 2, + n_negatives: tp.Optional[int] = None, + verbose: int = 0, + ) -> None: + super().__init__(verbose=verbose) + self.pretrained_item_embeddings = pretrained_item_embeddings + self.n_factors = n_factors + self.projection_hidden = projection_hidden + self.n_blocks = n_blocks + self.n_heads = n_heads + self.session_max_len = session_max_len + self.dropout = dropout + self.adaptor_dropout = adaptor_dropout + self.adaptor_type = adaptor_type + self.use_adaptor_ffn = use_adaptor_ffn + self.phase1_epochs = phase1_epochs + self.phase2_epochs = phase2_epochs + self.phase3_epochs = phase3_epochs + self.phase1_lr = phase1_lr + self.phase2_lr = phase2_lr + self.phase3_lr = phase3_lr + self.lr_head = lr_head + self.lr_wp = lr_wp + self.lr_transformer = lr_transformer + self.grad_clip = grad_clip + self.weight_decay = weight_decay + self.batch_size = batch_size + self.recommend_batch_size = recommend_batch_size + self.dataloader_num_workers = dataloader_num_workers + self.train_min_user_interactions = train_min_user_interactions + self.n_negatives = n_negatives + + self._net: tp.Optional[UniSRec] = None + self._data_preparator: tp.Optional[SASRecDataPreparator] = None + + # ── config boilerplate (embeddings are not serialised) ── + + def _get_config(self) -> UniSRecModelConfig: + return UniSRecModelConfig( + cls=self.__class__, + verbose=self.verbose, + model=UniSRecConfig( + n_factors=self.n_factors, + projection_hidden=self.projection_hidden, + n_blocks=self.n_blocks, + n_heads=self.n_heads, + session_max_len=self.session_max_len, + dropout=self.dropout, + adaptor_dropout=self.adaptor_dropout, + adaptor_type=self.adaptor_type, + use_adaptor_ffn=self.use_adaptor_ffn, + phase1_epochs=self.phase1_epochs, + phase2_epochs=self.phase2_epochs, + phase3_epochs=self.phase3_epochs, + phase1_lr=self.phase1_lr, + phase2_lr=self.phase2_lr, + phase3_lr=self.phase3_lr, + lr_head=self.lr_head, + lr_wp=self.lr_wp, + lr_transformer=self.lr_transformer, + grad_clip=self.grad_clip, + weight_decay=self.weight_decay, + batch_size=self.batch_size, + recommend_batch_size=self.recommend_batch_size, + dataloader_num_workers=self.dataloader_num_workers, + train_min_user_interactions=self.train_min_user_interactions, + n_negatives=self.n_negatives, + ), + ) + + @classmethod + def _from_config(cls, config: UniSRecModelConfig) -> "UniSRecModel": + raise NotImplementedError( + "UniSRecModel cannot be restored from config alone — " + "pretrained_item_embeddings must be supplied at construction time." + ) + + # ── helpers ── + + def _align_embeddings(self, dp: SASRecDataPreparator) -> torch.Tensor: + """Reindex ``pretrained_item_embeddings`` to the preparator's internal IDs.""" + ext_ids = dp.item_id_map.to_external.values # array[internal_id] → external_id + n_internal = dp.item_id_map.size + n_extra = dp.n_item_extra_tokens + + emb = self.pretrained_item_embeddings + if emb.ndim == 2: + aligned = torch.zeros(n_internal, emb.shape[1]) + else: + aligned = torch.zeros(n_internal, emb.shape[1], emb.shape[2]) + + for int_id in range(n_extra, n_internal): + ext_id = int(ext_ids[int_id]) + if 0 <= ext_id < emb.shape[0]: + aligned[int_id] = emb[ext_id] + + return aligned + + def _make_trainer(self, max_epochs: int) -> pl.Trainer: + return pl.Trainer( + max_epochs=max_epochs, + gradient_clip_val=self.grad_clip, + enable_checkpointing=False, + enable_model_summary=False, + logger=self.verbose > 0, + enable_progress_bar=self.verbose > 0, + ) + + # ── Phase param-groups ── + + def _phase2_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]: + if self.adaptor_type == "pca": + groups: tp.List[tp.Dict[str, tp.Any]] = [ + {"params": [net.whitening_proj], "lr": self.phase2_lr * self.lr_wp, "weight_decay": 0.0}, + {"params": [net.whitening_bias], "lr": self.phase2_lr * 10.0, "weight_decay": 0.0}, + ] + if net.head is not None: + groups.append({ + "params": list(net.head.parameters()), + "lr": self.phase2_lr * self.lr_head, + "weight_decay": self.weight_decay, + }) + else: + groups = [ + {"params": list(net.bn_input.parameters()), "lr": self.phase2_lr, "weight_decay": 0.0}, + {"params": list(net.bn_score.parameters()), "lr": self.phase2_lr, "weight_decay": 0.0}, + {"params": list(net.head.parameters()), "lr": self.phase2_lr * self.lr_head, "weight_decay": self.weight_decay}, + ] + return groups + + def _phase3_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]: + # adaptor + if self.adaptor_type == "pca": + adaptor: tp.List[tp.Dict[str, tp.Any]] = [ + {"params": [net.whitening_proj], "lr": self.phase3_lr * self.lr_wp, "weight_decay": 0.0}, + {"params": [net.whitening_bias], "lr": self.phase3_lr * 10.0, "weight_decay": 0.0}, + ] + else: + adaptor = [ + {"params": list(net.bn_input.parameters()), "lr": self.phase3_lr, "weight_decay": 0.0}, + {"params": list(net.bn_score.parameters()), "lr": self.phase3_lr, "weight_decay": 0.0}, + ] + # head + head: tp.List[tp.Dict[str, tp.Any]] = [] + if net.head is not None: + head = [{"params": list(net.head.parameters()), "lr": self.phase3_lr * self.lr_head, "weight_decay": self.weight_decay}] + # transformer + transformer = [ + {"params": list(net.pos_emb.parameters()), "lr": self.phase3_lr * self.lr_transformer, "weight_decay": 0.0}, + { + "params": ( + [p for l in net.attention_layers for p in l.parameters()] + + [p for l in net.forward_layers for p in l.parameters()] + ), + "lr": self.phase3_lr * self.lr_transformer, + "weight_decay": self.weight_decay, + }, + { + "params": ( + [p for l in net.attention_layernorms for p in l.parameters()] + + [p for l in net.forward_layernorms for p in l.parameters()] + + list(net.last_layernorm.parameters()) + ), + "lr": self.phase3_lr, + "weight_decay": 0.0, + }, + ] + return adaptor + head + transformer + + # ── fit ── + + def _fit(self, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> None: + # Data preparation + negative_sampler = None + n_negatives_dp: tp.Optional[int] = None + if self.n_negatives is not None: + negative_sampler = CatalogUniformSampler(n_negatives=self.n_negatives) + n_negatives_dp = self.n_negatives + + dp = SASRecDataPreparator( + session_max_len=self.session_max_len, + batch_size=self.batch_size, + dataloader_num_workers=self.dataloader_num_workers, + train_min_user_interactions=self.train_min_user_interactions, + n_negatives=n_negatives_dp, + negative_sampler=negative_sampler, + ) + dp.process_dataset_train(dataset) + self._data_preparator = dp + + n_real_items = dp.item_id_map.size - dp.n_item_extra_tokens + aligned_emb = self._align_embeddings(dp) + + net = UniSRec( + n_items=n_real_items, + pretrained_embeddings=aligned_emb, + n_factors=self.n_factors, + projection_hidden=self.projection_hidden, + n_blocks=self.n_blocks, + n_heads=self.n_heads, + session_max_len=self.session_max_len, + dropout=self.dropout, + adaptor_dropout=self.adaptor_dropout, + adaptor_type=self.adaptor_type, + use_adaptor_ffn=self.use_adaptor_ffn, + ) + + train_dl = dp.get_dataloader_train() + + # ── Phase 1: ID embeddings ── + if self.phase1_epochs > 0: + p1_params = [{"params": list(net.item_emb.parameters()) + net.transformer_params, "lr": self.phase1_lr}] + lm = UniSRecLightning(net, p1_params, use_id=True) + self._make_trainer(self.phase1_epochs).fit(lm, train_dl) + + # ── Phase 2: adaptor only (transformer frozen) ── + if self.phase2_epochs > 0 and self.use_adaptor_ffn: + net.freeze_transformer() + lm = UniSRecLightning(net, self._phase2_params(net), use_id=False) + self._make_trainer(self.phase2_epochs).fit(lm, train_dl) + + # ── Phase 3: full fine-tune ── + if self.phase3_epochs > 0: + net.unfreeze_transformer() + lm = UniSRecLightning(net, self._phase3_params(net), use_id=False) + self._make_trainer(self.phase3_epochs).fit(lm, train_dl) + + self._net = net + + # ── dataset transforms ── + + def _custom_transform_dataset_u2i( + self, + dataset: Dataset, + users: tp.Any, + on_unsupported_targets: tp.Any, + context: tp.Optional["pd.DataFrame"] = None, + ) -> Dataset: + assert self._data_preparator is not None + return self._data_preparator.transform_dataset_u2i(dataset, users) + + def _custom_transform_dataset_i2i( + self, dataset: Dataset, target_items: tp.Any, on_unsupported_targets: tp.Any + ) -> Dataset: + assert self._data_preparator is not None + return self._data_preparator.transform_dataset_i2i(dataset) + + # ── embeddings for ranking ── + + @torch.no_grad() + def _get_user_embeddings(self, dataset: Dataset) -> torch.Tensor: + assert self._data_preparator is not None and self._net is not None + self._net.eval() + device = next(self._net.parameters()).device + recommend_dl = self._data_preparator.get_dataloader_recommend(dataset, self.recommend_batch_size) + all_embs = [] + for batch in recommend_dl: + x = batch["x"].to(device) + all_embs.append(self._net.encode_last(x, use_id=False)) + return torch.cat(all_embs, dim=0) + + @torch.no_grad() + def _get_item_embeddings(self) -> torch.Tensor: + assert self._net is not None + self._net.eval() + all_emb = self._net.project_all() # (n_items+1, D) + return all_emb[1:] # skip padding → (n_items, D) + + # ── recommend ── + + def _recommend_u2i( + self, + user_ids: InternalIdsArray, + dataset: Dataset, + k: int, + filter_viewed: bool, + sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], + ) -> InternalRecoTriplet: + assert self._data_preparator is not None + device = next(self._net.parameters()).device # type: ignore[union-attr] + + user_embs = self._get_user_embeddings(dataset) + item_embs = self._get_item_embeddings() + + # viewed-item filter + filter_csr = None + if filter_viewed: + ui_mat = dataset.get_user_item_matrix(include_weights=False) + n_users_mat = ui_mat.shape[0] + n_items_emb = item_embs.shape[0] + n_extra = self._data_preparator.n_item_extra_tokens + + sliced = ui_mat[:, n_extra:] if ui_mat.shape[1] > n_extra else sparse.csr_matrix((n_users_mat, 0)) + n_cols = sliced.shape[1] + if n_cols < n_items_emb: + filter_csr = sparse.hstack([sliced, sparse.csr_matrix((n_users_mat, n_items_emb - n_cols))], format="csr") + elif n_cols > n_items_emb: + filter_csr = sliced[:, :n_items_emb] + else: + filter_csr = sliced + + # whitelist + whitelist = None + if sorted_item_ids_to_recommend is not None: + n_extra = self._data_preparator.n_item_extra_tokens + wl = sorted_item_ids_to_recommend - n_extra + whitelist = wl[(wl >= 0) & (wl < item_embs.shape[0])] + + u_ids, i_ids, scores = rank_topk( + user_embs, item_embs, k, + filter_csr=filter_csr, + whitelist=whitelist, + batch_size=self.recommend_batch_size, + ) + + n_extra = self._data_preparator.n_item_extra_tokens + i_ids = i_ids + n_extra + return u_ids, i_ids, scores + + def _recommend_i2i( + self, + target_ids: InternalIdsArray, + dataset: Dataset, + k: int, + sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], + ) -> InternalRecoTriplet: + assert self._data_preparator is not None and self._net is not None + + item_embs = self._get_item_embeddings() + n_extra = self._data_preparator.n_item_extra_tokens + + target_emb_idx = target_ids - n_extra + target_embs = item_embs[target_emb_idx] + + whitelist = None + if sorted_item_ids_to_recommend is not None: + wl = sorted_item_ids_to_recommend - n_extra + whitelist = wl[(wl >= 0) & (wl < item_embs.shape[0])] + + t_ids, i_ids, scores = rank_topk( + target_embs, item_embs, k, + whitelist=whitelist, + batch_size=self.recommend_batch_size, + ) + + result_target_ids = target_ids[t_ids] + result_item_ids = i_ids + n_extra + return result_target_ids, result_item_ids, scores diff --git a/rectools/fast_transformers/unisrec_net.py b/rectools/fast_transformers/unisrec_net.py new file mode 100644 index 00000000..2e83b5e8 --- /dev/null +++ b/rectools/fast_transformers/unisrec_net.py @@ -0,0 +1,296 @@ +"""UniSRec network: SASRec encoder with pretrained text embeddings and learnable adaptor.""" + +import typing as tp + +import torch +from torch import nn + + +def _make_mlp(in_dim: int, hidden_dim: int, out_dim: int, dropout: float) -> nn.Sequential: + return nn.Sequential( + nn.Linear(in_dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, out_dim), + ) + + +class FeedForward(nn.Module): + """Point-wise FFN via Conv1d (kernel_size=1), matching the reference UniSRec.""" + + def __init__(self, hidden_units: int, dropout_rate: float) -> None: + super().__init__() + self.conv1 = nn.Conv1d(hidden_units, hidden_units, kernel_size=1) + self.dropout1 = nn.Dropout(p=dropout_rate) + self.relu = nn.ReLU() + self.conv2 = nn.Conv1d(hidden_units, hidden_units, kernel_size=1) + self.dropout2 = nn.Dropout(p=dropout_rate) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + outputs = self.conv1(inputs.transpose(-1, -2)) + outputs = self.relu(self.dropout1(outputs)) + outputs = self.conv2(outputs) + outputs = self.dropout2(outputs) + return outputs.transpose(-1, -2) + + +class UniSRec(nn.Module): + """ + UniSRec: sequential recommender with pretrained text embeddings + adaptor. + + Architecture: + frozen_emb --> adaptor (PCA/BN + optional MLP) --> SASRec encoder + item_emb --> SASRec encoder (Phase 1, ID-based) + + Parameters + ---------- + n_items : int + Number of real items (excluding padding token at index 0). + pretrained_embeddings : Tensor + Shape ``(n_items + 1, D_text)`` or ``(n_items + 1, n_variants, D_text)``. + Index 0 = padding (zeros), indices 1..n_items = item text embeddings. + n_factors : int + Hidden / output dimension of the transformer. + projection_hidden : int + Intermediate dimension for the PCA adaptor head. + n_blocks : int + Number of transformer blocks. + n_heads : int + Number of attention heads. + session_max_len : int + Maximum sequence length (positional embedding size). + dropout : float + Dropout in transformer blocks. + adaptor_dropout : float + Dropout inside the adaptor MLP. + adaptor_type : ``"pca"`` | ``"bn"`` + Type of adaptor for projecting pretrained embeddings. + use_adaptor_ffn : bool + Whether to use a 2-layer MLP head after the linear projection. + initializer_range : float + Std for normal weight initialisation. + """ + + PADDING_IDX = 0 + + def __init__( + self, + n_items: int, + pretrained_embeddings: torch.Tensor, + n_factors: int = 256, + projection_hidden: int = 512, + n_blocks: int = 2, + n_heads: int = 1, + session_max_len: int = 200, + dropout: float = 0.1, + adaptor_dropout: float = 0.2, + adaptor_type: str = "pca", + use_adaptor_ffn: bool = True, + initializer_range: float = 0.02, + ) -> None: + super().__init__() + self.n_items = n_items + self.n_factors = n_factors + self.session_max_len = session_max_len + self.n_blocks = n_blocks + self.adaptor_type = adaptor_type + self.use_adaptor_ffn = use_adaptor_ffn + self.initializer_range = initializer_range + + if not use_adaptor_ffn and adaptor_type != "pca": + raise ValueError("use_adaptor_ffn=False is only supported with adaptor_type='pca'") + + # ── ID embedding (Phase 1) ── + self.item_emb = nn.Embedding(n_items + 1, n_factors, padding_idx=self.PADDING_IDX) + + # ── Frozen pretrained embeddings ── + if pretrained_embeddings.ndim == 2: + pretrained_embeddings = pretrained_embeddings.unsqueeze(1) + self.register_buffer("frozen_emb", pretrained_embeddings) + self.n_variants = pretrained_embeddings.shape[1] + + qwen_dim = pretrained_embeddings.shape[2] + emb_for_init = pretrained_embeddings[1:, 0, :] # skip padding row + + # ── Adaptor ── + if adaptor_type == "pca": + self.whitening_bias = nn.Parameter(emb_for_init.mean(dim=0)) + if use_adaptor_ffn: + self.whitening_proj = nn.Parameter(self._pca_init(emb_for_init, projection_hidden)) + proj_dim = self.whitening_proj.shape[1] + self.head = _make_mlp(proj_dim, proj_dim, n_factors, adaptor_dropout) + else: + self.whitening_proj = nn.Parameter(self._pca_init(emb_for_init, n_factors)) + self.head = None + elif adaptor_type == "bn": + self.bn_input = nn.BatchNorm1d(qwen_dim) + self.bn_score = nn.BatchNorm1d(qwen_dim) + self.head = _make_mlp(qwen_dim, n_factors, n_factors, adaptor_dropout) + else: + raise ValueError(f"Unknown adaptor_type: {adaptor_type}") + + # ── Positional embedding + dropout ── + self.pos_emb = nn.Embedding(session_max_len, n_factors) + self.emb_dropout = nn.Dropout(dropout) + + # ── Transformer blocks (pre-norm) ── + self.attention_layernorms = nn.ModuleList() + self.attention_layers = nn.ModuleList() + self.forward_layernorms = nn.ModuleList() + self.forward_layers = nn.ModuleList() + self.last_layernorm = nn.LayerNorm(n_factors, eps=1e-12) + + for _ in range(n_blocks): + self.attention_layernorms.append(nn.LayerNorm(n_factors, eps=1e-12)) + self.attention_layers.append(nn.MultiheadAttention(n_factors, n_heads, dropout, batch_first=True)) + self.forward_layernorms.append(nn.LayerNorm(n_factors, eps=1e-12)) + self.forward_layers.append(FeedForward(n_factors, dropout)) + + self.apply(self._init_weights) + + # ── Init helpers ── + + @staticmethod + def _pca_init(embeddings: torch.Tensor, out_dim: int) -> torch.Tensor: + centered = embeddings - embeddings.mean(dim=0) + _, _, Vh = torch.linalg.svd(centered, full_matrices=False) + out_dim = min(out_dim, Vh.shape[0]) + return Vh[:out_dim].T.contiguous() + + def _init_weights(self, module: nn.Module) -> None: + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # ── Adaptor ── + + def _adapt_input(self, x: torch.Tensor) -> torch.Tensor: + if self.adaptor_type == "pca": + projected = (x - self.whitening_bias) @ self.whitening_proj + return self.head(projected) if self.head is not None else projected + shape = x.shape + flat = x.view(-1, shape[-1]) + return self.head(self.bn_input(flat)).view(*shape[:-1], self.n_factors) + + def _adapt_score(self, x: torch.Tensor) -> torch.Tensor: + if self.adaptor_type == "pca": + projected = (x - self.whitening_bias) @ self.whitening_proj + return self.head(projected) if self.head is not None else projected + shape = x.shape + flat = x.view(-1, shape[-1]) + return self.head(self.bn_score(flat)).view(*shape[:-1], self.n_factors) + + def _sample_frozen(self, item_ids: torch.Tensor) -> torch.Tensor: + """Look up pretrained embeddings, sampling a random variant during training.""" + if self.n_variants == 1 or not self.training: + return self.frozen_emb[item_ids, 0] + vi = torch.randint(self.n_variants, item_ids.shape, device=item_ids.device) + vi = vi * (item_ids != 0).long() # padding always uses variant 0 + return self.frozen_emb[item_ids, vi] + + def project_all(self) -> torch.Tensor: + """Project all frozen embeddings (variant 0) through the score adaptor. + + Returns shape ``(n_items + 1, n_factors)``. + """ + return self._adapt_score(self.frozen_emb[:, 0]) + + # ── Param-group helpers for multi-phase training ── + + @property + def transformer_params(self) -> tp.List[nn.Parameter]: + modules = ( + list(self.attention_layernorms) + list(self.attention_layers) + + list(self.forward_layernorms) + list(self.forward_layers) + + [self.last_layernorm, self.pos_emb] + ) + return [p for m in modules for p in m.parameters()] + + @property + def adaptor_params(self) -> tp.List[nn.Parameter]: + params: tp.List[nn.Parameter] = list(self.head.parameters()) if self.head is not None else [] + if self.adaptor_type == "pca": + params += [self.whitening_proj, self.whitening_bias] + else: + params += list(self.bn_input.parameters()) + list(self.bn_score.parameters()) + return params + + def freeze_transformer(self) -> None: + for p in self.transformer_params: + p.requires_grad = False + + def unfreeze_transformer(self) -> None: + for p in self.transformer_params: + p.requires_grad = True + + # ── Encoder ── + + def _causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: + return torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1) + + def _encode(self, seqs: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: + B, L = input_ids.shape + positions = torch.arange(L, device=input_ids.device).unsqueeze(0) + seqs = seqs + self.pos_emb(positions) + seqs = self.emb_dropout(seqs) + + pad_mask = (input_ids == self.PADDING_IDX) # (B, L) + pad_mask_3d = pad_mask.unsqueeze(-1) # (B, L, 1) + seqs = seqs.masked_fill(pad_mask_3d, 0.0) # zero out padding + + attn_mask = self._causal_mask(L, seqs.device) + key_padding_mask = pad_mask + + for i in range(self.n_blocks): + normed = self.attention_layernorms[i](seqs) + # Zero padding in Q/K/V so NaN can never appear in dot-products + normed = normed.masked_fill(pad_mask_3d, 0.0) + mha_out, _ = self.attention_layers[i]( + normed, normed, normed, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + ) + # masked_fill handles NaN*0 correctly (unlike multiplication) + seqs = (seqs + mha_out).masked_fill(pad_mask_3d, 0.0) + seqs = seqs + self.forward_layers[i](self.forward_layernorms[i](seqs)) + seqs = seqs.masked_fill(pad_mask_3d, 0.0) + + return self.last_layernorm(seqs) + + # ── Public forward / encode ── + + def forward(self, input_ids: torch.Tensor, use_id: bool = False) -> torch.Tensor: + """ + Encode a sequence of item IDs. + + Parameters + ---------- + input_ids : LongTensor (B, L) + Left-padded item ID sequences (0 = padding). + use_id : bool + If True use the trainable ``item_emb`` (Phase 1). + If False use the adapted pretrained embeddings (Phase 2/3). + + Returns + ------- + Tensor (B, L, n_factors) + """ + if use_id: + seqs = self.item_emb(input_ids) + else: + seqs = self._adapt_input(self._sample_frozen(input_ids)) + return self._encode(seqs, input_ids) + + def encode_last(self, input_ids: torch.Tensor, use_id: bool = False) -> torch.Tensor: + """Encode and return the last-position representation (B, D).""" + h = self.forward(input_ids, use_id=use_id) # (B, L, D) + return h[:, -1, :] # left-padded → last position is always the rightmost diff --git a/scripts/train_fast_sasrec.py b/scripts/train_fast_sasrec.py new file mode 100644 index 00000000..f0608504 --- /dev/null +++ b/scripts/train_fast_sasrec.py @@ -0,0 +1,77 @@ +"""End-to-end smoke test: synthetic dataset, train, recommend, metrics, i2i.""" + +import numpy as np +import pandas as pd + +from rectools import Columns +from rectools.dataset import Dataset +from rectools.fast_transformers import FlatSASRecModel + + +def main() -> None: + # --- Synthetic dataset: 80 users x 60 items --- + rng = np.random.RandomState(123) + n_users, n_items = 80, 60 + + rows = [] + for u in range(n_users): + n_inter = rng.randint(4, 15) + items = rng.choice(n_items, size=n_inter, replace=False) + for rank, item in enumerate(items): + rows.append({ + Columns.User: u, + Columns.Item: item, + Columns.Weight: 1.0, + Columns.Datetime: pd.Timestamp("2024-01-01") + pd.Timedelta(hours=rank), + }) + df = pd.DataFrame(rows) + dataset = Dataset.construct(df) + print(f"Dataset: {n_users} users, {n_items} items, {len(df)} interactions") + + # --- Train --- + model = FlatSASRecModel( + n_factors=32, n_blocks=2, n_heads=2, session_max_len=16, + loss="softmax", epochs=2, batch_size=32, lr=1e-3, verbose=1, + ) + model.fit(dataset) + print("Training done.") + + # --- Recommend --- + users = list(range(n_users)) + reco = model.recommend(users=users, dataset=dataset, k=5, filter_viewed=True) + print(f"\nTop-5 recommendations (first 3 users):") + print(reco[reco[Columns.User].isin(range(3))].to_string(index=False)) + + # --- Simple metrics --- + interactions = dataset.get_raw_interactions() + hits = 0 + total = 0 + ap_sum = 0.0 + for u in users: + viewed = set(interactions[interactions[Columns.User] == u][Columns.Item]) + rec_items = reco[reco[Columns.User] == u][Columns.Item].tolist() + # For this smoke test, "relevance" = items the user actually interacted with + # (training set overlap is expected since we don't do train/test split here) + rel = [1 if i in viewed else 0 for i in rec_items] + hits += sum(rel) + total += len(rec_items) + # AP + if sum(rel) > 0: + precision_at = np.cumsum(rel) / np.arange(1, len(rel) + 1) + ap_sum += np.sum(precision_at * rel) / sum(rel) + recall = hits / max(total, 1) + map_at_k = ap_sum / len(users) + print(f"\nRecall@5 (train overlap): {recall:.4f}") + print(f"MAP@5 (train overlap): {map_at_k:.4f}") + + # --- I2I --- + target_items = list(range(10)) + i2i = model.recommend_to_items(target_items=target_items, dataset=dataset, k=5) + print(f"\nI2I recommendations (first 3 target items):") + print(i2i[i2i[Columns.TargetItem].isin(range(3))].to_string(index=False)) + + print("\nSmoke test passed!") + + +if __name__ == "__main__": + main() diff --git a/scripts/train_unisrec.py b/scripts/train_unisrec.py new file mode 100644 index 00000000..5720ff7a --- /dev/null +++ b/scripts/train_unisrec.py @@ -0,0 +1,96 @@ +"""End-to-end smoke test for UniSRecModel with synthetic data and fake embeddings.""" + +import numpy as np +import pandas as pd +import torch + +from rectools import Columns +from rectools.dataset import Dataset +from rectools.fast_transformers import UniSRecModel + + +def main() -> None: + # --- Synthetic dataset: 80 users x 60 items --- + rng = np.random.RandomState(123) + n_users, n_items = 80, 60 + + rows = [] + for u in range(n_users): + n_inter = rng.randint(4, 15) + items = rng.choice(n_items, size=n_inter, replace=False) + for rank, item in enumerate(items): + rows.append({ + Columns.User: u, + Columns.Item: item, + Columns.Weight: 1.0, + Columns.Datetime: pd.Timestamp("2024-01-01") + pd.Timedelta(hours=rank), + }) + df = pd.DataFrame(rows) + dataset = Dataset.construct(df) + print(f"Dataset: {n_users} users, {n_items} items, {len(df)} interactions") + + # --- Fake pretrained embeddings (random, shape [n_items, 64]) --- + torch.manual_seed(42) + pretrained = torch.randn(n_items, 64) + + # --- Train --- + model = UniSRecModel( + pretrained_item_embeddings=pretrained, + n_factors=32, + projection_hidden=64, + n_blocks=2, + n_heads=2, + session_max_len=16, + phase1_epochs=2, + phase2_epochs=2, + phase3_epochs=2, + phase1_lr=1e-3, + phase2_lr=3e-4, + phase3_lr=1e-4, + batch_size=32, + verbose=1, + ) + model.fit(dataset) + print("Training done (3 phases).") + + # --- Recommend --- + users = list(range(n_users)) + reco = model.recommend(users=users, dataset=dataset, k=5, filter_viewed=True) + print(f"\nTop-5 recommendations (first 3 users):") + print(reco[reco[Columns.User].isin(range(3))].to_string(index=False)) + + # --- Simple metrics --- + interactions = dataset.get_raw_interactions() + hits = 0 + total = 0 + ap_sum = 0.0 + for u in users: + viewed = set(interactions[interactions[Columns.User] == u][Columns.Item]) + rec_items = reco[reco[Columns.User] == u][Columns.Item].tolist() + rel = [1 if i in viewed else 0 for i in rec_items] + hits += sum(rel) + total += len(rec_items) + if sum(rel) > 0: + precision_at = np.cumsum(rel) / np.arange(1, len(rel) + 1) + ap_sum += np.sum(precision_at * rel) / sum(rel) + recall = hits / max(total, 1) + map_at_k = ap_sum / len(users) + print(f"\nRecall@5 (train overlap): {recall:.4f}") + print(f"MAP@5 (train overlap): {map_at_k:.4f}") + + # --- NaN check --- + nan_count = reco[Columns.Score].isna().sum() + print(f"NaN scores: {nan_count} / {len(reco)}") + assert nan_count == 0, "Found NaN scores!" + + # --- I2I --- + target_items = list(range(10)) + i2i = model.recommend_to_items(target_items=target_items, dataset=dataset, k=5) + print(f"\nI2I recommendations (first 3 target items):") + print(i2i[i2i[Columns.TargetItem].isin(range(3))].to_string(index=False)) + + print("\nSmoke test passed!") + + +if __name__ == "__main__": + main() diff --git a/tests/fast_transformers/__init__.py b/tests/fast_transformers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fast_transformers/conftest.py b/tests/fast_transformers/conftest.py new file mode 100644 index 00000000..ddf4468f --- /dev/null +++ b/tests/fast_transformers/conftest.py @@ -0,0 +1,31 @@ +"""Fixtures for fast_transformers tests.""" + +import numpy as np +import pandas as pd +import pytest + +from rectools import Columns +from rectools.dataset import Dataset + + +@pytest.fixture() +def tiny_dataset() -> Dataset: + """20 users x 25 items, each user has 3-8 interactions.""" + rng = np.random.RandomState(42) + n_users, n_items = 20, 25 + + rows = [] + for u in range(n_users): + n_inter = rng.randint(3, 9) + items = rng.choice(n_items, size=n_inter, replace=False) + for rank, item in enumerate(items): + rows.append( + { + Columns.User: u, + Columns.Item: item, + Columns.Weight: 1.0, + Columns.Datetime: pd.Timestamp("2023-01-01") + pd.Timedelta(days=rank), + } + ) + df = pd.DataFrame(rows) + return Dataset.construct(df) diff --git a/tests/fast_transformers/test_model.py b/tests/fast_transformers/test_model.py new file mode 100644 index 00000000..7676fb2d --- /dev/null +++ b/tests/fast_transformers/test_model.py @@ -0,0 +1,89 @@ +"""Tests for FlatSASRecModel.""" + +import pickle + +import numpy as np +import pandas as pd +import pytest + +from rectools import Columns +from rectools.dataset import Dataset +from rectools.fast_transformers import FlatSASRecConfig, FlatSASRecModel + + +def _make_model(**kwargs) -> FlatSASRecModel: + defaults = dict( + n_factors=16, n_blocks=1, n_heads=2, session_max_len=8, + epochs=1, batch_size=16, lr=1e-3, verbose=0, + ) + defaults.update(kwargs) + return FlatSASRecModel(**defaults) + + +class TestFitRecommend: + def test_recommend_columns(self, tiny_dataset: Dataset) -> None: + model = _make_model() + model.fit(tiny_dataset) + users = list(range(5)) + reco = model.recommend(users=users, dataset=tiny_dataset, k=3, filter_viewed=False) + assert set(reco.columns) == {Columns.User, Columns.Item, Columns.Score, Columns.Rank} + assert reco[Columns.User].nunique() == 5 + + def test_filter_viewed(self, tiny_dataset: Dataset) -> None: + model = _make_model() + model.fit(tiny_dataset) + users = list(range(5)) + reco = model.recommend(users=users, dataset=tiny_dataset, k=5, filter_viewed=True) + interactions = tiny_dataset.get_raw_interactions() + for uid in users: + viewed = set(interactions[interactions[Columns.User] == uid][Columns.Item]) + recommended = set(reco[reco[Columns.User] == uid][Columns.Item]) + assert viewed.isdisjoint(recommended), f"User {uid} got viewed items in recommendations" + + def test_i2i(self, tiny_dataset: Dataset) -> None: + model = _make_model() + model.fit(tiny_dataset) + items = list(range(5)) + reco = model.recommend_to_items(target_items=items, dataset=tiny_dataset, k=3) + assert set(reco.columns) == {Columns.TargetItem, Columns.Item, Columns.Score, Columns.Rank} + assert reco[Columns.TargetItem].nunique() == 5 + + def test_metrics_positive(self, tiny_dataset: Dataset) -> None: + model = _make_model(epochs=3) + model.fit(tiny_dataset) + users = list(range(tiny_dataset.user_id_map.size)) + reco = model.recommend(users=users, dataset=tiny_dataset, k=5, filter_viewed=False) + assert len(reco) > 0 + assert reco[Columns.Score].notna().all() + + +class TestConfig: + def test_config_roundtrip(self) -> None: + model = _make_model(n_factors=32, n_blocks=3) + config = model.get_config(mode="pydantic") + model2 = FlatSASRecModel.from_config(config) + assert model2.n_factors == 32 + assert model2.n_blocks == 3 + + def test_pickle_roundtrip(self, tiny_dataset: Dataset) -> None: + model = _make_model() + model.fit(tiny_dataset) + data = pickle.dumps(model) + model2 = pickle.loads(data) + assert model2.is_fitted + users = list(range(3)) + reco = model2.recommend(users=users, dataset=tiny_dataset, k=3, filter_viewed=False) + assert len(reco) > 0 + + +class TestLosses: + def test_bce_training(self, tiny_dataset: Dataset) -> None: + model = _make_model(loss="BCE", n_negatives=2) + model.fit(tiny_dataset) + users = list(range(5)) + reco = model.recommend(users=users, dataset=tiny_dataset, k=3, filter_viewed=False) + assert len(reco) > 0 + + def test_invalid_loss(self) -> None: + with pytest.raises(ValueError, match="Unsupported loss"): + _make_model(loss="invalid_loss_name") diff --git a/tests/fast_transformers/test_net.py b/tests/fast_transformers/test_net.py new file mode 100644 index 00000000..0d590466 --- /dev/null +++ b/tests/fast_transformers/test_net.py @@ -0,0 +1,49 @@ +"""Tests for FlatSASRec network.""" + +import torch +import pytest + +from rectools.fast_transformers.net import FlatSASRec + + +@pytest.fixture() +def net() -> FlatSASRec: + return FlatSASRec(n_items=30, n_factors=16, n_blocks=1, n_heads=2, session_max_len=8, dropout=0.0) + + +class TestFlatSASRec: + def test_full_catalog_logits_shape(self, net: FlatSASRec) -> None: + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + } + logits = net(batch) + assert logits.shape == (2, 5, 30) # (B, L, n_items) + + def test_candidate_logits_shape(self, net: FlatSASRec) -> None: + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + "negatives": torch.randint(1, 30, (2, 5, 3)), + } + logits = net(batch) + assert logits.shape == (2, 5, 4) # (B, L, 1 + n_neg) + + def test_encode_last_shape(self, net: FlatSASRec) -> None: + x = torch.tensor([[0, 0, 1, 2, 3]]) + emb = net.encode_last(x) + assert emb.shape == (1, 16) + + def test_padding_invariance(self, net: FlatSASRec) -> None: + """Different left-padding should produce same last-position embedding.""" + net.eval() + x1 = torch.tensor([[0, 0, 0, 1, 2]]) + x2 = torch.tensor([[0, 0, 0, 0, 2]]) + # Not exactly the same because sequence context differs, + # but if we use the same content the output should be identical + x_a = torch.tensor([[0, 0, 0, 5, 10]]) + x_b = torch.tensor([[0, 0, 0, 5, 10]]) + with torch.no_grad(): + e_a = net.encode_last(x_a) + e_b = net.encode_last(x_b) + torch.testing.assert_close(e_a, e_b) diff --git a/tests/fast_transformers/test_unisrec_model.py b/tests/fast_transformers/test_unisrec_model.py new file mode 100644 index 00000000..ff0b11ed --- /dev/null +++ b/tests/fast_transformers/test_unisrec_model.py @@ -0,0 +1,138 @@ +"""Tests for UniSRecModel.""" + +import numpy as np +import pandas as pd +import pytest +import torch + +from rectools import Columns +from rectools.dataset import Dataset +from rectools.fast_transformers import UniSRecConfig, UniSRecModel + + +def _make_dataset(n_users: int = 20, n_items: int = 25, seed: int = 42) -> Dataset: + rng = np.random.RandomState(seed) + rows = [] + for u in range(n_users): + n_inter = rng.randint(3, 8) + items = rng.choice(n_items, size=n_inter, replace=False) + for rank, item in enumerate(items): + rows.append({ + Columns.User: u, + Columns.Item: item, + Columns.Weight: 1.0, + Columns.Datetime: pd.Timestamp("2024-01-01") + pd.Timedelta(hours=rank), + }) + return Dataset.construct(pd.DataFrame(rows)) + + +def _make_embeddings(n_items: int = 25, dim: int = 64) -> torch.Tensor: + torch.manual_seed(0) + emb = torch.randn(n_items, dim) + emb[0] = 0.0 + return emb + + +def _make_model(**kwargs) -> UniSRecModel: + defaults = dict( + pretrained_item_embeddings=_make_embeddings(), + n_factors=16, + projection_hidden=32, + n_blocks=1, + n_heads=2, + session_max_len=8, + phase1_epochs=1, + phase2_epochs=1, + phase3_epochs=1, + batch_size=16, + verbose=0, + ) + defaults.update(kwargs) + return UniSRecModel(**defaults) + + +class TestFitRecommend: + def test_recommend_columns(self) -> None: + ds = _make_dataset() + model = _make_model() + model.fit(ds) + users = list(range(5)) + reco = model.recommend(users=users, dataset=ds, k=3, filter_viewed=False) + assert set(reco.columns) == {Columns.User, Columns.Item, Columns.Score, Columns.Rank} + assert reco[Columns.User].nunique() == 5 + + def test_filter_viewed(self) -> None: + ds = _make_dataset() + model = _make_model() + model.fit(ds) + users = list(range(5)) + reco = model.recommend(users=users, dataset=ds, k=5, filter_viewed=True) + interactions = ds.get_raw_interactions() + for uid in users: + viewed = set(interactions[interactions[Columns.User] == uid][Columns.Item]) + recommended = set(reco[reco[Columns.User] == uid][Columns.Item]) + assert viewed.isdisjoint(recommended), f"User {uid} got viewed items" + + def test_i2i(self) -> None: + ds = _make_dataset() + model = _make_model() + model.fit(ds) + items = list(range(5)) + reco = model.recommend_to_items(target_items=items, dataset=ds, k=3) + assert set(reco.columns) == {Columns.TargetItem, Columns.Item, Columns.Score, Columns.Rank} + assert reco[Columns.TargetItem].nunique() == 5 + + def test_scores_not_nan(self) -> None: + ds = _make_dataset() + model = _make_model(phase1_epochs=2, phase3_epochs=2) + model.fit(ds) + users = list(range(ds.user_id_map.size)) + reco = model.recommend(users=users, dataset=ds, k=5, filter_viewed=False) + assert len(reco) > 0 + assert reco[Columns.Score].notna().all() + + +class TestPhaseSkipping: + def test_skip_phase1(self) -> None: + ds = _make_dataset() + model = _make_model(phase1_epochs=0) + model.fit(ds) + reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) + assert len(reco) > 0 + + def test_skip_phase2(self) -> None: + ds = _make_dataset() + model = _make_model(phase2_epochs=0) + model.fit(ds) + reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) + assert len(reco) > 0 + + def test_only_phase3(self) -> None: + ds = _make_dataset() + model = _make_model(phase1_epochs=0, phase2_epochs=0, phase3_epochs=2) + model.fit(ds) + reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) + assert len(reco) > 0 + + +class TestWithNegatives: + def test_sampled_loss(self) -> None: + ds = _make_dataset() + model = _make_model(n_negatives=4) + model.fit(ds) + reco = model.recommend(users=[0, 1, 2], dataset=ds, k=3, filter_viewed=False) + assert len(reco) > 0 + + +class TestConfig: + def test_get_config(self) -> None: + model = _make_model() + config = model.get_config(mode="pydantic") + assert config.model.n_factors == 16 + assert config.model.n_blocks == 1 + + def test_from_config_raises(self) -> None: + model = _make_model() + config = model.get_config(mode="pydantic") + with pytest.raises(NotImplementedError, match="pretrained_item_embeddings"): + UniSRecModel.from_config(config) diff --git a/tests/fast_transformers/test_unisrec_net.py b/tests/fast_transformers/test_unisrec_net.py new file mode 100644 index 00000000..61889975 --- /dev/null +++ b/tests/fast_transformers/test_unisrec_net.py @@ -0,0 +1,115 @@ +"""Tests for UniSRec network.""" + +import torch +import pytest + +from rectools.fast_transformers.unisrec_net import UniSRec + + +@pytest.fixture() +def pretrained_emb() -> torch.Tensor: + """Fake pretrained embeddings: (31, 64) — 30 items + 1 padding.""" + torch.manual_seed(0) + emb = torch.randn(31, 64) + emb[0] = 0.0 # padding + return emb + + +@pytest.fixture() +def net(pretrained_emb: torch.Tensor) -> UniSRec: + return UniSRec( + n_items=30, + pretrained_embeddings=pretrained_emb, + n_factors=16, + projection_hidden=32, + n_blocks=1, + n_heads=2, + session_max_len=8, + dropout=0.0, + adaptor_dropout=0.0, + ) + + +class TestUniSRecShapes: + def test_forward_id_shape(self, net: UniSRec) -> None: + x = torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]) + h = net(x, use_id=True) + assert h.shape == (2, 5, 16) + + def test_forward_adapted_shape(self, net: UniSRec) -> None: + x = torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]) + h = net(x, use_id=False) + assert h.shape == (2, 5, 16) + + def test_encode_last_shape(self, net: UniSRec) -> None: + x = torch.tensor([[0, 0, 1, 2, 3]]) + emb = net.encode_last(x, use_id=False) + assert emb.shape == (1, 16) + + def test_project_all_shape(self, net: UniSRec) -> None: + proj = net.project_all() + assert proj.shape == (31, 16) # n_items + 1 (with padding) + + def test_item_emb_shape(self, net: UniSRec) -> None: + assert net.item_emb.weight.shape == (31, 16) + + +class TestUniSRecAdaptor: + def test_pca_no_ffn(self, pretrained_emb: torch.Tensor) -> None: + net = UniSRec( + n_items=30, + pretrained_embeddings=pretrained_emb, + n_factors=16, + n_blocks=1, + n_heads=2, + session_max_len=8, + adaptor_type="pca", + use_adaptor_ffn=False, + ) + proj = net.project_all() + assert proj.shape == (31, 16) + assert net.head is None + + def test_multi_variant(self) -> None: + torch.manual_seed(0) + emb = torch.randn(31, 3, 64) # 3 variants + emb[0] = 0.0 + net = UniSRec( + n_items=30, + pretrained_embeddings=emb, + n_factors=16, + projection_hidden=32, + n_blocks=1, + n_heads=2, + session_max_len=8, + ) + assert net.n_variants == 3 + x = torch.tensor([[0, 0, 1, 2, 3]]) + h = net(x, use_id=False) + assert h.shape == (1, 5, 16) + + +class TestFreezeUnfreeze: + def test_freeze_transformer(self, net: UniSRec) -> None: + net.freeze_transformer() + for p in net.transformer_params: + assert not p.requires_grad + for p in net.adaptor_params: + assert p.requires_grad + + def test_unfreeze_transformer(self, net: UniSRec) -> None: + net.freeze_transformer() + net.unfreeze_transformer() + for p in net.transformer_params: + assert p.requires_grad + + +class TestPaddingInvariance: + def test_same_input_same_output(self, net: UniSRec) -> None: + net.eval() + x_a = torch.tensor([[0, 0, 0, 5, 10]]) + x_b = torch.tensor([[0, 0, 0, 5, 10]]) + with torch.no_grad(): + e_a = net.encode_last(x_a, use_id=False) + e_b = net.encode_last(x_b, use_id=False) + torch.testing.assert_close(e_a, e_b) From 6c875b3700ec2074f5c7d0b2072130113fe8b18a Mon Sep 17 00:00:00 2001 From: Topapec Date: Wed, 22 Apr 2026 19:16:31 +0300 Subject: [PATCH 2/7] feat: make UniSRec fully configurable New config options: - ffn_type: conv1d / linear_gelu / linear_relu + ffn_expansion - optimizer: adam / adamw - scheduler: cosine_warmup (with warmup_ratio, min_lr_ratio) - loss: softmax / BCE / gBCE / sampled_softmax (with gbce_t) - patience: early stopping via EarlyStopping callback + val split - data_preparator: accept custom preparator instance 31 tests passing. --- .../fast_transformers/unisrec_lightning.py | 198 +++++++++++++---- rectools/fast_transformers/unisrec_model.py | 208 +++++++++++++----- rectools/fast_transformers/unisrec_net.py | 34 ++- tests/fast_transformers/test_unisrec_model.py | 76 ++++++- 4 files changed, 413 insertions(+), 103 deletions(-) diff --git a/rectools/fast_transformers/unisrec_lightning.py b/rectools/fast_transformers/unisrec_lightning.py index c0c440f3..640b574d 100644 --- a/rectools/fast_transformers/unisrec_lightning.py +++ b/rectools/fast_transformers/unisrec_lightning.py @@ -1,13 +1,19 @@ -"""Lightning wrapper for UniSRec: supports full-softmax and sampled CE loss.""" +"""Lightning wrapper for UniSRec with configurable loss, optimizer, scheduler.""" +import math import typing as tp import torch import torch.nn.functional as F import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR from .unisrec_net import UniSRec +SUPPORTED_LOSSES = ("softmax", "BCE", "gBCE", "sampled_softmax") +SUPPORTED_OPTIMIZERS = ("adam", "adamw") +SUPPORTED_SCHEDULERS = (None, "cosine_warmup") + class UniSRecLightning(pl.LightningModule): """ @@ -22,11 +28,27 @@ def __init__( net: UniSRec, param_groups: tp.List[tp.Dict[str, tp.Any]], use_id: bool = False, + loss: str = "softmax", + n_negatives: tp.Optional[int] = None, + gbce_t: float = 0.2, + optimizer: str = "adamw", + scheduler: tp.Optional[str] = None, + warmup_ratio: float = 0.05, + min_lr_ratio: float = 0.1, + total_steps: tp.Optional[int] = None, ) -> None: super().__init__() self.net = net self._param_groups = param_groups self.use_id = use_id + self.loss_name = loss + self.n_negatives = n_negatives + self.gbce_t = gbce_t + self.optimizer_name = optimizer + self.scheduler_name = scheduler + self.warmup_ratio = warmup_ratio + self.min_lr_ratio = min_lr_ratio + self.total_steps = total_steps # ── helpers ── @@ -35,63 +57,149 @@ def _get_item_embs(self, item_ids: torch.Tensor) -> torch.Tensor: return self.net.item_emb(item_ids) return self.net._adapt_score(self.net._sample_frozen(item_ids)) - # ── training step ── + def _get_all_embs(self) -> torch.Tensor: + if self.use_id: + return self.net.item_emb.weight + return self.net.project_all() - def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: - input_ids = batch["x"] + def _get_pos_neg_logits( + self, hidden: torch.Tensor, labels: torch.Tensor, negatives: torch.Tensor, + ) -> torch.Tensor: + """Compute (B, L, 1+N) logits where index 0 = positive.""" + emb_pos = self._get_item_embs(labels) + logits_pos = (hidden * emb_pos).sum(dim=-1) + + emb_neg = self._get_item_embs(negatives) + logits_neg = torch.matmul( + hidden.unsqueeze(2), emb_neg.transpose(2, 3), + ).squeeze(2) + + return torch.cat([logits_pos.unsqueeze(-1), logits_neg], dim=-1) + + # ── losses ── + + def _calc_loss( + self, hidden: torch.Tensor, batch: tp.Dict[str, torch.Tensor], + ) -> torch.Tensor: labels = batch["y"] - hidden = self.net(input_ids, use_id=self.use_id) # (B, L, D) + has_neg = "negatives" in batch - if "negatives" in batch: - loss = self._sampled_ce_loss(hidden, labels, batch["negatives"]) - else: - loss = self._full_softmax_loss(hidden, labels) + if self.loss_name == "softmax" and not has_neg: + return self._full_softmax_loss(hidden, labels) - self.log("train_loss", loss, prog_bar=True) - return loss + if self.loss_name == "softmax" and has_neg: + # full softmax even if negatives are available + return self._full_softmax_loss(hidden, labels) - def _full_softmax_loss(self, hidden: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - if self.use_id: - all_emb = self.net.item_emb.weight # (n_items+1, D) - else: - all_emb = self.net.project_all() # (n_items+1, D) + if not has_neg: + raise ValueError(f"Loss '{self.loss_name}' requires negatives but batch has none") + + logits = self._get_pos_neg_logits(hidden, labels, batch["negatives"]) + mask = labels != 0 - logits = hidden @ all_emb.T # (B, L, n_items+1) - logits[:, :, 0] = float("-inf") # never predict padding + if self.loss_name == "sampled_softmax": + return self._sampled_softmax_loss(logits, mask) + if self.loss_name == "BCE": + return self._bce_loss(logits, mask) + if self.loss_name == "gBCE": + return self._gbce_loss(logits, mask) + + raise ValueError(f"Unknown loss: {self.loss_name}") + + def _full_softmax_loss(self, hidden: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + all_emb = self._get_all_embs() + logits = hidden @ all_emb.T + logits[:, :, 0] = float("-inf") targets = labels.clone() - targets[targets == 0] = -100 # padding → ignore + targets[targets == 0] = -100 return F.cross_entropy( - logits.view(-1, logits.size(-1)), - targets.view(-1), - ignore_index=-100, + logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100, ) - def _sampled_ce_loss( - self, - hidden: torch.Tensor, - labels: torch.Tensor, - negatives: torch.Tensor, - ) -> torch.Tensor: - emb_pos = self._get_item_embs(labels) # (B, L, D) - logits_pos = (hidden * emb_pos).sum(dim=-1) # (B, L) + def _sampled_softmax_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Sampled softmax: positive at index 0, swap to index 1 so index 0 can be ignored.""" + logits = logits.clone() + logits[:, :, [0, 1]] = logits[:, :, [1, 0]] + targets = mask.long() # 1 where non-padding, 0 where padding + return F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0, + ) - emb_neg = self._get_item_embs(negatives) # (B, L, N, D) - logits_neg = torch.matmul( # (B, L, N) - hidden.unsqueeze(2), emb_neg.transpose(2, 3), - ).squeeze(2) + def _bce_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + target = torch.zeros_like(logits) + target[:, :, 0] = 1.0 + loss = F.binary_cross_entropy_with_logits(logits, target, reduction="none") + loss = loss.mean(-1) * mask + return loss.sum() / mask.sum().clamp(min=1) - logits = torch.cat([logits_pos.unsqueeze(-1), logits_neg], dim=-1) # (B, L, 1+N) + def _gbce_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + n_items = self.net.n_items + n_neg = self.n_negatives or logits.size(-1) - 1 + alpha = n_neg / max(n_items - 1, 1) + beta = alpha * (self.gbce_t * (1 - 1 / alpha) + 1 / alpha) + + dtype = torch.float64 + pos_logits = logits[:, :, 0:1].to(dtype) + neg_logits = logits[:, :, 1:] + + eps = 1e-10 + pos_probs = torch.clamp(torch.sigmoid(pos_logits), eps, 1 - eps) + pos_adjusted = torch.clamp(pos_probs.pow(-beta), 1 + eps, torch.finfo(dtype).max) + pos_adjusted = torch.clamp(1.0 / (pos_adjusted - 1), eps, torch.finfo(dtype).max) + pos_transformed = torch.log(pos_adjusted).to(logits.dtype) + + adjusted_logits = torch.cat([pos_transformed, neg_logits], dim=-1) + return self._bce_loss(adjusted_logits, mask) + + # ── training / validation ── + + def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: + hidden = self.net(batch["x"], use_id=self.use_id) + loss = self._calc_loss(hidden, batch) + self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True) + return loss + + def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: + hidden = self.net(batch["x"], use_id=self.use_id) + # Validation batch has y of shape (B, 1) -- take last hidden position only + hidden = hidden[:, -1:, :] + loss = self._calc_loss(hidden, batch) + self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True) + return loss + + # ── optimizer / scheduler ── + + def configure_optimizers(self) -> tp.Any: + if self.optimizer_name == "adamw": + opt = torch.optim.AdamW(self._param_groups) + elif self.optimizer_name == "adam": + opt = torch.optim.Adam(self._param_groups) + else: + raise ValueError(f"Unknown optimizer: {self.optimizer_name}") + + if self.scheduler_name is None: + return opt + + if self.scheduler_name == "cosine_warmup": + total = self.total_steps or 1 + warmup = int(total * self.warmup_ratio) + scheduler = _cosine_warmup_scheduler(opt, warmup, total, self.min_lr_ratio) + return {"optimizer": opt, "lr_scheduler": {"scheduler": scheduler, "interval": "step"}} + + raise ValueError(f"Unknown scheduler: {self.scheduler_name}") - targets = torch.zeros_like(labels) # positive class = index 0 - targets[labels == 0] = -100 # padding → ignore - return F.cross_entropy( - logits.view(-1, logits.size(-1)), - targets.view(-1), - ignore_index=-100, - ) - # ── optimizer ── +def _cosine_warmup_scheduler( + optimizer: torch.optim.Optimizer, + warmup_steps: int, + total_steps: int, + min_lr_ratio: float = 0.0, +) -> LambdaLR: + def lr_lambda(step: int) -> float: + if step < warmup_steps: + return step / max(1, warmup_steps) + progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) + return min_lr_ratio + (1.0 - min_lr_ratio) * 0.5 * (1.0 + math.cos(math.pi * progress)) - def configure_optimizers(self) -> torch.optim.Optimizer: - return torch.optim.AdamW(self._param_groups) + return LambdaLR(optimizer, lr_lambda) diff --git a/rectools/fast_transformers/unisrec_model.py b/rectools/fast_transformers/unisrec_model.py index a1990884..ac93ebc9 100644 --- a/rectools/fast_transformers/unisrec_model.py +++ b/rectools/fast_transformers/unisrec_model.py @@ -1,12 +1,15 @@ -"""UniSRecModel: ModelBase wrapper with three-phase training.""" +"""UniSRecModel: ModelBase wrapper with configurable three-phase training.""" import typing as tp import numpy as np +import pandas as pd import torch import pytorch_lightning as pl +from pytorch_lightning.callbacks import EarlyStopping from scipy import sparse +from rectools import Columns from rectools.dataset import Dataset from rectools.models.base import InternalRecoTriplet, ModelBase, ModelConfig from rectools.models.nn.transformers.sasrec import SASRecDataPreparator @@ -15,13 +18,14 @@ from rectools.utils.config import BaseConfig from .unisrec_net import UniSRec -from .unisrec_lightning import UniSRecLightning +from .unisrec_lightning import UniSRecLightning, SUPPORTED_LOSSES, SUPPORTED_OPTIMIZERS, SUPPORTED_SCHEDULERS from .ranking import rank_topk class UniSRecConfig(BaseConfig): """Hyperparameters for UniSRecModel (without pretrained embeddings).""" + # architecture n_factors: int = 256 projection_hidden: int = 512 n_blocks: int = 2 @@ -31,7 +35,10 @@ class UniSRecConfig(BaseConfig): adaptor_dropout: float = 0.2 adaptor_type: str = "pca" use_adaptor_ffn: bool = True + ffn_type: str = "conv1d" + ffn_expansion: int = 1 + # training phases phase1_epochs: int = 10 phase2_epochs: int = 10 phase3_epochs: int = 10 @@ -42,13 +49,27 @@ class UniSRecConfig(BaseConfig): lr_wp: float = 0.1 lr_transformer: float = 3.0 + # optimizer / scheduler + optimizer: str = "adamw" + scheduler: tp.Optional[str] = None + warmup_ratio: float = 0.05 + min_lr_ratio: float = 0.1 grad_clip: float = 1.0 weight_decay: float = 0.01 + + # loss + loss: str = "softmax" + gbce_t: float = 0.2 + n_negatives: tp.Optional[int] = None + + # early stopping + patience: tp.Optional[int] = None + + # data batch_size: int = 128 recommend_batch_size: int = 256 dataloader_num_workers: int = 0 train_min_user_interactions: int = 2 - n_negatives: tp.Optional[int] = None class UniSRecModelConfig(ModelConfig): @@ -57,15 +78,20 @@ class UniSRecModelConfig(ModelConfig): model: UniSRecConfig = UniSRecConfig() +def _leave_last_out_mask(interactions: pd.DataFrame, **kwargs: tp.Any) -> pd.Series: + """Default validation mask: last interaction per user.""" + return interactions.groupby(Columns.User).cumcount(ascending=False) == 0 + + class UniSRecModel(ModelBase[UniSRecModelConfig]): """ UniSRec integrated into RecTools via ``ModelBase``. Three training phases --------------------- - 1. **Phase 1** — SASRec on ID embeddings (``item_emb`` + transformer). - 2. **Phase 2** — Adaptor only (transformer frozen, pretrained embeddings). - 3. **Phase 3** — Full fine-tune (adaptor + transformer, pretrained embeddings). + 1. **Phase 1** - SASRec on ID embeddings (``item_emb`` + transformer). + 2. **Phase 2** - Adaptor only (transformer frozen, pretrained embeddings). + 3. **Phase 3** - Full fine-tune (adaptor + transformer, pretrained embeddings). Parameters ---------- @@ -74,8 +100,9 @@ class UniSRecModel(ModelBase[UniSRecModelConfig]): ``(max_external_item_id + 1, n_variants, D_text)``. Index *i* holds the text embedding for the item whose **external** ID equals *i*. Index 0 is padding (zeros). - During ``fit`` the tensor is reindexed to match the internal ID map - produced by ``SASRecDataPreparator``. + data_preparator : object, optional + Custom data preparator. Must implement the same interface as + ``SASRecDataPreparator``. If None, one is created automatically. """ config_class = UniSRecModelConfig @@ -85,6 +112,7 @@ class UniSRecModel(ModelBase[UniSRecModelConfig]): def __init__( self, pretrained_item_embeddings: torch.Tensor, + # architecture n_factors: int = 256, projection_hidden: int = 512, n_blocks: int = 2, @@ -94,6 +122,9 @@ def __init__( adaptor_dropout: float = 0.2, adaptor_type: str = "pca", use_adaptor_ffn: bool = True, + ffn_type: str = "conv1d", + ffn_expansion: int = 1, + # training phases phase1_epochs: int = 10, phase2_epochs: int = 10, phase3_epochs: int = 10, @@ -103,16 +134,39 @@ def __init__( lr_head: float = 0.3, lr_wp: float = 0.1, lr_transformer: float = 3.0, + # optimizer / scheduler + optimizer: str = "adamw", + scheduler: tp.Optional[str] = None, + warmup_ratio: float = 0.05, + min_lr_ratio: float = 0.1, grad_clip: float = 1.0, weight_decay: float = 0.01, + # loss + loss: str = "softmax", + gbce_t: float = 0.2, + n_negatives: tp.Optional[int] = None, + # early stopping + patience: tp.Optional[int] = None, + # data batch_size: int = 128, recommend_batch_size: int = 256, dataloader_num_workers: int = 0, train_min_user_interactions: int = 2, - n_negatives: tp.Optional[int] = None, + # misc + data_preparator: tp.Any = None, verbose: int = 0, ) -> None: super().__init__(verbose=verbose) + + if loss not in SUPPORTED_LOSSES: + raise ValueError(f"Unsupported loss '{loss}'. Choose from {SUPPORTED_LOSSES}") + if loss in ("BCE", "gBCE", "sampled_softmax") and n_negatives is None: + raise ValueError(f"Loss '{loss}' requires n_negatives to be set") + if optimizer not in SUPPORTED_OPTIMIZERS: + raise ValueError(f"Unsupported optimizer '{optimizer}'. Choose from {SUPPORTED_OPTIMIZERS}") + if scheduler not in SUPPORTED_SCHEDULERS: + raise ValueError(f"Unsupported scheduler '{scheduler}'. Choose from {SUPPORTED_SCHEDULERS}") + self.pretrained_item_embeddings = pretrained_item_embeddings self.n_factors = n_factors self.projection_hidden = projection_hidden @@ -123,6 +177,8 @@ def __init__( self.adaptor_dropout = adaptor_dropout self.adaptor_type = adaptor_type self.use_adaptor_ffn = use_adaptor_ffn + self.ffn_type = ffn_type + self.ffn_expansion = ffn_expansion self.phase1_epochs = phase1_epochs self.phase2_epochs = phase2_epochs self.phase3_epochs = phase3_epochs @@ -132,18 +188,26 @@ def __init__( self.lr_head = lr_head self.lr_wp = lr_wp self.lr_transformer = lr_transformer + self.optimizer = optimizer + self.scheduler = scheduler + self.warmup_ratio = warmup_ratio + self.min_lr_ratio = min_lr_ratio self.grad_clip = grad_clip self.weight_decay = weight_decay + self.loss = loss + self.gbce_t = gbce_t + self.n_negatives = n_negatives + self.patience = patience self.batch_size = batch_size self.recommend_batch_size = recommend_batch_size self.dataloader_num_workers = dataloader_num_workers self.train_min_user_interactions = train_min_user_interactions - self.n_negatives = n_negatives + self._custom_data_preparator = data_preparator self._net: tp.Optional[UniSRec] = None - self._data_preparator: tp.Optional[SASRecDataPreparator] = None + self._data_preparator: tp.Optional[tp.Any] = None - # ── config boilerplate (embeddings are not serialised) ── + # ── config (embeddings + data_preparator not serialised) ── def _get_config(self) -> UniSRecModelConfig: return UniSRecModelConfig( @@ -159,6 +223,8 @@ def _get_config(self) -> UniSRecModelConfig: adaptor_dropout=self.adaptor_dropout, adaptor_type=self.adaptor_type, use_adaptor_ffn=self.use_adaptor_ffn, + ffn_type=self.ffn_type, + ffn_expansion=self.ffn_expansion, phase1_epochs=self.phase1_epochs, phase2_epochs=self.phase2_epochs, phase3_epochs=self.phase3_epochs, @@ -168,28 +234,35 @@ def _get_config(self) -> UniSRecModelConfig: lr_head=self.lr_head, lr_wp=self.lr_wp, lr_transformer=self.lr_transformer, + optimizer=self.optimizer, + scheduler=self.scheduler, + warmup_ratio=self.warmup_ratio, + min_lr_ratio=self.min_lr_ratio, grad_clip=self.grad_clip, weight_decay=self.weight_decay, + loss=self.loss, + gbce_t=self.gbce_t, + n_negatives=self.n_negatives, + patience=self.patience, batch_size=self.batch_size, recommend_batch_size=self.recommend_batch_size, dataloader_num_workers=self.dataloader_num_workers, train_min_user_interactions=self.train_min_user_interactions, - n_negatives=self.n_negatives, ), ) @classmethod def _from_config(cls, config: UniSRecModelConfig) -> "UniSRecModel": raise NotImplementedError( - "UniSRecModel cannot be restored from config alone — " + "UniSRecModel cannot be restored from config alone -- " "pretrained_item_embeddings must be supplied at construction time." ) # ── helpers ── - def _align_embeddings(self, dp: SASRecDataPreparator) -> torch.Tensor: - """Reindex ``pretrained_item_embeddings`` to the preparator's internal IDs.""" - ext_ids = dp.item_id_map.to_external.values # array[internal_id] → external_id + def _align_embeddings(self, dp: tp.Any) -> torch.Tensor: + """Reindex pretrained_item_embeddings to the preparator's internal IDs.""" + ext_ids = dp.item_id_map.to_external.values n_internal = dp.item_id_map.size n_extra = dp.n_item_extra_tokens @@ -206,18 +279,44 @@ def _align_embeddings(self, dp: SASRecDataPreparator) -> torch.Tensor: return aligned - def _make_trainer(self, max_epochs: int) -> pl.Trainer: + def _make_trainer(self, max_epochs: int, val_dl: tp.Any = None) -> pl.Trainer: + callbacks = [] + if self.patience is not None and val_dl is not None: + callbacks.append(EarlyStopping(monitor="val_loss", patience=self.patience, mode="min")) + return pl.Trainer( max_epochs=max_epochs, gradient_clip_val=self.grad_clip, + callbacks=callbacks or None, enable_checkpointing=False, enable_model_summary=False, logger=self.verbose > 0, enable_progress_bar=self.verbose > 0, ) + def _make_lightning( + self, net: UniSRec, param_groups: tp.List[tp.Dict], use_id: bool, max_epochs: int, train_dl: tp.Any, + ) -> UniSRecLightning: + total_steps = len(train_dl) * max_epochs if self.scheduler else None + return UniSRecLightning( + net=net, + param_groups=param_groups, + use_id=use_id, + loss=self.loss, + n_negatives=self.n_negatives, + gbce_t=self.gbce_t, + optimizer=self.optimizer, + scheduler=self.scheduler, + warmup_ratio=self.warmup_ratio, + min_lr_ratio=self.min_lr_ratio, + total_steps=total_steps, + ) + # ── Phase param-groups ── + def _phase1_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]: + return [{"params": list(net.item_emb.parameters()) + net.transformer_params, "lr": self.phase1_lr}] + def _phase2_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]: if self.adaptor_type == "pca": groups: tp.List[tp.Dict[str, tp.Any]] = [ @@ -239,7 +338,6 @@ def _phase2_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]: return groups def _phase3_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]: - # adaptor if self.adaptor_type == "pca": adaptor: tp.List[tp.Dict[str, tp.Any]] = [ {"params": [net.whitening_proj], "lr": self.phase3_lr * self.lr_wp, "weight_decay": 0.0}, @@ -250,11 +348,9 @@ def _phase3_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]: {"params": list(net.bn_input.parameters()), "lr": self.phase3_lr, "weight_decay": 0.0}, {"params": list(net.bn_score.parameters()), "lr": self.phase3_lr, "weight_decay": 0.0}, ] - # head head: tp.List[tp.Dict[str, tp.Any]] = [] if net.head is not None: head = [{"params": list(net.head.parameters()), "lr": self.phase3_lr * self.lr_head, "weight_decay": self.weight_decay}] - # transformer transformer = [ {"params": list(net.pos_emb.parameters()), "lr": self.phase3_lr * self.lr_transformer, "weight_decay": 0.0}, { @@ -281,20 +377,25 @@ def _phase3_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]: def _fit(self, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> None: # Data preparation - negative_sampler = None - n_negatives_dp: tp.Optional[int] = None - if self.n_negatives is not None: - negative_sampler = CatalogUniformSampler(n_negatives=self.n_negatives) - n_negatives_dp = self.n_negatives + if self._custom_data_preparator is not None: + dp = self._custom_data_preparator + else: + requires_neg = self.loss in ("BCE", "gBCE", "sampled_softmax") or self.n_negatives is not None + negative_sampler = CatalogUniformSampler(n_negatives=self.n_negatives) if requires_neg else None + n_negatives_dp = self.n_negatives if requires_neg else None + + dp_kwargs: tp.Dict[str, tp.Any] = dict( + session_max_len=self.session_max_len, + batch_size=self.batch_size, + dataloader_num_workers=self.dataloader_num_workers, + train_min_user_interactions=self.train_min_user_interactions, + n_negatives=n_negatives_dp, + negative_sampler=negative_sampler, + ) + if self.patience is not None: + dp_kwargs["get_val_mask_func"] = _leave_last_out_mask + dp = SASRecDataPreparator(**dp_kwargs) - dp = SASRecDataPreparator( - session_max_len=self.session_max_len, - batch_size=self.batch_size, - dataloader_num_workers=self.dataloader_num_workers, - train_min_user_interactions=self.train_min_user_interactions, - n_negatives=n_negatives_dp, - negative_sampler=negative_sampler, - ) dp.process_dataset_train(dataset) self._data_preparator = dp @@ -313,27 +414,31 @@ def _fit(self, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> None: adaptor_dropout=self.adaptor_dropout, adaptor_type=self.adaptor_type, use_adaptor_ffn=self.use_adaptor_ffn, + ffn_type=self.ffn_type, + ffn_expansion=self.ffn_expansion, ) train_dl = dp.get_dataloader_train() + val_dl = dp.get_dataloader_val() if self.patience is not None else None + + def _run_phase(param_groups: tp.List[tp.Dict], use_id: bool, max_epochs: int) -> None: + lm = self._make_lightning(net, param_groups, use_id, max_epochs, train_dl) + trainer = self._make_trainer(max_epochs, val_dl) + trainer.fit(lm, train_dl, val_dl) - # ── Phase 1: ID embeddings ── + # Phase 1: ID embeddings if self.phase1_epochs > 0: - p1_params = [{"params": list(net.item_emb.parameters()) + net.transformer_params, "lr": self.phase1_lr}] - lm = UniSRecLightning(net, p1_params, use_id=True) - self._make_trainer(self.phase1_epochs).fit(lm, train_dl) + _run_phase(self._phase1_params(net), use_id=True, max_epochs=self.phase1_epochs) - # ── Phase 2: adaptor only (transformer frozen) ── + # Phase 2: adaptor only (transformer frozen) if self.phase2_epochs > 0 and self.use_adaptor_ffn: net.freeze_transformer() - lm = UniSRecLightning(net, self._phase2_params(net), use_id=False) - self._make_trainer(self.phase2_epochs).fit(lm, train_dl) + _run_phase(self._phase2_params(net), use_id=False, max_epochs=self.phase2_epochs) - # ── Phase 3: full fine-tune ── + # Phase 3: full fine-tune if self.phase3_epochs > 0: net.unfreeze_transformer() - lm = UniSRecLightning(net, self._phase3_params(net), use_id=False) - self._make_trainer(self.phase3_epochs).fit(lm, train_dl) + _run_phase(self._phase3_params(net), use_id=False, max_epochs=self.phase3_epochs) self._net = net @@ -344,7 +449,7 @@ def _custom_transform_dataset_u2i( dataset: Dataset, users: tp.Any, on_unsupported_targets: tp.Any, - context: tp.Optional["pd.DataFrame"] = None, + context: tp.Optional[pd.DataFrame] = None, ) -> Dataset: assert self._data_preparator is not None return self._data_preparator.transform_dataset_u2i(dataset, users) @@ -373,8 +478,8 @@ def _get_user_embeddings(self, dataset: Dataset) -> torch.Tensor: def _get_item_embeddings(self) -> torch.Tensor: assert self._net is not None self._net.eval() - all_emb = self._net.project_all() # (n_items+1, D) - return all_emb[1:] # skip padding → (n_items, D) + all_emb = self._net.project_all() + return all_emb[1:] # ── recommend ── @@ -392,7 +497,6 @@ def _recommend_u2i( user_embs = self._get_user_embeddings(dataset) item_embs = self._get_item_embeddings() - # viewed-item filter filter_csr = None if filter_viewed: ui_mat = dataset.get_user_item_matrix(include_weights=False) @@ -409,7 +513,6 @@ def _recommend_u2i( else: filter_csr = sliced - # whitelist whitelist = None if sorted_item_ids_to_recommend is not None: n_extra = self._data_preparator.n_item_extra_tokens @@ -418,9 +521,7 @@ def _recommend_u2i( u_ids, i_ids, scores = rank_topk( user_embs, item_embs, k, - filter_csr=filter_csr, - whitelist=whitelist, - batch_size=self.recommend_batch_size, + filter_csr=filter_csr, whitelist=whitelist, batch_size=self.recommend_batch_size, ) n_extra = self._data_preparator.n_item_extra_tokens @@ -449,8 +550,7 @@ def _recommend_i2i( t_ids, i_ids, scores = rank_topk( target_embs, item_embs, k, - whitelist=whitelist, - batch_size=self.recommend_batch_size, + whitelist=whitelist, batch_size=self.recommend_batch_size, ) result_target_ids = target_ids[t_ids] diff --git a/rectools/fast_transformers/unisrec_net.py b/rectools/fast_transformers/unisrec_net.py index 2e83b5e8..d1329b20 100644 --- a/rectools/fast_transformers/unisrec_net.py +++ b/rectools/fast_transformers/unisrec_net.py @@ -15,7 +15,7 @@ def _make_mlp(in_dim: int, hidden_dim: int, out_dim: int, dropout: float) -> nn. ) -class FeedForward(nn.Module): +class FeedForwardConv1d(nn.Module): """Point-wise FFN via Conv1d (kernel_size=1), matching the reference UniSRec.""" def __init__(self, hidden_units: int, dropout_rate: float) -> None: @@ -34,6 +34,34 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor: return outputs.transpose(-1, -2) +# keep old name as alias +FeedForward = FeedForwardConv1d + + +def make_ffn(n_factors: int, ffn_type: str, expansion: int, dropout: float) -> nn.Module: + """Create a feed-forward block. + + Parameters + ---------- + ffn_type : ``"conv1d"`` | ``"linear_gelu"`` | ``"linear_relu"`` + expansion : hidden-dim multiplier (e.g. 1 or 4). + """ + if ffn_type == "conv1d": + return FeedForwardConv1d(n_factors, dropout) + hidden = n_factors * expansion + if ffn_type == "linear_gelu": + return nn.Sequential( + nn.Linear(n_factors, hidden), nn.GELU(), nn.Dropout(dropout), + nn.Linear(hidden, n_factors), nn.Dropout(dropout), + ) + if ffn_type == "linear_relu": + return nn.Sequential( + nn.Linear(n_factors, hidden), nn.ReLU(), nn.Dropout(dropout), + nn.Linear(hidden, n_factors), + ) + raise ValueError(f"Unknown ffn_type: {ffn_type}. Choose from: conv1d, linear_gelu, linear_relu") + + class UniSRec(nn.Module): """ UniSRec: sequential recommender with pretrained text embeddings + adaptor. @@ -87,6 +115,8 @@ def __init__( adaptor_type: str = "pca", use_adaptor_ffn: bool = True, initializer_range: float = 0.02, + ffn_type: str = "conv1d", + ffn_expansion: int = 1, ) -> None: super().__init__() self.n_items = n_items @@ -144,7 +174,7 @@ def __init__( self.attention_layernorms.append(nn.LayerNorm(n_factors, eps=1e-12)) self.attention_layers.append(nn.MultiheadAttention(n_factors, n_heads, dropout, batch_first=True)) self.forward_layernorms.append(nn.LayerNorm(n_factors, eps=1e-12)) - self.forward_layers.append(FeedForward(n_factors, dropout)) + self.forward_layers.append(make_ffn(n_factors, ffn_type, ffn_expansion, dropout)) self.apply(self._init_weights) diff --git a/tests/fast_transformers/test_unisrec_model.py b/tests/fast_transformers/test_unisrec_model.py index ff0b11ed..98dc3e94 100644 --- a/tests/fast_transformers/test_unisrec_model.py +++ b/tests/fast_transformers/test_unisrec_model.py @@ -124,12 +124,84 @@ def test_sampled_loss(self) -> None: assert len(reco) > 0 +class TestFFNTypes: + @pytest.mark.parametrize("ffn_type", ["conv1d", "linear_gelu", "linear_relu"]) + def test_ffn_type(self, ffn_type: str) -> None: + ds = _make_dataset() + model = _make_model(ffn_type=ffn_type, ffn_expansion=2, phase1_epochs=0, phase2_epochs=0, phase3_epochs=1) + model.fit(ds) + reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) + assert len(reco) > 0 + + +class TestLosses: + def test_bce_loss(self) -> None: + ds = _make_dataset() + model = _make_model(loss="BCE", n_negatives=4) + model.fit(ds) + reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) + assert len(reco) > 0 + + def test_gbce_loss(self) -> None: + ds = _make_dataset() + model = _make_model(loss="gBCE", n_negatives=4, gbce_t=0.2) + model.fit(ds) + reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) + assert len(reco) > 0 + + def test_sampled_softmax_loss(self) -> None: + ds = _make_dataset() + model = _make_model(loss="sampled_softmax", n_negatives=4) + model.fit(ds) + reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) + assert len(reco) > 0 + + def test_invalid_loss(self) -> None: + with pytest.raises(ValueError, match="Unsupported loss"): + _make_model(loss="invalid") + + +class TestOptimizer: + def test_adam_optimizer(self) -> None: + ds = _make_dataset() + model = _make_model(optimizer="adam", phase1_epochs=0, phase2_epochs=0, phase3_epochs=1) + model.fit(ds) + reco = model.recommend(users=[0], dataset=ds, k=3, filter_viewed=False) + assert len(reco) > 0 + + def test_invalid_optimizer(self) -> None: + with pytest.raises(ValueError, match="Unsupported optimizer"): + _make_model(optimizer="sgd") + + +class TestScheduler: + def test_cosine_warmup(self) -> None: + ds = _make_dataset() + model = _make_model(scheduler="cosine_warmup", warmup_ratio=0.1, phase1_epochs=0, phase2_epochs=0, phase3_epochs=2) + model.fit(ds) + reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) + assert len(reco) > 0 + + +class TestEarlyStopping: + def test_patience(self) -> None: + ds = _make_dataset() + model = _make_model(patience=2, phase1_epochs=0, phase2_epochs=0, phase3_epochs=5) + model.fit(ds) + reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) + assert len(reco) > 0 + + class TestConfig: def test_get_config(self) -> None: - model = _make_model() + model = _make_model(ffn_type="linear_gelu", loss="BCE", n_negatives=4, optimizer="adam", scheduler="cosine_warmup", patience=5) config = model.get_config(mode="pydantic") assert config.model.n_factors == 16 - assert config.model.n_blocks == 1 + assert config.model.ffn_type == "linear_gelu" + assert config.model.loss == "BCE" + assert config.model.optimizer == "adam" + assert config.model.scheduler == "cosine_warmup" + assert config.model.patience == 5 def test_from_config_raises(self) -> None: model = _make_model() From 3cec1e06e0a323bd3b493e03d30c11e651076614 Mon Sep 17 00:00:00 2001 From: TOPAPEC Date: Fri, 24 Apr 2026 15:48:59 +0000 Subject: [PATCH 3/7] Fast gpu preprocessing and good metrics --- .gitignore | 7 +- rectools/fast_transformers/__init__.py | 8 +- rectools/fast_transformers/gpu_data.py | 112 ++++++ rectools/fast_transformers/unisrec_model.py | 401 +++++--------------- scripts/profile_build_sequences.py | 142 +++++++ scripts/test_1epoch.py | 88 +++++ scripts/train_unisrec_ml20m.py | 293 ++++++++++++++ 7 files changed, 742 insertions(+), 309 deletions(-) create mode 100644 rectools/fast_transformers/gpu_data.py create mode 100644 scripts/profile_build_sequences.py create mode 100644 scripts/test_1epoch.py create mode 100644 scripts/train_unisrec_ml20m.py diff --git a/.gitignore b/.gitignore index c5b1c9f3..13082042 100644 --- a/.gitignore +++ b/.gitignore @@ -95,4 +95,9 @@ benchmark_results/ *.dat # CatBoost -catboost_info/ \ No newline at end of file +catboost_info/ + +# Dev testing folder +training_folder/ +*.pt +data/* \ No newline at end of file diff --git a/rectools/fast_transformers/__init__.py b/rectools/fast_transformers/__init__.py index 2a10affd..c074130f 100644 --- a/rectools/fast_transformers/__init__.py +++ b/rectools/fast_transformers/__init__.py @@ -1,14 +1,19 @@ """Fast Transformers: flat sequential recommenders without ItemNet hierarchy.""" +from .gpu_data import build_sequences, align_embeddings, GPUBatchDataset, make_dataloader from .lightning_wrap import FlatSASRecLightning from .model import FlatSASRecConfig, FlatSASRecModel from .net import FlatSASRec, SASRecBlock from .ranking import rank_topk from .unisrec_net import UniSRec, FeedForward from .unisrec_lightning import UniSRecLightning -from .unisrec_model import UniSRecConfig, UniSRecModel +from .unisrec_model import UniSRecModel __all__ = [ + "build_sequences", + "align_embeddings", + "GPUBatchDataset", + "make_dataloader", "FlatSASRec", "SASRecBlock", "FlatSASRecLightning", @@ -18,6 +23,5 @@ "UniSRec", "FeedForward", "UniSRecLightning", - "UniSRecConfig", "UniSRecModel", ] diff --git a/rectools/fast_transformers/gpu_data.py b/rectools/fast_transformers/gpu_data.py new file mode 100644 index 00000000..c4e67852 --- /dev/null +++ b/rectools/fast_transformers/gpu_data.py @@ -0,0 +1,112 @@ +"""GPU-native sequence building for transformer training. Pure torch, no pandas/numpy.""" + +import typing as tp + +import torch +from torch.utils.data import Dataset as TorchDataset, DataLoader + + +def build_sequences( + user_ids: torch.Tensor, + item_ids: torch.Tensor, + timestamps: torch.Tensor, + max_len: int, + min_interactions: int = 2, + device: str = "cuda", +) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + user_ids = user_ids.to(device) + item_ids = item_ids.to(device) + timestamps = timestamps.to(device) + + unique_items, item_inv = torch.unique(item_ids, return_inverse=True) + internal_items = item_inv + 1 + + unique_users, user_inv = torch.unique(user_ids, return_inverse=True) + + order1 = torch.argsort(timestamps, stable=True) + order2 = torch.argsort(user_inv[order1], stable=True) + order = order1[order2] + + sorted_user_inv = user_inv[order] + sorted_items = internal_items[order] + + changes = torch.where(sorted_user_inv[1:] != sorted_user_inv[:-1])[0] + 1 + starts = torch.cat([torch.tensor([0], device=device), changes]) + ends = torch.cat([changes, torch.tensor([len(sorted_user_inv)], device=device)]) + lengths = ends - starts + + mask = lengths >= min_interactions + starts = starts[mask] + ends = ends[mask] + lengths = lengths[mask] + n_users = len(starts) + + capped_lens = torch.clamp(lengths, max=max_len + 1) + + effective_lens = torch.clamp(capped_lens - 1, min=0) + total_elements = effective_lens.sum().item() + + x = torch.zeros(n_users, max_len, dtype=torch.long, device=device) + y = torch.zeros(n_users, max_len, dtype=torch.long, device=device) + + if total_elements > 0: + user_indices = torch.repeat_interleave(torch.arange(n_users, device=device), effective_lens) + cumsum = effective_lens.cumsum(0) + offsets = torch.arange(total_elements, device=device) - torch.repeat_interleave(cumsum - effective_lens, effective_lens) + + x_src = torch.repeat_interleave(ends - capped_lens, effective_lens) + offsets + y_src = x_src + 1 + col_indices = max_len - torch.repeat_interleave(effective_lens, effective_lens) + offsets + + x[user_indices, col_indices] = sorted_items[x_src] + y[user_indices, col_indices] = sorted_items[y_src] + + valid_user_indices = torch.where(mask)[0] + result_users = unique_users[valid_user_indices] if len(valid_user_indices) < len(unique_users) else unique_users + + return x, y, unique_items, result_users + + +def align_embeddings( + pretrained: torch.Tensor, + unique_items: torch.Tensor, + n_items: int, +) -> torch.Tensor: + idx = unique_items.long().cpu() + valid = (idx >= 0) & (idx < pretrained.shape[0]) + + if pretrained.ndim == 2: + aligned = torch.zeros(n_items + 1, pretrained.shape[1]) + aligned[1:][valid] = pretrained[idx[valid]] + else: + aligned = torch.zeros(n_items + 1, pretrained.shape[1], pretrained.shape[2]) + aligned[1:][valid] = pretrained[idx[valid]] + + return aligned + + +class GPUBatchDataset(TorchDataset): + def __init__(self, x: torch.Tensor, y: torch.Tensor, transform: tp.Optional[tp.Callable] = None): + self.x = x + self.y = y + self.transform = transform + + def __len__(self) -> int: + return len(self.x) + + def __getitem__(self, idx: int) -> tp.Dict[str, torch.Tensor]: + batch = {"x": self.x[idx], "y": self.y[idx]} + if self.transform: + batch = self.transform(batch) + return batch + + +def make_dataloader( + x: torch.Tensor, + y: torch.Tensor, + batch_size: int, + shuffle: bool = True, + transform: tp.Optional[tp.Callable] = None, +) -> DataLoader: + ds = GPUBatchDataset(x, y, transform=transform) + return DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=0) diff --git a/rectools/fast_transformers/unisrec_model.py b/rectools/fast_transformers/unisrec_model.py index ac93ebc9..d3a136d9 100644 --- a/rectools/fast_transformers/unisrec_model.py +++ b/rectools/fast_transformers/unisrec_model.py @@ -1,91 +1,20 @@ -"""UniSRecModel: ModelBase wrapper with configurable three-phase training.""" +"""UniSRecModel: standalone model with configurable three-phase training.""" import typing as tp +from pathlib import Path -import numpy as np -import pandas as pd import torch import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping -from scipy import sparse - -from rectools import Columns -from rectools.dataset import Dataset -from rectools.models.base import InternalRecoTriplet, ModelBase, ModelConfig -from rectools.models.nn.transformers.sasrec import SASRecDataPreparator -from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler -from rectools.types import InternalIdsArray -from rectools.utils.config import BaseConfig from .unisrec_net import UniSRec from .unisrec_lightning import UniSRecLightning, SUPPORTED_LOSSES, SUPPORTED_OPTIMIZERS, SUPPORTED_SCHEDULERS -from .ranking import rank_topk - - -class UniSRecConfig(BaseConfig): - """Hyperparameters for UniSRecModel (without pretrained embeddings).""" - - # architecture - n_factors: int = 256 - projection_hidden: int = 512 - n_blocks: int = 2 - n_heads: int = 1 - session_max_len: int = 200 - dropout: float = 0.1 - adaptor_dropout: float = 0.2 - adaptor_type: str = "pca" - use_adaptor_ffn: bool = True - ffn_type: str = "conv1d" - ffn_expansion: int = 1 - - # training phases - phase1_epochs: int = 10 - phase2_epochs: int = 10 - phase3_epochs: int = 10 - phase1_lr: float = 1e-3 - phase2_lr: float = 3e-4 - phase3_lr: float = 1e-4 - lr_head: float = 0.3 - lr_wp: float = 0.1 - lr_transformer: float = 3.0 - - # optimizer / scheduler - optimizer: str = "adamw" - scheduler: tp.Optional[str] = None - warmup_ratio: float = 0.05 - min_lr_ratio: float = 0.1 - grad_clip: float = 1.0 - weight_decay: float = 0.01 - - # loss - loss: str = "softmax" - gbce_t: float = 0.2 - n_negatives: tp.Optional[int] = None - - # early stopping - patience: tp.Optional[int] = None - - # data - batch_size: int = 128 - recommend_batch_size: int = 256 - dataloader_num_workers: int = 0 - train_min_user_interactions: int = 2 - - -class UniSRecModelConfig(ModelConfig): - """Full model config (cls + verbose + hyper-params).""" - - model: UniSRecConfig = UniSRecConfig() - - -def _leave_last_out_mask(interactions: pd.DataFrame, **kwargs: tp.Any) -> pd.Series: - """Default validation mask: last interaction per user.""" - return interactions.groupby(Columns.User).cumcount(ascending=False) == 0 - - -class UniSRecModel(ModelBase[UniSRecModelConfig]): +from .gpu_data import build_sequences, align_embeddings, make_dataloader + + +class UniSRecModel: """ - UniSRec integrated into RecTools via ``ModelBase``. + UniSRec sequential recommender with pretrained text embeddings. Three training phases --------------------- @@ -100,15 +29,8 @@ class UniSRecModel(ModelBase[UniSRecModelConfig]): ``(max_external_item_id + 1, n_variants, D_text)``. Index *i* holds the text embedding for the item whose **external** ID equals *i*. Index 0 is padding (zeros). - data_preparator : object, optional - Custom data preparator. Must implement the same interface as - ``SASRecDataPreparator``. If None, one is created automatically. """ - config_class = UniSRecModelConfig - recommends_for_warm = False - recommends_for_cold = False - def __init__( self, pretrained_item_embeddings: torch.Tensor, @@ -149,15 +71,10 @@ def __init__( patience: tp.Optional[int] = None, # data batch_size: int = 128, - recommend_batch_size: int = 256, dataloader_num_workers: int = 0, train_min_user_interactions: int = 2, - # misc - data_preparator: tp.Any = None, verbose: int = 0, ) -> None: - super().__init__(verbose=verbose) - if loss not in SUPPORTED_LOSSES: raise ValueError(f"Unsupported loss '{loss}'. Choose from {SUPPORTED_LOSSES}") if loss in ("BCE", "gBCE", "sampled_softmax") and n_negatives is None: @@ -199,86 +116,17 @@ def __init__( self.n_negatives = n_negatives self.patience = patience self.batch_size = batch_size - self.recommend_batch_size = recommend_batch_size self.dataloader_num_workers = dataloader_num_workers self.train_min_user_interactions = train_min_user_interactions - self._custom_data_preparator = data_preparator + self.verbose = verbose self._net: tp.Optional[UniSRec] = None - self._data_preparator: tp.Optional[tp.Any] = None - - # ── config (embeddings + data_preparator not serialised) ── - - def _get_config(self) -> UniSRecModelConfig: - return UniSRecModelConfig( - cls=self.__class__, - verbose=self.verbose, - model=UniSRecConfig( - n_factors=self.n_factors, - projection_hidden=self.projection_hidden, - n_blocks=self.n_blocks, - n_heads=self.n_heads, - session_max_len=self.session_max_len, - dropout=self.dropout, - adaptor_dropout=self.adaptor_dropout, - adaptor_type=self.adaptor_type, - use_adaptor_ffn=self.use_adaptor_ffn, - ffn_type=self.ffn_type, - ffn_expansion=self.ffn_expansion, - phase1_epochs=self.phase1_epochs, - phase2_epochs=self.phase2_epochs, - phase3_epochs=self.phase3_epochs, - phase1_lr=self.phase1_lr, - phase2_lr=self.phase2_lr, - phase3_lr=self.phase3_lr, - lr_head=self.lr_head, - lr_wp=self.lr_wp, - lr_transformer=self.lr_transformer, - optimizer=self.optimizer, - scheduler=self.scheduler, - warmup_ratio=self.warmup_ratio, - min_lr_ratio=self.min_lr_ratio, - grad_clip=self.grad_clip, - weight_decay=self.weight_decay, - loss=self.loss, - gbce_t=self.gbce_t, - n_negatives=self.n_negatives, - patience=self.patience, - batch_size=self.batch_size, - recommend_batch_size=self.recommend_batch_size, - dataloader_num_workers=self.dataloader_num_workers, - train_min_user_interactions=self.train_min_user_interactions, - ), - ) - - @classmethod - def _from_config(cls, config: UniSRecModelConfig) -> "UniSRecModel": - raise NotImplementedError( - "UniSRecModel cannot be restored from config alone -- " - "pretrained_item_embeddings must be supplied at construction time." - ) + self._unique_items: tp.Optional[torch.Tensor] = None + self._unique_users: tp.Optional[torch.Tensor] = None + self.is_fitted: bool = False # ── helpers ── - def _align_embeddings(self, dp: tp.Any) -> torch.Tensor: - """Reindex pretrained_item_embeddings to the preparator's internal IDs.""" - ext_ids = dp.item_id_map.to_external.values - n_internal = dp.item_id_map.size - n_extra = dp.n_item_extra_tokens - - emb = self.pretrained_item_embeddings - if emb.ndim == 2: - aligned = torch.zeros(n_internal, emb.shape[1]) - else: - aligned = torch.zeros(n_internal, emb.shape[1], emb.shape[2]) - - for int_id in range(n_extra, n_internal): - ext_id = int(ext_ids[int_id]) - if 0 <= ext_id < emb.shape[0]: - aligned[int_id] = emb[ext_id] - - return aligned - def _make_trainer(self, max_epochs: int, val_dl: tp.Any = None) -> pl.Trainer: callbacks = [] if self.patience is not None and val_dl is not None: @@ -375,35 +223,41 @@ def _phase3_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]: # ── fit ── - def _fit(self, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> None: - # Data preparation - if self._custom_data_preparator is not None: - dp = self._custom_data_preparator - else: - requires_neg = self.loss in ("BCE", "gBCE", "sampled_softmax") or self.n_negatives is not None - negative_sampler = CatalogUniformSampler(n_negatives=self.n_negatives) if requires_neg else None - n_negatives_dp = self.n_negatives if requires_neg else None - - dp_kwargs: tp.Dict[str, tp.Any] = dict( - session_max_len=self.session_max_len, - batch_size=self.batch_size, - dataloader_num_workers=self.dataloader_num_workers, - train_min_user_interactions=self.train_min_user_interactions, - n_negatives=n_negatives_dp, - negative_sampler=negative_sampler, - ) - if self.patience is not None: - dp_kwargs["get_val_mask_func"] = _leave_last_out_mask - dp = SASRecDataPreparator(**dp_kwargs) - - dp.process_dataset_train(dataset) - self._data_preparator = dp - - n_real_items = dp.item_id_map.size - dp.n_item_extra_tokens - aligned_emb = self._align_embeddings(dp) + def fit( + self, + user_ids: torch.Tensor, + item_ids: torch.Tensor, + timestamps: torch.Tensor, + ) -> "UniSRecModel": + """ + Train the model on interaction data. + + Parameters + ---------- + user_ids : LongTensor (N,) + External user IDs for each interaction. + item_ids : LongTensor (N,) + External item IDs for each interaction. + timestamps : LongTensor (N,) + Timestamps (any monotonic int64 values). + + Returns + ------- + self + """ + x, y, unique_items, unique_users = build_sequences( + user_ids, item_ids, timestamps, + max_len=self.session_max_len, + min_interactions=self.train_min_user_interactions, + ) + self._unique_items = unique_items.cpu() + self._unique_users = unique_users.cpu() + n_items = len(unique_items) + + aligned_emb = align_embeddings(self.pretrained_item_embeddings, unique_items, n_items) net = UniSRec( - n_items=n_real_items, + n_items=n_items, pretrained_embeddings=aligned_emb, n_factors=self.n_factors, projection_hidden=self.projection_hidden, @@ -418,141 +272,76 @@ def _fit(self, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> None: ffn_expansion=self.ffn_expansion, ) - train_dl = dp.get_dataloader_train() - val_dl = dp.get_dataloader_val() if self.patience is not None else None + train_dl = make_dataloader(x, y, batch_size=self.batch_size, shuffle=True) + + val_dl = None + if self.patience is not None: + val_y_last = y[:, -1:] + val_dl = make_dataloader(x, val_y_last, batch_size=self.batch_size, shuffle=False) def _run_phase(param_groups: tp.List[tp.Dict], use_id: bool, max_epochs: int) -> None: lm = self._make_lightning(net, param_groups, use_id, max_epochs, train_dl) trainer = self._make_trainer(max_epochs, val_dl) trainer.fit(lm, train_dl, val_dl) - # Phase 1: ID embeddings if self.phase1_epochs > 0: _run_phase(self._phase1_params(net), use_id=True, max_epochs=self.phase1_epochs) - # Phase 2: adaptor only (transformer frozen) if self.phase2_epochs > 0 and self.use_adaptor_ffn: net.freeze_transformer() _run_phase(self._phase2_params(net), use_id=False, max_epochs=self.phase2_epochs) - # Phase 3: full fine-tune if self.phase3_epochs > 0: net.unfreeze_transformer() _run_phase(self._phase3_params(net), use_id=False, max_epochs=self.phase3_epochs) self._net = net + self.is_fitted = True + return self - # ── dataset transforms ── + # ── save / load ── - def _custom_transform_dataset_u2i( - self, - dataset: Dataset, - users: tp.Any, - on_unsupported_targets: tp.Any, - context: tp.Optional[pd.DataFrame] = None, - ) -> Dataset: - assert self._data_preparator is not None - return self._data_preparator.transform_dataset_u2i(dataset, users) - - def _custom_transform_dataset_i2i( - self, dataset: Dataset, target_items: tp.Any, on_unsupported_targets: tp.Any - ) -> Dataset: - assert self._data_preparator is not None - return self._data_preparator.transform_dataset_i2i(dataset) - - # ── embeddings for ranking ── - - @torch.no_grad() - def _get_user_embeddings(self, dataset: Dataset) -> torch.Tensor: - assert self._data_preparator is not None and self._net is not None - self._net.eval() - device = next(self._net.parameters()).device - recommend_dl = self._data_preparator.get_dataloader_recommend(dataset, self.recommend_batch_size) - all_embs = [] - for batch in recommend_dl: - x = batch["x"].to(device) - all_embs.append(self._net.encode_last(x, use_id=False)) - return torch.cat(all_embs, dim=0) - - @torch.no_grad() - def _get_item_embeddings(self) -> torch.Tensor: + def save_checkpoint(self, path: tp.Union[str, Path]) -> None: assert self._net is not None - self._net.eval() - all_emb = self._net.project_all() - return all_emb[1:] - - # ── recommend ── - - def _recommend_u2i( - self, - user_ids: InternalIdsArray, - dataset: Dataset, - k: int, - filter_viewed: bool, - sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], - ) -> InternalRecoTriplet: - assert self._data_preparator is not None - device = next(self._net.parameters()).device # type: ignore[union-attr] - - user_embs = self._get_user_embeddings(dataset) - item_embs = self._get_item_embeddings() - - filter_csr = None - if filter_viewed: - ui_mat = dataset.get_user_item_matrix(include_weights=False) - n_users_mat = ui_mat.shape[0] - n_items_emb = item_embs.shape[0] - n_extra = self._data_preparator.n_item_extra_tokens - - sliced = ui_mat[:, n_extra:] if ui_mat.shape[1] > n_extra else sparse.csr_matrix((n_users_mat, 0)) - n_cols = sliced.shape[1] - if n_cols < n_items_emb: - filter_csr = sparse.hstack([sliced, sparse.csr_matrix((n_users_mat, n_items_emb - n_cols))], format="csr") - elif n_cols > n_items_emb: - filter_csr = sliced[:, :n_items_emb] - else: - filter_csr = sliced - - whitelist = None - if sorted_item_ids_to_recommend is not None: - n_extra = self._data_preparator.n_item_extra_tokens - wl = sorted_item_ids_to_recommend - n_extra - whitelist = wl[(wl >= 0) & (wl < item_embs.shape[0])] - - u_ids, i_ids, scores = rank_topk( - user_embs, item_embs, k, - filter_csr=filter_csr, whitelist=whitelist, batch_size=self.recommend_batch_size, - ) - - n_extra = self._data_preparator.n_item_extra_tokens - i_ids = i_ids + n_extra - return u_ids, i_ids, scores - - def _recommend_i2i( - self, - target_ids: InternalIdsArray, - dataset: Dataset, - k: int, - sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], - ) -> InternalRecoTriplet: - assert self._data_preparator is not None and self._net is not None - - item_embs = self._get_item_embeddings() - n_extra = self._data_preparator.n_item_extra_tokens - - target_emb_idx = target_ids - n_extra - target_embs = item_embs[target_emb_idx] - - whitelist = None - if sorted_item_ids_to_recommend is not None: - wl = sorted_item_ids_to_recommend - n_extra - whitelist = wl[(wl >= 0) & (wl < item_embs.shape[0])] - - t_ids, i_ids, scores = rank_topk( - target_embs, item_embs, k, - whitelist=whitelist, batch_size=self.recommend_batch_size, + torch.save({ + "net": self._net.state_dict(), + "unique_items": self._unique_items, + "unique_users": self._unique_users, + "n_items": len(self._unique_items), + }, path) + + def load_checkpoint(self, path: tp.Union[str, Path], device: str = "cuda") -> None: + ckpt = torch.load(path, map_location=device, weights_only=False) + self._unique_items = ckpt["unique_items"] + self._unique_users = ckpt["unique_users"] + n_items = ckpt["n_items"] + + aligned_emb = align_embeddings(self.pretrained_item_embeddings, self._unique_items, n_items) + + self._net = UniSRec( + n_items=n_items, + pretrained_embeddings=aligned_emb, + n_factors=self.n_factors, + projection_hidden=self.projection_hidden, + n_blocks=self.n_blocks, + n_heads=self.n_heads, + session_max_len=self.session_max_len, + dropout=self.dropout, + adaptor_dropout=self.adaptor_dropout, + adaptor_type=self.adaptor_type, + use_adaptor_ffn=self.use_adaptor_ffn, + ffn_type=self.ffn_type, + ffn_expansion=self.ffn_expansion, ) - - result_target_ids = target_ids[t_ids] - result_item_ids = i_ids + n_extra - return result_target_ids, result_item_ids, scores + self._net.load_state_dict(ckpt["net"]) + self._net.to(device).eval() + self.is_fitted = True + + @property + def net(self) -> UniSRec: + assert self._net is not None, "Model not fitted or loaded" + return self._net + + @property + def item_id_mapping(self) -> torch.Tensor: + return self._unique_items diff --git a/scripts/profile_build_sequences.py b/scripts/profile_build_sequences.py new file mode 100644 index 00000000..9325b1df --- /dev/null +++ b/scripts/profile_build_sequences.py @@ -0,0 +1,142 @@ +"""Profile build_sequences on synthetic data matching ML-20M scale.""" + +import time +import torch + +def build_sequences_profiled( + user_ids, item_ids, timestamps, max_len, min_interactions=2, device="cuda", +): + t0 = time.time() + user_ids = user_ids.to(device) + item_ids = item_ids.to(device) + timestamps = timestamps.to(device) + torch.cuda.synchronize() + t_transfer = time.time() - t0 + + t0 = time.time() + unique_items, item_inv = torch.unique(item_ids, return_inverse=True) + internal_items = item_inv + 1 + unique_users, user_inv = torch.unique(user_ids, return_inverse=True) + torch.cuda.synchronize() + t_unique = time.time() - t0 + + t0 = time.time() + order1 = torch.argsort(timestamps, stable=True) + order2 = torch.argsort(user_inv[order1], stable=True) + order = order1[order2] + sorted_user_inv = user_inv[order] + sorted_items = internal_items[order] + torch.cuda.synchronize() + t_sort = time.time() - t0 + + t0 = time.time() + changes = torch.where(sorted_user_inv[1:] != sorted_user_inv[:-1])[0] + 1 + starts = torch.cat([torch.tensor([0], device=device), changes]) + ends = torch.cat([changes, torch.tensor([len(sorted_user_inv)], device=device)]) + lengths = ends - starts + mask = lengths >= min_interactions + starts = starts[mask] + ends = ends[mask] + lengths = lengths[mask] + n_users = len(starts) + capped_lens = torch.clamp(lengths, max=max_len + 1) + torch.cuda.synchronize() + t_boundaries = time.time() - t0 + + t0 = time.time() + effective_lens = torch.clamp(capped_lens - 1, min=0) + total_elements = effective_lens.sum().item() + x = torch.zeros(n_users, max_len, dtype=torch.long, device=device) + y = torch.zeros(n_users, max_len, dtype=torch.long, device=device) + + if total_elements > 0: + user_indices = torch.repeat_interleave(torch.arange(n_users, device=device), effective_lens) + cumsum = effective_lens.cumsum(0) + offsets = torch.arange(total_elements, device=device) - torch.repeat_interleave(cumsum - effective_lens, effective_lens) + x_src = torch.repeat_interleave(ends - capped_lens, effective_lens) + offsets + y_src = x_src + 1 + col_indices = max_len - torch.repeat_interleave(effective_lens, effective_lens) + offsets + x[user_indices, col_indices] = sorted_items[x_src] + y[user_indices, col_indices] = sorted_items[y_src] + torch.cuda.synchronize() + t_scatter = time.time() - t0 + + valid_user_indices = torch.where(mask)[0] + result_users = unique_users[valid_user_indices] if len(valid_user_indices) < len(unique_users) else unique_users + + print(f" transfer to GPU: {t_transfer:.3f}s") + print(f" unique: {t_unique:.3f}s") + print(f" sort (2x argsort): {t_sort:.3f}s") + print(f" boundaries: {t_boundaries:.3f}s") + print(f" scatter (vectorized): {t_scatter:.3f}s") + print(f" TOTAL: {t_transfer + t_unique + t_sort + t_boundaries + t_scatter:.3f}s") + print(f" n_users={n_users}, total_elements={total_elements}") + + return x, y, unique_items, result_users + + +def verify_correctness(): + """Small test to verify vectorized scatter produces correct results.""" + torch.manual_seed(42) + n = 50 + user_ids = torch.tensor([0,0,0,0,0, 1,1,1, 2,2,2,2]) + item_ids = torch.tensor([10,20,30,40,50, 60,70,80, 90,100,110,120]) + timestamps = torch.arange(n := len(user_ids)) + + from rectools.fast_transformers.gpu_data import build_sequences + x, y, ui, uu = build_sequences(user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device="cuda") + + x_cpu = x.cpu() + y_cpu = y.cpu() + + print("\n=== Correctness check ===") + print(f"x:\n{x_cpu}") + print(f"y:\n{y_cpu}") + + # User 0: items [1,2,3,4,5], capped to 5 (max_len+1=5), effective=4 + # x row: [2, 3, 4, 5] wait, max_len=4 so x[0] should be [1,2,3,4], y[0]=[2,3,4,5] + # Actually: capped = min(5, 4+1=5) = 5, effective = 4 + # seq = items[-5:] = [1,2,3,4,5] + # x: seq[:-1] = [1,2,3,4] placed at cols 0..3 + # y: seq[1:] = [2,3,4,5] placed at cols 0..3 + assert x_cpu[0].tolist() == [1,2,3,4], f"Got {x_cpu[0].tolist()}" + assert y_cpu[0].tolist() == [2,3,4,5], f"Got {y_cpu[0].tolist()}" + + # User 1: items [6,7,8], capped=3, effective=2 + # seq = [6,7,8], x: [6,7] at cols 2..3, y: [7,8] at cols 2..3 + assert x_cpu[1].tolist() == [0,0,6,7], f"Got {x_cpu[1].tolist()}" + assert y_cpu[1].tolist() == [0,0,7,8], f"Got {y_cpu[1].tolist()}" + + # User 2: items [9,10,11,12], capped=4, effective=3 + # seq = [9,10,11,12], x: [9,10,11] at cols 1..3, y: [10,11,12] at cols 1..3 + assert x_cpu[2].tolist() == [0,9,10,11], f"Got {x_cpu[2].tolist()}" + assert y_cpu[2].tolist() == [0,10,11,12], f"Got {y_cpu[2].tolist()}" + + print("All assertions passed!") + + +def profile_ml20m_scale(): + """Generate data at ML-20M scale and profile.""" + print("\n=== ML-20M scale profile ===") + torch.manual_seed(0) + N = 5_000_000 + n_users_approx = 136_000 + n_items_approx = 7_000 + + user_ids = torch.randint(0, n_users_approx, (N,)) + item_ids = torch.randint(0, n_items_approx, (N,)) + timestamps = torch.randint(0, 10**9, (N,), dtype=torch.long) + + # warmup + print("Warmup...") + _ = build_sequences_profiled(user_ids[:1000], item_ids[:1000], timestamps[:1000], max_len=200, device="cuda") + + print("\nFull run:") + x, y, ui, uu = build_sequences_profiled(user_ids, item_ids, timestamps, max_len=200, device="cuda") + print(f"Output shape: x={x.shape}, y={y.shape}") + print(f"GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB") + + +if __name__ == "__main__": + verify_correctness() + profile_ml20m_scale() diff --git a/scripts/test_1epoch.py b/scripts/test_1epoch.py new file mode 100644 index 00000000..76d283ae --- /dev/null +++ b/scripts/test_1epoch.py @@ -0,0 +1,88 @@ +"""Quick 1-epoch smoke test of the full pipeline.""" + +import time +from pathlib import Path + +import pandas as pd +import torch + +from rectools.fast_transformers import UniSRecModel + +DATA_DIR = Path("data/ml-20m") +MIN_RATING = 4.0 +MIN_ITEM_INTERACTIONS = 50 +MIN_USER_INTERACTIONS = 5 + + +def load_data(): + ratings = pd.read_csv(DATA_DIR / "ml-20m" / "ratings.csv") + ratings.columns = ["user_id", "item_id", "rating", "timestamp"] + ratings = ratings[ratings["rating"] >= MIN_RATING] + item_counts = ratings.groupby("item_id").size() + popular = item_counts[item_counts >= MIN_ITEM_INTERACTIONS].index + ratings = ratings[ratings["item_id"].isin(popular)] + user_counts = ratings.groupby("user_id").size() + valid = user_counts[user_counts >= MIN_USER_INTERACTIONS].index + ratings = ratings[ratings["user_id"].isin(valid)] + return ratings + + +def main(): + print("Loading data...") + ratings = load_data() + print(f" {len(ratings):,} interactions, {ratings['user_id'].nunique():,} users, {ratings['item_id'].nunique():,} items") + + pretrained = torch.load(DATA_DIR / "qwen_embeddings.pt", weights_only=True) + print(f" Pretrained embeddings: {pretrained.shape}") + + user_ids = torch.tensor(ratings["user_id"].values, dtype=torch.long) + item_ids = torch.tensor(ratings["item_id"].values, dtype=torch.long) + timestamps = torch.tensor(ratings["timestamp"].values, dtype=torch.long) + + model = UniSRecModel( + pretrained_item_embeddings=pretrained, + n_factors=512, + projection_hidden=512, + n_blocks=2, + n_heads=1, + session_max_len=200, + dropout=0.1, + adaptor_dropout=0.2, + adaptor_type="pca", + use_adaptor_ffn=True, + phase1_epochs=0, + phase2_epochs=0, + phase3_epochs=1, + phase3_lr=1e-4, + lr_head=0.3, + lr_wp=0.1, + lr_transformer=3.0, + optimizer="adamw", + scheduler="cosine_warmup", + warmup_ratio=0.05, + min_lr_ratio=1.0, + grad_clip=1.0, + weight_decay=0.01, + loss="softmax", + batch_size=128, + dataloader_num_workers=0, + train_min_user_interactions=2, + verbose=1, + ) + + print("\nStarting 1-epoch training...") + t0 = time.time() + model.fit(user_ids, item_ids, timestamps) + elapsed = time.time() - t0 + print(f"\n1-epoch training complete in {elapsed:.1f}s") + + # Verify item_id_mapping contains original IDs + unique_items = model.item_id_mapping + print(f"unique_items range: [{unique_items.min().item()}, {unique_items.max().item()}]") + print(f"Original item_id range: [{ratings['item_id'].min()}, {ratings['item_id'].max()}]") + assert unique_items.max().item() > 100, "IDs should be original MovieLens IDs, not 0-based reindexed" + print("ID mapping verified — original external IDs preserved!") + + +if __name__ == "__main__": + main() diff --git a/scripts/train_unisrec_ml20m.py b/scripts/train_unisrec_ml20m.py new file mode 100644 index 00000000..388ee9a4 --- /dev/null +++ b/scripts/train_unisrec_ml20m.py @@ -0,0 +1,293 @@ +"""Train UniSRec on ML-20M with Qwen embeddings.""" + +import json +import zipfile +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm + +from rectools.fast_transformers import UniSRecModel + +DESCRIPTIONS_PATH = "training_folder/uniSRec/item_descriptions_compact.json" +QWEN_MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" +QWEN_DIM = 1024 +DATA_DIR = Path("data/ml-20m") +CACHE_EMB_PATH = DATA_DIR / "qwen_embeddings.pt" +ML20M_URL = "https://files.grouplens.org/datasets/movielens/ml-20m.zip" + +MIN_RATING = 4.0 +MIN_ITEM_INTERACTIONS = 50 +MIN_USER_INTERACTIONS = 5 +PHASE3_EPOCHS = 30 + + +def download_ml20m(): + DATA_DIR.mkdir(parents=True, exist_ok=True) + ratings_path = DATA_DIR / "ml-20m" / "ratings.csv" + if ratings_path.exists(): + return + zip_path = DATA_DIR / "ml-20m.zip" + if not zip_path.exists(): + print(f"Downloading ML-20M...") + import urllib.request + urllib.request.urlretrieve(ML20M_URL, zip_path) + print("Extracting...") + with zipfile.ZipFile(zip_path, "r") as zf: + zf.extractall(DATA_DIR) + + +def load_and_preprocess(): + download_ml20m() + ratings = pd.read_csv(DATA_DIR / "ml-20m" / "ratings.csv") + ratings.columns = ["user_id", "item_id", "rating", "timestamp"] + + if MIN_RATING > 0: + ratings = ratings[ratings["rating"] >= MIN_RATING] + print(f"After rating filter (>={MIN_RATING}): {len(ratings):,} interactions") + + if MIN_ITEM_INTERACTIONS > 0: + item_counts = ratings.groupby("item_id").size() + popular = item_counts[item_counts >= MIN_ITEM_INTERACTIONS].index + ratings = ratings[ratings["item_id"].isin(popular)] + print(f"After item filter (>={MIN_ITEM_INTERACTIONS}): {ratings['item_id'].nunique():,} items") + + user_counts = ratings.groupby("user_id").size() + valid = user_counts[user_counts >= MIN_USER_INTERACTIONS].index + ratings = ratings[ratings["user_id"].isin(valid)] + print(f"Final: {len(ratings):,} interactions, {ratings['user_id'].nunique():,} users, {ratings['item_id'].nunique():,} items") + + movies = pd.read_csv(DATA_DIR / "ml-20m" / "movies.csv") + movies.columns = ["movieId", "title", "genres"] + return ratings, movies + + +def _last_token_pool(hidden_states, attention_mask): + left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] + if left_padding: + return hidden_states[:, -1] + seq_lengths = attention_mask.sum(dim=1) - 1 + return hidden_states[torch.arange(hidden_states.shape[0], device=hidden_states.device), seq_lengths] + + +@torch.no_grad() +def encode_qwen(texts, device="cuda", batch_size=1024): + from transformers import AutoModel, AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_NAME, padding_side="left") + model = AutoModel.from_pretrained(QWEN_MODEL_NAME, torch_dtype=torch.float16).to(device).eval() + + embeddings = torch.zeros(len(texts), QWEN_DIM) + for start in tqdm(range(0, len(texts), batch_size), desc="Qwen encode"): + batch = texts[start:start + batch_size] + inputs = tokenizer(batch, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device) + hidden = model(**inputs).last_hidden_state + out = _last_token_pool(hidden, inputs["attention_mask"]) + embeddings[start:start + len(batch)] = out.float().cpu() + + del model, tokenizer + torch.cuda.empty_cache() + return embeddings + + +def build_pretrained_embeddings(movies, descriptions): + all_movie_ids = sorted(movies["movieId"].unique()) + max_id = max(all_movie_ids) + texts_by_id = {} + + for mid in all_movie_ids: + key = str(mid) + if key in descriptions: + val = descriptions[key] + texts_by_id[mid] = val[0] if isinstance(val, list) else val + else: + row = movies[movies["movieId"] == mid] + if len(row) > 0: + texts_by_id[mid] = f"{row.iloc[0]['title']} {row.iloc[0]['genres']}" + else: + texts_by_id[mid] = f"movie {mid}" + + ordered_ids = sorted(texts_by_id.keys()) + ordered_texts = [texts_by_id[mid] for mid in ordered_ids] + + if CACHE_EMB_PATH.exists(): + print(f"Loading cached embeddings from {CACHE_EMB_PATH}") + return torch.load(CACHE_EMB_PATH, weights_only=True) + + raw_embs = encode_qwen(ordered_texts, batch_size=512) + + embeddings = torch.zeros(max_id + 1, QWEN_DIM) + for i, mid in enumerate(ordered_ids): + embeddings[mid] = raw_embs[i] + + torch.save(embeddings, CACHE_EMB_PATH) + print(f"Saved embeddings to {CACHE_EMB_PATH}, shape={embeddings.shape}") + return embeddings + + +def split_eval(ratings): + ratings = ratings.sort_values(["user_id", "timestamp"]) + grouped = ratings.groupby("user_id") + test_idx = grouped.tail(1).index + remaining = ratings.drop(test_idx) + val_idx = remaining.groupby("user_id").tail(1).index + train_idx = remaining.drop(val_idx).index + + train = ratings.loc[train_idx] + val = ratings.loc[val_idx] + test = ratings.loc[test_idx] + return train, val, test + + +def to_tensors(df): + """Convert a ratings DataFrame to (user_ids, item_ids, timestamps) tensors.""" + return ( + torch.tensor(df["user_id"].values, dtype=torch.long), + torch.tensor(df["item_id"].values, dtype=torch.long), + torch.tensor(df["timestamp"].values, dtype=torch.long), + ) + + +@torch.no_grad() +def evaluate_fast(model, train_ratings_df, test_df, k=10, batch_size=256): + net = model.net + net.cuda().eval() + device = torch.device("cuda") + maxlen = net.session_max_len + + item_embs = net.project_all() + unique_items = model.item_id_mapping + + ext_to_int = {} + for i in range(len(unique_items)): + ext_to_int[int(unique_items[i].item())] = i + 1 + + train_grouped = train_ratings_df.sort_values("timestamp").groupby("user_id")["item_id"].agg(list).to_dict() + test_grouped = test_df.groupby("user_id")["item_id"].first().to_dict() + test_users = list(test_grouped.keys()) + + hits, ndcg_sum, mrr_sum, total = 0, 0.0, 0.0, 0 + + for start in tqdm(range(0, len(test_users), batch_size), desc="Evaluating"): + batch_users = test_users[start:start + batch_size] + seqs, targets = [], [] + for uid in batch_users: + history = train_grouped.get(uid, []) + mapped = [ext_to_int[iid] for iid in history if iid in ext_to_int] + if not mapped: + continue + seq = mapped[-maxlen:] + seqs.append([0] * (maxlen - len(seq)) + seq) + targets.append(ext_to_int.get(test_grouped[uid])) + + if not seqs: + continue + + x = torch.tensor(seqs, dtype=torch.long, device=device) + h = net.encode_last(x, use_id=False) + scores = h @ item_embs.T + scores[:, 0] = float("-inf") + + for i, target_int in enumerate(targets): + if target_int is None: + continue + _, topk_idx = scores[i].topk(k) + topk = topk_idx.cpu().tolist() + if target_int in topk: + rank = topk.index(target_int) + hits += 1 + ndcg_sum += 1.0 / np.log2(rank + 2) + mrr_sum += 1.0 / (rank + 1) + total += 1 + + return { + f"HR@{k}": hits / total if total else 0, + f"NDCG@{k}": ndcg_sum / total if total else 0, + f"MRR@{k}": mrr_sum / total if total else 0, + "n_users": total, + } + + +def main(): + print("=" * 60) + print("UniSRec Training on ML-20M") + print("=" * 60) + + ratings, movies = load_and_preprocess() + descriptions = json.loads(Path(DESCRIPTIONS_PATH).read_text()) + print(f"Loaded {len(descriptions)} descriptions") + + pretrained = build_pretrained_embeddings(movies, descriptions) + print(f"Pretrained embeddings: {pretrained.shape}") + + train_ratings, val_ratings, test_ratings = split_eval(ratings) + print(f"Split: train={len(train_ratings):,}, val={len(val_ratings):,}, test={len(test_ratings):,}") + + train_with_val = pd.concat([train_ratings, val_ratings]) + + checkpoint_path = DATA_DIR / "unisrec_v3.pt" + + model = UniSRecModel( + pretrained_item_embeddings=pretrained, + n_factors=512, + projection_hidden=512, + n_blocks=2, + n_heads=1, + session_max_len=200, + dropout=0.1, + adaptor_dropout=0.2, + adaptor_type="pca", + use_adaptor_ffn=True, + phase1_epochs=0, + phase2_epochs=0, + phase3_epochs=PHASE3_EPOCHS, + phase1_lr=1e-3, + phase2_lr=3e-4, + phase3_lr=1e-4, + lr_head=0.3, + lr_wp=0.1, + lr_transformer=3.0, + optimizer="adamw", + scheduler="cosine_warmup", + warmup_ratio=0.05, + min_lr_ratio=1.0, + grad_clip=1.0, + weight_decay=0.01, + loss="softmax", + patience=10, + batch_size=128, + dataloader_num_workers=0, + train_min_user_interactions=2, + verbose=1, + ) + + if checkpoint_path.exists(): + print(f"Loading checkpoint from {checkpoint_path}") + model.load_checkpoint(checkpoint_path) + else: + print("\nStarting training...") + user_ids, item_ids, timestamps = to_tensors(train_with_val) + model.fit(user_ids, item_ids, timestamps) + model.save_checkpoint(checkpoint_path) + print(f"Saved checkpoint to {checkpoint_path}") + + print("Training complete!") + + print("\n--- Validation Metrics ---") + val_results = evaluate_fast(model, train_ratings, val_ratings) + for m, v in val_results.items(): + print(f" {m}: {v}") + + print("\n--- Test Metrics ---") + test_results = evaluate_fast(model, train_with_val, test_ratings) + for m, v in test_results.items(): + print(f" {m}: {v}") + + print("\n--- Expected Metrics ---") + print(" val: HR@10=0.2431 NDCG@10=0.1335") + print(" test: HR@10=0.2218 NDCG@10=0.1251 MRR@10=0.0957") + + +if __name__ == "__main__": + main() From aa015f858252c2e61659f72a9c3c260176db3dc0 Mon Sep 17 00:00:00 2001 From: TOPAPEC Date: Fri, 24 Apr 2026 21:16:50 +0000 Subject: [PATCH 4/7] Tests + comparison --- scripts/compare_sasrec_unisrec.py | 421 +++++++++++++++ scripts/comparison_report.md | 58 +++ tests/fast_transformers/test_gpu_data.py | 460 +++++++++++++++++ .../fast_transformers/test_lightning_wrap.py | 176 +++++++ tests/fast_transformers/test_ranking.py | 331 ++++++++++++ .../test_unisrec_lightning.py | 482 ++++++++++++++++++ tests/fast_transformers/test_unisrec_model.py | 263 +++++----- 7 files changed, 2048 insertions(+), 143 deletions(-) create mode 100644 scripts/compare_sasrec_unisrec.py create mode 100644 scripts/comparison_report.md create mode 100644 tests/fast_transformers/test_gpu_data.py create mode 100644 tests/fast_transformers/test_lightning_wrap.py create mode 100644 tests/fast_transformers/test_ranking.py create mode 100644 tests/fast_transformers/test_unisrec_lightning.py diff --git a/scripts/compare_sasrec_unisrec.py b/scripts/compare_sasrec_unisrec.py new file mode 100644 index 00000000..bf6ee18a --- /dev/null +++ b/scripts/compare_sasrec_unisrec.py @@ -0,0 +1,421 @@ +"""Compare RecTools SASRec vs UniSRec-ID on ML-20M. + +Both use full softmax, Adam, n_factors=256, 10 epochs. +MIN_RATING=-1 (no filter), MIN_ITEM_INTERACTIONS=5, MIN_USER_INTERACTIONS=2. +Writes results to scripts/comparison_report.md. +""" + +import gc +import time +from datetime import datetime +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from tqdm import tqdm + +from rectools import Columns +from rectools.dataset import Dataset +from rectools.models import SASRecModel +from rectools.fast_transformers import UniSRecModel +from rectools.fast_transformers.gpu_data import build_sequences + +DATA_DIR = Path("data/ml-20m") +CACHE_EMB_PATH = DATA_DIR / "qwen_embeddings.pt" +REPORT_PATH = Path("scripts/comparison_report.md") + +MIN_RATING = -1 +MIN_ITEM_INTERACTIONS = 5 +MIN_USER_INTERACTIONS = 2 + +EPOCHS = 10 +PATIENCE = None +BATCH_SIZE = 128 +SESSION_MAX_LEN = 200 +N_FACTORS = 256 +N_BLOCKS = 2 +N_HEADS = 1 +LR = 1e-3 + + +def load_and_preprocess(): + ratings = pd.read_csv(DATA_DIR / "ml-20m" / "ratings.csv") + ratings.columns = ["user_id", "item_id", "rating", "timestamp"] + + if MIN_RATING > 0: + ratings = ratings[ratings["rating"] >= MIN_RATING] + + if MIN_ITEM_INTERACTIONS > 0: + item_counts = ratings.groupby("item_id").size() + popular = item_counts[item_counts >= MIN_ITEM_INTERACTIONS].index + ratings = ratings[ratings["item_id"].isin(popular)] + + if MIN_USER_INTERACTIONS > 0: + user_counts = ratings.groupby("user_id").size() + valid = user_counts[user_counts >= MIN_USER_INTERACTIONS].index + ratings = ratings[ratings["user_id"].isin(valid)] + + return ratings + + +def split_eval(ratings): + ratings = ratings.sort_values(["user_id", "timestamp"]) + grouped = ratings.groupby("user_id") + test_idx = grouped.tail(1).index + remaining = ratings.drop(test_idx) + val_idx = remaining.groupby("user_id").tail(1).index + train_idx = remaining.drop(val_idx).index + return ratings.loc[train_idx], ratings.loc[val_idx], ratings.loc[test_idx] + + +def to_tensors(df): + return ( + torch.tensor(df["user_id"].values, dtype=torch.long), + torch.tensor(df["item_id"].values, dtype=torch.long), + torch.tensor(df["timestamp"].values, dtype=torch.long), + ) + + +@torch.no_grad() +def evaluate_unisrec(model, train_df, test_df, k=10, batch_size=256, use_id=False): + net = model.net + net.cuda().eval() + device = torch.device("cuda") + maxlen = net.session_max_len + + item_embs = net.item_emb.weight if use_id else net.project_all() + unique_items = model.item_id_mapping + ext_to_int = {int(unique_items[i].item()): i + 1 for i in range(len(unique_items))} + + train_grouped = train_df.sort_values("timestamp").groupby("user_id")["item_id"].agg(list).to_dict() + test_grouped = test_df.groupby("user_id")["item_id"].first().to_dict() + test_users = list(test_grouped.keys()) + + hits, ndcg_sum, mrr_sum, total = 0, 0.0, 0.0, 0 + for start in tqdm(range(0, len(test_users), batch_size), desc="Eval UniSRec"): + batch_users = test_users[start:start + batch_size] + seqs, targets = [], [] + for uid in batch_users: + history = train_grouped.get(uid, []) + mapped = [ext_to_int[iid] for iid in history if iid in ext_to_int] + if not mapped: + continue + seq = mapped[-maxlen:] + seqs.append([0] * (maxlen - len(seq)) + seq) + targets.append(ext_to_int.get(test_grouped[uid])) + if not seqs: + continue + x = torch.tensor(seqs, dtype=torch.long, device=device) + h = net.encode_last(x, use_id=use_id) + scores = h @ item_embs.T + scores[:, 0] = float("-inf") + for i, target_int in enumerate(targets): + if target_int is None: + continue + _, topk_idx = scores[i].topk(k) + topk = topk_idx.cpu().tolist() + if target_int in topk: + rank = topk.index(target_int) + hits += 1 + ndcg_sum += 1.0 / np.log2(rank + 2) + mrr_sum += 1.0 / (rank + 1) + total += 1 + return {"HR@10": hits / total, "NDCG@10": ndcg_sum / total, "MRR@10": mrr_sum / total, "n_users": total} + + +def evaluate_sasrec(model, dataset_for_recommend, test_df, k=10): + test_users = test_df["user_id"].unique() + reco = model.recommend(users=test_users, dataset=dataset_for_recommend, k=k, filter_viewed=False) + + test_targets = test_df.groupby("user_id")["item_id"].first().to_dict() + hits, ndcg_sum, mrr_sum, total = 0, 0.0, 0.0, 0 + for uid, group in reco.groupby(Columns.User): + target = test_targets.get(uid) + if target is None: + continue + items = group[Columns.Item].tolist() + if target in items: + rank = items.index(target) + hits += 1 + ndcg_sum += 1.0 / np.log2(rank + 2) + mrr_sum += 1.0 / (rank + 1) + total += 1 + return {"HR@10": hits / total, "NDCG@10": ndcg_sum / total, "MRR@10": mrr_sum / total, "n_users": total} + + +def cleanup(): + gc.collect() + torch.cuda.empty_cache() + + +def write_report(timings: dict, metrics: dict, data_info: dict): + gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A" + lines = [ + f"# SASRec vs UniSRec-ID Comparison", + f"", + f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M')} ", + f"**GPU:** {gpu_name} ", + f"**Dataset:** ML-20M (min_rating={MIN_RATING}, min_item={MIN_ITEM_INTERACTIONS}, min_user={MIN_USER_INTERACTIONS})", + f"", + f"## Data", + f"", + f"| | Count |", + f"|---|---:|", + f"| Interactions | {data_info['n_interactions']:,} |", + f"| Users | {data_info['n_users']:,} |", + f"| Items | {data_info['n_items']:,} |", + f"| Train | {data_info['n_train']:,} |", + f"| Val | {data_info['n_val']:,} |", + f"| Test | {data_info['n_test']:,} |", + f"", + f"## Config", + f"", + f"| Parameter | Value |", + f"|---|---|", + f"| n_factors | {N_FACTORS} |", + f"| n_blocks | {N_BLOCKS} |", + f"| n_heads | {N_HEADS} |", + f"| session_max_len | {SESSION_MAX_LEN} |", + f"| batch_size | {BATCH_SIZE} |", + f"| lr | {LR} |", + f"| loss | softmax |", + f"| optimizer | Adam |", + f"| epochs | {EPOCHS} |", + f"| patience | {PATIENCE} |", + f"| dropout | 0.1 |", + f"", + f"## Timing", + f"", + f"| Stage | SASRec | UniSRec ID |", + f"|---|---:|---:|", + ] + + for stage in ["data_load", "preprocessing", "model_init", "training", "eval"]: + s = timings.get(f"sasrec_{stage}", 0) + u = timings.get(f"unisrec_{stage}", 0) + label = { + "data_load": "Data load & split", + "preprocessing": "Preprocessing", + "model_init": "Model init", + "training": f"Training ({EPOCHS} epochs)", + "eval": "Evaluation", + }[stage] + lines.append(f"| {label} | {s:.1f}s | {u:.1f}s |") + + s_total = sum(timings.get(f"sasrec_{s}", 0) for s in ["preprocessing", "model_init", "training", "eval"]) + u_total = sum(timings.get(f"unisrec_{s}", 0) for s in ["preprocessing", "model_init", "training", "eval"]) + lines.append(f"| **Total** | **{s_total:.1f}s** | **{u_total:.1f}s** |") + + s_epoch = timings.get("sasrec_training", 0) / max(timings.get("sasrec_epochs_done", 1), 1) + u_epoch = timings.get("unisrec_training", 0) / max(timings.get("unisrec_epochs_done", 1), 1) + lines.extend([ + f"", + f"| | SASRec | UniSRec ID |", + f"|---|---:|---:|", + f"| Epochs completed | {timings.get('sasrec_epochs_done', EPOCHS)} | {timings.get('unisrec_epochs_done', EPOCHS)} |", + f"| Time per epoch | {s_epoch:.1f}s | {u_epoch:.1f}s |", + f"| Preprocessing speedup | — | {timings.get('prep_speedup', 0):.0f}x |", + ]) + + lines.extend([ + f"", + f"## Quality (test set, {metrics['sasrec']['n_users']:,} users)", + f"", + f"| Model | HR@10 | NDCG@10 | MRR@10 |", + f"|---|---:|---:|---:|", + ]) + for name, key in [("SASRec", "sasrec"), ("UniSRec ID", "unisrec")]: + m = metrics[key] + lines.append(f"| {name} | {m['HR@10']:.4f} | {m['NDCG@10']:.4f} | {m['MRR@10']:.4f} |") + + hr_diff = (metrics["unisrec"]["HR@10"] / metrics["sasrec"]["HR@10"] - 1) * 100 + ndcg_diff = (metrics["unisrec"]["NDCG@10"] / metrics["sasrec"]["NDCG@10"] - 1) * 100 + lines.extend([ + f"", + f"UniSRec vs SASRec: HR@10 {hr_diff:+.1f}%, NDCG@10 {ndcg_diff:+.1f}%", + ]) + + report = "\n".join(lines) + "\n" + REPORT_PATH.write_text(report) + print(f"\nReport written to {REPORT_PATH}") + return report + + +def main(): + torch.set_float32_matmul_precision("high") + timings = {} + + print(f"SASRec vs UniSRec-ID | {EPOCHS} epochs | n_factors={N_FACTORS} | Adam | softmax") + print("=" * 70) + + # ── Data ── + t0 = time.time() + ratings = load_and_preprocess() + train_ratings, val_ratings, test_ratings = split_eval(ratings) + train_with_val = pd.concat([train_ratings, val_ratings]) + timings["data_load"] = time.time() - t0 + + data_info = { + "n_interactions": len(ratings), + "n_users": ratings["user_id"].nunique(), + "n_items": ratings["item_id"].nunique(), + "n_train": len(train_ratings), + "n_val": len(val_ratings), + "n_test": len(test_ratings), + } + print(f"Data: {data_info['n_interactions']:,} interactions, {data_info['n_users']:,} users, {data_info['n_items']:,} items") + print(f"Split: train={data_info['n_train']:,}, val={data_info['n_val']:,}, test={data_info['n_test']:,}") + + user_ids_t, item_ids_t, timestamps_t = to_tensors(train_with_val) + pretrained = torch.load(CACHE_EMB_PATH, weights_only=True) + + # ══════════════════════════════════════════════════════════════ + # 1. SASRec (RecTools) + # ══════════════════════════════════════════════════════════════ + print(f"\n{'='*70}") + print(f"1. SASRec (RecTools) — {EPOCHS} epochs") + print(f"{'='*70}") + + # Preprocessing + t0 = time.time() + df_rectools = pd.DataFrame({ + Columns.User: train_with_val["user_id"].values, + Columns.Item: train_with_val["item_id"].values, + Columns.Weight: 1.0, + Columns.Datetime: pd.to_datetime(train_with_val["timestamp"], unit="s"), + }) + dataset = Dataset.construct(df_rectools) + timings["sasrec_preprocessing"] = time.time() - t0 + print(f" Preprocessing (Dataset.construct): {timings['sasrec_preprocessing']:.2f}s") + + # Model init + training + def sasrec_trainer(**kwargs): + import pytorch_lightning as pl + callbacks = [] + if PATIENCE is not None: + from pytorch_lightning.callbacks import EarlyStopping + callbacks.append(EarlyStopping(monitor="val_loss", patience=PATIENCE, mode="min")) + return pl.Trainer( + max_epochs=EPOCHS, + min_epochs=1, + callbacks=callbacks or None, + enable_checkpointing=False, + enable_model_summary=False, + logger=True, + enable_progress_bar=True, + devices=1, + ) + + sasrec_kwargs = dict( + n_factors=N_FACTORS, + n_blocks=N_BLOCKS, + n_heads=N_HEADS, + session_max_len=SESSION_MAX_LEN, + dropout_rate=0.1, + loss="softmax", + lr=LR, + batch_size=BATCH_SIZE, + epochs=EPOCHS, + train_min_user_interactions=MIN_USER_INTERACTIONS, + dataloader_num_workers=0, + verbose=1, + get_trainer_func=sasrec_trainer, + ) + if PATIENCE is not None: + def sasrec_val_mask(interactions_df, **kwargs): + idx = interactions_df.groupby(Columns.User).tail(1).index + mask = pd.Series(False, index=interactions_df.index) + mask.loc[idx] = True + return mask + sasrec_kwargs["get_val_mask_func"] = sasrec_val_mask + + t0 = time.time() + sasrec = SASRecModel(**sasrec_kwargs) + timings["sasrec_model_init"] = time.time() - t0 + + t0 = time.time() + sasrec.fit(dataset) + timings["sasrec_training"] = time.time() - t0 + timings["sasrec_epochs_done"] = sasrec.fit_trainer.current_epoch + 1 + print(f" Training: {timings['sasrec_training']:.1f}s, {timings['sasrec_epochs_done']} epochs") + + # Eval + print(" Evaluating...") + t0 = time.time() + sasrec_metrics = evaluate_sasrec(sasrec, dataset, test_ratings) + timings["sasrec_eval"] = time.time() - t0 + print(f" Eval: {timings['sasrec_eval']:.1f}s") + print(f" HR@10={sasrec_metrics['HR@10']:.4f} NDCG@10={sasrec_metrics['NDCG@10']:.4f} MRR@10={sasrec_metrics['MRR@10']:.4f}") + del sasrec; cleanup() + + # ══════════════════════════════════════════════════════════════ + # 2. UniSRec ID + # ══════════════════════════════════════════════════════════════ + print(f"\n{'='*70}") + print(f"2. UniSRec ID — {EPOCHS} epochs") + print(f"{'='*70}") + + # Preprocessing + torch.cuda.synchronize() + t0 = time.time() + _ = build_sequences(user_ids_t, item_ids_t, timestamps_t, max_len=SESSION_MAX_LEN) + torch.cuda.synchronize() + timings["unisrec_preprocessing"] = time.time() - t0 + print(f" Preprocessing (build_sequences): {timings['unisrec_preprocessing']:.4f}s") + timings["prep_speedup"] = timings["sasrec_preprocessing"] / timings["unisrec_preprocessing"] + print(f" Speedup vs Dataset.construct: {timings['prep_speedup']:.0f}x") + + # Model init + t0 = time.time() + unisrec_id = UniSRecModel( + pretrained_item_embeddings=pretrained, + n_factors=N_FACTORS, + projection_hidden=N_FACTORS, + n_blocks=N_BLOCKS, + n_heads=N_HEADS, + session_max_len=SESSION_MAX_LEN, + dropout=0.1, + adaptor_dropout=0.2, + adaptor_type="pca", + use_adaptor_ffn=True, + phase1_epochs=EPOCHS, + phase2_epochs=0, + phase3_epochs=0, + phase1_lr=LR, + optimizer="adam", + grad_clip=1.0, + weight_decay=0.0, + loss="softmax", + patience=PATIENCE, + batch_size=BATCH_SIZE, + dataloader_num_workers=0, + train_min_user_interactions=MIN_USER_INTERACTIONS, + verbose=1, + ) + timings["unisrec_model_init"] = time.time() - t0 + + # Training (fit includes build_sequences internally, but we already measured preprocessing separately) + t0 = time.time() + unisrec_id.fit(user_ids_t, item_ids_t, timestamps_t) + timings["unisrec_training"] = time.time() - t0 + timings["unisrec_epochs_done"] = EPOCHS + print(f" Training (total fit): {timings['unisrec_training']:.1f}s") + + # Eval + print(" Evaluating...") + t0 = time.time() + unisrec_metrics = evaluate_unisrec(unisrec_id, train_with_val, test_ratings, use_id=True) + timings["unisrec_eval"] = time.time() - t0 + print(f" Eval: {timings['unisrec_eval']:.1f}s") + print(f" HR@10={unisrec_metrics['HR@10']:.4f} NDCG@10={unisrec_metrics['NDCG@10']:.4f} MRR@10={unisrec_metrics['MRR@10']:.4f}") + del unisrec_id; cleanup() + + # ── Report ── + metrics = {"sasrec": sasrec_metrics, "unisrec": unisrec_metrics} + report = write_report(timings, metrics, data_info) + print("\n" + report) + + +if __name__ == "__main__": + main() diff --git a/scripts/comparison_report.md b/scripts/comparison_report.md new file mode 100644 index 00000000..fd136387 --- /dev/null +++ b/scripts/comparison_report.md @@ -0,0 +1,58 @@ +# SASRec vs UniSRec-ID Comparison + +**Date:** 2026-04-24 19:59 +**GPU:** NVIDIA GeForce RTX 4090 +**Dataset:** ML-20M (min_rating=-1, min_item=5, min_user=2) + +## Data + +| | Count | +|---|---:| +| Interactions | 19,984,024 | +| Users | 138,493 | +| Items | 18,345 | +| Train | 19,707,038 | +| Val | 138,493 | +| Test | 138,493 | + +## Config + +| Parameter | Value | +|---|---| +| n_factors | 256 | +| n_blocks | 2 | +| n_heads | 1 | +| session_max_len | 200 | +| batch_size | 128 | +| lr | 0.001 | +| loss | softmax | +| optimizer | Adam | +| epochs | 10 | +| patience | None | +| dropout | 0.1 | + +## Timing + +| Stage | SASRec | UniSRec ID | +|---|---:|---:| +| Data load & split | 0.0s | 0.0s | +| Preprocessing | 14.6s | 0.5s | +| Model init | 0.0s | 0.0s | +| Training (10 epochs) | 911.8s | 639.5s | +| Evaluation | 175.6s | 28.0s | +| **Total** | **1102.1s** | **668.0s** | + +| | SASRec | UniSRec ID | +|---|---:|---:| +| Epochs completed | 11 | 10 | +| Time per epoch | 82.9s | 63.9s | +| Preprocessing speedup | — | 29x | + +## Quality (test set, 138,493 users) + +| Model | HR@10 | NDCG@10 | MRR@10 | +|---|---:|---:|---:| +| SASRec | 0.2417 | 0.1410 | 0.1103 | +| UniSRec ID | 0.2528 | 0.1495 | 0.1179 | + +UniSRec vs SASRec: HR@10 +4.6%, NDCG@10 +6.0% diff --git a/tests/fast_transformers/test_gpu_data.py b/tests/fast_transformers/test_gpu_data.py new file mode 100644 index 00000000..c3938e6f --- /dev/null +++ b/tests/fast_transformers/test_gpu_data.py @@ -0,0 +1,460 @@ +"""Tests for GPU-native sequence building and data utilities.""" + +import torch +import pytest + +from rectools.fast_transformers.gpu_data import ( + build_sequences, + align_embeddings, + GPUBatchDataset, + make_dataloader, +) + +DEVICE = "cpu" + + +class TestBuildSequences: + """Tests for the build_sequences function.""" + + def test_basic_two_users(self) -> None: + """Two users with 3 interactions each, max_len=4.""" + user_ids = torch.tensor([0, 0, 0, 1, 1, 1]) + item_ids = torch.tensor([10, 20, 30, 40, 50, 60]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6]) + + x, y, unique_items, result_users = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE + ) + + assert x.shape == (2, 4) + assert y.shape == (2, 4) + + # Items are mapped to internal 1-based IDs; 0 = padding + # unique_items is sorted, so: [10, 20, 30, 40, 50, 60] + # internal IDs: 10->1, 20->2, 30->3, 40->4, 50->5, 60->6 + + # User 0: items [10, 20, 30] in order => internal [1, 2, 3] + # x = [0, 1, 2] left-padded to len 4 => [0, 0, 1, 2] + # y = [0, 2, 3] left-padded to len 4 => [0, 0, 2, 3] + assert x[0].tolist() == [0, 0, 1, 2] + assert y[0].tolist() == [0, 0, 2, 3] + + # User 1: items [40, 50, 60] in order => internal [4, 5, 6] + # x = [0, 4, 5] => [0, 0, 4, 5] + # y = [0, 5, 6] => [0, 0, 5, 6] + assert x[1].tolist() == [0, 0, 4, 5] + assert y[1].tolist() == [0, 0, 5, 6] + + assert result_users.tolist() == [0, 1] + + def test_unique_items_mapping(self) -> None: + """unique_items should map internal_id - 1 => external_id.""" + user_ids = torch.tensor([0, 0, 0]) + item_ids = torch.tensor([100, 50, 200]) + timestamps = torch.tensor([1, 2, 3]) + + _, _, unique_items, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=5, min_interactions=2, device=DEVICE + ) + + # torch.unique sorts, so unique_items = [50, 100, 200] + assert unique_items.tolist() == [50, 100, 200] + + def test_min_interactions_filtering(self) -> None: + """Users with fewer than min_interactions should be dropped.""" + user_ids = torch.tensor([0, 0, 0, 1, 2, 2]) + item_ids = torch.tensor([10, 20, 30, 40, 50, 60]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6]) + + x, y, _, result_users = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE + ) + + # User 1 has only 1 interaction => dropped + assert x.shape[0] == 2 + assert result_users.tolist() == [0, 2] + + def test_min_interactions_higher_threshold(self) -> None: + """Higher min_interactions threshold filters more aggressively.""" + user_ids = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2]) + item_ids = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80, 90]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]) + + x, y, _, result_users = build_sequences( + user_ids, item_ids, timestamps, max_len=5, min_interactions=3, device=DEVICE + ) + + # User 0 has 3, User 1 has 2 (dropped), User 2 has 4 + assert x.shape[0] == 2 + assert result_users.tolist() == [0, 2] + + def test_all_users_filtered_out(self) -> None: + """When all users have fewer than min_interactions, return empty tensors.""" + user_ids = torch.tensor([0, 1, 2]) + item_ids = torch.tensor([10, 20, 30]) + timestamps = torch.tensor([1, 2, 3]) + + x, y, _, result_users = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE + ) + + assert x.shape == (0, 4) + assert y.shape == (0, 4) + assert len(result_users) == 0 + + def test_max_len_truncation(self) -> None: + """Sequences longer than max_len should be truncated, keeping the most recent items.""" + user_ids = torch.tensor([0, 0, 0, 0, 0]) + item_ids = torch.tensor([10, 20, 30, 40, 50]) + timestamps = torch.tensor([1, 2, 3, 4, 5]) + + x, y, _, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE + ) + + # 5 items total. capped_lens = min(5, 3+1) = 4, effective = 3 + # Sorted items: 10->1, 20->2, 30->3, 40->4, 50->5 + # last 4 items for x/y windowing: items at positions [1..4] + # x takes [1,2,3] => internal [2,3,4]; y takes [2,3,4] => internal [3,4,5] + assert x.shape == (1, 3) + assert y.shape == (1, 3) + assert x[0].tolist() == [2, 3, 4] + assert y[0].tolist() == [3, 4, 5] + + def test_timestamp_ordering(self) -> None: + """Items should be ordered by timestamp regardless of input order.""" + user_ids = torch.tensor([0, 0, 0]) + item_ids = torch.tensor([30, 10, 20]) + timestamps = torch.tensor([3, 1, 2]) + + x, y, unique_items, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE + ) + + # unique_items (sorted by value): [10, 20, 30] => internal 1, 2, 3 + # By timestamp: 10(t=1), 20(t=2), 30(t=3) => internal [1, 2, 3] + # x = [0, 0, 1, 2] + # y = [0, 0, 2, 3] + assert unique_items.tolist() == [10, 20, 30] + assert x[0].tolist() == [0, 0, 1, 2] + assert y[0].tolist() == [0, 0, 2, 3] + + def test_left_padding(self) -> None: + """Sequences shorter than max_len should be left-padded with zeros.""" + user_ids = torch.tensor([0, 0]) + item_ids = torch.tensor([10, 20]) + timestamps = torch.tensor([1, 2]) + + x, y, _, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=5, min_interactions=2, device=DEVICE + ) + + # 2 items => effective_len = 1 (capped_lens = 2, effective = 1) + # x = [0, 0, 0, 0, 1], y = [0, 0, 0, 0, 2] + assert x[0].tolist() == [0, 0, 0, 0, 1] + assert y[0].tolist() == [0, 0, 0, 0, 2] + + def test_result_users_preserves_external_ids(self) -> None: + """result_users should contain external user IDs, not internal indices.""" + user_ids = torch.tensor([100, 100, 100, 200, 200, 200]) + item_ids = torch.tensor([1, 2, 3, 4, 5, 6]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6]) + + _, _, _, result_users = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE + ) + + assert result_users.tolist() == [100, 200] + + def test_shared_items_across_users(self) -> None: + """Same items used by different users should share internal IDs.""" + user_ids = torch.tensor([0, 0, 0, 1, 1, 1]) + item_ids = torch.tensor([10, 20, 30, 20, 30, 40]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6]) + + x, y, unique_items, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE + ) + + # unique_items: [10, 20, 30, 40] => internal 1, 2, 3, 4 + assert unique_items.tolist() == [10, 20, 30, 40] + + # User 0: 10(1), 20(2), 30(3) => x=[0, 1, 2], y=[0, 2, 3] + assert x[0].tolist() == [0, 0, 1, 2] + assert y[0].tolist() == [0, 0, 2, 3] + + # User 1: 20(2), 30(3), 40(4) => x=[0, 2, 3], y=[0, 3, 4] + assert x[1].tolist() == [0, 0, 2, 3] + assert y[1].tolist() == [0, 0, 3, 4] + + def test_output_device(self) -> None: + """All output tensors should be on the specified device.""" + user_ids = torch.tensor([0, 0]) + item_ids = torch.tensor([1, 2]) + timestamps = torch.tensor([1, 2]) + + x, y, unique_items, result_users = build_sequences( + user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE + ) + + assert x.device.type == DEVICE + assert y.device.type == DEVICE + assert unique_items.device.type == DEVICE + assert result_users.device.type == DEVICE + + def test_output_dtypes(self) -> None: + """x and y should be long tensors.""" + user_ids = torch.tensor([0, 0]) + item_ids = torch.tensor([1, 2]) + timestamps = torch.tensor([1, 2]) + + x, y, _, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE + ) + + assert x.dtype == torch.long + assert y.dtype == torch.long + + def test_exact_max_len_sequence(self) -> None: + """Sequence with exactly max_len + 1 items should fill entire x and y.""" + user_ids = torch.tensor([0, 0, 0, 0]) + item_ids = torch.tensor([10, 20, 30, 40]) + timestamps = torch.tensor([1, 2, 3, 4]) + + x, y, _, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE + ) + + # 4 items, max_len=3 => capped_lens = min(4, 4) = 4, effective = 3 + # No padding needed + assert 0 not in x[0].tolist() + assert 0 not in y[0].tolist() + + def test_multiple_users_different_lengths(self) -> None: + """Users with different sequence lengths should be properly handled.""" + user_ids = torch.tensor([0, 0, 1, 1, 1, 1]) + item_ids = torch.tensor([10, 20, 30, 40, 50, 60]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6]) + + x, y, unique_items, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=5, min_interactions=2, device=DEVICE + ) + + # unique_items: [10, 20, 30, 40, 50, 60] => internal 1..6 + # User 0: 2 items => effective=1 + # x[0] = [0, 0, 0, 0, 1], y[0] = [0, 0, 0, 0, 2] + assert x[0].tolist() == [0, 0, 0, 0, 1] + assert y[0].tolist() == [0, 0, 0, 0, 2] + + # User 1: 4 items => effective=3 + # x[1] = [0, 0, 3, 4, 5], y[1] = [0, 0, 4, 5, 6] + assert x[1].tolist() == [0, 0, 3, 4, 5] + assert y[1].tolist() == [0, 0, 4, 5, 6] + + +class TestAlignEmbeddings: + """Tests for the align_embeddings function.""" + + def test_2d_pretrained(self) -> None: + """Align 2D pretrained embeddings to internal ID order.""" + pretrained = torch.tensor([ + [1.0, 2.0], # external item 0 + [3.0, 4.0], # external item 1 + [5.0, 6.0], # external item 2 + [7.0, 8.0], # external item 3 + ]) + # unique_items: external IDs that map to internal IDs 1, 2, 3 + unique_items = torch.tensor([2, 0, 3]) + n_items = 3 + + aligned = align_embeddings(pretrained, unique_items, n_items) + + assert aligned.shape == (4, 2) # n_items + 1 + # Row 0 (padding) should be zeros + assert aligned[0].tolist() == [0.0, 0.0] + # Internal ID 1 => external ID 2 => pretrained[2] = [5, 6] + assert aligned[1].tolist() == [5.0, 6.0] + # Internal ID 2 => external ID 0 => pretrained[0] = [1, 2] + assert aligned[2].tolist() == [1.0, 2.0] + # Internal ID 3 => external ID 3 => pretrained[3] = [7, 8] + assert aligned[3].tolist() == [7.0, 8.0] + + def test_3d_pretrained(self) -> None: + """Align 3D pretrained embeddings (multi-variant).""" + pretrained = torch.tensor([ + [[1.0, 2.0], [3.0, 4.0]], # item 0, 2 variants + [[5.0, 6.0], [7.0, 8.0]], # item 1 + ]) + unique_items = torch.tensor([1, 0]) + n_items = 2 + + aligned = align_embeddings(pretrained, unique_items, n_items) + + assert aligned.shape == (3, 2, 2) # (n_items+1, n_variants, dim) + # Row 0 (padding) should be zeros + torch.testing.assert_close(aligned[0], torch.zeros(2, 2)) + # Internal ID 1 => external ID 1 + torch.testing.assert_close(aligned[1], pretrained[1]) + # Internal ID 2 => external ID 0 + torch.testing.assert_close(aligned[2], pretrained[0]) + + def test_padding_row_is_zero(self) -> None: + """The first row (padding, internal ID 0) should always be zeros.""" + pretrained = torch.randn(10, 8) + unique_items = torch.tensor([0, 1, 2]) + n_items = 3 + + aligned = align_embeddings(pretrained, unique_items, n_items) + + torch.testing.assert_close(aligned[0], torch.zeros(8)) + + def test_out_of_range_indices(self) -> None: + """Items with external IDs outside pretrained range should get zero embeddings.""" + pretrained = torch.tensor([ + [1.0, 2.0], # external 0 + [3.0, 4.0], # external 1 + ]) + # External ID 5 is out of range (pretrained has only 2 rows) + unique_items = torch.tensor([0, 5, 1]) + n_items = 3 + + aligned = align_embeddings(pretrained, unique_items, n_items) + + assert aligned.shape == (4, 2) + # Internal 1 => external 0 => valid + assert aligned[1].tolist() == [1.0, 2.0] + # Internal 2 => external 5 => out of range => zeros + assert aligned[2].tolist() == [0.0, 0.0] + # Internal 3 => external 1 => valid + assert aligned[3].tolist() == [3.0, 4.0] + + def test_negative_indices_handled(self) -> None: + """Negative external IDs should be treated as invalid and get zeros.""" + pretrained = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + unique_items = torch.tensor([-1, 0]) + n_items = 2 + + aligned = align_embeddings(pretrained, unique_items, n_items) + + assert aligned.shape == (3, 2) + # Internal 1 => external -1 => invalid => zeros + assert aligned[1].tolist() == [0.0, 0.0] + # Internal 2 => external 0 => valid + assert aligned[2].tolist() == [1.0, 2.0] + + def test_output_shape_matches_n_items_plus_one(self) -> None: + """Output shape should be (n_items + 1, D) regardless of unique_items length.""" + pretrained = torch.randn(20, 4) + unique_items = torch.tensor([3, 7, 15]) + n_items = 3 + + aligned = align_embeddings(pretrained, unique_items, n_items) + + assert aligned.shape == (4, 4) + + +class TestGPUBatchDataset: + """Tests for GPUBatchDataset.""" + + def test_length(self) -> None: + x = torch.zeros(5, 3) + y = torch.zeros(5, 3) + ds = GPUBatchDataset(x, y) + assert len(ds) == 5 + + def test_getitem_returns_dict(self) -> None: + x = torch.tensor([[1, 2, 3], [4, 5, 6]]) + y = torch.tensor([[7, 8, 9], [10, 11, 12]]) + ds = GPUBatchDataset(x, y) + + batch = ds[0] + assert isinstance(batch, dict) + assert "x" in batch + assert "y" in batch + assert batch["x"].tolist() == [1, 2, 3] + assert batch["y"].tolist() == [7, 8, 9] + + def test_getitem_second_element(self) -> None: + x = torch.tensor([[1, 2], [3, 4]]) + y = torch.tensor([[5, 6], [7, 8]]) + ds = GPUBatchDataset(x, y) + + batch = ds[1] + assert batch["x"].tolist() == [3, 4] + assert batch["y"].tolist() == [7, 8] + + def test_transform_applied(self) -> None: + x = torch.tensor([[1, 2]]) + y = torch.tensor([[3, 4]]) + + def double_x(batch: dict) -> dict: + batch["x"] = batch["x"] * 2 + return batch + + ds = GPUBatchDataset(x, y, transform=double_x) + batch = ds[0] + assert batch["x"].tolist() == [2, 4] + assert batch["y"].tolist() == [3, 4] + + def test_no_transform(self) -> None: + x = torch.tensor([[10, 20]]) + y = torch.tensor([[30, 40]]) + ds = GPUBatchDataset(x, y, transform=None) + + batch = ds[0] + assert batch["x"].tolist() == [10, 20] + assert batch["y"].tolist() == [30, 40] + + +class TestMakeDataloader: + """Tests for make_dataloader.""" + + def test_returns_dataloader(self) -> None: + x = torch.zeros(10, 3) + y = torch.zeros(10, 3) + dl = make_dataloader(x, y, batch_size=4, shuffle=False) + assert isinstance(dl, torch.utils.data.DataLoader) + + def test_batch_size(self) -> None: + x = torch.zeros(10, 3) + y = torch.zeros(10, 3) + dl = make_dataloader(x, y, batch_size=4, shuffle=False) + + batches = list(dl) + # 10 samples, batch_size 4 => 3 batches: 4, 4, 2 + assert len(batches) == 3 + assert batches[0]["x"].shape[0] == 4 + assert batches[2]["x"].shape[0] == 2 + + def test_batch_content(self) -> None: + x = torch.tensor([[1, 2], [3, 4], [5, 6]]) + y = torch.tensor([[7, 8], [9, 10], [11, 12]]) + dl = make_dataloader(x, y, batch_size=3, shuffle=False) + + batch = next(iter(dl)) + assert batch["x"].shape == (3, 2) + assert batch["y"].shape == (3, 2) + torch.testing.assert_close(batch["x"], x) + torch.testing.assert_close(batch["y"], y) + + def test_transform_in_dataloader(self) -> None: + x = torch.tensor([[1, 2], [3, 4]]) + y = torch.tensor([[5, 6], [7, 8]]) + + def add_key(batch: dict) -> dict: + batch["mask"] = (batch["x"] > 0).long() + return batch + + dl = make_dataloader(x, y, batch_size=2, shuffle=False, transform=add_key) + batch = next(iter(dl)) + assert "mask" in batch + assert batch["mask"].tolist() == [[1, 1], [1, 1]] + + def test_single_sample_batch(self) -> None: + x = torch.tensor([[1, 2, 3]]) + y = torch.tensor([[4, 5, 6]]) + dl = make_dataloader(x, y, batch_size=1, shuffle=False) + + batch = next(iter(dl)) + assert batch["x"].shape == (1, 3) + assert batch["y"].shape == (1, 3) diff --git a/tests/fast_transformers/test_lightning_wrap.py b/tests/fast_transformers/test_lightning_wrap.py new file mode 100644 index 00000000..ca3b5b30 --- /dev/null +++ b/tests/fast_transformers/test_lightning_wrap.py @@ -0,0 +1,176 @@ +"""Tests for FlatSASRecLightning wrapper.""" + +import torch +import pytest + +from rectools.fast_transformers.net import FlatSASRec +from rectools.fast_transformers.lightning_wrap import FlatSASRecLightning + + +@pytest.fixture() +def net() -> FlatSASRec: + return FlatSASRec( + n_items=10, + n_factors=8, + n_blocks=1, + n_heads=1, + session_max_len=5, + dropout=0.0, + ) + + +class TestFlatSASRecLightning: + # ---- constructor ---- + + def test_init_softmax_loss(self, net: FlatSASRec) -> None: + module = FlatSASRecLightning(net, loss="softmax") + assert module.loss_name == "softmax" + assert isinstance(module.loss_fn, torch.nn.CrossEntropyLoss) + + def test_init_bce_loss(self, net: FlatSASRec) -> None: + module = FlatSASRecLightning(net, loss="BCE") + assert module.loss_name == "BCE" + assert isinstance(module.loss_fn, torch.nn.BCEWithLogitsLoss) + + def test_init_invalid_loss_raises(self, net: FlatSASRec) -> None: + with pytest.raises(ValueError, match="Unsupported loss"): + FlatSASRecLightning(net, loss="mse") + + def test_init_stores_hyperparams(self, net: FlatSASRec) -> None: + module = FlatSASRecLightning(net, lr=0.005, n_negatives=4) + assert module.lr == 0.005 + assert module.n_negatives == 4 + + # ---- configure_optimizers ---- + + def test_configure_optimizers_type_and_lr(self, net: FlatSASRec) -> None: + lr = 2e-4 + module = FlatSASRecLightning(net, lr=lr) + optimizer = module.configure_optimizers() + assert isinstance(optimizer, torch.optim.Adam) + assert optimizer.defaults["lr"] == lr + + def test_configure_optimizers_betas(self, net: FlatSASRec) -> None: + module = FlatSASRecLightning(net) + optimizer = module.configure_optimizers() + assert optimizer.defaults["betas"] == (0.9, 0.98) + + # ---- on_train_start ---- + + def test_on_train_start_reinitializes_params(self, net: FlatSASRec) -> None: + module = FlatSASRecLightning(net) + + # Snapshot parameters with dim > 1 before reinit + snapshots_before = { + name: p.clone() for name, p in module.net.named_parameters() if p.dim() > 1 + } + assert len(snapshots_before) > 0, "Expected at least one param with dim > 1" + + # Force parameters to a constant value so reinit is detectable + with torch.no_grad(): + for p in module.net.parameters(): + if p.dim() > 1: + p.fill_(42.0) + + module.on_train_start() + + changed = False + for name, p in module.net.named_parameters(): + if p.dim() > 1 and not torch.all(p == 42.0): + changed = True + break + assert changed, "on_train_start should reinitialize parameters via xavier_uniform_" + + # ---- training_step with softmax ---- + + def test_training_step_softmax_returns_scalar(self, net: FlatSASRec) -> None: + module = FlatSASRecLightning(net, loss="softmax") + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0, "Loss should be a scalar" + assert not torch.isnan(loss), "Loss should not be NaN" + assert not torch.isinf(loss), "Loss should not be Inf" + + def test_training_step_softmax_positive_loss(self, net: FlatSASRec) -> None: + module = FlatSASRecLightning(net, loss="softmax") + batch = { + "x": torch.tensor([[1, 2, 3, 4, 5]]), + "y": torch.tensor([[2, 3, 4, 5, 6]]), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.item() > 0, "Cross-entropy loss should be positive" + + def test_training_step_softmax_all_padding_returns_nan(self, net: FlatSASRec) -> None: + """When all targets are padding (y=0), cross_entropy with ignore_index=-100 returns NaN.""" + module = FlatSASRecLightning(net, loss="softmax") + batch = { + "x": torch.tensor([[0, 0, 0, 0, 0]]), + "y": torch.tensor([[0, 0, 0, 0, 0]]), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0 + # PyTorch cross_entropy returns NaN when all targets are ignored + assert torch.isnan(loss) + + # ---- training_step with BCE ---- + + def test_training_step_bce_returns_scalar(self, net: FlatSASRec) -> None: + n_negatives = 3 + module = FlatSASRecLightning(net, loss="BCE", n_negatives=n_negatives) + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + "negatives": torch.randint(1, 10, (2, 5, n_negatives)), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0, "Loss should be a scalar" + assert not torch.isnan(loss), "Loss should not be NaN" + assert not torch.isinf(loss), "Loss should not be Inf" + + def test_training_step_bce_positive_loss(self, net: FlatSASRec) -> None: + n_negatives = 2 + module = FlatSASRecLightning(net, loss="BCE", n_negatives=n_negatives) + batch = { + "x": torch.tensor([[1, 2, 3, 4, 5]]), + "y": torch.tensor([[2, 3, 4, 5, 6]]), + "negatives": torch.randint(1, 10, (1, 5, n_negatives)), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.item() > 0, "BCE loss should be positive" + + def test_training_step_bce_mask_reduces_loss(self, net: FlatSASRec) -> None: + """Padding positions should not contribute to BCE loss.""" + n_negatives = 2 + module = FlatSASRecLightning(net, loss="BCE", n_negatives=n_negatives) + module.eval() + + torch.manual_seed(0) + negs = torch.randint(1, 10, (1, 5, n_negatives)) + + # Batch with no padding + batch_full = { + "x": torch.tensor([[1, 2, 3, 4, 5]]), + "y": torch.tensor([[2, 3, 4, 5, 6]]), + "negatives": negs.clone(), + } + # Batch with partial padding + batch_padded = { + "x": torch.tensor([[0, 0, 3, 4, 5]]), + "y": torch.tensor([[0, 0, 4, 5, 6]]), + "negatives": negs.clone(), + } + + with torch.no_grad(): + loss_full = module.training_step(batch_full, batch_idx=0) + loss_padded = module.training_step(batch_padded, batch_idx=0) + + # Losses should differ because the padded batch masks out some positions + assert loss_full.item() != pytest.approx(loss_padded.item(), abs=1e-6) + + # ---- supported losses constant ---- + + def test_supported_losses_tuple(self) -> None: + assert FlatSASRecLightning.SUPPORTED_LOSSES == ("softmax", "BCE") diff --git a/tests/fast_transformers/test_ranking.py b/tests/fast_transformers/test_ranking.py new file mode 100644 index 00000000..46a5066f --- /dev/null +++ b/tests/fast_transformers/test_ranking.py @@ -0,0 +1,331 @@ +"""Tests for rectools.fast_transformers.ranking.rank_topk.""" + +import numpy as np +import pytest +import torch +from scipy import sparse + +from rectools.fast_transformers.ranking import rank_topk + + +class TestRankTopk: + """Tests for rank_topk function.""" + + def _make_embeddings(self) -> tuple: + """Create deterministic user/item embeddings for testing. + + 3 users, 5 items, dimension 2. + Scores matrix (user_embs @ item_embs.T): + user0: [2, 5, 1, 4, 3] + user1: [3, 1, 5, 2, 4] + user2: [4, 3, 2, 5, 1] + """ + # Construct embeddings so the dot-product scores are easy to reason about. + # We use a trick: set item_embs to one-hot-ish vectors so each column + # of the score matrix is directly controlled. + item_embs = torch.eye(5, dtype=torch.float32) + # user_embs rows are just the desired score rows + user_embs = torch.tensor( + [ + [2.0, 5.0, 1.0, 4.0, 3.0], + [3.0, 1.0, 5.0, 2.0, 4.0], + [4.0, 3.0, 2.0, 5.0, 1.0], + ], + dtype=torch.float32, + ) + return user_embs, item_embs + + def test_basic_topk(self): + """Top-k returns the correct items and scores for each user.""" + user_embs, item_embs = self._make_embeddings() + k = 3 + user_ids, item_ids, scores = rank_topk(user_embs, item_embs, k) + + # user0 top-3: item1(5), item3(4), item4(3) + # user1 top-3: item2(5), item4(4), item0(3) + # user2 top-3: item3(5), item0(4), item1(3) + expected_items = { + 0: [1, 3, 4], + 1: [2, 4, 0], + 2: [3, 0, 1], + } + expected_scores = { + 0: [5.0, 4.0, 3.0], + 1: [5.0, 4.0, 3.0], + 2: [5.0, 4.0, 3.0], + } + + for uid in range(3): + mask = user_ids == uid + assert mask.sum() == k + np.testing.assert_array_equal(item_ids[mask], expected_items[uid]) + np.testing.assert_array_almost_equal(scores[mask], expected_scores[uid]) + + def test_output_shapes(self): + """Output arrays all have length n_users * k.""" + user_embs, item_embs = self._make_embeddings() + k = 2 + user_ids, item_ids, scores = rank_topk(user_embs, item_embs, k) + + n_users = user_embs.shape[0] + expected_len = n_users * k + assert len(user_ids) == expected_len + assert len(item_ids) == expected_len + assert len(scores) == expected_len + + def test_scores_sorted_descending_per_user(self): + """Scores within each user block are in descending order.""" + user_embs, item_embs = self._make_embeddings() + k = 4 + user_ids, item_ids, scores = rank_topk(user_embs, item_embs, k) + + for uid in range(user_embs.shape[0]): + mask = user_ids == uid + user_scores = scores[mask] + assert np.all(user_scores[:-1] >= user_scores[1:]), ( + f"Scores for user {uid} are not in descending order: {user_scores}" + ) + + def test_filter_csr_excludes_viewed_items(self): + """Items present in filter_csr are excluded from recommendations.""" + user_embs, item_embs = self._make_embeddings() + k = 3 + + # user0 has viewed item1 (their top item with score 5) + # user1 has viewed item2 (their top item with score 5) + filter_csr = sparse.csr_matrix( + ([1, 1], ([0, 1], [1, 2])), + shape=(3, 5), + ) + + user_ids, item_ids, scores = rank_topk(user_embs, item_embs, k, filter_csr=filter_csr) + + # user0: item1 excluded -> top-3: item3(4), item4(3), item0(2) + mask0 = user_ids == 0 + np.testing.assert_array_equal(item_ids[mask0], [3, 4, 0]) + np.testing.assert_array_almost_equal(scores[mask0], [4.0, 3.0, 2.0]) + + # user1: item2 excluded -> top-3: item4(4), item0(3), item3(2) + mask1 = user_ids == 1 + np.testing.assert_array_equal(item_ids[mask1], [4, 0, 3]) + np.testing.assert_array_almost_equal(scores[mask1], [4.0, 3.0, 2.0]) + + # user2: nothing excluded -> top-3: item3(5), item0(4), item1(3) + mask2 = user_ids == 2 + np.testing.assert_array_equal(item_ids[mask2], [3, 0, 1]) + np.testing.assert_array_almost_equal(scores[mask2], [5.0, 4.0, 3.0]) + + def test_whitelist_restricts_items(self): + """Only whitelisted items appear in results, but with original indices.""" + user_embs, item_embs = self._make_embeddings() + k = 2 + + # Only consider items 0, 2, 4 + whitelist = np.array([0, 2, 4]) + user_ids, item_ids, scores = rank_topk(user_embs, item_embs, k, whitelist=whitelist) + + for uid in range(3): + mask = user_ids == uid + # All returned items must be in the whitelist + assert set(item_ids[mask]).issubset(set(whitelist)) + + # user0 scores on [0,2,4]: [2,1,3] -> top-2: item4(3), item0(2) + mask0 = user_ids == 0 + np.testing.assert_array_equal(item_ids[mask0], [4, 0]) + np.testing.assert_array_almost_equal(scores[mask0], [3.0, 2.0]) + + # user1 scores on [0,2,4]: [3,5,4] -> top-2: item2(5), item4(4) + mask1 = user_ids == 1 + np.testing.assert_array_equal(item_ids[mask1], [2, 4]) + np.testing.assert_array_almost_equal(scores[mask1], [5.0, 4.0]) + + def test_filter_csr_and_whitelist_combined(self): + """filter_csr and whitelist work correctly together.""" + user_embs, item_embs = self._make_embeddings() + k = 2 + + # Whitelist: items 0, 1, 3 + whitelist = np.array([0, 1, 3]) + + # user0 viewed item1 (top item in whitelist) + filter_csr = sparse.csr_matrix( + ([1], ([0], [1])), + shape=(3, 5), + ) + + user_ids, item_ids, scores = rank_topk( + user_embs, item_embs, k, filter_csr=filter_csr, whitelist=whitelist + ) + + # user0 whitelist scores: item0(2), item1(5), item3(4) + # After filter (item1 excluded): item0(2), item3(4) + # top-2: item3(4), item0(2) + mask0 = user_ids == 0 + np.testing.assert_array_equal(item_ids[mask0], [3, 0]) + np.testing.assert_array_almost_equal(scores[mask0], [4.0, 2.0]) + + # user1 no items filtered, whitelist scores: item0(3), item1(1), item3(2) + # top-2: item0(3), item3(2) + mask1 = user_ids == 1 + np.testing.assert_array_equal(item_ids[mask1], [0, 3]) + np.testing.assert_array_almost_equal(scores[mask1], [3.0, 2.0]) + + def test_k_greater_than_n_items(self): + """When k > n_items, returns all items per user.""" + user_embs, item_embs = self._make_embeddings() + n_items = item_embs.shape[0] + k = n_items + 10 # Much larger than n_items + + user_ids, item_ids, scores = rank_topk(user_embs, item_embs, k) + + # Should return n_items results per user, not k + n_users = user_embs.shape[0] + assert len(user_ids) == n_users * n_items + assert len(item_ids) == n_users * n_items + assert len(scores) == n_users * n_items + + # Check that all items appear for each user + for uid in range(n_users): + mask = user_ids == uid + assert sorted(item_ids[mask]) == list(range(n_items)) + + def test_k_greater_than_n_items_with_whitelist(self): + """When k > len(whitelist), returns len(whitelist) items per user.""" + user_embs, item_embs = self._make_embeddings() + whitelist = np.array([1, 3]) + k = 10 + + user_ids, item_ids, scores = rank_topk(user_embs, item_embs, k, whitelist=whitelist) + + n_users = user_embs.shape[0] + assert len(user_ids) == n_users * len(whitelist) + + for uid in range(n_users): + mask = user_ids == uid + assert set(item_ids[mask]) == set(whitelist) + + def test_batch_size_does_not_affect_results(self): + """Different batch sizes produce identical results.""" + user_embs, item_embs = self._make_embeddings() + k = 3 + + uid_full, iid_full, sc_full = rank_topk(user_embs, item_embs, k, batch_size=256) + uid_bs1, iid_bs1, sc_bs1 = rank_topk(user_embs, item_embs, k, batch_size=1) + uid_bs2, iid_bs2, sc_bs2 = rank_topk(user_embs, item_embs, k, batch_size=2) + + np.testing.assert_array_equal(uid_full, uid_bs1) + np.testing.assert_array_equal(iid_full, iid_bs1) + np.testing.assert_array_almost_equal(sc_full, sc_bs1) + + np.testing.assert_array_equal(uid_full, uid_bs2) + np.testing.assert_array_equal(iid_full, iid_bs2) + np.testing.assert_array_almost_equal(sc_full, sc_bs2) + + def test_batch_size_with_filter_and_whitelist(self): + """Batch processing gives same results with filter_csr and whitelist.""" + user_embs, item_embs = self._make_embeddings() + k = 2 + whitelist = np.array([0, 2, 4]) + filter_csr = sparse.csr_matrix( + ([1, 1], ([0, 2], [0, 4])), + shape=(3, 5), + ) + + uid_full, iid_full, sc_full = rank_topk( + user_embs, item_embs, k, filter_csr=filter_csr, whitelist=whitelist, batch_size=256 + ) + uid_bs1, iid_bs1, sc_bs1 = rank_topk( + user_embs, item_embs, k, filter_csr=filter_csr, whitelist=whitelist, batch_size=1 + ) + + np.testing.assert_array_equal(uid_full, uid_bs1) + np.testing.assert_array_equal(iid_full, iid_bs1) + np.testing.assert_array_almost_equal(sc_full, sc_bs1) + + def test_multiple_users_independent_topk(self): + """Each user gets their own independent top-k based on their embeddings.""" + user_embs, item_embs = self._make_embeddings() + k = 1 + + user_ids, item_ids, scores = rank_topk(user_embs, item_embs, k) + + # Each user should get exactly 1 result + assert len(user_ids) == 3 + np.testing.assert_array_equal(user_ids, [0, 1, 2]) + + # Best items: user0->item1(5), user1->item2(5), user2->item3(5) + np.testing.assert_array_equal(item_ids, [1, 2, 3]) + np.testing.assert_array_almost_equal(scores, [5.0, 5.0, 5.0]) + + def test_single_user(self): + """Works correctly with a single user.""" + user_embs = torch.tensor([[1.0, 0.0, 0.0]], dtype=torch.float32) + item_embs = torch.tensor( + [[3.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]], + dtype=torch.float32, + ) + k = 2 + + user_ids, item_ids, scores = rank_topk(user_embs, item_embs, k) + + np.testing.assert_array_equal(user_ids, [0, 0]) + np.testing.assert_array_equal(item_ids, [0, 2]) + np.testing.assert_array_almost_equal(scores, [3.0, 2.0]) + + def test_single_item(self): + """Works correctly with a single item.""" + user_embs = torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32) + item_embs = torch.tensor([[1.0, 1.0]], dtype=torch.float32) + k = 5 # k > n_items + + user_ids, item_ids, scores = rank_topk(user_embs, item_embs, k) + + # Only 1 item, so each user gets 1 result + assert len(user_ids) == 2 + np.testing.assert_array_equal(user_ids, [0, 1]) + np.testing.assert_array_equal(item_ids, [0, 0]) + np.testing.assert_array_almost_equal(scores, [3.0, 7.0]) + + def test_user_ids_are_sequential_indices(self): + """Returned user_ids are sequential integer indices starting from 0.""" + user_embs, item_embs = self._make_embeddings() + k = 2 + + user_ids, _, _ = rank_topk(user_embs, item_embs, k) + + # user_ids should be [0,0, 1,1, 2,2] + expected = np.repeat(np.arange(3), k) + np.testing.assert_array_equal(user_ids, expected) + + def test_return_types_are_numpy(self): + """All returned arrays are numpy ndarrays.""" + user_embs, item_embs = self._make_embeddings() + k = 2 + + user_ids, item_ids, scores = rank_topk(user_embs, item_embs, k) + + assert isinstance(user_ids, np.ndarray) + assert isinstance(item_ids, np.ndarray) + assert isinstance(scores, np.ndarray) + + def test_filter_all_items_for_user(self): + """When all items are filtered for a user, scores are -inf.""" + user_embs = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) + item_embs = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float32) + k = 1 + + # Filter all items for user 0 + filter_csr = sparse.csr_matrix( + ([1, 1], ([0, 0], [0, 1])), + shape=(2, 2), + ) + + user_ids, item_ids, scores = rank_topk(user_embs, item_embs, k, filter_csr=filter_csr) + + # user0: all filtered -> score is -inf + mask0 = user_ids == 0 + assert np.all(np.isneginf(scores[mask0])) + + # user1: nothing filtered -> normal result + mask1 = user_ids == 1 + assert scores[mask1][0] == pytest.approx(1.0) diff --git a/tests/fast_transformers/test_unisrec_lightning.py b/tests/fast_transformers/test_unisrec_lightning.py new file mode 100644 index 00000000..855c0616 --- /dev/null +++ b/tests/fast_transformers/test_unisrec_lightning.py @@ -0,0 +1,482 @@ +"""Tests for UniSRecLightning wrapper and _cosine_warmup_scheduler.""" + +import math + +import torch +import pytest + +from rectools.fast_transformers.unisrec_net import UniSRec +from rectools.fast_transformers.unisrec_lightning import ( + UniSRecLightning, + _cosine_warmup_scheduler, + SUPPORTED_LOSSES, + SUPPORTED_OPTIMIZERS, + SUPPORTED_SCHEDULERS, +) + + +@pytest.fixture() +def pretrained_emb() -> torch.Tensor: + """Fake pretrained embeddings: (11, 32) -- 10 items + 1 padding.""" + torch.manual_seed(0) + emb = torch.randn(11, 32) + emb[0] = 0.0 # padding + return emb + + +@pytest.fixture() +def net(pretrained_emb: torch.Tensor) -> UniSRec: + return UniSRec( + n_items=10, + pretrained_embeddings=pretrained_emb, + n_factors=8, + projection_hidden=16, + n_blocks=1, + n_heads=1, + session_max_len=5, + dropout=0.0, + adaptor_dropout=0.0, + ) + + +def _make_module( + net: UniSRec, + use_id: bool = False, + loss: str = "softmax", + n_negatives: int | None = None, + optimizer: str = "adamw", + scheduler: str | None = None, + total_steps: int | None = None, + lr: float = 1e-3, + warmup_ratio: float = 0.05, + min_lr_ratio: float = 0.1, + gbce_t: float = 0.2, +) -> UniSRecLightning: + """Build a UniSRecLightning with a single param group.""" + param_groups = [{"params": list(net.parameters()), "lr": lr}] + return UniSRecLightning( + net=net, + param_groups=param_groups, + use_id=use_id, + loss=loss, + n_negatives=n_negatives, + gbce_t=gbce_t, + optimizer=optimizer, + scheduler=scheduler, + warmup_ratio=warmup_ratio, + min_lr_ratio=min_lr_ratio, + total_steps=total_steps, + ) + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + + +class TestConstants: + def test_supported_losses(self) -> None: + assert SUPPORTED_LOSSES == ("softmax", "BCE", "gBCE", "sampled_softmax") + + def test_supported_optimizers(self) -> None: + assert SUPPORTED_OPTIMIZERS == ("adam", "adamw") + + def test_supported_schedulers(self) -> None: + assert SUPPORTED_SCHEDULERS == (None, "cosine_warmup") + + +# --------------------------------------------------------------------------- +# configure_optimizers +# --------------------------------------------------------------------------- + + +class TestConfigureOptimizers: + def test_adam_returns_adam(self, net: UniSRec) -> None: + module = _make_module(net, optimizer="adam") + result = module.configure_optimizers() + assert isinstance(result, torch.optim.Adam) + + def test_adamw_returns_adamw(self, net: UniSRec) -> None: + module = _make_module(net, optimizer="adamw") + result = module.configure_optimizers() + assert isinstance(result, torch.optim.AdamW) + + def test_no_scheduler_returns_optimizer_only(self, net: UniSRec) -> None: + module = _make_module(net, scheduler=None) + result = module.configure_optimizers() + # When scheduler is None, returns just the optimizer (not a dict) + assert isinstance(result, torch.optim.Optimizer) + + def test_cosine_warmup_returns_dict(self, net: UniSRec) -> None: + module = _make_module(net, scheduler="cosine_warmup", total_steps=100) + result = module.configure_optimizers() + assert isinstance(result, dict) + assert "optimizer" in result + assert "lr_scheduler" in result + assert result["lr_scheduler"]["interval"] == "step" + + def test_unknown_optimizer_raises(self, net: UniSRec) -> None: + module = _make_module(net, optimizer="sgd") + with pytest.raises(ValueError, match="Unknown optimizer"): + module.configure_optimizers() + + def test_unknown_scheduler_raises(self, net: UniSRec) -> None: + module = _make_module(net, scheduler="step_lr") + with pytest.raises(ValueError, match="Unknown scheduler"): + module.configure_optimizers() + + def test_cosine_warmup_total_steps_default(self, net: UniSRec) -> None: + """When total_steps is None, it defaults to 1.""" + module = _make_module(net, scheduler="cosine_warmup", total_steps=None) + result = module.configure_optimizers() + assert isinstance(result, dict) + + def test_optimizer_lr(self, net: UniSRec) -> None: + lr = 5e-4 + module = _make_module(net, optimizer="adam", lr=lr) + opt = module.configure_optimizers() + assert opt.param_groups[0]["lr"] == lr + + +# --------------------------------------------------------------------------- +# _cosine_warmup_scheduler +# --------------------------------------------------------------------------- + + +class TestCosineWarmupScheduler: + def test_lr_at_step_zero_is_zero(self) -> None: + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + scheduler = _cosine_warmup_scheduler(opt, warmup_steps=10, total_steps=100, min_lr_ratio=0.0) + # LambdaLR stores the lambda; get factor for step 0 + lr_factor = scheduler.lr_lambdas[0](0) + assert lr_factor == 0.0 + + def test_lr_during_warmup_is_linear(self) -> None: + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + warmup_steps = 10 + scheduler = _cosine_warmup_scheduler(opt, warmup_steps=warmup_steps, total_steps=100) + lr_fn = scheduler.lr_lambdas[0] + for step in range(1, warmup_steps): + assert lr_fn(step) == pytest.approx(step / warmup_steps) + + def test_lr_at_warmup_end_is_one(self) -> None: + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + scheduler = _cosine_warmup_scheduler(opt, warmup_steps=10, total_steps=100) + lr_fn = scheduler.lr_lambdas[0] + # At warmup_steps, progress = 0, cos(0) = 1 => factor = 1.0 + assert lr_fn(10) == pytest.approx(1.0) + + def test_lr_at_end_equals_min_lr_ratio(self) -> None: + min_lr_ratio = 0.1 + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + scheduler = _cosine_warmup_scheduler( + opt, warmup_steps=10, total_steps=100, min_lr_ratio=min_lr_ratio, + ) + lr_fn = scheduler.lr_lambdas[0] + # At total_steps, progress = 1, cos(pi) = -1 => factor = min_lr_ratio + assert lr_fn(100) == pytest.approx(min_lr_ratio) + + def test_lr_at_cosine_midpoint(self) -> None: + """At the midpoint of the cosine phase, factor should be (1 + min_lr_ratio) / 2.""" + warmup_steps = 10 + total_steps = 110 + min_lr_ratio = 0.0 + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + scheduler = _cosine_warmup_scheduler( + opt, warmup_steps=warmup_steps, total_steps=total_steps, min_lr_ratio=min_lr_ratio, + ) + lr_fn = scheduler.lr_lambdas[0] + midpoint = warmup_steps + (total_steps - warmup_steps) // 2 # 60 + # progress = 0.5 => cos(pi/2) = 0 => factor = 0.5 + expected = min_lr_ratio + (1.0 - min_lr_ratio) * 0.5 * (1.0 + math.cos(math.pi * 0.5)) + assert lr_fn(midpoint) == pytest.approx(expected, abs=1e-6) + + def test_lr_with_nonzero_min_lr_ratio(self) -> None: + min_lr_ratio = 0.3 + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + scheduler = _cosine_warmup_scheduler( + opt, warmup_steps=0, total_steps=100, min_lr_ratio=min_lr_ratio, + ) + lr_fn = scheduler.lr_lambdas[0] + # At step 0 (warmup_steps=0, so cosine phase), progress=0, cos(0)=1 => factor=1.0 + assert lr_fn(0) == pytest.approx(1.0) + # At total_steps => factor = min_lr_ratio + assert lr_fn(100) == pytest.approx(min_lr_ratio) + + def test_returns_lambda_lr(self) -> None: + opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) + scheduler = _cosine_warmup_scheduler(opt, warmup_steps=5, total_steps=50) + assert isinstance(scheduler, torch.optim.lr_scheduler.LambdaLR) + + +# --------------------------------------------------------------------------- +# training_step +# --------------------------------------------------------------------------- + + +class TestTrainingStep: + def test_softmax_with_use_id_true(self, net: UniSRec) -> None: + module = _make_module(net, use_id=True, loss="softmax") + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0, "Loss should be a scalar" + assert not torch.isnan(loss), "Loss should not be NaN" + assert not torch.isinf(loss), "Loss should not be Inf" + + def test_softmax_with_use_id_false(self, net: UniSRec) -> None: + module = _make_module(net, use_id=False, loss="softmax") + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0, "Loss should be a scalar" + assert not torch.isnan(loss), "Loss should not be NaN" + assert not torch.isinf(loss), "Loss should not be Inf" + + def test_softmax_positive_loss(self, net: UniSRec) -> None: + module = _make_module(net, use_id=True, loss="softmax") + batch = { + "x": torch.tensor([[1, 2, 3, 4, 5]]), + "y": torch.tensor([[2, 3, 4, 5, 6]]), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.item() > 0, "Cross-entropy loss should be positive" + + def test_bce_loss_returns_scalar(self, net: UniSRec) -> None: + n_negatives = 3 + module = _make_module(net, use_id=True, loss="BCE", n_negatives=n_negatives) + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + "negatives": torch.randint(1, 10, (2, 5, n_negatives)), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert not torch.isnan(loss) + assert not torch.isinf(loss) + + def test_gbce_loss_returns_scalar(self, net: UniSRec) -> None: + n_negatives = 3 + module = _make_module(net, use_id=True, loss="gBCE", n_negatives=n_negatives) + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + "negatives": torch.randint(1, 10, (2, 5, n_negatives)), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert not torch.isnan(loss) + assert not torch.isinf(loss) + + def test_sampled_softmax_loss_returns_scalar(self, net: UniSRec) -> None: + n_negatives = 3 + module = _make_module(net, use_id=True, loss="sampled_softmax", n_negatives=n_negatives) + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + "negatives": torch.randint(1, 10, (2, 5, n_negatives)), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert not torch.isnan(loss) + assert not torch.isinf(loss) + + def test_softmax_ignores_negatives_when_present(self, net: UniSRec) -> None: + """Softmax loss uses full softmax even when negatives are provided.""" + module_no_neg = _make_module(net, use_id=True, loss="softmax") + module_with_neg = _make_module(net, use_id=True, loss="softmax") + net.eval() + + batch_no_neg = { + "x": torch.tensor([[1, 2, 3, 4, 5]]), + "y": torch.tensor([[2, 3, 4, 5, 6]]), + } + batch_with_neg = { + "x": torch.tensor([[1, 2, 3, 4, 5]]), + "y": torch.tensor([[2, 3, 4, 5, 6]]), + "negatives": torch.randint(1, 10, (1, 5, 3)), + } + with torch.no_grad(): + loss_no_neg = module_no_neg.training_step(batch_no_neg, batch_idx=0) + loss_with_neg = module_with_neg.training_step(batch_with_neg, batch_idx=0) + torch.testing.assert_close(loss_no_neg, loss_with_neg) + + def test_all_padding_softmax(self, net: UniSRec) -> None: + """When all targets are padding, cross_entropy with ignore_index returns NaN.""" + module = _make_module(net, use_id=True, loss="softmax") + batch = { + "x": torch.tensor([[0, 0, 0, 0, 0]]), + "y": torch.tensor([[0, 0, 0, 0, 0]]), + } + loss = module.training_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert torch.isnan(loss) + + +# --------------------------------------------------------------------------- +# validation_step +# --------------------------------------------------------------------------- + + +class TestValidationStep: + def test_validation_returns_scalar(self, net: UniSRec) -> None: + module = _make_module(net, use_id=True, loss="softmax") + module.eval() + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[4], [8]]), # (B, 1) + } + with torch.no_grad(): + loss = module.validation_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert not torch.isnan(loss) + assert not torch.isinf(loss) + + def test_validation_uses_last_hidden(self, net: UniSRec) -> None: + """Validation slices hidden to [:, -1:, :], so y shape (B, 1) works.""" + module = _make_module(net, use_id=False, loss="softmax") + module.eval() + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3]]), + "y": torch.tensor([[4]]), # single target per sequence + } + with torch.no_grad(): + loss = module.validation_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert not torch.isnan(loss) + + def test_validation_with_negatives(self, net: UniSRec) -> None: + n_negatives = 3 + module = _make_module(net, use_id=True, loss="BCE", n_negatives=n_negatives) + module.eval() + batch = { + "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), + "y": torch.tensor([[4], [8]]), + "negatives": torch.randint(1, 10, (2, 1, n_negatives)), + } + with torch.no_grad(): + loss = module.validation_step(batch, batch_idx=0) + assert loss.dim() == 0 + assert not torch.isnan(loss) + + +# --------------------------------------------------------------------------- +# _calc_loss dispatch +# --------------------------------------------------------------------------- + + +class TestCalcLossDispatch: + def test_softmax_without_negatives_uses_full_softmax(self, net: UniSRec) -> None: + module = _make_module(net, use_id=True, loss="softmax") + hidden = torch.randn(2, 5, 8) + batch = { + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + } + loss = module._calc_loss(hidden, batch) + assert loss.dim() == 0 + assert not torch.isnan(loss) + + def test_bce_without_negatives_raises(self, net: UniSRec) -> None: + module = _make_module(net, use_id=True, loss="BCE") + hidden = torch.randn(2, 5, 8) + batch = { + "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), + } + with pytest.raises(ValueError, match="requires negatives"): + module._calc_loss(hidden, batch) + + def test_gbce_without_negatives_raises(self, net: UniSRec) -> None: + module = _make_module(net, use_id=True, loss="gBCE") + hidden = torch.randn(2, 5, 8) + batch = {"y": torch.tensor([[1, 2, 3, 4, 5]])} + with pytest.raises(ValueError, match="requires negatives"): + module._calc_loss(hidden, batch) + + def test_sampled_softmax_without_negatives_raises(self, net: UniSRec) -> None: + module = _make_module(net, use_id=True, loss="sampled_softmax") + hidden = torch.randn(1, 5, 8) + batch = {"y": torch.tensor([[1, 2, 3, 4, 5]])} + with pytest.raises(ValueError, match="requires negatives"): + module._calc_loss(hidden, batch) + + def test_unknown_loss_raises(self, net: UniSRec) -> None: + module = _make_module(net, use_id=True, loss="mse") + hidden = torch.randn(1, 5, 8) + batch = { + "y": torch.tensor([[1, 2, 3, 4, 5]]), + "negatives": torch.randint(1, 10, (1, 5, 3)), + } + with pytest.raises(ValueError, match="Unknown loss"): + module._calc_loss(hidden, batch) + + +# --------------------------------------------------------------------------- +# _get_item_embs / _get_all_embs +# --------------------------------------------------------------------------- + + +class TestEmbeddingHelpers: + def test_get_item_embs_id_mode(self, net: UniSRec) -> None: + module = _make_module(net, use_id=True) + item_ids = torch.tensor([[1, 2, 3]]) + embs = module._get_item_embs(item_ids) + assert embs.shape == (1, 3, 8) # (B, L, n_factors) + + def test_get_item_embs_adapted_mode(self, net: UniSRec) -> None: + module = _make_module(net, use_id=False) + item_ids = torch.tensor([[1, 2, 3]]) + embs = module._get_item_embs(item_ids) + assert embs.shape == (1, 3, 8) + + def test_get_all_embs_id_mode(self, net: UniSRec) -> None: + module = _make_module(net, use_id=True) + all_embs = module._get_all_embs() + assert all_embs.shape == (11, 8) # n_items + 1 + + def test_get_all_embs_adapted_mode(self, net: UniSRec) -> None: + module = _make_module(net, use_id=False) + all_embs = module._get_all_embs() + assert all_embs.shape == (11, 8) + + def test_get_pos_neg_logits_shape(self, net: UniSRec) -> None: + module = _make_module(net, use_id=True) + hidden = torch.randn(2, 5, 8) + labels = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) + negatives = torch.randint(1, 10, (2, 5, 3)) + logits = module._get_pos_neg_logits(hidden, labels, negatives) + assert logits.shape == (2, 5, 4) # 1 positive + 3 negatives + + +# --------------------------------------------------------------------------- +# Init stores params +# --------------------------------------------------------------------------- + + +class TestInit: + def test_stores_all_attributes(self, net: UniSRec) -> None: + module = _make_module( + net, + use_id=True, + loss="BCE", + n_negatives=5, + optimizer="adam", + scheduler="cosine_warmup", + total_steps=200, + warmup_ratio=0.1, + min_lr_ratio=0.05, + gbce_t=0.3, + ) + assert module.use_id is True + assert module.loss_name == "BCE" + assert module.n_negatives == 5 + assert module.optimizer_name == "adam" + assert module.scheduler_name == "cosine_warmup" + assert module.total_steps == 200 + assert module.warmup_ratio == 0.1 + assert module.min_lr_ratio == 0.05 + assert module.gbce_t == 0.3 + assert module.net is net diff --git a/tests/fast_transformers/test_unisrec_model.py b/tests/fast_transformers/test_unisrec_model.py index 98dc3e94..a3de7d7d 100644 --- a/tests/fast_transformers/test_unisrec_model.py +++ b/tests/fast_transformers/test_unisrec_model.py @@ -1,29 +1,9 @@ -"""Tests for UniSRecModel.""" +"""Tests for UniSRecModel (standalone, tensor-based API).""" -import numpy as np -import pandas as pd import pytest import torch -from rectools import Columns -from rectools.dataset import Dataset -from rectools.fast_transformers import UniSRecConfig, UniSRecModel - - -def _make_dataset(n_users: int = 20, n_items: int = 25, seed: int = 42) -> Dataset: - rng = np.random.RandomState(seed) - rows = [] - for u in range(n_users): - n_inter = rng.randint(3, 8) - items = rng.choice(n_items, size=n_inter, replace=False) - for rank, item in enumerate(items): - rows.append({ - Columns.User: u, - Columns.Item: item, - Columns.Weight: 1.0, - Columns.Datetime: pd.Timestamp("2024-01-01") + pd.Timedelta(hours=rank), - }) - return Dataset.construct(pd.DataFrame(rows)) +from rectools.fast_transformers import UniSRecModel def _make_embeddings(n_items: int = 25, dim: int = 64) -> torch.Tensor: @@ -33,6 +13,24 @@ def _make_embeddings(n_items: int = 25, dim: int = 64) -> torch.Tensor: return emb +def _make_interactions(n_users: int = 20, n_items: int = 25, seed: int = 42): + """Generate synthetic (user_ids, item_ids, timestamps) tensors.""" + rng = torch.Generator().manual_seed(seed) + users, items, timestamps = [], [], [] + for u in range(n_users): + n_inter = torch.randint(3, 8, (1,), generator=rng).item() + item_pool = torch.randperm(n_items, generator=rng)[:n_inter] + 1 # 1-based + for rank, item in enumerate(item_pool): + users.append(u) + items.append(item.item()) + timestamps.append(rank) + return ( + torch.tensor(users, dtype=torch.long), + torch.tensor(items, dtype=torch.long), + torch.tensor(timestamps, dtype=torch.long), + ) + + def _make_model(**kwargs) -> UniSRecModel: defaults = dict( pretrained_item_embeddings=_make_embeddings(), @@ -51,160 +49,139 @@ def _make_model(**kwargs) -> UniSRecModel: return UniSRecModel(**defaults) -class TestFitRecommend: - def test_recommend_columns(self) -> None: - ds = _make_dataset() +class TestFit: + def test_fit_returns_self(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() model = _make_model() - model.fit(ds) - users = list(range(5)) - reco = model.recommend(users=users, dataset=ds, k=3, filter_viewed=False) - assert set(reco.columns) == {Columns.User, Columns.Item, Columns.Score, Columns.Rank} - assert reco[Columns.User].nunique() == 5 - - def test_filter_viewed(self) -> None: - ds = _make_dataset() + result = model.fit(user_ids, item_ids, timestamps) + assert result is model + + def test_is_fitted_after_fit(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() model = _make_model() - model.fit(ds) - users = list(range(5)) - reco = model.recommend(users=users, dataset=ds, k=5, filter_viewed=True) - interactions = ds.get_raw_interactions() - for uid in users: - viewed = set(interactions[interactions[Columns.User] == uid][Columns.Item]) - recommended = set(reco[reco[Columns.User] == uid][Columns.Item]) - assert viewed.isdisjoint(recommended), f"User {uid} got viewed items" - - def test_i2i(self) -> None: - ds = _make_dataset() + assert not model.is_fitted + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_net_accessible_after_fit(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() model = _make_model() - model.fit(ds) - items = list(range(5)) - reco = model.recommend_to_items(target_items=items, dataset=ds, k=3) - assert set(reco.columns) == {Columns.TargetItem, Columns.Item, Columns.Score, Columns.Rank} - assert reco[Columns.TargetItem].nunique() == 5 - - def test_scores_not_nan(self) -> None: - ds = _make_dataset() - model = _make_model(phase1_epochs=2, phase3_epochs=2) - model.fit(ds) - users = list(range(ds.user_id_map.size)) - reco = model.recommend(users=users, dataset=ds, k=5, filter_viewed=False) - assert len(reco) > 0 - assert reco[Columns.Score].notna().all() + model.fit(user_ids, item_ids, timestamps) + net = model.net + assert net is not None + + def test_item_id_mapping_has_original_ids(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model() + model.fit(user_ids, item_ids, timestamps) + mapping = model.item_id_mapping + original_unique = torch.unique(item_ids) + assert set(mapping.tolist()) == set(original_unique.tolist()) + + def test_net_not_accessible_before_fit(self) -> None: + model = _make_model() + with pytest.raises(AssertionError): + _ = model.net class TestPhaseSkipping: def test_skip_phase1(self) -> None: - ds = _make_dataset() + user_ids, item_ids, timestamps = _make_interactions() model = _make_model(phase1_epochs=0) - model.fit(ds) - reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) - assert len(reco) > 0 + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted def test_skip_phase2(self) -> None: - ds = _make_dataset() + user_ids, item_ids, timestamps = _make_interactions() model = _make_model(phase2_epochs=0) - model.fit(ds) - reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) - assert len(reco) > 0 + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_only_phase1(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(phase1_epochs=2, phase2_epochs=0, phase3_epochs=0) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted def test_only_phase3(self) -> None: - ds = _make_dataset() + user_ids, item_ids, timestamps = _make_interactions() model = _make_model(phase1_epochs=0, phase2_epochs=0, phase3_epochs=2) - model.fit(ds) - reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) - assert len(reco) > 0 - - -class TestWithNegatives: - def test_sampled_loss(self) -> None: - ds = _make_dataset() - model = _make_model(n_negatives=4) - model.fit(ds) - reco = model.recommend(users=[0, 1, 2], dataset=ds, k=3, filter_viewed=False) - assert len(reco) > 0 - - -class TestFFNTypes: - @pytest.mark.parametrize("ffn_type", ["conv1d", "linear_gelu", "linear_relu"]) - def test_ffn_type(self, ffn_type: str) -> None: - ds = _make_dataset() - model = _make_model(ffn_type=ffn_type, ffn_expansion=2, phase1_epochs=0, phase2_epochs=0, phase3_epochs=1) - model.fit(ds) - reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) - assert len(reco) > 0 + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted class TestLosses: - def test_bce_loss(self) -> None: - ds = _make_dataset() - model = _make_model(loss="BCE", n_negatives=4) - model.fit(ds) - reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) - assert len(reco) > 0 - - def test_gbce_loss(self) -> None: - ds = _make_dataset() - model = _make_model(loss="gBCE", n_negatives=4, gbce_t=0.2) - model.fit(ds) - reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) - assert len(reco) > 0 - - def test_sampled_softmax_loss(self) -> None: - ds = _make_dataset() - model = _make_model(loss="sampled_softmax", n_negatives=4) - model.fit(ds) - reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) - assert len(reco) > 0 - - def test_invalid_loss(self) -> None: + def test_softmax_loss(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(loss="softmax", phase1_epochs=0, phase2_epochs=0, phase3_epochs=1) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_invalid_loss_raises(self) -> None: with pytest.raises(ValueError, match="Unsupported loss"): _make_model(loss="invalid") class TestOptimizer: - def test_adam_optimizer(self) -> None: - ds = _make_dataset() + def test_adam(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() model = _make_model(optimizer="adam", phase1_epochs=0, phase2_epochs=0, phase3_epochs=1) - model.fit(ds) - reco = model.recommend(users=[0], dataset=ds, k=3, filter_viewed=False) - assert len(reco) > 0 + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_adamw(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(optimizer="adamw", phase1_epochs=0, phase2_epochs=0, phase3_epochs=1) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted - def test_invalid_optimizer(self) -> None: + def test_invalid_optimizer_raises(self) -> None: with pytest.raises(ValueError, match="Unsupported optimizer"): _make_model(optimizer="sgd") class TestScheduler: def test_cosine_warmup(self) -> None: - ds = _make_dataset() + user_ids, item_ids, timestamps = _make_interactions() model = _make_model(scheduler="cosine_warmup", warmup_ratio=0.1, phase1_epochs=0, phase2_epochs=0, phase3_epochs=2) - model.fit(ds) - reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) - assert len(reco) > 0 + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted + + def test_invalid_scheduler_raises(self) -> None: + with pytest.raises(ValueError, match="Unsupported scheduler"): + _make_model(scheduler="step") + + +class TestCheckpoint: + def test_save_load_roundtrip(self, tmp_path) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(phase1_epochs=1, phase2_epochs=0, phase3_epochs=0) + model.fit(user_ids, item_ids, timestamps) + + ckpt_path = tmp_path / "model.pt" + model.save_checkpoint(ckpt_path) + + model2 = _make_model(phase1_epochs=1, phase2_epochs=0, phase3_epochs=0) + model2.load_checkpoint(ckpt_path, device="cpu") + assert model2.is_fitted + + mapping1 = model.item_id_mapping + mapping2 = model2.item_id_mapping + assert torch.equal(mapping1, mapping2) + + +class TestFFNTypes: + @pytest.mark.parametrize("ffn_type", ["conv1d", "linear_gelu", "linear_relu"]) + def test_ffn_type(self, ffn_type: str) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(ffn_type=ffn_type, ffn_expansion=2, phase1_epochs=0, phase2_epochs=0, phase3_epochs=1) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted class TestEarlyStopping: def test_patience(self) -> None: - ds = _make_dataset() + user_ids, item_ids, timestamps = _make_interactions() model = _make_model(patience=2, phase1_epochs=0, phase2_epochs=0, phase3_epochs=5) - model.fit(ds) - reco = model.recommend(users=[0, 1], dataset=ds, k=3, filter_viewed=False) - assert len(reco) > 0 - - -class TestConfig: - def test_get_config(self) -> None: - model = _make_model(ffn_type="linear_gelu", loss="BCE", n_negatives=4, optimizer="adam", scheduler="cosine_warmup", patience=5) - config = model.get_config(mode="pydantic") - assert config.model.n_factors == 16 - assert config.model.ffn_type == "linear_gelu" - assert config.model.loss == "BCE" - assert config.model.optimizer == "adam" - assert config.model.scheduler == "cosine_warmup" - assert config.model.patience == 5 - - def test_from_config_raises(self) -> None: - model = _make_model() - config = model.get_config(mode="pydantic") - with pytest.raises(NotImplementedError, match="pretrained_item_embeddings"): - UniSRecModel.from_config(config) + model.fit(user_ids, item_ids, timestamps) + assert model.is_fitted From e24fec380bf380c878495844ca3cb409c44cc8a6 Mon Sep 17 00:00:00 2001 From: TOPAPEC Date: Fri, 24 Apr 2026 22:17:27 +0000 Subject: [PATCH 5/7] add changelog, fixed gpu model load --- .gitignore | 4 +- CHANGELOG.md | 13 + rectools/fast_transformers/unisrec_model.py | 4 +- scripts/profile_build_sequences.py | 142 ---------- scripts/test_1epoch.py | 88 ------ scripts/train_fast_sasrec.py | 77 ----- scripts/train_unisrec.py | 96 ------- scripts/train_unisrec_ml20m.py | 293 -------------------- 8 files changed, 17 insertions(+), 700 deletions(-) delete mode 100644 scripts/profile_build_sequences.py delete mode 100644 scripts/test_1epoch.py delete mode 100644 scripts/train_fast_sasrec.py delete mode 100644 scripts/train_unisrec.py delete mode 100644 scripts/train_unisrec_ml20m.py diff --git a/.gitignore b/.gitignore index 13082042..d63a776b 100644 --- a/.gitignore +++ b/.gitignore @@ -97,7 +97,7 @@ benchmark_results/ # CatBoost catboost_info/ -# Dev testing folder +# Dev artifacts training_folder/ *.pt -data/* \ No newline at end of file +data/* diff --git a/CHANGELOG.md b/CHANGELOG.md index 15e77808..285ee45a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added +- `rectools.fast_transformers` module — standalone transformer-based sequential recommenders that work directly with torch tensors, bypassing the `Dataset`/pandas pipeline. GPU-native sequence building via `build_sequences()` gives ~30x preprocessing speedup over `SASRecDataPreparator` on ML-20M +- `FlatSASRec` network and `FlatSASRecModel` — flat SASRec implementation without the ItemNet hierarchy. Pre-norm transformer encoder with id-embeddings, causal masking, softmax and BCE losses. Integrates with RecTools `ModelBase` for compatibility with the standard `fit`/`recommend` API +- `UniSRec` network and `UniSRecModel` — sequential recommender with pretrained text embeddings (e.g. Qwen) and a learnable PCA/BN adaptor. Three-phase training: (1) SASRec warm-up on ID embeddings, (2) adaptor-only with frozen transformer, (3) full fine-tune on pretrained embeddings. Configurable losses (softmax, BCE, gBCE, sampled_softmax), optimizers (Adam, AdamW), cosine warmup scheduler, early stopping, checkpoint save/load. `UniSRecModel.fit()` accepts raw `(user_ids, item_ids, timestamps)` tensors +- `rank_topk()` utility for batched top-k scoring with CSR-based viewed-item filtering and item whitelist support +- `align_embeddings()` for mapping pretrained embedding matrices to internal item ID order +- `GPUBatchDataset` and `make_dataloader()` — lightweight torch Dataset/DataLoader wrappers for sequence training data +- Configurable FFN blocks in `UniSRec`: `conv1d` (original paper), `linear_gelu`, `linear_relu` with adjustable expansion factor +- Tests for all `fast_transformers` submodules (143 tests) + + ## [0.18.0] - 21.02.2026 ### Added diff --git a/rectools/fast_transformers/unisrec_model.py b/rectools/fast_transformers/unisrec_model.py index d3a136d9..c737900e 100644 --- a/rectools/fast_transformers/unisrec_model.py +++ b/rectools/fast_transformers/unisrec_model.py @@ -312,8 +312,8 @@ def save_checkpoint(self, path: tp.Union[str, Path]) -> None: def load_checkpoint(self, path: tp.Union[str, Path], device: str = "cuda") -> None: ckpt = torch.load(path, map_location=device, weights_only=False) - self._unique_items = ckpt["unique_items"] - self._unique_users = ckpt["unique_users"] + self._unique_items = ckpt["unique_items"].cpu() + self._unique_users = ckpt["unique_users"].cpu() n_items = ckpt["n_items"] aligned_emb = align_embeddings(self.pretrained_item_embeddings, self._unique_items, n_items) diff --git a/scripts/profile_build_sequences.py b/scripts/profile_build_sequences.py deleted file mode 100644 index 9325b1df..00000000 --- a/scripts/profile_build_sequences.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Profile build_sequences on synthetic data matching ML-20M scale.""" - -import time -import torch - -def build_sequences_profiled( - user_ids, item_ids, timestamps, max_len, min_interactions=2, device="cuda", -): - t0 = time.time() - user_ids = user_ids.to(device) - item_ids = item_ids.to(device) - timestamps = timestamps.to(device) - torch.cuda.synchronize() - t_transfer = time.time() - t0 - - t0 = time.time() - unique_items, item_inv = torch.unique(item_ids, return_inverse=True) - internal_items = item_inv + 1 - unique_users, user_inv = torch.unique(user_ids, return_inverse=True) - torch.cuda.synchronize() - t_unique = time.time() - t0 - - t0 = time.time() - order1 = torch.argsort(timestamps, stable=True) - order2 = torch.argsort(user_inv[order1], stable=True) - order = order1[order2] - sorted_user_inv = user_inv[order] - sorted_items = internal_items[order] - torch.cuda.synchronize() - t_sort = time.time() - t0 - - t0 = time.time() - changes = torch.where(sorted_user_inv[1:] != sorted_user_inv[:-1])[0] + 1 - starts = torch.cat([torch.tensor([0], device=device), changes]) - ends = torch.cat([changes, torch.tensor([len(sorted_user_inv)], device=device)]) - lengths = ends - starts - mask = lengths >= min_interactions - starts = starts[mask] - ends = ends[mask] - lengths = lengths[mask] - n_users = len(starts) - capped_lens = torch.clamp(lengths, max=max_len + 1) - torch.cuda.synchronize() - t_boundaries = time.time() - t0 - - t0 = time.time() - effective_lens = torch.clamp(capped_lens - 1, min=0) - total_elements = effective_lens.sum().item() - x = torch.zeros(n_users, max_len, dtype=torch.long, device=device) - y = torch.zeros(n_users, max_len, dtype=torch.long, device=device) - - if total_elements > 0: - user_indices = torch.repeat_interleave(torch.arange(n_users, device=device), effective_lens) - cumsum = effective_lens.cumsum(0) - offsets = torch.arange(total_elements, device=device) - torch.repeat_interleave(cumsum - effective_lens, effective_lens) - x_src = torch.repeat_interleave(ends - capped_lens, effective_lens) + offsets - y_src = x_src + 1 - col_indices = max_len - torch.repeat_interleave(effective_lens, effective_lens) + offsets - x[user_indices, col_indices] = sorted_items[x_src] - y[user_indices, col_indices] = sorted_items[y_src] - torch.cuda.synchronize() - t_scatter = time.time() - t0 - - valid_user_indices = torch.where(mask)[0] - result_users = unique_users[valid_user_indices] if len(valid_user_indices) < len(unique_users) else unique_users - - print(f" transfer to GPU: {t_transfer:.3f}s") - print(f" unique: {t_unique:.3f}s") - print(f" sort (2x argsort): {t_sort:.3f}s") - print(f" boundaries: {t_boundaries:.3f}s") - print(f" scatter (vectorized): {t_scatter:.3f}s") - print(f" TOTAL: {t_transfer + t_unique + t_sort + t_boundaries + t_scatter:.3f}s") - print(f" n_users={n_users}, total_elements={total_elements}") - - return x, y, unique_items, result_users - - -def verify_correctness(): - """Small test to verify vectorized scatter produces correct results.""" - torch.manual_seed(42) - n = 50 - user_ids = torch.tensor([0,0,0,0,0, 1,1,1, 2,2,2,2]) - item_ids = torch.tensor([10,20,30,40,50, 60,70,80, 90,100,110,120]) - timestamps = torch.arange(n := len(user_ids)) - - from rectools.fast_transformers.gpu_data import build_sequences - x, y, ui, uu = build_sequences(user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device="cuda") - - x_cpu = x.cpu() - y_cpu = y.cpu() - - print("\n=== Correctness check ===") - print(f"x:\n{x_cpu}") - print(f"y:\n{y_cpu}") - - # User 0: items [1,2,3,4,5], capped to 5 (max_len+1=5), effective=4 - # x row: [2, 3, 4, 5] wait, max_len=4 so x[0] should be [1,2,3,4], y[0]=[2,3,4,5] - # Actually: capped = min(5, 4+1=5) = 5, effective = 4 - # seq = items[-5:] = [1,2,3,4,5] - # x: seq[:-1] = [1,2,3,4] placed at cols 0..3 - # y: seq[1:] = [2,3,4,5] placed at cols 0..3 - assert x_cpu[0].tolist() == [1,2,3,4], f"Got {x_cpu[0].tolist()}" - assert y_cpu[0].tolist() == [2,3,4,5], f"Got {y_cpu[0].tolist()}" - - # User 1: items [6,7,8], capped=3, effective=2 - # seq = [6,7,8], x: [6,7] at cols 2..3, y: [7,8] at cols 2..3 - assert x_cpu[1].tolist() == [0,0,6,7], f"Got {x_cpu[1].tolist()}" - assert y_cpu[1].tolist() == [0,0,7,8], f"Got {y_cpu[1].tolist()}" - - # User 2: items [9,10,11,12], capped=4, effective=3 - # seq = [9,10,11,12], x: [9,10,11] at cols 1..3, y: [10,11,12] at cols 1..3 - assert x_cpu[2].tolist() == [0,9,10,11], f"Got {x_cpu[2].tolist()}" - assert y_cpu[2].tolist() == [0,10,11,12], f"Got {y_cpu[2].tolist()}" - - print("All assertions passed!") - - -def profile_ml20m_scale(): - """Generate data at ML-20M scale and profile.""" - print("\n=== ML-20M scale profile ===") - torch.manual_seed(0) - N = 5_000_000 - n_users_approx = 136_000 - n_items_approx = 7_000 - - user_ids = torch.randint(0, n_users_approx, (N,)) - item_ids = torch.randint(0, n_items_approx, (N,)) - timestamps = torch.randint(0, 10**9, (N,), dtype=torch.long) - - # warmup - print("Warmup...") - _ = build_sequences_profiled(user_ids[:1000], item_ids[:1000], timestamps[:1000], max_len=200, device="cuda") - - print("\nFull run:") - x, y, ui, uu = build_sequences_profiled(user_ids, item_ids, timestamps, max_len=200, device="cuda") - print(f"Output shape: x={x.shape}, y={y.shape}") - print(f"GPU memory: {torch.cuda.memory_allocated()/1e9:.2f} GB") - - -if __name__ == "__main__": - verify_correctness() - profile_ml20m_scale() diff --git a/scripts/test_1epoch.py b/scripts/test_1epoch.py deleted file mode 100644 index 76d283ae..00000000 --- a/scripts/test_1epoch.py +++ /dev/null @@ -1,88 +0,0 @@ -"""Quick 1-epoch smoke test of the full pipeline.""" - -import time -from pathlib import Path - -import pandas as pd -import torch - -from rectools.fast_transformers import UniSRecModel - -DATA_DIR = Path("data/ml-20m") -MIN_RATING = 4.0 -MIN_ITEM_INTERACTIONS = 50 -MIN_USER_INTERACTIONS = 5 - - -def load_data(): - ratings = pd.read_csv(DATA_DIR / "ml-20m" / "ratings.csv") - ratings.columns = ["user_id", "item_id", "rating", "timestamp"] - ratings = ratings[ratings["rating"] >= MIN_RATING] - item_counts = ratings.groupby("item_id").size() - popular = item_counts[item_counts >= MIN_ITEM_INTERACTIONS].index - ratings = ratings[ratings["item_id"].isin(popular)] - user_counts = ratings.groupby("user_id").size() - valid = user_counts[user_counts >= MIN_USER_INTERACTIONS].index - ratings = ratings[ratings["user_id"].isin(valid)] - return ratings - - -def main(): - print("Loading data...") - ratings = load_data() - print(f" {len(ratings):,} interactions, {ratings['user_id'].nunique():,} users, {ratings['item_id'].nunique():,} items") - - pretrained = torch.load(DATA_DIR / "qwen_embeddings.pt", weights_only=True) - print(f" Pretrained embeddings: {pretrained.shape}") - - user_ids = torch.tensor(ratings["user_id"].values, dtype=torch.long) - item_ids = torch.tensor(ratings["item_id"].values, dtype=torch.long) - timestamps = torch.tensor(ratings["timestamp"].values, dtype=torch.long) - - model = UniSRecModel( - pretrained_item_embeddings=pretrained, - n_factors=512, - projection_hidden=512, - n_blocks=2, - n_heads=1, - session_max_len=200, - dropout=0.1, - adaptor_dropout=0.2, - adaptor_type="pca", - use_adaptor_ffn=True, - phase1_epochs=0, - phase2_epochs=0, - phase3_epochs=1, - phase3_lr=1e-4, - lr_head=0.3, - lr_wp=0.1, - lr_transformer=3.0, - optimizer="adamw", - scheduler="cosine_warmup", - warmup_ratio=0.05, - min_lr_ratio=1.0, - grad_clip=1.0, - weight_decay=0.01, - loss="softmax", - batch_size=128, - dataloader_num_workers=0, - train_min_user_interactions=2, - verbose=1, - ) - - print("\nStarting 1-epoch training...") - t0 = time.time() - model.fit(user_ids, item_ids, timestamps) - elapsed = time.time() - t0 - print(f"\n1-epoch training complete in {elapsed:.1f}s") - - # Verify item_id_mapping contains original IDs - unique_items = model.item_id_mapping - print(f"unique_items range: [{unique_items.min().item()}, {unique_items.max().item()}]") - print(f"Original item_id range: [{ratings['item_id'].min()}, {ratings['item_id'].max()}]") - assert unique_items.max().item() > 100, "IDs should be original MovieLens IDs, not 0-based reindexed" - print("ID mapping verified — original external IDs preserved!") - - -if __name__ == "__main__": - main() diff --git a/scripts/train_fast_sasrec.py b/scripts/train_fast_sasrec.py deleted file mode 100644 index f0608504..00000000 --- a/scripts/train_fast_sasrec.py +++ /dev/null @@ -1,77 +0,0 @@ -"""End-to-end smoke test: synthetic dataset, train, recommend, metrics, i2i.""" - -import numpy as np -import pandas as pd - -from rectools import Columns -from rectools.dataset import Dataset -from rectools.fast_transformers import FlatSASRecModel - - -def main() -> None: - # --- Synthetic dataset: 80 users x 60 items --- - rng = np.random.RandomState(123) - n_users, n_items = 80, 60 - - rows = [] - for u in range(n_users): - n_inter = rng.randint(4, 15) - items = rng.choice(n_items, size=n_inter, replace=False) - for rank, item in enumerate(items): - rows.append({ - Columns.User: u, - Columns.Item: item, - Columns.Weight: 1.0, - Columns.Datetime: pd.Timestamp("2024-01-01") + pd.Timedelta(hours=rank), - }) - df = pd.DataFrame(rows) - dataset = Dataset.construct(df) - print(f"Dataset: {n_users} users, {n_items} items, {len(df)} interactions") - - # --- Train --- - model = FlatSASRecModel( - n_factors=32, n_blocks=2, n_heads=2, session_max_len=16, - loss="softmax", epochs=2, batch_size=32, lr=1e-3, verbose=1, - ) - model.fit(dataset) - print("Training done.") - - # --- Recommend --- - users = list(range(n_users)) - reco = model.recommend(users=users, dataset=dataset, k=5, filter_viewed=True) - print(f"\nTop-5 recommendations (first 3 users):") - print(reco[reco[Columns.User].isin(range(3))].to_string(index=False)) - - # --- Simple metrics --- - interactions = dataset.get_raw_interactions() - hits = 0 - total = 0 - ap_sum = 0.0 - for u in users: - viewed = set(interactions[interactions[Columns.User] == u][Columns.Item]) - rec_items = reco[reco[Columns.User] == u][Columns.Item].tolist() - # For this smoke test, "relevance" = items the user actually interacted with - # (training set overlap is expected since we don't do train/test split here) - rel = [1 if i in viewed else 0 for i in rec_items] - hits += sum(rel) - total += len(rec_items) - # AP - if sum(rel) > 0: - precision_at = np.cumsum(rel) / np.arange(1, len(rel) + 1) - ap_sum += np.sum(precision_at * rel) / sum(rel) - recall = hits / max(total, 1) - map_at_k = ap_sum / len(users) - print(f"\nRecall@5 (train overlap): {recall:.4f}") - print(f"MAP@5 (train overlap): {map_at_k:.4f}") - - # --- I2I --- - target_items = list(range(10)) - i2i = model.recommend_to_items(target_items=target_items, dataset=dataset, k=5) - print(f"\nI2I recommendations (first 3 target items):") - print(i2i[i2i[Columns.TargetItem].isin(range(3))].to_string(index=False)) - - print("\nSmoke test passed!") - - -if __name__ == "__main__": - main() diff --git a/scripts/train_unisrec.py b/scripts/train_unisrec.py deleted file mode 100644 index 5720ff7a..00000000 --- a/scripts/train_unisrec.py +++ /dev/null @@ -1,96 +0,0 @@ -"""End-to-end smoke test for UniSRecModel with synthetic data and fake embeddings.""" - -import numpy as np -import pandas as pd -import torch - -from rectools import Columns -from rectools.dataset import Dataset -from rectools.fast_transformers import UniSRecModel - - -def main() -> None: - # --- Synthetic dataset: 80 users x 60 items --- - rng = np.random.RandomState(123) - n_users, n_items = 80, 60 - - rows = [] - for u in range(n_users): - n_inter = rng.randint(4, 15) - items = rng.choice(n_items, size=n_inter, replace=False) - for rank, item in enumerate(items): - rows.append({ - Columns.User: u, - Columns.Item: item, - Columns.Weight: 1.0, - Columns.Datetime: pd.Timestamp("2024-01-01") + pd.Timedelta(hours=rank), - }) - df = pd.DataFrame(rows) - dataset = Dataset.construct(df) - print(f"Dataset: {n_users} users, {n_items} items, {len(df)} interactions") - - # --- Fake pretrained embeddings (random, shape [n_items, 64]) --- - torch.manual_seed(42) - pretrained = torch.randn(n_items, 64) - - # --- Train --- - model = UniSRecModel( - pretrained_item_embeddings=pretrained, - n_factors=32, - projection_hidden=64, - n_blocks=2, - n_heads=2, - session_max_len=16, - phase1_epochs=2, - phase2_epochs=2, - phase3_epochs=2, - phase1_lr=1e-3, - phase2_lr=3e-4, - phase3_lr=1e-4, - batch_size=32, - verbose=1, - ) - model.fit(dataset) - print("Training done (3 phases).") - - # --- Recommend --- - users = list(range(n_users)) - reco = model.recommend(users=users, dataset=dataset, k=5, filter_viewed=True) - print(f"\nTop-5 recommendations (first 3 users):") - print(reco[reco[Columns.User].isin(range(3))].to_string(index=False)) - - # --- Simple metrics --- - interactions = dataset.get_raw_interactions() - hits = 0 - total = 0 - ap_sum = 0.0 - for u in users: - viewed = set(interactions[interactions[Columns.User] == u][Columns.Item]) - rec_items = reco[reco[Columns.User] == u][Columns.Item].tolist() - rel = [1 if i in viewed else 0 for i in rec_items] - hits += sum(rel) - total += len(rec_items) - if sum(rel) > 0: - precision_at = np.cumsum(rel) / np.arange(1, len(rel) + 1) - ap_sum += np.sum(precision_at * rel) / sum(rel) - recall = hits / max(total, 1) - map_at_k = ap_sum / len(users) - print(f"\nRecall@5 (train overlap): {recall:.4f}") - print(f"MAP@5 (train overlap): {map_at_k:.4f}") - - # --- NaN check --- - nan_count = reco[Columns.Score].isna().sum() - print(f"NaN scores: {nan_count} / {len(reco)}") - assert nan_count == 0, "Found NaN scores!" - - # --- I2I --- - target_items = list(range(10)) - i2i = model.recommend_to_items(target_items=target_items, dataset=dataset, k=5) - print(f"\nI2I recommendations (first 3 target items):") - print(i2i[i2i[Columns.TargetItem].isin(range(3))].to_string(index=False)) - - print("\nSmoke test passed!") - - -if __name__ == "__main__": - main() diff --git a/scripts/train_unisrec_ml20m.py b/scripts/train_unisrec_ml20m.py deleted file mode 100644 index 388ee9a4..00000000 --- a/scripts/train_unisrec_ml20m.py +++ /dev/null @@ -1,293 +0,0 @@ -"""Train UniSRec on ML-20M with Qwen embeddings.""" - -import json -import zipfile -from pathlib import Path - -import numpy as np -import pandas as pd -import torch -from tqdm import tqdm - -from rectools.fast_transformers import UniSRecModel - -DESCRIPTIONS_PATH = "training_folder/uniSRec/item_descriptions_compact.json" -QWEN_MODEL_NAME = "Qwen/Qwen3-Embedding-0.6B" -QWEN_DIM = 1024 -DATA_DIR = Path("data/ml-20m") -CACHE_EMB_PATH = DATA_DIR / "qwen_embeddings.pt" -ML20M_URL = "https://files.grouplens.org/datasets/movielens/ml-20m.zip" - -MIN_RATING = 4.0 -MIN_ITEM_INTERACTIONS = 50 -MIN_USER_INTERACTIONS = 5 -PHASE3_EPOCHS = 30 - - -def download_ml20m(): - DATA_DIR.mkdir(parents=True, exist_ok=True) - ratings_path = DATA_DIR / "ml-20m" / "ratings.csv" - if ratings_path.exists(): - return - zip_path = DATA_DIR / "ml-20m.zip" - if not zip_path.exists(): - print(f"Downloading ML-20M...") - import urllib.request - urllib.request.urlretrieve(ML20M_URL, zip_path) - print("Extracting...") - with zipfile.ZipFile(zip_path, "r") as zf: - zf.extractall(DATA_DIR) - - -def load_and_preprocess(): - download_ml20m() - ratings = pd.read_csv(DATA_DIR / "ml-20m" / "ratings.csv") - ratings.columns = ["user_id", "item_id", "rating", "timestamp"] - - if MIN_RATING > 0: - ratings = ratings[ratings["rating"] >= MIN_RATING] - print(f"After rating filter (>={MIN_RATING}): {len(ratings):,} interactions") - - if MIN_ITEM_INTERACTIONS > 0: - item_counts = ratings.groupby("item_id").size() - popular = item_counts[item_counts >= MIN_ITEM_INTERACTIONS].index - ratings = ratings[ratings["item_id"].isin(popular)] - print(f"After item filter (>={MIN_ITEM_INTERACTIONS}): {ratings['item_id'].nunique():,} items") - - user_counts = ratings.groupby("user_id").size() - valid = user_counts[user_counts >= MIN_USER_INTERACTIONS].index - ratings = ratings[ratings["user_id"].isin(valid)] - print(f"Final: {len(ratings):,} interactions, {ratings['user_id'].nunique():,} users, {ratings['item_id'].nunique():,} items") - - movies = pd.read_csv(DATA_DIR / "ml-20m" / "movies.csv") - movies.columns = ["movieId", "title", "genres"] - return ratings, movies - - -def _last_token_pool(hidden_states, attention_mask): - left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] - if left_padding: - return hidden_states[:, -1] - seq_lengths = attention_mask.sum(dim=1) - 1 - return hidden_states[torch.arange(hidden_states.shape[0], device=hidden_states.device), seq_lengths] - - -@torch.no_grad() -def encode_qwen(texts, device="cuda", batch_size=1024): - from transformers import AutoModel, AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_NAME, padding_side="left") - model = AutoModel.from_pretrained(QWEN_MODEL_NAME, torch_dtype=torch.float16).to(device).eval() - - embeddings = torch.zeros(len(texts), QWEN_DIM) - for start in tqdm(range(0, len(texts), batch_size), desc="Qwen encode"): - batch = texts[start:start + batch_size] - inputs = tokenizer(batch, padding=True, truncation=True, max_length=512, return_tensors="pt").to(device) - hidden = model(**inputs).last_hidden_state - out = _last_token_pool(hidden, inputs["attention_mask"]) - embeddings[start:start + len(batch)] = out.float().cpu() - - del model, tokenizer - torch.cuda.empty_cache() - return embeddings - - -def build_pretrained_embeddings(movies, descriptions): - all_movie_ids = sorted(movies["movieId"].unique()) - max_id = max(all_movie_ids) - texts_by_id = {} - - for mid in all_movie_ids: - key = str(mid) - if key in descriptions: - val = descriptions[key] - texts_by_id[mid] = val[0] if isinstance(val, list) else val - else: - row = movies[movies["movieId"] == mid] - if len(row) > 0: - texts_by_id[mid] = f"{row.iloc[0]['title']} {row.iloc[0]['genres']}" - else: - texts_by_id[mid] = f"movie {mid}" - - ordered_ids = sorted(texts_by_id.keys()) - ordered_texts = [texts_by_id[mid] for mid in ordered_ids] - - if CACHE_EMB_PATH.exists(): - print(f"Loading cached embeddings from {CACHE_EMB_PATH}") - return torch.load(CACHE_EMB_PATH, weights_only=True) - - raw_embs = encode_qwen(ordered_texts, batch_size=512) - - embeddings = torch.zeros(max_id + 1, QWEN_DIM) - for i, mid in enumerate(ordered_ids): - embeddings[mid] = raw_embs[i] - - torch.save(embeddings, CACHE_EMB_PATH) - print(f"Saved embeddings to {CACHE_EMB_PATH}, shape={embeddings.shape}") - return embeddings - - -def split_eval(ratings): - ratings = ratings.sort_values(["user_id", "timestamp"]) - grouped = ratings.groupby("user_id") - test_idx = grouped.tail(1).index - remaining = ratings.drop(test_idx) - val_idx = remaining.groupby("user_id").tail(1).index - train_idx = remaining.drop(val_idx).index - - train = ratings.loc[train_idx] - val = ratings.loc[val_idx] - test = ratings.loc[test_idx] - return train, val, test - - -def to_tensors(df): - """Convert a ratings DataFrame to (user_ids, item_ids, timestamps) tensors.""" - return ( - torch.tensor(df["user_id"].values, dtype=torch.long), - torch.tensor(df["item_id"].values, dtype=torch.long), - torch.tensor(df["timestamp"].values, dtype=torch.long), - ) - - -@torch.no_grad() -def evaluate_fast(model, train_ratings_df, test_df, k=10, batch_size=256): - net = model.net - net.cuda().eval() - device = torch.device("cuda") - maxlen = net.session_max_len - - item_embs = net.project_all() - unique_items = model.item_id_mapping - - ext_to_int = {} - for i in range(len(unique_items)): - ext_to_int[int(unique_items[i].item())] = i + 1 - - train_grouped = train_ratings_df.sort_values("timestamp").groupby("user_id")["item_id"].agg(list).to_dict() - test_grouped = test_df.groupby("user_id")["item_id"].first().to_dict() - test_users = list(test_grouped.keys()) - - hits, ndcg_sum, mrr_sum, total = 0, 0.0, 0.0, 0 - - for start in tqdm(range(0, len(test_users), batch_size), desc="Evaluating"): - batch_users = test_users[start:start + batch_size] - seqs, targets = [], [] - for uid in batch_users: - history = train_grouped.get(uid, []) - mapped = [ext_to_int[iid] for iid in history if iid in ext_to_int] - if not mapped: - continue - seq = mapped[-maxlen:] - seqs.append([0] * (maxlen - len(seq)) + seq) - targets.append(ext_to_int.get(test_grouped[uid])) - - if not seqs: - continue - - x = torch.tensor(seqs, dtype=torch.long, device=device) - h = net.encode_last(x, use_id=False) - scores = h @ item_embs.T - scores[:, 0] = float("-inf") - - for i, target_int in enumerate(targets): - if target_int is None: - continue - _, topk_idx = scores[i].topk(k) - topk = topk_idx.cpu().tolist() - if target_int in topk: - rank = topk.index(target_int) - hits += 1 - ndcg_sum += 1.0 / np.log2(rank + 2) - mrr_sum += 1.0 / (rank + 1) - total += 1 - - return { - f"HR@{k}": hits / total if total else 0, - f"NDCG@{k}": ndcg_sum / total if total else 0, - f"MRR@{k}": mrr_sum / total if total else 0, - "n_users": total, - } - - -def main(): - print("=" * 60) - print("UniSRec Training on ML-20M") - print("=" * 60) - - ratings, movies = load_and_preprocess() - descriptions = json.loads(Path(DESCRIPTIONS_PATH).read_text()) - print(f"Loaded {len(descriptions)} descriptions") - - pretrained = build_pretrained_embeddings(movies, descriptions) - print(f"Pretrained embeddings: {pretrained.shape}") - - train_ratings, val_ratings, test_ratings = split_eval(ratings) - print(f"Split: train={len(train_ratings):,}, val={len(val_ratings):,}, test={len(test_ratings):,}") - - train_with_val = pd.concat([train_ratings, val_ratings]) - - checkpoint_path = DATA_DIR / "unisrec_v3.pt" - - model = UniSRecModel( - pretrained_item_embeddings=pretrained, - n_factors=512, - projection_hidden=512, - n_blocks=2, - n_heads=1, - session_max_len=200, - dropout=0.1, - adaptor_dropout=0.2, - adaptor_type="pca", - use_adaptor_ffn=True, - phase1_epochs=0, - phase2_epochs=0, - phase3_epochs=PHASE3_EPOCHS, - phase1_lr=1e-3, - phase2_lr=3e-4, - phase3_lr=1e-4, - lr_head=0.3, - lr_wp=0.1, - lr_transformer=3.0, - optimizer="adamw", - scheduler="cosine_warmup", - warmup_ratio=0.05, - min_lr_ratio=1.0, - grad_clip=1.0, - weight_decay=0.01, - loss="softmax", - patience=10, - batch_size=128, - dataloader_num_workers=0, - train_min_user_interactions=2, - verbose=1, - ) - - if checkpoint_path.exists(): - print(f"Loading checkpoint from {checkpoint_path}") - model.load_checkpoint(checkpoint_path) - else: - print("\nStarting training...") - user_ids, item_ids, timestamps = to_tensors(train_with_val) - model.fit(user_ids, item_ids, timestamps) - model.save_checkpoint(checkpoint_path) - print(f"Saved checkpoint to {checkpoint_path}") - - print("Training complete!") - - print("\n--- Validation Metrics ---") - val_results = evaluate_fast(model, train_ratings, val_ratings) - for m, v in val_results.items(): - print(f" {m}: {v}") - - print("\n--- Test Metrics ---") - test_results = evaluate_fast(model, train_with_val, test_ratings) - for m, v in test_results.items(): - print(f" {m}: {v}") - - print("\n--- Expected Metrics ---") - print(" val: HR@10=0.2431 NDCG@10=0.1335") - print(" test: HR@10=0.2218 NDCG@10=0.1251 MRR@10=0.0957") - - -if __name__ == "__main__": - main() From 7d3850b70aa58d794cc295c33fc5f27abd8f81fd Mon Sep 17 00:00:00 2001 From: TOPAPEC Date: Fri, 24 Apr 2026 22:23:51 +0000 Subject: [PATCH 6/7] Formatting --- rectools/fast_transformers/__init__.py | 4 +- rectools/fast_transformers/gpu_data.py | 7 +- rectools/fast_transformers/lightning_wrap.py | 6 +- rectools/fast_transformers/model.py | 21 +-- rectools/fast_transformers/net.py | 14 +- .../fast_transformers/unisrec_lightning.py | 22 ++- rectools/fast_transformers/unisrec_model.py | 66 +++++--- rectools/fast_transformers/unisrec_net.py | 27 ++-- scripts/compare_sasrec_unisrec.py | 149 +++++++++++------- tests/fast_transformers/test_gpu_data.py | 55 +++---- .../fast_transformers/test_lightning_wrap.py | 8 +- tests/fast_transformers/test_model.py | 14 +- tests/fast_transformers/test_net.py | 7 +- tests/fast_transformers/test_ranking.py | 10 +- .../test_unisrec_lightning.py | 23 ++- tests/fast_transformers/test_unisrec_model.py | 4 +- tests/fast_transformers/test_unisrec_net.py | 2 +- 17 files changed, 252 insertions(+), 187 deletions(-) diff --git a/rectools/fast_transformers/__init__.py b/rectools/fast_transformers/__init__.py index c074130f..1f129c37 100644 --- a/rectools/fast_transformers/__init__.py +++ b/rectools/fast_transformers/__init__.py @@ -1,13 +1,13 @@ """Fast Transformers: flat sequential recommenders without ItemNet hierarchy.""" -from .gpu_data import build_sequences, align_embeddings, GPUBatchDataset, make_dataloader +from .gpu_data import GPUBatchDataset, align_embeddings, build_sequences, make_dataloader from .lightning_wrap import FlatSASRecLightning from .model import FlatSASRecConfig, FlatSASRecModel from .net import FlatSASRec, SASRecBlock from .ranking import rank_topk -from .unisrec_net import UniSRec, FeedForward from .unisrec_lightning import UniSRecLightning from .unisrec_model import UniSRecModel +from .unisrec_net import FeedForward, UniSRec __all__ = [ "build_sequences", diff --git a/rectools/fast_transformers/gpu_data.py b/rectools/fast_transformers/gpu_data.py index c4e67852..5a8d7eee 100644 --- a/rectools/fast_transformers/gpu_data.py +++ b/rectools/fast_transformers/gpu_data.py @@ -3,7 +3,8 @@ import typing as tp import torch -from torch.utils.data import Dataset as TorchDataset, DataLoader +from torch.utils.data import DataLoader +from torch.utils.data import Dataset as TorchDataset def build_sequences( @@ -52,7 +53,9 @@ def build_sequences( if total_elements > 0: user_indices = torch.repeat_interleave(torch.arange(n_users, device=device), effective_lens) cumsum = effective_lens.cumsum(0) - offsets = torch.arange(total_elements, device=device) - torch.repeat_interleave(cumsum - effective_lens, effective_lens) + offsets = torch.arange(total_elements, device=device) - torch.repeat_interleave( + cumsum - effective_lens, effective_lens + ) x_src = torch.repeat_interleave(ends - capped_lens, effective_lens) + offsets y_src = x_src + 1 diff --git a/rectools/fast_transformers/lightning_wrap.py b/rectools/fast_transformers/lightning_wrap.py index 698afa10..75d20a39 100644 --- a/rectools/fast_transformers/lightning_wrap.py +++ b/rectools/fast_transformers/lightning_wrap.py @@ -2,8 +2,8 @@ import typing as tp -import torch import pytorch_lightning as pl +import torch from torch import nn from .net import FlatSASRec @@ -47,7 +47,9 @@ def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> to if self.loss_name == "softmax": # logits: (B, L, n_items) — full catalog # targets need to be 0-indexed item ids (subtract 1 since item ids start from 1) - targets = y - 1 # shift to 0-based for CrossEntropyLoss; padding (0) becomes -1 -> ignore_index=0 won't work + targets = ( + y - 1 + ) # shift to 0-based for CrossEntropyLoss; padding (0) becomes -1 -> ignore_index=0 won't work # Actually, we set ignore_index=0 but padding maps to -1. # Let's use a different approach: set padding targets to 0 and use ignore_index=0 targets = y.clone() diff --git a/rectools/fast_transformers/model.py b/rectools/fast_transformers/model.py index e62f9943..ba2b2405 100644 --- a/rectools/fast_transformers/model.py +++ b/rectools/fast_transformers/model.py @@ -2,18 +2,15 @@ import typing as tp -import numpy as np import pandas as pd -import torch import pytorch_lightning as pl +import torch from scipy import sparse -from rectools import Columns from rectools.dataset import Dataset -from rectools.dataset.identifiers import IdMap from rectools.models.base import InternalRecoTriplet, ModelBase, ModelConfig -from rectools.models.nn.transformers.sasrec import SASRecDataPreparator from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler +from rectools.models.nn.transformers.sasrec import SASRecDataPreparator from rectools.types import InternalIdsArray from rectools.utils.config import BaseConfig @@ -157,10 +154,6 @@ def _fit(self, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> None: dp.process_dataset_train(dataset) self._data_preparator = dp - n_items = dp.item_id_map.size # includes extra tokens (padding) - # item ids in the preparator go from 0 (padding) to n_items-1 - # FlatSASRec expects n_items = max real item count (embedding table = n_items+1 with padding at 0) - # The preparator's item_id_map.size includes the padding token, so real items = size - 1 n_real_items = dp.item_id_map.size - dp.n_item_extra_tokens net = FlatSASRec( @@ -242,7 +235,6 @@ def _recommend_u2i( sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], ) -> InternalRecoTriplet: assert self._data_preparator is not None - device = next(self._net.parameters()).device # type: ignore user_embs = self._get_user_embeddings(dataset) # (n_users, D) item_embs = self._get_item_embeddings() # (n_items, D) @@ -278,7 +270,9 @@ def _recommend_u2i( whitelist = wl[(wl >= 0) & (wl < item_embs.shape[0])] u_ids, i_ids, scores = rank_topk( - user_embs, item_embs, k, + user_embs, + item_embs, + k, filter_csr=filter_csr, whitelist=whitelist, batch_size=self.recommend_batch_size, @@ -298,7 +292,6 @@ def _recommend_i2i( sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], ) -> InternalRecoTriplet: assert self._data_preparator is not None and self._net is not None - device = next(self._net.parameters()).device item_embs = self._get_item_embeddings() # (n_items, D) n_extra = self._data_preparator.n_item_extra_tokens @@ -313,7 +306,9 @@ def _recommend_i2i( whitelist = wl[(wl >= 0) & (wl < item_embs.shape[0])] t_ids, i_ids, scores = rank_topk( - target_embs, item_embs, k, + target_embs, + item_embs, + k, whitelist=whitelist, batch_size=self.recommend_batch_size, ) diff --git a/rectools/fast_transformers/net.py b/rectools/fast_transformers/net.py index 81d4dd7d..f9e06b00 100644 --- a/rectools/fast_transformers/net.py +++ b/rectools/fast_transformers/net.py @@ -127,19 +127,7 @@ def encode_last(self, x: torch.Tensor) -> torch.Tensor: Tensor (B, D) """ h = self.encode(x) # (B, L, D) - # Find last non-padding position per row - non_pad = (x != self.PADDING_IDX) # (B, L) - # lengths: number of non-pad tokens - lengths = non_pad.sum(dim=1) # (B,) - # Clamp to at least 1 to avoid index -1 for fully-padded rows - last_idx = (lengths - 1).clamp(min=0) - # We use left-padding, so last non-pad is at position (L - 1) if any token exists - # Actually with left padding, non-pad tokens are at the end, so the last position is L-1 - # But let's compute correctly: the last non-pad index - # With left-padding: first non-pad is at L - length, last non-pad is at L - 1 - B = x.shape[0] - last_pos = x.shape[1] - 1 # last position is always the last for left-padded sequences - return h[:, last_pos, :] # (B, D) + return h[:, -1, :] # left-padded: last position is always rightmost def all_item_embeddings(self) -> torch.Tensor: """ diff --git a/rectools/fast_transformers/unisrec_lightning.py b/rectools/fast_transformers/unisrec_lightning.py index 640b574d..118d5840 100644 --- a/rectools/fast_transformers/unisrec_lightning.py +++ b/rectools/fast_transformers/unisrec_lightning.py @@ -3,9 +3,9 @@ import math import typing as tp +import pytorch_lightning as pl import torch import torch.nn.functional as F -import pytorch_lightning as pl from torch.optim.lr_scheduler import LambdaLR from .unisrec_net import UniSRec @@ -63,7 +63,10 @@ def _get_all_embs(self) -> torch.Tensor: return self.net.project_all() def _get_pos_neg_logits( - self, hidden: torch.Tensor, labels: torch.Tensor, negatives: torch.Tensor, + self, + hidden: torch.Tensor, + labels: torch.Tensor, + negatives: torch.Tensor, ) -> torch.Tensor: """Compute (B, L, 1+N) logits where index 0 = positive.""" emb_pos = self._get_item_embs(labels) @@ -71,7 +74,8 @@ def _get_pos_neg_logits( emb_neg = self._get_item_embs(negatives) logits_neg = torch.matmul( - hidden.unsqueeze(2), emb_neg.transpose(2, 3), + hidden.unsqueeze(2), + emb_neg.transpose(2, 3), ).squeeze(2) return torch.cat([logits_pos.unsqueeze(-1), logits_neg], dim=-1) @@ -79,7 +83,9 @@ def _get_pos_neg_logits( # ── losses ── def _calc_loss( - self, hidden: torch.Tensor, batch: tp.Dict[str, torch.Tensor], + self, + hidden: torch.Tensor, + batch: tp.Dict[str, torch.Tensor], ) -> torch.Tensor: labels = batch["y"] has_neg = "negatives" in batch @@ -114,7 +120,9 @@ def _full_softmax_loss(self, hidden: torch.Tensor, labels: torch.Tensor) -> torc targets = labels.clone() targets[targets == 0] = -100 return F.cross_entropy( - logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100, + logits.view(-1, logits.size(-1)), + targets.view(-1), + ignore_index=-100, ) def _sampled_softmax_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: @@ -123,7 +131,9 @@ def _sampled_softmax_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> tor logits[:, :, [0, 1]] = logits[:, :, [1, 0]] targets = mask.long() # 1 where non-padding, 0 where padding return F.cross_entropy( - logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=0, + logits.view(-1, logits.size(-1)), + targets.view(-1), + ignore_index=0, ) def _bce_loss(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: diff --git a/rectools/fast_transformers/unisrec_model.py b/rectools/fast_transformers/unisrec_model.py index c737900e..cbb7b632 100644 --- a/rectools/fast_transformers/unisrec_model.py +++ b/rectools/fast_transformers/unisrec_model.py @@ -3,13 +3,13 @@ import typing as tp from pathlib import Path -import torch import pytorch_lightning as pl +import torch from pytorch_lightning.callbacks import EarlyStopping +from .gpu_data import align_embeddings, build_sequences, make_dataloader +from .unisrec_lightning import SUPPORTED_LOSSES, SUPPORTED_OPTIMIZERS, SUPPORTED_SCHEDULERS, UniSRecLightning from .unisrec_net import UniSRec -from .unisrec_lightning import UniSRecLightning, SUPPORTED_LOSSES, SUPPORTED_OPTIMIZERS, SUPPORTED_SCHEDULERS -from .gpu_data import build_sequences, align_embeddings, make_dataloader class UniSRecModel: @@ -143,7 +143,12 @@ def _make_trainer(self, max_epochs: int, val_dl: tp.Any = None) -> pl.Trainer: ) def _make_lightning( - self, net: UniSRec, param_groups: tp.List[tp.Dict], use_id: bool, max_epochs: int, train_dl: tp.Any, + self, + net: UniSRec, + param_groups: tp.List[tp.Dict], + use_id: bool, + max_epochs: int, + train_dl: tp.Any, ) -> UniSRecLightning: total_steps = len(train_dl) * max_epochs if self.scheduler else None return UniSRecLightning( @@ -172,16 +177,22 @@ def _phase2_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]: {"params": [net.whitening_bias], "lr": self.phase2_lr * 10.0, "weight_decay": 0.0}, ] if net.head is not None: - groups.append({ - "params": list(net.head.parameters()), - "lr": self.phase2_lr * self.lr_head, - "weight_decay": self.weight_decay, - }) + groups.append( + { + "params": list(net.head.parameters()), + "lr": self.phase2_lr * self.lr_head, + "weight_decay": self.weight_decay, + } + ) else: groups = [ {"params": list(net.bn_input.parameters()), "lr": self.phase2_lr, "weight_decay": 0.0}, {"params": list(net.bn_score.parameters()), "lr": self.phase2_lr, "weight_decay": 0.0}, - {"params": list(net.head.parameters()), "lr": self.phase2_lr * self.lr_head, "weight_decay": self.weight_decay}, + { + "params": list(net.head.parameters()), + "lr": self.phase2_lr * self.lr_head, + "weight_decay": self.weight_decay, + }, ] return groups @@ -198,21 +209,27 @@ def _phase3_params(self, net: UniSRec) -> tp.List[tp.Dict[str, tp.Any]]: ] head: tp.List[tp.Dict[str, tp.Any]] = [] if net.head is not None: - head = [{"params": list(net.head.parameters()), "lr": self.phase3_lr * self.lr_head, "weight_decay": self.weight_decay}] + head = [ + { + "params": list(net.head.parameters()), + "lr": self.phase3_lr * self.lr_head, + "weight_decay": self.weight_decay, + } + ] transformer = [ {"params": list(net.pos_emb.parameters()), "lr": self.phase3_lr * self.lr_transformer, "weight_decay": 0.0}, { "params": ( - [p for l in net.attention_layers for p in l.parameters()] - + [p for l in net.forward_layers for p in l.parameters()] + [p for layer in net.attention_layers for p in layer.parameters()] + + [p for layer in net.forward_layers for p in layer.parameters()] ), "lr": self.phase3_lr * self.lr_transformer, "weight_decay": self.weight_decay, }, { "params": ( - [p for l in net.attention_layernorms for p in l.parameters()] - + [p for l in net.forward_layernorms for p in l.parameters()] + [p for layer in net.attention_layernorms for p in layer.parameters()] + + [p for layer in net.forward_layernorms for p in layer.parameters()] + list(net.last_layernorm.parameters()) ), "lr": self.phase3_lr, @@ -246,7 +263,9 @@ def fit( self """ x, y, unique_items, unique_users = build_sequences( - user_ids, item_ids, timestamps, + user_ids, + item_ids, + timestamps, max_len=self.session_max_len, min_interactions=self.train_min_user_interactions, ) @@ -303,12 +322,15 @@ def _run_phase(param_groups: tp.List[tp.Dict], use_id: bool, max_epochs: int) -> def save_checkpoint(self, path: tp.Union[str, Path]) -> None: assert self._net is not None - torch.save({ - "net": self._net.state_dict(), - "unique_items": self._unique_items, - "unique_users": self._unique_users, - "n_items": len(self._unique_items), - }, path) + torch.save( + { + "net": self._net.state_dict(), + "unique_items": self._unique_items, + "unique_users": self._unique_users, + "n_items": len(self._unique_items), + }, + path, + ) def load_checkpoint(self, path: tp.Union[str, Path], device: str = "cuda") -> None: ckpt = torch.load(path, map_location=device, weights_only=False) diff --git a/rectools/fast_transformers/unisrec_net.py b/rectools/fast_transformers/unisrec_net.py index d1329b20..47ebc7a9 100644 --- a/rectools/fast_transformers/unisrec_net.py +++ b/rectools/fast_transformers/unisrec_net.py @@ -51,12 +51,17 @@ def make_ffn(n_factors: int, ffn_type: str, expansion: int, dropout: float) -> n hidden = n_factors * expansion if ffn_type == "linear_gelu": return nn.Sequential( - nn.Linear(n_factors, hidden), nn.GELU(), nn.Dropout(dropout), - nn.Linear(hidden, n_factors), nn.Dropout(dropout), + nn.Linear(n_factors, hidden), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden, n_factors), + nn.Dropout(dropout), ) if ffn_type == "linear_relu": return nn.Sequential( - nn.Linear(n_factors, hidden), nn.ReLU(), nn.Dropout(dropout), + nn.Linear(n_factors, hidden), + nn.ReLU(), + nn.Dropout(dropout), nn.Linear(hidden, n_factors), ) raise ValueError(f"Unknown ffn_type: {ffn_type}. Choose from: conv1d, linear_gelu, linear_relu") @@ -238,8 +243,10 @@ def project_all(self) -> torch.Tensor: @property def transformer_params(self) -> tp.List[nn.Parameter]: modules = ( - list(self.attention_layernorms) + list(self.attention_layers) - + list(self.forward_layernorms) + list(self.forward_layers) + list(self.attention_layernorms) + + list(self.attention_layers) + + list(self.forward_layernorms) + + list(self.forward_layers) + [self.last_layernorm, self.pos_emb] ) return [p for m in modules for p in m.parameters()] @@ -272,9 +279,9 @@ def _encode(self, seqs: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: seqs = seqs + self.pos_emb(positions) seqs = self.emb_dropout(seqs) - pad_mask = (input_ids == self.PADDING_IDX) # (B, L) - pad_mask_3d = pad_mask.unsqueeze(-1) # (B, L, 1) - seqs = seqs.masked_fill(pad_mask_3d, 0.0) # zero out padding + pad_mask = input_ids == self.PADDING_IDX # (B, L) + pad_mask_3d = pad_mask.unsqueeze(-1) # (B, L, 1) + seqs = seqs.masked_fill(pad_mask_3d, 0.0) # zero out padding attn_mask = self._causal_mask(L, seqs.device) key_padding_mask = pad_mask @@ -284,7 +291,9 @@ def _encode(self, seqs: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: # Zero padding in Q/K/V so NaN can never appear in dot-products normed = normed.masked_fill(pad_mask_3d, 0.0) mha_out, _ = self.attention_layers[i]( - normed, normed, normed, + normed, + normed, + normed, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, diff --git a/scripts/compare_sasrec_unisrec.py b/scripts/compare_sasrec_unisrec.py index bf6ee18a..de39c3fd 100644 --- a/scripts/compare_sasrec_unisrec.py +++ b/scripts/compare_sasrec_unisrec.py @@ -17,9 +17,9 @@ from rectools import Columns from rectools.dataset import Dataset -from rectools.models import SASRecModel from rectools.fast_transformers import UniSRecModel from rectools.fast_transformers.gpu_data import build_sequences +from rectools.models import SASRecModel DATA_DIR = Path("data/ml-20m") CACHE_EMB_PATH = DATA_DIR / "qwen_embeddings.pt" @@ -94,7 +94,7 @@ def evaluate_unisrec(model, train_df, test_df, k=10, batch_size=256, use_id=Fals hits, ndcg_sum, mrr_sum, total = 0, 0.0, 0.0, 0 for start in tqdm(range(0, len(test_users), batch_size), desc="Eval UniSRec"): - batch_users = test_users[start:start + batch_size] + batch_users = test_users[start : start + batch_size] seqs, targets = [], [] for uid in batch_users: history = train_grouped.get(uid, []) @@ -151,44 +151,48 @@ def cleanup(): def write_report(timings: dict, metrics: dict, data_info: dict): gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A" + date_str = datetime.now().strftime("%Y-%m-%d %H:%M") + dataset_str = ( + f"ML-20M (min_rating={MIN_RATING}," f" min_item={MIN_ITEM_INTERACTIONS}," f" min_user={MIN_USER_INTERACTIONS})" + ) lines = [ - f"# SASRec vs UniSRec-ID Comparison", - f"", - f"**Date:** {datetime.now().strftime('%Y-%m-%d %H:%M')} ", + "# SASRec vs UniSRec-ID Comparison", + "", + f"**Date:** {date_str} ", f"**GPU:** {gpu_name} ", - f"**Dataset:** ML-20M (min_rating={MIN_RATING}, min_item={MIN_ITEM_INTERACTIONS}, min_user={MIN_USER_INTERACTIONS})", - f"", - f"## Data", - f"", - f"| | Count |", - f"|---|---:|", + f"**Dataset:** {dataset_str}", + "", + "## Data", + "", + "| | Count |", + "|---|---:|", f"| Interactions | {data_info['n_interactions']:,} |", f"| Users | {data_info['n_users']:,} |", f"| Items | {data_info['n_items']:,} |", f"| Train | {data_info['n_train']:,} |", f"| Val | {data_info['n_val']:,} |", f"| Test | {data_info['n_test']:,} |", - f"", - f"## Config", - f"", - f"| Parameter | Value |", - f"|---|---|", + "", + "## Config", + "", + "| Parameter | Value |", + "|---|---|", f"| n_factors | {N_FACTORS} |", f"| n_blocks | {N_BLOCKS} |", f"| n_heads | {N_HEADS} |", f"| session_max_len | {SESSION_MAX_LEN} |", f"| batch_size | {BATCH_SIZE} |", f"| lr | {LR} |", - f"| loss | softmax |", - f"| optimizer | Adam |", + "| loss | softmax |", + "| optimizer | Adam |", f"| epochs | {EPOCHS} |", f"| patience | {PATIENCE} |", - f"| dropout | 0.1 |", - f"", - f"## Timing", - f"", - f"| Stage | SASRec | UniSRec ID |", - f"|---|---:|---:|", + "| dropout | 0.1 |", + "", + "## Timing", + "", + "| Stage | SASRec | UniSRec ID |", + "|---|---:|---:|", ] for stage in ["data_load", "preprocessing", "model_init", "training", "eval"]: @@ -209,32 +213,42 @@ def write_report(timings: dict, metrics: dict, data_info: dict): s_epoch = timings.get("sasrec_training", 0) / max(timings.get("sasrec_epochs_done", 1), 1) u_epoch = timings.get("unisrec_training", 0) / max(timings.get("unisrec_epochs_done", 1), 1) - lines.extend([ - f"", - f"| | SASRec | UniSRec ID |", - f"|---|---:|---:|", - f"| Epochs completed | {timings.get('sasrec_epochs_done', EPOCHS)} | {timings.get('unisrec_epochs_done', EPOCHS)} |", - f"| Time per epoch | {s_epoch:.1f}s | {u_epoch:.1f}s |", - f"| Preprocessing speedup | — | {timings.get('prep_speedup', 0):.0f}x |", - ]) - - lines.extend([ - f"", - f"## Quality (test set, {metrics['sasrec']['n_users']:,} users)", - f"", - f"| Model | HR@10 | NDCG@10 | MRR@10 |", - f"|---|---:|---:|---:|", - ]) + s_epochs_done = timings.get("sasrec_epochs_done", EPOCHS) + u_epochs_done = timings.get("unisrec_epochs_done", EPOCHS) + prep_speedup = timings.get("prep_speedup", 0) + lines.extend( + [ + "", + "| | SASRec | UniSRec ID |", + "|---|---:|---:|", + f"| Epochs completed | {s_epochs_done} | {u_epochs_done} |", + f"| Time per epoch | {s_epoch:.1f}s | {u_epoch:.1f}s |", + f"| Preprocessing speedup | — | {prep_speedup:.0f}x |", + ] + ) + + n_test_users = metrics["sasrec"]["n_users"] + lines.extend( + [ + "", + f"## Quality (test set, {n_test_users:,} users)", + "", + "| Model | HR@10 | NDCG@10 | MRR@10 |", + "|---|---:|---:|---:|", + ] + ) for name, key in [("SASRec", "sasrec"), ("UniSRec ID", "unisrec")]: m = metrics[key] lines.append(f"| {name} | {m['HR@10']:.4f} | {m['NDCG@10']:.4f} | {m['MRR@10']:.4f} |") hr_diff = (metrics["unisrec"]["HR@10"] / metrics["sasrec"]["HR@10"] - 1) * 100 ndcg_diff = (metrics["unisrec"]["NDCG@10"] / metrics["sasrec"]["NDCG@10"] - 1) * 100 - lines.extend([ - f"", - f"UniSRec vs SASRec: HR@10 {hr_diff:+.1f}%, NDCG@10 {ndcg_diff:+.1f}%", - ]) + lines.extend( + [ + "", + f"UniSRec vs SASRec: HR@10 {hr_diff:+.1f}%, NDCG@10 {ndcg_diff:+.1f}%", + ] + ) report = "\n".join(lines) + "\n" REPORT_PATH.write_text(report) @@ -264,7 +278,10 @@ def main(): "n_val": len(val_ratings), "n_test": len(test_ratings), } - print(f"Data: {data_info['n_interactions']:,} interactions, {data_info['n_users']:,} users, {data_info['n_items']:,} items") + n_int = data_info["n_interactions"] + n_usr = data_info["n_users"] + n_itm = data_info["n_items"] + print(f"Data: {n_int:,} interactions, {n_usr:,} users, {n_itm:,} items") print(f"Split: train={data_info['n_train']:,}, val={data_info['n_val']:,}, test={data_info['n_test']:,}") user_ids_t, item_ids_t, timestamps_t = to_tensors(train_with_val) @@ -273,18 +290,20 @@ def main(): # ══════════════════════════════════════════════════════════════ # 1. SASRec (RecTools) # ══════════════════════════════════════════════════════════════ - print(f"\n{'='*70}") + print(f"\n{'=' * 70}") print(f"1. SASRec (RecTools) — {EPOCHS} epochs") - print(f"{'='*70}") + print(f"{'=' * 70}") # Preprocessing t0 = time.time() - df_rectools = pd.DataFrame({ - Columns.User: train_with_val["user_id"].values, - Columns.Item: train_with_val["item_id"].values, - Columns.Weight: 1.0, - Columns.Datetime: pd.to_datetime(train_with_val["timestamp"], unit="s"), - }) + df_rectools = pd.DataFrame( + { + Columns.User: train_with_val["user_id"].values, + Columns.Item: train_with_val["item_id"].values, + Columns.Weight: 1.0, + Columns.Datetime: pd.to_datetime(train_with_val["timestamp"], unit="s"), + } + ) dataset = Dataset.construct(df_rectools) timings["sasrec_preprocessing"] = time.time() - t0 print(f" Preprocessing (Dataset.construct): {timings['sasrec_preprocessing']:.2f}s") @@ -292,9 +311,11 @@ def main(): # Model init + training def sasrec_trainer(**kwargs): import pytorch_lightning as pl + callbacks = [] if PATIENCE is not None: from pytorch_lightning.callbacks import EarlyStopping + callbacks.append(EarlyStopping(monitor="val_loss", patience=PATIENCE, mode="min")) return pl.Trainer( max_epochs=EPOCHS, @@ -323,11 +344,13 @@ def sasrec_trainer(**kwargs): get_trainer_func=sasrec_trainer, ) if PATIENCE is not None: + def sasrec_val_mask(interactions_df, **kwargs): idx = interactions_df.groupby(Columns.User).tail(1).index mask = pd.Series(False, index=interactions_df.index) mask.loc[idx] = True return mask + sasrec_kwargs["get_val_mask_func"] = sasrec_val_mask t0 = time.time() @@ -346,15 +369,19 @@ def sasrec_val_mask(interactions_df, **kwargs): sasrec_metrics = evaluate_sasrec(sasrec, dataset, test_ratings) timings["sasrec_eval"] = time.time() - t0 print(f" Eval: {timings['sasrec_eval']:.1f}s") - print(f" HR@10={sasrec_metrics['HR@10']:.4f} NDCG@10={sasrec_metrics['NDCG@10']:.4f} MRR@10={sasrec_metrics['MRR@10']:.4f}") - del sasrec; cleanup() + hr = sasrec_metrics["HR@10"] + ndcg = sasrec_metrics["NDCG@10"] + mrr = sasrec_metrics["MRR@10"] + print(f" HR@10={hr:.4f} NDCG@10={ndcg:.4f} MRR@10={mrr:.4f}") + del sasrec + cleanup() # ══════════════════════════════════════════════════════════════ # 2. UniSRec ID # ══════════════════════════════════════════════════════════════ - print(f"\n{'='*70}") + print(f"\n{'=' * 70}") print(f"2. UniSRec ID — {EPOCHS} epochs") - print(f"{'='*70}") + print(f"{'=' * 70}") # Preprocessing torch.cuda.synchronize() @@ -408,8 +435,12 @@ def sasrec_val_mask(interactions_df, **kwargs): unisrec_metrics = evaluate_unisrec(unisrec_id, train_with_val, test_ratings, use_id=True) timings["unisrec_eval"] = time.time() - t0 print(f" Eval: {timings['unisrec_eval']:.1f}s") - print(f" HR@10={unisrec_metrics['HR@10']:.4f} NDCG@10={unisrec_metrics['NDCG@10']:.4f} MRR@10={unisrec_metrics['MRR@10']:.4f}") - del unisrec_id; cleanup() + hr = unisrec_metrics["HR@10"] + ndcg = unisrec_metrics["NDCG@10"] + mrr = unisrec_metrics["MRR@10"] + print(f" HR@10={hr:.4f} NDCG@10={ndcg:.4f} MRR@10={mrr:.4f}") + del unisrec_id + cleanup() # ── Report ── metrics = {"sasrec": sasrec_metrics, "unisrec": unisrec_metrics} diff --git a/tests/fast_transformers/test_gpu_data.py b/tests/fast_transformers/test_gpu_data.py index c3938e6f..7b69c1dd 100644 --- a/tests/fast_transformers/test_gpu_data.py +++ b/tests/fast_transformers/test_gpu_data.py @@ -1,12 +1,11 @@ """Tests for GPU-native sequence building and data utilities.""" import torch -import pytest from rectools.fast_transformers.gpu_data import ( - build_sequences, - align_embeddings, GPUBatchDataset, + align_embeddings, + build_sequences, make_dataloader, ) @@ -108,9 +107,7 @@ def test_max_len_truncation(self) -> None: item_ids = torch.tensor([10, 20, 30, 40, 50]) timestamps = torch.tensor([1, 2, 3, 4, 5]) - x, y, _, _ = build_sequences( - user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE - ) + x, y, _, _ = build_sequences(user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE) # 5 items total. capped_lens = min(5, 3+1) = 4, effective = 3 # Sorted items: 10->1, 20->2, 30->3, 40->4, 50->5 @@ -145,9 +142,7 @@ def test_left_padding(self) -> None: item_ids = torch.tensor([10, 20]) timestamps = torch.tensor([1, 2]) - x, y, _, _ = build_sequences( - user_ids, item_ids, timestamps, max_len=5, min_interactions=2, device=DEVICE - ) + x, y, _, _ = build_sequences(user_ids, item_ids, timestamps, max_len=5, min_interactions=2, device=DEVICE) # 2 items => effective_len = 1 (capped_lens = 2, effective = 1) # x = [0, 0, 0, 0, 1], y = [0, 0, 0, 0, 2] @@ -208,9 +203,7 @@ def test_output_dtypes(self) -> None: item_ids = torch.tensor([1, 2]) timestamps = torch.tensor([1, 2]) - x, y, _, _ = build_sequences( - user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE - ) + x, y, _, _ = build_sequences(user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE) assert x.dtype == torch.long assert y.dtype == torch.long @@ -221,9 +214,7 @@ def test_exact_max_len_sequence(self) -> None: item_ids = torch.tensor([10, 20, 30, 40]) timestamps = torch.tensor([1, 2, 3, 4]) - x, y, _, _ = build_sequences( - user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE - ) + x, y, _, _ = build_sequences(user_ids, item_ids, timestamps, max_len=3, min_interactions=2, device=DEVICE) # 4 items, max_len=3 => capped_lens = min(4, 4) = 4, effective = 3 # No padding needed @@ -257,12 +248,14 @@ class TestAlignEmbeddings: def test_2d_pretrained(self) -> None: """Align 2D pretrained embeddings to internal ID order.""" - pretrained = torch.tensor([ - [1.0, 2.0], # external item 0 - [3.0, 4.0], # external item 1 - [5.0, 6.0], # external item 2 - [7.0, 8.0], # external item 3 - ]) + pretrained = torch.tensor( + [ + [1.0, 2.0], # external item 0 + [3.0, 4.0], # external item 1 + [5.0, 6.0], # external item 2 + [7.0, 8.0], # external item 3 + ] + ) # unique_items: external IDs that map to internal IDs 1, 2, 3 unique_items = torch.tensor([2, 0, 3]) n_items = 3 @@ -281,10 +274,12 @@ def test_2d_pretrained(self) -> None: def test_3d_pretrained(self) -> None: """Align 3D pretrained embeddings (multi-variant).""" - pretrained = torch.tensor([ - [[1.0, 2.0], [3.0, 4.0]], # item 0, 2 variants - [[5.0, 6.0], [7.0, 8.0]], # item 1 - ]) + pretrained = torch.tensor( + [ + [[1.0, 2.0], [3.0, 4.0]], # item 0, 2 variants + [[5.0, 6.0], [7.0, 8.0]], # item 1 + ] + ) unique_items = torch.tensor([1, 0]) n_items = 2 @@ -310,10 +305,12 @@ def test_padding_row_is_zero(self) -> None: def test_out_of_range_indices(self) -> None: """Items with external IDs outside pretrained range should get zero embeddings.""" - pretrained = torch.tensor([ - [1.0, 2.0], # external 0 - [3.0, 4.0], # external 1 - ]) + pretrained = torch.tensor( + [ + [1.0, 2.0], # external 0 + [3.0, 4.0], # external 1 + ] + ) # External ID 5 is out of range (pretrained has only 2 rows) unique_items = torch.tensor([0, 5, 1]) n_items = 3 diff --git a/tests/fast_transformers/test_lightning_wrap.py b/tests/fast_transformers/test_lightning_wrap.py index ca3b5b30..e45fccfe 100644 --- a/tests/fast_transformers/test_lightning_wrap.py +++ b/tests/fast_transformers/test_lightning_wrap.py @@ -1,10 +1,10 @@ """Tests for FlatSASRecLightning wrapper.""" -import torch import pytest +import torch -from rectools.fast_transformers.net import FlatSASRec from rectools.fast_transformers.lightning_wrap import FlatSASRecLightning +from rectools.fast_transformers.net import FlatSASRec @pytest.fixture() @@ -61,9 +61,7 @@ def test_on_train_start_reinitializes_params(self, net: FlatSASRec) -> None: module = FlatSASRecLightning(net) # Snapshot parameters with dim > 1 before reinit - snapshots_before = { - name: p.clone() for name, p in module.net.named_parameters() if p.dim() > 1 - } + snapshots_before = {name: p.clone() for name, p in module.net.named_parameters() if p.dim() > 1} assert len(snapshots_before) > 0, "Expected at least one param with dim > 1" # Force parameters to a constant value so reinit is detectable diff --git a/tests/fast_transformers/test_model.py b/tests/fast_transformers/test_model.py index 7676fb2d..a230d160 100644 --- a/tests/fast_transformers/test_model.py +++ b/tests/fast_transformers/test_model.py @@ -2,19 +2,23 @@ import pickle -import numpy as np -import pandas as pd import pytest from rectools import Columns from rectools.dataset import Dataset -from rectools.fast_transformers import FlatSASRecConfig, FlatSASRecModel +from rectools.fast_transformers import FlatSASRecModel def _make_model(**kwargs) -> FlatSASRecModel: defaults = dict( - n_factors=16, n_blocks=1, n_heads=2, session_max_len=8, - epochs=1, batch_size=16, lr=1e-3, verbose=0, + n_factors=16, + n_blocks=1, + n_heads=2, + session_max_len=8, + epochs=1, + batch_size=16, + lr=1e-3, + verbose=0, ) defaults.update(kwargs) return FlatSASRecModel(**defaults) diff --git a/tests/fast_transformers/test_net.py b/tests/fast_transformers/test_net.py index 0d590466..62a14a3e 100644 --- a/tests/fast_transformers/test_net.py +++ b/tests/fast_transformers/test_net.py @@ -1,7 +1,7 @@ """Tests for FlatSASRec network.""" -import torch import pytest +import torch from rectools.fast_transformers.net import FlatSASRec @@ -37,10 +37,7 @@ def test_encode_last_shape(self, net: FlatSASRec) -> None: def test_padding_invariance(self, net: FlatSASRec) -> None: """Different left-padding should produce same last-position embedding.""" net.eval() - x1 = torch.tensor([[0, 0, 0, 1, 2]]) - x2 = torch.tensor([[0, 0, 0, 0, 2]]) - # Not exactly the same because sequence context differs, - # but if we use the same content the output should be identical + # Same content should produce identical output x_a = torch.tensor([[0, 0, 0, 5, 10]]) x_b = torch.tensor([[0, 0, 0, 5, 10]]) with torch.no_grad(): diff --git a/tests/fast_transformers/test_ranking.py b/tests/fast_transformers/test_ranking.py index 46a5066f..156175bc 100644 --- a/tests/fast_transformers/test_ranking.py +++ b/tests/fast_transformers/test_ranking.py @@ -82,9 +82,9 @@ def test_scores_sorted_descending_per_user(self): for uid in range(user_embs.shape[0]): mask = user_ids == uid user_scores = scores[mask] - assert np.all(user_scores[:-1] >= user_scores[1:]), ( - f"Scores for user {uid} are not in descending order: {user_scores}" - ) + assert np.all( + user_scores[:-1] >= user_scores[1:] + ), f"Scores for user {uid} are not in descending order: {user_scores}" def test_filter_csr_excludes_viewed_items(self): """Items present in filter_csr are excluded from recommendations.""" @@ -153,9 +153,7 @@ def test_filter_csr_and_whitelist_combined(self): shape=(3, 5), ) - user_ids, item_ids, scores = rank_topk( - user_embs, item_embs, k, filter_csr=filter_csr, whitelist=whitelist - ) + user_ids, item_ids, scores = rank_topk(user_embs, item_embs, k, filter_csr=filter_csr, whitelist=whitelist) # user0 whitelist scores: item0(2), item1(5), item3(4) # After filter (item1 excluded): item0(2), item3(4) diff --git a/tests/fast_transformers/test_unisrec_lightning.py b/tests/fast_transformers/test_unisrec_lightning.py index 855c0616..871cb2be 100644 --- a/tests/fast_transformers/test_unisrec_lightning.py +++ b/tests/fast_transformers/test_unisrec_lightning.py @@ -2,17 +2,17 @@ import math -import torch import pytest +import torch -from rectools.fast_transformers.unisrec_net import UniSRec from rectools.fast_transformers.unisrec_lightning import ( - UniSRecLightning, - _cosine_warmup_scheduler, SUPPORTED_LOSSES, SUPPORTED_OPTIMIZERS, SUPPORTED_SCHEDULERS, + UniSRecLightning, + _cosine_warmup_scheduler, ) +from rectools.fast_transformers.unisrec_net import UniSRec @pytest.fixture() @@ -170,7 +170,10 @@ def test_lr_at_end_equals_min_lr_ratio(self) -> None: min_lr_ratio = 0.1 opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) scheduler = _cosine_warmup_scheduler( - opt, warmup_steps=10, total_steps=100, min_lr_ratio=min_lr_ratio, + opt, + warmup_steps=10, + total_steps=100, + min_lr_ratio=min_lr_ratio, ) lr_fn = scheduler.lr_lambdas[0] # At total_steps, progress = 1, cos(pi) = -1 => factor = min_lr_ratio @@ -183,7 +186,10 @@ def test_lr_at_cosine_midpoint(self) -> None: min_lr_ratio = 0.0 opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) scheduler = _cosine_warmup_scheduler( - opt, warmup_steps=warmup_steps, total_steps=total_steps, min_lr_ratio=min_lr_ratio, + opt, + warmup_steps=warmup_steps, + total_steps=total_steps, + min_lr_ratio=min_lr_ratio, ) lr_fn = scheduler.lr_lambdas[0] midpoint = warmup_steps + (total_steps - warmup_steps) // 2 # 60 @@ -195,7 +201,10 @@ def test_lr_with_nonzero_min_lr_ratio(self) -> None: min_lr_ratio = 0.3 opt = torch.optim.Adam([torch.nn.Parameter(torch.zeros(1))], lr=1.0) scheduler = _cosine_warmup_scheduler( - opt, warmup_steps=0, total_steps=100, min_lr_ratio=min_lr_ratio, + opt, + warmup_steps=0, + total_steps=100, + min_lr_ratio=min_lr_ratio, ) lr_fn = scheduler.lr_lambdas[0] # At step 0 (warmup_steps=0, so cosine phase), progress=0, cos(0)=1 => factor=1.0 diff --git a/tests/fast_transformers/test_unisrec_model.py b/tests/fast_transformers/test_unisrec_model.py index a3de7d7d..13bba453 100644 --- a/tests/fast_transformers/test_unisrec_model.py +++ b/tests/fast_transformers/test_unisrec_model.py @@ -143,7 +143,9 @@ def test_invalid_optimizer_raises(self) -> None: class TestScheduler: def test_cosine_warmup(self) -> None: user_ids, item_ids, timestamps = _make_interactions() - model = _make_model(scheduler="cosine_warmup", warmup_ratio=0.1, phase1_epochs=0, phase2_epochs=0, phase3_epochs=2) + model = _make_model( + scheduler="cosine_warmup", warmup_ratio=0.1, phase1_epochs=0, phase2_epochs=0, phase3_epochs=2 + ) model.fit(user_ids, item_ids, timestamps) assert model.is_fitted diff --git a/tests/fast_transformers/test_unisrec_net.py b/tests/fast_transformers/test_unisrec_net.py index 61889975..2298beba 100644 --- a/tests/fast_transformers/test_unisrec_net.py +++ b/tests/fast_transformers/test_unisrec_net.py @@ -1,7 +1,7 @@ """Tests for UniSRec network.""" -import torch import pytest +import torch from rectools.fast_transformers.unisrec_net import UniSRec From d68834f077a51fab5abc562a74ca35224449541b Mon Sep 17 00:00:00 2001 From: TOPAPEC Date: Fri, 24 Apr 2026 23:44:35 +0000 Subject: [PATCH 7/7] feat: add ONNX export, hash ID mapping, and map_item_ids MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add hash-based ID mapping (splitmix64) as alternative to dense torch.unique mapping in build_sequences and align_embeddings. - Add UniSRecModel.export_to_onnx() for native ONNX export of encoder and item embeddings (project_all). - Add UniSRecModel.map_item_ids() for external→internal ID conversion at inference time (works for both dense and hash modes). - Remove FlatSASRecModel/FlatSASRecLightning (RecTools-coupled wrappers that duplicated UniSRecModel functionality). - Add tests: hash mapping (including string-derived IDs), ONNX export roundtrip, map_item_ids for both modes. --- rectools/fast_transformers/__init__.py | 8 +- rectools/fast_transformers/gpu_data.py | 42 ++- rectools/fast_transformers/lightning_wrap.py | 76 ----- rectools/fast_transformers/model.py | 320 ------------------ rectools/fast_transformers/unisrec_model.py | 95 +++++- tests/fast_transformers/conftest.py | 31 -- tests/fast_transformers/test_gpu_data.py | 177 ++++++++++ .../fast_transformers/test_lightning_wrap.py | 174 ---------- tests/fast_transformers/test_model.py | 93 ----- tests/fast_transformers/test_onnx_export.py | 252 ++++++++++++++ tests/fast_transformers/test_unisrec_model.py | 43 +++ 11 files changed, 605 insertions(+), 706 deletions(-) delete mode 100644 rectools/fast_transformers/lightning_wrap.py delete mode 100644 rectools/fast_transformers/model.py delete mode 100644 tests/fast_transformers/conftest.py delete mode 100644 tests/fast_transformers/test_lightning_wrap.py delete mode 100644 tests/fast_transformers/test_model.py create mode 100644 tests/fast_transformers/test_onnx_export.py diff --git a/rectools/fast_transformers/__init__.py b/rectools/fast_transformers/__init__.py index 1f129c37..7ad04123 100644 --- a/rectools/fast_transformers/__init__.py +++ b/rectools/fast_transformers/__init__.py @@ -1,8 +1,6 @@ """Fast Transformers: flat sequential recommenders without ItemNet hierarchy.""" -from .gpu_data import GPUBatchDataset, align_embeddings, build_sequences, make_dataloader -from .lightning_wrap import FlatSASRecLightning -from .model import FlatSASRecConfig, FlatSASRecModel +from .gpu_data import GPUBatchDataset, align_embeddings, build_sequences, hash_item_ids, make_dataloader from .net import FlatSASRec, SASRecBlock from .ranking import rank_topk from .unisrec_lightning import UniSRecLightning @@ -12,13 +10,11 @@ __all__ = [ "build_sequences", "align_embeddings", + "hash_item_ids", "GPUBatchDataset", "make_dataloader", "FlatSASRec", "SASRecBlock", - "FlatSASRecLightning", - "FlatSASRecModel", - "FlatSASRecConfig", "rank_topk", "UniSRec", "FeedForward", diff --git a/rectools/fast_transformers/gpu_data.py b/rectools/fast_transformers/gpu_data.py index 5a8d7eee..5906706e 100644 --- a/rectools/fast_transformers/gpu_data.py +++ b/rectools/fast_transformers/gpu_data.py @@ -7,6 +7,26 @@ from torch.utils.data import Dataset as TorchDataset +def _splitmix64(x: torch.Tensor) -> torch.Tensor: + """Vectorized splitmix64 bit-mixer: element-wise int64 hash over a torch tensor. + + Standard library hashes (``hash()``, ``hashlib``) operate on scalar Python objects + and cannot be vectorized across GPU tensors. Splitmix64 is pure int64 arithmetic, + so it maps naturally to ``torch.Tensor`` ops and runs on any device. + + Reference: https://xorshift.di.unimi.it/splitmix64.c (Vigna, 2015). + """ + x = x.long() + x = (x ^ (x >> 30)) * (-4658895280553007687) # 0xbf58476d1ce4e5b9 as signed int64 + x = (x ^ (x >> 27)) * (-7723592293110705685) # 0x94d049bb133111eb as signed int64 + return x ^ (x >> 31) + + +def hash_item_ids(item_ids: torch.Tensor, dict_size: int) -> torch.Tensor: + """Map arbitrary integer item IDs to [1, dict_size] via splitmix64 hash.""" + return _splitmix64(item_ids) % dict_size + 1 + + def build_sequences( user_ids: torch.Tensor, item_ids: torch.Tensor, @@ -14,13 +34,22 @@ def build_sequences( max_len: int, min_interactions: int = 2, device: str = "cuda", + id_mapping: str = "dense", ) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: user_ids = user_ids.to(device) item_ids = item_ids.to(device) timestamps = timestamps.to(device) - unique_items, item_inv = torch.unique(item_ids, return_inverse=True) - internal_items = item_inv + 1 + unique_items = torch.unique(item_ids) + n_unique = len(unique_items) + + if id_mapping == "dense": + _, item_inv = torch.unique(item_ids, return_inverse=True) + internal_items = item_inv + 1 + elif id_mapping == "hash": + internal_items = hash_item_ids(item_ids, n_unique) + else: + raise ValueError(f"Unknown id_mapping: {id_mapping}. Use 'dense' or 'hash'") unique_users, user_inv = torch.unique(user_ids, return_inverse=True) @@ -74,16 +103,23 @@ def align_embeddings( pretrained: torch.Tensor, unique_items: torch.Tensor, n_items: int, + id_mapping: str = "dense", ) -> torch.Tensor: idx = unique_items.long().cpu() valid = (idx >= 0) & (idx < pretrained.shape[0]) if pretrained.ndim == 2: aligned = torch.zeros(n_items + 1, pretrained.shape[1]) - aligned[1:][valid] = pretrained[idx[valid]] else: aligned = torch.zeros(n_items + 1, pretrained.shape[1], pretrained.shape[2]) + + if id_mapping == "dense": aligned[1:][valid] = pretrained[idx[valid]] + elif id_mapping == "hash": + positions = hash_item_ids(idx, n_items) + aligned[positions[valid]] = pretrained[idx[valid]] + else: + raise ValueError(f"Unknown id_mapping: {id_mapping}. Use 'dense' or 'hash'") return aligned diff --git a/rectools/fast_transformers/lightning_wrap.py b/rectools/fast_transformers/lightning_wrap.py deleted file mode 100644 index 75d20a39..00000000 --- a/rectools/fast_transformers/lightning_wrap.py +++ /dev/null @@ -1,76 +0,0 @@ -"""PyTorch Lightning wrapper for FlatSASRec.""" - -import typing as tp - -import pytorch_lightning as pl -import torch -from torch import nn - -from .net import FlatSASRec - - -class FlatSASRecLightning(pl.LightningModule): - """Lightning module wrapping FlatSASRec with softmax / BCE losses.""" - - SUPPORTED_LOSSES = ("softmax", "BCE") - - def __init__( - self, - net: FlatSASRec, - lr: float = 1e-3, - loss: str = "softmax", - n_negatives: int = 1, - ) -> None: - super().__init__() - self.net = net - self.lr = lr - self.loss_name = loss - self.n_negatives = n_negatives - - if loss == "softmax": - self.loss_fn = nn.CrossEntropyLoss(ignore_index=0) - elif loss == "BCE": - self.loss_fn = nn.BCEWithLogitsLoss(reduction="none") - else: - raise ValueError(f"Unsupported loss: {loss}. Use one of {self.SUPPORTED_LOSSES}") - - def on_train_start(self) -> None: - for p in self.net.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - - def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: - logits = self.net(batch) - y = batch["y"] # (B, L) - mask = y != FlatSASRec.PADDING_IDX # ignore padding positions - - if self.loss_name == "softmax": - # logits: (B, L, n_items) — full catalog - # targets need to be 0-indexed item ids (subtract 1 since item ids start from 1) - targets = ( - y - 1 - ) # shift to 0-based for CrossEntropyLoss; padding (0) becomes -1 -> ignore_index=0 won't work - # Actually, we set ignore_index=0 but padding maps to -1. - # Let's use a different approach: set padding targets to 0 and use ignore_index=0 - targets = y.clone() - targets[~mask] = 0 - # For CE loss: targets should index into logits dim=-1 which is [0..n_items-1] - # Our item ids in y are 1..n_items, so subtract 1 - targets = targets - 1 - targets[~mask] = -100 # PyTorch ignore index - loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-100) - else: - # BCE: logits shape (B, L, 1+N) - B, L, C = logits.shape - labels = torch.zeros(B, L, C, device=logits.device) - labels[:, :, 0] = 1.0 # first column is positive - loss_per_elem = self.loss_fn(logits, labels) # (B, L, C) - # Mask out padding positions - loss_per_elem = loss_per_elem * mask.unsqueeze(-1).float() - loss = loss_per_elem.sum() / mask.sum().clamp(min=1) / C - - self.log("train_loss", loss, prog_bar=True) - return loss - - def configure_optimizers(self) -> torch.optim.Optimizer: - return torch.optim.Adam(self.parameters(), lr=self.lr, betas=(0.9, 0.98)) diff --git a/rectools/fast_transformers/model.py b/rectools/fast_transformers/model.py deleted file mode 100644 index ba2b2405..00000000 --- a/rectools/fast_transformers/model.py +++ /dev/null @@ -1,320 +0,0 @@ -"""FlatSASRecModel: standalone flat sequential recommender built on ModelBase.""" - -import typing as tp - -import pandas as pd -import pytorch_lightning as pl -import torch -from scipy import sparse - -from rectools.dataset import Dataset -from rectools.models.base import InternalRecoTriplet, ModelBase, ModelConfig -from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler -from rectools.models.nn.transformers.sasrec import SASRecDataPreparator -from rectools.types import InternalIdsArray -from rectools.utils.config import BaseConfig - -from .lightning_wrap import FlatSASRecLightning -from .net import FlatSASRec -from .ranking import rank_topk - - -class FlatSASRecConfig(BaseConfig): - """Configuration for FlatSASRecModel.""" - - n_factors: int = 64 - n_blocks: int = 2 - n_heads: int = 2 - session_max_len: int = 32 - dropout: float = 0.1 - loss: str = "softmax" - n_negatives: int = 1 - epochs: int = 5 - batch_size: int = 128 - lr: float = 1e-3 - recommend_batch_size: int = 256 - dataloader_num_workers: int = 0 - train_min_user_interactions: int = 2 - - -class FlatSASRecModelConfig(ModelConfig): - """Full model config including cls.""" - - model: FlatSASRecConfig = FlatSASRecConfig() - - -class FlatSASRecModel(ModelBase[FlatSASRecModelConfig]): - """ - Flat SASRec model: sequential recommender without the ItemNet hierarchy. - - Uses SASRecDataPreparator for data processing and a standalone FlatSASRec - network for encoding. - """ - - config_class = FlatSASRecModelConfig - recommends_for_warm = False - recommends_for_cold = False - - def __init__( - self, - n_factors: int = 64, - n_blocks: int = 2, - n_heads: int = 2, - session_max_len: int = 32, - dropout: float = 0.1, - loss: str = "softmax", - n_negatives: int = 1, - epochs: int = 5, - batch_size: int = 128, - lr: float = 1e-3, - recommend_batch_size: int = 256, - dataloader_num_workers: int = 0, - train_min_user_interactions: int = 2, - verbose: int = 0, - ) -> None: - super().__init__(verbose=verbose) - - if loss not in FlatSASRecLightning.SUPPORTED_LOSSES: - raise ValueError(f"Unsupported loss '{loss}'. Choose from {FlatSASRecLightning.SUPPORTED_LOSSES}") - - self.n_factors = n_factors - self.n_blocks = n_blocks - self.n_heads = n_heads - self.session_max_len = session_max_len - self.dropout = dropout - self.loss = loss - self.n_negatives = n_negatives - self.epochs = epochs - self.batch_size = batch_size - self.lr = lr - self.recommend_batch_size = recommend_batch_size - self.dataloader_num_workers = dataloader_num_workers - self.train_min_user_interactions = train_min_user_interactions - - self._net: tp.Optional[FlatSASRec] = None - self._lightning: tp.Optional[FlatSASRecLightning] = None - self._data_preparator: tp.Optional[SASRecDataPreparator] = None - - def _get_config(self) -> FlatSASRecModelConfig: - return FlatSASRecModelConfig( - cls=self.__class__, - verbose=self.verbose, - model=FlatSASRecConfig( - n_factors=self.n_factors, - n_blocks=self.n_blocks, - n_heads=self.n_heads, - session_max_len=self.session_max_len, - dropout=self.dropout, - loss=self.loss, - n_negatives=self.n_negatives, - epochs=self.epochs, - batch_size=self.batch_size, - lr=self.lr, - recommend_batch_size=self.recommend_batch_size, - dataloader_num_workers=self.dataloader_num_workers, - train_min_user_interactions=self.train_min_user_interactions, - ), - ) - - @classmethod - def _from_config(cls, config: FlatSASRecModelConfig) -> "FlatSASRecModel": - m = config.model - return cls( - n_factors=m.n_factors, - n_blocks=m.n_blocks, - n_heads=m.n_heads, - session_max_len=m.session_max_len, - dropout=m.dropout, - loss=m.loss, - n_negatives=m.n_negatives, - epochs=m.epochs, - batch_size=m.batch_size, - lr=m.lr, - recommend_batch_size=m.recommend_batch_size, - dataloader_num_workers=m.dataloader_num_workers, - train_min_user_interactions=m.train_min_user_interactions, - verbose=config.verbose, - ) - - def _fit(self, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> None: - negative_sampler = None - n_negatives_dp: tp.Optional[int] = None - if self.loss == "BCE": - negative_sampler = CatalogUniformSampler(n_negatives=self.n_negatives) - n_negatives_dp = self.n_negatives - - dp = SASRecDataPreparator( - session_max_len=self.session_max_len, - batch_size=self.batch_size, - dataloader_num_workers=self.dataloader_num_workers, - train_min_user_interactions=self.train_min_user_interactions, - n_negatives=n_negatives_dp, - negative_sampler=negative_sampler, - ) - dp.process_dataset_train(dataset) - self._data_preparator = dp - - n_real_items = dp.item_id_map.size - dp.n_item_extra_tokens - - net = FlatSASRec( - n_items=n_real_items, - n_factors=self.n_factors, - n_blocks=self.n_blocks, - n_heads=self.n_heads, - session_max_len=self.session_max_len, - dropout=self.dropout, - ) - - lightning_model = FlatSASRecLightning( - net=net, - lr=self.lr, - loss=self.loss, - n_negatives=self.n_negatives, - ) - - train_dl = dp.get_dataloader_train() - val_dl = dp.get_dataloader_val() - - trainer = pl.Trainer( - max_epochs=self.epochs, - enable_checkpointing=False, - enable_model_summary=False, - logger=self.verbose > 0, - enable_progress_bar=self.verbose > 0, - ) - trainer.fit(lightning_model, train_dataloaders=train_dl, val_dataloaders=val_dl) - - self._net = net - self._lightning = lightning_model - - def _custom_transform_dataset_u2i( - self, - dataset: Dataset, - users: tp.Any, - on_unsupported_targets: tp.Any, - context: tp.Optional[pd.DataFrame] = None, - ) -> Dataset: - assert self._data_preparator is not None - return self._data_preparator.transform_dataset_u2i(dataset, users) - - def _custom_transform_dataset_i2i( - self, dataset: Dataset, target_items: tp.Any, on_unsupported_targets: tp.Any - ) -> Dataset: - assert self._data_preparator is not None - return self._data_preparator.transform_dataset_i2i(dataset) - - @torch.no_grad() - def _get_user_embeddings(self, dataset: Dataset) -> torch.Tensor: - """Compute user embeddings from their interaction sequences.""" - assert self._data_preparator is not None and self._net is not None - self._net.eval() - - recommend_dl = self._data_preparator.get_dataloader_recommend(dataset, self.recommend_batch_size) - device = next(self._net.parameters()).device - - all_embs = [] - for batch in recommend_dl: - x = batch["x"].to(device) - embs = self._net.encode_last(x) # (batch, D) - all_embs.append(embs) - return torch.cat(all_embs, dim=0) - - @torch.no_grad() - def _get_item_embeddings(self) -> torch.Tensor: - """Get all item embeddings from the network.""" - assert self._net is not None - self._net.eval() - return self._net.all_item_embeddings() - - def _recommend_u2i( - self, - user_ids: InternalIdsArray, - dataset: Dataset, - k: int, - filter_viewed: bool, - sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], - ) -> InternalRecoTriplet: - assert self._data_preparator is not None - - user_embs = self._get_user_embeddings(dataset) # (n_users, D) - item_embs = self._get_item_embeddings() # (n_items, D) - - # Build filter matrix - filter_csr = None - if filter_viewed: - ui_mat = dataset.get_user_item_matrix(include_weights=False) - n_users_mat = ui_mat.shape[0] - n_items_emb = item_embs.shape[0] - n_extra = self._data_preparator.n_item_extra_tokens - # item_embs[i] corresponds to preparator internal item id (i + n_extra). - # ui_mat columns are dataset internal item ids which share the preparator's id_map. - # Slice out the extra-token columns and pad/trim to exactly n_items_emb cols. - if ui_mat.shape[1] > n_extra: - sliced = ui_mat[:, n_extra:] - else: - sliced = sparse.csr_matrix((n_users_mat, 0)) - n_cols = sliced.shape[1] - if n_cols < n_items_emb: - pad = sparse.csr_matrix((n_users_mat, n_items_emb - n_cols)) - filter_csr = sparse.hstack([sliced, pad], format="csr") - elif n_cols > n_items_emb: - filter_csr = sliced[:, :n_items_emb] - else: - filter_csr = sliced - - # Map whitelist to item_embs indices (0-based, without extra tokens) - whitelist = None - if sorted_item_ids_to_recommend is not None: - n_extra = self._data_preparator.n_item_extra_tokens - wl = sorted_item_ids_to_recommend - n_extra - whitelist = wl[(wl >= 0) & (wl < item_embs.shape[0])] - - u_ids, i_ids, scores = rank_topk( - user_embs, - item_embs, - k, - filter_csr=filter_csr, - whitelist=whitelist, - batch_size=self.recommend_batch_size, - ) - - # Convert item indices back to preparator's internal ids - n_extra = self._data_preparator.n_item_extra_tokens - i_ids = i_ids + n_extra - - return u_ids, i_ids, scores - - def _recommend_i2i( - self, - target_ids: InternalIdsArray, - dataset: Dataset, - k: int, - sorted_item_ids_to_recommend: tp.Optional[InternalIdsArray], - ) -> InternalRecoTriplet: - assert self._data_preparator is not None and self._net is not None - - item_embs = self._get_item_embeddings() # (n_items, D) - n_extra = self._data_preparator.n_item_extra_tokens - - # Target embeddings: target_ids are preparator internal ids - target_emb_idx = target_ids - n_extra - target_embs = item_embs[target_emb_idx] # (n_targets, D) - - whitelist = None - if sorted_item_ids_to_recommend is not None: - wl = sorted_item_ids_to_recommend - n_extra - whitelist = wl[(wl >= 0) & (wl < item_embs.shape[0])] - - t_ids, i_ids, scores = rank_topk( - target_embs, - item_embs, - k, - whitelist=whitelist, - batch_size=self.recommend_batch_size, - ) - - # Map back - result_target_ids = target_ids[t_ids] - result_item_ids = i_ids + n_extra - - return result_target_ids, result_item_ids, scores diff --git a/rectools/fast_transformers/unisrec_model.py b/rectools/fast_transformers/unisrec_model.py index cbb7b632..5f70f6bc 100644 --- a/rectools/fast_transformers/unisrec_model.py +++ b/rectools/fast_transformers/unisrec_model.py @@ -7,11 +7,20 @@ import torch from pytorch_lightning.callbacks import EarlyStopping -from .gpu_data import align_embeddings, build_sequences, make_dataloader +from .gpu_data import align_embeddings, build_sequences, hash_item_ids, make_dataloader from .unisrec_lightning import SUPPORTED_LOSSES, SUPPORTED_OPTIMIZERS, SUPPORTED_SCHEDULERS, UniSRecLightning from .unisrec_net import UniSRec +class _ProjectAllWrapper(torch.nn.Module): + def __init__(self, net: UniSRec) -> None: + super().__init__() + self.net = net + + def forward(self) -> torch.Tensor: + return self.net.project_all() + + class UniSRecModel: """ UniSRec sequential recommender with pretrained text embeddings. @@ -73,6 +82,7 @@ def __init__( batch_size: int = 128, dataloader_num_workers: int = 0, train_min_user_interactions: int = 2, + id_mapping: str = "dense", verbose: int = 0, ) -> None: if loss not in SUPPORTED_LOSSES: @@ -118,6 +128,7 @@ def __init__( self.batch_size = batch_size self.dataloader_num_workers = dataloader_num_workers self.train_min_user_interactions = train_min_user_interactions + self.id_mapping = id_mapping self.verbose = verbose self._net: tp.Optional[UniSRec] = None @@ -268,12 +279,13 @@ def fit( timestamps, max_len=self.session_max_len, min_interactions=self.train_min_user_interactions, + id_mapping=self.id_mapping, ) self._unique_items = unique_items.cpu() self._unique_users = unique_users.cpu() n_items = len(unique_items) - aligned_emb = align_embeddings(self.pretrained_item_embeddings, unique_items, n_items) + aligned_emb = align_embeddings(self.pretrained_item_embeddings, unique_items, n_items, self.id_mapping) net = UniSRec( n_items=n_items, @@ -328,6 +340,7 @@ def save_checkpoint(self, path: tp.Union[str, Path]) -> None: "unique_items": self._unique_items, "unique_users": self._unique_users, "n_items": len(self._unique_items), + "id_mapping": self.id_mapping, }, path, ) @@ -337,8 +350,9 @@ def load_checkpoint(self, path: tp.Union[str, Path], device: str = "cuda") -> No self._unique_items = ckpt["unique_items"].cpu() self._unique_users = ckpt["unique_users"].cpu() n_items = ckpt["n_items"] + self.id_mapping = ckpt.get("id_mapping", "dense") - aligned_emb = align_embeddings(self.pretrained_item_embeddings, self._unique_items, n_items) + aligned_emb = align_embeddings(self.pretrained_item_embeddings, self._unique_items, n_items, self.id_mapping) self._net = UniSRec( n_items=n_items, @@ -359,6 +373,81 @@ def load_checkpoint(self, path: tp.Union[str, Path], device: str = "cuda") -> No self._net.to(device).eval() self.is_fitted = True + # ── ONNX export ── + + def export_to_onnx( + self, + encoder_path: tp.Union[str, Path], + items_path: tp.Optional[tp.Union[str, Path]] = None, + opset_version: int = 18, + ) -> None: + """Export the model to ONNX. + + Parameters + ---------- + encoder_path + Path for the encoder graph (input_ids -> hidden states). + items_path + If given, also exports project_all (-> item embeddings). + opset_version + ONNX opset version (default 18). + """ + assert self._net is not None, "Model not fitted or loaded" + net = self._net + was_training = net.training + net.eval() + + device = next(net.parameters()).device + dummy = torch.zeros(1, 5, dtype=torch.long, device=device) + + torch.onnx.export( + net, + (dummy, False), + str(encoder_path), + input_names=["input_ids"], + output_names=["hidden"], + opset_version=opset_version, + ) + + if items_path is not None: + wrapper = _ProjectAllWrapper(net) + wrapper.eval() + torch.onnx.export( + wrapper, + (), + str(items_path), + input_names=[], + output_names=["item_embs"], + opset_version=opset_version, + ) + + if was_training: + net.train() + + def map_item_ids(self, external_ids: torch.Tensor) -> torch.Tensor: + """Map external item IDs to internal IDs used by the model. + + Parameters + ---------- + external_ids : LongTensor + External item IDs. + + Returns + ------- + LongTensor + Internal IDs in ``[0, n_items]``. 0 means unknown item. + """ + assert self._unique_items is not None, "Model not fitted or loaded" + if self.id_mapping == "hash": + n_items = len(self._unique_items) + known = torch.isin(external_ids, self._unique_items) + result = torch.zeros_like(external_ids) + result[known] = hash_item_ids(external_ids[known], n_items) + return result + + lookup = {int(v): i + 1 for i, v in enumerate(self._unique_items.tolist())} + return torch.tensor([lookup.get(int(x), 0) for x in external_ids.tolist()], dtype=torch.long) + @property def net(self) -> UniSRec: assert self._net is not None, "Model not fitted or loaded" diff --git a/tests/fast_transformers/conftest.py b/tests/fast_transformers/conftest.py deleted file mode 100644 index ddf4468f..00000000 --- a/tests/fast_transformers/conftest.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Fixtures for fast_transformers tests.""" - -import numpy as np -import pandas as pd -import pytest - -from rectools import Columns -from rectools.dataset import Dataset - - -@pytest.fixture() -def tiny_dataset() -> Dataset: - """20 users x 25 items, each user has 3-8 interactions.""" - rng = np.random.RandomState(42) - n_users, n_items = 20, 25 - - rows = [] - for u in range(n_users): - n_inter = rng.randint(3, 9) - items = rng.choice(n_items, size=n_inter, replace=False) - for rank, item in enumerate(items): - rows.append( - { - Columns.User: u, - Columns.Item: item, - Columns.Weight: 1.0, - Columns.Datetime: pd.Timestamp("2023-01-01") + pd.Timedelta(days=rank), - } - ) - df = pd.DataFrame(rows) - return Dataset.construct(df) diff --git a/tests/fast_transformers/test_gpu_data.py b/tests/fast_transformers/test_gpu_data.py index 7b69c1dd..7717b6fe 100644 --- a/tests/fast_transformers/test_gpu_data.py +++ b/tests/fast_transformers/test_gpu_data.py @@ -1,11 +1,15 @@ """Tests for GPU-native sequence building and data utilities.""" +import hashlib + +import pytest import torch from rectools.fast_transformers.gpu_data import ( GPUBatchDataset, align_embeddings, build_sequences, + hash_item_ids, make_dataloader, ) @@ -455,3 +459,176 @@ def test_single_sample_batch(self) -> None: batch = next(iter(dl)) assert batch["x"].shape == (1, 3) assert batch["y"].shape == (1, 3) + + +class TestHashItemIds: + """Tests for hash_item_ids and _splitmix64.""" + + def test_output_range(self) -> None: + ids = torch.tensor([0, 1, 100, 999, -5]) + result = hash_item_ids(ids, 50) + assert result.min() >= 1 + assert result.max() <= 50 + + def test_deterministic(self) -> None: + ids = torch.tensor([1, 2, 3]) + r1 = hash_item_ids(ids, 100) + r2 = hash_item_ids(ids, 100) + assert r1.tolist() == r2.tolist() + + def test_different_inputs_spread(self) -> None: + ids = torch.arange(100) + result = hash_item_ids(ids, 1000) + assert len(result.unique()) >= 90 + + def test_large_negative_values(self) -> None: + ids = torch.tensor([-(2**62), -(2**60), -1, 0, 1, 2**60, 2**62]) + result = hash_item_ids(ids, 200) + assert result.min() >= 1 + assert result.max() <= 200 + + def test_string_derived_ids(self) -> None: + """Workflow: hash strings via hashlib -> int64 tensor -> hash_item_ids.""" + strings = ["item_abc", "product_42", "sku-99", "uuid-xxx-yyy", ""] + int_ids = torch.tensor( + [int.from_bytes(hashlib.sha256(s.encode()).digest()[:8], "little", signed=True) for s in strings], + dtype=torch.long, + ) + result = hash_item_ids(int_ids, 100) + assert result.min() >= 1 + assert result.max() <= 100 + assert result.shape == (5,) + + def test_string_ids_deterministic(self) -> None: + strings = ["hello", "world"] + int_ids = torch.tensor( + [int.from_bytes(hashlib.sha256(s.encode()).digest()[:8], "little", signed=True) for s in strings], + dtype=torch.long, + ) + r1 = hash_item_ids(int_ids, 50) + r2 = hash_item_ids(int_ids, 50) + assert r1.tolist() == r2.tolist() + + def test_string_ids_spread(self) -> None: + """Many distinct strings should produce well-spread hash values.""" + strings = [f"item_{i}" for i in range(200)] + int_ids = torch.tensor( + [int.from_bytes(hashlib.sha256(s.encode()).digest()[:8], "little", signed=True) for s in strings], + dtype=torch.long, + ) + result = hash_item_ids(int_ids, 1000) + assert len(result.unique()) >= 180 + + +class TestBuildSequencesHash: + """Tests for build_sequences with id_mapping='hash'.""" + + def test_basic_shape(self) -> None: + user_ids = torch.tensor([0, 0, 0, 1, 1, 1]) + item_ids = torch.tensor([10, 20, 30, 40, 50, 60]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6]) + x, y, unique_items, result_users = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE, id_mapping="hash" + ) + assert x.shape == (2, 4) + assert y.shape == (2, 4) + assert result_users.tolist() == [0, 1] + + def test_values_in_range(self) -> None: + user_ids = torch.tensor([0, 0, 0, 1, 1, 1]) + item_ids = torch.tensor([10, 20, 30, 40, 50, 60]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6]) + x, y, unique_items, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE, id_mapping="hash" + ) + n_unique = len(unique_items) + nonzero_x = x[x != 0] + assert nonzero_x.min() >= 1 + assert nonzero_x.max() <= n_unique + nonzero_y = y[y != 0] + assert nonzero_y.min() >= 1 + assert nonzero_y.max() <= n_unique + + def test_left_padding_preserved(self) -> None: + user_ids = torch.tensor([0, 0]) + item_ids = torch.tensor([10, 20]) + timestamps = torch.tensor([1, 2]) + x, y, _, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=5, min_interactions=2, device=DEVICE, id_mapping="hash" + ) + assert x[0, :4].tolist() == [0, 0, 0, 0] + assert x[0, 4] != 0 + + def test_unique_items_unchanged(self) -> None: + """unique_items is always the sorted set of external IDs, regardless of id_mapping.""" + user_ids = torch.tensor([0, 0, 0]) + item_ids = torch.tensor([100, 50, 200]) + timestamps = torch.tensor([1, 2, 3]) + _, _, unique_items, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=5, min_interactions=2, device=DEVICE, id_mapping="hash" + ) + assert unique_items.tolist() == [50, 100, 200] + + def test_invalid_id_mapping_raises(self) -> None: + with pytest.raises(ValueError, match="Unknown id_mapping"): + build_sequences( + torch.tensor([0, 0]), + torch.tensor([1, 2]), + torch.tensor([1, 2]), + max_len=3, + min_interactions=2, + device=DEVICE, + id_mapping="invalid", + ) + + def test_same_item_same_hash(self) -> None: + """Same external item ID used by different users should get the same internal hash.""" + user_ids = torch.tensor([0, 0, 0, 1, 1, 1]) + item_ids = torch.tensor([10, 20, 30, 20, 30, 40]) + timestamps = torch.tensor([1, 2, 3, 4, 5, 6]) + x, y, _, _ = build_sequences( + user_ids, item_ids, timestamps, max_len=4, min_interactions=2, device=DEVICE, id_mapping="hash" + ) + hash_20 = hash_item_ids(torch.tensor([20]), len(torch.unique(item_ids))).item() + hash_30 = hash_item_ids(torch.tensor([30]), len(torch.unique(item_ids))).item() + all_vals = torch.cat([x.flatten(), y.flatten()]) + assert hash_20 in all_vals.tolist() + assert hash_30 in all_vals.tolist() + + +class TestAlignEmbeddingsHash: + """Tests for align_embeddings with id_mapping='hash'.""" + + def test_embeddings_at_hash_positions(self) -> None: + pretrained = torch.zeros(4, 2) + pretrained[1] = torch.tensor([3.0, 4.0]) + pretrained[2] = torch.tensor([5.0, 6.0]) + pretrained[3] = torch.tensor([7.0, 8.0]) + unique_items = torch.tensor([1, 2, 3]) + n_items = 10 + aligned = align_embeddings(pretrained, unique_items, n_items, id_mapping="hash") + assert aligned.shape == (11, 2) + assert aligned[0].tolist() == [0.0, 0.0] + positions = hash_item_ids(unique_items, n_items) + for i, ext_id in enumerate(unique_items): + pos = positions[i].item() + assert aligned[pos].tolist() == pretrained[ext_id].tolist() + + def test_3d_hash_mode(self) -> None: + pretrained = torch.zeros(4, 2, 2) + pretrained[1] = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + pretrained[2] = torch.tensor([[5.0, 6.0], [7.0, 8.0]]) + pretrained[3] = torch.tensor([[9.0, 10.0], [11.0, 12.0]]) + unique_items = torch.tensor([1, 2, 3]) + n_items = 10 + aligned = align_embeddings(pretrained, unique_items, n_items, id_mapping="hash") + assert aligned.shape == (11, 2, 2) + assert aligned[0].tolist() == [[0.0, 0.0], [0.0, 0.0]] + positions = hash_item_ids(unique_items, n_items) + for i, ext_id in enumerate(unique_items): + pos = positions[i].item() + torch.testing.assert_close(aligned[pos], pretrained[ext_id]) + + def test_invalid_id_mapping_raises(self) -> None: + with pytest.raises(ValueError, match="Unknown id_mapping"): + align_embeddings(torch.randn(5, 2), torch.tensor([1, 2]), 2, id_mapping="bad") diff --git a/tests/fast_transformers/test_lightning_wrap.py b/tests/fast_transformers/test_lightning_wrap.py deleted file mode 100644 index e45fccfe..00000000 --- a/tests/fast_transformers/test_lightning_wrap.py +++ /dev/null @@ -1,174 +0,0 @@ -"""Tests for FlatSASRecLightning wrapper.""" - -import pytest -import torch - -from rectools.fast_transformers.lightning_wrap import FlatSASRecLightning -from rectools.fast_transformers.net import FlatSASRec - - -@pytest.fixture() -def net() -> FlatSASRec: - return FlatSASRec( - n_items=10, - n_factors=8, - n_blocks=1, - n_heads=1, - session_max_len=5, - dropout=0.0, - ) - - -class TestFlatSASRecLightning: - # ---- constructor ---- - - def test_init_softmax_loss(self, net: FlatSASRec) -> None: - module = FlatSASRecLightning(net, loss="softmax") - assert module.loss_name == "softmax" - assert isinstance(module.loss_fn, torch.nn.CrossEntropyLoss) - - def test_init_bce_loss(self, net: FlatSASRec) -> None: - module = FlatSASRecLightning(net, loss="BCE") - assert module.loss_name == "BCE" - assert isinstance(module.loss_fn, torch.nn.BCEWithLogitsLoss) - - def test_init_invalid_loss_raises(self, net: FlatSASRec) -> None: - with pytest.raises(ValueError, match="Unsupported loss"): - FlatSASRecLightning(net, loss="mse") - - def test_init_stores_hyperparams(self, net: FlatSASRec) -> None: - module = FlatSASRecLightning(net, lr=0.005, n_negatives=4) - assert module.lr == 0.005 - assert module.n_negatives == 4 - - # ---- configure_optimizers ---- - - def test_configure_optimizers_type_and_lr(self, net: FlatSASRec) -> None: - lr = 2e-4 - module = FlatSASRecLightning(net, lr=lr) - optimizer = module.configure_optimizers() - assert isinstance(optimizer, torch.optim.Adam) - assert optimizer.defaults["lr"] == lr - - def test_configure_optimizers_betas(self, net: FlatSASRec) -> None: - module = FlatSASRecLightning(net) - optimizer = module.configure_optimizers() - assert optimizer.defaults["betas"] == (0.9, 0.98) - - # ---- on_train_start ---- - - def test_on_train_start_reinitializes_params(self, net: FlatSASRec) -> None: - module = FlatSASRecLightning(net) - - # Snapshot parameters with dim > 1 before reinit - snapshots_before = {name: p.clone() for name, p in module.net.named_parameters() if p.dim() > 1} - assert len(snapshots_before) > 0, "Expected at least one param with dim > 1" - - # Force parameters to a constant value so reinit is detectable - with torch.no_grad(): - for p in module.net.parameters(): - if p.dim() > 1: - p.fill_(42.0) - - module.on_train_start() - - changed = False - for name, p in module.net.named_parameters(): - if p.dim() > 1 and not torch.all(p == 42.0): - changed = True - break - assert changed, "on_train_start should reinitialize parameters via xavier_uniform_" - - # ---- training_step with softmax ---- - - def test_training_step_softmax_returns_scalar(self, net: FlatSASRec) -> None: - module = FlatSASRecLightning(net, loss="softmax") - batch = { - "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), - "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), - } - loss = module.training_step(batch, batch_idx=0) - assert loss.dim() == 0, "Loss should be a scalar" - assert not torch.isnan(loss), "Loss should not be NaN" - assert not torch.isinf(loss), "Loss should not be Inf" - - def test_training_step_softmax_positive_loss(self, net: FlatSASRec) -> None: - module = FlatSASRecLightning(net, loss="softmax") - batch = { - "x": torch.tensor([[1, 2, 3, 4, 5]]), - "y": torch.tensor([[2, 3, 4, 5, 6]]), - } - loss = module.training_step(batch, batch_idx=0) - assert loss.item() > 0, "Cross-entropy loss should be positive" - - def test_training_step_softmax_all_padding_returns_nan(self, net: FlatSASRec) -> None: - """When all targets are padding (y=0), cross_entropy with ignore_index=-100 returns NaN.""" - module = FlatSASRecLightning(net, loss="softmax") - batch = { - "x": torch.tensor([[0, 0, 0, 0, 0]]), - "y": torch.tensor([[0, 0, 0, 0, 0]]), - } - loss = module.training_step(batch, batch_idx=0) - assert loss.dim() == 0 - # PyTorch cross_entropy returns NaN when all targets are ignored - assert torch.isnan(loss) - - # ---- training_step with BCE ---- - - def test_training_step_bce_returns_scalar(self, net: FlatSASRec) -> None: - n_negatives = 3 - module = FlatSASRecLightning(net, loss="BCE", n_negatives=n_negatives) - batch = { - "x": torch.tensor([[0, 0, 1, 2, 3], [0, 4, 5, 6, 7]]), - "y": torch.tensor([[0, 0, 2, 3, 4], [0, 5, 6, 7, 8]]), - "negatives": torch.randint(1, 10, (2, 5, n_negatives)), - } - loss = module.training_step(batch, batch_idx=0) - assert loss.dim() == 0, "Loss should be a scalar" - assert not torch.isnan(loss), "Loss should not be NaN" - assert not torch.isinf(loss), "Loss should not be Inf" - - def test_training_step_bce_positive_loss(self, net: FlatSASRec) -> None: - n_negatives = 2 - module = FlatSASRecLightning(net, loss="BCE", n_negatives=n_negatives) - batch = { - "x": torch.tensor([[1, 2, 3, 4, 5]]), - "y": torch.tensor([[2, 3, 4, 5, 6]]), - "negatives": torch.randint(1, 10, (1, 5, n_negatives)), - } - loss = module.training_step(batch, batch_idx=0) - assert loss.item() > 0, "BCE loss should be positive" - - def test_training_step_bce_mask_reduces_loss(self, net: FlatSASRec) -> None: - """Padding positions should not contribute to BCE loss.""" - n_negatives = 2 - module = FlatSASRecLightning(net, loss="BCE", n_negatives=n_negatives) - module.eval() - - torch.manual_seed(0) - negs = torch.randint(1, 10, (1, 5, n_negatives)) - - # Batch with no padding - batch_full = { - "x": torch.tensor([[1, 2, 3, 4, 5]]), - "y": torch.tensor([[2, 3, 4, 5, 6]]), - "negatives": negs.clone(), - } - # Batch with partial padding - batch_padded = { - "x": torch.tensor([[0, 0, 3, 4, 5]]), - "y": torch.tensor([[0, 0, 4, 5, 6]]), - "negatives": negs.clone(), - } - - with torch.no_grad(): - loss_full = module.training_step(batch_full, batch_idx=0) - loss_padded = module.training_step(batch_padded, batch_idx=0) - - # Losses should differ because the padded batch masks out some positions - assert loss_full.item() != pytest.approx(loss_padded.item(), abs=1e-6) - - # ---- supported losses constant ---- - - def test_supported_losses_tuple(self) -> None: - assert FlatSASRecLightning.SUPPORTED_LOSSES == ("softmax", "BCE") diff --git a/tests/fast_transformers/test_model.py b/tests/fast_transformers/test_model.py deleted file mode 100644 index a230d160..00000000 --- a/tests/fast_transformers/test_model.py +++ /dev/null @@ -1,93 +0,0 @@ -"""Tests for FlatSASRecModel.""" - -import pickle - -import pytest - -from rectools import Columns -from rectools.dataset import Dataset -from rectools.fast_transformers import FlatSASRecModel - - -def _make_model(**kwargs) -> FlatSASRecModel: - defaults = dict( - n_factors=16, - n_blocks=1, - n_heads=2, - session_max_len=8, - epochs=1, - batch_size=16, - lr=1e-3, - verbose=0, - ) - defaults.update(kwargs) - return FlatSASRecModel(**defaults) - - -class TestFitRecommend: - def test_recommend_columns(self, tiny_dataset: Dataset) -> None: - model = _make_model() - model.fit(tiny_dataset) - users = list(range(5)) - reco = model.recommend(users=users, dataset=tiny_dataset, k=3, filter_viewed=False) - assert set(reco.columns) == {Columns.User, Columns.Item, Columns.Score, Columns.Rank} - assert reco[Columns.User].nunique() == 5 - - def test_filter_viewed(self, tiny_dataset: Dataset) -> None: - model = _make_model() - model.fit(tiny_dataset) - users = list(range(5)) - reco = model.recommend(users=users, dataset=tiny_dataset, k=5, filter_viewed=True) - interactions = tiny_dataset.get_raw_interactions() - for uid in users: - viewed = set(interactions[interactions[Columns.User] == uid][Columns.Item]) - recommended = set(reco[reco[Columns.User] == uid][Columns.Item]) - assert viewed.isdisjoint(recommended), f"User {uid} got viewed items in recommendations" - - def test_i2i(self, tiny_dataset: Dataset) -> None: - model = _make_model() - model.fit(tiny_dataset) - items = list(range(5)) - reco = model.recommend_to_items(target_items=items, dataset=tiny_dataset, k=3) - assert set(reco.columns) == {Columns.TargetItem, Columns.Item, Columns.Score, Columns.Rank} - assert reco[Columns.TargetItem].nunique() == 5 - - def test_metrics_positive(self, tiny_dataset: Dataset) -> None: - model = _make_model(epochs=3) - model.fit(tiny_dataset) - users = list(range(tiny_dataset.user_id_map.size)) - reco = model.recommend(users=users, dataset=tiny_dataset, k=5, filter_viewed=False) - assert len(reco) > 0 - assert reco[Columns.Score].notna().all() - - -class TestConfig: - def test_config_roundtrip(self) -> None: - model = _make_model(n_factors=32, n_blocks=3) - config = model.get_config(mode="pydantic") - model2 = FlatSASRecModel.from_config(config) - assert model2.n_factors == 32 - assert model2.n_blocks == 3 - - def test_pickle_roundtrip(self, tiny_dataset: Dataset) -> None: - model = _make_model() - model.fit(tiny_dataset) - data = pickle.dumps(model) - model2 = pickle.loads(data) - assert model2.is_fitted - users = list(range(3)) - reco = model2.recommend(users=users, dataset=tiny_dataset, k=3, filter_viewed=False) - assert len(reco) > 0 - - -class TestLosses: - def test_bce_training(self, tiny_dataset: Dataset) -> None: - model = _make_model(loss="BCE", n_negatives=2) - model.fit(tiny_dataset) - users = list(range(5)) - reco = model.recommend(users=users, dataset=tiny_dataset, k=3, filter_viewed=False) - assert len(reco) > 0 - - def test_invalid_loss(self) -> None: - with pytest.raises(ValueError, match="Unsupported loss"): - _make_model(loss="invalid_loss_name") diff --git a/tests/fast_transformers/test_onnx_export.py b/tests/fast_transformers/test_onnx_export.py new file mode 100644 index 00000000..39c2ac36 --- /dev/null +++ b/tests/fast_transformers/test_onnx_export.py @@ -0,0 +1,252 @@ +"""Tests for ONNX export of UniSRec network and UniSRecModel.export_to_onnx.""" + +from pathlib import Path + +import numpy as np +import pytest +import torch + +onnx = pytest.importorskip("onnx") +ort = pytest.importorskip("onnxruntime") + +from rectools.fast_transformers.unisrec_model import UniSRecModel # noqa: E402 +from rectools.fast_transformers.unisrec_net import UniSRec # noqa: E402 + + +@pytest.fixture() +def net() -> UniSRec: + torch.manual_seed(0) + pretrained = torch.randn(11, 32) + pretrained[0] = 0.0 + model = UniSRec( + n_items=10, + pretrained_embeddings=pretrained, + n_factors=16, + projection_hidden=32, + n_blocks=1, + n_heads=2, + session_max_len=8, + dropout=0.0, + adaptor_dropout=0.0, + ) + model.eval() + return model + + +def _export_and_load(net: torch.nn.Module, args, tmp_path: Path, **kwargs): + path = str(tmp_path / "model.onnx") + torch.onnx.export(net, args, path, opset_version=18, **kwargs) + model = onnx.load(path) + onnx.checker.check_model(model) + return ort.InferenceSession(path) + + +class TestUniSRecOnnxExport: + def test_export_succeeds(self, net: UniSRec, tmp_path: Path) -> None: + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + path = str(tmp_path / "model.onnx") + torch.onnx.export( + net, + (dummy, False), + path, + input_names=["input_ids"], + output_names=["hidden"], + opset_version=18, + ) + model = onnx.load(path) + onnx.checker.check_model(model) + + def test_forward_roundtrip(self, net: UniSRec, tmp_path: Path) -> None: + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + sess = _export_and_load( + net, + (dummy, False), + tmp_path, + input_names=["input_ids"], + output_names=["hidden"], + ) + with torch.no_grad(): + expected = net(dummy, use_id=False).numpy() + result = sess.run(None, {"input_ids": dummy.numpy()})[0] + np.testing.assert_allclose(result, expected, atol=1e-5) + + @pytest.mark.xfail(reason="torch.onnx.export ignores dynamic_shapes for tuple args with bool") + def test_dynamic_batch(self, net: UniSRec, tmp_path: Path) -> None: + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + batch = torch.export.Dim("batch", min=1) + sess = _export_and_load( + net, + (dummy, False), + tmp_path, + input_names=["input_ids"], + output_names=["hidden"], + dynamic_shapes=({0: batch}, None), + ) + batch_input = torch.tensor( + [[0, 0, 1, 2, 3], [0, 1, 4, 5, 6], [0, 0, 0, 7, 8]], + dtype=torch.long, + ) + with torch.no_grad(): + expected = net(batch_input, use_id=False).numpy() + result = sess.run(None, {"input_ids": batch_input.numpy()})[0] + assert result.shape[0] == 3 + np.testing.assert_allclose(result, expected, atol=1e-5) + + def test_different_sequence_lengths(self, net: UniSRec, tmp_path: Path) -> None: + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + batch = torch.export.Dim("batch", min=1) + seq_len = torch.export.Dim("seq_len", min=1, max=8) + sess = _export_and_load( + net, + (dummy, False), + tmp_path, + input_names=["input_ids"], + output_names=["hidden"], + dynamic_shapes=({0: batch, 1: seq_len}, None), + ) + short = torch.tensor([[0, 1, 2]], dtype=torch.long) + with torch.no_grad(): + expected = net(short, use_id=False).numpy() + result = sess.run(None, {"input_ids": short.numpy()})[0] + assert result.shape == (1, 3, 16) + np.testing.assert_allclose(result, expected, atol=1e-5) + + def test_padding_only_input(self, net: UniSRec, tmp_path: Path) -> None: + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + sess = _export_and_load( + net, + (dummy, False), + tmp_path, + input_names=["input_ids"], + output_names=["hidden"], + ) + all_pad = torch.zeros(1, 5, dtype=torch.long) + with torch.no_grad(): + expected = net(all_pad, use_id=False).numpy() + result = sess.run(None, {"input_ids": all_pad.numpy()})[0] + np.testing.assert_allclose(result, expected, atol=1e-5) + + def test_output_shape(self, net: UniSRec, tmp_path: Path) -> None: + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + sess = _export_and_load( + net, + (dummy, False), + tmp_path, + input_names=["input_ids"], + output_names=["hidden"], + ) + result = sess.run(None, {"input_ids": dummy.numpy()})[0] + assert result.shape == (1, 5, 16) + + def test_project_all_roundtrip(self, net: UniSRec, tmp_path: Path) -> None: + class _ProjectAll(torch.nn.Module): + def __init__(self, inner: UniSRec): + super().__init__() + self.inner = inner + + def forward(self) -> torch.Tensor: + return self.inner.project_all() + + wrapper = _ProjectAll(net) + wrapper.eval() + path = str(tmp_path / "project_all.onnx") + torch.onnx.export( + wrapper, + (), + path, + input_names=[], + output_names=["item_embs"], + opset_version=18, + ) + model = onnx.load(path) + onnx.checker.check_model(model) + sess = ort.InferenceSession(path) + with torch.no_grad(): + expected = net.project_all().numpy() + result = sess.run(None, {})[0] + assert result.shape == (11, 16) + np.testing.assert_allclose(result, expected, atol=1e-5) + + +class TestUniSRecModelExport: + """Tests for UniSRecModel.export_to_onnx.""" + + @pytest.fixture() + def model(self) -> UniSRecModel: + torch.manual_seed(0) + pretrained = torch.randn(11, 32) + pretrained[0] = 0.0 + m = UniSRecModel( + pretrained_item_embeddings=pretrained, + n_factors=16, + projection_hidden=32, + n_blocks=1, + n_heads=2, + session_max_len=8, + phase1_epochs=0, + phase2_epochs=0, + phase3_epochs=0, + ) + from rectools.fast_transformers.gpu_data import align_embeddings + + unique_items = torch.arange(1, 11) + aligned = align_embeddings(pretrained, unique_items, 10) + net = UniSRec( + n_items=10, + pretrained_embeddings=aligned, + n_factors=16, + projection_hidden=32, + n_blocks=1, + n_heads=2, + session_max_len=8, + dropout=0.0, + adaptor_dropout=0.0, + ) + net.eval() + m._net = net + m._unique_items = unique_items + m._unique_users = torch.arange(5) + m.is_fitted = True + return m + + def test_export_encoder(self, model: UniSRecModel, tmp_path: Path) -> None: + path = tmp_path / "encoder.onnx" + model.export_to_onnx(str(path)) + loaded = onnx.load(str(path)) + onnx.checker.check_model(loaded) + + def test_export_encoder_roundtrip(self, model: UniSRecModel, tmp_path: Path) -> None: + path = tmp_path / "encoder.onnx" + model.export_to_onnx(str(path)) + sess = ort.InferenceSession(str(path)) + dummy = torch.tensor([[0, 0, 1, 2, 3]], dtype=torch.long) + with torch.no_grad(): + expected = model.net(dummy, use_id=False).numpy() + result = sess.run(None, {"input_ids": dummy.numpy()})[0] + np.testing.assert_allclose(result, expected, atol=1e-5) + + def test_export_encoder_and_items(self, model: UniSRecModel, tmp_path: Path) -> None: + enc_path = tmp_path / "encoder.onnx" + items_path = tmp_path / "items.onnx" + model.export_to_onnx(str(enc_path), items_path=str(items_path)) + + loaded_enc = onnx.load(str(enc_path)) + onnx.checker.check_model(loaded_enc) + loaded_items = onnx.load(str(items_path)) + onnx.checker.check_model(loaded_items) + + def test_items_roundtrip(self, model: UniSRecModel, tmp_path: Path) -> None: + items_path = tmp_path / "items.onnx" + model.export_to_onnx(str(tmp_path / "enc.onnx"), items_path=str(items_path)) + sess = ort.InferenceSession(str(items_path)) + with torch.no_grad(): + expected = model.net.project_all().numpy() + result = sess.run(None, {})[0] + assert result.shape == (11, 16) + np.testing.assert_allclose(result, expected, atol=1e-5) + + def test_unfitted_model_raises(self, tmp_path: Path) -> None: + pretrained = torch.randn(5, 8) + m = UniSRecModel(pretrained_item_embeddings=pretrained, n_factors=8) + with pytest.raises(AssertionError): + m.export_to_onnx(str(tmp_path / "model.onnx")) diff --git a/tests/fast_transformers/test_unisrec_model.py b/tests/fast_transformers/test_unisrec_model.py index 13bba453..38965890 100644 --- a/tests/fast_transformers/test_unisrec_model.py +++ b/tests/fast_transformers/test_unisrec_model.py @@ -4,6 +4,7 @@ import torch from rectools.fast_transformers import UniSRecModel +from rectools.fast_transformers.gpu_data import hash_item_ids def _make_embeddings(n_items: int = 25, dim: int = 64) -> torch.Tensor: @@ -187,3 +188,45 @@ def test_patience(self) -> None: model = _make_model(patience=2, phase1_epochs=0, phase2_epochs=0, phase3_epochs=5) model.fit(user_ids, item_ids, timestamps) assert model.is_fitted + + +class TestMapItemIds: + def test_dense_known_items(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(phase1_epochs=1, phase2_epochs=0, phase3_epochs=0) + model.fit(user_ids, item_ids, timestamps) + unique = model.item_id_mapping + result = model.map_item_ids(unique) + expected = torch.arange(1, len(unique) + 1, dtype=torch.long) + assert result.tolist() == expected.tolist() + + def test_dense_unknown_items(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(phase1_epochs=1, phase2_epochs=0, phase3_epochs=0) + model.fit(user_ids, item_ids, timestamps) + unknown = torch.tensor([9999, 8888], dtype=torch.long) + result = model.map_item_ids(unknown) + assert result.tolist() == [0, 0] + + def test_hash_known_items(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(phase1_epochs=1, phase2_epochs=0, phase3_epochs=0, id_mapping="hash") + model.fit(user_ids, item_ids, timestamps) + unique = model.item_id_mapping + n_items = len(unique) + result = model.map_item_ids(unique) + expected = hash_item_ids(unique, n_items) + assert result.tolist() == expected.tolist() + + def test_hash_unknown_items(self) -> None: + user_ids, item_ids, timestamps = _make_interactions() + model = _make_model(phase1_epochs=1, phase2_epochs=0, phase3_epochs=0, id_mapping="hash") + model.fit(user_ids, item_ids, timestamps) + unknown = torch.tensor([9999, 8888], dtype=torch.long) + result = model.map_item_ids(unknown) + assert result.tolist() == [0, 0] + + def test_unfitted_raises(self) -> None: + model = _make_model() + with pytest.raises(AssertionError): + model.map_item_ids(torch.tensor([1, 2]))