@@ -1613,11 +1613,19 @@ def test_sni_callback_race(self):
16131613 # Replacing sni_callback while handshakes are in-flight must not
16141614 # crash (use-after-free on the callback in free-threaded builds).
16151615 client_ctx , server_ctx , hostname = testing_context ()
1616- server_ctx .sni_callback = lambda * a : None
1616+
1617+ def make_callback (n ):
1618+ def sni_cb (_ssl_obj , _servername , _ctx ):
1619+ if n == - 1 and _servername == "" :
1620+ raise AssertionError ("unreachable" )
1621+ return None
1622+ return sni_cb
1623+
1624+ server_ctx .sni_callback = make_callback (0 )
16171625 done = threading .Event ()
16181626
16191627 def do_handshakes ():
1620- for _ in range ( 100 ):
1628+ while not done . is_set ( ):
16211629 c_in = ssl .MemoryBIO ()
16221630 c_out = ssl .MemoryBIO ()
16231631 s_in = ssl .MemoryBIO ()
@@ -1644,16 +1652,21 @@ def do_handshakes():
16441652 c_in .write (s_out .read ())
16451653
16461654 def toggle_callback ():
1655+ i = 0
16471656 while not done .is_set ():
1648- server_ctx .sni_callback = lambda * a : None
1657+ server_ctx .sni_callback = make_callback ( i )
16491658 server_ctx .sni_callback = None
1659+ server_ctx .sni_callback = make_callback (- i )
1660+ i += 1
16501661
1651- threads = [threading .Thread (target = do_handshakes ) for _ in range (4 )]
1662+ workers = max (4 , (os .cpu_count () or 4 ) * 2 )
1663+ threads = [threading .Thread (target = do_handshakes )
1664+ for _ in range (workers )]
16521665 threads .append (threading .Thread (target = toggle_callback ))
16531666
16541667 with threading_helper .catch_threading_exception () as cm :
1655- with threading_helper .start_threads (threads , done . set ):
1656- pass
1668+ with threading_helper .start_threads (threads ):
1669+ done . set ()
16571670 self .assertIsNone (cm .exc_value )
16581671
16591672 def test_cert_store_stats (self ):
0 commit comments