diff --git a/model2vec/train/base.py b/model2vec/train/base.py index 5c0a019..8ce093e 100644 --- a/model2vec/train/base.py +++ b/model2vec/train/base.py @@ -18,7 +18,13 @@ from model2vec import StaticModel from model2vec.inference import StaticModelPipeline from model2vec.train.dataset import TextDataset -from model2vec.train.utils import get_probable_pad_token_id, suppress_lightning_warnings, to_pipeline, train_test_split +from model2vec.train.utils import ( + get_probable_pad_token_id, + logit, + suppress_lightning_warnings, + to_pipeline, + train_test_split, +) logger = logging.getLogger(__name__) @@ -79,11 +85,15 @@ def __init__( self.freeze = freeze self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=self.freeze, padding_idx=pad_id) self.head = self.construct_head() - self.w = self.construct_weights() if weights is None else nn.Parameter(weights.float(), requires_grad=True) + self._weights = weights + self.w = self.construct_weights() self.tokenizer = tokenizer def construct_weights(self) -> nn.Parameter: """Construct the weights for the model.""" + if self._weights is not None: + w = logit(self._weights) + return nn.Parameter(w.float(), requires_grad=True) weights = torch.zeros(len(self.token_mapping)) weights[self.pad_id] = -10_000 return nn.Parameter(weights, requires_grad=not self.freeze) diff --git a/model2vec/train/utils.py b/model2vec/train/utils.py index d206c77..683c7ff 100644 --- a/model2vec/train/utils.py +++ b/model2vec/train/utils.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Callable import numpy as np +import torch from sklearn.model_selection import train_test_split as sklearn_split from sklearn.neural_network import MLPClassifier, MLPRegressor from sklearn.pipeline import make_pipeline @@ -111,3 +112,8 @@ class TipFilter(logging.Filter): def filter(self, record: logging.LogRecord) -> bool: """Filter out tip messages from lightning.""" return "💡 Tip" not in record.getMessage() + + +def logit(x: torch.Tensor) -> torch.Tensor: + """Invert a sigmoid.""" + return -torch.log((1 / x) - 1) diff --git a/tests/test_trainable.py b/tests/test_trainable.py index 6f328da..3541ef8 100644 --- a/tests/test_trainable.py +++ b/tests/test_trainable.py @@ -13,7 +13,7 @@ from model2vec.train.base import BaseFinetuneable from model2vec.train.dataset import TextDataset from model2vec.train.similarity import StaticModelForSimilarity -from model2vec.train.utils import get_probable_pad_token_id, train_test_split +from model2vec.train.utils import get_probable_pad_token_id, logit, train_test_split @pytest.mark.parametrize("n_layers", [0, 1, 2, 3]) @@ -74,6 +74,17 @@ def test_init_classifier_from_model(mock_vectors: np.ndarray, mock_tokenizer: To assert s.w.shape[0] == mock_vectors.shape[0] +def test_init_classifier_from_model_w(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None: + """Test initializion from a static model.""" + model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, weights=np.ones(len(mock_vectors))) + s = StaticModelForClassification.from_static_model(model=model) + assert s._weights is not None + assert torch.all(s._weights == torch.ones(len(mock_vectors))) + w = s.construct_weights() + assert w.shape[0] == mock_vectors.shape[0] + assert torch.all(w == logit(torch.ones(len(mock_vectors)))) + + def test_pad_token(mock_tokenizer: Tokenizer) -> None: """Test initializion from a static model.""" tokenizer_model = TokenizerModel.from_tokenizer(mock_tokenizer) @@ -360,3 +371,9 @@ def test_determine_interval() -> None: ) assert val_check_interval == 100 assert check_val_every_epoch is None + + +def test_logit() -> None: + """Test on random data.""" + x = torch.arange(10).float() / 10 + assert torch.allclose(logit(torch.sigmoid(x)), x, atol=1e-6)