1010from modelcache .manager .vector_data import manager
1111from modelcache .manager import CacheBase , VectorBase , get_data_manager , data_manager
1212from modelcache .similarity_evaluation .distance import SearchDistanceEvaluation
13- from modelcache .processor .pre import query_multi_splicing
14- from modelcache .processor .pre import insert_multi_splicing
13+ from modelcache .processor .pre import query_multi_splicing ,insert_multi_splicing , query_with_role
1514from concurrent .futures import ThreadPoolExecutor
1615from modelcache .utils .model_filter import model_blacklist_filter
1716from modelcache .embedding import Data2VecAudio
@@ -36,13 +35,17 @@ def response_hitquery(cache_resp):
3635
3736if manager .MPNet_base :
3837 mpnet_base = MPNet_Base ()
39- embedding_func = lambda x : mpnet_base .embedding_func ( x )
38+ embedding_func = mpnet_base .to_embeddings
4039 dimension = mpnet_base .dimension
4140 data_manager .NORMALIZE = False
41+ query_pre_embedding_func = query_with_role
42+ insert_pre_embedding_func = query_with_role
4243else :
4344 data2vec = Data2VecAudio ()
4445 embedding_func = data2vec .to_embeddings
4546 dimension = data2vec .dimension
47+ query_pre_embedding_func = query_multi_splicing
48+ insert_pre_embedding_func = insert_multi_splicing
4649
4750mysql_config = configparser .ConfigParser ()
4851mysql_config .read ('modelcache/config/mysql_config.ini' )
@@ -95,8 +98,8 @@ def response_hitquery(cache_resp):
9598 embedding_func = embedding_func ,
9699 data_manager = data_manager ,
97100 similarity_evaluation = SearchDistanceEvaluation (),
98- query_pre_embedding_func = query_multi_splicing ,
99- insert_pre_embedding_func = insert_multi_splicing ,
101+ query_pre_embedding_func = query_pre_embedding_func ,
102+ insert_pre_embedding_func = insert_pre_embedding_func ,
100103)
101104
102105global executor
0 commit comments