Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions model2vec/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
from model2vec import StaticModel
from model2vec.inference import StaticModelPipeline
from model2vec.train.dataset import TextDataset
from model2vec.train.utils import get_probable_pad_token_id, suppress_lightning_warnings, to_pipeline, train_test_split
from model2vec.train.utils import (
get_probable_pad_token_id,
logit,
suppress_lightning_warnings,
to_pipeline,
train_test_split,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -79,11 +85,15 @@ def __init__(
self.freeze = freeze
self.embeddings = nn.Embedding.from_pretrained(vectors.clone(), freeze=self.freeze, padding_idx=pad_id)
self.head = self.construct_head()
self.w = self.construct_weights() if weights is None else nn.Parameter(weights.float(), requires_grad=True)
self._weights = weights
self.w = self.construct_weights()
self.tokenizer = tokenizer

def construct_weights(self) -> nn.Parameter:
"""Construct the weights for the model."""
if self._weights is not None:
w = logit(self._weights)
return nn.Parameter(w.float(), requires_grad=True)
weights = torch.zeros(len(self.token_mapping))
weights[self.pad_id] = -10_000
return nn.Parameter(weights, requires_grad=not self.freeze)
Expand Down
6 changes: 6 additions & 0 deletions model2vec/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING, Any, Callable

import numpy as np
import torch
from sklearn.model_selection import train_test_split as sklearn_split
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.pipeline import make_pipeline
Expand Down Expand Up @@ -111,3 +112,8 @@ class TipFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
"""Filter out tip messages from lightning."""
return "💡 Tip" not in record.getMessage()


def logit(x: torch.Tensor) -> torch.Tensor:
"""Invert a sigmoid."""
return -torch.log((1 / x) - 1)
19 changes: 18 additions & 1 deletion tests/test_trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from model2vec.train.base import BaseFinetuneable
from model2vec.train.dataset import TextDataset
from model2vec.train.similarity import StaticModelForSimilarity
from model2vec.train.utils import get_probable_pad_token_id, train_test_split
from model2vec.train.utils import get_probable_pad_token_id, logit, train_test_split


@pytest.mark.parametrize("n_layers", [0, 1, 2, 3])
Expand Down Expand Up @@ -74,6 +74,17 @@ def test_init_classifier_from_model(mock_vectors: np.ndarray, mock_tokenizer: To
assert s.w.shape[0] == mock_vectors.shape[0]


def test_init_classifier_from_model_w(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
"""Test initializion from a static model."""
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, weights=np.ones(len(mock_vectors)))
s = StaticModelForClassification.from_static_model(model=model)
assert s._weights is not None
assert torch.all(s._weights == torch.ones(len(mock_vectors)))
w = s.construct_weights()
assert w.shape[0] == mock_vectors.shape[0]
assert torch.all(w == logit(torch.ones(len(mock_vectors))))


def test_pad_token(mock_tokenizer: Tokenizer) -> None:
"""Test initializion from a static model."""
tokenizer_model = TokenizerModel.from_tokenizer(mock_tokenizer)
Expand Down Expand Up @@ -360,3 +371,9 @@ def test_determine_interval() -> None:
)
assert val_check_interval == 100
assert check_val_every_epoch is None


def test_logit() -> None:
"""Test on random data."""
x = torch.arange(10).float() / 10
assert torch.allclose(logit(torch.sigmoid(x)), x, atol=1e-6)
Loading