diff --git a/model2vec/distill/inference.py b/model2vec/distill/inference.py index b08a021..98fe8cd 100644 --- a/model2vec/distill/inference.py +++ b/model2vec/distill/inference.py @@ -137,10 +137,9 @@ def _encode_with_model( # NOTE: If the dtype is bfloat 16, we convert to float32, # because numpy does not suport bfloat16 # See here: https://github.com/numpy/numpy/issues/19808 - if hidden.dtype == torch.bfloat16: - hidden = hidden.float() + hidden = hidden.float() pooler = getattr(outputs, "pooler_output", None) - if pooler is not None and pooler.dtype == torch.bfloat16: + if pooler is not None: pooler = pooler.float() return hidden, pooler, encodings_on_device