diff --git a/CHANGELOG.md b/CHANGELOG.md index 597faec..a424a03 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,14 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [Unreleased] -- Nothing yet +### Added + +- Support closing `Kv` client's `session` via: ([#20](https://github.com/h4l/denokv-python/pull/20)) + - `Kv.aclose()` + - async context manager + - At interpreter exit / garbage collection via `Kv.create_finalizer()` + - Automatically when an interactive console exists: + - `Kv` objects created by `open_kv()` from an interactive console/REPL automatically close at exit. + - The `open_kv()` function has a `finalize` option that controls this. [unreleased]: https://github.com/h4l/denokv-python/commits/main/ diff --git a/src/denokv/kv.py b/src/denokv/kv.py index 354a7d6..ce53323 100644 --- a/src/denokv/kv.py +++ b/src/denokv/kv.py @@ -1,14 +1,17 @@ from __future__ import annotations import asyncio +import weakref from base64 import urlsafe_b64decode from base64 import urlsafe_b64encode from binascii import unhexlify +from contextlib import AbstractAsyncContextManager from dataclasses import dataclass from dataclasses import field from enum import Flag from enum import auto from os import environ +from types import TracebackType from typing import TYPE_CHECKING from typing import AsyncIterator from typing import Awaitable @@ -17,6 +20,7 @@ from typing import Final from typing import Generic from typing import Iterable +from typing import Literal from typing import Protocol from typing import Sequence from typing import TypedDict @@ -32,6 +36,7 @@ from denokv._datapath_pb2 import SnapshotRead from denokv._datapath_pb2 import SnapshotReadOutput from denokv._pycompat.dataclasses import slots_if310 +from denokv._pycompat.typing import override from denokv.asyncio import loop_time from denokv.auth import ConsistencyLevel from denokv.auth import DatabaseMetadata @@ -465,8 +470,8 @@ class KvFlags(Flag): DEFAULT_KV_FLAGS: Final = KvFlags.IntAsNumber -@dataclass(init=False, **slots_if310()) -class Kv: +@dataclass(init=False) +class Kv(AbstractAsyncContextManager["Kv", None]): """ Interface to perform requests against a Deno KV database. @@ -498,6 +503,59 @@ def __init__( self.v8_decoder = v8_decoder or Decoder() self.flags = KvFlags.IntAsNumber if flags is None else flags + @override + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + /, + ) -> None: + await self.aclose() + + @property + def closed(self) -> bool: + return self.session.closed + + async def aclose(self) -> None: + if self.closed: + return + await self._aclose(self.session) + + @classmethod + async def _aclose(cls, session: aiohttp.ClientSession) -> None: + await session.close() + + @classmethod + def _finalize(cls, session: aiohttp.ClientSession) -> None | asyncio.Future[None]: + if session._loop.is_running(): + return session._loop.create_task( + cls._aclose(session), name="denokv.Kv.create_finalizer" + ) + else: + return session._loop.run_until_complete(cls._aclose(session)) + + def create_finalizer(self) -> weakref.finalize: + """ + Automatically close the instance when it goes out of scope, or at exit. + + This creates a Finalizer (`weakref.finalize`) that closes the Kv + instance automatically when garbage collected, or when Python exits. + + If the event loop of the Kv's session is running, the finalizer returns, + an `asyncio.Task` that closes the Kv instance. Otherwise it runs the + session's loop to close it and returns None after it's closed. + + Notes + ----- + It's recommended to close Kv instances explicitly using async context + manager blocks, but Finalizers can be used in situations where a context + manager is not practical, like in an interactive environment. + + `open_kv()` automatically creates a Finalizer in interactive sessions. + """ + return weakref.finalize(self, self._finalize, self.session) + def _prepare_key(self, key: AnyKvKeyT) -> AnyKvKeyT: if self.flags & KvFlags.IntAsNumber and not isinstance(key, KvKeyEncodable): return normalize_key(key, bigints=False) # type: ignore[return-value] @@ -943,12 +1001,24 @@ def _common_prefix_length(a: Sequence[object], b: Sequence[object]) -> int: return match_length +def _is_python_running_in_interactive_environment() -> bool: + import sys + + # sys.ps1 is only set in interactive environments: + # https://stackoverflow.com/a/64523765/693728 + return hasattr(sys, "ps1") + + +OpenKvFinalize: TypeAlias = Literal[True, False, "interactive"] + + async def open_kv( target: URL | str | KvCredentials, *, access_token: str | None = None, session: aiohttp.ClientSession | None = None, flags: KvFlags | None = None, + finalize: OpenKvFinalize | None = None, ) -> Kv: """ Create a connection to a KV database. @@ -968,6 +1038,11 @@ async def open_kv( Default: A new session is created. flags Enable/disable flags that change Kv behaviour. Default: [DEFAULT_KV_FLAGS] + finalize + Whether to create a finalizer to automatically close the Kv instance at + exit, or when out of scope. If set to 'interactive', a finalizer is + created only if Python is running as an interactive session. + Default: ['interactive'] Notes ----- @@ -993,8 +1068,16 @@ async def open_kv( ) target = KvCredentials(server_url=target, access_token=access_token) + if finalize not in (True, False, "interactive", None): + raise ValueError("finalize must be True, False, None or 'interactive'") + finalize = "interactive" if finalize is None else finalize session = session or aiohttp.ClientSession() retry = ExponentialBackoff() auth = Authenticator(session=session, retry_delays=retry, credentials=target) - return Kv(session=session, auth=auth, retry=retry, flags=flags) + kv = Kv(session=session, auth=auth, retry=retry, flags=flags) + if finalize is True or ( + finalize == "interactive" and _is_python_running_in_interactive_environment() + ): + kv.create_finalizer() + return kv diff --git a/test/test_kv.py b/test/test_kv.py index a4745ef..67e61cf 100644 --- a/test/test_kv.py +++ b/test/test_kv.py @@ -1,6 +1,10 @@ from __future__ import annotations import asyncio +import sys +import weakref +from contextlib import asynccontextmanager +from contextlib import contextmanager from datetime import datetime from datetime import timedelta from functools import partial @@ -64,6 +68,7 @@ from denokv.kv import KvFlags from denokv.kv import KvListOptions from denokv.kv import KvU64 +from denokv.kv import OpenKvFinalize from denokv.kv import VersionStamp from denokv.kv import normalize_key from denokv.kv import open_kv @@ -960,6 +965,83 @@ async def test_Kv_list__retries_retryable_snapshot_read_errors( assert len(auth_fn.mock_calls) == 7 +@pytest_mark_asyncio +async def test_aclose() -> None: + authenticator = Mock() + kv = Kv(session=aiohttp.ClientSession(), auth=authenticator) + assert not kv.closed + assert not kv.session.closed + + await kv.aclose() + assert kv.closed + assert kv.session.closed + + +@pytest_mark_asyncio +async def test_close_via_context_manager() -> None: + authenticator = Mock() + async with Kv(session=aiohttp.ClientSession(), auth=authenticator) as kv: + assert not kv.closed + assert not kv.session.closed + assert kv.closed + assert kv.session.closed + + +@pytest_mark_asyncio +async def test_close_via_finalizer__manual() -> None: + session = aiohttp.ClientSession() + authenticator = Mock() + kv = Kv(session=session, auth=authenticator) + + f = kv.create_finalizer() + assert isinstance(f, weakref.finalize) + assert not session.closed + + result = f() + assert isinstance(result, asyncio.Future) + await result + assert session.closed + + +@pytest_mark_asyncio +async def test_close_via_finalizer__loop_running__auto() -> None: + session = aiohttp.ClientSession() + authenticator = Mock() + + def use_kv_and_finalize() -> None: + kv = Kv(session=session, auth=authenticator) + kv.create_finalizer() + + assert not session.closed + + async with all_inner_tasks_awaited(): + use_kv_and_finalize() + + assert session.closed + + +def test_close_via_finalizer__loop_not_running() -> None: + loop = asyncio.new_event_loop() + authenticator = Mock() + + async def create_session() -> aiohttp.ClientSession: + return aiohttp.ClientSession() + + session = loop.run_until_complete(create_session()) + + def use_kv_and_finalize() -> None: + kv = Kv(session=session, auth=authenticator) + kv.create_finalizer() + + assert not session.closed + assert not loop.is_running() + + use_kv_and_finalize() + + assert session.closed + assert not loop.is_running() + + def test_open_kv__requires_event_loop_to_default_session() -> None: with pytest.raises(RuntimeError, match=r"no running event loop"): aiohttp.ClientSession() @@ -1007,3 +1089,71 @@ async def test_open_kv( assert isinstance(kv.metadata_cache.authenticator, Authenticator) credentials = kv.metadata_cache.authenticator.credentials assert credentials.access_token == "argsecret" + + +@pytest.mark.parametrize( + "is_interactive, finalize, is_closed", + [ + (True, None, True), + (True, "interactive", True), + (True, True, True), + (True, False, False), + (False, None, False), + (False, "interactive", False), + (False, True, True), + (False, False, False), + ], +) +@pytest_mark_asyncio +async def test_open_kv__creates_finalizer_when_running_interactively( + is_interactive: bool, finalize: OpenKvFinalize | None, is_closed: bool +) -> None: + async with aiohttp.ClientSession() as session: + + async def open_kv_and_drop_reference() -> None: + await open_kv( + "https://0.0.0.0/example", + finalize=finalize, + session=session, + access_token="example", + ) + + async with all_inner_tasks_awaited(): + if is_interactive: + with interactive_session_active(): + await open_kv_and_drop_reference() + else: + await open_kv_and_drop_reference() + assert session.closed is is_closed + + +@pytest_mark_asyncio +async def test_open_kv__validates_finalize() -> None: + with pytest.raises( + ValueError, match=r"finalize must be True, False, None or 'interactive'" + ): + await open_kv( + "https://0.0.0.0/example", + finalize="sdfdsf", # type: ignore[arg-type] + access_token="example", + ) + + +@contextmanager +def interactive_session_active() -> Generator[None]: + assert not hasattr(sys, "ps1") + + try: + sys.ps1 = "example" + yield + finally: + del sys.ps1 + + +@asynccontextmanager +async def all_inner_tasks_awaited() -> AsyncGenerator[None]: + pre_existing_tasks = asyncio.all_tasks() + yield + inner_tasks = asyncio.all_tasks() - pre_existing_tasks + if inner_tasks: + await asyncio.wait(inner_tasks)