Skip to content

Commit 616ec4f

Browse files
committed
Fix inventory flushing
- it's a combo of handling a tuple and pickling of raw data, and is triggered under rare occasions
1 parent a9cc560 commit 616ec4f

1 file changed

Lines changed: 196 additions & 0 deletions

File tree

src/tests/test_inventory_flush.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""
2+
Integration test for SqliteInventory.flush()
3+
4+
Runs with both the producer (SqliteInventory) and the consumer (sqlThread)
5+
active, to verify that inventory items with various payload types survive
6+
the queue round-trip and land correctly in the database.
7+
8+
Reproduces the bug where BLOB fields (payload, tag) passed as memoryview
9+
or bytearray cause: "Error binding parameter 3 - probably unsupported type."
10+
"""
11+
12+
import os
13+
import sqlite3
14+
import struct
15+
import tempfile
16+
import threading
17+
import time
18+
import unittest
19+
20+
from .common import skip_python3
21+
22+
skip_python3()
23+
24+
os.environ['BITMESSAGE_HOME'] = tempfile.mkdtemp()
25+
26+
from pybitmessage.helper_sql import ( # noqa:E402
27+
sqlQuery, sql_ready, sqlStoredProcedure)
28+
from pybitmessage.class_sqlThread import sqlThread # noqa:E402
29+
from pybitmessage.storage.sqlite import SqliteInventory # noqa:E402
30+
31+
32+
class TestInventoryFlush(unittest.TestCase):
33+
"""
34+
Integration test: exercises flush() end-to-end with the real sqlThread
35+
consumer running, so that type errors in parameter binding surface here
36+
rather than silently killing a production thread.
37+
"""
38+
39+
@classmethod
40+
def setUpClass(cls):
41+
sql_lookup = sqlThread()
42+
sql_lookup.daemon = True
43+
sql_lookup.start()
44+
sql_ready.wait()
45+
cls.inventory = SqliteInventory()
46+
47+
@classmethod
48+
def tearDownClass(cls):
49+
sqlStoredProcedure('exit')
50+
for thread in threading.enumerate():
51+
if thread.name == "SQL":
52+
thread.join(timeout=10)
53+
54+
# -- helpers ----------------------------------------------------------
55+
56+
@staticmethod
57+
def _make_hash(seed):
58+
"""Return a 32-byte hash derived from *seed*."""
59+
return (b'\x00' * 31 + bytes([seed & 0xFF]))[-32:]
60+
61+
def _insert_and_flush(self, obj_hash, obj_type, stream,
62+
payload, expires, tag):
63+
"""
64+
Put one item into the in-memory inventory, call flush(),
65+
and return the row that the sqlThread actually wrote.
66+
"""
67+
self.inventory[obj_hash] = (obj_type, stream, payload, expires, tag)
68+
self.inventory.flush()
69+
70+
rows = sqlQuery(
71+
'SELECT objecttype, streamnumber, payload, expirestime, tag'
72+
' FROM inventory WHERE hash=?',
73+
sqlite3.Binary(obj_hash))
74+
return rows
75+
76+
def _cleanup_hash(self, obj_hash):
77+
"""Remove a test row so tests stay independent."""
78+
from pybitmessage.helper_sql import sqlExecute
79+
sqlExecute('DELETE FROM inventory WHERE hash=?',
80+
sqlite3.Binary(obj_hash))
81+
82+
# -- test cases -------------------------------------------------------
83+
84+
def test_flush_with_bytes_payload(self):
85+
"""Baseline: payload and tag are plain bytes — must always work."""
86+
h = self._make_hash(1)
87+
payload = b'\x80\x01' + os.urandom(64)
88+
tag = b'\xff' * 32
89+
expires = int(time.time()) + 3600
90+
91+
rows = self._insert_and_flush(h, 2, 1, payload, expires, tag)
92+
try:
93+
self.assertEqual(len(rows), 1, "Row not found after flush")
94+
self.assertEqual(rows[0][0], 2)
95+
self.assertEqual(rows[0][1], 1)
96+
self.assertEqual(bytes(rows[0][2]), payload)
97+
self.assertEqual(rows[0][3], expires)
98+
self.assertEqual(bytes(rows[0][4]), tag)
99+
finally:
100+
self._cleanup_hash(h)
101+
102+
def test_flush_with_memoryview_payload(self):
103+
"""
104+
Reproduce the production crash:
105+
payload and tag as memoryview objects cause
106+
'Error binding parameter 3 - probably unsupported type.'
107+
"""
108+
h = self._make_hash(2)
109+
raw_payload = b'\x80\x02' + os.urandom(64)
110+
raw_tag = b'\xee' * 32
111+
expires = int(time.time()) + 3600
112+
113+
rows = self._insert_and_flush(
114+
h, 2, 1, memoryview(raw_payload), expires, memoryview(raw_tag))
115+
try:
116+
self.assertEqual(len(rows), 1,
117+
"Row not found — flush likely crashed on "
118+
"memoryview parameters")
119+
self.assertEqual(bytes(rows[0][2]), raw_payload)
120+
self.assertEqual(bytes(rows[0][4]), raw_tag)
121+
finally:
122+
self._cleanup_hash(h)
123+
124+
def test_flush_with_bytearray_payload(self):
125+
"""bytearray is another bytes-like type that could trip sqlite3."""
126+
h = self._make_hash(3)
127+
raw_payload = bytearray(b'\x80\x03' + os.urandom(64))
128+
raw_tag = bytearray(b'\xdd' * 32)
129+
expires = int(time.time()) + 3600
130+
131+
rows = self._insert_and_flush(
132+
h, 2, 1, raw_payload, expires, raw_tag)
133+
try:
134+
self.assertEqual(len(rows), 1,
135+
"Row not found — flush likely crashed on "
136+
"bytearray parameters")
137+
self.assertEqual(bytes(rows[0][2]), bytes(raw_payload))
138+
self.assertEqual(bytes(rows[0][4]), bytes(raw_tag))
139+
finally:
140+
self._cleanup_hash(h)
141+
142+
def test_flush_with_empty_tag(self):
143+
"""Empty tag (b'') must not break the INSERT."""
144+
h = self._make_hash(4)
145+
payload = b'\x80\x04' + os.urandom(64)
146+
expires = int(time.time()) + 3600
147+
148+
rows = self._insert_and_flush(h, 2, 1, payload, expires, b'')
149+
try:
150+
self.assertEqual(len(rows), 1)
151+
self.assertEqual(bytes(rows[0][4]), b'')
152+
finally:
153+
self._cleanup_hash(h)
154+
155+
def test_flush_multiple_items(self):
156+
"""Flush a batch and verify every row arrives."""
157+
count = 20
158+
hashes = [self._make_hash(0x10 + i) for i in range(count)]
159+
expires = int(time.time()) + 3600
160+
161+
for i, h in enumerate(hashes):
162+
payload = struct.pack('>I', i) + os.urandom(60)
163+
tag = struct.pack('>I', i) + b'\x00' * 28
164+
# mix types on purpose
165+
if i % 3 == 0:
166+
payload = memoryview(payload)
167+
tag = memoryview(tag)
168+
elif i % 3 == 1:
169+
payload = bytearray(payload)
170+
tag = bytearray(tag)
171+
self.inventory[h] = (2, 1, payload, expires, tag)
172+
173+
self.inventory.flush()
174+
175+
try:
176+
for i, h in enumerate(hashes):
177+
rows = sqlQuery(
178+
'SELECT objecttype FROM inventory WHERE hash=?',
179+
sqlite3.Binary(h))
180+
self.assertEqual(
181+
len(rows), 1,
182+
"Item {} missing after batch flush".format(i))
183+
finally:
184+
for h in hashes:
185+
self._cleanup_hash(h)
186+
187+
def test_flush_clears_memory_cache(self):
188+
"""After flush the in-memory _inventory dict must be empty."""
189+
h = self._make_hash(0xF0)
190+
self.inventory[h] = (
191+
2, 1, b'\x00' * 32, int(time.time()) + 3600, b'')
192+
self.inventory.flush()
193+
try:
194+
self.assertEqual(len(self.inventory._inventory), 0)
195+
finally:
196+
self._cleanup_hash(h)

0 commit comments

Comments
 (0)