Skip to content

Commit 2744371

Browse files
[3.14] gh-149816: Fix SNI callback callable race (GH-150018) (GH-150100)
(cherry picked from commit 8b31d08) Co-authored-by: Kirill Ignatev <kiri11@users.noreply.github.com>
1 parent dc3a0b4 commit 2744371

3 files changed

Lines changed: 74 additions & 16 deletions

File tree

Lib/test/test_ssl.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,6 +1533,59 @@ def dummycallback(sock, servername, ctx, cycle=ctx):
15331533
gc.collect()
15341534
self.assertIs(wr(), None)
15351535

1536+
@unittest.skipUnless(support.Py_GIL_DISABLED,
1537+
"test is only useful if the GIL is disabled")
1538+
@threading_helper.requires_working_threading()
1539+
def test_sni_callback_race(self):
1540+
# Replacing sni_callback while handshakes are in-flight must not
1541+
# crash (use-after-free on the callback in free-threaded builds).
1542+
client_ctx, server_ctx, hostname = testing_context()
1543+
1544+
server_ctx.sni_callback = lambda *a: None
1545+
done = threading.Event()
1546+
1547+
def do_handshakes():
1548+
while not done.is_set():
1549+
c_in = ssl.MemoryBIO()
1550+
c_out = ssl.MemoryBIO()
1551+
s_in = ssl.MemoryBIO()
1552+
s_out = ssl.MemoryBIO()
1553+
client = client_ctx.wrap_bio(
1554+
c_in, c_out, server_hostname=hostname)
1555+
server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
1556+
for _ in range(50):
1557+
try:
1558+
client.do_handshake()
1559+
except ssl.SSLWantReadError:
1560+
pass
1561+
except ssl.SSLError:
1562+
break
1563+
if c_out.pending:
1564+
s_in.write(c_out.read())
1565+
try:
1566+
server.do_handshake()
1567+
except ssl.SSLWantReadError:
1568+
pass
1569+
except ssl.SSLError:
1570+
break
1571+
if s_out.pending:
1572+
c_in.write(s_out.read())
1573+
1574+
def toggle_callback():
1575+
while not done.is_set():
1576+
server_ctx.sni_callback = lambda *a: None
1577+
server_ctx.sni_callback = None
1578+
1579+
workers = max(4, (os.cpu_count() or 4) * 2)
1580+
threads = [threading.Thread(target=do_handshakes)
1581+
for _ in range(workers)]
1582+
threads.append(threading.Thread(target=toggle_callback))
1583+
1584+
with threading_helper.catch_threading_exception() as cm:
1585+
with threading_helper.start_threads(threads):
1586+
done.set()
1587+
self.assertIsNone(cm.exc_value)
1588+
15361589
def test_cert_store_stats(self):
15371590
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
15381591
self.assertEqual(ctx.cert_store_stats(),
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix race condition in :attr:`ssl.SSLContext.sni_callback`

Modules/_ssl.c

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#define OPENSSL_NO_DEPRECATED 1
2727

2828
#include "Python.h"
29+
#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
2930
#include "pycore_fileutils.h" // _PyIsSelectable_fd()
3031
#include "pycore_long.h" // _PyLong_UnsignedLongLong_Converter()
3132
#include "pycore_pyerrors.h" // _PyErr_ChainExceptions1()
@@ -4669,12 +4670,15 @@ _servername_callback(SSL *s, int *al, void *args)
46694670
PyObject *result;
46704671
/* The high-level ssl.SSLSocket object */
46714672
PyObject *ssl_socket;
4673+
PyObject *sni_cb;
46724674
const char *servername = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name);
46734675
PyGILState_STATE gstate = PyGILState_Ensure();
46744676

4675-
if (sslctx->set_sni_cb == NULL) {
4676-
/* remove race condition in this the call back while if removing the
4677-
* callback is in progress */
4677+
Py_BEGIN_CRITICAL_SECTION(sslctx);
4678+
sni_cb = Py_XNewRef(sslctx->set_sni_cb);
4679+
Py_END_CRITICAL_SECTION();
4680+
4681+
if (sni_cb == NULL) {
46784682
PyGILState_Release(gstate);
46794683
return SSL_TLSEXT_ERR_OK;
46804684
}
@@ -4701,7 +4705,7 @@ _servername_callback(SSL *s, int *al, void *args)
47014705
goto error;
47024706

47034707
if (servername == NULL) {
4704-
result = PyObject_CallFunctionObjArgs(sslctx->set_sni_cb, ssl_socket,
4708+
result = PyObject_CallFunctionObjArgs(sni_cb, ssl_socket,
47054709
Py_None, sslctx, NULL);
47064710
}
47074711
else {
@@ -4728,7 +4732,7 @@ _servername_callback(SSL *s, int *al, void *args)
47284732
}
47294733
Py_DECREF(servername_bytes);
47304734
result = PyObject_CallFunctionObjArgs(
4731-
sslctx->set_sni_cb, ssl_socket, servername_str,
4735+
sni_cb, ssl_socket, servername_str,
47324736
sslctx, NULL);
47334737
Py_DECREF(servername_str);
47344738
}
@@ -4738,7 +4742,7 @@ _servername_callback(SSL *s, int *al, void *args)
47384742
PyErr_FormatUnraisable("Exception ignored "
47394743
"in ssl servername callback "
47404744
"while calling set SNI callback %R",
4741-
sslctx->set_sni_cb);
4745+
sni_cb);
47424746
*al = SSL_AD_HANDSHAKE_FAILURE;
47434747
ret = SSL_TLSEXT_ERR_ALERT_FATAL;
47444748
}
@@ -4763,11 +4767,13 @@ _servername_callback(SSL *s, int *al, void *args)
47634767
Py_DECREF(result);
47644768
}
47654769

4770+
Py_DECREF(sni_cb);
47664771
PyGILState_Release(gstate);
47674772
return ret;
47684773

47694774
error:
47704775
Py_XDECREF(ssl_socket);
4776+
Py_XDECREF(sni_cb);
47714777
*al = SSL_AD_INTERNAL_ERROR;
47724778
ret = SSL_TLSEXT_ERR_ALERT_FATAL;
47734779
PyGILState_Release(gstate);
@@ -4813,20 +4819,18 @@ _ssl__SSLContext_sni_callback_set_impl(PySSLContext *self, PyObject *value)
48134819
"sni_callback cannot be set on TLS_CLIENT context");
48144820
return -1;
48154821
}
4816-
Py_CLEAR(self->set_sni_cb);
4817-
if (value == Py_None) {
4822+
if (!PyCallable_Check(value)) {
48184823
SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL);
4819-
}
4820-
else {
4821-
if (!PyCallable_Check(value)) {
4822-
SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL);
4823-
PyErr_SetString(PyExc_TypeError,
4824-
"not a callable object");
4824+
Py_CLEAR(self->set_sni_cb);
4825+
if (value != Py_None) {
4826+
PyErr_SetString(PyExc_TypeError, "not a callable object");
48254827
return -1;
48264828
}
4827-
self->set_sni_cb = Py_NewRef(value);
4828-
SSL_CTX_set_tlsext_servername_callback(self->ctx, _servername_callback);
4829+
}
4830+
else {
4831+
Py_XSETREF(self->set_sni_cb, Py_NewRef(value));
48294832
SSL_CTX_set_tlsext_servername_arg(self->ctx, self);
4833+
SSL_CTX_set_tlsext_servername_callback(self->ctx, _servername_callback);
48304834
}
48314835
return 0;
48324836
}

0 commit comments

Comments
 (0)