Skip to content

Commit 60990c2

Browse files
committed
Changed vector type to return Vector object instead of NumPy array - closes #99
1 parent f58842c commit 60990c2

15 files changed

Lines changed: 93 additions & 127 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
## 0.5.0 (unreleased)
22

33
- Added type hints
4+
- Changed `vector` type to return `Vector` object instead of NumPy array
45
- Removed `utils` package (use top-level `pgvector` package instead)
56
- Removed re-exported classes (use top-level `pgvector` package instead)
67
- Dropped support for Python < 3.10

pgvector/django/vector.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from django import forms
22
from django.db.models import Field
3-
import numpy as np
43
from typing import Any
54
from .. import Vector
65

@@ -25,30 +24,23 @@ def db_type(self, connection: Any) -> str:
2524
return 'vector'
2625
return 'vector(%d)' % self.dimensions
2726

28-
def from_db_value(self, value: Any, expression: Any, connection: Any) -> np.ndarray | None:
27+
def from_db_value(self, value: Any, expression: Any, connection: Any) -> Vector | None:
2928
return Vector._from_db(value)
3029

31-
def to_python(self, value: Any) -> np.ndarray | None:
32-
if isinstance(value, list):
33-
return np.array(value, dtype=np.float32)
34-
return Vector._from_db(value)
30+
def to_python(self, value: Any) -> Vector | None:
31+
if value is None or isinstance(value, Vector):
32+
return value
33+
elif isinstance(value, str):
34+
return Vector._from_db(value)
35+
else:
36+
return Vector(value)
3537

3638
def get_prep_value(self, value: Any) -> str | None:
3739
return Vector._to_db(value)
3840

3941
def value_to_string(self, obj: Any) -> str | None:
4042
return self.get_prep_value(self.value_from_object(obj))
4143

42-
def validate(self, value: Any, model_instance: Any) -> None:
43-
if isinstance(value, np.ndarray):
44-
value = value.tolist()
45-
super().validate(value, model_instance)
46-
47-
def run_validators(self, value: Any) -> None:
48-
if isinstance(value, np.ndarray):
49-
value = value.tolist()
50-
super().run_validators(value)
51-
5244
def formfield(self, form_class: Any = None, choices_form_class: Any = None, **kwargs: Any) -> forms.Field:
5345
return super().formfield(
5446
form_class=VectorFormField if form_class is None else form_class,
@@ -59,19 +51,14 @@ def formfield(self, form_class: Any = None, choices_form_class: Any = None, **kw
5951

6052
class VectorWidget(forms.TextInput):
6153
def format_value(self, value: Any) -> str | None:
62-
if isinstance(value, np.ndarray):
63-
value = value.tolist()
54+
if isinstance(value, Vector):
55+
value = value.to_list()
6456
return super().format_value(value)
6557

6658

6759
class VectorFormField(forms.CharField):
6860
widget = VectorWidget
6961

70-
def has_changed(self, initial: Any, data: Any) -> bool:
71-
if isinstance(initial, np.ndarray):
72-
initial = initial.tolist()
73-
return super().has_changed(initial, data)
74-
7562
def to_python(self, value: Any) -> Any:
7663
if isinstance(value, str) and value == '':
7764
return None

pgvector/peewee/vector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy as np
21
from peewee import Expression, Field
32
from typing import Any
43
from .. import Vector
@@ -17,7 +16,7 @@ def get_modifiers(self) -> list[int] | None:
1716
def db_value(self, value: object) -> str | None:
1817
return Vector._to_db(value)
1918

20-
def python_value(self, value: Any) -> np.ndarray | None:
19+
def python_value(self, value: Any) -> Vector | None:
2120
return Vector._from_db(value)
2221

2322
def _distance(self, op: str, vector: object) -> Expression:

pgvector/psycopg/vector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class VectorLoader(Loader):
3131

3232
format = Format.TEXT
3333

34-
def load(self, data: Buffer) -> np.ndarray | None:
34+
def load(self, data: Buffer) -> Vector | None:
3535
if isinstance(data, memoryview):
3636
data = bytes(data)
3737
return Vector._from_db(data.decode('utf8'))
@@ -41,7 +41,7 @@ class VectorBinaryLoader(VectorLoader):
4141

4242
format = Format.BINARY
4343

44-
def load(self, data: Buffer) -> np.ndarray | None:
44+
def load(self, data: Buffer) -> Vector | None:
4545
if isinstance(data, (bytearray, memoryview)):
4646
data = bytes(data)
4747
return Vector._from_db_binary(data)

pgvector/psycopg2/vector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def getquoted(self) -> bytes:
1111
return adapt(Vector._to_db(self._value)).getquoted()
1212

1313

14-
def cast_vector(value: str | None, cur: cursor) -> np.ndarray | None:
14+
def cast_vector(value: str | None, cur: cursor) -> Vector | None:
1515
return Vector._from_db(value)
1616

1717

pgvector/sqlalchemy/vector.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy as np
21
from sqlalchemy.dialects.postgresql.base import ischema_names
32
from sqlalchemy.types import UserDefinedType, Float, String
43
from sqlalchemy import Dialect, Operators
@@ -32,7 +31,7 @@ def process(value: Any) -> Any:
3231
return process
3332

3433
def result_processor(self, dialect: Dialect, coltype: Any) -> Any:
35-
def process(value: Any) -> np.ndarray | None:
34+
def process(value: Any) -> Vector | None:
3635
return Vector._from_db(value)
3736
return process
3837

pgvector/vector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,15 @@ def _to_db_binary(cls, value: object) -> bytes | None:
8282
return value.to_binary()
8383

8484
@classmethod
85-
def _from_db(cls, value: str | np.ndarray | None) -> np.ndarray | None:
86-
if value is None or isinstance(value, np.ndarray):
85+
def _from_db(cls, value: str | Vector | None) -> Vector | None:
86+
if value is None or isinstance(value, Vector):
8787
return value
8888

89-
return cls.from_text(value).to_numpy().astype(np.float32)
89+
return cls.from_text(value)
9090

9191
@classmethod
92-
def _from_db_binary(cls, value: bytes | np.ndarray | None) -> np.ndarray | None:
93-
if value is None or isinstance(value, np.ndarray):
92+
def _from_db_binary(cls, value: bytes | Vector | None) -> Vector | None:
93+
if value is None or isinstance(value, Vector):
9494
return value
9595

96-
return cls.from_binary(value).to_numpy().astype(np.float32)
96+
return cls.from_binary(value)

tests/test_asyncpg.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ async def test_vector(self):
2020
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)", embedding, embedding2)
2121

2222
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
23-
assert np.array_equal(res[0]['embedding'], embedding.to_numpy())
24-
assert res[0]['embedding'].dtype == np.float32
25-
assert np.array_equal(res[1]['embedding'], embedding2)
23+
assert res[0]['embedding'] == embedding
24+
assert res[1]['embedding'] == Vector(embedding2)
2625
assert res[2]['embedding'] is None
2726

2827
# ensures binary format is correct
@@ -116,10 +115,8 @@ async def test_vector_array(self):
116115
await conn.execute("INSERT INTO asyncpg_items (embeddings) VALUES (ARRAY[$1, $2]::vector[])", embeddings2[0], embeddings2[1])
117116

118117
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
119-
assert np.array_equal(res[0]['embeddings'][0], embeddings[0].to_numpy())
120-
assert np.array_equal(res[0]['embeddings'][1], embeddings[1].to_numpy())
121-
assert np.array_equal(res[1]['embeddings'][0], embeddings2[0])
122-
assert np.array_equal(res[1]['embeddings'][1], embeddings2[1])
118+
assert res[0]['embeddings'] == embeddings
119+
assert res[1]['embeddings'] == [Vector(e) for e in embeddings2]
123120

124121
await conn.close()
125122

@@ -140,7 +137,6 @@ async def init(conn):
140137
await conn.execute("INSERT INTO asyncpg_items (embedding) VALUES ($1), ($2), (NULL)", embedding, embedding2)
141138

142139
res = await conn.fetch("SELECT * FROM asyncpg_items ORDER BY id")
143-
assert np.array_equal(res[0]['embedding'], embedding.to_numpy())
144-
assert res[0]['embedding'].dtype == np.float32
145-
assert np.array_equal(res[1]['embedding'], embedding2)
140+
assert res[0]['embedding'] == embedding
141+
assert res[1]['embedding'] == Vector(embedding2)
146142
assert res[2]['embedding'] is None

tests/test_django.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import numpy as np
1515
import os
1616
import pgvector.django
17-
from pgvector import HalfVector, SparseVector
17+
from pgvector import HalfVector, SparseVector, Vector
1818
from pgvector.django import VectorExtension, VectorField, HalfVectorField, BitField, SparseVectorField, IvfflatIndex, HnswIndex, L2Distance, MaxInnerProduct, CosineDistance, L1Distance, HammingDistance, JaccardDistance
1919
from unittest import mock
2020

@@ -167,12 +167,11 @@ def setup_method(self):
167167
def test_vector(self):
168168
Item(id=1, embedding=[1, 2, 3]).save()
169169
item = Item.objects.get(pk=1)
170-
assert np.array_equal(item.embedding, [1, 2, 3])
171-
assert item.embedding.dtype == np.float32
170+
assert item.embedding == Vector([1, 2, 3])
172171

173172
def test_vector_l2_distance(self):
174173
create_items()
175-
distance = L2Distance('embedding', [1, 1, 1])
174+
distance = L2Distance('embedding', Vector([1, 1, 1]))
176175
items = Item.objects.annotate(distance=distance).order_by(distance)
177176
assert [v.id for v in items] == [1, 3, 2]
178177
assert [v.distance for v in items] == [0, 1, sqrt(3)]
@@ -295,15 +294,15 @@ def test_vector_avg(self):
295294
Item(embedding=[1, 2, 3]).save()
296295
Item(embedding=[4, 5, 6]).save()
297296
avg = Item.objects.aggregate(Avg('embedding'))['embedding__avg']
298-
assert np.array_equal(avg, [2.5, 3.5, 4.5])
297+
assert avg == Vector([2.5, 3.5, 4.5])
299298

300299
def test_vector_sum(self):
301300
sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum']
302301
assert sum is None
303302
Item(embedding=[1, 2, 3]).save()
304303
Item(embedding=[4, 5, 6]).save()
305304
sum = Item.objects.aggregate(Sum('embedding'))['embedding__sum']
306-
assert np.array_equal(sum, [5, 7, 9])
305+
assert sum == Vector([5, 7, 9])
307306

308307
def test_halfvec_avg(self):
309308
avg = Item.objects.aggregate(Avg('half_embedding'))['half_embedding__avg']
@@ -349,7 +348,7 @@ def test_vector_form_save(self):
349348
assert form.has_changed()
350349
assert form.is_valid()
351350
assert form.save()
352-
assert np.array_equal(Item.objects.get(pk=1).embedding, [4, 5, 6])
351+
assert Item.objects.get(pk=1).embedding == Vector([4, 5, 6])
353352

354353
def test_vector_form_save_missing(self):
355354
Item(id=1).save()
@@ -467,8 +466,7 @@ def test_vector_array(self):
467466

468467
# this fails if the driver does not cast arrays
469468
item = Item.objects.get(pk=1)
470-
assert np.array_equal(item.embeddings[0], [1, 2, 3])
471-
assert np.array_equal(item.embeddings[1], [4, 5, 6])
469+
assert item.embeddings == [Vector([1, 2, 3]), Vector([4, 5, 6])]
472470

473471
def test_double_array(self):
474472
Item(id=1, double_embedding=[1, 1, 1]).save()

tests/test_peewee.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from math import sqrt
22
import numpy as np
33
from peewee import Model, PostgresqlDatabase, fn
4-
from pgvector import HalfVector, SparseVector
4+
from pgvector import HalfVector, SparseVector, Vector
55
from pgvector.peewee import VectorField, HalfVectorField, FixedBitField, SparseVectorField
66

77
db = PostgresqlDatabase('pgvector_python_test')
@@ -43,8 +43,7 @@ def setup_method(self):
4343
def test_vector(self):
4444
Item.create(id=1, embedding=[1, 2, 3])
4545
item = Item.get_by_id(1)
46-
assert np.array_equal(item.embedding, [1, 2, 3])
47-
assert item.embedding.dtype == np.float32
46+
assert item.embedding == Vector([1, 2, 3])
4847

4948
def test_vector_l2_distance(self):
5049
create_items()
@@ -170,15 +169,15 @@ def test_vector_avg(self):
170169
Item.create(embedding=[1, 2, 3])
171170
Item.create(embedding=[4, 5, 6])
172171
avg = Item.select(fn.avg(Item.embedding).coerce(True)).scalar()
173-
assert np.array_equal(avg, [2.5, 3.5, 4.5])
172+
assert avg == Vector([2.5, 3.5, 4.5])
174173

175174
def test_vector_sum(self):
176175
sum = Item.select(fn.sum(Item.embedding).coerce(True)).scalar()
177176
assert sum is None
178177
Item.create(embedding=[1, 2, 3])
179178
Item.create(embedding=[4, 5, 6])
180179
sum = Item.select(fn.sum(Item.embedding).coerce(True)).scalar()
181-
assert np.array_equal(sum, [5, 7, 9])
180+
assert sum == Vector([5, 7, 9])
182181

183182
def test_halfvec_avg(self):
184183
avg = Item.select(fn.avg(Item.half_embedding).coerce(True)).scalar()

0 commit comments

Comments
 (0)