Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import re
from typing import cast

import numpy as np
from huggingface_hub.hf_api import model_info
from skeletoken import TokenizerModel
from skeletoken.external.transformers import reshape_embeddings
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast
from transformers.modeling_utils import PreTrainedModel

from model2vec.distill.inference import PCADimType, PoolingMode, create_embeddings, post_process_embeddings
from model2vec.distill.inference import PCADimType, PoolingMode, apply_pca, compute_weights, create_embeddings
from model2vec.distill.utils import select_optimal_device
from model2vec.model import StaticModel
from model2vec.quantization import DType, quantize_embeddings
Expand Down Expand Up @@ -108,16 +107,17 @@ def distill_from_model(
pooling=pooling,
)

# Maybe apply quantization
# Apply quantization
if vocabulary_quantization is not None:
_, weights = post_process_embeddings(np.asarray(embeddings), None, sif_coefficient=sif_coefficient)
weights = compute_weights(len(embeddings), sif_coefficient=sif_coefficient)
embeddings, token_mapping, weights = quantize_vocabulary(
n_clusters=vocabulary_quantization, weights=weights, embeddings=np.asarray(embeddings)
n_clusters=vocabulary_quantization, weights=weights, embeddings=embeddings
)
embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient=sif_coefficient)
embeddings = apply_pca(embeddings, pca_dims)
else:
# Post-process the embeddings.
embeddings, weights = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
weights = compute_weights(len(embeddings), sif_coefficient=sif_coefficient)
embeddings = apply_pca(embeddings, pca_dims)
embeddings = embeddings * weights[:, None]
weights = None
token_mapping = None
Expand Down
28 changes: 15 additions & 13 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,20 @@ def _encode_pooler_with_model(model: PreTrainedModel, encodings: dict[str, torch
return pooler.cpu()


def post_process_embeddings(
embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4
) -> tuple[np.ndarray, np.ndarray]:
"""Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law."""
def compute_weights(n_embeddings: int, sif_coefficient: float | None) -> np.ndarray:
"""Compute the weights based on Zipf's law and a SIF coefficient."""
if sif_coefficient is None:
return np.ones(n_embeddings)
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
inv_rank = 1 / (np.arange(2, n_embeddings + 2))
proba = inv_rank / np.sum(inv_rank)
weight = sif_coefficient / (sif_coefficient + proba)

return weight


def apply_pca(embeddings: np.ndarray, pca_dims: PCADimType) -> np.ndarray:
"""Apply PCA to the embeddings."""
if pca_dims is not None:
if pca_dims == "auto":
pca_dims = embeddings.shape[1]
Expand Down Expand Up @@ -241,12 +251,4 @@ def post_process_embeddings(
logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
logger.info(f"Explained variance: {explained_variance:.3f}.")

if sif_coefficient is not None:
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
proba = inv_rank / np.sum(inv_rank)
weight = sif_coefficient / (sif_coefficient + proba)
else:
weight = np.ones(embeddings.shape[0])

return embeddings, weight
return embeddings
36 changes: 34 additions & 2 deletions tests/test_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from transformers.tokenization_utils_tokenizers import PreTrainedTokenizerFast

from model2vec.distill.distillation import distill, distill_from_model
from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings
from model2vec.distill.inference import PoolingMode, apply_pca, compute_weights, create_embeddings
from model2vec.model import StaticModel
from model2vec.tokenizer import clean_and_create_vocabulary

Expand Down Expand Up @@ -88,6 +88,36 @@ def test_distill_from_model(
assert token in static_model.tokens or normalized in static_model.tokens


@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@patch("transformers.AutoModel.from_pretrained")
def test_distill_quantization(
mock_auto_model: MagicMock,
mock_model_info: MagicMock,
mock_berttokenizer: PreTrainedTokenizerFast,
mock_transformer: PreTrainedModel,
) -> None:
"""Test distill function with different parameters."""
# Mock the return value of model_info to avoid calling the Hugging Face API
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})
mock_auto_model.return_value = mock_transformer

static_model = distill_from_model(
model=mock_transformer,
tokenizer=mock_berttokenizer,
vocabulary=None,
device="cpu",
pca_dims="auto",
sif_coefficient=1e-4,
token_remove_pattern=None,
vocabulary_quantization=3,
)

assert static_model.embedding.shape == (3, 768)
assert static_model.weights is not None
assert static_model.token_mapping is not None
assert len(static_model.weights) == static_model.tokenizer.get_vocab_size()


@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@patch("transformers.AutoModel.from_pretrained")
def test_distill_removal_pattern_all_tokens(
Expand Down Expand Up @@ -259,7 +289,9 @@ def test__post_process_embeddings(
# The implementation logs a warning and skips reduction; no exception expected.
pass

processed_embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient)
processed_embeddings = apply_pca(embeddings, pca_dims)
weights = compute_weights(len(processed_embeddings), sif_coefficient=sif_coefficient)
processed_embeddings = processed_embeddings * weights[:, None]

# Assert the shape is correct
assert processed_embeddings.shape == expected_shape
Expand Down
Loading