From d3e9aa83b3d70fffe982c737e44e924e27e756b2 Mon Sep 17 00:00:00 2001 From: stephantul Date: Thu, 16 Apr 2026 15:28:59 +0200 Subject: [PATCH] feat: convert output to float in distill --- model2vec/distill/inference.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/model2vec/distill/inference.py b/model2vec/distill/inference.py index b08a021..98fe8cd 100644 --- a/model2vec/distill/inference.py +++ b/model2vec/distill/inference.py @@ -137,10 +137,9 @@ def _encode_with_model( # NOTE: If the dtype is bfloat 16, we convert to float32, # because numpy does not suport bfloat16 # See here: https://github.com/numpy/numpy/issues/19808 - if hidden.dtype == torch.bfloat16: - hidden = hidden.float() + hidden = hidden.float() pooler = getattr(outputs, "pooler_output", None) - if pooler is not None and pooler.dtype == torch.bfloat16: + if pooler is not None: pooler = pooler.float() return hidden, pooler, encodings_on_device