diff --git a/dbt/adapters/sqlserver/sqlserver_connections.py b/dbt/adapters/sqlserver/sqlserver_connections.py index 3291d6ca..fbcf1e58 100644 --- a/dbt/adapters/sqlserver/sqlserver_connections.py +++ b/dbt/adapters/sqlserver/sqlserver_connections.py @@ -61,6 +61,10 @@ logger = AdapterLogger("sqlserver") +# Attribute used to stash the in-flight pyodbc / mssql-python cursor on a +# Connection so cancel() can reach it from another thread. See cancel(). +_IN_FLIGHT_CURSOR_ATTR = "_dbt_sqlserver_in_flight_cursor" + class SQLServerConnectionManager(SQLConnectionManager): TYPE = "sqlserver" @@ -142,8 +146,40 @@ def connect() -> Any: return conn - def cancel(self, connection: Connection): - logger.debug("Cancel query") + def cancel(self, connection: Connection) -> None: + """Cancel the in-flight query on ``connection``, if any. + + dbt-core's ``cancel_open`` calls this for sibling connections when a + run is interrupted (Ctrl-C) or another thread errors. We cancel by + calling ``Cursor.cancel()`` on the connection's in-flight cursor: + pyodbc exposes it and it is explicitly designed to be called from + another thread (it issues ``SQLCancel``); mssql-python's cursor is + used the same way when it supports it. Cancellation targets statement + execution. If no statement is in flight, the cursor is gone, or the + backend cursor does not support cancellation, this is a best-effort + no-op. + """ + + cursor = getattr(connection, _IN_FLIGHT_CURSOR_ATTR, None) + if cursor is None: + logger.debug(f"No in-flight query to cancel for connection {connection.name}.") + return + + cancel_cursor = getattr(cursor, "cancel", None) + if not callable(cancel_cursor): + logger.debug( + f"Backend cursor for connection {connection.name} does not " + "support cancellation; skipping." + ) + return + + try: + logger.debug(f"Cancelling in-flight query for connection {connection.name}.") + cancel_cursor() + except Exception as exc: + # The statement may have completed between the lookup and the + # cancel; cancellation is best-effort, so swallow and log. + logger.debug(f"Failed to cancel query for connection {connection.name}: {exc}") def add_begin_query(self): if self._dbt_sqlserver_use_dbt_transactions: @@ -270,16 +306,23 @@ def _execute_query_with_retry( pre = time.time() cursor = connection.handle.cursor() + # Track the in-flight cursor so cancel() / cancel_open() can stop it + # from another thread (e.g. on Ctrl-C); cleared once execution + # finishes. See cancel(). + setattr(connection, _IN_FLIGHT_CURSOR_ATTR, cursor) credentials = self.get_credentials(connection.credentials) - _execute_query_with_retry( - cursor=cursor, - sql=sql, - bindings=bindings, - retryable_exceptions=retryable_exceptions, - retry_limit=(credentials.retries if credentials.retries > 3 else retry_limit), - attempt=1, - ) + try: + _execute_query_with_retry( + cursor=cursor, + sql=sql, + bindings=bindings, + retryable_exceptions=retryable_exceptions, + retry_limit=(credentials.retries if credentials.retries > 3 else retry_limit), + attempt=1, + ) + finally: + setattr(connection, _IN_FLIGHT_CURSOR_ATTR, None) if is_pyodbc_handle(connection.handle): connection.handle.add_output_converter(-155, byte_array_to_datetime) diff --git a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py index fb61e875..7e8998cc 100644 --- a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py +++ b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py @@ -1,5 +1,6 @@ import builtins import importlib +from contextlib import contextmanager from types import SimpleNamespace from typing import Any, Dict, List from unittest.mock import MagicMock, patch @@ -1594,3 +1595,122 @@ def test_rollback_handle_disabled_exception_fires_rollback_failed( SQLServerConnectionManager._rollback_handle(connection) mock_fire_event.assert_called_once() + + +@contextmanager +def _passthrough_exception_handler(_sql): + """Stand-in for add_query's exception_handler in in-flight-cursor tests. + + The real exception_handler has its own tests and depends on global runtime + state; here we only care about cursor tracking, so we let exceptions pass + straight through. + """ + + yield + + +def _build_cancel_test_manager(monkeypatch: pytest.MonkeyPatch, cursor: MagicMock): + """Manager + thread connection wired for add_query in-flight-cursor tests.""" + + manager = object.__new__(SQLServerConnectionManager) + + handle = MagicMock() + handle.cursor.return_value = cursor + + credentials = MagicMock() + credentials.retries = 1 + + connection = MagicMock() + connection.handle = handle + connection.credentials = credentials + connection.transaction_open = True + connection.name = "cancel-test" + + monkeypatch.setattr(manager, "get_thread_connection", lambda: connection) + monkeypatch.setattr(manager, "exception_handler", _passthrough_exception_handler) + + return manager, connection + + +def test_cancel_cancels_in_flight_cursor() -> None: + manager = object.__new__(SQLServerConnectionManager) + cursor = MagicMock() + connection = SimpleNamespace(name="conn-1", _dbt_sqlserver_in_flight_cursor=cursor) + + manager.cancel(connection) + + cursor.cancel.assert_called_once_with() + + +def test_cancel_is_noop_without_in_flight_cursor() -> None: + manager = object.__new__(SQLServerConnectionManager) + connection = SimpleNamespace(name="conn-1", _dbt_sqlserver_in_flight_cursor=None) + + # Must not raise when nothing is in flight. + manager.cancel(connection) + + +def test_cancel_is_noop_when_attribute_absent() -> None: + manager = object.__new__(SQLServerConnectionManager) + # A connection that never ran a query has no in-flight cursor attribute. + connection = SimpleNamespace(name="conn-1") + + manager.cancel(connection) + + +def test_cancel_handles_cursor_without_cancel_support() -> None: + manager = object.__new__(SQLServerConnectionManager) + # object() has no .cancel attribute, so cancellation is skipped gracefully. + connection = SimpleNamespace(name="conn-1", _dbt_sqlserver_in_flight_cursor=object()) + + manager.cancel(connection) + + +def test_cancel_swallows_errors_from_cursor_cancel() -> None: + manager = object.__new__(SQLServerConnectionManager) + cursor = MagicMock() + cursor.cancel.side_effect = RuntimeError("statement already completed") + connection = SimpleNamespace(name="conn-1", _dbt_sqlserver_in_flight_cursor=cursor) + + # Best-effort: a failure to cancel must not propagate. + manager.cancel(connection) + + cursor.cancel.assert_called_once_with() + + +def test_add_query_registers_then_clears_in_flight_cursor( + monkeypatch: pytest.MonkeyPatch, +) -> None: + cursor = MagicMock() + cursor.rowcount = 0 + captured: Dict[str, Any] = {} + + manager, connection = _build_cancel_test_manager(monkeypatch, cursor) + + def _record_in_flight(*_args, **_kwargs): + captured["during"] = connection._dbt_sqlserver_in_flight_cursor + + cursor.execute.side_effect = _record_in_flight + + with patch("dbt.adapters.sqlserver.sqlserver_connections.fire_event"): + manager.add_query("select 1", auto_begin=False) + + # Registered while the statement runs, so cancel() can reach it... + assert captured["during"] is cursor + # ...and cleared once execution completes. + assert connection._dbt_sqlserver_in_flight_cursor is None + + +def test_add_query_clears_in_flight_cursor_after_failure( + monkeypatch: pytest.MonkeyPatch, +) -> None: + cursor = MagicMock() + cursor.execute.side_effect = ValueError("boom") # not a retryable exception + + manager, connection = _build_cancel_test_manager(monkeypatch, cursor) + + with patch("dbt.adapters.sqlserver.sqlserver_connections.fire_event"): + with pytest.raises(ValueError): + manager.add_query("select 1", auto_begin=False) + + assert connection._dbt_sqlserver_in_flight_cursor is None