2626#define OPENSSL_NO_DEPRECATED 1
2727
2828#include "Python.h"
29+ #include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
2930#include "pycore_fileutils.h" // _PyIsSelectable_fd()
3031#include "pycore_long.h" // _PyLong_UnsignedLongLong_Converter()
3132#include "pycore_pyerrors.h" // _PyErr_ChainExceptions1()
@@ -4669,12 +4670,15 @@ _servername_callback(SSL *s, int *al, void *args)
46694670 PyObject * result ;
46704671 /* The high-level ssl.SSLSocket object */
46714672 PyObject * ssl_socket ;
4673+ PyObject * sni_cb ;
46724674 const char * servername = SSL_get_servername (s , TLSEXT_NAMETYPE_host_name );
46734675 PyGILState_STATE gstate = PyGILState_Ensure ();
46744676
4675- if (sslctx -> set_sni_cb == NULL ) {
4676- /* remove race condition in this the call back while if removing the
4677- * callback is in progress */
4677+ Py_BEGIN_CRITICAL_SECTION (sslctx );
4678+ sni_cb = Py_XNewRef (sslctx -> set_sni_cb );
4679+ Py_END_CRITICAL_SECTION ();
4680+
4681+ if (sni_cb == NULL ) {
46784682 PyGILState_Release (gstate );
46794683 return SSL_TLSEXT_ERR_OK ;
46804684 }
@@ -4701,7 +4705,7 @@ _servername_callback(SSL *s, int *al, void *args)
47014705 goto error ;
47024706
47034707 if (servername == NULL ) {
4704- result = PyObject_CallFunctionObjArgs (sslctx -> set_sni_cb , ssl_socket ,
4708+ result = PyObject_CallFunctionObjArgs (sni_cb , ssl_socket ,
47054709 Py_None , sslctx , NULL );
47064710 }
47074711 else {
@@ -4728,7 +4732,7 @@ _servername_callback(SSL *s, int *al, void *args)
47284732 }
47294733 Py_DECREF (servername_bytes );
47304734 result = PyObject_CallFunctionObjArgs (
4731- sslctx -> set_sni_cb , ssl_socket , servername_str ,
4735+ sni_cb , ssl_socket , servername_str ,
47324736 sslctx , NULL );
47334737 Py_DECREF (servername_str );
47344738 }
@@ -4738,7 +4742,7 @@ _servername_callback(SSL *s, int *al, void *args)
47384742 PyErr_FormatUnraisable ("Exception ignored "
47394743 "in ssl servername callback "
47404744 "while calling set SNI callback %R" ,
4741- sslctx -> set_sni_cb );
4745+ sni_cb );
47424746 * al = SSL_AD_HANDSHAKE_FAILURE ;
47434747 ret = SSL_TLSEXT_ERR_ALERT_FATAL ;
47444748 }
@@ -4763,11 +4767,13 @@ _servername_callback(SSL *s, int *al, void *args)
47634767 Py_DECREF (result );
47644768 }
47654769
4770+ Py_DECREF (sni_cb );
47664771 PyGILState_Release (gstate );
47674772 return ret ;
47684773
47694774error :
47704775 Py_XDECREF (ssl_socket );
4776+ Py_XDECREF (sni_cb );
47714777 * al = SSL_AD_INTERNAL_ERROR ;
47724778 ret = SSL_TLSEXT_ERR_ALERT_FATAL ;
47734779 PyGILState_Release (gstate );
@@ -4813,20 +4819,18 @@ _ssl__SSLContext_sni_callback_set_impl(PySSLContext *self, PyObject *value)
48134819 "sni_callback cannot be set on TLS_CLIENT context" );
48144820 return -1 ;
48154821 }
4816- Py_CLEAR (self -> set_sni_cb );
4817- if (value == Py_None ) {
4822+ if (!PyCallable_Check (value )) {
48184823 SSL_CTX_set_tlsext_servername_callback (self -> ctx , NULL );
4819- }
4820- else {
4821- if (!PyCallable_Check (value )) {
4822- SSL_CTX_set_tlsext_servername_callback (self -> ctx , NULL );
4823- PyErr_SetString (PyExc_TypeError ,
4824- "not a callable object" );
4824+ Py_CLEAR (self -> set_sni_cb );
4825+ if (value != Py_None ) {
4826+ PyErr_SetString (PyExc_TypeError , "not a callable object" );
48254827 return -1 ;
48264828 }
4827- self -> set_sni_cb = Py_NewRef (value );
4828- SSL_CTX_set_tlsext_servername_callback (self -> ctx , _servername_callback );
4829+ }
4830+ else {
4831+ Py_XSETREF (self -> set_sni_cb , Py_NewRef (value ));
48294832 SSL_CTX_set_tlsext_servername_arg (self -> ctx , self );
4833+ SSL_CTX_set_tlsext_servername_callback (self -> ctx , _servername_callback );
48304834 }
48314835 return 0 ;
48324836}
0 commit comments