Skip to content

Commit 89d0b19

Browse files
committed
Add test_sni_callback_race
1 parent acefff9 commit 89d0b19

1 file changed

Lines changed: 50 additions & 0 deletions

File tree

Lib/test/test_ssl.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,56 @@ def dummycallback(sock, servername, ctx, cycle=ctx):
16061606
gc.collect()
16071607
self.assertIs(wr(), None)
16081608

1609+
@unittest.skipUnless(support.Py_GIL_DISABLED,
1610+
"test is only useful if the GIL is disabled")
1611+
@threading_helper.requires_working_threading()
1612+
def test_sni_callback_race(self):
1613+
# Replacing sni_callback while handshakes are in-flight must not
1614+
# crash (use-after-free on the callback in free-threaded builds).
1615+
client_ctx, server_ctx, hostname = testing_context()
1616+
server_ctx.sni_callback = lambda *a: None
1617+
done = threading.Event()
1618+
1619+
def do_handshakes():
1620+
for _ in range(100):
1621+
c_in = ssl.MemoryBIO()
1622+
c_out = ssl.MemoryBIO()
1623+
s_in = ssl.MemoryBIO()
1624+
s_out = ssl.MemoryBIO()
1625+
client = client_ctx.wrap_bio(
1626+
c_in, c_out, server_hostname=hostname)
1627+
server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
1628+
for _ in range(50):
1629+
try:
1630+
client.do_handshake()
1631+
except ssl.SSLWantReadError:
1632+
pass
1633+
except ssl.SSLError:
1634+
break
1635+
if c_out.pending:
1636+
s_in.write(c_out.read())
1637+
try:
1638+
server.do_handshake()
1639+
except ssl.SSLWantReadError:
1640+
pass
1641+
except ssl.SSLError:
1642+
break
1643+
if s_out.pending:
1644+
c_in.write(s_out.read())
1645+
1646+
def toggle_callback():
1647+
while not done.is_set():
1648+
server_ctx.sni_callback = lambda *a: None
1649+
server_ctx.sni_callback = None
1650+
1651+
threads = [threading.Thread(target=do_handshakes) for _ in range(4)]
1652+
threads.append(threading.Thread(target=toggle_callback))
1653+
1654+
with threading_helper.catch_threading_exception() as cm:
1655+
with threading_helper.start_threads(threads, done.set):
1656+
pass
1657+
self.assertIsNone(cm.exc_value)
1658+
16091659
def test_cert_store_stats(self):
16101660
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
16111661
self.assertEqual(ctx.cert_store_stats(),

0 commit comments

Comments
 (0)