diff --git a/README.md b/README.md index 5666a2c..7ace1e3 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,17 @@ with connection.cursor() as cursor: rows = cursor.fetchall() ``` +For large result sets you can use a server-side cursor that streams result +sets incrementally instead of buffering the whole response in memory: + +```python +with connection.cursor(stream_results=True) as cursor: + cursor.execute("SELECT id, val FROM table") + + for row in iter(cursor.fetchone, None): + ... +``` + Usage of async connection: ```python @@ -47,3 +58,14 @@ async with async_connection.cursor() as cursor: rows = await cursor.fetchmany(size=5) rows = await cursor.fetchall() ``` + +Async streaming cursors are enabled with the same flag: + +```python +async with async_connection.cursor(stream_results=True) as cursor: + await cursor.execute("SELECT id, val FROM table") + + row = await cursor.fetchone() + rows = await cursor.fetchmany(size=5) + rows = await cursor.fetchall() +``` diff --git a/poetry.lock b/poetry.lock index d1752b3..1dbed51 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2760,14 +2760,14 @@ propcache = ">=0.2.1" [[package]] name = "ydb" -version = "3.26.7" +version = "3.28.2" description = "YDB Python SDK" optional = false python-versions = "*" groups = ["main"] files = [ - {file = "ydb-3.26.7-py2.py3-none-any.whl", hash = "sha256:14f7e47ced588f449ac6d7b63de8ab7bf779e6df08d7724ebb62afaaab27024c"}, - {file = "ydb-3.26.7.tar.gz", hash = "sha256:2e4fc1e7be9e225ea7bd7f5a97abbce648c3d89a16d549ad21ddef16e3f00dfd"}, + {file = "ydb-3.28.2-py2.py3-none-any.whl", hash = "sha256:13511d59146a964207c7e13cb3de071a2b95fb5067e4b9e89c669c9e9bd39ee3"}, + {file = "ydb-3.28.2.tar.gz", hash = "sha256:a27325289a2a0f69e7326b14084c5612634b0801aebba695850ed1591c728e07"}, ] [package.dependencies] diff --git a/tests/test_connections.py b/tests/test_connections.py index 0eced67..a8d300e 100644 --- a/tests/test_connections.py +++ b/tests/test_connections.py @@ -65,7 +65,6 @@ def _test_commit_rollback_after_begin( maybe_await(connection.begin()) maybe_await(connection.rollback()) - def _test_connection(self, connection: dbapi.Connection) -> None: maybe_await(connection.commit()) maybe_await(connection.rollback()) @@ -100,9 +99,12 @@ def _test_cursor_raw_query(self, connection: dbapi.Connection) -> None: with suppress(dbapi.DatabaseError): maybe_await(cur.execute_scheme("DROP TABLE test")) - maybe_await(cur.execute_scheme( - "CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))" - )) + maybe_await( + cur.execute_scheme( + "CREATE TABLE test(" + "id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))" + ) + ) maybe_await( cur.execute( @@ -421,9 +423,7 @@ def test_commit_rollback_after_begin( isolation_level: str, connection: dbapi.Connection, ) -> None: - self._test_commit_rollback_after_begin( - connection, isolation_level - ) + self._test_commit_rollback_after_begin(connection, isolation_level) def test_connection(self, connection: dbapi.Connection) -> None: self._test_connection(connection) @@ -442,14 +442,10 @@ def test_errors_with_interactive_tx( ) -> None: self._test_error_with_interactive_tx(connection) - def test_get_view_names( - self, connection: dbapi.Connection - ) -> None: + def test_get_view_names(self, connection: dbapi.Connection) -> None: self._test_get_view_names(connection) - def test_get_table_names( - self, connection: dbapi.Connection - ) -> None: + def test_get_table_names(self, connection: dbapi.Connection) -> None: self._test_get_table_names(connection) @@ -527,9 +523,7 @@ async def test_commit_rollback_after_begin( connection: dbapi.AsyncConnection, ) -> None: await greenlet_spawn( - self._test_commit_rollback_after_begin, - connection, - isolation_level + self._test_commit_rollback_after_begin, connection, isolation_level ) @pytest.mark.asyncio diff --git a/tests/test_cursors.py b/tests/test_cursors.py index 177b394..0b2a211 100644 --- a/tests/test_cursors.py +++ b/tests/test_cursors.py @@ -5,6 +5,7 @@ from inspect import iscoroutine import pytest +import pytest_asyncio import ydb import ydb_dbapi from sqlalchemy.util import await_only @@ -25,10 +26,16 @@ def maybe_await(obj: callable) -> any: class FakeSyncConnection: + def _clear_current_cursor(self, cursor: Cursor | None = None) -> None: ... + def _invalidate_session(self) -> None: ... class FakeAsyncConnection: + def _clear_current_cursor( + self, cursor: AsyncCursor | None = None + ) -> None: ... + async def _invalidate_session(self) -> None: ... @@ -156,6 +163,371 @@ def _test_cursor_state_after_error( assert cursor._state == CursorStatus.finished +class BaseStreamCursorIntegrationTestSuit: + def _make_multi_result_query(self, count: int) -> str: + return ";\n".join( + f'SELECT {index} AS id, CAST("{index}" AS Utf8) AS value' + for index in range(count) + ) + + def _test_stream_cursor_blocks_shared_transaction_session( + self, + connection: ydb_dbapi.Connection | ydb_dbapi.AsyncConnection, + ) -> None: + connection.set_isolation_level(ydb_dbapi.IsolationLevel.SERIALIZABLE) + maybe_await(connection.begin()) + + stream_cursor = connection.cursor(stream_results=True) + other_cursor = connection.cursor() + + try: + maybe_await(stream_cursor.execute("SELECT 1 AS id")) + + with pytest.raises(ydb_dbapi.ProgrammingError): + maybe_await(other_cursor.execute("SELECT 2 AS id")) + + with pytest.raises(ydb_dbapi.ProgrammingError): + maybe_await(connection.commit()) + + rows = maybe_await(stream_cursor.fetchall()) + assert rows == [(1,)] + + maybe_await(connection.commit()) + finally: + maybe_await(stream_cursor.close()) + maybe_await(other_cursor.close()) + if connection._tx_context or connection._session: + maybe_await(connection.rollback()) + + def _test_stream_cursor_fetches_real_data( + self, + connection: ydb_dbapi.Connection | ydb_dbapi.AsyncConnection, + ) -> None: + stream_cursor = connection.cursor(stream_results=True) + + try: + maybe_await( + stream_cursor.execute( + """ + SELECT 0 AS id, CAST("zero" AS Utf8) AS value + UNION ALL + SELECT 1 AS id, CAST("one" AS Utf8) AS value + UNION ALL + SELECT 2 AS id, CAST("two" AS Utf8) AS value; + + SELECT 10 AS id, CAST("ten" AS Utf8) AS value + UNION ALL + SELECT 11 AS id, CAST("eleven" AS Utf8) AS value; + """ + ) + ) + + assert stream_cursor.rowcount == -1 + + row = maybe_await(stream_cursor.fetchone()) + assert row == (0, "zero") + assert stream_cursor.rowcount == -1 + + rows = maybe_await(stream_cursor.fetchmany(size=2)) + assert rows == [(1, "one"), (2, "two")] + assert stream_cursor.rowcount == -1 + + rows = maybe_await(stream_cursor.fetchall()) + assert rows == [(10, "ten"), (11, "eleven")] + assert stream_cursor.rowcount == 5 + assert maybe_await(stream_cursor.fetchall()) == [] + finally: + maybe_await(stream_cursor.close()) + + def _test_stream_cursor_empty_result_set( + self, + connection: ydb_dbapi.Connection | ydb_dbapi.AsyncConnection, + ) -> None: + stream_cursor = connection.cursor(stream_results=True) + + try: + maybe_await( + stream_cursor.execute( + """ + SELECT id + FROM ( + SELECT CAST(1 AS Int64) AS id + ) + WHERE FALSE; + """ + ) + ) + + assert stream_cursor.description is not None + assert stream_cursor.description[0][0] == "id" + assert stream_cursor.rowcount == -1 + assert maybe_await(stream_cursor.fetchone()) is None + assert stream_cursor.rowcount == 0 + assert maybe_await(stream_cursor.fetchmany(size=2)) == [] + assert maybe_await(stream_cursor.fetchall()) == [] + finally: + maybe_await(stream_cursor.close()) + + def _test_stream_cursor_close_releases_session_for_next_query( + self, + connection: ydb_dbapi.Connection | ydb_dbapi.AsyncConnection, + ) -> None: + stream_cursor = connection.cursor(stream_results=True) + other_cursor = connection.cursor() + + try: + maybe_await( + stream_cursor.execute( + """ + SELECT number AS id + FROM AS_TABLE([<|number:1|>, <|number:2|>, <|number:3|>]); + """ + ) + ) + + assert maybe_await(stream_cursor.fetchone()) == (1,) + maybe_await(stream_cursor.close()) + + maybe_await(other_cursor.execute("SELECT 99 AS id")) + assert other_cursor.fetchall() == [(99,)] + finally: + maybe_await(stream_cursor.close()) + maybe_await(other_cursor.close()) + + def _test_stream_cursor_execute_while_running_fails( + self, + connection: ydb_dbapi.Connection | ydb_dbapi.AsyncConnection, + ) -> None: + stream_cursor = connection.cursor(stream_results=True) + + try: + maybe_await( + stream_cursor.execute( + """ + SELECT number AS id + FROM AS_TABLE([<|number:1|>, <|number:2|>]); + """ + ) + ) + + with pytest.raises(ydb_dbapi.ProgrammingError): + maybe_await(stream_cursor.execute("SELECT 10 AS id")) + + assert maybe_await(stream_cursor.fetchall()) == [(1,), (2,)] + maybe_await(stream_cursor.execute("SELECT 10 AS id")) + assert maybe_await(stream_cursor.fetchall()) == [(10,)] + finally: + maybe_await(stream_cursor.close()) + + def _test_stream_cursor_close_unblocks_shared_transaction_session( + self, + connection: ydb_dbapi.Connection | ydb_dbapi.AsyncConnection, + ) -> None: + connection.set_isolation_level(ydb_dbapi.IsolationLevel.SERIALIZABLE) + maybe_await(connection.begin()) + + stream_cursor = connection.cursor(stream_results=True) + other_cursor = connection.cursor() + + try: + maybe_await( + stream_cursor.execute( + """ + SELECT number AS id + FROM AS_TABLE([<|number:1|>, <|number:2|>, <|number:3|>]); + """ + ) + ) + assert maybe_await(stream_cursor.fetchone()) == (1,) + + maybe_await(stream_cursor.close()) + + maybe_await(other_cursor.execute("SELECT 77 AS id")) + assert other_cursor.fetchall() == [(77,)] + maybe_await(connection.commit()) + finally: + maybe_await(stream_cursor.close()) + maybe_await(other_cursor.close()) + if connection._tx_context or connection._session: + maybe_await(connection.rollback()) + + def _test_stream_cursor_many_result_sets_fetchone_state( + self, + connection: ydb_dbapi.Connection | ydb_dbapi.AsyncConnection, + ) -> None: + result_set_count = 8 + stream_cursor = connection.cursor(stream_results=True) + + try: + maybe_await( + stream_cursor.execute( + self._make_multi_result_query(result_set_count) + ) + ) + + assert stream_cursor._state == CursorStatus.running + assert stream_cursor.rowcount == -1 + assert stream_cursor.description is not None + assert stream_cursor.description[0][0] == "id" + assert stream_cursor.description[1][0] == "value" + + for index in range(result_set_count): + assert maybe_await(stream_cursor.fetchone()) == ( + index, + str(index), + ) + assert stream_cursor._state == CursorStatus.running + if index < result_set_count - 1: + assert stream_cursor.rowcount == -1 + + assert stream_cursor.rowcount == -1 + assert maybe_await(stream_cursor.fetchone()) is None + assert stream_cursor.rowcount == result_set_count + assert stream_cursor._state == CursorStatus.finished + finally: + maybe_await(stream_cursor.close()) + + def _test_stream_cursor_many_result_sets_fetchmany( + self, + connection: ydb_dbapi.Connection | ydb_dbapi.AsyncConnection, + ) -> None: + result_set_count = 9 + stream_cursor = connection.cursor(stream_results=True) + + try: + maybe_await( + stream_cursor.execute( + self._make_multi_result_query(result_set_count) + ) + ) + + rows = maybe_await(stream_cursor.fetchmany(size=4)) + assert rows == [ + (0, "0"), + (1, "1"), + (2, "2"), + (3, "3"), + ] + assert stream_cursor.rowcount == -1 + assert stream_cursor._state == CursorStatus.running + + rows = maybe_await(stream_cursor.fetchmany(size=4)) + assert rows == [ + (4, "4"), + (5, "5"), + (6, "6"), + (7, "7"), + ] + assert stream_cursor.rowcount == -1 + + rows = maybe_await(stream_cursor.fetchmany(size=4)) + assert rows == [(8, "8")] + assert stream_cursor.rowcount == result_set_count + assert stream_cursor._state == CursorStatus.finished + + assert maybe_await(stream_cursor.fetchmany(size=4)) == [] + assert stream_cursor.rowcount == result_set_count + assert stream_cursor._state == CursorStatus.finished + finally: + maybe_await(stream_cursor.close()) + + def _test_stream_cursor_many_result_sets_fetchall( + self, + connection: ydb_dbapi.Connection | ydb_dbapi.AsyncConnection, + ) -> None: + result_set_count = 6 + stream_cursor = connection.cursor(stream_results=True) + + try: + maybe_await( + stream_cursor.execute( + self._make_multi_result_query(result_set_count) + ) + ) + + rows = maybe_await(stream_cursor.fetchall()) + assert rows == [ + (index, str(index)) for index in range(result_set_count) + ] + assert stream_cursor.rowcount == result_set_count + assert stream_cursor._state == CursorStatus.finished + assert maybe_await(stream_cursor.fetchall()) == [] + finally: + maybe_await(stream_cursor.close()) + + def _test_other_cursor_reusable_after_blocked_execute( + self, + connection: ydb_dbapi.Connection | ydb_dbapi.AsyncConnection, + ) -> None: + connection.set_isolation_level(ydb_dbapi.IsolationLevel.SERIALIZABLE) + maybe_await(connection.begin()) + + stream_cursor = connection.cursor(stream_results=True) + other_cursor = connection.cursor() + + try: + maybe_await(stream_cursor.execute("SELECT 1 AS id")) + + with pytest.raises(ydb_dbapi.ProgrammingError): + maybe_await(other_cursor.execute("SELECT 2 AS id")) + + assert other_cursor._state == CursorStatus.ready + + maybe_await(stream_cursor.fetchall()) + maybe_await(stream_cursor.close()) + + maybe_await(other_cursor.execute("SELECT 3 AS id")) + assert other_cursor.fetchall() == [(3,)] + finally: + maybe_await(stream_cursor.close()) + maybe_await(other_cursor.close()) + if connection._tx_context or connection._session: + maybe_await(connection.rollback()) + + def _test_stream_cursor_reusable_after_blocked_execute( + self, + connection: ydb_dbapi.Connection | ydb_dbapi.AsyncConnection, + ) -> None: + connection.set_isolation_level(ydb_dbapi.IsolationLevel.SERIALIZABLE) + maybe_await(connection.begin()) + + first_stream = connection.cursor(stream_results=True) + second_stream = connection.cursor(stream_results=True) + + try: + maybe_await(first_stream.execute("SELECT 1 AS id")) + + with pytest.raises(ydb_dbapi.ProgrammingError): + maybe_await(second_stream.execute("SELECT 2 AS id")) + + assert second_stream._state == CursorStatus.ready + + maybe_await(first_stream.fetchall()) + maybe_await(first_stream.close()) + + maybe_await(second_stream.execute("SELECT 3 AS id")) + assert maybe_await(second_stream.fetchall()) == [(3,)] + finally: + maybe_await(first_stream.close()) + maybe_await(second_stream.close()) + if connection._tx_context or connection._session: + maybe_await(connection.rollback()) + + def _test_connection_close_with_running_stream_cursor( + self, + connection: ydb_dbapi.Connection | ydb_dbapi.AsyncConnection, + ) -> None: + connection.set_isolation_level(ydb_dbapi.IsolationLevel.SERIALIZABLE) + maybe_await(connection.begin()) + + stream_cursor = connection.cursor(stream_results=True) + maybe_await(stream_cursor.execute("SELECT 1 AS id")) + + assert stream_cursor._state == CursorStatus.running + + maybe_await(connection.close()) + + class TestCursor(BaseCursorTestSuit): @pytest.fixture def sync_cursor( @@ -255,3 +627,211 @@ async def test_cursor_state_after_error( self, async_cursor: AsyncCursor ) -> None: await greenlet_spawn(self._test_cursor_state_after_error, async_cursor) + + +class TestStreamCursorIntegration(BaseStreamCursorIntegrationTestSuit): + @pytest.fixture + def connection( + self, connection_kwargs: dict + ) -> Generator[ydb_dbapi.Connection]: + conn = ydb_dbapi.connect(**connection_kwargs) # ignore: typing + try: + yield conn + finally: + conn.close() + + def test_stream_cursor_blocks_shared_transaction_session( + self, connection: ydb_dbapi.Connection + ) -> None: + self._test_stream_cursor_blocks_shared_transaction_session(connection) + + def test_stream_cursor_fetches_real_data( + self, connection: ydb_dbapi.Connection + ) -> None: + self._test_stream_cursor_fetches_real_data(connection) + + def test_stream_cursor_empty_result_set( + self, connection: ydb_dbapi.Connection + ) -> None: + self._test_stream_cursor_empty_result_set(connection) + + def test_stream_cursor_close_releases_session_for_next_query( + self, connection: ydb_dbapi.Connection + ) -> None: + self._test_stream_cursor_close_releases_session_for_next_query( + connection + ) + + def test_stream_cursor_execute_while_running_fails( + self, connection: ydb_dbapi.Connection + ) -> None: + self._test_stream_cursor_execute_while_running_fails(connection) + + def test_stream_cursor_close_unblocks_shared_transaction_session( + self, connection: ydb_dbapi.Connection + ) -> None: + self._test_stream_cursor_close_unblocks_shared_transaction_session( + connection + ) + + def test_stream_cursor_many_result_sets_fetchone_state( + self, connection: ydb_dbapi.Connection + ) -> None: + self._test_stream_cursor_many_result_sets_fetchone_state(connection) + + def test_stream_cursor_many_result_sets_fetchmany( + self, connection: ydb_dbapi.Connection + ) -> None: + self._test_stream_cursor_many_result_sets_fetchmany(connection) + + def test_stream_cursor_many_result_sets_fetchall( + self, connection: ydb_dbapi.Connection + ) -> None: + self._test_stream_cursor_many_result_sets_fetchall(connection) + + def test_other_cursor_reusable_after_blocked_execute( + self, connection: ydb_dbapi.Connection + ) -> None: + self._test_other_cursor_reusable_after_blocked_execute(connection) + + def test_stream_cursor_reusable_after_blocked_execute( + self, connection: ydb_dbapi.Connection + ) -> None: + self._test_stream_cursor_reusable_after_blocked_execute(connection) + + def test_connection_close_with_running_stream_cursor( + self, connection_kwargs: dict + ) -> None: + conn = ydb_dbapi.connect(**connection_kwargs) + self._test_connection_close_with_running_stream_cursor(conn) + + +class TestAsyncStreamCursorIntegration(BaseStreamCursorIntegrationTestSuit): + @pytest_asyncio.fixture + async def connection( + self, connection_kwargs: dict + ) -> AsyncGenerator[ydb_dbapi.AsyncConnection]: + def connect() -> ydb_dbapi.AsyncConnection: + return maybe_await(ydb_dbapi.async_connect(**connection_kwargs)) + + conn = await greenlet_spawn(connect) + try: + yield conn + finally: + + def close() -> None: + maybe_await(conn.close()) + + await greenlet_spawn(close) + + @pytest.mark.asyncio + async def test_stream_cursor_blocks_shared_transaction_session( + self, connection: ydb_dbapi.AsyncConnection + ) -> None: + await greenlet_spawn( + self._test_stream_cursor_blocks_shared_transaction_session, + connection, + ) + + @pytest.mark.asyncio + async def test_stream_cursor_fetches_real_data( + self, connection: ydb_dbapi.AsyncConnection + ) -> None: + await greenlet_spawn( + self._test_stream_cursor_fetches_real_data, + connection, + ) + + @pytest.mark.asyncio + async def test_stream_cursor_empty_result_set( + self, connection: ydb_dbapi.AsyncConnection + ) -> None: + await greenlet_spawn( + self._test_stream_cursor_empty_result_set, + connection, + ) + + @pytest.mark.asyncio + async def test_stream_cursor_close_releases_session_for_next_query( + self, connection: ydb_dbapi.AsyncConnection + ) -> None: + await greenlet_spawn( + self._test_stream_cursor_close_releases_session_for_next_query, + connection, + ) + + @pytest.mark.asyncio + async def test_stream_cursor_execute_while_running_fails( + self, connection: ydb_dbapi.AsyncConnection + ) -> None: + await greenlet_spawn( + self._test_stream_cursor_execute_while_running_fails, + connection, + ) + + @pytest.mark.asyncio + async def test_stream_cursor_close_unblocks_shared_transaction_session( + self, connection: ydb_dbapi.AsyncConnection + ) -> None: + await greenlet_spawn( + self._test_stream_cursor_close_unblocks_shared_transaction_session, + connection, + ) + + @pytest.mark.asyncio + async def test_stream_cursor_many_result_sets_fetchone_state( + self, connection: ydb_dbapi.AsyncConnection + ) -> None: + await greenlet_spawn( + self._test_stream_cursor_many_result_sets_fetchone_state, + connection, + ) + + @pytest.mark.asyncio + async def test_stream_cursor_many_result_sets_fetchmany( + self, connection: ydb_dbapi.AsyncConnection + ) -> None: + await greenlet_spawn( + self._test_stream_cursor_many_result_sets_fetchmany, + connection, + ) + + @pytest.mark.asyncio + async def test_stream_cursor_many_result_sets_fetchall( + self, connection: ydb_dbapi.AsyncConnection + ) -> None: + await greenlet_spawn( + self._test_stream_cursor_many_result_sets_fetchall, + connection, + ) + + @pytest.mark.asyncio + async def test_other_cursor_reusable_after_blocked_execute( + self, connection: ydb_dbapi.AsyncConnection + ) -> None: + await greenlet_spawn( + self._test_other_cursor_reusable_after_blocked_execute, + connection, + ) + + @pytest.mark.asyncio + async def test_stream_cursor_reusable_after_blocked_execute( + self, connection: ydb_dbapi.AsyncConnection + ) -> None: + await greenlet_spawn( + self._test_stream_cursor_reusable_after_blocked_execute, + connection, + ) + + @pytest.mark.asyncio + async def test_connection_close_with_running_stream_cursor( + self, connection_kwargs: dict + ) -> None: + def connect() -> ydb_dbapi.AsyncConnection: + return maybe_await(ydb_dbapi.async_connect(**connection_kwargs)) + + conn = await greenlet_spawn(connect) + await greenlet_spawn( + self._test_connection_close_with_running_stream_cursor, + conn, + ) diff --git a/tests/test_cursors_unit.py b/tests/test_cursors_unit.py new file mode 100644 index 0000000..81dd918 --- /dev/null +++ b/tests/test_cursors_unit.py @@ -0,0 +1,304 @@ +from __future__ import annotations + +from collections.abc import AsyncGenerator +from inspect import iscoroutine + +import pytest +import ydb +from sqlalchemy.util import await_only +from typing_extensions import Self +from ydb_dbapi import AsyncCursor +from ydb_dbapi import AsyncStreamCursor +from ydb_dbapi import Cursor +from ydb_dbapi import StreamCursor + + +def maybe_await(obj: callable) -> any: + if not iscoroutine(obj): + return obj + return await_only(obj) + + +RESULT_SET_LENGTH = 4 +RESULT_SET_COUNT = 3 + + +class FakeResultSet: + def __init__(self, rows: list[tuple], columns: list[object]) -> None: + self.rows = rows + self.columns = columns + + +class FakeSyncResponseContextIterator: + def __init__(self, result_sets: list[FakeResultSet]) -> None: + self._result_sets = iter(result_sets) + self.cancelled = False + + def __iter__(self) -> Self: + return self + + def __next__(self) -> FakeResultSet: + return next(self._result_sets) + + def cancel(self) -> None: + self.cancelled = True + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type: object, exc: object, tb: object) -> None: + for _ in self: + pass + + +class FakeAsyncResponseContextIterator: + def __init__(self, result_sets: list[FakeResultSet]) -> None: + self._result_sets = iter(result_sets) + self.cancelled = False + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> FakeResultSet: + try: + return next(self._result_sets) + except StopIteration as e: + raise StopAsyncIteration from e + + def cancel(self) -> None: + self.cancelled = True + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, exc_type: object, exc: object, tb: object + ) -> None: + async for _ in self: + pass + + +class FakeSyncStreamSession: + def __init__(self, result_sets: list[FakeResultSet]) -> None: + self._result_sets = result_sets + self.deleted = False + + def transaction(self, tx_mode: ydb.BaseQueryTxMode) -> Self: + return self + + def execute(self, **kwargs: object) -> FakeSyncResponseContextIterator: + return FakeSyncResponseContextIterator(self._result_sets) + + def delete(self) -> None: + self.deleted = True + + +class FakeAsyncStreamSession: + def __init__(self, result_sets: list[FakeResultSet]) -> None: + self._result_sets = result_sets + self.deleted = False + + def transaction(self, tx_mode: ydb.BaseQueryTxMode) -> Self: + return self + + async def execute( + self, **kwargs: object + ) -> FakeAsyncResponseContextIterator: + return FakeAsyncResponseContextIterator(self._result_sets) + + async def delete(self) -> None: + self.deleted = True + + +class FakeSyncStreamSessionPool: + def __init__(self, result_sets: list[FakeResultSet]) -> None: + self.session = FakeSyncStreamSession(result_sets) + self.released_sessions: list[FakeSyncStreamSession] = [] + + def acquire(self, timeout: float | None = None) -> FakeSyncStreamSession: + return self.session + + def release(self, session: FakeSyncStreamSession) -> None: + self.released_sessions.append(session) + + +class FakeAsyncStreamSessionPool: + def __init__(self, result_sets: list[FakeResultSet]) -> None: + self.session = FakeAsyncStreamSession(result_sets) + self.released_sessions: list[FakeAsyncStreamSession] = [] + + async def acquire( + self, timeout: float | None = None + ) -> FakeAsyncStreamSession: + return self.session + + async def release(self, session: FakeAsyncStreamSession) -> None: + self.released_sessions.append(session) + + +def make_result_sets(count: int = 1) -> list[FakeResultSet]: + return [ + FakeResultSet( + rows=[(row_id, row_id) for row_id in range(RESULT_SET_LENGTH)], + columns=[], + ) + for _ in range(count) + ] + + +class FakeSyncConnection: + def _clear_current_cursor(self, cursor: Cursor | None = None) -> None: ... + + def _invalidate_session(self) -> None: ... + + def _set_current_cursor(self, cursor: StreamCursor) -> None: ... + + +class FakeAsyncConnection: + def _clear_current_cursor( + self, cursor: AsyncCursor | None = None + ) -> None: ... + + async def _invalidate_session(self) -> None: ... + + def _set_current_cursor(self, cursor: AsyncStreamCursor) -> None: ... + + +class BaseStreamCursorTestSuit: + def _test_stream_cursor_fetch_one(self, cursor: StreamCursor) -> None: + maybe_await(cursor.execute("SELECT id, val FROM table")) + + assert cursor.rowcount == -1 + + for i in range(4): + row = cursor.fetchone() + assert row is not None + assert row[0] == i + + assert cursor.fetchone() is None + assert cursor.rowcount == 4 + + def _test_stream_cursor_fetch_many(self, cursor: StreamCursor) -> None: + maybe_await( + cursor.execute( + """ + SELECT id, val FROM table; + SELECT id, val FROM table1; + SELECT id, val FROM table2; + """ + ) + ) + + assert cursor.rowcount == -1 + + rows = cursor.fetchmany(size=5) + assert len(rows) == 5 + assert cursor.rowcount == -1 + + rows = cursor.fetchmany(size=7) + assert len(rows) == 7 + assert cursor.fetchmany(size=1) == [] + assert cursor.rowcount == 12 + + def _test_stream_cursor_fetch_all(self, cursor: StreamCursor) -> None: + maybe_await(cursor.execute("SELECT id, val FROM table")) + + rows = cursor.fetchall() + assert len(rows) == 4 + assert cursor.rowcount == 4 + assert cursor.fetchall() == [] + + +class TestStreamCursor(BaseStreamCursorTestSuit): + @pytest.fixture + def sync_cursor(self) -> StreamCursor: + cursor = StreamCursor( + FakeSyncConnection(), + FakeSyncStreamSessionPool(make_result_sets()), + ydb.QuerySerializableReadWrite(), + request_settings=ydb.BaseRequestSettings(), + retry_settings=ydb.RetrySettings(), + ) + try: + yield cursor + finally: + cursor.close() + + def test_cursor_fetch_one(self, sync_cursor: StreamCursor) -> None: + self._test_stream_cursor_fetch_one(sync_cursor) + + def test_cursor_fetch_many(self, sync_cursor: StreamCursor) -> None: + sync_cursor._session_pool = FakeSyncStreamSessionPool( + make_result_sets(RESULT_SET_COUNT) + ) + self._test_stream_cursor_fetch_many(sync_cursor) + + def test_cursor_fetch_all(self, sync_cursor: StreamCursor) -> None: + self._test_stream_cursor_fetch_all(sync_cursor) + + +class TestAsyncStreamCursor: + @pytest.fixture + async def async_cursor(self) -> AsyncGenerator[AsyncStreamCursor]: + cursor = AsyncStreamCursor( + FakeAsyncConnection(), + FakeAsyncStreamSessionPool(make_result_sets()), + ydb.QuerySerializableReadWrite(), + request_settings=ydb.BaseRequestSettings(), + retry_settings=ydb.RetrySettings(), + ) + yield cursor + await cursor.close() + + @pytest.mark.asyncio + async def test_cursor_fetch_one( + self, async_cursor: AsyncStreamCursor + ) -> None: + await async_cursor.execute("SELECT id, val FROM table") + + assert async_cursor.rowcount == -1 + + for i in range(4): + row = await async_cursor.fetchone() + assert row is not None + assert row[0] == i + + assert await async_cursor.fetchone() is None + assert async_cursor.rowcount == 4 + + @pytest.mark.asyncio + async def test_cursor_fetch_many( + self, async_cursor: AsyncStreamCursor + ) -> None: + async_cursor._session_pool = FakeAsyncStreamSessionPool( + make_result_sets(RESULT_SET_COUNT) + ) + await async_cursor.execute( + """ + SELECT id, val FROM table; + SELECT id, val FROM table1; + SELECT id, val FROM table2; + """ + ) + + assert async_cursor.rowcount == -1 + + rows = await async_cursor.fetchmany(size=5) + assert len(rows) == 5 + assert async_cursor.rowcount == -1 + + rows = await async_cursor.fetchmany(size=7) + assert len(rows) == 7 + assert await async_cursor.fetchmany(size=1) == [] + assert async_cursor.rowcount == 12 + + @pytest.mark.asyncio + async def test_cursor_fetch_all( + self, async_cursor: AsyncStreamCursor + ) -> None: + await async_cursor.execute("SELECT id, val FROM table") + + rows = await async_cursor.fetchall() + assert len(rows) == 4 + assert async_cursor.rowcount == 4 + assert await async_cursor.fetchall() == [] diff --git a/ydb_dbapi/__init__.py b/ydb_dbapi/__init__.py index 1d2f2dd..bd79565 100644 --- a/ydb_dbapi/__init__.py +++ b/ydb_dbapi/__init__.py @@ -5,7 +5,9 @@ from .connections import connect from .constants import * from .cursors import AsyncCursor +from .cursors import AsyncStreamCursor from .cursors import Cursor +from .cursors import StreamCursor from .errors import * from .version import VERSION diff --git a/ydb_dbapi/connections.py b/ydb_dbapi/connections.py index 1ee215b..b4fec90 100644 --- a/ydb_dbapi/connections.py +++ b/ydb_dbapi/connections.py @@ -15,10 +15,14 @@ from ydb.retries import retry_operation_sync from .cursors import AsyncCursor +from .cursors import AsyncStreamCursor from .cursors import Cursor +from .cursors import StreamCursor from .errors import InterfaceError from .errors import InternalError from .errors import NotSupportedError +from .errors import ProgrammingError +from .utils import CursorStatus from .utils import handle_ydb_errors from .utils import maybe_get_current_trace_id from .utils import prepare_credentials @@ -203,6 +207,7 @@ class Connection(BaseConnection): _driver_cls = ydb.Driver _pool_cls = ydb.QuerySessionPool _cursor_cls = Cursor + _stream_cursor_cls = StreamCursor def __init__( self, @@ -231,10 +236,13 @@ def __init__( driver_config_kwargs=driver_config_kwargs, **kwargs, ) - self._current_cursor: Cursor | None = None + self._current_cursor: StreamCursor | None = None - def cursor(self) -> Cursor: - return self._cursor_cls( + def cursor(self, stream_results: bool = False) -> Cursor | StreamCursor: + cursor_cls = ( + self._stream_cursor_cls if stream_results else self._cursor_cls + ) + return cursor_cls( connection=self, session_pool=self._session_pool, tx_mode=self._tx_mode, @@ -266,6 +274,7 @@ def begin(self) -> None: @handle_ydb_errors def commit(self) -> None: + self._raise_if_current_cursor_running() if self._tx_context: settings = self._get_request_settings() self._tx_context.commit(settings=settings) @@ -273,9 +282,11 @@ def commit(self) -> None: if self._session: self._session_pool.release(self._session) self._session = None + self._clear_current_cursor() @handle_ydb_errors def rollback(self) -> None: + self._raise_if_current_cursor_running() if self._tx_context: settings = self._get_request_settings() self._tx_context.rollback(settings=settings) @@ -283,9 +294,12 @@ def rollback(self) -> None: if self._session: self._session_pool.release(self._session) self._session = None + self._clear_current_cursor() @handle_ydb_errors def close(self) -> None: + if self._current_cursor is not None: + self._current_cursor.close() self.rollback() if self._session: @@ -387,17 +401,38 @@ def bulk_upsert( ) def _invalidate_session(self) -> None: + self._clear_current_cursor() if self._tx_context: self._tx_context = None if self._session: self._session_pool.release(self._session) self._session = None + def _set_current_cursor(self, cursor: StreamCursor) -> None: + self._current_cursor = cursor + + def _clear_current_cursor(self, cursor: Cursor | None = None) -> None: + if cursor is None or self._current_cursor is cursor: + self._current_cursor = None + + def _raise_if_current_cursor_running( + self, cursor: Cursor | None = None + ) -> None: + if self._current_cursor is None or self._current_cursor is cursor: + return + if self._current_cursor._state != CursorStatus.running: + return + raise ProgrammingError( + "Could not perform operation: a server-side cursor is still " + "streaming results for the current transaction session." + ) + class AsyncConnection(BaseConnection): _driver_cls = ydb.aio.Driver _pool_cls = ydb.aio.QuerySessionPool _cursor_cls = AsyncCursor + _stream_cursor_cls = AsyncStreamCursor def __init__( self, @@ -426,10 +461,15 @@ def __init__( driver_config_kwargs=driver_config_kwargs, **kwargs, ) - self._current_cursor: AsyncCursor | None = None + self._current_cursor: AsyncStreamCursor | None = None - def cursor(self) -> AsyncCursor: - return self._cursor_cls( + def cursor( + self, stream_results: bool = False + ) -> AsyncCursor | AsyncStreamCursor: + cursor_cls = ( + self._stream_cursor_cls if stream_results else self._cursor_cls + ) + return cursor_cls( connection=self, session_pool=self._session_pool, tx_mode=self._tx_mode, @@ -461,6 +501,7 @@ async def begin(self) -> None: @handle_ydb_errors async def commit(self) -> None: + self._raise_if_current_cursor_running() if self._tx_context: settings = self._get_request_settings() await self._tx_context.commit(settings=settings) @@ -468,9 +509,11 @@ async def commit(self) -> None: if self._session: await self._session_pool.release(self._session) self._session = None + self._clear_current_cursor() @handle_ydb_errors async def rollback(self) -> None: + self._raise_if_current_cursor_running() if self._tx_context: settings = self._get_request_settings() await self._tx_context.rollback(settings=settings) @@ -478,9 +521,12 @@ async def rollback(self) -> None: if self._session: await self._session_pool.release(self._session) self._session = None + self._clear_current_cursor() @handle_ydb_errors async def close(self) -> None: + if self._current_cursor is not None: + await self._current_cursor.close() await self.rollback() if self._session: @@ -586,12 +632,34 @@ async def bulk_upsert( ) async def _invalidate_session(self) -> None: + self._clear_current_cursor() if self._tx_context: self._tx_context = None if self._session: await self._session_pool.release(self._session) self._session = None + def _set_current_cursor(self, cursor: AsyncStreamCursor) -> None: + self._current_cursor = cursor + + def _clear_current_cursor( + self, cursor: AsyncStreamCursor | None = None + ) -> None: + if cursor is None or self._current_cursor is cursor: + self._current_cursor = None + + def _raise_if_current_cursor_running( + self, cursor: AsyncStreamCursor | None = None + ) -> None: + if self._current_cursor is None or self._current_cursor is cursor: + return + if self._current_cursor._state != CursorStatus.running: + return + raise ProgrammingError( + "Could not perform operation: a server-side cursor is still " + "streaming results for the current transaction session." + ) + def connect(*args: Any, **kwargs: Any) -> Connection: conn = Connection(*args, **kwargs) diff --git a/ydb_dbapi/cursors.py b/ydb_dbapi/cursors.py index 7060401..643c1a5 100644 --- a/ydb_dbapi/cursors.py +++ b/ydb_dbapi/cursors.py @@ -14,6 +14,10 @@ import ydb from typing_extensions import Self +from ydb.aio.query.base import AsyncResponseContextIterator +from ydb.query.base import SyncResponseContextIterator +from ydb.retries import retry_operation_async +from ydb.retries import retry_operation_sync from .errors import DatabaseError from .errors import InterfaceError @@ -51,7 +55,7 @@ async def awrapper( return await func(self, *args, **kwargs) except ydb.Error: self._state = CursorStatus.finished - await self._connection._invalidate_session() + await self._invalidate_active_session() raise return awrapper @@ -62,19 +66,20 @@ def wrapper(self: Cursor, *args: tuple, **kwargs: dict) -> Any: return func(self, *args, **kwargs) except ydb.Error: self._state = CursorStatus.finished - self._connection._invalidate_session() + self._invalidate_active_session() raise return wrapper -class BufferedCursor: +class BaseCursor: def __init__(self) -> None: self.arraysize: int = 1 self._rows: Iterator | None = None self._rows_count: int = -1 self._description: list[tuple] | None = None self._state: CursorStatus = CursorStatus.ready + self._rowcount_accumulator: int = 0 self._table_path_prefix: str = "" @@ -92,13 +97,19 @@ def setinputsizes(self) -> None: def setoutputsize(self) -> None: pass + def _reset_result_state(self) -> None: + self._rows = None + self._rows_count = -1 + self._description = None + self._rowcount_accumulator = 0 + def _rows_iterable( self, result_set: ydb.convert.ResultSet ) -> Generator[tuple]: try: for row in result_set.rows: # returns tuple to be compatible with SqlAlchemy and because - # of this PEP to return a sequence: + # of this PEP to return a sequence: # https://www.python.org/dev/peps/pep-0249/#fetchmany yield row[::] except ydb.Error as e: @@ -110,20 +121,14 @@ def _update_result_set( replace_current: bool = True, ) -> None: self._update_description(result_set) + self._rowcount_accumulator += len(result_set.rows) new_rows_iter = self._rows_iterable(result_set) - new_rows_count = len(result_set.rows) or -1 if self._rows is None or replace_current: self._rows = new_rows_iter - self._rows_count = new_rows_count else: self._rows = itertools.chain(self._rows, new_rows_iter) - if new_rows_count != -1: - if self._rows_count != -1: - self._rows_count += new_rows_count - else: - self._rows_count = new_rows_count def _update_description(self, result_set: ydb.convert.ResultSet) -> None: if not result_set.columns: @@ -148,6 +153,9 @@ def _fill_buffer(self, result_set_list: list) -> None: for result_set in result_set_list: self._update_result_set(result_set, replace_current=False) + def _finalize_rowcount(self) -> None: + self._rows_count = self._rowcount_accumulator + def _raise_if_running(self) -> None: if self._state == CursorStatus.running: raise ProgrammingError( @@ -166,9 +174,11 @@ def is_closed(self) -> bool: return self._state == CursorStatus.closed def _begin_query(self) -> None: + self._reset_result_state() self._state = CursorStatus.running def _finish_query(self) -> None: + self._finalize_rowcount() self._state = CursorStatus.finished def _fetchone_from_buffer(self) -> tuple | None: @@ -192,7 +202,7 @@ def _append_table_path_prefix(self, query: str) -> str: return query -class Cursor(BufferedCursor): +class Cursor(BaseCursor): def __init__( self, connection: Connection, @@ -211,7 +221,6 @@ def __init__( self._retry_settings = retry_settings self._tx_context = tx_context self._table_path_prefix = table_path_prefix - self._stream: Iterator | None = None def fetchone(self) -> tuple | None: return self._fetchone_from_buffer() @@ -223,6 +232,9 @@ def fetchmany(self, size: int | None = None) -> list: def fetchall(self) -> list: return self._fetchall_from_buffer() + def _invalidate_active_session(self) -> None: + self._connection._invalidate_session() + def _get_request_settings(self) -> ydb.BaseRequestSettings: settings = self._request_settings.make_copy() @@ -309,6 +321,7 @@ def execute_scheme( parameters: ParametersType | None = None, ) -> None: self._raise_if_closed() + self._raise_if_running() query = self._append_table_path_prefix(query) self._begin_query() @@ -326,6 +339,8 @@ def execute( ) -> None: self._raise_if_closed() self._raise_if_running() + if self._tx_context is not None: + self._connection._raise_if_current_cursor_running() query = self._append_table_path_prefix(query) self._begin_query() @@ -369,7 +384,232 @@ def __exit__( self.close() -class AsyncCursor(BufferedCursor): +class StreamCursor(Cursor): + def __init__( + self, + connection: Connection, + session_pool: ydb.QuerySessionPool, + tx_mode: ydb.BaseQueryTxMode, + request_settings: ydb.BaseRequestSettings, + retry_settings: ydb.RetrySettings, + tx_context: ydb.QueryTxContext | None = None, + table_path_prefix: str = "", + ) -> None: + super().__init__( + connection=connection, + session_pool=session_pool, + tx_mode=tx_mode, + request_settings=request_settings, + retry_settings=retry_settings, + tx_context=tx_context, + table_path_prefix=table_path_prefix, + ) + self._stream: SyncResponseContextIterator | None = None + self._session_owner: ydb.QuerySession | None = None + + def _invalidate_active_session(self) -> None: + self._clear_current_cursor() + if self._session_owner is None: + self._connection._invalidate_session() + return + + session = self._session_owner + self._session_owner = None + self._stream = None + try: + session.delete() + finally: + self._session_pool.release(session) + + def _clear_current_cursor(self) -> None: + self._connection._clear_current_cursor(self) + + def _register_current_cursor(self) -> None: + if self._tx_context is not None: + self._connection._set_current_cursor(self) + + def _release_owned_session(self) -> None: + if self._session_owner is None: + return + + session = self._session_owner + self._session_owner = None + self._session_pool.release(session) + + def _discard_owned_session(self) -> None: + if self._session_owner is None: + return + + session = self._session_owner + self._session_owner = None + try: + session.delete() + finally: + self._session_pool.release(session) + + def _finish_stream(self) -> None: + self._stream = None + self._release_owned_session() + self._clear_current_cursor() + self._finish_query() + + @handle_ydb_errors + @invalidate_cursor_on_ydb_error + def _load_next_result_set(self) -> bool: + if self._stream is None: + return False + + try: + result_set = next(self._stream) + except StopIteration: + self._finish_stream() + return False + + self._update_result_set(result_set, replace_current=False) + return True + + def _prime_stream(self) -> None: + if ( + not self._load_next_result_set() + and self._state == CursorStatus.running + ): + self._finish_stream() + + @handle_ydb_errors + @invalidate_cursor_on_ydb_error + def _execute_session_query_stream( + self, + query: str, + parameters: ParametersType | None = None, + ) -> SyncResponseContextIterator: + settings = self._get_request_settings() + + def callee() -> SyncResponseContextIterator: + acquire_timeout = getattr( + self._retry_settings, + "max_session_acquire_timeout", + None, + ) + session = self._session_pool.acquire(timeout=acquire_timeout) + try: + stream = session.transaction(self._tx_mode).execute( + query=query, + parameters=parameters, + commit_tx=True, + settings=settings, + ) + except Exception: + self._session_pool.release(session) + raise + + self._session_owner = session + return stream + + return retry_operation_sync( + callee, + self._retry_settings, + ) + + @handle_ydb_errors + @invalidate_cursor_on_ydb_error + def _execute_transactional_query_stream( + self, + tx_context: ydb.QueryTxContext, + query: str, + parameters: ParametersType | None = None, + ) -> SyncResponseContextIterator: + settings = self._get_request_settings() + return tx_context.execute( + query=query, + parameters=parameters, + commit_tx=False, + settings=settings, + ) + + def execute( + self, + query: str, + parameters: ParametersType | None = None, + ) -> None: + self._raise_if_closed() + self._raise_if_running() + if self._tx_context is not None: + self._connection._raise_if_current_cursor_running(self) + + query = self._append_table_path_prefix(query) + self._begin_query() + + if self._tx_context is not None: + self._stream = self._execute_transactional_query_stream( + tx_context=self._tx_context, + query=query, + parameters=parameters, + ) + self._register_current_cursor() + else: + self._stream = self._execute_session_query_stream( + query=query, + parameters=parameters, + ) + + self._prime_stream() + + def fetchone(self) -> tuple | None: + self._raise_if_closed() + + while True: + row = self._fetchone_from_buffer() + if row is not None: + return row + if not self._load_next_result_set(): + return None + + def fetchmany(self, size: int | None = None) -> list: + self._raise_if_closed() + result: list[tuple] = [] + target_size = size or self.arraysize + + while len(result) < target_size: + row = self.fetchone() + if row is None: + break + result.append(row) + + return result + + def fetchall(self) -> list: + self._raise_if_closed() + result = list(self._fetchall_from_buffer()) + + while self._load_next_result_set(): + result.extend(self._fetchall_from_buffer()) + + return result + + def close(self) -> None: + if self._state == CursorStatus.closed: + return + + if self._state == CursorStatus.running: + if self._session_owner is not None and self._stream is not None: + self._stream.cancel() + self._stream = None + self._discard_owned_session() + elif self._stream is not None: + with self._stream: + pass + self._finish_stream() + else: + self._release_owned_session() + self._clear_current_cursor() + + self._stream = None + self._session_owner = None + self._clear_current_cursor() + self._state = CursorStatus.closed + + +class AsyncCursor(BaseCursor): def __init__( self, connection: AsyncConnection, @@ -388,7 +628,6 @@ def __init__( self._retry_settings = retry_settings self._tx_context = tx_context self._table_path_prefix = table_path_prefix - self._stream: AsyncIterator | None = None def fetchone(self) -> tuple | None: return self._fetchone_from_buffer() @@ -400,6 +639,9 @@ def fetchmany(self, size: int | None = None) -> list: def fetchall(self) -> list: return self._fetchall_from_buffer() + async def _invalidate_active_session(self) -> None: + await self._connection._invalidate_session() + def _get_request_settings(self) -> ydb.BaseRequestSettings: settings = self._request_settings.make_copy() @@ -486,6 +728,7 @@ async def execute_scheme( parameters: ParametersType | None = None, ) -> None: self._raise_if_closed() + self._raise_if_running() query = self._append_table_path_prefix(query) self._begin_query() @@ -503,6 +746,8 @@ async def execute( ) -> None: self._raise_if_closed() self._raise_if_running() + if self._tx_context is not None: + self._connection._raise_if_current_cursor_running() query = self._append_table_path_prefix(query) @@ -545,3 +790,284 @@ async def __aexit__( tb: object, ) -> None: self.close() + + +class AsyncStreamCursor(BaseCursor): + def __init__( + self, + connection: AsyncConnection, + session_pool: ydb.aio.QuerySessionPool, + tx_mode: ydb.BaseQueryTxMode, + request_settings: ydb.BaseRequestSettings, + retry_settings: ydb.RetrySettings, + tx_context: ydb.aio.QueryTxContext | None = None, + table_path_prefix: str = "", + ) -> None: + super().__init__() + self._connection = connection + self._session_pool = session_pool + self._tx_mode = tx_mode + self._request_settings = request_settings + self._retry_settings = retry_settings + self._tx_context = tx_context + self._table_path_prefix = table_path_prefix + self._stream: AsyncResponseContextIterator | None = None + self._session_owner: ydb.aio.QuerySession | None = None + + async def _invalidate_active_session(self) -> None: + self._clear_current_cursor() + self._stream = None + if self._session_owner is None: + await self._connection._invalidate_session() + return + + session = self._session_owner + self._session_owner = None + try: + await session.delete() + finally: + await self._session_pool.release(session) + + def _clear_current_cursor(self) -> None: + self._connection._clear_current_cursor(self) + + def _register_current_cursor(self) -> None: + if self._tx_context is not None: + self._connection._set_current_cursor(self) + + def _get_request_settings(self) -> ydb.BaseRequestSettings: + settings = self._request_settings.make_copy() + + if self._request_settings.trace_id is None: + settings = settings.with_trace_id(maybe_get_current_trace_id()) + + return settings + + async def _release_owned_session(self) -> None: + if self._session_owner is None: + return + + session = self._session_owner + self._session_owner = None + await self._session_pool.release(session) + + async def _discard_owned_session(self) -> None: + if self._session_owner is None: + return + + session = self._session_owner + self._session_owner = None + try: + await session.delete() + finally: + await self._session_pool.release(session) + + async def _finish_stream(self) -> None: + self._stream = None + await self._release_owned_session() + self._clear_current_cursor() + self._finish_query() + + @handle_ydb_errors + @invalidate_cursor_on_ydb_error + async def _load_next_result_set(self) -> bool: + if self._stream is None: + return False + + try: + result_set = await self._stream.__anext__() + except StopAsyncIteration: + await self._finish_stream() + return False + + self._update_result_set(result_set, replace_current=False) + return True + + async def _prime_stream(self) -> None: + if ( + not await self._load_next_result_set() + and self._state == CursorStatus.running + ): + await self._finish_stream() + + @handle_ydb_errors + @invalidate_cursor_on_ydb_error + async def _execute_session_query_stream( + self, + query: str, + parameters: ParametersType | None = None, + ) -> AsyncResponseContextIterator: + settings = self._get_request_settings() + + async def callee() -> AsyncResponseContextIterator: + acquire_timeout = getattr( + self._retry_settings, + "max_session_acquire_timeout", + None, + ) + session = await self._session_pool.acquire(timeout=acquire_timeout) + try: + stream = await session.transaction(self._tx_mode).execute( + query=query, + parameters=parameters, + commit_tx=True, + settings=settings, + ) + except Exception: + await self._session_pool.release(session) + raise + + self._session_owner = session + return stream + + return await retry_operation_async( + callee, + self._retry_settings, + ) + + @handle_ydb_errors + @invalidate_cursor_on_ydb_error + async def _execute_transactional_query_stream( + self, + tx_context: ydb.aio.QueryTxContext, + query: str, + parameters: ParametersType | None = None, + ) -> AsyncResponseContextIterator: + settings = self._get_request_settings() + return await tx_context.execute( + query=query, + parameters=parameters, + commit_tx=False, + settings=settings, + ) + + async def execute( + self, + query: str, + parameters: ParametersType | None = None, + ) -> None: + self._raise_if_closed() + self._raise_if_running() + if self._tx_context is not None: + self._connection._raise_if_current_cursor_running(self) + + query = self._append_table_path_prefix(query) + self._begin_query() + + if self._tx_context is not None: + self._stream = await self._execute_transactional_query_stream( + tx_context=self._tx_context, + query=query, + parameters=parameters, + ) + self._register_current_cursor() + else: + self._stream = await self._execute_session_query_stream( + query=query, + parameters=parameters, + ) + + await self._prime_stream() + + async def execute_scheme( + self, + query: str, + parameters: ParametersType | None = None, + ) -> None: + self._raise_if_closed() + self._raise_if_running() + + query = self._append_table_path_prefix(query) + self._begin_query() + + settings = self._get_request_settings() + + @handle_ydb_errors + @invalidate_cursor_on_ydb_error + async def execute_generic_query() -> list[ydb.convert.ResultSet]: + async def callee( + session: ydb.aio.QuerySession, + ) -> list[ydb.convert.ResultSet]: + stream = await session.execute( + query=query, + parameters=parameters, + settings=settings, + ) + return [result_set async for result_set in stream] + + return await self._session_pool.retry_operation_async( + callee, + retry_settings=self._retry_settings, + ) + + result_list = await execute_generic_query() + self._fill_buffer(result_list) + self._finish_query() + + async def fetchone(self) -> tuple | None: + self._raise_if_closed() + + while True: + row = self._fetchone_from_buffer() + if row is not None: + return row + if not await self._load_next_result_set(): + return None + + async def fetchmany(self, size: int | None = None) -> list: + self._raise_if_closed() + result: list[tuple] = [] + target_size = size or self.arraysize + + while len(result) < target_size: + row = await self.fetchone() + if row is None: + break + result.append(row) + + return result + + async def fetchall(self) -> list: + self._raise_if_closed() + result = list(self._fetchall_from_buffer()) + + while await self._load_next_result_set(): + result.extend(self._fetchall_from_buffer()) + + return result + + async def nextset(self) -> bool: + return False + + async def close(self) -> None: + if self._state == CursorStatus.closed: + return + + if self._state == CursorStatus.running: + if self._session_owner is not None and self._stream is not None: + self._stream.cancel() + self._stream = None + await self._discard_owned_session() + elif self._stream is not None: + async with self._stream: + pass + await self._finish_stream() + else: + await self._release_owned_session() + self._clear_current_cursor() + + self._stream = None + self._session_owner = None + self._clear_current_cursor() + self._state = CursorStatus.closed + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: object, + ) -> None: + await self.close() diff --git a/ydb_dbapi/utils.py b/ydb_dbapi/utils.py index 5e0e1f0..c3738db 100644 --- a/ydb_dbapi/utils.py +++ b/ydb_dbapi/utils.py @@ -12,6 +12,7 @@ from .errors import DatabaseError from .errors import DataError +from .errors import Error from .errors import IntegrityError from .errors import InternalError from .errors import NotSupportedError @@ -26,6 +27,8 @@ def handle_ydb_errors(func: Callable) -> Callable: # noqa: C901 async def awrapper(*args: tuple, **kwargs: dict) -> Any: try: return await func(*args, **kwargs) + except Error: + raise except ( ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed, @@ -65,6 +68,8 @@ async def awrapper(*args: tuple, **kwargs: dict) -> Any: def wrapper(*args: tuple, **kwargs: dict) -> Any: try: return func(*args, **kwargs) + except Error: + raise except ( ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed,