11# -*- coding: utf-8 -*-
2+ from typing import List
3+
4+ import numpy as np
5+ from typing import List
6+
7+ import numpy as np
8+ from modelcache .manager .vector_data .base import VectorBase , VectorData
9+ # from modelcache.utils import import_redis
10+ # from modelcache.utils.log import gptcache_log
11+
12+ # import_redis()
13+ #
14+ # # pylint: disable=C0413
15+ # from redis.commands.search.indexDefinition import IndexDefinition, IndexType
16+ # from redis.commands.search.query import Query
17+ # from redis.commands.search.field import TagField, VectorField
18+ # from redis.client import Redis
19+
20+
21+ class RedisVectorStore (VectorBase ):
22+ """ vector store: Redis
23+
24+ :param host: redis host, defaults to "localhost".
25+ :type host: str
26+ :param port: redis port, defaults to "6379".
27+ :type port: str
28+ :param username: redis username, defaults to "".
29+ :type username: str
30+ :param password: redis password, defaults to "".
31+ :type password: str
32+ :param dimension: the dimension of the vector, defaults to 0.
33+ :type dimension: int
34+ :param collection_name: the name of the index for Redis, defaults to "gptcache".
35+ :type collection_name: str
36+ :param top_k: the number of the vectors results to return, defaults to 1.
37+ :type top_k: int
38+
39+ Example:
40+ .. code-block:: python
41+
42+ from gptcache.manager import VectorBase
43+
44+ vector_base = VectorBase("redis", dimension=10)
45+ """
46+ def __init__ (
47+ self ,
48+ host : str = "localhost" ,
49+ port : str = "6379" ,
50+ username : str = "" ,
51+ password : str = "" ,
52+ dimension : int = 0 ,
53+ collection_name : str = "gptcache" ,
54+ top_k : int = 1 ,
55+ namespace : str = "" ,
56+ ):
57+ self ._client = Redis (
58+ host = host , port = int (port ), username = username , password = password
59+ )
60+ self .top_k = top_k
61+ self .dimension = dimension
62+ self .collection_name = collection_name
63+ self .namespace = namespace
64+ self .doc_prefix = f"{ self .namespace } doc:" # Prefix with the specified namespace
65+ self ._create_collection (collection_name )
66+
67+ def _check_index_exists (self , index_name : str ) -> bool :
68+ """Check if Redis index exists."""
69+ try :
70+ self ._client .ft (index_name ).info ()
71+ except : # pylint: disable=W0702
72+ gptcache_log .info ("Index does not exist" )
73+ return False
74+ gptcache_log .info ("Index already exists" )
75+ return True
76+
77+ def _create_collection (self , collection_name ):
78+ if self ._check_index_exists (collection_name ):
79+ gptcache_log .info (
80+ "The %s already exists, and it will be used directly" , collection_name
81+ )
82+ else :
83+ schema = (
84+ TagField ("tag" ), # Tag Field Name
85+ VectorField (
86+ "vector" , # Vector Field Name
87+ "FLAT" ,
88+ { # Vector Index Type: FLAT or HNSW
89+ "TYPE" : "FLOAT32" , # FLOAT32 or FLOAT64
90+ "DIM" : self .dimension , # Number of Vector Dimensions
91+ "DISTANCE_METRIC" : "COSINE" , # Vector Search Distance Metric
92+ },
93+ ),
94+ )
95+ definition = IndexDefinition (
96+ prefix = [self .doc_prefix ], index_type = IndexType .HASH
97+ )
98+
99+ # create Index
100+ self ._client .ft (collection_name ).create_index (
101+ fields = schema , definition = definition
102+ )
103+
104+ def mul_add (self , datas : List [VectorData ]):
105+ pipe = self ._client .pipeline ()
106+
107+ for data in datas :
108+ key : int = data .id
109+ obj = {
110+ "vector" : data .data .astype (np .float32 ).tobytes (),
111+ }
112+ pipe .hset (f"{ self .doc_prefix } { key } " , mapping = obj )
113+
114+ pipe .execute ()
115+
116+ def search (self , data : np .ndarray , top_k : int = - 1 ):
117+ query = (
118+ Query (
119+ f"*=>[KNN { top_k if top_k > 0 else self .top_k } @vector $vec as score]"
120+ )
121+ .sort_by ("score" )
122+ .return_fields ("id" , "score" )
123+ .paging (0 , top_k if top_k > 0 else self .top_k )
124+ .dialect (2 )
125+ )
126+ query_params = {"vec" : data .astype (np .float32 ).tobytes ()}
127+ results = (
128+ self ._client .ft (self .collection_name )
129+ .search (query , query_params = query_params )
130+ .docs
131+ )
132+ return [(float (result .score ), int (result .id [len (self .doc_prefix ):])) for result in results ]
133+
134+ def rebuild (self , ids = None ) -> bool :
135+ pass
136+
137+ def delete (self , ids ) -> None :
138+ pipe = self ._client .pipeline ()
139+ for data_id in ids :
140+ pipe .delete (f"{ self .doc_prefix } { data_id } " )
141+ pipe .execute ()
0 commit comments