|
6 | 6 |
|
7 | 7 | try: |
8 | 8 | import numpy as np |
| 9 | + NUMPY_AVAILABLE = True |
9 | 10 | except ImportError: |
10 | | - np = None |
| 11 | + NUMPY_AVAILABLE = False |
11 | 12 |
|
12 | 13 | conn = psycopg.connect(dbname='pgvector_python_test', autocommit=True) |
13 | 14 |
|
@@ -45,26 +46,26 @@ def test_vector_binary_format_correct(self): |
45 | 46 | res = next(conn.execute('SELECT %b::vector::text', (embedding,)))[0] |
46 | 47 | assert res == '[1.5,2,3]' |
47 | 48 |
|
48 | | - @pytest.mark.skipif(np is None, reason='NumPy required') |
| 49 | + @pytest.mark.skipif(NUMPY_AVAILABLE, reason='NumPy required') |
49 | 50 | def test_vector_numpy_binary_format(self): |
50 | 51 | embedding = np.array([1.5, 2, 3]) |
51 | 52 | res = next(conn.execute('SELECT %b::vector', (embedding,), binary=True))[0] |
52 | 53 | assert res == Vector(embedding) |
53 | 54 |
|
54 | | - @pytest.mark.skipif(np is None, reason='NumPy required') |
| 55 | + @pytest.mark.skipif(NUMPY_AVAILABLE, reason='NumPy required') |
55 | 56 | def test_vector_numpy_text_format(self): |
56 | 57 | embedding = np.array([1.5, 2, 3]) |
57 | 58 | res = next(conn.execute('SELECT %t::vector', (embedding,)))[0] |
58 | 59 | assert res == Vector(embedding) |
59 | 60 |
|
60 | | - @pytest.mark.skipif(np is None, reason='NumPy required') |
| 61 | + @pytest.mark.skipif(NUMPY_AVAILABLE, reason='NumPy required') |
61 | 62 | def test_vector_numpy_binary_format_non_contiguous(self): |
62 | 63 | embedding = np.flipud(np.array([1.5, 2, 3])) |
63 | 64 | assert not embedding.data.contiguous |
64 | 65 | res = next(conn.execute('SELECT %b::vector', (embedding,)))[0] |
65 | 66 | assert res == Vector([3, 2, 1.5]) |
66 | 67 |
|
67 | | - @pytest.mark.skipif(np is None, reason='NumPy required') |
| 68 | + @pytest.mark.skipif(NUMPY_AVAILABLE, reason='NumPy required') |
68 | 69 | def test_vector_numpy_text_format_non_contiguous(self): |
69 | 70 | embedding = np.flipud(np.array([1.5, 2, 3])) |
70 | 71 | assert not embedding.data.contiguous |
|
0 commit comments