diff --git a/tests/conftest.py b/tests/conftest.py index 7112f987..368ca0e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,9 @@ import ydb from ydb import issues +YDB_ENDPOINT_PORT = 2136 +YDB_SECURE_ENDPOINT_PORT = 2135 + def _docker_client(): """Build a Docker SDK client that works with non-default sockets. @@ -44,31 +47,6 @@ def pytest_addoption(parser): default="docker-compose.yml", help="Path to docker-compose file (relative to project root)", ) - parser.addoption( - "--ydb-endpoint", - action="store", - default=None, - help=( - "Use an already-running YDB endpoint (e.g. localhost:2136) instead of spinning " - "a container via docker-compose. Also honored from the YDB_ENDPOINT env var. " - "Tests that explicitly restart the container via the `docker_project` fixture " - "(chaos-style) are incompatible with this mode." - ), - ) - - -def _running_ydb_endpoint(pytestconfig): - """Return a pre-running endpoint if the user asked for one, else None.""" - existing = pytestconfig.getoption("--ydb-endpoint") or os.environ.get("YDB_ENDPOINT") - if not existing: - return None - # Strip scheme if present — the `endpoint` fixture is expected to return - # the "host:port" form that is_ydb_responsive / driver construction use. - for prefix in ("grpcs://", "grpc://"): - if existing.startswith(prefix): - existing = existing[len(prefix) :] - break - return existing @pytest.fixture(scope="session") @@ -181,26 +159,15 @@ def is_ydb_secure_responsive(endpoint, root_certificates): @pytest.fixture(scope="module") -def endpoint(pytestconfig, request): - """Wait for YDB to be responsive and return endpoint. - - If --ydb-endpoint / YDB_ENDPOINT is set, return it directly without - touching pytest-docker — this lets tests run against an already-running - container. - """ - existing = _running_ydb_endpoint(pytestconfig) - if existing is not None: - if not is_ydb_responsive(existing): - raise RuntimeError(f"--ydb-endpoint={existing} is not responsive") - yield existing - return - - # Pytest-docker path: resolve docker_services lazily so the fixtures are - # only requested when we actually need to spin a container. - docker_ip = request.getfixturevalue("docker_ip") +def endpoint(request): + """Wait for YDB to be responsive and return endpoint.""" + # Resolve docker_services lazily so the fixture is only requested when a + # test actually needs the container. The compose file publishes a fixed + # host port, so avoid `port_for()` — it shells out to `docker compose + # port`, which is exactly the late subprocess path that can trip the + # Python 3.9 gRPC fork race. docker_services = request.getfixturevalue("docker_services") - port = docker_services.port_for("ydb", 2136) - endpoint_url = f"{docker_ip}:{port}" + endpoint_url = f"localhost:{YDB_ENDPOINT_PORT}" docker_services.wait_until_responsive( timeout=60.0, pause=1.0, @@ -210,7 +177,7 @@ def endpoint(pytestconfig, request): @pytest.fixture(scope="session") -def secure_endpoint(pytestconfig, docker_ip, docker_services): +def secure_endpoint(pytestconfig, docker_services): """Wait for YDB TLS endpoint to be responsive.""" ca_path = os.path.join(str(pytestconfig.rootdir), "ydb_certs/ca.pem") @@ -228,9 +195,7 @@ def wait_for_certificate(): os.environ["YDB_SSL_ROOT_CERTIFICATES_FILE"] = ca_path root_certificates = ydb.load_ydb_root_certificate() - port = docker_services.port_for("ydb", 2135) - # Use 'localhost' instead of docker_ip because SSL certificate is issued for 'localhost' - endpoint_url = f"localhost:{port}" + endpoint_url = f"localhost:{YDB_SECURE_ENDPOINT_PORT}" docker_services.wait_until_responsive( timeout=60.0, diff --git a/ydb/aio/query/pool.py b/ydb/aio/query/pool.py index 45fa1acb..a0d9d93c 100644 --- a/ydb/aio/query/pool.py +++ b/ydb/aio/query/pool.py @@ -137,7 +137,6 @@ async def acquire(self, timeout: Optional[float] = None) -> QuerySession: async def release(self, session: QuerySession) -> None: """Release a session back to Session Pool.""" - self._queue.put_nowait(session) logger.debug("Session returned to queue: %s", session.session_id) diff --git a/ydb/aio/query/pool_test.py b/ydb/aio/query/pool_test.py index 3a09bded..de33a8e0 100644 --- a/ydb/aio/query/pool_test.py +++ b/ydb/aio/query/pool_test.py @@ -6,6 +6,7 @@ from ydb import issues from ydb.aio.query.pool import QuerySessionPool +from ydb.aio.query.session import QuerySession def _make_pool(size=1): @@ -92,3 +93,25 @@ async def enqueue_immediately(): await asyncio.sleep(0.05) total = pool._queue.qsize() + len(released_sessions) self.assertGreaterEqual(total, 0) + + async def test_retry_reacquires_invalidated_session_before_first_use(self): + pool = _make_pool(size=1) + + invalidated_session = QuerySession.__new__(QuerySession) + invalidated_session._session_id = "invalidated-session" + invalidated_session._closed = False + invalidated_session._invalidated = False + invalidated_session._stream = None + invalidated_session._close_session(invalidate=True) + + live_session = MagicMock() + live_session.explain = AsyncMock(return_value="ok") + + sessions = iter([invalidated_session, live_session]) + pool.acquire = AsyncMock(side_effect=lambda timeout=None: next(sessions)) + pool.release = AsyncMock() + + result = await pool.retry_operation_async(lambda session: session.explain("SELECT 1")) + + self.assertEqual(result, "ok") + live_session.explain.assert_awaited_once_with("SELECT 1") diff --git a/ydb/aio/query/session.py b/ydb/aio/query/session.py index 67e62ff6..a565b266 100644 --- a/ydb/aio/query/session.py +++ b/ydb/aio/query/session.py @@ -62,7 +62,7 @@ async def _attach(self) -> None: ) issues._process_response(first_response) except Exception as e: - self._invalidate() + self._close_session(invalidate=True) raise e self._loop.create_task(self._check_session_status_loop(), name="check session status task") @@ -76,7 +76,7 @@ async def _check_session_status_loop(self) -> None: logger.debug("Attach stream closed, session_id: %s", self._session_id) except Exception as e: logger.debug("Attach stream error: %s, session_id: %s", e, self._session_id) - self._invalidate() + self._close_session(invalidate=True) async def delete(self, settings: Optional[BaseRequestSettings] = None) -> None: """Deletes a Session of Query Service on server side and releases resources. @@ -92,7 +92,7 @@ async def delete(self, settings: Optional[BaseRequestSettings] = None) -> None: except Exception: pass - self._invalidate() + self._close_session() async def create(self, settings: Optional[BaseRequestSettings] = None) -> "QuerySession": """Creates a Session of Query Service on server side and attaches it. diff --git a/ydb/aio/query/transaction.py b/ydb/aio/query/transaction.py index e4232580..c31d79fb 100644 --- a/ydb/aio/query/transaction.py +++ b/ydb/aio/query/transaction.py @@ -72,7 +72,7 @@ async def __aexit__(self, *args, **kwargs): logger.warning("Failed to rollback leaked tx: %s", self._tx_state.tx_id) except BaseException: logger.warning("Failed to rollback leaked tx: %s", self._tx_state.tx_id) - self.session._invalidate() + self.session._close_session(invalidate=True) async def _ensure_prev_stream_finished(self) -> None: if self._prev_stream is not None: diff --git a/ydb/query/base.py b/ydb/query/base.py index bf0d80b9..7fea6cd0 100644 --- a/ydb/query/base.py +++ b/ydb/query/base.py @@ -205,7 +205,7 @@ def decorator(rpc_state, response_pb, session: "BaseQuerySession", *args, **kwar try: return func(rpc_state, response_pb, session, *args, **kwargs) except issues.BadSession: - session._invalidate() + session._close_session(invalidate=True) raise return decorator diff --git a/ydb/query/pool.py b/ydb/query/pool.py index cba3e47e..44d4d34a 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -125,7 +125,6 @@ def acquire(self, timeout: Optional[float] = None) -> QuerySession: def release(self, session: QuerySession) -> None: """Release a session back to Session Pool.""" - self._queue.put_nowait(session) logger.debug("Session returned to queue: %s", session.session_id) diff --git a/ydb/query/pool_test.py b/ydb/query/pool_test.py index c7ff3e4e..a7423400 100644 --- a/ydb/query/pool_test.py +++ b/ydb/query/pool_test.py @@ -7,6 +7,7 @@ from ydb import issues from ydb.query.pool import QuerySessionPool +from ydb.query.session import QuerySession def _make_pool(size=1): @@ -54,3 +55,24 @@ def release_after_delay(): self.assertIs(acquired, session) finally: t.join() + + +class TestRetryOperationSync(unittest.TestCase): + def test_retry_reacquires_invalidated_session_before_first_use(self): + pool = _make_pool(size=1) + + invalidated_session = QuerySession(pool._driver) + invalidated_session._session_id = "invalidated-session" + invalidated_session._close_session(invalidate=True) + + live_session = MagicMock() + live_session.explain.return_value = "ok" + + sessions = iter([invalidated_session, live_session]) + pool.acquire = MagicMock(side_effect=lambda timeout=None: next(sessions)) + pool.release = MagicMock() + + result = pool.retry_operation_sync(lambda session: session.explain("SELECT 1")) + + self.assertEqual(result, "ok") + live_session.explain.assert_called_once_with("SELECT 1") diff --git a/ydb/query/session.py b/ydb/query/session.py index b28cba8b..f2099f8c 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -56,7 +56,6 @@ def wrapper_delete_session( ) -> "BaseQuerySession": message = _ydb_query.DeleteSessionResponse.from_proto(response_pb) issues._process_response(message.status) - session._closed = True return session @@ -71,6 +70,7 @@ class BaseQuerySession(abc.ABC, Generic[DriverT]): _session_id: Optional[str] = None _node_id: Optional[int] = None _closed: bool = False + _invalidated: bool = False def __init__(self, driver: DriverT, settings: Optional[base.QueryClientSettings] = None): self._driver = driver @@ -122,12 +122,18 @@ def _get_client_settings( return base.QueryClientSettings() def _check_session_ready_to_use(self) -> None: - if not self.is_active: + if self._session_id is None: + raise RuntimeError("Session is not initialized") + if self._invalidated: + raise issues.BadSession(f"Session is not active, session_id: {self._session_id}, closed: {self._closed}") + if self._closed: raise RuntimeError(f"Session is not active, session_id: {self._session_id}, closed: {self._closed}") - def _invalidate(self) -> None: + def _close_session(self, invalidate: bool = False) -> None: if self._closed: return + if invalidate: + self._invalidated = True self._closed = True if self._stream is not None: @@ -161,9 +167,9 @@ def _on_execute_stream_error(self, e: BaseException) -> None: issues.Cancelled, ), ): - self._invalidate() + self._close_session(invalidate=True) else: - self._invalidate() + self._close_session(invalidate=True) # Overloads for _create_call @overload @@ -339,7 +345,7 @@ def _attach(self, first_resp_timeout: int = DEFAULT_INITIAL_RESPONSE_TIMEOUT) -> ) issues._process_response(first_response) except Exception as e: - self._invalidate() + self._close_session(invalidate=True) raise e threading.Thread( @@ -356,7 +362,7 @@ def _check_session_status_loop(self, status_stream: _utilities.SyncResponseItera logger.debug("Attach stream closed, session_id: %s", self._session_id) except Exception as e: logger.debug("Attach stream error: %s, session_id: %s", e, self._session_id) - self._invalidate() + self._close_session(invalidate=True) def delete(self, settings: Optional[BaseRequestSettings] = None) -> None: """Deletes a Session of Query Service on server side and releases resources. @@ -372,7 +378,7 @@ def delete(self, settings: Optional[BaseRequestSettings] = None) -> None: except Exception: pass - self._invalidate() + self._close_session() def create(self, settings: Optional[BaseRequestSettings] = None) -> "QuerySession": """Creates a Session of Query Service on server side and attaches it.