Skip to content

Commit 563ebdf

Browse files
committed
Improved typechecking for tests [skip ci]
1 parent 93fabd2 commit 563ebdf

3 files changed

Lines changed: 27 additions & 22 deletions

File tree

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
lint:
44
pycodestyle . --ignore=E501 --exclude=.venv
55

6+
check:
7+
ty check pgvector tests
8+
69
build:
710
python3 -m build
811

tests/test_django.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# type: ignore
2+
13
import django
24
from django.conf import settings
35
from django.contrib.postgres.fields import ArrayField

tests/test_sqlmodel.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Item(SQLModel, table=True):
2626

2727
index = Index(
2828
'sqlmodel_index',
29-
Item.embedding,
29+
Item.embedding, # type: ignore
3030
postgresql_using='hnsw',
3131
postgresql_with={'m': 16, 'ef_construction': 64},
3232
postgresql_ops={'embedding': 'vector_l2_ops'}
@@ -65,41 +65,41 @@ def test_orm(self):
6565
assert items[0].id == 1
6666
assert items[1].id == 2
6767
assert items[2].id == 3
68-
assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3]))
69-
assert items[0].embedding.dtype == np.float32
70-
assert np.array_equal(items[1].embedding, np.array([4, 5, 6]))
71-
assert items[1].embedding.dtype == np.float32
68+
assert np.array_equal(items[0].embedding, np.array([1.5, 2, 3])) # type: ignore
69+
assert items[0].embedding.dtype == np.float32 # type: ignore
70+
assert np.array_equal(items[1].embedding, np.array([4, 5, 6])) # type: ignore
71+
assert items[1].embedding.dtype == np.float32 # type: ignore
7272
assert items[2].embedding is None
7373

7474
def test_vector(self):
7575
with Session(engine) as session:
7676
session.add(Item(id=1, embedding=[1, 2, 3]))
7777
session.commit()
7878
item = session.get_one(Item, 1)
79-
assert np.array_equal(item.embedding, np.array([1, 2, 3]))
79+
assert np.array_equal(item.embedding, np.array([1, 2, 3])) # type: ignore
8080

8181
def test_vector_l2_distance(self):
8282
create_items()
8383
with Session(engine) as session:
84-
items = session.exec(select(Item).order_by(Item.embedding.l2_distance([1, 1, 1])))
84+
items = session.exec(select(Item).order_by(Item.embedding.l2_distance([1, 1, 1]))) # type: ignore
8585
assert [v.id for v in items] == [1, 3, 2]
8686

8787
def test_vector_max_inner_product(self):
8888
create_items()
8989
with Session(engine) as session:
90-
items = session.exec(select(Item).order_by(Item.embedding.max_inner_product([1, 1, 1])))
90+
items = session.exec(select(Item).order_by(Item.embedding.max_inner_product([1, 1, 1]))) # type: ignore
9191
assert [v.id for v in items] == [2, 3, 1]
9292

9393
def test_vector_cosine_distance(self):
9494
create_items()
9595
with Session(engine) as session:
96-
items = session.exec(select(Item).order_by(Item.embedding.cosine_distance([1, 1, 1])))
96+
items = session.exec(select(Item).order_by(Item.embedding.cosine_distance([1, 1, 1]))) # type: ignore
9797
assert [v.id for v in items] == [1, 2, 3]
9898

9999
def test_vector_l1_distance(self):
100100
create_items()
101101
with Session(engine) as session:
102-
items = session.exec(select(Item).order_by(Item.embedding.l1_distance([1, 1, 1])))
102+
items = session.exec(select(Item).order_by(Item.embedding.l1_distance([1, 1, 1]))) # type: ignore
103103
assert [v.id for v in items] == [1, 3, 2]
104104

105105
def test_halfvec(self):
@@ -112,25 +112,25 @@ def test_halfvec(self):
112112
def test_halfvec_l2_distance(self):
113113
create_items()
114114
with Session(engine) as session:
115-
items = session.exec(select(Item).order_by(Item.half_embedding.l2_distance([1, 1, 1])))
115+
items = session.exec(select(Item).order_by(Item.half_embedding.l2_distance([1, 1, 1]))) # type: ignore
116116
assert [v.id for v in items] == [1, 3, 2]
117117

118118
def test_halfvec_max_inner_product(self):
119119
create_items()
120120
with Session(engine) as session:
121-
items = session.exec(select(Item).order_by(Item.half_embedding.max_inner_product([1, 1, 1])))
121+
items = session.exec(select(Item).order_by(Item.half_embedding.max_inner_product([1, 1, 1]))) # type: ignore
122122
assert [v.id for v in items] == [2, 3, 1]
123123

124124
def test_halfvec_cosine_distance(self):
125125
create_items()
126126
with Session(engine) as session:
127-
items = session.exec(select(Item).order_by(Item.half_embedding.cosine_distance([1, 1, 1])))
127+
items = session.exec(select(Item).order_by(Item.half_embedding.cosine_distance([1, 1, 1]))) # type: ignore
128128
assert [v.id for v in items] == [1, 2, 3]
129129

130130
def test_halfvec_l1_distance(self):
131131
create_items()
132132
with Session(engine) as session:
133-
items = session.exec(select(Item).order_by(Item.half_embedding.l1_distance([1, 1, 1])))
133+
items = session.exec(select(Item).order_by(Item.half_embedding.l1_distance([1, 1, 1]))) # type: ignore
134134
assert [v.id for v in items] == [1, 3, 2]
135135

136136
def test_bit(self):
@@ -143,13 +143,13 @@ def test_bit(self):
143143
def test_bit_hamming_distance(self):
144144
create_items()
145145
with Session(engine) as session:
146-
items = session.exec(select(Item).order_by(Item.binary_embedding.hamming_distance('101')))
146+
items = session.exec(select(Item).order_by(Item.binary_embedding.hamming_distance('101'))) # type: ignore
147147
assert [v.id for v in items] == [2, 3, 1]
148148

149149
def test_bit_jaccard_distance(self):
150150
create_items()
151151
with Session(engine) as session:
152-
items = session.exec(select(Item).order_by(Item.binary_embedding.jaccard_distance('101')))
152+
items = session.exec(select(Item).order_by(Item.binary_embedding.jaccard_distance('101'))) # type: ignore
153153
assert [v.id for v in items] == [2, 3, 1]
154154

155155
def test_sparsevec(self):
@@ -162,37 +162,37 @@ def test_sparsevec(self):
162162
def test_sparsevec_l2_distance(self):
163163
create_items()
164164
with Session(engine) as session:
165-
items = session.exec(select(Item).order_by(Item.sparse_embedding.l2_distance([1, 1, 1])))
165+
items = session.exec(select(Item).order_by(Item.sparse_embedding.l2_distance([1, 1, 1]))) # type: ignore
166166
assert [v.id for v in items] == [1, 3, 2]
167167

168168
def test_sparsevec_max_inner_product(self):
169169
create_items()
170170
with Session(engine) as session:
171-
items = session.exec(select(Item).order_by(Item.sparse_embedding.max_inner_product([1, 1, 1])))
171+
items = session.exec(select(Item).order_by(Item.sparse_embedding.max_inner_product([1, 1, 1]))) # type: ignore
172172
assert [v.id for v in items] == [2, 3, 1]
173173

174174
def test_sparsevec_cosine_distance(self):
175175
create_items()
176176
with Session(engine) as session:
177-
items = session.exec(select(Item).order_by(Item.sparse_embedding.cosine_distance([1, 1, 1])))
177+
items = session.exec(select(Item).order_by(Item.sparse_embedding.cosine_distance([1, 1, 1]))) # type: ignore
178178
assert [v.id for v in items] == [1, 2, 3]
179179

180180
def test_sparsevec_l1_distance(self):
181181
create_items()
182182
with Session(engine) as session:
183-
items = session.exec(select(Item).order_by(Item.sparse_embedding.l1_distance([1, 1, 1])))
183+
items = session.exec(select(Item).order_by(Item.sparse_embedding.l1_distance([1, 1, 1]))) # type: ignore
184184
assert [v.id for v in items] == [1, 3, 2]
185185

186186
def test_filter(self):
187187
create_items()
188188
with Session(engine) as session:
189-
items = session.exec(select(Item).filter(Item.embedding.l2_distance([1, 1, 1]) < 1))
189+
items = session.exec(select(Item).filter(Item.embedding.l2_distance([1, 1, 1]) < 1)) # type: ignore
190190
assert [v.id for v in items] == [1]
191191

192192
def test_select(self):
193193
with Session(engine) as session:
194194
session.add(Item(embedding=[2, 3, 3]))
195-
items = session.exec(select(Item.embedding.l2_distance([1, 1, 1]))).all()
195+
items = session.exec(select(Item.embedding.l2_distance([1, 1, 1]))).all() # type: ignore
196196
assert items[0] == 3
197197

198198
def test_vector_avg(self):

0 commit comments

Comments
 (0)