Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions pymongo/asynchronous/command_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,13 @@
NoReturn,
Optional,
Sequence,
Union,
)

from bson import CodecOptions, _convert_raw_document_lists_to_streams
from pymongo.asynchronous.cursor_base import _AsyncCursorBase, _ConnectionManager
from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
from pymongo.message import _GetMore, _OpMsg, _OpReply, _RawBatchGetMore
from pymongo.message import _GetMore, _OpMsg, _RawBatchGetMore
from pymongo.response import PinnedResponse
from pymongo.typings import _Address, _DocumentOut, _DocumentType

Expand Down Expand Up @@ -145,7 +144,7 @@ async def _maybe_pin_connection(self, conn: AsyncConnection) -> None:

def _unpack_response(
self,
response: Union[_OpReply, _OpMsg],
response: _OpMsg,
cursor_id: Optional[int],
codec_options: CodecOptions[Mapping[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
Expand Down Expand Up @@ -189,15 +188,10 @@ async def _send_message(self, operation: _GetMore) -> None:
if isinstance(response, PinnedResponse):
if not self._sock_mgr:
self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type]
if response.from_command:
cursor = response.docs[0]["cursor"]
documents = cursor["nextBatch"]
self._postbatchresumetoken = cursor.get("postBatchResumeToken")
self._id = cursor["id"]
else:
documents = response.docs
assert isinstance(response.data, _OpReply)
self._id = response.data.cursor_id
cursor = response.docs[0]["cursor"]
documents = cursor["nextBatch"]
self._postbatchresumetoken = cursor.get("postBatchResumeToken")
self._id = cursor["id"]

if self._id == 0:
await self.close()
Expand Down Expand Up @@ -333,7 +327,7 @@ def __init__(

def _unpack_response( # type: ignore[override]
self,
response: Union[_OpReply, _OpMsg],
response: _OpMsg,
cursor_id: Optional[int],
codec_options: CodecOptions[dict[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
Expand Down
39 changes: 16 additions & 23 deletions pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from pymongo.message import (
_GetMore,
_OpMsg,
_OpReply,
_Query,
_RawBatchGetMore,
_RawBatchQuery,
Expand Down Expand Up @@ -864,7 +863,7 @@ def collation(self, collation: Optional[_CollationIn]) -> AsyncCursor[_DocumentT

def _unpack_response(
self,
response: Union[_OpReply, _OpMsg],
response: _OpMsg,
cursor_id: Optional[int],
codec_options: CodecOptions, # type: ignore[type-arg]
user_fields: Optional[Mapping[str, Any]] = None,
Expand Down Expand Up @@ -1020,29 +1019,23 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None:

cmd_name = operation.name
docs = response.docs
if response.from_command:
if cmd_name != "explain":
cursor = docs[0]["cursor"]
self._id = cursor["id"]
if cmd_name == "find":
documents = cursor["firstBatch"]
# Update the namespace used for future getMore commands.
ns = cursor.get("ns")
if ns:
self._dbname, self._collname = ns.split(".", 1)
else:
documents = cursor["nextBatch"]
self._data = deque(documents)
self._retrieved += len(documents)
if cmd_name != "explain":
cursor = docs[0]["cursor"]
self._id = cursor["id"]
if cmd_name == "find":
documents = cursor["firstBatch"]
# Update the namespace used for future getMore commands.
ns = cursor.get("ns")
if ns:
self._dbname, self._collname = ns.split(".", 1)
else:
self._id = 0
self._data = deque(docs)
self._retrieved += len(docs)
documents = cursor["nextBatch"]
self._data = deque(documents)
self._retrieved += len(documents)
else:
assert isinstance(response.data, _OpReply)
self._id = response.data.cursor_id
self._id = 0
self._data = deque(docs)
self._retrieved += response.data.number_returned
self._retrieved += len(docs)

if self._id == 0:
# Don't wait for garbage collection to call __del__, return the
Expand Down Expand Up @@ -1195,7 +1188,7 @@ def __init__(

def _unpack_response(
self,
response: Union[_OpReply, _OpMsg],
response: _OpMsg,
cursor_id: Optional[int],
codec_options: CodecOptions[Mapping[str, Any]],
user_fields: Optional[Mapping[str, Any]] = None,
Expand Down
33 changes: 11 additions & 22 deletions pymongo/asynchronous/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def command(
conn: AsyncConnection,
dbname: str,
spec: MutableMapping[str, Any],
is_mongos: bool,
is_mongos: bool, # noqa: ARG001
read_preference: Optional[_ServerMode],
codec_options: CodecOptions[_DocumentType],
session: Optional[AsyncClientSession],
Expand All @@ -76,7 +76,6 @@ async def command(
parse_write_concern_error: bool = False,
collation: Optional[_CollationIn] = None,
compression_ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None,
use_op_msg: bool = False,
unacknowledged: bool = False,
user_fields: Optional[Mapping[str, Any]] = None,
exhaust_allowed: bool = False,
Expand All @@ -102,22 +101,17 @@ async def command(
field in the command response.
:param collation: The collation for this command.
:param compression_ctx: optional compression Context.
:param use_op_msg: True if we should use OP_MSG.
:param unacknowledged: True if this is an unacknowledged command.
:param user_fields: Response fields that should be decoded
using the TypeDecoders from codec_options, passed to
bson._decode_all_selective.
:param exhaust_allowed: True if we should enable OP_MSG exhaustAllowed.
"""
name = next(iter(spec))
ns = dbname + ".$cmd"
speculative_hello = False

# Publish the original command document, perhaps with lsid and $clusterTime.
orig = spec
if is_mongos and not use_op_msg:
assert read_preference is not None
spec = message._maybe_add_read_preference(spec, read_preference)
if read_concern and not (session and session.in_transaction):
if read_concern.level:
spec["readConcern"] = read_concern.document
Expand All @@ -142,20 +136,15 @@ async def command(
conn.apply_timeout(client, spec)
_csot.apply_write_concern(spec, write_concern)

if use_op_msg:
flags = _OpMsg.MORE_TO_COME if unacknowledged else 0
flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0
request_id, msg, size, max_doc_size = message._op_msg(
flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx
)
# If this is an unacknowledged write then make sure the encoded doc(s)
# are small enough, otherwise rely on the server to return an error.
if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size:
message._raise_document_too_large(name, size, max_bson_size)
else:
request_id, msg, size = message._query(
0, ns, 0, -1, spec, None, codec_options, compression_ctx
)
flags = _OpMsg.MORE_TO_COME if unacknowledged else 0
flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0
request_id, msg, size, max_doc_size = message._op_msg(
flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx
)
# If this is an unacknowledged write then make sure the encoded doc(s)
# are small enough, otherwise rely on the server to return an error.
if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size:
message._raise_document_too_large(name, size, max_bson_size)

if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD:
message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD)
Expand Down Expand Up @@ -190,7 +179,7 @@ async def command(

try:
await async_sendall(conn.conn.get_conn, msg)
if use_op_msg and unacknowledged:
if unacknowledged:
# Unacknowledged, fake a successful command response.
reply = None
response_doc: _DocumentOut = {"ok": 1}
Expand Down
18 changes: 6 additions & 12 deletions pymongo/asynchronous/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
ZlibContext,
ZstdContext,
)
from pymongo.message import _OpMsg, _OpReply
from pymongo.message import _OpMsg
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import _ServerMode
from pymongo.typings import _Address, _CollationIn
Expand Down Expand Up @@ -146,7 +146,6 @@ def __init__(
self.supports_sessions = False
self.hello_ok: bool = False
self.is_mongos = False
self.op_msg_enabled = False
self.listeners = pool.opts._event_listeners
self.enabled_for_cmap = pool.enabled_for_cmap
self.enabled_for_logging = pool.enabled_for_logging
Expand Down Expand Up @@ -235,13 +234,11 @@ async def unpin(self) -> None:
await self.close_conn(ConnectionClosedReason.STALE)

def hello_cmd(self) -> dict[str, Any]:
# Handshake spec requires us to use OP_MSG+hello command for the
# initial handshake in load balanced or stable API mode.
# As of PYTHON-5713, always use OP_MSG for the handshake since all
# supported servers (MongoDB 4.2+, wire version >= 8) support it.
if self.opts.server_api or self.hello_ok or self.opts.load_balanced:
self.op_msg_enabled = True
return {HelloCompat.CMD: 1}
else:
return {HelloCompat.LEGACY_CMD: 1, "helloOk": True}
return {HelloCompat.LEGACY_CMD: 1, "helloOk": True}

async def hello(self) -> Hello[dict[str, Any]]:
return await self._hello(None, None)
Expand Down Expand Up @@ -314,7 +311,6 @@ async def _hello(
ctx = self.compression_settings.get_compression_context(hello.compressors)
self.compression_context = ctx

self.op_msg_enabled = True
self.server_connection_id = hello.connection_id
if creds:
self.negotiated_mechs = hello.sasl_supported_mechs
Expand Down Expand Up @@ -397,8 +393,7 @@ async def command(
self.send_cluster_time(spec, session, client)
listeners = self.listeners if publish_events else None
unacknowledged = bool(write_concern and not write_concern.acknowledged)
if self.op_msg_enabled:
self._raise_if_not_writable(unacknowledged)
self._raise_if_not_writable(unacknowledged)
try:
return await command(
self,
Expand All @@ -418,7 +413,6 @@ async def command(
parse_write_concern_error=parse_write_concern_error,
collation=collation,
compression_ctx=self.compression_context,
use_op_msg=self.op_msg_enabled,
unacknowledged=unacknowledged,
user_fields=user_fields,
exhaust_allowed=exhaust_allowed,
Expand Down Expand Up @@ -447,7 +441,7 @@ async def send_message(self, message: bytes, max_doc_size: int) -> None:
except BaseException as error:
await self._raise_connection_failure(error)

async def receive_message(self, request_id: Optional[int]) -> Union[_OpReply, _OpMsg]:
async def receive_message(self, request_id: Optional[int]) -> _OpMsg:
"""Receive a raw BSON message or raise ConnectionFailure.

If any exception is raised, the socket is closed.
Expand Down
11 changes: 1 addition & 10 deletions pymongo/asynchronous/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,16 +265,7 @@ async def run_operation(
duration = datetime.now() - start
# Must publish in find / getMore / explain command response
# format.
if use_cmd:
res = docs[0]
elif operation.name == "explain":
res = docs[0] if docs else {}
else:
res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} # type: ignore[union-attr]
if operation.name == "find":
res["cursor"]["firstBatch"] = docs
else:
res["cursor"]["nextBatch"] = docs
res = docs[0]
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
_debug_log(
_COMMAND_LOGGER,
Expand Down
Loading
Loading