diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index 683a78556..277c01f30 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -171,7 +171,7 @@ def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): mask = attention_mask.unsqueeze(-1).expand(data.size()).float() masked_embeddings = data * mask sum_embeddings = torch.sum(masked_embeddings, dim=1) - token_num = torch.sum(attention_mask).item() + token_num = attention_mask.sum(dim=1, keepdim=True) return sum_embeddings, token_num @@ -224,7 +224,7 @@ def get_embeddings(self, params): ): embedding = embedding / token_num normalized_embeddings = F.normalize(embedding, p=2, dim=1) - ret["token_num"] = token_num + ret["token_num"] = token_num.sum().item() else: all_embeddings = [] all_token_num = 0 @@ -273,7 +273,7 @@ def get_embeddings(self, params): embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num normalized_embeddings = F.normalize(embedding, p=2, dim=1) - ret["token_num"] = all_token_num + ret["token_num"] = all_token_num.sum().item() if base64_encode == "base64": out_embeddings = self.__encode_base64(normalized_embeddings)