|
14 | 14 | import numpy as np |
15 | 15 | import os |
16 | 16 | import pgvector.django |
17 | | -from pgvector import HalfVector, SparseVector |
| 17 | +from pgvector import HalfVector, SparseVector, Vector |
18 | 18 | from pgvector.django import VectorExtension, VectorField, HalfVectorField, BitField, SparseVectorField, IvfflatIndex, HnswIndex, L2Distance, MaxInnerProduct, CosineDistance, L1Distance, HammingDistance, JaccardDistance |
19 | 19 | from unittest import mock |
20 | 20 |
|
@@ -167,12 +167,11 @@ def setup_method(self): |
167 | 167 | def test_vector(self): |
168 | 168 | Item(id=1, embedding=[1, 2, 3]).save() |
169 | 169 | item = Item.objects.get(pk=1) |
170 | | - assert np.array_equal(item.embedding, [1, 2, 3]) |
171 | | - assert item.embedding.dtype == np.float32 |
| 170 | + assert item.embedding == Vector([1, 2, 3]) |
172 | 171 |
|
173 | 172 | def test_vector_l2_distance(self): |
174 | 173 | create_items() |
175 | | - distance = L2Distance('embedding', [1, 1, 1]) |
| 174 | + distance = L2Distance('embedding', Vector([1, 1, 1])) |
176 | 175 | items = Item.objects.annotate(distance=distance).order_by(distance) |
177 | 176 | assert [v.id for v in items] == [1, 3, 2] |
178 | 177 | assert [v.distance for v in items] == [0, 1, sqrt(3)] |
@@ -295,15 +294,15 @@ def test_vector_avg(self): |
295 | 294 | Item(embedding=[1, 2, 3]).save() |
296 | 295 | Item(embedding=[4, 5, 6]).save() |
297 | 296 | avg = Item.objects.aggregate(Avg('embedding'))['embedding__avg'] |
298 | | - assert np.array_equal(avg, [2.5, 3.5, 4.5]) |
| 297 | + assert avg == Vector([2.5, 3.5, 4.5]) |
299 | 298 |
|
300 | 299 | def test_vector_sum(self): |
301 | 300 | sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum'] |
302 | 301 | assert sum is None |
303 | 302 | Item(embedding=[1, 2, 3]).save() |
304 | 303 | Item(embedding=[4, 5, 6]).save() |
305 | 304 | sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum'] |
306 | | - assert np.array_equal(sum, [5, 7, 9]) |
| 305 | + assert sum == Vector([5, 7, 9]) |
307 | 306 |
|
308 | 307 | def test_halfvec_avg(self): |
309 | 308 | avg = Item.objects.aggregate(Avg('half_embedding'))['half_embedding__avg'] |
@@ -349,7 +348,7 @@ def test_vector_form_save(self): |
349 | 348 | assert form.has_changed() |
350 | 349 | assert form.is_valid() |
351 | 350 | assert form.save() |
352 | | - assert np.array_equal(Item.objects.get(pk=1).embedding, [4, 5, 6]) |
| 351 | + assert Item.objects.get(pk=1).embedding == Vector([4, 5, 6]) |
353 | 352 |
|
354 | 353 | def test_vector_form_save_missing(self): |
355 | 354 | Item(id=1).save() |
@@ -467,8 +466,7 @@ def test_vector_array(self): |
467 | 466 |
|
468 | 467 | # this fails if the driver does not cast arrays |
469 | 468 | item = Item.objects.get(pk=1) |
470 | | - assert np.array_equal(item.embeddings[0], [1, 2, 3]) |
471 | | - assert np.array_equal(item.embeddings[1], [4, 5, 6]) |
| 469 | + assert item.embeddings == [Vector([1, 2, 3]), Vector([4, 5, 6])] |
472 | 470 |
|
473 | 471 | def test_double_array(self): |
474 | 472 | Item(id=1, double_embedding=[1, 1, 1]).save() |
|
0 commit comments