From 2cd02daf6610c098982668872a7f0f04e319e8ae Mon Sep 17 00:00:00 2001 From: Iris Date: Thu, 11 Jun 2026 13:53:42 -0700 Subject: [PATCH] add srvAllowedHostsSuffix option to srv uri --- pymongo/asynchronous/mongo_client.py | 22 ++++++++++++++++--- pymongo/asynchronous/monitor.py | 1 + pymongo/asynchronous/settings.py | 7 ++++++ pymongo/asynchronous/srv_resolver.py | 20 ++++++++++++----- pymongo/asynchronous/uri_parser.py | 8 ++++++- pymongo/common.py | 1 + pymongo/synchronous/mongo_client.py | 22 ++++++++++++++++--- pymongo/synchronous/monitor.py | 1 + pymongo/synchronous/settings.py | 7 ++++++ pymongo/synchronous/srv_resolver.py | 20 ++++++++++++----- pymongo/synchronous/uri_parser.py | 8 ++++++- pymongo/uri_parser_shared.py | 1 + .../srvAllowedHostsSuffix-mismatch.json | 5 +++++ .../srvAllowedHostsSuffix-with_dot.json | 11 ++++++++++ .../srvAllowedHostsSuffix-without_dot.json | 11 ++++++++++ 15 files changed, 125 insertions(+), 20 deletions(-) create mode 100644 test/srv_seedlist/replica-set/srvAllowedHostsSuffix-mismatch.json create mode 100644 test/srv_seedlist/replica-set/srvAllowedHostsSuffix-with_dot.json create mode 100644 test/srv_seedlist/replica-set/srvAllowedHostsSuffix-without_dot.json diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 412a13ec70..ad9d8357f6 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -808,6 +808,7 @@ def __init__( fqdn = None srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") + srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix") if len([h for h in self._host if "/" in h]) > 1: raise ConfigurationError("host must not contain multiple MongoDB URIs") for entity in self._host: @@ -858,6 +859,8 @@ def __init__( srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + if srv_allowed_hosts_suffix is None: + srv_allowed_hosts_suffix = opts.get("srvallowedhostssuffix") opts = self._normalize_and_validate_options(opts, self._seeds) # Username and password passed as kwargs override user info in URI. @@ -895,7 +898,9 @@ def __init__( self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries) - self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name) + self._init_based_on_options( + self._seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix + ) self._opened = False self._closed = False @@ -913,6 +918,7 @@ async def _resolve_srv(self) -> None: opts = common._CaseInsensitiveDictionary() srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") + srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix") for entity in self._host: # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' # it must be a URI, @@ -933,6 +939,7 @@ async def _resolve_srv(self) -> None: connect_timeout=timeout, srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, + srv_allowed_hosts_suffix=srv_allowed_hosts_suffix, ) seeds.update(res["nodelist"]) opts = res["options"] @@ -965,6 +972,8 @@ async def _resolve_srv(self) -> None: srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + if srv_allowed_hosts_suffix is None: + srv_allowed_hosts_suffix = opts.get("srvAllowedHostsSuffix") opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. @@ -974,10 +983,16 @@ async def _resolve_srv(self) -> None: username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC ) - self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + self._init_based_on_options( + seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix + ) def _init_based_on_options( - self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any + self, + seeds: Collection[tuple[str, int]], + srv_max_hosts: Any, + srv_service_name: Any, + srv_allowed_hosts_suffix: Any, ) -> None: self._event_listeners = self._options.pool_options._event_listeners self._topology_settings = TopologySettings( @@ -996,6 +1011,7 @@ def _init_based_on_options( load_balanced=self._options.load_balanced, srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, + srv_allowed_hosts_suffix=srv_allowed_hosts_suffix, server_monitoring_mode=self._options.server_monitoring_mode, topology_id=self._topology_settings._topology_id if self._topology_settings else None, ) diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index 45c12b219f..a0ee5e50ac 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -418,6 +418,7 @@ async def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: self._fqdn, self._settings.pool_options.connect_timeout, self._settings.srv_service_name, + srv_allowed_hosts_suffix=self._settings.srv_allowed_hosts_suffix, ) seedlist, ttl = await resolver.get_hosts_and_min_ttl() if len(seedlist) == 0: diff --git a/pymongo/asynchronous/settings.py b/pymongo/asynchronous/settings.py index 9c2331971a..40ee8482bf 100644 --- a/pymongo/asynchronous/settings.py +++ b/pymongo/asynchronous/settings.py @@ -50,6 +50,7 @@ def __init__( load_balanced: Optional[bool] = None, srv_service_name: str = common.SRV_SERVICE_NAME, srv_max_hosts: int = 0, + srv_allowed_hosts_suffix: Optional[str] = None, server_monitoring_mode: str = common.SERVER_MONITORING_MODE, topology_id: Optional[ObjectId] = None, ): @@ -78,6 +79,7 @@ def __init__( self._load_balanced = load_balanced self._srv_service_name = srv_service_name self._srv_max_hosts = srv_max_hosts or 0 + self._srv_allowed_hosts_suffix = srv_allowed_hosts_suffix self._server_monitoring_mode = server_monitoring_mode if topology_id is not None: self._topology_id = topology_id @@ -155,6 +157,11 @@ def srv_max_hosts(self) -> int: """The srvMaxHosts.""" return self._srv_max_hosts + @property + def srv_allowed_hosts_suffix(self) -> Optional[str]: + """The srvAllowedHostsSuffix.""" + return self._srv_allowed_hosts_suffix + @property def server_monitoring_mode(self) -> str: """The serverMonitoringMode.""" diff --git a/pymongo/asynchronous/srv_resolver.py b/pymongo/asynchronous/srv_resolver.py index 9c4d9a9d57..d5e82c2086 100644 --- a/pymongo/asynchronous/srv_resolver.py +++ b/pymongo/asynchronous/srv_resolver.py @@ -70,11 +70,15 @@ def __init__( connect_timeout: Optional[float], srv_service_name: str, srv_max_hosts: int = 0, + srv_allowed_hosts_suffix: Optional[str] = None, ): self.__fqdn = fqdn self.__srv = srv_service_name self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT self.__srv_max_hosts = srv_max_hosts or 0 + self.__srv_allowed_hosts_suffix = ( + "." + srv_allowed_hosts_suffix.lower().lstrip(".") if srv_allowed_hosts_suffix else None + ) # ensure there's a . at the beginning of the domain # Validate the fully qualified domain name. try: ipaddress.ip_address(fqdn) @@ -134,12 +138,16 @@ async def _get_srv_response_and_hosts( raise ConfigurationError( "Invalid SRV host: return address is identical to SRV hostname" ) - try: - nlist = srv_host.split(".")[1:][-self.__slen :] - except Exception as exc: - raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc - if self.__plist != nlist: - raise ConfigurationError(f"Invalid SRV host: {node[0]}") + if self.__srv_allowed_hosts_suffix is not None: + if not srv_host.endswith(self.__srv_allowed_hosts_suffix): + raise ConfigurationError(f"Invalid SRV host: {node[0]}") + else: + try: + nlist = srv_host.split(".")[1:][-self.__slen :] + except Exception as exc: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc + if self.__plist != nlist: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") if self.__srv_max_hosts: nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) return results, nodes diff --git a/pymongo/asynchronous/uri_parser.py b/pymongo/asynchronous/uri_parser.py index 055b04d75a..e86e59dd6c 100644 --- a/pymongo/asynchronous/uri_parser.py +++ b/pymongo/asynchronous/uri_parser.py @@ -47,6 +47,7 @@ async def parse_uri( connect_timeout: Optional[float] = None, srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, + srv_allowed_hosts_suffix: Optional[str] = None, ) -> dict[str, Any]: """Parse and validate a MongoDB URI. @@ -115,6 +116,7 @@ async def parse_uri( connect_timeout, srv_service_name, srv_max_hosts, + srv_allowed_hosts_suffix, ) ) result["options"] = _make_options_case_sensitive(result["options"]) @@ -130,6 +132,7 @@ async def _parse_srv( connect_timeout: Optional[float] = None, srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, + srv_allowed_hosts_suffix: Optional[str] = None, ) -> dict[str, Any]: if uri.startswith(SCHEME): is_srv = False @@ -157,6 +160,7 @@ async def _parse_srv( hosts = unquote_plus(hosts) srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + srv_allowed_hosts_suffix = srv_allowed_hosts_suffix or options.get("srvAllowedHostsSuffix") if is_srv: nodes = split_hosts(hosts, default_port=None) fqdn, port = nodes[0] @@ -164,7 +168,9 @@ async def _parse_srv( # Use the connection timeout. connectTimeoutMS passed as a keyword # argument overrides the same option passed in the connection string. connect_timeout = connect_timeout or options.get("connectTimeoutMS") - dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) + dns_resolver = _SrvResolver( + fqdn, connect_timeout, srv_service_name, srv_max_hosts, srv_allowed_hosts_suffix + ) nodes = await dns_resolver.get_hosts() dns_options = await dns_resolver.get_options() if dns_options: diff --git a/pymongo/common.py b/pymongo/common.py index ea349b3d23..85db10d2f0 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -721,6 +721,7 @@ def validate_server_monitoring_mode(option: str, value: str) -> str: "zlibcompressionlevel": validate_zlib_compression_level, "srvservicename": validate_string, "srvmaxhosts": validate_non_negative_integer, + "srvallowedhostssuffix": validate_string, "timeoutms": validate_timeoutms, "servermonitoringmode": validate_server_monitoring_mode, "maxadaptiveretries": validate_non_negative_integer, diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 2bd6f31b72..85523babf6 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -808,6 +808,7 @@ def __init__( fqdn = None srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") + srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix") if len([h for h in self._host if "/" in h]) > 1: raise ConfigurationError("host must not contain multiple MongoDB URIs") for entity in self._host: @@ -858,6 +859,8 @@ def __init__( srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + if srv_allowed_hosts_suffix is None: + srv_allowed_hosts_suffix = opts.get("srvallowedhostssuffix") opts = self._normalize_and_validate_options(opts, self._seeds) # Username and password passed as kwargs override user info in URI. @@ -895,7 +898,9 @@ def __init__( self._retry_policy = _RetryPolicy(attempts=self._options.max_adaptive_retries) - self._init_based_on_options(self._seeds, srv_max_hosts, srv_service_name) + self._init_based_on_options( + self._seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix + ) self._opened = False self._closed = False @@ -913,6 +918,7 @@ def _resolve_srv(self) -> None: opts = common._CaseInsensitiveDictionary() srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") + srv_allowed_hosts_suffix = keyword_opts.get("srvallowedhostssuffix") for entity in self._host: # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' # it must be a URI, @@ -933,6 +939,7 @@ def _resolve_srv(self) -> None: connect_timeout=timeout, srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, + srv_allowed_hosts_suffix=srv_allowed_hosts_suffix, ) seeds.update(res["nodelist"]) opts = res["options"] @@ -965,6 +972,8 @@ def _resolve_srv(self) -> None: srv_service_name = opts.get("srvServiceName", common.SRV_SERVICE_NAME) srv_max_hosts = srv_max_hosts or opts.get("srvmaxhosts") + if srv_allowed_hosts_suffix is None: + srv_allowed_hosts_suffix = opts.get("srvAllowedHostsSuffix") opts = self._normalize_and_validate_options(opts, seeds) # Username and password passed as kwargs override user info in URI. @@ -974,10 +983,16 @@ def _resolve_srv(self) -> None: username, password, self._resolve_srv_info["dbase"], opts, _IS_SYNC ) - self._init_based_on_options(seeds, srv_max_hosts, srv_service_name) + self._init_based_on_options( + seeds, srv_max_hosts, srv_service_name, srv_allowed_hosts_suffix + ) def _init_based_on_options( - self, seeds: Collection[tuple[str, int]], srv_max_hosts: Any, srv_service_name: Any + self, + seeds: Collection[tuple[str, int]], + srv_max_hosts: Any, + srv_service_name: Any, + srv_allowed_hosts_suffix: Any, ) -> None: self._event_listeners = self._options.pool_options._event_listeners self._topology_settings = TopologySettings( @@ -996,6 +1011,7 @@ def _init_based_on_options( load_balanced=self._options.load_balanced, srv_service_name=srv_service_name, srv_max_hosts=srv_max_hosts, + srv_allowed_hosts_suffix=srv_allowed_hosts_suffix, server_monitoring_mode=self._options.server_monitoring_mode, topology_id=self._topology_settings._topology_id if self._topology_settings else None, ) diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index f395588814..9ecc42505c 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -416,6 +416,7 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: self._fqdn, self._settings.pool_options.connect_timeout, self._settings.srv_service_name, + srv_allowed_hosts_suffix=self._settings.srv_allowed_hosts_suffix, ) seedlist, ttl = resolver.get_hosts_and_min_ttl() if len(seedlist) == 0: diff --git a/pymongo/synchronous/settings.py b/pymongo/synchronous/settings.py index 61b86fa18d..ea54fca3f9 100644 --- a/pymongo/synchronous/settings.py +++ b/pymongo/synchronous/settings.py @@ -50,6 +50,7 @@ def __init__( load_balanced: Optional[bool] = None, srv_service_name: str = common.SRV_SERVICE_NAME, srv_max_hosts: int = 0, + srv_allowed_hosts_suffix: Optional[str] = None, server_monitoring_mode: str = common.SERVER_MONITORING_MODE, topology_id: Optional[ObjectId] = None, ): @@ -78,6 +79,7 @@ def __init__( self._load_balanced = load_balanced self._srv_service_name = srv_service_name self._srv_max_hosts = srv_max_hosts or 0 + self._srv_allowed_hosts_suffix = srv_allowed_hosts_suffix self._server_monitoring_mode = server_monitoring_mode if topology_id is not None: self._topology_id = topology_id @@ -155,6 +157,11 @@ def srv_max_hosts(self) -> int: """The srvMaxHosts.""" return self._srv_max_hosts + @property + def srv_allowed_hosts_suffix(self) -> Optional[str]: + """The srvAllowedHostsSuffix.""" + return self._srv_allowed_hosts_suffix + @property def server_monitoring_mode(self) -> str: """The serverMonitoringMode.""" diff --git a/pymongo/synchronous/srv_resolver.py b/pymongo/synchronous/srv_resolver.py index 4802310698..8d26c2fb28 100644 --- a/pymongo/synchronous/srv_resolver.py +++ b/pymongo/synchronous/srv_resolver.py @@ -70,11 +70,15 @@ def __init__( connect_timeout: Optional[float], srv_service_name: str, srv_max_hosts: int = 0, + srv_allowed_hosts_suffix: Optional[str] = None, ): self.__fqdn = fqdn self.__srv = srv_service_name self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT self.__srv_max_hosts = srv_max_hosts or 0 + self.__srv_allowed_hosts_suffix = ( + "." + srv_allowed_hosts_suffix.lower().lstrip(".") if srv_allowed_hosts_suffix else None + ) # ensure there's a . at the beginning of the domain # Validate the fully qualified domain name. try: ipaddress.ip_address(fqdn) @@ -134,12 +138,16 @@ def _get_srv_response_and_hosts( raise ConfigurationError( "Invalid SRV host: return address is identical to SRV hostname" ) - try: - nlist = srv_host.split(".")[1:][-self.__slen :] - except Exception as exc: - raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc - if self.__plist != nlist: - raise ConfigurationError(f"Invalid SRV host: {node[0]}") + if self.__srv_allowed_hosts_suffix is not None: + if not srv_host.endswith(self.__srv_allowed_hosts_suffix): + raise ConfigurationError(f"Invalid SRV host: {node[0]}") + else: + try: + nlist = srv_host.split(".")[1:][-self.__slen :] + except Exception as exc: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") from exc + if self.__plist != nlist: + raise ConfigurationError(f"Invalid SRV host: {node[0]}") if self.__srv_max_hosts: nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) return results, nodes diff --git a/pymongo/synchronous/uri_parser.py b/pymongo/synchronous/uri_parser.py index 45c1752953..2ebf24fb15 100644 --- a/pymongo/synchronous/uri_parser.py +++ b/pymongo/synchronous/uri_parser.py @@ -47,6 +47,7 @@ def parse_uri( connect_timeout: Optional[float] = None, srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, + srv_allowed_hosts_suffix: Optional[str] = None, ) -> dict[str, Any]: """Parse and validate a MongoDB URI. @@ -115,6 +116,7 @@ def parse_uri( connect_timeout, srv_service_name, srv_max_hosts, + srv_allowed_hosts_suffix, ) ) result["options"] = _make_options_case_sensitive(result["options"]) @@ -130,6 +132,7 @@ def _parse_srv( connect_timeout: Optional[float] = None, srv_service_name: Optional[str] = None, srv_max_hosts: Optional[int] = None, + srv_allowed_hosts_suffix: Optional[str] = None, ) -> dict[str, Any]: if uri.startswith(SCHEME): is_srv = False @@ -157,6 +160,7 @@ def _parse_srv( hosts = unquote_plus(hosts) srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + srv_allowed_hosts_suffix = srv_allowed_hosts_suffix or options.get("srvAllowedHostsSuffix") if is_srv: nodes = split_hosts(hosts, default_port=None) fqdn, port = nodes[0] @@ -164,7 +168,9 @@ def _parse_srv( # Use the connection timeout. connectTimeoutMS passed as a keyword # argument overrides the same option passed in the connection string. connect_timeout = connect_timeout or options.get("connectTimeoutMS") - dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) + dns_resolver = _SrvResolver( + fqdn, connect_timeout, srv_service_name, srv_max_hosts, srv_allowed_hosts_suffix + ) nodes = dns_resolver.get_hosts() dns_options = dns_resolver.get_options() if dns_options: diff --git a/pymongo/uri_parser_shared.py b/pymongo/uri_parser_shared.py index 59168d1e9f..0c9cf5909b 100644 --- a/pymongo/uri_parser_shared.py +++ b/pymongo/uri_parser_shared.py @@ -88,6 +88,7 @@ "socketTimeoutMS", "srvMaxHosts", "srvServiceName", + "srvAllowedHostsSuffix", "ssl", "tls", "tlsAllowInvalidCertificates", diff --git a/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-mismatch.json b/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-mismatch.json new file mode 100644 index 0000000000..d8892d2ebe --- /dev/null +++ b/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-mismatch.json @@ -0,0 +1,5 @@ +{ + "uri": "mongodb+srv://test12.test.build.10gen.cc/?srvAllowedHostsSuffix=test.build.10gen.cc", + "seeds": [], + "hosts": [] +} \ No newline at end of file diff --git a/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-with_dot.json b/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-with_dot.json new file mode 100644 index 0000000000..95f3ae854c --- /dev/null +++ b/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-with_dot.json @@ -0,0 +1,11 @@ +{ + "uri": "mongodb+srv://test12.test.build.10gen.cc/?srvAllowedHostsSuffix=.build.10gen.cc", + "seeds": [ + "localhost.build.10gen.cc:27017" + ], + "options": { + "srvAllowedHostsSuffix": ".build.10gen.cc", + "ssl": true + }, + "ping": false +} \ No newline at end of file diff --git a/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-without_dot.json b/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-without_dot.json new file mode 100644 index 0000000000..d8a9ec5340 --- /dev/null +++ b/test/srv_seedlist/replica-set/srvAllowedHostsSuffix-without_dot.json @@ -0,0 +1,11 @@ +{ + "uri": "mongodb+srv://test12.test.build.10gen.cc/?srvAllowedHostsSuffix=build.10gen.cc", + "seeds": [ + "localhost.build.10gen.cc:27017" + ], + "options": { + "srvAllowedHostsSuffix": "build.10gen.cc", + "ssl": true + }, + "ping": false +} \ No newline at end of file