diff --git a/pymongo/_telemetry.py b/pymongo/_telemetry.py deleted file mode 100644 index 962e12642a..0000000000 --- a/pymongo/_telemetry.py +++ /dev/null @@ -1,194 +0,0 @@ -# Copyright 2025-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unified per-command telemetry: logging, monitoring, and OpenTelemetry. - -Currently wires the command logging channel; APM event publishing and -OpenTelemetry spans are layered on top of :class:`_CommandTelemetry`. -""" -from __future__ import annotations - -import datetime -import logging -from typing import Any, Mapping, Optional - -from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log -from pymongo.message import _convert_exception - - -class _CommandTelemetry: - """Context manager for per-command telemetry. - - Currently wires the command *logging* channel only; the name reflects the - intended scope as the APM and OpenTelemetry channels are added on top. - - Logs the ``STARTED`` event on entry, then ``SUCCEEDED`` or ``FAILED`` once - the outcome is known. Call :meth:`handle_succeeded` with the server reply - on success or :meth:`handle_failed` with the raised exception on error; if - an exception propagates out of the ``with`` block without either being - called, the ``FAILED`` event is logged automatically from ``__exit__``. - - This consolidates command *logging* only -- APM event publishing remains - at the call site. The context manager owns the duration clock (from the - ``start`` time passed in) and exposes it via :attr:`duration`, and stores - the computed failure document on :attr:`failure`, so callers can reuse both - for APM events. A future change can extend this class to publish monitoring - (and OpenTelemetry) events alongside logging. - - Usage:: - - with _CommandTelemetry(client, conn, cmd, dbname, request_id, start) as cmd_telemetry: - reply = do_network_call() - duration = cmd_telemetry.handle_succeeded(reply) - # Failures are logged automatically in __exit__. - """ - - __slots__ = ( - "_client", - "_conn", - "_spec", - "_dbname", - "_request_id", - "_operation_id", - "_start", - "duration", - "failure", - "_handled", - ) - - def __init__( - self, - client: Any, - conn: Any, - spec: Mapping[str, Any], - dbname: str, - request_id: int, - start: datetime.datetime, - operation_id: Optional[int] = None, - ) -> None: - self._client = client - self._conn = conn - self._spec = spec - self._dbname = dbname - self._request_id = request_id - self._operation_id = request_id if operation_id is None else operation_id - self._start = start - self.duration: Optional[datetime.timedelta] = None - self.failure: Any = None - self._handled = False - - def _enabled(self) -> bool: - return self._client is not None and _COMMAND_LOGGER.isEnabledFor(logging.DEBUG) - - def __enter__(self) -> _CommandTelemetry: - if self._enabled(): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.STARTED, - clientId=self._client._topology_settings._topology_id, - command=self._spec, - commandName=next(iter(self._spec)), - databaseName=self._dbname, - requestId=self._request_id, - operationId=self._operation_id, - driverConnectionId=self._conn.id, - serverConnectionId=self._conn.server_connection_id, - serverHost=self._conn.address[0], - serverPort=self._conn.address[1], - serviceId=self._conn.service_id, - ) - return self - - def handle_succeeded( - self, - reply: Any, - speculative_hello: bool = False, - ) -> datetime.timedelta: - """Log the ``SUCCEEDED`` event and return the elapsed duration.""" - self.duration = datetime.datetime.now() - self._start - if self._enabled(): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.SUCCEEDED, - clientId=self._client._topology_settings._topology_id, - durationMS=self.duration, - reply=reply, - commandName=next(iter(self._spec)), - databaseName=self._dbname, - requestId=self._request_id, - operationId=self._operation_id, - driverConnectionId=self._conn.id, - serverConnectionId=self._conn.server_connection_id, - serverHost=self._conn.address[0], - serverPort=self._conn.address[1], - serviceId=self._conn.service_id, - speculative_authenticate=speculative_hello, - ) - self._handled = True - return self.duration - - def handle_failed( - self, - exc: BaseException, - failure: Optional[Any] = None, - is_server_side_error: Optional[bool] = None, - ) -> datetime.timedelta: - """Log the ``FAILED`` event and return the elapsed duration. - - The failure document and server-side-error flag are derived from *exc* - for the common case. Callers that must transform the failure document - (e.g. unacknowledged bulk writes) pass *failure* explicitly. The - computed failure is stored on :attr:`failure` for reuse by APM events. - """ - self.duration = datetime.datetime.now() - self._start - if failure is None: - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure = exc.details - else: - failure = _convert_exception(exc) # type: ignore[arg-type] - if is_server_side_error is None: - is_server_side_error = isinstance(exc, OperationFailure) - self.failure = failure - if self._enabled(): - _debug_log( - _COMMAND_LOGGER, - message=_CommandStatusMessage.FAILED, - clientId=self._client._topology_settings._topology_id, - durationMS=self.duration, - failure=failure, - commandName=next(iter(self._spec)), - databaseName=self._dbname, - requestId=self._request_id, - operationId=self._operation_id, - driverConnectionId=self._conn.id, - serverConnectionId=self._conn.server_connection_id, - serverHost=self._conn.address[0], - serverPort=self._conn.address[1], - serviceId=self._conn.service_id, - isServerSideError=is_server_side_error, - ) - self._handled = True - return self.duration - - def __exit__( - self, - exc_type: Optional[type], - exc_val: Optional[BaseException], - exc_tb: Any, - ) -> None: - # Safety net: log a failure if an exception propagates without the - # outcome having been recorded explicitly by the caller. - if exc_val is not None and not self._handled: - self.handle_failed(exc_val) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 69f0f8fa87..4a54f9eb3f 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -19,6 +19,8 @@ from __future__ import annotations import copy +import datetime +import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -34,7 +36,6 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common -from pymongo._telemetry import _CommandTelemetry from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern from pymongo.asynchronous.helpers import _handle_reauth from pymongo.bulk_shared import ( @@ -56,6 +57,7 @@ OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _DELETE, _INSERT, @@ -250,25 +252,78 @@ async def write_command( ) -> dict[str, Any]: """A proxy for SocketInfo.write_command that handles event publishing.""" cmd[bwc.field] = docs - with _CommandTelemetry( - client, bwc.conn, cmd, bwc.db_name, request_id, bwc.start_time - ) as cmd_telemetry: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._start(cmd, request_id, docs) + try: + reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] + duration = datetime.datetime.now() - bwc.start_time + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) if bwc.publish: - bwc._start(cmd, request_id, docs) - try: - reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] - duration = cmd_telemetry.handle_succeeded(reply) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] - await client._process_response(reply, bwc.session) # type: ignore[arg-type] - except Exception as exc: - duration = cmd_telemetry.handle_failed(exc) - if bwc.publish: - bwc._fail(request_id, cmd_telemetry.failure, duration) - # Process the response from the server. - if isinstance(exc, (NotPrimaryError, OperationFailure)): - await client._process_response(exc.details, bwc.session) # type: ignore[arg-type] - raise + bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + await client._process_response(reply, bwc.session) # type: ignore[arg-type] + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + + if bwc.publish: + bwc._fail(request_id, failure, duration) + # Process the response from the server. + if isinstance(exc, (NotPrimaryError, OperationFailure)): + await client._process_response(exc.details, bwc.session) # type: ignore[arg-type] + raise return reply # type: ignore[return-value] async def unack_write( @@ -282,33 +337,81 @@ async def unack_write( client: AsyncMongoClient[Any], ) -> Optional[Mapping[str, Any]]: """A proxy for AsyncConnection.unack_write that handles event publishing.""" - with _CommandTelemetry( - client, bwc.conn, cmd, bwc.db_name, request_id, bwc.start_time - ) as cmd_telemetry: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + cmd = bwc._start(cmd, request_id, docs) + try: + result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] + duration = datetime.datetime.now() - bwc.start_time + if result is not None: + reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] + else: + # Comply with APM spec. + reply = {"ok": 1} + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) if bwc.publish: - cmd = bwc._start(cmd, request_id, docs) - try: - result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - duration = cmd_telemetry.handle_succeeded(reply) - if bwc.publish: - bwc._succeed(request_id, reply, duration) - except Exception as exc: - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - duration = cmd_telemetry.handle_failed(exc, failure=failure) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) - raise + bwc._succeed(request_id, reply, duration) + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, OperationFailure): + failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] + elif isinstance(exc, NotPrimaryError): + failure = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if bwc.publish: + assert bwc.start_time is not None + bwc._fail(request_id, failure, duration) + raise return result # type: ignore[return-value] async def _execute_batch_unack( diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 2e195696af..015947d7ef 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -19,6 +19,8 @@ from __future__ import annotations import copy +import datetime +import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -46,7 +48,6 @@ _merge_command, _throw_client_bulk_write_exception, ) -from pymongo._telemetry import _CommandTelemetry from pymongo.common import ( validate_is_document_type, validate_ok_for_replace, @@ -62,6 +63,7 @@ WaitQueueTimeoutError, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _ClientBulkWriteContext, _convert_client_bulk_exception, @@ -237,29 +239,82 @@ async def write_command( """A proxy for AsyncConnection.write_command that handles event publishing.""" cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs - with _CommandTelemetry( - client, bwc.conn, cmd, bwc.db_name, request_id, bwc.start_time - ) as cmd_telemetry: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._start(cmd, request_id, op_docs, ns_docs) + try: + reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] + duration = datetime.datetime.now() - bwc.start_time + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) if bwc.publish: - bwc._start(cmd, request_id, op_docs, ns_docs) - try: - reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] - duration = cmd_telemetry.handle_succeeded(reply) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] - # Process the response from the server. - await self.client._process_response(reply, bwc.session) # type: ignore[arg-type] - except Exception as exc: - duration = cmd_telemetry.handle_failed(exc) - if bwc.publish: - bwc._fail(request_id, cmd_telemetry.failure, duration) - # Top-level error will be embedded in ClientBulkWriteException. - reply = {"error": exc} - # Process the response from the server. - if isinstance(exc, OperationFailure): - await self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] - else: - await self.client._process_response({}, bwc.session) # type: ignore[arg-type] + bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + # Process the response from the server. + await self.client._process_response(reply, bwc.session) # type: ignore[arg-type] + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + + if bwc.publish: + bwc._fail(request_id, failure, duration) + # Top-level error will be embedded in ClientBulkWriteException. + reply = {"error": exc} + # Process the response from the server. + if isinstance(exc, OperationFailure): + await self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] + else: + await self.client._process_response({}, bwc.session) # type: ignore[arg-type] return reply # type: ignore[return-value] async def unack_write( @@ -273,34 +328,82 @@ async def unack_write( client: AsyncMongoClient[Any], ) -> Optional[Mapping[str, Any]]: """A proxy for AsyncConnection.unack_write that handles event publishing.""" - with _CommandTelemetry( - client, bwc.conn, cmd, bwc.db_name, request_id, bwc.start_time - ) as cmd_telemetry: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + cmd = bwc._start(cmd, request_id, op_docs, ns_docs) + try: + result = await bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override] + duration = datetime.datetime.now() - bwc.start_time + if result is not None: + reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] + else: + # Comply with APM spec. + reply = {"ok": 1} + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) if bwc.publish: - cmd = bwc._start(cmd, request_id, op_docs, ns_docs) - try: - result = await bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override] - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - duration = cmd_telemetry.handle_succeeded(reply) - if bwc.publish: - bwc._succeed(request_id, reply, duration) - except Exception as exc: - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - duration = cmd_telemetry.handle_failed(exc, failure=failure) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) - # Top-level error will be embedded in ClientBulkWriteException. - reply = {"error": exc} + bwc._succeed(request_id, reply, duration) + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, OperationFailure): + failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] + elif isinstance(exc, NotPrimaryError): + failure = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if bwc.publish: + assert bwc.start_time is not None + bwc._fail(request_id, failure, duration) + # Top-level error will be embedded in ClientBulkWriteException. + reply = {"error": exc} return reply async def _execute_batch_unack( diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index 34194899e1..5a59c67a15 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -189,10 +189,15 @@ async def _send_message(self, operation: _GetMore) -> None: if isinstance(response, PinnedResponse): if not self._sock_mgr: self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type] - cursor = response.docs[0]["cursor"] - documents = cursor["nextBatch"] - self._postbatchresumetoken = cursor.get("postBatchResumeToken") - self._id = cursor["id"] + if response.from_command: + cursor = response.docs[0]["cursor"] + documents = cursor["nextBatch"] + self._postbatchresumetoken = cursor.get("postBatchResumeToken") + self._id = cursor["id"] + else: + documents = response.docs + assert isinstance(response.data, _OpReply) + self._id = response.data.cursor_id if self._id == 0: await self.close() diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index f7c1671777..a60c082ade 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -1020,23 +1020,29 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None: cmd_name = operation.name docs = response.docs - if cmd_name != "explain": - cursor = docs[0]["cursor"] - self._id = cursor["id"] - if cmd_name == "find": - documents = cursor["firstBatch"] - # Update the namespace used for future getMore commands. - ns = cursor.get("ns") - if ns: - self._dbname, self._collname = ns.split(".", 1) + if response.from_command: + if cmd_name != "explain": + cursor = docs[0]["cursor"] + self._id = cursor["id"] + if cmd_name == "find": + documents = cursor["firstBatch"] + # Update the namespace used for future getMore commands. + ns = cursor.get("ns") + if ns: + self._dbname, self._collname = ns.split(".", 1) + else: + documents = cursor["nextBatch"] + self._data = deque(documents) + self._retrieved += len(documents) else: - documents = cursor["nextBatch"] - self._data = deque(documents) - self._retrieved += len(documents) + self._id = 0 + self._data = deque(docs) + self._retrieved += len(docs) else: - self._id = 0 + assert isinstance(response.data, _OpReply) + self._id = response.data.cursor_id self._data = deque(docs) - self._retrieved += len(docs) + self._retrieved += response.data.number_returned if self._id == 0: # Don't wait for garbage collection to call __del__, return the diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index 3e73492efa..5a5dc7fa2c 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -16,10 +16,10 @@ from __future__ import annotations import datetime +import logging from typing import ( TYPE_CHECKING, Any, - Callable, Mapping, MutableMapping, Optional, @@ -30,10 +30,18 @@ from bson import _decode_all_selective from pymongo import _csot, helpers_shared, message -from pymongo._telemetry import _CommandTelemetry from pymongo.compression_support import _NO_COMPRESSION -from pymongo.message import _OpMsg, _OpReply +from pymongo.errors import ( + NotPrimaryError, + OperationFailure, +) +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate +from pymongo.network_layer import ( + async_receive_message, + async_sendall, +) if TYPE_CHECKING: from bson import CodecOptions @@ -49,143 +57,6 @@ _IS_SYNC = False -_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} - - -async def _network_command_core( - conn: AsyncConnection, - dbname: str, - spec: MutableMapping[str, Any], - request_id: int, - msg: Optional[bytes], - max_doc_size: int, - codec_options: CodecOptions[_DocumentType], - session: Optional[AsyncClientSession], - client: Optional[AsyncMongoClient[Any]], - listeners: Optional[_EventListeners], - address: Optional[_Address], - start: datetime.datetime, - check: bool = True, - allowable_errors: Optional[Sequence[Union[str, int]]] = None, - parse_write_concern_error: bool = False, - user_fields: Optional[Mapping[str, Any]] = None, - unacknowledged: bool = False, - more_to_come: bool = False, - unpack_res: Optional[Callable[..., list[_DocumentOut]]] = None, - cursor_id: Optional[int] = None, - orig: Optional[MutableMapping[str, Any]] = None, - speculative_hello: bool = False, -) -> tuple[list[_DocumentOut], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: - """Send/receive a command and return (docs, raw_reply, duration). - - Handles APM logging, send/receive, unpacking, response processing, - and decryption. Both the standard command path and the cursor - (find/getMore) path go through this function. - """ - publish = listeners is not None and listeners.enabled_for_commands - name = next(iter(spec)) - reply: Optional[Union[_OpReply, _OpMsg]] = None - docs: list[_DocumentOut] = [] - - with _CommandTelemetry(client, conn, spec, dbname, request_id, start) as cmd_telemetry: - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig if orig is not None else spec, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - if more_to_come: - reply = await conn.receive_message(None) - else: - assert msg is not None - # Only enforce the client-side document size limit for - # unacknowledged writes (where the server cannot report an - # error). For acknowledged writes, pass 0 so an oversized - # document is sent and the server returns an OperationFailure, - # matching the historical behavior. - await conn.send_message(msg, max_doc_size if unacknowledged else 0) - if unacknowledged: - # Unacknowledged write: fake a successful command response. - docs = [{"ok": 1}] # type: ignore[list-item] - else: - reply = await conn.receive_message(request_id) - - if reply is not None: - conn.more_to_come = reply.more_to_come - if unpack_res is not None: - docs = unpack_res( - reply, - cursor_id, - codec_options, - legacy_response=False, - user_fields=_CURSOR_DOC_FIELDS, - ) - else: - docs = list( - reply.unpack_response(codec_options=codec_options, user_fields=user_fields) - ) - response_doc = docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - await client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - pool_opts=conn.opts, - ) - except Exception as exc: - duration = cmd_telemetry.handle_failed(exc) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - cmd_telemetry.failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = cmd_telemetry.handle_succeeded(docs[0], speculative_hello) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - docs[0], - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - speculative_hello=speculative_hello, - database_name=dbname, - ) - - # Decrypt response. - if client and client._encrypter and reply is not None: - decrypted = await client._encrypter.decrypt(reply.raw_command_response()) - decrypt_fields = _CURSOR_DOC_FIELDS if unpack_res is not None else user_fields - docs = list(_decode_all_selective(decrypted, codec_options, decrypt_fields)) # type: ignore[arg-type] - - return docs, reply, duration - async def command( conn: AsyncConnection, @@ -285,30 +156,143 @@ async def command( request_id, msg, size = message._query( 0, ns, 0, -1, spec, None, codec_options, compression_ctx ) - max_doc_size = 0 if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) + if client is not None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=spec, + commandName=next(iter(spec)), + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_start( + orig, + dbname, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + ) + + try: + await async_sendall(conn.conn.get_conn, msg) + if use_op_msg and unacknowledged: + # Unacknowledged, fake a successful command response. + reply = None + response_doc: _DocumentOut = {"ok": 1} + else: + reply = await async_receive_message(conn, request_id) + conn.more_to_come = reply.more_to_come + unpacked_docs = reply.unpack_response( + codec_options=codec_options, user_fields=user_fields + ) + + response_doc = unpacked_docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + await client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) + except Exception as exc: + duration = datetime.datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = message._convert_exception(exc) + if client is not None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(spec)), + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_failure( + duration, + failure, + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + database_name=dbname, + ) + raise + duration = datetime.datetime.now() - start + if client is not None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=response_doc, + commandName=next(iter(spec)), + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + speculative_authenticate="speculativeAuthenticate" in orig, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_success( + duration, + response_doc, + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + speculative_hello=speculative_hello, + database_name=dbname, + ) + + if client and client._encrypter and reply: + decrypted = await client._encrypter.decrypt(reply.raw_command_response()) + response_doc = cast( + "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] + ) - docs, _reply, _duration = await _network_command_core( - conn=conn, - dbname=dbname, - spec=spec, - request_id=request_id, - msg=msg, - max_doc_size=max_doc_size, - codec_options=codec_options, - session=session, - client=client, - listeners=listeners, - address=address, - start=start, - check=check, - allowable_errors=allowable_errors, - parse_write_concern_error=parse_write_concern_error, - user_fields=user_fields, - unacknowledged=unacknowledged, - orig=orig, - speculative_hello=speculative_hello, - ) - return cast("_DocumentType", docs[0]) + return response_doc # type: ignore[return-value] diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 92855fbaa2..f212306174 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -26,14 +26,18 @@ Union, ) +from bson import _decode_all_selective from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.asynchronous.network import _network_command_core +from pymongo.errors import NotPrimaryError, OperationFailure +from pymongo.helpers_shared import _check_command_response from pymongo.logger import ( + _COMMAND_LOGGER, _SDAM_LOGGER, + _CommandStatusMessage, _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _GetMore, _Query +from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query from pymongo.response import PinnedResponse, Response if TYPE_CHECKING: @@ -51,6 +55,8 @@ _IS_SYNC = False +_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} + class Server: def __init__( @@ -152,51 +158,171 @@ async def run_operation( :param client: An AsyncMongoClient instance. """ assert listeners is not None + publish = listeners.enabled_for_commands start = datetime.now() - # All supported servers have wire version >= 8, so use_command() always - # returns True and the command (OP_MSG) path is always used; the legacy - # OP_QUERY/OP_GET_MORE path is dead. Call it for its session-validation - # side effect and assert the invariant rather than ignoring the result. use_cmd = operation.use_command(conn) - assert use_cmd - more_to_come = bool(operation.conn_mgr and operation.conn_mgr.more_to_come) + more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come cmd, dbn = await self.operation_to_command(operation, conn, use_cmd) if more_to_come: request_id = 0 - msg = None - max_doc_size = 0 else: - op_message = operation.get_message(read_preference, conn, use_cmd) - request_id, msg, max_doc_size = self._split_message(op_message) - - if listeners.enabled_for_commands and "$db" not in cmd: - cmd["$db"] = dbn - - docs, reply, duration = await _network_command_core( - conn=conn, - dbname=dbn, - spec=cmd, - request_id=request_id, - msg=msg, - max_doc_size=max_doc_size, - codec_options=operation.codec_options, - session=operation.session, # type: ignore[arg-type] - client=client, - listeners=listeners, - address=conn.address, - start=start, - more_to_come=more_to_come, - unpack_res=unpack_res, - cursor_id=operation.cursor_id, - ) - - assert reply is not None - response: Response + message = operation.get_message(read_preference, conn, use_cmd) + request_id, data, max_doc_size = self._split_message(message) + + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=next(iter(cmd)), + databaseName=dbn, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + + if publish: + if "$db" not in cmd: + cmd["$db"] = dbn + assert listeners is not None + listeners.publish_command_start( + cmd, + dbn, + request_id, + conn.address, + conn.server_connection_id, + service_id=conn.service_id, + ) + + try: + if more_to_come: + reply = await conn.receive_message(None) + else: + await conn.send_message(data, max_doc_size) + reply = await conn.receive_message(request_id) + + # Unpack and check for command errors. + if use_cmd: + user_fields = _CURSOR_DOC_FIELDS + legacy_response = False + else: + user_fields = None + legacy_response = True + docs = unpack_res( + reply, + operation.cursor_id, + operation.codec_options, + legacy_response=legacy_response, + user_fields=user_fields, + ) + if use_cmd: + first = docs[0] + await operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] + _check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type] + except Exception as exc: + duration = datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=dbn, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + listeners.publish_command_failure( + duration, + failure, + operation.name, + request_id, + conn.address, + conn.server_connection_id, + service_id=conn.service_id, + database_name=dbn, + ) + raise + duration = datetime.now() - start + # Must publish in find / getMore / explain command response + # format. + if use_cmd: + res = docs[0] + elif operation.name == "explain": + res = docs[0] if docs else {} + else: + res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr] + if operation.name == "find": + res["cursor"]["firstBatch"] = docs + else: + res["cursor"]["nextBatch"] = docs + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=res, + commandName=next(iter(cmd)), + databaseName=dbn, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + if publish: + assert listeners is not None + listeners.publish_command_success( + duration, + res, + operation.name, + request_id, + conn.address, + conn.server_connection_id, + service_id=conn.service_id, + database_name=dbn, + ) + + # Decrypt response. client = operation.client # type: ignore[assignment] + if client and client._encrypter: + if use_cmd: + decrypted = await client._encrypter.decrypt(reply.raw_command_response()) + docs = _decode_all_selective(decrypted, operation.codec_options, user_fields) + + response: Response + if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type] conn.pin_cursor() - more_to_come = reply.more_to_come # type: ignore[union-attr] + if isinstance(reply, _OpMsg): + # In OP_MSG, the server keeps sending only if the + # more_to_come flag is set. + more_to_come = reply.more_to_come + else: + # In OP_REPLY, the server keeps sending until cursor_id is 0. + more_to_come = bool(operation.exhaust and reply.cursor_id) if operation.conn_mgr: operation.conn_mgr.update_exhaust(more_to_come) response = PinnedResponse( diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 6fc337c3e0..22d6a7a76a 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -19,6 +19,8 @@ from __future__ import annotations import copy +import datetime +import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -34,7 +36,6 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common -from pymongo._telemetry import _CommandTelemetry from pymongo.bulk_shared import ( _COMMANDS, _DELETE_ALL, @@ -54,6 +55,7 @@ OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _DELETE, _INSERT, @@ -250,25 +252,78 @@ def write_command( ) -> dict[str, Any]: """A proxy for SocketInfo.write_command that handles event publishing.""" cmd[bwc.field] = docs - with _CommandTelemetry( - client, bwc.conn, cmd, bwc.db_name, request_id, bwc.start_time - ) as cmd_telemetry: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._start(cmd, request_id, docs) + try: + reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] + duration = datetime.datetime.now() - bwc.start_time + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) if bwc.publish: - bwc._start(cmd, request_id, docs) - try: - reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] - duration = cmd_telemetry.handle_succeeded(reply) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] - client._process_response(reply, bwc.session) # type: ignore[arg-type] - except Exception as exc: - duration = cmd_telemetry.handle_failed(exc) - if bwc.publish: - bwc._fail(request_id, cmd_telemetry.failure, duration) - # Process the response from the server. - if isinstance(exc, (NotPrimaryError, OperationFailure)): - client._process_response(exc.details, bwc.session) # type: ignore[arg-type] - raise + bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + client._process_response(reply, bwc.session) # type: ignore[arg-type] + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + + if bwc.publish: + bwc._fail(request_id, failure, duration) + # Process the response from the server. + if isinstance(exc, (NotPrimaryError, OperationFailure)): + client._process_response(exc.details, bwc.session) # type: ignore[arg-type] + raise return reply # type: ignore[return-value] def unack_write( @@ -282,33 +337,81 @@ def unack_write( client: MongoClient[Any], ) -> Optional[Mapping[str, Any]]: """A proxy for Connection.unack_write that handles event publishing.""" - with _CommandTelemetry( - client, bwc.conn, cmd, bwc.db_name, request_id, bwc.start_time - ) as cmd_telemetry: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + cmd = bwc._start(cmd, request_id, docs) + try: + result = bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] + duration = datetime.datetime.now() - bwc.start_time + if result is not None: + reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] + else: + # Comply with APM spec. + reply = {"ok": 1} + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) if bwc.publish: - cmd = bwc._start(cmd, request_id, docs) - try: - result = bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - duration = cmd_telemetry.handle_succeeded(reply) - if bwc.publish: - bwc._succeed(request_id, reply, duration) - except Exception as exc: - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - duration = cmd_telemetry.handle_failed(exc, failure=failure) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) - raise + bwc._succeed(request_id, reply, duration) + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, OperationFailure): + failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] + elif isinstance(exc, NotPrimaryError): + failure = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if bwc.publish: + assert bwc.start_time is not None + bwc._fail(request_id, failure, duration) + raise return result # type: ignore[return-value] def _execute_batch_unack( diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index d80e26dee4..1134594ae9 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -19,6 +19,8 @@ from __future__ import annotations import copy +import datetime +import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -46,7 +48,6 @@ _merge_command, _throw_client_bulk_write_exception, ) -from pymongo._telemetry import _CommandTelemetry from pymongo.common import ( validate_is_document_type, validate_ok_for_replace, @@ -62,6 +63,7 @@ WaitQueueTimeoutError, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _ClientBulkWriteContext, _convert_client_bulk_exception, @@ -237,29 +239,82 @@ def write_command( """A proxy for Connection.write_command that handles event publishing.""" cmd["ops"] = op_docs cmd["nsInfo"] = ns_docs - with _CommandTelemetry( - client, bwc.conn, cmd, bwc.db_name, request_id, bwc.start_time - ) as cmd_telemetry: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._start(cmd, request_id, op_docs, ns_docs) + try: + reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] + duration = datetime.datetime.now() - bwc.start_time + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) if bwc.publish: - bwc._start(cmd, request_id, op_docs, ns_docs) - try: - reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc, arg-type] - duration = cmd_telemetry.handle_succeeded(reply) - if bwc.publish: - bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] - # Process the response from the server. - self.client._process_response(reply, bwc.session) # type: ignore[arg-type] - except Exception as exc: - duration = cmd_telemetry.handle_failed(exc) - if bwc.publish: - bwc._fail(request_id, cmd_telemetry.failure, duration) - # Top-level error will be embedded in ClientBulkWriteException. - reply = {"error": exc} - # Process the response from the server. - if isinstance(exc, OperationFailure): - self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] - else: - self.client._process_response({}, bwc.session) # type: ignore[arg-type] + bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + # Process the response from the server. + self.client._process_response(reply, bwc.session) # type: ignore[arg-type] + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + + if bwc.publish: + bwc._fail(request_id, failure, duration) + # Top-level error will be embedded in ClientBulkWriteException. + reply = {"error": exc} + # Process the response from the server. + if isinstance(exc, OperationFailure): + self.client._process_response(exc.details, bwc.session) # type: ignore[arg-type] + else: + self.client._process_response({}, bwc.session) # type: ignore[arg-type] return reply # type: ignore[return-value] def unack_write( @@ -273,34 +328,82 @@ def unack_write( client: MongoClient[Any], ) -> Optional[Mapping[str, Any]]: """A proxy for Connection.unack_write that handles event publishing.""" - with _CommandTelemetry( - client, bwc.conn, cmd, bwc.db_name, request_id, bwc.start_time - ) as cmd_telemetry: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + cmd = bwc._start(cmd, request_id, op_docs, ns_docs) + try: + result = bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override] + duration = datetime.datetime.now() - bwc.start_time + if result is not None: + reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] + else: + # Comply with APM spec. + reply = {"ok": 1} + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) if bwc.publish: - cmd = bwc._start(cmd, request_id, op_docs, ns_docs) - try: - result = bwc.conn.unack_write(msg, bwc.max_bson_size) # type: ignore[func-returns-value, misc, override] - if result is not None: - reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] - else: - # Comply with APM spec. - reply = {"ok": 1} - duration = cmd_telemetry.handle_succeeded(reply) - if bwc.publish: - bwc._succeed(request_id, reply, duration) - except Exception as exc: - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - duration = cmd_telemetry.handle_failed(exc, failure=failure) - if bwc.publish: - assert bwc.start_time is not None - bwc._fail(request_id, failure, duration) - # Top-level error will be embedded in ClientBulkWriteException. - reply = {"error": exc} + bwc._succeed(request_id, reply, duration) + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, OperationFailure): + failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] + elif isinstance(exc, NotPrimaryError): + failure = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if bwc.publish: + assert bwc.start_time is not None + bwc._fail(request_id, failure, duration) + # Top-level error will be embedded in ClientBulkWriteException. + reply = {"error": exc} return reply def _execute_batch_unack( diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index b2023ad5de..34f60c6540 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -189,10 +189,15 @@ def _send_message(self, operation: _GetMore) -> None: if isinstance(response, PinnedResponse): if not self._sock_mgr: self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type] - cursor = response.docs[0]["cursor"] - documents = cursor["nextBatch"] - self._postbatchresumetoken = cursor.get("postBatchResumeToken") - self._id = cursor["id"] + if response.from_command: + cursor = response.docs[0]["cursor"] + documents = cursor["nextBatch"] + self._postbatchresumetoken = cursor.get("postBatchResumeToken") + self._id = cursor["id"] + else: + documents = response.docs + assert isinstance(response.data, _OpReply) + self._id = response.data.cursor_id if self._id == 0: self.close() diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 909671dac5..5a721d8e06 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -1018,23 +1018,29 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None: cmd_name = operation.name docs = response.docs - if cmd_name != "explain": - cursor = docs[0]["cursor"] - self._id = cursor["id"] - if cmd_name == "find": - documents = cursor["firstBatch"] - # Update the namespace used for future getMore commands. - ns = cursor.get("ns") - if ns: - self._dbname, self._collname = ns.split(".", 1) + if response.from_command: + if cmd_name != "explain": + cursor = docs[0]["cursor"] + self._id = cursor["id"] + if cmd_name == "find": + documents = cursor["firstBatch"] + # Update the namespace used for future getMore commands. + ns = cursor.get("ns") + if ns: + self._dbname, self._collname = ns.split(".", 1) + else: + documents = cursor["nextBatch"] + self._data = deque(documents) + self._retrieved += len(documents) else: - documents = cursor["nextBatch"] - self._data = deque(documents) - self._retrieved += len(documents) + self._id = 0 + self._data = deque(docs) + self._retrieved += len(docs) else: - self._id = 0 + assert isinstance(response.data, _OpReply) + self._id = response.data.cursor_id self._data = deque(docs) - self._retrieved += len(docs) + self._retrieved += response.data.number_returned if self._id == 0: # Don't wait for garbage collection to call __del__, return the diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index a7def8d89a..7d9bca4d58 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -16,10 +16,10 @@ from __future__ import annotations import datetime +import logging from typing import ( TYPE_CHECKING, Any, - Callable, Mapping, MutableMapping, Optional, @@ -30,10 +30,18 @@ from bson import _decode_all_selective from pymongo import _csot, helpers_shared, message -from pymongo._telemetry import _CommandTelemetry from pymongo.compression_support import _NO_COMPRESSION -from pymongo.message import _OpMsg, _OpReply +from pymongo.errors import ( + NotPrimaryError, + OperationFailure, +) +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _OpMsg from pymongo.monitoring import _is_speculative_authenticate +from pymongo.network_layer import ( + receive_message, + sendall, +) if TYPE_CHECKING: from bson import CodecOptions @@ -49,143 +57,6 @@ _IS_SYNC = True -_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} - - -def _network_command_core( - conn: Connection, - dbname: str, - spec: MutableMapping[str, Any], - request_id: int, - msg: Optional[bytes], - max_doc_size: int, - codec_options: CodecOptions[_DocumentType], - session: Optional[ClientSession], - client: Optional[MongoClient[Any]], - listeners: Optional[_EventListeners], - address: Optional[_Address], - start: datetime.datetime, - check: bool = True, - allowable_errors: Optional[Sequence[Union[str, int]]] = None, - parse_write_concern_error: bool = False, - user_fields: Optional[Mapping[str, Any]] = None, - unacknowledged: bool = False, - more_to_come: bool = False, - unpack_res: Optional[Callable[..., list[_DocumentOut]]] = None, - cursor_id: Optional[int] = None, - orig: Optional[MutableMapping[str, Any]] = None, - speculative_hello: bool = False, -) -> tuple[list[_DocumentOut], Optional[Union[_OpReply, _OpMsg]], datetime.timedelta]: - """Send/receive a command and return (docs, raw_reply, duration). - - Handles APM logging, send/receive, unpacking, response processing, - and decryption. Both the standard command path and the cursor - (find/getMore) path go through this function. - """ - publish = listeners is not None and listeners.enabled_for_commands - name = next(iter(spec)) - reply: Optional[Union[_OpReply, _OpMsg]] = None - docs: list[_DocumentOut] = [] - - with _CommandTelemetry(client, conn, spec, dbname, request_id, start) as cmd_telemetry: - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_start( - orig if orig is not None else spec, - dbname, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - ) - - try: - if more_to_come: - reply = conn.receive_message(None) - else: - assert msg is not None - # Only enforce the client-side document size limit for - # unacknowledged writes (where the server cannot report an - # error). For acknowledged writes, pass 0 so an oversized - # document is sent and the server returns an OperationFailure, - # matching the historical behavior. - conn.send_message(msg, max_doc_size if unacknowledged else 0) - if unacknowledged: - # Unacknowledged write: fake a successful command response. - docs = [{"ok": 1}] # type: ignore[list-item] - else: - reply = conn.receive_message(request_id) - - if reply is not None: - conn.more_to_come = reply.more_to_come - if unpack_res is not None: - docs = unpack_res( - reply, - cursor_id, - codec_options, - legacy_response=False, - user_fields=_CURSOR_DOC_FIELDS, - ) - else: - docs = list( - reply.unpack_response(codec_options=codec_options, user_fields=user_fields) - ) - response_doc = docs[0] - if not conn.ready: - cluster_time = response_doc.get("$clusterTime") - if cluster_time: - conn._cluster_time = cluster_time - if client: - client._process_response(response_doc, session) - if check: - helpers_shared._check_command_response( - response_doc, - conn.max_wire_version, - allowable_errors, - parse_write_concern_error=parse_write_concern_error, - pool_opts=conn.opts, - ) - except Exception as exc: - duration = cmd_telemetry.handle_failed(exc) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_failure( - duration, - cmd_telemetry.failure, - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - database_name=dbname, - ) - raise - duration = cmd_telemetry.handle_succeeded(docs[0], speculative_hello) - if publish: - assert listeners is not None - assert address is not None - listeners.publish_command_success( - duration, - docs[0], - name, - request_id, - address, - conn.server_connection_id, - service_id=conn.service_id, - speculative_hello=speculative_hello, - database_name=dbname, - ) - - # Decrypt response. - if client and client._encrypter and reply is not None: - decrypted = client._encrypter.decrypt(reply.raw_command_response()) - decrypt_fields = _CURSOR_DOC_FIELDS if unpack_res is not None else user_fields - docs = list(_decode_all_selective(decrypted, codec_options, decrypt_fields)) # type: ignore[arg-type] - - return docs, reply, duration - def command( conn: Connection, @@ -285,30 +156,143 @@ def command( request_id, msg, size = message._query( 0, ns, 0, -1, spec, None, codec_options, compression_ctx ) - max_doc_size = 0 if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) + if client is not None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=spec, + commandName=next(iter(spec)), + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_start( + orig, + dbname, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + ) + + try: + sendall(conn.conn.get_conn, msg) + if use_op_msg and unacknowledged: + # Unacknowledged, fake a successful command response. + reply = None + response_doc: _DocumentOut = {"ok": 1} + else: + reply = receive_message(conn, request_id) + conn.more_to_come = reply.more_to_come + unpacked_docs = reply.unpack_response( + codec_options=codec_options, user_fields=user_fields + ) + + response_doc = unpacked_docs[0] + if not conn.ready: + cluster_time = response_doc.get("$clusterTime") + if cluster_time: + conn._cluster_time = cluster_time + if client: + client._process_response(response_doc, session) + if check: + helpers_shared._check_command_response( + response_doc, + conn.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) + except Exception as exc: + duration = datetime.datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = message._convert_exception(exc) + if client is not None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(spec)), + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_failure( + duration, + failure, + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + database_name=dbname, + ) + raise + duration = datetime.datetime.now() - start + if client is not None: + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=response_doc, + commandName=next(iter(spec)), + databaseName=dbname, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + speculative_authenticate="speculativeAuthenticate" in orig, + ) + if publish: + assert listeners is not None + assert address is not None + listeners.publish_command_success( + duration, + response_doc, + name, + request_id, + address, + conn.server_connection_id, + service_id=conn.service_id, + speculative_hello=speculative_hello, + database_name=dbname, + ) + + if client and client._encrypter and reply: + decrypted = client._encrypter.decrypt(reply.raw_command_response()) + response_doc = cast( + "_DocumentOut", _decode_all_selective(decrypted, codec_options, user_fields)[0] + ) - docs, _reply, _duration = _network_command_core( - conn=conn, - dbname=dbname, - spec=spec, - request_id=request_id, - msg=msg, - max_doc_size=max_doc_size, - codec_options=codec_options, - session=session, - client=client, - listeners=listeners, - address=address, - start=start, - check=check, - allowable_errors=allowable_errors, - parse_write_concern_error=parse_write_concern_error, - user_fields=user_fields, - unacknowledged=unacknowledged, - orig=orig, - speculative_hello=speculative_hello, - ) - return cast("_DocumentType", docs[0]) + return response_doc # type: ignore[return-value] diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index ba18fec35d..f57420918b 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -26,15 +26,19 @@ Union, ) +from bson import _decode_all_selective +from pymongo.errors import NotPrimaryError, OperationFailure +from pymongo.helpers_shared import _check_command_response from pymongo.logger import ( + _COMMAND_LOGGER, _SDAM_LOGGER, + _CommandStatusMessage, _debug_log, _SDAMStatusMessage, ) -from pymongo.message import _GetMore, _Query +from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query from pymongo.response import PinnedResponse, Response from pymongo.synchronous.helpers import _handle_reauth -from pymongo.synchronous.network import _network_command_core if TYPE_CHECKING: from queue import Queue @@ -51,6 +55,8 @@ _IS_SYNC = True +_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} + class Server: def __init__( @@ -152,51 +158,171 @@ def run_operation( :param client: A MongoClient instance. """ assert listeners is not None + publish = listeners.enabled_for_commands start = datetime.now() - # All supported servers have wire version >= 8, so use_command() always - # returns True and the command (OP_MSG) path is always used; the legacy - # OP_QUERY/OP_GET_MORE path is dead. Call it for its session-validation - # side effect and assert the invariant rather than ignoring the result. use_cmd = operation.use_command(conn) - assert use_cmd - more_to_come = bool(operation.conn_mgr and operation.conn_mgr.more_to_come) + more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come cmd, dbn = self.operation_to_command(operation, conn, use_cmd) if more_to_come: request_id = 0 - msg = None - max_doc_size = 0 else: - op_message = operation.get_message(read_preference, conn, use_cmd) - request_id, msg, max_doc_size = self._split_message(op_message) - - if listeners.enabled_for_commands and "$db" not in cmd: - cmd["$db"] = dbn - - docs, reply, duration = _network_command_core( - conn=conn, - dbname=dbn, - spec=cmd, - request_id=request_id, - msg=msg, - max_doc_size=max_doc_size, - codec_options=operation.codec_options, - session=operation.session, # type: ignore[arg-type] - client=client, - listeners=listeners, - address=conn.address, - start=start, - more_to_come=more_to_come, - unpack_res=unpack_res, - cursor_id=operation.cursor_id, - ) - - assert reply is not None - response: Response + message = operation.get_message(read_preference, conn, use_cmd) + request_id, data, max_doc_size = self._split_message(message) + + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.STARTED, + clientId=client._topology_settings._topology_id, + command=cmd, + commandName=next(iter(cmd)), + databaseName=dbn, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + + if publish: + if "$db" not in cmd: + cmd["$db"] = dbn + assert listeners is not None + listeners.publish_command_start( + cmd, + dbn, + request_id, + conn.address, + conn.server_connection_id, + service_id=conn.service_id, + ) + + try: + if more_to_come: + reply = conn.receive_message(None) + else: + conn.send_message(data, max_doc_size) + reply = conn.receive_message(request_id) + + # Unpack and check for command errors. + if use_cmd: + user_fields = _CURSOR_DOC_FIELDS + legacy_response = False + else: + user_fields = None + legacy_response = True + docs = unpack_res( + reply, + operation.cursor_id, + operation.codec_options, + legacy_response=legacy_response, + user_fields=user_fields, + ) + if use_cmd: + first = docs[0] + operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] + _check_command_response(first, conn.max_wire_version, pool_opts=conn.opts) # type:ignore[has-type] + except Exception as exc: + duration = datetime.now() - start + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.FAILED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=dbn, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if publish: + assert listeners is not None + listeners.publish_command_failure( + duration, + failure, + operation.name, + request_id, + conn.address, + conn.server_connection_id, + service_id=conn.service_id, + database_name=dbn, + ) + raise + duration = datetime.now() - start + # Must publish in find / getMore / explain command response + # format. + if use_cmd: + res = docs[0] + elif operation.name == "explain": + res = docs[0] if docs else {} + else: + res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr] + if operation.name == "find": + res["cursor"]["firstBatch"] = docs + else: + res["cursor"]["nextBatch"] = docs + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + message=_CommandStatusMessage.SUCCEEDED, + clientId=client._topology_settings._topology_id, + durationMS=duration, + reply=res, + commandName=next(iter(cmd)), + databaseName=dbn, + requestId=request_id, + operationId=request_id, + driverConnectionId=conn.id, + serverConnectionId=conn.server_connection_id, + serverHost=conn.address[0], + serverPort=conn.address[1], + serviceId=conn.service_id, + ) + if publish: + assert listeners is not None + listeners.publish_command_success( + duration, + res, + operation.name, + request_id, + conn.address, + conn.server_connection_id, + service_id=conn.service_id, + database_name=dbn, + ) + + # Decrypt response. client = operation.client # type: ignore[assignment] + if client and client._encrypter: + if use_cmd: + decrypted = client._encrypter.decrypt(reply.raw_command_response()) + docs = _decode_all_selective(decrypted, operation.codec_options, user_fields) + + response: Response + if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type] conn.pin_cursor() - more_to_come = reply.more_to_come # type: ignore[union-attr] + if isinstance(reply, _OpMsg): + # In OP_MSG, the server keeps sending only if the + # more_to_come flag is set. + more_to_come = reply.more_to_come + else: + # In OP_REPLY, the server keeps sending until cursor_id is 0. + more_to_come = bool(operation.exhaust and reply.cursor_id) if operation.conn_mgr: operation.conn_mgr.update_exhaust(more_to_come) response = PinnedResponse(