Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 53 additions & 10 deletions dbt/adapters/sqlserver/sqlserver_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@

logger = AdapterLogger("sqlserver")

# Attribute used to stash the in-flight pyodbc / mssql-python cursor on a
# Connection so cancel() can reach it from another thread. See cancel().
_IN_FLIGHT_CURSOR_ATTR = "_dbt_sqlserver_in_flight_cursor"


class SQLServerConnectionManager(SQLConnectionManager):
TYPE = "sqlserver"
Expand Down Expand Up @@ -142,8 +146,40 @@ def connect() -> Any:

return conn

def cancel(self, connection: Connection):
logger.debug("Cancel query")
def cancel(self, connection: Connection) -> None:
"""Cancel the in-flight query on ``connection``, if any.

dbt-core's ``cancel_open`` calls this for sibling connections when a
run is interrupted (Ctrl-C) or another thread errors. We cancel by
calling ``Cursor.cancel()`` on the connection's in-flight cursor:
pyodbc exposes it and it is explicitly designed to be called from
another thread (it issues ``SQLCancel``); mssql-python's cursor is
used the same way when it supports it. Cancellation targets statement
execution. If no statement is in flight, the cursor is gone, or the
backend cursor does not support cancellation, this is a best-effort
no-op.
"""

cursor = getattr(connection, _IN_FLIGHT_CURSOR_ATTR, None)
if cursor is None:
logger.debug(f"No in-flight query to cancel for connection {connection.name}.")
return

cancel_cursor = getattr(cursor, "cancel", None)
if not callable(cancel_cursor):
logger.debug(
f"Backend cursor for connection {connection.name} does not "
"support cancellation; skipping."
)
return

try:
logger.debug(f"Cancelling in-flight query for connection {connection.name}.")
cancel_cursor()
except Exception as exc:
# The statement may have completed between the lookup and the
# cancel; cancellation is best-effort, so swallow and log.
logger.debug(f"Failed to cancel query for connection {connection.name}: {exc}")

def add_begin_query(self):
if self._dbt_sqlserver_use_dbt_transactions:
Expand Down Expand Up @@ -270,16 +306,23 @@ def _execute_query_with_retry(
pre = time.time()

cursor = connection.handle.cursor()
# Track the in-flight cursor so cancel() / cancel_open() can stop it
# from another thread (e.g. on Ctrl-C); cleared once execution
# finishes. See cancel().
setattr(connection, _IN_FLIGHT_CURSOR_ATTR, cursor)
credentials = self.get_credentials(connection.credentials)

_execute_query_with_retry(
cursor=cursor,
sql=sql,
bindings=bindings,
retryable_exceptions=retryable_exceptions,
retry_limit=(credentials.retries if credentials.retries > 3 else retry_limit),
attempt=1,
)
try:
_execute_query_with_retry(
cursor=cursor,
sql=sql,
bindings=bindings,
retryable_exceptions=retryable_exceptions,
retry_limit=(credentials.retries if credentials.retries > 3 else retry_limit),
attempt=1,
)
finally:
setattr(connection, _IN_FLIGHT_CURSOR_ATTR, None)

if is_pyodbc_handle(connection.handle):
connection.handle.add_output_converter(-155, byte_array_to_datetime)
Expand Down
120 changes: 120 additions & 0 deletions tests/unit/adapters/mssql/test_sqlserver_connection_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import builtins
import importlib
from contextlib import contextmanager
from types import SimpleNamespace
from typing import Any, Dict, List
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -1594,3 +1595,122 @@ def test_rollback_handle_disabled_exception_fires_rollback_failed(
SQLServerConnectionManager._rollback_handle(connection)

mock_fire_event.assert_called_once()


@contextmanager
def _passthrough_exception_handler(_sql):
"""Stand-in for add_query's exception_handler in in-flight-cursor tests.

The real exception_handler has its own tests and depends on global runtime
state; here we only care about cursor tracking, so we let exceptions pass
straight through.
"""

yield


def _build_cancel_test_manager(monkeypatch: pytest.MonkeyPatch, cursor: MagicMock):
"""Manager + thread connection wired for add_query in-flight-cursor tests."""

manager = object.__new__(SQLServerConnectionManager)

handle = MagicMock()
handle.cursor.return_value = cursor

credentials = MagicMock()
credentials.retries = 1

connection = MagicMock()
connection.handle = handle
connection.credentials = credentials
connection.transaction_open = True
connection.name = "cancel-test"

monkeypatch.setattr(manager, "get_thread_connection", lambda: connection)
monkeypatch.setattr(manager, "exception_handler", _passthrough_exception_handler)

return manager, connection


def test_cancel_cancels_in_flight_cursor() -> None:
manager = object.__new__(SQLServerConnectionManager)
cursor = MagicMock()
connection = SimpleNamespace(name="conn-1", _dbt_sqlserver_in_flight_cursor=cursor)

manager.cancel(connection)

cursor.cancel.assert_called_once_with()


def test_cancel_is_noop_without_in_flight_cursor() -> None:
manager = object.__new__(SQLServerConnectionManager)
connection = SimpleNamespace(name="conn-1", _dbt_sqlserver_in_flight_cursor=None)

# Must not raise when nothing is in flight.
manager.cancel(connection)


def test_cancel_is_noop_when_attribute_absent() -> None:
manager = object.__new__(SQLServerConnectionManager)
# A connection that never ran a query has no in-flight cursor attribute.
connection = SimpleNamespace(name="conn-1")

manager.cancel(connection)


def test_cancel_handles_cursor_without_cancel_support() -> None:
manager = object.__new__(SQLServerConnectionManager)
# object() has no .cancel attribute, so cancellation is skipped gracefully.
connection = SimpleNamespace(name="conn-1", _dbt_sqlserver_in_flight_cursor=object())

manager.cancel(connection)


def test_cancel_swallows_errors_from_cursor_cancel() -> None:
manager = object.__new__(SQLServerConnectionManager)
cursor = MagicMock()
cursor.cancel.side_effect = RuntimeError("statement already completed")
connection = SimpleNamespace(name="conn-1", _dbt_sqlserver_in_flight_cursor=cursor)

# Best-effort: a failure to cancel must not propagate.
manager.cancel(connection)

cursor.cancel.assert_called_once_with()


def test_add_query_registers_then_clears_in_flight_cursor(
monkeypatch: pytest.MonkeyPatch,
) -> None:
cursor = MagicMock()
cursor.rowcount = 0
captured: Dict[str, Any] = {}

manager, connection = _build_cancel_test_manager(monkeypatch, cursor)

def _record_in_flight(*_args, **_kwargs):
captured["during"] = connection._dbt_sqlserver_in_flight_cursor

cursor.execute.side_effect = _record_in_flight

with patch("dbt.adapters.sqlserver.sqlserver_connections.fire_event"):
manager.add_query("select 1", auto_begin=False)

# Registered while the statement runs, so cancel() can reach it...
assert captured["during"] is cursor
# ...and cleared once execution completes.
assert connection._dbt_sqlserver_in_flight_cursor is None


def test_add_query_clears_in_flight_cursor_after_failure(
monkeypatch: pytest.MonkeyPatch,
) -> None:
cursor = MagicMock()
cursor.execute.side_effect = ValueError("boom") # not a retryable exception

manager, connection = _build_cancel_test_manager(monkeypatch, cursor)

with patch("dbt.adapters.sqlserver.sqlserver_connections.fire_event"):
with pytest.raises(ValueError):
manager.add_query("select 1", auto_begin=False)

assert connection._dbt_sqlserver_in_flight_cursor is None