diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 265da9c..60bff01 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -11,12 +11,7 @@ "postCreateCommand": "pipx install poetry", "customizations": { "vscode": { - "extensions": [ - "charliermarsh.ruff", - "ms-python.python", - "ms-python.isort", - "ms-python.black-formatter" - ] + "extensions": ["charliermarsh.ruff", "ms-python.python"] } } } diff --git a/.prettierrc b/.prettierrc index 222861c..e659e61 100644 --- a/.prettierrc +++ b/.prettierrc @@ -1,4 +1,5 @@ { "tabWidth": 2, - "useTabs": false + "useTabs": false, + "proseWrap": "always" } diff --git a/CHANGELOG.md b/CHANGELOG.md index a424a03..2618018 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,12 +8,22 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ### Added -- Support closing `Kv` client's `session` via: ([#20](https://github.com/h4l/denokv-python/pull/20)) +- Support writing to KV databases with `Kv.set()`, `Kv.delete()`, `Kv.sum()`, + `Kv.min()`, `Kv.max()`, `Kv.enqueue()` and `Kv.check()`. + ([#16](https://github.com/h4l/denokv-python/pull/16)) + + - These methods are available on `Kv` itself for one-off operations, and + `Kv.atomic()` can chain these methods to group write operations to apply + together in a transaction. + +- Support closing `Kv` client's `session` via: + ([#20](https://github.com/h4l/denokv-python/pull/20)) - `Kv.aclose()` - async context manager - At interpreter exit / garbage collection via `Kv.create_finalizer()` - Automatically when an interactive console exists: - - `Kv` objects created by `open_kv()` from an interactive console/REPL automatically close at exit. + - `Kv` objects created by `open_kv()` from an interactive console/REPL + automatically close at exit. - The `open_kv()` function has a `finalize` option that controls this. [unreleased]: https://github.com/h4l/denokv-python/commits/main/ diff --git a/README.md b/README.md index f5d73f2..7106e1b 100644 --- a/README.md +++ b/README.md @@ -6,21 +6,31 @@ _Connect to [Deno KV] cloud and [self-hosted] databases from Python._ [self-hosted]: https://deno.com/blog/kv-is-open-source-with-continuous-backup [denokv server]: https://github.com/denoland/denokv -The `denokv` package is an unofficial Python client for the Deno KV database. It can connect to -both the distributed cloud KV service, or self-hosted [denokv server] (which can be a replica of a cloud KV database, or standalone). +The `denokv` package is an unofficial Python client for the Deno KV database. It +can connect to both the distributed cloud KV service, or self-hosted [denokv +server] (which can be a replica of a cloud KV database, or standalone). -It implements version 3 of the [KV Connect protocol spec, published by Deno](https://github.com/denoland/denokv/blob/main/proto/kv-connect.md). +It implements version 3 of the +[KV Connect protocol spec, published by Deno](https://github.com/denoland/denokv/blob/main/proto/kv-connect.md). ## Status -The package is under active development and is not yet stable or feature-complete. +The package is under active development and is not yet stable or +feature-complete. **Working**: -- [x] Reading data with kv.get(), kv.list() +- [x] Reading data with `Kv.get()`, `Kv.list()` + - The read APIs are being reworked to improve ergonomics and functionality +- [x] Writing data with with `Kv.set()`, `Kv.delete()`, `Kv.sum()`, `Kv.min()`, + `Kv.max()`, `Kv.enqueue()` and `Kv.check()`. + - These methods are available on `Kv` itself for one-off operations, and + `Kv.atomic()` can chain these methods to group write operations to apply + together in a transaction. **To-do**: -- [ ] [Writing data / transactions](https://docs.deno.com/deploy/kv/manual/transactions/) - [ ] [Watching for changes](https://docs.deno.com/deploy/kv/manual/operations/#watch) - [ ] [Queues](https://deno.com/blog/queues) + - This is uncertain: The KV Connect protocol does not support Queues, but they + could be implemented using watching in theory. diff --git a/docker-bake.hcl b/docker-bake.hcl index e2c4387..bb3941a 100644 --- a/docker-bake.hcl +++ b/docker-bake.hcl @@ -30,6 +30,20 @@ function "get_py_image_tag" { py_versions = ["3.9", "3.10", "3.11", "3.12", "3.13"] +target "dev" { + name = "dev_py${replace(py, ".", "")}" + matrix = { + py = py_versions, + } + args = { + PYTHON_VER = get_py_image_tag(py) + REPORT_CODE_COVERAGE = REPORT_CODE_COVERAGE + REPORT_CODE_BRANCH_COVERAGE = REPORT_CODE_BRANCH_COVERAGE + } + target = "poetry" + tags = ["ghcr.io/h4l/denokv-python/dev:py${replace(py, ".", "")}"] +} + target "test" { name = "test_py${replace(py, ".", "")}" matrix = { diff --git a/docker-compose.yml b/docker-compose.yml index 175ffa2..1f70ae7 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,6 +11,19 @@ services: networks: - devcontainer_denokv_python + dev-py39: + profiles: [dev-py39] + image: ghcr.io/h4l/denokv-python/dev:py39 + volumes: + - workspace:/workspaces + working_dir: /workspaces/denokv-python + environment: + PYTHONPATH: /workspaces/denokv-python/src + command: poetry run ipython + networks: + - devcontainer_denokv_python + + networks: devcontainer_denokv_python: external: true diff --git a/poetry.lock b/poetry.lock index 97f3665..5c55821 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2919,13 +2919,13 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "v8serialize" -version = "0.1.0" +version = "0.2.0a0" description = "Read & write JavaScript values from Python with the V8 serialization format." optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "v8serialize-0.1.0-py3-none-any.whl", hash = "sha256:5136e50c24308f9ddc7b8083ca34e7c65f57cd321dd703b9667708a8552eebed"}, - {file = "v8serialize-0.1.0.tar.gz", hash = "sha256:bd330fb925be9c395d82ed4f048b78f0d560d4358b3061d149665d8a8cc60d86"}, + {file = "v8serialize-0.2.0a0-py3-none-any.whl", hash = "sha256:643557ed38757a5ddaac008469bc87c47e1e89df062941ce133323e60122f4f0"}, + {file = "v8serialize-0.2.0a0.tar.gz", hash = "sha256:57dc3262a089ba5917da4d53300fa392e50e19ab2476ef3e507d6061f70ee332"}, ] [package.dependencies] @@ -3098,4 +3098,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "4581fc6eb84b066a30991807676d038a115ea4e6a24093059d9dc46f2cab1bed" +content-hash = "eceacf465958af528559937236cb0ab10ae6d4a8d492c1e075650c451a85b5db" diff --git a/pyproject.toml b/pyproject.toml index ac52135..b1454b8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ protobuf = ">=4.22.0,<6" # change was in 6.2.4 (which is late 2019) # https://github.com/apple/foundationdb/commits/main/bindings/python/fdb/tuple.py foundationdb = ">=6.2.4,<8" -v8serialize = "^0.1.0" +v8serialize = "^0.2.0-alpha.0" [tool.poetry.group.dev.dependencies] pytest = "^8.3.2" @@ -57,6 +57,9 @@ extra_standard_library = ["typing_extensions"] [tool.mypy] strict = true +enable_error_code = [ + 'possibly-undefined' +] mypy_path = "./stubs" [[tool.mypy.overrides]] @@ -84,6 +87,7 @@ select = [ "FA", # flake8-future-annotations "PYI", # flake8-pyi "I", + "TID", # flake8-tidy-imports ] ignore = [ @@ -99,6 +103,10 @@ ignore = [ "PYI041", ] +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"typing_extensions".msg = "use denokv._pycompat.typing instead, typing_extensions is not a runtime dependency." +"typing".msg = "Use denokv._pycompat.typing instead (apart from overload and Literal), using typing is error-prone as it has many differences between python versions." + [tool.ruff.lint.pydocstyle] convention = "numpy" @@ -114,6 +122,7 @@ filterwarnings = [ ] [tool.coverage.run] +source = ["denokv"] omit = [ # generated Protocol Buffers module "*/denokv/_datapath_pb2.py", diff --git a/src/denokv/__init__.py b/src/denokv/__init__.py index 1340fdb..ce13266 100644 --- a/src/denokv/__init__.py +++ b/src/denokv/__init__.py @@ -1,5 +1,8 @@ from __future__ import annotations +from denokv._kv_values import KvEntry as KvEntry +from denokv._kv_values import KvU64 as KvU64 +from denokv._kv_values import VersionStamp as VersionStamp from denokv.auth import ConsistencyLevel as ConsistencyLevel from denokv.auth import MetadataExchangeDenoKvError as MetadataExchangeDenoKvError from denokv.datapath import AnyKvKey as AnyKvKey @@ -13,10 +16,7 @@ from denokv.kv import CursorFormatType as CursorFormatType from denokv.kv import Kv as Kv from denokv.kv import KvCredentials as KvCredentials -from denokv.kv import KvEntry as KvEntry from denokv.kv import KvListOptions as KvListOptions -from denokv.kv import KvU64 as KvU64 from denokv.kv import ListKvEntry as ListKvEntry -from denokv.kv import VersionStamp as VersionStamp from denokv.kv import open_kv as open_kv from denokv.kv_keys import KvKey as KvKey diff --git a/src/denokv/_kv_types.py b/src/denokv/_kv_types.py new file mode 100644 index 0000000..77ea8a9 --- /dev/null +++ b/src/denokv/_kv_types.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod + +from google.protobuf.message import Message +from v8serialize import Encoder + +from denokv._datapath_pb2 import AtomicWrite +from denokv._kv_values import VersionStamp +from denokv._pycompat.typing import Generic +from denokv._pycompat.typing import Protocol +from denokv._pycompat.typing import Sequence +from denokv._pycompat.typing import TypeAlias +from denokv._pycompat.typing import TypeGuard +from denokv._pycompat.typing import TypeVar +from denokv._pycompat.typing import Union +from denokv.auth import EndpointInfo +from denokv.datapath import CheckFailure +from denokv.datapath import DataPathError +from denokv.result import Nothing +from denokv.result import Option +from denokv.result import Result +from denokv.result import Some + +WriteResultT = TypeVar("WriteResultT") +WriteResultT_co = TypeVar("WriteResultT_co", covariant=True) +MessageT_co = TypeVar("MessageT_co", bound=Message, covariant=True) + + +class ProtobufMessageRepresentation(Generic[MessageT_co], ABC): + """An object that can represent itself as a protobuf Messages.""" + + __slots__ = () + + @abstractmethod + def as_protobuf(self, *, v8_encoder: Encoder) -> Sequence[MessageT_co]: ... + + +class SingleProtobufMessageRepresentation(ProtobufMessageRepresentation[MessageT_co]): + """An object that can represent itself as a single protobuf Message.""" + + __slots__ = () + + @abstractmethod + def as_protobuf(self, *, v8_encoder: Encoder) -> tuple[MessageT_co]: ... + + +class AtomicWriteRepresentation(SingleProtobufMessageRepresentation[AtomicWrite]): + __slots__ = () + + +class AtomicWriteRepresentationWriter( + AtomicWriteRepresentation, Generic[WriteResultT_co] +): + __slots__ = () + + @abstractmethod + async def write(self, kv: KvWriter, *, v8_encoder: Encoder) -> WriteResultT_co: ... + + +KvWriterWriteResult: TypeAlias = Result[ + tuple[VersionStamp, EndpointInfo], Union[CheckFailure, DataPathError] +] + + +class KvWriter(ABC): + """A low-level interface for objects that can perform KV writes.""" + + @abstractmethod + async def write(self, *, protobuf_atomic_write: AtomicWrite) -> KvWriterWriteResult: + """Write a protobuf AtomicWrite message to the database.""" + + +class V8EncoderProvider(Protocol): + @property + def v8_encoder(self) -> Encoder: ... + + +def is_v8_encoder_provider(obj: object) -> TypeGuard[V8EncoderProvider]: + return isinstance(getattr(obj, "v8_encoder", None), Encoder) + + +def get_v8_encoder(maybe_v8_encoder_provider: object) -> Option[Encoder]: + if is_v8_encoder_provider(maybe_v8_encoder_provider): + return Some(maybe_v8_encoder_provider.v8_encoder) + return Nothing() diff --git a/src/denokv/_kv_values.py b/src/denokv/_kv_values.py new file mode 100644 index 0000000..e0ce994 --- /dev/null +++ b/src/denokv/_kv_values.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from binascii import unhexlify +from dataclasses import dataclass + +from denokv._pycompat.dataclasses import slots_if310 +from denokv._pycompat.typing import ClassVar +from denokv._pycompat.typing import Generic +from denokv._pycompat.typing import Self +from denokv._pycompat.typing import TypeVar +from denokv._pycompat.typing import TypeVarTuple +from denokv._pycompat.typing import Unpack +from denokv.datapath import AnyKvKeyT +from denokv.datapath import KvKeyPiece + +T = TypeVar("T", default=object) +# Note that the default arg doesn't seem to work with MyPy yet. The +# DefaultKvKey alias is what this should behave as when defaulted. +Pieces = TypeVarTuple("Pieces", default=Unpack[tuple[KvKeyPiece, ...]]) + + +@dataclass(frozen=True, **slots_if310()) +class KvEntry(Generic[AnyKvKeyT, T]): + """A value read from the Deno KV database, along with its key and version.""" + + key: AnyKvKeyT + value: T + versionstamp: VersionStamp + + +class VersionStamp(bytes): + r""" + A 20-hex-char / (10 byte) version identifier. + + This value represents the relative age of a KvEntry. A VersionStamp that + compares larger than another is newer. + + Examples + -------- + >>> VersionStamp(0xff << 16) + VersionStamp('00000000000000ff0000') + >>> int(VersionStamp('000000000000000000ff')) + 255 + >>> bytes(VersionStamp('00000000000000ff0000')) + b'\x00\x00\x00\x00\x00\x00\x00\xff\x00\x00' + >>> VersionStamp(b'\x00\x00\x00\x00\x00\x00\x00\xff\x00\x00') + VersionStamp('00000000000000ff0000') + >>> isinstance(VersionStamp(0), bytes) + True + >>> str(VersionStamp(0xff << 16)) + '00000000000000ff0000' + """ + + __slots__ = () + + RANGE: ClassVar = range(0, 2**80) + + def __new__(cls, value: str | bytes | int) -> Self: + if isinstance(value, int): + if value not in VersionStamp.RANGE: + raise ValueError("value not in range for 80-bit unsigned int") + # Unlike most others, versionstamp uses big-endian as it needs to + # sort lexicographically as bytes. + value = value.to_bytes(length=10, byteorder="big") + if isinstance(value, str): + try: + value = unhexlify(value) + except Exception: + value = b"" + if len(value) != 10: + raise ValueError("value is not a 20 char hex string") + else: + if len(value) != 10: + raise ValueError("value is not 10 bytes long") + return bytes.__new__(cls, value) + + def __index__(self) -> int: + return int.from_bytes(self, byteorder="big") + + def __bytes__(self) -> bytes: + return self[:] + + def __str__(self) -> str: + return self.hex() + + def __repr__(self) -> str: + return f"{type(self).__name__}({str(self)!r})" + + +@dataclass(frozen=True, **slots_if310()) +class KvU64: + """ + An special int value that supports operations like `sum`, `max`, and `min`. + + Notes + ----- + This type is not an int subtype to avoid it being mistakenly flattened into + a regular int and loosing its special meaning when written back to the DB. + + Examples + -------- + >>> KvU64(bytes([0, 0, 0, 0, 0, 0, 0, 0])) + KvU64(0) + >>> KvU64(bytes([1, 0, 0, 0, 0, 0, 0, 0])) + KvU64(1) + >>> KvU64(bytes([1, 1, 0, 0, 0, 0, 0, 0])) + KvU64(257) + >>> KvU64(2**64 - 1) + KvU64(18446744073709551615) + >>> KvU64(2**64) + Traceback (most recent call last): + ... + ValueError: value not in range for 64-bit unsigned int + >>> KvU64(-1) + Traceback (most recent call last): + ... + ValueError: value not in range for 64-bit unsigned int + """ + + RANGE: ClassVar[range] = range(0, 2**64) + value: int + + def __init__(self, value: bytes | int) -> None: + if isinstance(value, bytes): + if len(value) != 8: + raise ValueError("value must be a 8 bytes") + value = int.from_bytes(value, byteorder="little") + elif isinstance(value, int): + if value not in KvU64.RANGE: + raise ValueError("value not in range for 64-bit unsigned int") + else: + raise TypeError("value must be 8 bytes or a 64-bit unsigned int") + object.__setattr__(self, "value", value) + + def __index__(self) -> int: + return self.value + + def __bytes__(self) -> bytes: + return self.to_bytes() + + def to_bytes(self) -> bytes: + return self.value.to_bytes(8, byteorder="little") + + def __repr__(self) -> str: + return f"{type(self).__name__}({self.value})" diff --git a/src/denokv/_kv_writes.py b/src/denokv/_kv_writes.py new file mode 100644 index 0000000..4defc53 --- /dev/null +++ b/src/denokv/_kv_writes.py @@ -0,0 +1,2190 @@ +from __future__ import annotations + +from abc import abstractmethod +from builtins import float as float_ +from dataclasses import dataclass +from dataclasses import field +from datetime import datetime +from enum import Enum +from functools import total_ordering +from itertools import islice +from types import MappingProxyType +from typing import Literal # noqa: TID251 +from typing import overload # noqa: TID251 + +from v8serialize import Encoder +from v8serialize.constants import FLOAT64_SAFE_INT_RANGE +from v8serialize.encode import WritableTagStream +from v8serialize.jstypes import JSBigInt + +from denokv import _datapath_pb2 as dp_protobuf +from denokv._datapath_pb2 import AtomicWrite +from denokv._kv_types import AtomicWriteRepresentationWriter +from denokv._kv_types import KvWriter +from denokv._kv_types import ProtobufMessageRepresentation +from denokv._kv_types import SingleProtobufMessageRepresentation +from denokv._kv_types import get_v8_encoder +from denokv._kv_values import KvEntry as KvEntry +from denokv._kv_values import KvU64 as KvU64 +from denokv._kv_values import VersionStamp as VersionStamp +from denokv._pycompat.dataclasses import FrozenAfterInitDataclass +from denokv._pycompat.dataclasses import slots_if310 +from denokv._pycompat.enum import EvalEnumRepr +from denokv._pycompat.exceptions import with_notes +from denokv._pycompat.typing import TYPE_CHECKING +from denokv._pycompat.typing import Any +from denokv._pycompat.typing import ClassVar +from denokv._pycompat.typing import Container +from denokv._pycompat.typing import Final +from denokv._pycompat.typing import Generic +from denokv._pycompat.typing import Iterable +from denokv._pycompat.typing import Mapping +from denokv._pycompat.typing import MutableSequence +from denokv._pycompat.typing import Never +from denokv._pycompat.typing import Protocol +from denokv._pycompat.typing import Self +from denokv._pycompat.typing import Sequence +from denokv._pycompat.typing import TypeAlias +from denokv._pycompat.typing import TypedDict +from denokv._pycompat.typing import TypeGuard +from denokv._pycompat.typing import TypeIs +from denokv._pycompat.typing import TypeVar +from denokv._pycompat.typing import Union +from denokv._pycompat.typing import Unpack +from denokv._pycompat.typing import assert_never +from denokv._pycompat.typing import cast +from denokv._pycompat.typing import override +from denokv._pycompat.typing import runtime_checkable +from denokv._utils import frozen +from denokv.auth import EndpointInfo +from denokv.backoff import Backoff +from denokv.backoff import ExponentialBackoff +from denokv.datapath import AnyKvKey +from denokv.datapath import CheckFailure +from denokv.datapath import pack_key +from denokv.errors import DenoKvError +from denokv.kv_keys import KvKey +from denokv.result import AnyFailure +from denokv.result import AnySuccess +from denokv.result import is_err + +KvNumberNameT = TypeVar("KvNumberNameT", bound=str, default=str) +NumberT = TypeVar("NumberT", bound=Union[int, float], default=Union[int, float]) +KvNumberTypeT = TypeVar("KvNumberTypeT", default=object) + +KvNumberNameT_co = TypeVar("KvNumberNameT_co", bound=str, covariant=True, default=str) +NumberT_co = TypeVar( + "NumberT_co", bound=Union[int, float], covariant=True, default=Union[int, float] +) +KvNumberTypeT_co = TypeVar("KvNumberTypeT_co", covariant=True, default=object) + +U = TypeVar("U") +MutateResultT = TypeVar("MutateResultT") +EnqueueResultT = TypeVar("EnqueueResultT") +CheckResultT = TypeVar("CheckResultT") + + +@total_ordering +@dataclass(frozen=True, unsafe_hash=True, **slots_if310()) +class KvNumberInfo(Generic[KvNumberNameT_co, NumberT, KvNumberTypeT]): + name: KvNumberNameT_co = field(init=False) + py_type: type[NumberT] = field(init=False) + kv_type: type[KvNumberTypeT] = field(init=False) + + @property + @abstractmethod + def default_limit(self) -> Limit[NumberT]: ... + + @abstractmethod + def validate_limit(self, limit: Limit[NumberT]) -> Limit[NumberT]: ... + + def __lt__(self, other: object) -> bool: + if isinstance(other, KvNumberInfo): + self_name: str = self.name # mypy needs help with inferring str + other_name: str = other.name + + if self_name == other_name and self != other: + raise RuntimeError("KvNumberInfo instances must have unique names") + return self_name < other_name + return NotImplemented + + def as_py_number(self, number: KvNumberTypeT | NumberT | int) -> NumberT: + if self.is_py_number(number): + return number + if self.is_kv_number(number) or self._is_compatible_int(number, target="py"): + return self.py_type(number) # type: ignore[arg-type,return-value] + raise self._describe_invalid_number(number, target="py") + + def as_kv_number(self, number: KvNumberTypeT | NumberT | int) -> KvNumberTypeT: + if self.is_kv_number(number): + return number + if self.is_py_number(number) or self._is_compatible_int(number, target="kv"): + return self.kv_type(number) # type: ignore[call-arg] + raise self._describe_invalid_number(number, target="kv") + + def _is_compatible_int( + self, number: object, *, target: Literal["py", "kv"] + ) -> TypeGuard[int]: + return type(number) is int + + def _describe_invalid_number( + self, number: object, *, target: Literal["py", "kv"] + ) -> Exception: + return with_notes( + TypeError( + f"number is not compatible with {self.name} {target} number type" + ), + f"number: {number!r} ({type(number)}), " f"{self.name}={self}", + ) + + def is_py_number(self, value: object) -> TypeGuard[NumberT]: + return isinstance(value, self.py_type) + + def is_kv_number(self, value: object) -> TypeGuard[KvNumberTypeT]: + return isinstance(value, self.kv_type) + + @abstractmethod + def get_sum_mutations( + self, + sum: Sum[KvNumberNameT_co, NumberT, KvNumberTypeT], + *, + v8_encoder: Encoder | None = None, + ) -> Sequence[dp_protobuf.Mutation]: ... + + @abstractmethod + def get_min_mutations( + self, + min: Min[KvNumberNameT_co, NumberT, KvNumberTypeT], + *, + v8_encoder: Encoder | None = None, + ) -> Sequence[dp_protobuf.Mutation]: ... + + @abstractmethod + def get_max_mutations( + self, + max: Max[KvNumberNameT_co, NumberT, KvNumberTypeT], + *, + v8_encoder: Encoder | None = None, + ) -> Sequence[dp_protobuf.Mutation]: ... + + +class V8KvNumberInfo(KvNumberInfo[KvNumberNameT_co, NumberT, KvNumberTypeT]): + @property + def default_limit(self) -> Limit[NumberT]: + return LIMIT_UNLIMITED + + def validate_limit(self, limit: Limit[NumberT]) -> Limit[NumberT]: + if limit.limit_exceeded not in ( + LimitExceededPolicy.ABORT, + LimitExceededPolicy.CLAMP, + ): + raise with_notes( + ValueError(f"Number type {self.name!r} does not support wrap limits"), + "Use 'u64' (KvU64) to wrap on 0, 2^64 - 1 bounds.", + ) + return limit + + @abstractmethod + def v8_encode_kv_number(self, value: KvNumberTypeT) -> bytes: ... + + @override + def get_sum_mutations( + self, + sum: Sum[KvNumberNameT_co, NumberT, KvNumberTypeT], + *, + v8_encoder: Encoder | None = None, + ) -> Sequence[dp_protobuf.Mutation]: + encoded_min = b"" + encoded_max = b"" + self.validate_limit(sum.limit) + if sum.limit.min is not None: + encoded_min = self.v8_encode_kv_number(self.as_kv_number(sum.limit.min)) + if sum.limit.max is not None: + encoded_max = self.v8_encode_kv_number(self.as_kv_number(sum.limit.max)) + + mutation = dp_protobuf.Mutation( + mutation_type=dp_protobuf.MutationType.M_SUM, + key=pack_key(sum.key), + value=dp_protobuf.KvValue( + data=self.v8_encode_kv_number(self.as_kv_number(sum.delta)), + encoding=dp_protobuf.ValueEncoding.VE_V8, + ), + expire_at_ms=sum.expire_at_ms(), + sum_min=encoded_min, + sum_max=encoded_max, + sum_clamp=sum.limit.limit_exceeded is LimitExceededPolicy.CLAMP, + ) + + return [mutation] + + @override + def get_min_mutations( + self, + min: Min[KvNumberNameT_co, NumberT, KvNumberTypeT], + *, + v8_encoder: Encoder | None = None, + ) -> Sequence[dp_protobuf.Mutation]: + mutation = dp_protobuf.Mutation( + mutation_type=dp_protobuf.MutationType.M_SUM, + key=pack_key(min.key), + value=dp_protobuf.KvValue( + data=self.v8_encode_kv_number(self.as_kv_number(self.as_py_number(0))), + encoding=dp_protobuf.ValueEncoding.VE_V8, + ), + sum_max=self.v8_encode_kv_number(self.as_kv_number(min.value)), + sum_clamp=True, + expire_at_ms=min.expire_at_ms(), + ) + return [mutation] + + @override + def get_max_mutations( + self, + max: Max[KvNumberNameT_co, NumberT, KvNumberTypeT], + *, + v8_encoder: Encoder | None = None, + ) -> Sequence[dp_protobuf.Mutation]: + mutation = dp_protobuf.Mutation( + mutation_type=dp_protobuf.MutationType.M_SUM, + key=pack_key(max.key), + value=dp_protobuf.KvValue( + data=self.v8_encode_kv_number(self.as_kv_number(self.as_py_number(0))), + encoding=dp_protobuf.ValueEncoding.VE_V8, + ), + sum_min=self.v8_encode_kv_number(self.as_kv_number(max.value)), + sum_clamp=True, + expire_at_ms=max.expire_at_ms(), + ) + return [mutation] + + +class BigIntKvNumberInfo(V8KvNumberInfo[Literal["bigint"], int, JSBigInt]): + __slots__ = () + name = "bigint" + py_type = int + kv_type = JSBigInt + + def v8_encode_kv_number(self, value: JSBigInt) -> bytes: + return encode_v8_bigint(value) + + @override + def is_py_number(self, value: object) -> TypeGuard[int]: + # Don't treat JSBigInt instances as being py numbers so that we downcast + # JSBigInt to plain int in as_py_number(). This is important to allow + # other things to not treat JSBigInt as the same as int. + return (not self.is_kv_number(value)) and super().is_py_number(value) + + +class FloatKvNumberInfo(V8KvNumberInfo[Literal["float"], float, float]): + __slots__ = () + name = "float" + py_type = float + kv_type = float + + def v8_encode_kv_number(self, value: float) -> bytes: + return encode_v8_number(value) + + def _is_int_in_float_safe_range(self, value: object) -> TypeGuard[int]: + # int is assignable to float in Python's type system, but + # isinstance(int(x), float) is False. We don't allow subclasses of int, + # because JSBigInt is a subclass of int, and we don't want to treat them + # as FloatKvNumberInfo values. + return type(value) is int and value in FLOAT64_SAFE_INT_RANGE + + @override + def is_kv_number(self, value: object) -> TypeGuard[float]: + return self._is_int_in_float_safe_range(value) or super().is_kv_number(value) + + @override + def is_py_number(self, value: object) -> TypeGuard[float]: + return self._is_int_in_float_safe_range(value) or super().is_py_number(value) + + @override + def _is_compatible_int( + self, number: object, *, target: Literal["py", "kv"] + ) -> TypeGuard[int]: + # only allow conversions from plain int that is in the safe range. + return self._is_int_in_float_safe_range(number) + + @override + def _describe_invalid_number( + self, number: object, *, target: Literal["py", "kv"] + ) -> Exception: + err = super()._describe_invalid_number(number, target=target) + if type(number) is int and not self._is_int_in_float_safe_range(number): + return with_notes( + ValueError(*err.args), + "The int is too large to represent as a 64-bit floating point value.", + from_exception=err, + ) + return err + + +class U64KvNumberInfo(KvNumberInfo[Literal["u64"], int, KvU64]): + __slots__ = () + name = "u64" + py_type = int + kv_type = KvU64 + + @property + def default_limit(self) -> Limit[int]: + return LIMIT_KVU64 + + @override + def validate_limit(self, limit: Limit[int]) -> Limit[int]: + if limit.limit_exceeded is LimitExceededPolicy.ABORT: + raise with_notes( + ValueError(f"Number type {self.name!r} does not support abort limits"), + "Use 'bigint' (JSBigInt) or 'float' (int/float) to wrap on " + "0, 2^64 - 1 bounds.", + ) + + if limit.limit_exceeded is LimitExceededPolicy.WRAP and limit != LIMIT_KVU64: + raise with_notes( + ValueError( + f"Number type {self.name!r} wrap limit's min, max " + f"bounds cannot be changed" + ), + "'u64' (KvU64) can only wrap at 0 and 2^64 - 1. It can use " + "clamp with custom bounds through.", + ) + return limit + + @override + def get_sum_mutations( + self, + sum: Sum[Literal["u64"], int, KvU64], + *, + v8_encoder: Encoder | None = None, + ) -> Sequence[dp_protobuf.Mutation]: + self.validate_limit(sum.limit) + assert sum.limit.limit_exceeded is not LimitExceededPolicy.ABORT + if sum.limit.limit_exceeded is LimitExceededPolicy.WRAP: + return self._get_sum_wrap_mutations(sum) + elif sum.limit.limit_exceeded is LimitExceededPolicy.CLAMP: + return self._get_sum_clamp_mutations(sum) + else: + assert_never(sum.limit.limit_exceeded) + + def _get_sum_clamp_mutations( + self, sum: Sum[Literal["u64"], int, KvU64] + ) -> Sequence[dp_protobuf.Mutation]: + assert sum.limit.limit_exceeded is LimitExceededPolicy.CLAMP + + limit_min = 0 if sum.limit.min is None else sum.limit.min + limit_max = KvU64.RANGE.stop - 1 if sum.limit.max is None else sum.limit.max + if limit_min not in KvU64.RANGE: + raise with_notes( + ValueError("sum.limit.min must be in KvU64.RANGE"), + f"sum.limit.min: {limit_min}", + ) + if limit_max not in KvU64.RANGE: + raise with_notes( + ValueError("sum.limit.max must be in KvU64.RANGE"), + f"sum.limit.max: {limit_max}", + ) + + delta = self._normalise_clamp_delta(sum.delta) + + if delta < 0: + return self._get_negative_sum_clamp_mutations( + sum, delta, limit_min, limit_max + ) + else: + return self._get_positive_sum_clamp_mutations( + sum, delta, limit_min, limit_max + ) + + def _get_positive_sum_clamp_mutations( + self, + sum: Sum[Literal["u64"], int, KvU64], + delta: int, + limit_min: int, + limit_max: int, + ) -> Sequence[dp_protobuf.Mutation]: + assert delta in KvU64.RANGE + assert limit_min in KvU64.RANGE + assert limit_max in KvU64.RANGE + + # When the upper limit is <= the delta, the result is always clamped at the + # upper limit. Likewise if the lower limit pushes the result above the upper + # limit, the upper limit is used (it's applied last). + min_result = delta + if limit_max <= min_result or limit_max <= (limit_min or 0): + return [self._mutate_set(sum, KvU64(limit_max))] # result is constant + + mutations = list[dp_protobuf.Mutation]() + + if limit_min >= limit_max or limit_min <= delta: + limit_min = 0 # lower bound can have no effect on the result + + # We clamp the final result to be <= the limit_max by clamping the db + # value to the highest value that won't exceed the limit_max when the + # delta is added. + + # delta is always < limit_max, otherwise the result is constant, which + # is handled above. + max_start = limit_max - delta + assert max_start > 0 + mutations.append(self._mutate_min(sum, KvU64(max_start))) + + if delta != 0: + mutations.append(self._mutate_sum(sum, KvU64(delta))) + + if limit_min > 0: + mutations.append(self._mutate_max(sum, KvU64(limit_min))) + + return mutations + + def _get_negative_sum_clamp_mutations( + self, + sum: Sum[Literal["u64"], int, KvU64], + delta: int, + limit_min: int, + limit_max: int, + ) -> Sequence[dp_protobuf.Mutation]: + assert -delta in KvU64.RANGE + assert limit_min in KvU64.RANGE + assert limit_max in KvU64.RANGE + + # If value after adding the (negative) delta is always <= the lower + # limit, the lower limit is always the result. However the upper limit + # applies last, so if the upper limit is lower than the lower limit, it + # applies instead. + if limit_max <= limit_min: + return [self._mutate_set(sum, KvU64(limit_max))] + max_result = (KvU64.RANGE.stop - 1) + delta + if limit_min >= max_result: + assert limit_max > limit_min + return [self._mutate_set(sum, KvU64(limit_min))] + + mutations = list[dp_protobuf.Mutation]() + + # Offset the start to prevent it going negative after adding the delta + min_start = abs(delta) + limit_min + # min_start cannot exceed the range, because abs(delta) values >= the + # difference between limit_min and the top of the range trigger the + # constant result short-circuit above, as the result is always limit_min + assert min_start in KvU64.RANGE + mutations.append(self._mutate_max(sum, KvU64(min_start))) + + # Make the negative delta a positive delta that overflows to the result + # of applying the original negative delta offset. + if delta != 0: + delta = KvU64.RANGE.stop + delta + assert delta in KvU64.RANGE + + # Apply the delta (effectively subtracting) + mutations.append(self._mutate_sum(sum, KvU64(delta))) + + if limit_max >= max_result: + # limit_max can have no effect on the result + assert limit_max > limit_min + else: + mutations.append(self._mutate_min(sum, KvU64(limit_max))) + + return mutations + + def _get_sum_wrap_mutations( + self, sum: Sum[Literal["u64"], int, KvU64] + ) -> Sequence[dp_protobuf.Mutation]: + assert sum.limit.limit_exceeded is LimitExceededPolicy.WRAP + # Only one wrapping limit is available for KvU64 + # (the default 64-bit uint bounds). + if sum.limit != LIMIT_KVU64: + raise with_notes( + ValueError( + f"Deno KV does not support {LimitExceededPolicy.WRAP} with " + f"non-default min/max for KvU64 values" + ), + f"sum.limit: {sum.limit}", + ) + + delta = self._normalise_wrap_delta(sum.delta) + # M_SUM mutations for KvU64 only support positive delta values, because + # KvU64 is unsigned. We support negative effective deltas by taking + # advantage of integer overflow/wrapping — we add a positive value that + # overflows to the equivalent of subtracting delta. + # + # For example to subtract 2 from 10, we are calculating + # (10 + (2**64 - 2)) % 2**64 = 8 + if delta < 0: + delta = KvU64.RANGE.stop + delta + assert delta in KvU64.RANGE + + return [self._mutate_sum(sum, KvU64(delta))] + + @override + def get_min_mutations( + self, + min: Min[Literal["u64"], int, KvU64], + *, + v8_encoder: Encoder | None = None, + ) -> Sequence[dp_protobuf.Mutation]: + mutation = dp_protobuf.Mutation( + mutation_type=dp_protobuf.MutationType.M_MIN, + key=pack_key(min.key), + value=dp_protobuf.KvValue( + data=bytes(KvU64(min.value)), + encoding=dp_protobuf.ValueEncoding.VE_LE64, + ), + expire_at_ms=min.expire_at_ms(), + ) + return [mutation] + + @override + def get_max_mutations( + self, + max: Max[Literal["u64"], int, KvU64], + *, + v8_encoder: Encoder | None = None, + ) -> Sequence[dp_protobuf.Mutation]: + mutation = dp_protobuf.Mutation( + mutation_type=dp_protobuf.MutationType.M_MAX, + key=pack_key(max.key), + value=dp_protobuf.KvValue( + data=bytes(KvU64(max.value)), + encoding=dp_protobuf.ValueEncoding.VE_LE64, + ), + expire_at_ms=max.expire_at_ms(), + ) + return [mutation] + + @staticmethod + def _normalise_wrap_delta(delta: int) -> int: + """ + Normalise a sum delta value to be within +/- 2**64 for limit type wrap. + + This method wraps delta values larger than 2**64 - 1, in contrast with + _normalise_clamp_delta(), which clamps at the max value. + + Examples + -------- + >>> U64KvNumberInfo._normalise_wrap_delta(-5) + -5 + >>> U64KvNumberInfo._normalise_wrap_delta(-5 - 2**64) + -5 + >>> U64KvNumberInfo._normalise_wrap_delta(5) + 5 + >>> U64KvNumberInfo._normalise_wrap_delta(5 + 2**64) + 5 + >>> U64KvNumberInfo._normalise_wrap_delta(2**64) + 0 + >>> U64KvNumberInfo._normalise_wrap_delta(-2**64) + 0 + """ + pos_wrapped_delta = abs(delta) % KvU64.RANGE.stop + return -pos_wrapped_delta if delta < 0 else pos_wrapped_delta + + @staticmethod + def _normalise_clamp_delta(delta: int) -> int: + """ + Normalise a sum delta value to be within +/- 2**64 for limit type clamp. + + This method clamps delta values larger than 2**64 - 1 at the max value, + in contrast with _normalise_wrap_delta(), which wraps over the max value. + + Examples + -------- + >>> U64KvNumberInfo._normalise_clamp_delta(-5) + -5 + >>> U64KvNumberInfo._normalise_clamp_delta(-5 - 2**64) + -18446744073709551615 + >>> U64KvNumberInfo._normalise_clamp_delta(5) + 5 + >>> U64KvNumberInfo._normalise_clamp_delta(5 + 2**64) + 18446744073709551615 + >>> U64KvNumberInfo._normalise_clamp_delta(2**64) + 18446744073709551615 + >>> U64KvNumberInfo._normalise_clamp_delta(-2**64) + -18446744073709551615 + """ + pos_clamped_delta = min(abs(delta), KvU64.RANGE.stop - 1) + return -pos_clamped_delta if delta < 0 else pos_clamped_delta + + def _mutate( + self, + sum: Sum[Literal["u64"], int, KvU64], + mutation_type: dp_protobuf.MutationType, + value: KvU64, + ) -> dp_protobuf.Mutation: + return dp_protobuf.Mutation( + key=pack_key(sum.key), + expire_at_ms=sum.expire_at_ms(), + mutation_type=mutation_type, + value=dp_protobuf.KvValue(data=bytes(value), encoding=dp_protobuf.VE_LE64), + ) + + def _mutate_set( + self, sum: Sum[Literal["u64"], int, KvU64], value: KvU64 + ) -> dp_protobuf.Mutation: + return self._mutate(sum, dp_protobuf.MutationType.M_SET, value) + + def _mutate_max( + self, sum: Sum[Literal["u64"], int, KvU64], value: KvU64 + ) -> dp_protobuf.Mutation: + return self._mutate(sum, dp_protobuf.MutationType.M_MAX, value) + + def _mutate_min( + self, sum: Sum[Literal["u64"], int, KvU64], value: KvU64 + ) -> dp_protobuf.Mutation: + return self._mutate(sum, dp_protobuf.MutationType.M_MIN, value) + + def _mutate_sum( + self, sum: Sum[Literal["u64"], int, KvU64], value: KvU64 + ) -> dp_protobuf.Mutation: + return self._mutate(sum, dp_protobuf.MutationType.M_SUM, value) + + +@frozen +@total_ordering +class KvNumber(Enum): + """The types of numbers that the atomic sum/min/max operations can be used with.""" + + # _value_: KvNumberInfo + + bigint = BigIntKvNumberInfo() + """A JavaScript bigint — arbitrary-precision integer.""" + float = FloatKvNumberInfo() + """A JavaScript number — 64-bit floating-point number.""" + u64 = U64KvNumberInfo() + """A Deno KV-specific 64-bit unsigned integer.""" + + @overload + @classmethod + def resolve( + cls, identifier: BigIntKvNumberIdentifier + ) -> Literal[KvNumber.bigint]: ... + + @overload + @classmethod + def resolve( + cls, identifier: FloatKvNumberIdentifier + ) -> Literal[KvNumber.float]: ... + + @overload + @classmethod + def resolve(cls, identifier: U64KvNumberIdentifier) -> Literal[KvNumber.u64]: ... + + @overload + @classmethod + def resolve(cls, identifier: KvNumber) -> LiteralKvNumber: ... + + @overload + @classmethod + def resolve(cls, /, *, number: KvU64) -> Literal[KvNumber.u64]: ... + + @overload + @classmethod + def resolve(cls, /, *, number: JSBigInt) -> Literal[KvNumber.bigint]: ... # pyright: ignore[reportOverlappingOverload] + + @overload + @classmethod + def resolve(cls, /, *, number: float_) -> Literal[KvNumber.float]: ... + + @classmethod + def resolve( + cls, + identifier: KvNumberIdentifier | None = None, + *, + number: KvU64 | JSBigInt | __builtins__.float | None = None, + ) -> LiteralKvNumber: + if identifier is not None: + return cast(LiteralKvNumber, KvNumber(identifier)) + + if number is None: + raise TypeError("resolve() missing 1 required argument: 'identifier'") + try: + return cast(LiteralKvNumber, KvNumber(type(number))) + except Exception as e: + raise TypeError( + f"number is not supported by any KvNumber: {number!r}" + ) from e + + @classmethod + def _missing_(cls, value: Any) -> KvNumber | None: + return cls.__members__.get(value) + + def __lt__(self, other: object) -> bool: + if type(other) is KvNumber: + self_value: KvNumberInfo[Any, Any, Any] = self.value + other_value: KvNumberInfo[Any, Any, Any] = other.value + return self_value < other_value + return NotImplemented + + +LiteralKvNumber: TypeAlias = Literal[KvNumber.bigint, KvNumber.float, KvNumber.u64] +KvNumber._value2member_map_[JSBigInt] = KvNumber.bigint +KvNumber._value2member_map_[float] = KvNumber.float +# int values correspond to KvNumber (float64) because JavaScript integer values +# are float64, and v8serialize by default encodes and decodes int values as +# Number not Bigint (JSBigInt is used for BigInt). +KvNumber._value2member_map_[int] = KvNumber.float +KvNumber._value2member_map_[KvU64] = KvNumber.u64 + +BigIntKvNumberIdentifier: TypeAlias = Union[ + Literal["bigint", KvNumber.bigint], type[JSBigInt] +] +FloatKvNumberIdentifier: TypeAlias = Union[ + Literal["float", KvNumber.float], type[float] +] +U64KvNumberIdentifier: TypeAlias = Union[Literal["u64", KvNumber.u64], type[KvU64]] +KvNumberIdentifier: TypeAlias = Union[ + BigIntKvNumberIdentifier, FloatKvNumberIdentifier, U64KvNumberIdentifier, KvNumber +] + + +def encode_v8_number(number: float, /) -> bytes: + """Encode a Python float as a JavaScript Number in V8 serialization format.""" + if not KvNumber.float.value.is_kv_number(number): + raise with_notes( + TypeError("number must be a float or int in the float-safe range"), + f"number: {number!r} ({type(number)})", + ) + wts = WritableTagStream() + wts.write_header() + # It's OK to pass an int, they'll be encoded as float64 + wts.write_double(number) + return bytes(wts.data) + + +def encode_v8_bigint(number: JSBigInt, /) -> bytes: + """Encode a Python JSBigInt as a JavaScript BigInt in V8 serialization format.""" + if not KvNumber.bigint.value.is_kv_number(number): + raise TypeError(f"number must be a JSBigInt, not {type(number)}") + wts = WritableTagStream() + wts.write_header() + wts.write_bigint(number) + return bytes(wts.data) + + +@overload +def encode_kv_write_value( + value: KvU64 | bytes | JSBigInt | float, *, v8_encoder: Encoder | None = None +) -> dp_protobuf.KvValue: ... + + +@overload +def encode_kv_write_value( + value: object, *, v8_encoder: Encoder +) -> dp_protobuf.KvValue: ... + + +def encode_kv_write_value( + value: object, *, v8_encoder: Encoder | None = None +) -> dp_protobuf.KvValue: + if isinstance(value, KvU64): + return dp_protobuf.KvValue( + data=bytes(value), + encoding=dp_protobuf.ValueEncoding.VE_LE64, + ) + elif isinstance(value, bytes): + return dp_protobuf.KvValue( + data=value, encoding=dp_protobuf.ValueEncoding.VE_BYTES + ) + elif isinstance(value, JSBigInt): + return dp_protobuf.KvValue( + data=encode_v8_bigint(value), encoding=dp_protobuf.ValueEncoding.VE_V8 + ) + elif isinstance(value, float): + return dp_protobuf.KvValue( + data=encode_v8_number(value), encoding=dp_protobuf.ValueEncoding.VE_V8 + ) + else: + if v8_encoder is None: + raise TypeError( + "v8_encoder cannot be None when encoding an arbitrary object" + ) + return dp_protobuf.KvValue( + data=bytes(v8_encoder.encode(value)), + encoding=dp_protobuf.ValueEncoding.VE_V8, + ) + + +class MutationOptions(TypedDict, total=False): + expire_at: datetime | None + + +class LimitOptions(Generic[NumberT], TypedDict, total=False): + clamp_over: NumberT | None + clamp_under: NumberT | None + abort_over: NumberT | None + abort_under: NumberT | None + limit: Limit[NumberT] | None + + +class SumOptions(LimitOptions[NumberT_co], MutationOptions): + """Keyword arguments accepted by `sum()`/`Sum()`.""" + + +class SumArgs( + SumOptions[NumberT], Generic[KvNumberNameT, NumberT, KvNumberTypeT], total=False +): + """All arguments accepted by `sum()`/`Sum()`.""" + + key: AnyKvKey + delta: JSBigInt | float | KvU64 | NumberT | KvNumberTypeT + number_type: ( + KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT] | KvNumberIdentifier | None + ) + + +class CheckMixin(Generic[CheckResultT]): + @abstractmethod + def _check(self, check: CheckRepresentation, /) -> CheckResultT: + raise NotImplementedError + + @overload + def check( + self, key: AnyKvKey, versionstamp: VersionStamp | None = None + ) -> CheckResultT: ... + + @overload + def check(self, check: CheckRepresentation, /) -> CheckResultT: ... + + @overload + def check(self, check: AnyKeyVersion, /) -> CheckResultT: ... + + def check( + self, + key: CheckRepresentation | AnyKeyVersion | AnyKvKey, + versionstamp: VersionStamp | None = None, + ) -> CheckResultT: + if isinstance(key, CheckRepresentation): + if versionstamp is not None: + raise TypeError( + "'versionstamp' argument cannot be set when the first argument " + "to check() is an object with an 'as_protobuf' method" + ) + return self._check(key) + elif isinstance(key, AnyKeyVersion): + if versionstamp is not None: + raise TypeError( + "'versionstamp' argument cannot be set when the first argument " + "to check() is an object with 'key' and 'versionstamp' attributes" + ) + return self._check(Check(key.key, key.versionstamp)) + else: + return self._check(Check(key, versionstamp)) + + def check_key_has_version( + self, key: AnyKvKey, versionstamp: VersionStamp + ) -> CheckResultT: + return self._check(Check.for_key_with_version(key, versionstamp)) + + def check_key_not_set(self, key: AnyKvKey) -> CheckResultT: + return self._check(Check.for_key_not_set(key)) + + +class MutatorMixin(Generic[MutateResultT]): + @abstractmethod + def mutate(self, mutation: MutationRepresentation) -> MutateResultT: + raise NotImplementedError + + +class SetMutatorMixin(MutatorMixin[MutateResultT]): + def set( + self, key: AnyKvKey, value: object, *, versioned: bool = False + ) -> MutateResultT: + return self.mutate(Set(key, value, versioned=versioned)) + + +class SumMutatorMixin(MutatorMixin[MutateResultT]): + # The overloads here have two categories: Firstly overloads based on known + # Known KvNumber enum numbers — bigint, float and u64. Secondly, + # generic/catch-all for any KvNumberInfo instance. + @overload + def sum( + self, + key: AnyKvKey, + delta: JSBigInt, + number_type: None = None, + **options: Unpack[SumOptions[int]], + ) -> MutateResultT: ... + + @overload + def sum( + self, + key: AnyKvKey, + delta: int | JSBigInt, + number_type: BigIntKvNumberIdentifier, + **options: Unpack[SumOptions[int]], + ) -> MutateResultT: ... + + @overload + def sum( + self, + key: AnyKvKey, + delta: KvU64, + number_type: None = None, + **options: Unpack[SumOptions[int]], + ) -> MutateResultT: ... + + @overload + def sum( + self, + key: AnyKvKey, + delta: int | KvU64, + number_type: U64KvNumberIdentifier, + **options: Unpack[SumOptions[int]], + ) -> MutateResultT: ... + + @overload + def sum( + self, + key: AnyKvKey, + delta: float, + number_type: FloatKvNumberIdentifier | None = None, + **options: Unpack[SumOptions[float]], + ) -> MutateResultT: ... + + @overload + def sum( + self, + key: AnyKvKey, + delta: NumberT | KvNumberTypeT, + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT], + # Can't use float limits unless the float type is explicitly being used, + # as float is incompatible with the other number types, but int is + # compatible. + **options: Unpack[SumOptions[NumberT]], + ) -> MutateResultT: ... + + def sum( + self, + key: AnyKvKey, + delta: JSBigInt | float | KvU64 | NumberT | KvNumberTypeT, + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT] + | KvNumberIdentifier + | None = None, + **options: Unpack[SumOptions[NumberT]], + ) -> MutateResultT: + delta = cast(Union[NumberT, KvNumberTypeT], delta) + number_type = cast( + KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT], number_type + ) + return self.mutate(Sum(key, delta, number_type, **options)) + + def sum_bigint( + self, + key: AnyKvKey, + delta: int | JSBigInt, + **options: Unpack[SumOptions[int]], + ) -> MutateResultT: + return self.sum(key, delta, number_type=KvNumber.bigint, **options) + + def sum_float( + self, + key: AnyKvKey, + delta: float, + **options: Unpack[SumOptions[float]], + ) -> MutateResultT: + return self.sum(key, delta, number_type=KvNumber.float, **options) + + def sum_kvu64( + self, + key: AnyKvKey, + delta: int | KvU64, + **options: Unpack[SumOptions[int]], + ) -> MutateResultT: + return self.sum(key, delta, number_type=KvNumber.u64, **options) + + +class MinMutatorMixin(MutatorMixin[MutateResultT]): + @overload + def min( + self, + key: AnyKvKey, + value: JSBigInt, + number_type: None = None, + **options: Unpack[MutationOptions], + ) -> MutateResultT: ... + + @overload + def min( + self, + key: AnyKvKey, + value: int | JSBigInt, + number_type: BigIntKvNumberIdentifier, + **options: Unpack[MutationOptions], + ) -> MutateResultT: ... + + @overload + def min( + self, + key: AnyKvKey, + value: KvU64, + number_type: None = None, + **options: Unpack[MutationOptions], + ) -> MutateResultT: ... + + @overload + def min( + self, + key: AnyKvKey, + value: int | KvU64, + number_type: U64KvNumberIdentifier, + **options: Unpack[MutationOptions], + ) -> MutateResultT: ... + + @overload + def min( + self, + key: AnyKvKey, + value: float, + number_type: FloatKvNumberIdentifier | None = None, + **options: Unpack[MutationOptions], + ) -> MutateResultT: ... + + @overload + def min( + self, + key: AnyKvKey, + value: NumberT | KvNumberTypeT, + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT], + # Can't use float limits unless the float type is explicitly being used, + # as float is incompatible with the other number types, but int is + # compatible. + **options: Unpack[MutationOptions], + ) -> MutateResultT: ... + + def min( + self, + key: AnyKvKey, + value: JSBigInt | float | KvU64 | NumberT | KvNumberTypeT, + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT] + | KvNumberIdentifier + | None = None, + **options: Unpack[MutationOptions], + ) -> MutateResultT: + value = cast(Union[NumberT, KvNumberTypeT], value) + number_type = cast( + KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT], number_type + ) + return self.mutate(Min(key, value, number_type, **options)) + + def min_bigint( + self, + key: AnyKvKey, + value: int | JSBigInt, + **options: Unpack[MutationOptions], + ) -> MutateResultT: + return self.min(key, value, number_type=KvNumber.bigint, **options) + + def min_float( + self, + key: AnyKvKey, + value: float, + **options: Unpack[MutationOptions], + ) -> MutateResultT: + return self.min(key, value, number_type=KvNumber.float, **options) + + def min_kvu64( + self, + key: AnyKvKey, + value: int | KvU64, + **options: Unpack[MutationOptions], + ) -> MutateResultT: + return self.min(key, value, number_type=KvNumber.u64, **options) + + +class MaxMutatorMixin(MutatorMixin[MutateResultT]): + @overload + def max( + self, + key: AnyKvKey, + value: JSBigInt, + number_type: None = None, + **options: Unpack[MutationOptions], + ) -> MutateResultT: ... + + @overload + def max( + self, + key: AnyKvKey, + value: int | JSBigInt, + number_type: BigIntKvNumberIdentifier, + **options: Unpack[MutationOptions], + ) -> MutateResultT: ... + + @overload + def max( + self, + key: AnyKvKey, + value: KvU64, + number_type: None = None, + **options: Unpack[MutationOptions], + ) -> MutateResultT: ... + + @overload + def max( + self, + key: AnyKvKey, + value: int | KvU64, + number_type: U64KvNumberIdentifier, + **options: Unpack[MutationOptions], + ) -> MutateResultT: ... + + @overload + def max( + self, + key: AnyKvKey, + value: float, + number_type: FloatKvNumberIdentifier | None = None, + **options: Unpack[MutationOptions], + ) -> MutateResultT: ... + + @overload + def max( + self, + key: AnyKvKey, + value: NumberT | KvNumberTypeT, + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT], + # Can't use float limits unless the float type is explicitly being used, + # as float is incompatible with the other number types, but int is + # compatible. + **options: Unpack[MutationOptions], + ) -> MutateResultT: ... + + def max( + self, + key: AnyKvKey, + value: JSBigInt | float | KvU64 | NumberT | KvNumberTypeT, + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT] + | KvNumberIdentifier + | None = None, + **options: Unpack[MutationOptions], + ) -> MutateResultT: + value = cast(Union[NumberT, KvNumberTypeT], value) + number_type = cast( + KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT], number_type + ) + return self.mutate(Max(key, value, number_type, **options)) + + def max_bigint( + self, + key: AnyKvKey, + value: int | JSBigInt, + **options: Unpack[MutationOptions], + ) -> MutateResultT: + return self.max(key, value, number_type=KvNumber.bigint, **options) + + def max_float( + self, + key: AnyKvKey, + value: float, + **options: Unpack[MutationOptions], + ) -> MutateResultT: + return self.max(key, value, number_type=KvNumber.float, **options) + + def max_kvu64( + self, + key: AnyKvKey, + value: int | KvU64, + **options: Unpack[MutationOptions], + ) -> MutateResultT: + return self.max(key, value, number_type=KvNumber.u64, **options) + + +class DeleteMutatorMixin(MutatorMixin[MutateResultT]): + def delete(self, key: AnyKvKey) -> MutateResultT: + if isinstance(key, Delete): + return self.mutate(key) + return self.mutate(Delete(key)) + + +class EnqueueMixin(Generic[EnqueueResultT]): + @abstractmethod + def _enqueue(self, enqueue: Enqueue, /) -> EnqueueResultT: + raise NotImplementedError + + @overload + def enqueue(self, enqueue: Enqueue, /) -> EnqueueResultT: ... + + @overload + def enqueue( + self, + message: object, + *, + delivery_time: datetime | None = None, + retry_delays: Backoff | None = None, + dead_letter_keys: Sequence[AnyKvKey] | None = None, + ) -> EnqueueResultT: ... + + def enqueue( + self, + message: object | Enqueue, + *, + delivery_time: datetime | None = None, + retry_delays: Backoff | None = None, + dead_letter_keys: Sequence[AnyKvKey] | None = None, + ) -> EnqueueResultT: + if isinstance(message, Enqueue): + enqueue = message + else: + enqueue = Enqueue( + message, + delivery_time=delivery_time, + retry_delays=retry_delays, + dead_letter_keys=dead_letter_keys, + ) + return self._enqueue(enqueue) + + +@dataclass(init=False) +class PlannedWrite( + CheckMixin["PlannedWrite"], + SetMutatorMixin["PlannedWrite"], + SumMutatorMixin["PlannedWrite"], + MinMutatorMixin["PlannedWrite"], + MaxMutatorMixin["PlannedWrite"], + DeleteMutatorMixin["PlannedWrite"], + EnqueueMixin["PlannedWrite"], + AtomicWriteRepresentationWriter["CompletedWrite"], +): + kv: KvWriter | None + checks: MutableSequence[CheckRepresentation] + mutations: MutableSequence[MutationRepresentation] + enqueues: MutableSequence[EnqueueRepresentation] + v8_encoder: Encoder | None + + def __init__( + self, + kv: KvWriter | None = None, + checks: MutableSequence[CheckRepresentation] | None = None, + mutations: MutableSequence[MutationRepresentation] | None = None, + enqueues: MutableSequence[EnqueueRepresentation] | None = None, + *, + v8_encoder: Encoder | None = None, + ) -> None: + self.kv = kv + self.checks = list(checks or ()) + self.mutations = list(mutations or ()) + self.enqueues = list(enqueues or ()) + self.v8_encoder = v8_encoder + + @override + async def write( + self, kv: KvWriter | None = None, *, v8_encoder: Encoder | None = None + ) -> CompletedWrite: + _kv = self.kv if kv is None else kv + if _kv is None: + raise TypeError( + f"{type(self).__name__}.write() must get a value for its 'kv' " + "argument when 'self.kv' isn't set" + ) + + _v8_encoder = self.v8_encoder if v8_encoder is None else v8_encoder + if _v8_encoder is None: + _v8_encoder = get_v8_encoder(_kv).value_or(None) + if _v8_encoder is None: + raise TypeError( + f"{type(self).__name__}.write() must get a value for its " + "'v8_encoder' keyword argument when 'self.v8_encoder' isn't " + "set and 'kv' does not provide one." + ) + + (pb_atomic_write,) = self.as_protobuf(v8_encoder=_v8_encoder) + # Copy the write components so that the results are not affected if the + # PlannedWrite is modified during this write. + checks = tuple(self.checks) + mutations = tuple(self.mutations) + enqueues = tuple(self.enqueues) + result = await _kv.write(protobuf_atomic_write=pb_atomic_write) + + if is_err(result): + if isinstance(result.error, CheckFailure): + check_failure = result.error + return ConflictedWrite( + failed_checks=check_failure.failed_check_indexes, + checks=checks, + mutations=mutations, + enqueues=enqueues, + endpoint=check_failure.endpoint, + cause=check_failure, + ) + raise FailedWrite( + checks=checks, + mutations=mutations, + enqueues=enqueues, + endpoint=result.error.endpoint, + ) from result.error + + versionstamp, endpoint = result.value + return CommittedWrite( + versionstamp=versionstamp, + checks=checks, + mutations=mutations, + enqueues=enqueues, + endpoint=endpoint, + ) + + def as_protobuf(self, *, v8_encoder: Encoder) -> tuple[AtomicWrite]: + return ( + AtomicWrite( + checks=[ + pb_msg + for check in self.checks + for pb_msg in check.as_protobuf(v8_encoder=v8_encoder) + ], + mutations=[ + pb_msg + for mut in self.mutations + for pb_msg in mut.as_protobuf(v8_encoder=v8_encoder) + ], + enqueues=[ + pb_msg + for enq in self.enqueues + for pb_msg in enq.as_protobuf(v8_encoder=v8_encoder) + ], + ), + ) + + @override + def _check(self, check: CheckRepresentation, /) -> Self: + self.checks.append(check) + return self + + @override + def mutate(self, mutation: MutationRepresentation) -> Self: + self.mutations.append(mutation) + return self + + @override + def _enqueue(self, enqueue: Enqueue, /) -> Self: + self.enqueues.append(enqueue) + return self + + +EMPTY_MAP: Final[Mapping[Any, Any]] = MappingProxyType({}) + + +# TODO: Support capturing retries in the FailedWrite/CommittedWrite? +@dataclass(init=False, unsafe_hash=True) +class FailedWrite(FrozenAfterInitDataclass, AnyFailure, DenoKvError): + if TYPE_CHECKING: + + def _AnyFailure_marker(self, no_call: Never) -> Never: ... + + checks: Final[Sequence[CheckRepresentation]] = field() + failed_checks: Final[Sequence[int]] = field() + has_unknown_conflicts: Final[bool] = field() + """ + Whether the check(s) that failed are unknown. + + KV servers may or may not report which check(s) failed when a write + fails due to a check conflict. + """ + mutations: Final[Sequence[MutationRepresentation]] = field() + enqueues: Final[Sequence[EnqueueRepresentation]] = field() + endpoint: Final[EndpointInfo] = field() + ok: Final[Literal[False]] = False # noqa: PYI064 + versionstamp: Final[None] = None + + def __init__( + self, + checks: Iterable[CheckRepresentation], + mutations: Iterable[MutationRepresentation], + enqueues: Iterable[EnqueueRepresentation], + endpoint: EndpointInfo, + *, + cause: BaseException | None = None, + ) -> None: + super(FailedWrite, self).__init__() + self.checks = tuple(checks) # type: ignore[misc] # Cannot assign to final + # Allow subclass to initialise failed_checks + if not hasattr(self, "failed_checks"): + self.failed_checks = tuple() # type: ignore[misc] # Cannot assign to final + self.has_unknown_conflicts = False # type: ignore[misc] # Cannot assign to final + self.mutations = tuple(mutations) # type: ignore[misc] # Cannot assign to final + self.enqueues = tuple(enqueues) # type: ignore[misc] # Cannot assign to final + self.endpoint = endpoint # type: ignore[misc] # Cannot assign to final + self.__cause__ = cause + + @property + def conflicts(self) -> Mapping[AnyKvKey, CheckRepresentation]: + checks = self.checks + return {checks[i].key: checks[i] for i in self.failed_checks} + + def _get_cause_description(self) -> str: + if self.__cause__: + return type(self.__cause__).__name__ + return "unspecified cause" + + @property + def message(self) -> str: + # TODO: after xxx attempts? + return ( + f"to {str(self.endpoint.url)!r} " + f"due to {self._get_cause_description()}, " + f"with {len(self.checks)} checks, " + f"{len(self.mutations)} mutations, " + f"{len(self.enqueues)} enqueues" + ) + + def __str__(self) -> str: + return f"Write failed {self.message}" + + def __repr__(self) -> str: + return f"<{type(self).__name__} {self.message}>" + + +def _normalise_failed_checks( + failed_checks: Iterable[int], checks: tuple[CheckRepresentation, ...] +) -> tuple[int, ...]: + failed_checks = tuple(sorted(failed_checks)) + # If the server didn't report failed checks and there was only one check, we + # know the single check must have failed, so report that. + if len(failed_checks) == 0 and len(checks) == 1: + return (0,) + if failed_checks and (failed_checks[0] < 0 or failed_checks[-1] >= len(checks)): + raise ValueError("failed_checks contains out-of-bounds index") + return failed_checks + + +class ConflictedWrite(FailedWrite): + def __init__( + self, + failed_checks: Iterable[int] | None, + checks: Iterable[CheckRepresentation], + mutations: Iterable[MutationRepresentation], + enqueues: Iterable[EnqueueRepresentation], + endpoint: EndpointInfo, + *, + cause: BaseException | None = None, + ) -> None: + _checks = tuple(checks) + self.failed_checks = _normalise_failed_checks( # type: ignore[misc] # Cannot assign to final attribute "failed_checks" + failed_checks or [], + checks=_checks, + ) + self.has_unknown_conflicts = len(self.failed_checks) == 0 # type: ignore[misc] # Cannot assign to final attribute + super(ConflictedWrite, self).__init__( + _checks, mutations, enqueues, endpoint, cause=cause + ) + + @property + def message(self) -> str: + return ( + f"NOT APPLIED to {str(self.endpoint.url)!r} with " + f"{len(self.conflicts)}/{len(self.checks)} checks CONFLICTING, " + f"{len(self.mutations)} mutations, " + f"{len(self.enqueues)} enqueues" + ) + + def __str__(self) -> str: + return f"Write {self.message}" + + +@dataclass(init=False, unsafe_hash=True, **slots_if310()) +class CommittedWrite(FrozenAfterInitDataclass, AnySuccess): + if TYPE_CHECKING: + + def _AnySuccess_marker(self, no_call: Never) -> Never: ... + + ok: Final[Literal[True]] # noqa: PYI064 + conflicts: Final[Mapping[KvKey, CheckRepresentation]] # empty + has_unknown_conflicts: Final[Literal[False]] + versionstamp: Final[VersionStamp] + checks: Final[Sequence[CheckRepresentation]] + mutations: Final[Sequence[MutationRepresentation]] + enqueues: Final[Sequence[EnqueueRepresentation]] + endpoint: Final[EndpointInfo] + + def __init__( + self, + versionstamp: VersionStamp, + checks: Sequence[CheckRepresentation], + mutations: Sequence[MutationRepresentation], + enqueues: Sequence[EnqueueRepresentation], + endpoint: EndpointInfo, + ) -> None: + self.ok = True + self.conflicts = EMPTY_MAP + self.has_unknown_conflicts = False + self.versionstamp = versionstamp + self.checks = tuple(checks) + self.mutations = tuple(mutations) + self.enqueues = tuple(enqueues) + self.endpoint = endpoint + + @property + def _message(self) -> str: + return ( + f"version 0x{self.versionstamp} to {str(self.endpoint.url)!r} with " + f"{len(self.checks)} checks, " + f"{len(self.mutations)} mutations, " + f"{len(self.enqueues)} enqueues" + ) + + def __str__(self) -> str: + return f"Write committed {self._message}" + + def __repr__(self) -> str: + return f"<{type(self).__name__} {self._message}>" + + +CompletedWrite: TypeAlias = Union[CommittedWrite, ConflictedWrite] + + +def is_applied(write: CompletedWrite) -> TypeIs[CommittedWrite]: + return isinstance(write, CommittedWrite) + + +@runtime_checkable +class AnyKeyVersion(Protocol): + __slots__ = () + + if TYPE_CHECKING: + + @property + def key(self) -> AnyKvKey: ... + @property + def versionstamp(self) -> VersionStamp | None: ... + else: + key = ... + versionstamp = ... + + +class CheckRepresentation( + SingleProtobufMessageRepresentation[dp_protobuf.Check], AnyKeyVersion +): + __slots__ = () + + # Check never needs an Encoder, so override the signature to make it optional. + @override + @abstractmethod + def as_protobuf( + self, *, v8_encoder: Encoder | None = None + ) -> tuple[dp_protobuf.Check]: ... + + +@dataclass(frozen=True, **slots_if310()) +class Check(CheckRepresentation, AnyKeyVersion): + """ + A condition that must hold for a database write operation to be applied. + + By applying checks to a write operation, writes can ensure that the changes + they make are changing the existing values they expect. Without appropriate + checks, write operations could overwrite another writer's changes to the + database. + + Checks are part of Deno KV's + [Multi-version concurrency control](https://en.wikipedia.org/wiki/Multiversion_concurrency_control) + support. + """ + + key: AnyKvKey + """The key that the check applies to.""" + versionstamp: VersionStamp | None + """ + The version that that the key's value must have for the check to succeed. + + `None` means the key must not have a value set for the check to succeed. + """ + + @classmethod + def for_key_with_version(cls, key: AnyKvKey, versionstamp: VersionStamp) -> Self: + return cls(key, versionstamp) + + @classmethod + def for_key_not_set(cls, key: AnyKvKey) -> Self: + return cls(key, versionstamp=None) + + @override + def as_protobuf( + self, *, v8_encoder: Encoder | None = None + ) -> tuple[dp_protobuf.Check]: + return ( + dp_protobuf.Check(key=pack_key(self.key), versionstamp=self.versionstamp), + ) + + +class MutationRepresentation(ProtobufMessageRepresentation[dp_protobuf.Mutation]): + __slots__ = () + + @abstractmethod + def as_protobuf(self, *, v8_encoder: Encoder) -> Sequence[dp_protobuf.Mutation]: ... + + +@dataclass(init=False, **slots_if310()) +class Mutation(FrozenAfterInitDataclass, MutationRepresentation): + key: AnyKvKey + expire_at: datetime | None + + def __init__(self, key: AnyKvKey, **options: Unpack[MutationOptions]) -> None: + if type(self) is Mutation: + raise TypeError("cannot create Mutation instances directly") + self.key = key + self.expire_at = options.get("expire_at") + + def expire_at_ms(self) -> int: + return 0 if self.expire_at is None else int(self.expire_at.timestamp() * 1000) + + +@dataclass(init=False, **slots_if310()) +class Set(Mutation): + value: object + versioned: bool + + def __init__( + self, + key: AnyKvKey, + value: object, + *, + expire_at: datetime | None = None, + versioned: bool = False, + ) -> None: + super(Set, self).__init__(key, expire_at=expire_at) + self.value = value + self.versioned = versioned + + @override + def as_protobuf(self, *, v8_encoder: Encoder) -> tuple[dp_protobuf.Mutation]: + return ( + dp_protobuf.Mutation( + mutation_type=dp_protobuf.MutationType.M_SET_SUFFIX_VERSIONSTAMPED_KEY + if self.versioned + else dp_protobuf.MutationType.M_SET, + key=pack_key(self.key), + value=encode_kv_write_value(self.value, v8_encoder=v8_encoder), + expire_at_ms=self.expire_at_ms(), + ), + ) + + +class LimitExceededPolicy(EvalEnumRepr, Enum): + ABORT = "abort" + CLAMP = "clamp" + WRAP = "wrap" + + +LimitExceededInput = Literal[ + "abort", + "clamp", + LimitExceededPolicy.ABORT, + LimitExceededPolicy.CLAMP, +] + + +@dataclass(frozen=True, **slots_if310()) +class Limit(Container[NumberT_co]): + """ + A range of numbers used to define the allowed range of `Sum` operations. + + Examples + -------- + >>> lim = Limit(0, 100, limit_exceeded='clamp') + >>> lim + Limit(min=0, max=100, limit_exceeded=LimitExceededPolicy.CLAMP) + >>> -10 in lim + False + >>> 110 in lim + False + >>> 10 in lim + True + >>> 9000 in Limit(min=0) + True + """ + + min: NumberT_co | None = field(default=None) + max: NumberT_co | None = field(default=None) + limit_exceeded: LimitExceededPolicy = field(default=LimitExceededPolicy.ABORT) + + if TYPE_CHECKING: + # Customise the init signature to: + # - accept string values to init limit_exceeded + # - Hide the LimitExceededPolicy.WRAP option from the init signature so + # that using it is a type error. There's no way to use a custom wrap + # limit, only LIMIT_KVU64 is supported. + def __init__( + self, + min: NumberT_co | None = None, + max: NumberT_co | None = None, + limit_exceeded: LimitExceededInput | None = LimitExceededPolicy.ABORT, + ) -> None: + pass + + def __post_init__(self) -> None: + # Support specifying limit_exceeded via the enum's string values. + if not isinstance(self.limit_exceeded, LimitExceededPolicy): + object.__setattr__( + self, "limit_exceeded", LimitExceededPolicy(self.limit_exceeded) + ) + + def __contains__(self, x: object) -> bool: + if not isinstance(x, (int, float)): + return False + return (self.min is None or self.min <= x) and ( + self.max is None or self.max >= x + ) + + +LIMIT_KVU64 = Limit( + min=KvU64.RANGE[0], + max=KvU64.RANGE[-1], + # Not normally allowed by types because only LIMIT_KVU64 can use WRAP. + limit_exceeded=cast(LimitExceededInput, LimitExceededPolicy.WRAP), +) +LIMIT_UNLIMITED = Limit[Any]() + + +class AmbiguousNumberWarning(UserWarning): + pass + + +@dataclass(init=False, **slots_if310()) +class NumberMutation(Mutation, Generic[KvNumberNameT_co, NumberT_co, KvNumberTypeT_co]): + number_type: KvNumberInfo[KvNumberNameT_co, NumberT_co, KvNumberTypeT_co] + + def __init__( + self, + *, + key: AnyKvKey, + expire_at: datetime | None = None, + number_type: KvNumberInfo[KvNumberNameT_co, NumberT_co, KvNumberTypeT_co], + ) -> None: + super(NumberMutation, self).__init__(key, expire_at=expire_at) + self.number_type = number_type + + @classmethod + def _resolve_number_value_type( + cls, + value: JSBigInt | KvU64 | float | NumberT | KvNumberTypeT, + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT] + | KvNumberIdentifier + | None = None, + ) -> tuple[NumberT, KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT]]: + resolved_number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT] + if isinstance(number_type, KvNumberInfo): + resolved_number_type = number_type + elif number_type is not None: + number_identifier: KvNumberIdentifier = number_type + resolved_number_type = KvNumber.resolve(number_identifier).value # pyright: ignore[reportAssignmentType] + else: + known_number = cast(Union[KvU64, JSBigInt, float], value) + resolved_number_type = KvNumber.resolve(number=known_number).value # pyright: ignore[reportAssignmentType] + + resolved_value = cast(Union[KvNumberTypeT, NumberT], value) + + return ( + resolved_number_type.as_py_number(resolved_value), + resolved_number_type, + ) + + +@dataclass(init=False, **slots_if310()) +class Sum(NumberMutation[KvNumberNameT_co, NumberT_co, KvNumberTypeT_co]): + _INIT_OPTIONS: ClassVar = frozenset( + ["clamp_over", "clamp_under", "abort_over", "abort_under", "limit", "expire_at"] + ) + delta: Final[NumberT_co] # type: ignore[misc] + limit: Final[Limit[NumberT_co]] # type: ignore[misc] + + @override + def as_protobuf( + self, *, v8_encoder: Encoder | None = None + ) -> Sequence[dp_protobuf.Mutation]: + return self.number_type.get_sum_mutations(self, v8_encoder=v8_encoder) + + @overload + def __init__( # pyright: ignore[reportOverlappingOverload] + self: BigIntSum, + key: AnyKvKey, + delta: JSBigInt, + number_type: None = None, + **options: Unpack[SumOptions[int]], + ) -> None: ... + + @overload + def __init__( + self: BigIntSum, + key: AnyKvKey, + delta: int | JSBigInt, + number_type: BigIntKvNumberIdentifier, + **options: Unpack[SumOptions[int]], + ) -> None: ... + + @overload + def __init__( + self: U64Sum, + key: AnyKvKey, + delta: KvU64, + number_type: None = None, + **options: Unpack[SumOptions[int]], + ) -> None: ... + + @overload + def __init__( + self: U64Sum, + key: AnyKvKey, + delta: int | KvU64, + number_type: U64KvNumberIdentifier, + **options: Unpack[SumOptions[int]], + ) -> None: ... + + @overload + def __init__( + self: FloatSum, + key: AnyKvKey, + delta: float, + number_type: FloatKvNumberIdentifier | None = None, + **options: Unpack[SumOptions[float]], + ) -> None: ... + + @overload + def __init__( + self: Sum[KvNumberNameT, NumberT, KvNumberTypeT], + key: AnyKvKey, + delta: NumberT | KvNumberTypeT, + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT], + # Can't use float limits unless the float type is explicitly being used, + # as float is incompatible with the other number types, but int is + # compatible. + **options: Unpack[SumOptions[NumberT]], + ) -> None: ... + + def __init__( + self: Sum[KvNumberNameT, NumberT, KvNumberTypeT], + key: AnyKvKey, + delta: JSBigInt | KvU64 | float | NumberT | KvNumberTypeT, + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT] + | KvNumberIdentifier + | None = None, + **options: Unpack[SumOptions[int | float | NumberT]], + ) -> None: + if options.keys() - self._INIT_OPTIONS: + arg = next(iter(options.keys() - self._INIT_OPTIONS)) + raise TypeError( + f"Sum.__init__() got an unexpected keyword argument {arg!r}" + ) + resolved_delta, resolved_number_type = Sum._resolve_number_value_type( + delta, number_type + ) + super(Sum, self).__init__( + key=key, + expire_at=options.pop("expire_at", None), + number_type=resolved_number_type, + ) + self.limit = ( + Sum._create_limit(**cast(LimitOptions[NumberT], options)) + or resolved_number_type.default_limit + ) + resolved_number_type.validate_limit(self.limit) + self.delta = resolved_delta + + @classmethod + def _create_limit( + cls, **options: Unpack[LimitOptions[NumberT]] + ) -> Limit[NumberT] | None: + limits = dict[Literal["limit=", "clamp_*=", "abort_*="], Limit[NumberT]]() + + if limit := options.get("limit"): + limits["limit="] = limit + + if "clamp_under" in options or "clamp_over" in options: + limits["clamp_*="] = Limit( + min=options.get("clamp_under"), + max=options.get("clamp_over"), + limit_exceeded=LimitExceededPolicy.CLAMP, + ) + + if "abort_under" in options or "abort_over" in options: + limits["abort_*="] = Limit( + min=options.get("abort_under"), + max=options.get("abort_over"), + limit_exceeded=LimitExceededPolicy.ABORT, + ) + + if len(limits) > 1: + options_used = ", ".join(sorted(limits)) + raise with_notes( + ValueError( + f"Limit keyword arguments in conflict: " + f"Options {options_used} cannot be used together." + ), + "Use limit=Limit(limit_exceeded=..., ...) to create a limit " + "with a dynamic type.", + ) + return next(iter(limits.values()), None) + + +BigIntSum: TypeAlias = Sum[Literal["bigint"], int, JSBigInt] +FloatSum: TypeAlias = Sum[Literal["float"], float, float] +U64Sum: TypeAlias = Sum[Literal["u64"], int, KvU64] + + +@dataclass(init=False, **slots_if310()) +class Min(NumberMutation[KvNumberNameT_co, NumberT_co, KvNumberTypeT_co]): + value: Final[NumberT_co] # type: ignore[misc] + + @overload + def __init__( # pyright: ignore[reportOverlappingOverload] + self: BigIntMin, + key: AnyKvKey, + value: JSBigInt, + number_type: None = None, + **options: Unpack[MutationOptions], + ) -> None: ... + + @overload + def __init__( + self: BigIntMin, + key: AnyKvKey, + value: int | JSBigInt, + number_type: BigIntKvNumberIdentifier, + **options: Unpack[MutationOptions], + ) -> None: ... + + @overload + def __init__( + self: U64Min, + key: AnyKvKey, + value: KvU64, + number_type: None = None, + **options: Unpack[MutationOptions], + ) -> None: ... + + @overload + def __init__( + self: U64Min, + key: AnyKvKey, + value: int | KvU64, + number_type: U64KvNumberIdentifier, + **options: Unpack[MutationOptions], + ) -> None: ... + + @overload + def __init__( + self: FloatMin, + key: AnyKvKey, + value: float, + number_type: FloatKvNumberIdentifier | None = None, + **options: Unpack[MutationOptions], + ) -> None: ... + + @overload + def __init__( + self: Min[KvNumberNameT, NumberT, KvNumberTypeT], + key: AnyKvKey, + value: NumberT | KvNumberTypeT, + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT], + # Can't use float limits unless the float type is explicitly being used, + # as float is incompatible with the other number types, but int is + # compatible. + **options: Unpack[MutationOptions], + ) -> None: ... + + def __init__( + self: Min[KvNumberNameT, NumberT, KvNumberTypeT], + key: AnyKvKey, + value: JSBigInt | KvU64 | float | NumberT | KvNumberTypeT, + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT] + | KvNumberIdentifier + | None = None, + **options: Unpack[MutationOptions], + ) -> None: + resolved_number, resolved_number_type = Min._resolve_number_value_type( + value, number_type + ) + super(Min, self).__init__(key=key, number_type=resolved_number_type, **options) + self.value = resolved_number + + @override + def as_protobuf(self, *, v8_encoder: Encoder) -> Sequence[dp_protobuf.Mutation]: + return self.number_type.get_min_mutations(self, v8_encoder=v8_encoder) + + +BigIntMin: TypeAlias = Min[Literal["bigint"], int, JSBigInt] +FloatMin: TypeAlias = Min[Literal["float"], float, float] +U64Min: TypeAlias = Min[Literal["u64"], int, KvU64] + + +@dataclass(init=False, **slots_if310()) +class Max(NumberMutation[KvNumberNameT_co, NumberT_co, KvNumberTypeT_co]): + value: Final[NumberT_co] # type: ignore[misc] + + @overload + def __init__( # pyright: ignore[reportOverlappingOverload] + self: BigIntMax, + key: AnyKvKey, + value: JSBigInt, + number_type: None = None, + **options: Unpack[MutationOptions], + ) -> None: ... + + @overload + def __init__( + self: BigIntMax, + key: AnyKvKey, + value: int | JSBigInt, + number_type: BigIntKvNumberIdentifier, + **options: Unpack[MutationOptions], + ) -> None: ... + + @overload + def __init__( + self: U64Max, + key: AnyKvKey, + value: KvU64, + number_type: None = None, + **options: Unpack[MutationOptions], + ) -> None: ... + + @overload + def __init__( + self: U64Max, + key: AnyKvKey, + value: int | KvU64, + number_type: U64KvNumberIdentifier, + **options: Unpack[MutationOptions], + ) -> None: ... + + @overload + def __init__( + self: FloatMax, + key: AnyKvKey, + value: float, + number_type: FloatKvNumberIdentifier | None = None, + **options: Unpack[MutationOptions], + ) -> None: ... + + @overload + def __init__( + self: Max[KvNumberNameT, NumberT, KvNumberTypeT], + key: AnyKvKey, + value: NumberT | KvNumberTypeT, + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT], + # Can't use float limits unless the float type is explicitly being used, + # as float is incompatible with the other number types, but int is + # compatible. + **options: Unpack[MutationOptions], + ) -> None: ... + + def __init__( + self: Max[KvNumberNameT, NumberT, KvNumberTypeT], + key: AnyKvKey, + value: JSBigInt | KvU64 | float | NumberT | KvNumberTypeT, + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT] + | KvNumberIdentifier + | None = None, + **options: Unpack[MutationOptions], + ) -> None: + resolved_number, resolved_number_type = Max._resolve_number_value_type( + value, number_type + ) + super(Max, self).__init__(key=key, number_type=resolved_number_type, **options) + self.value = resolved_number + + @override + def as_protobuf(self, *, v8_encoder: Encoder) -> Sequence[dp_protobuf.Mutation]: + return self.number_type.get_max_mutations(self, v8_encoder=v8_encoder) + + +BigIntMax: TypeAlias = Max[Literal["bigint"], int, JSBigInt] +FloatMax: TypeAlias = Max[Literal["float"], float, float] +U64Max: TypeAlias = Max[Literal["u64"], int, KvU64] + + +@dataclass(**slots_if310()) +class Delete(Mutation): + def __init__(self, key: AnyKvKey) -> None: + super(Delete, self).__init__(key, expire_at=None) + + @override + def as_protobuf( + self, *, v8_encoder: Encoder | None = None + ) -> tuple[dp_protobuf.Mutation]: + return ( + dp_protobuf.Mutation( + mutation_type=dp_protobuf.MutationType.M_DELETE, key=pack_key(self.key) + ), + ) + + +DEFAULT_ENQUEUE_RETRY_DELAYS = ExponentialBackoff( + initial_interval_seconds=1, multiplier=3 +) +DEFAULT_ENQUEUE_RETRY_DELAY_COUNT = 10 + + +class EnqueueRepresentation(SingleProtobufMessageRepresentation[dp_protobuf.Enqueue]): + __slots__ = () + + +@dataclass(init=False, **slots_if310()) +class Enqueue(FrozenAfterInitDataclass, EnqueueRepresentation): + """ + A message to be async-delivered to a Deno app listening to the Kv's queue. + + Parameters + ---------- + message: + The message to deliver. Can be any value that can be written to the database. + delivery_time: + Delay the message delivery until this time. + + If the time is None or in the past, the message is delivered as soon as + possible. + retry_delays: + Delivery attempts that fail will be retried after these delays. + + If the value is an Iterable, a fixed number of values will be drawn to retry + with. Use a fixed-length Sequence to specify a precise number of retries. + Default: DEFAULT_ENQUEUE_RETRY_DELAYS + dead_letter_keys: + Messages that cannot be delivered will be written to these keys. + + Notes + ----- + See [Deno.Kv.listenQueue()](https://docs.deno.com/api/deno/~/Deno.Kv#method_listenqueue_0) + """ + + message: object + delivery_time: datetime | None + retry_delays: Backoff + dead_letter_keys: Sequence[AnyKvKey] + + def __init__( + self, + message: object, + *, + delivery_time: datetime | None = None, + retry_delays: Backoff | None = None, + dead_letter_keys: Sequence[AnyKvKey] | None = None, + ): + self.message = message + self.delivery_time = delivery_time + self.retry_delays = ( + DEFAULT_ENQUEUE_RETRY_DELAYS if retry_delays is None else retry_delays + ) + self.dead_letter_keys = () if dead_letter_keys is None else dead_letter_keys + + @override + def as_protobuf(self, *, v8_encoder: Encoder) -> tuple[dp_protobuf.Enqueue]: + deadline_ms = None + if self.delivery_time is not None: + deadline_ms = int(self.delivery_time.timestamp() * 1000) + return ( + dp_protobuf.Enqueue( + payload=bytes(v8_encoder.encode(self.message)), + keys_if_undelivered=[pack_key(k) for k in self.dead_letter_keys], + deadline_ms=deadline_ms, + backoff_schedule=self._evaluate_backoff_schedule(), + ), + ) + + def _evaluate_backoff_schedule(self) -> Sequence[int]: + # Sample a fixed max number from unknown-length iterables. + delay_seconds = ( + self.retry_delays + if isinstance(self.retry_delays, Sequence) + else islice(self.retry_delays, DEFAULT_ENQUEUE_RETRY_DELAY_COUNT) + ) + # Backoff times are in seconds, but we need milliseconds + return [int(delay * 1000) for delay in delay_seconds] + + +WriteOperation: TypeAlias = Union[ + CheckRepresentation, MutationRepresentation, EnqueueRepresentation +] diff --git a/src/denokv/_pycompat/dataclasses.py b/src/denokv/_pycompat/dataclasses.py index 3767436..cede4a0 100644 --- a/src/denokv/_pycompat/dataclasses.py +++ b/src/denokv/_pycompat/dataclasses.py @@ -4,8 +4,10 @@ from dataclasses import FrozenInstanceError from dataclasses import dataclass from dataclasses import fields as dataclass_fields -from typing import Literal -from typing import TypedDict # avoid circular reference with _pycompat.typing +from typing import Literal # noqa: TID251 + +# avoid circular reference with _pycompat.typing +from typing import TypedDict # noqa: TID251 class NoArg(TypedDict): @@ -51,6 +53,8 @@ class FrozenAfterInitDataclass: doesn't affect non-dataclass fields, such as typing.Generic's dunder fields. """ + __slots__ = () + def __delattr__(self, name: str) -> None: if name in (f.name for f in dataclass_fields(self)): raise FrozenInstanceError(f"cannot delete field {name}") diff --git a/src/denokv/_pycompat/enum.py b/src/denokv/_pycompat/enum.py index 36b38ff..f570813 100644 --- a/src/denokv/_pycompat/enum.py +++ b/src/denokv/_pycompat/enum.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +from enum import Enum from enum import EnumMeta from enum import Flag from enum import IntFlag @@ -65,3 +66,19 @@ def __str__(self) -> str: else: from enum import IntEnum as IntEnum # noqa: F401 # re-export + + +class EvalEnumRepr(Enum): + """ + An Enum mixin that uses 'EnumName.FIELD' as the repr. + + Example + ------- + >>> class EnumName(EvalEnumRepr, Enum): + ... FIELD = 'a' + >>> EnumName.FIELD + EnumName.FIELD + """ + + def __repr__(self) -> str: + return f"{type(self).__name__}.{self.name}" diff --git a/src/denokv/_pycompat/exceptions.py b/src/denokv/_pycompat/exceptions.py new file mode 100644 index 0000000..acfd02a --- /dev/null +++ b/src/denokv/_pycompat/exceptions.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from denokv._pycompat.typing import Protocol +from denokv._pycompat.typing import TypeGuard +from denokv._pycompat.typing import TypeVar +from denokv._pycompat.typing import cast + + +class Notes(Protocol): + __notes__: list[str] + + +def has_notes(exc: BaseException) -> TypeGuard[Notes]: + return isinstance(getattr(exc, "__notes__", None), list) + + +def add_note(exc: BaseException, note: str) -> None: + if not isinstance(note, str): + raise TypeError("note must be a str") + if not has_notes(exc): + exc_with_notes = cast(Notes, exc) + exc_with_notes.__notes__ = notes = [] + else: + notes = exc.__notes__ + notes.append(note) + + +ExceptionT = TypeVar("ExceptionT", bound=BaseException) + + +def with_notes( + exc: ExceptionT, *notes: str, from_exception: BaseException | None = None +) -> ExceptionT: + if from_exception and has_notes(from_exception): + for note in from_exception.__notes__: + add_note(exc, note) + for note in notes: + add_note(exc, note) + return exc diff --git a/src/denokv/_pycompat/protobuf.py b/src/denokv/_pycompat/protobuf.py index 4b3afbb..cd11027 100644 --- a/src/denokv/_pycompat/protobuf.py +++ b/src/denokv/_pycompat/protobuf.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import overload +from typing import overload # noqa: TID251 from denokv._datapath_pb2 import AtomicWriteStatus from denokv._datapath_pb2 import MutationType diff --git a/src/denokv/_pycompat/types.py b/src/denokv/_pycompat/types.py new file mode 100644 index 0000000..e60a52f --- /dev/null +++ b/src/denokv/_pycompat/types.py @@ -0,0 +1,18 @@ +from enum import Enum +from typing import Literal # noqa: TID251 + +from denokv._pycompat.typing import TypeAlias + + +class NotSetEnum(Enum): + NotSet = "NotSet" + """ + Sentinel value to use as an argument default. + + It's purpose is to differentiate the argument not being set from an explicit + None value (or similar). + """ + + +NotSetType: TypeAlias = Literal[NotSetEnum.NotSet] +NotSet = NotSetEnum.NotSet diff --git a/src/denokv/_pycompat/typing.py b/src/denokv/_pycompat/typing.py index f809a42..bf8941f 100644 --- a/src/denokv/_pycompat/typing.py +++ b/src/denokv/_pycompat/typing.py @@ -6,19 +6,22 @@ without needing if TYPE_CHECKING everywhere. """ +# ruff: noqa: TID251 + from __future__ import annotations from dataclasses import dataclass from dataclasses import field -from typing import IO as IO -from typing import TYPE_CHECKING as TYPE_CHECKING # Everything that exist in typing >=py39 except: # - ByteString (deprecated) # - overload (ruff does not recognise it when re-exported) # - Literal (ruff does not recognise it when re-exported) # - Handled below due to runtime differences: -# - TypeVar +# - TypeVar (does not support default argument pre py313) +# - TypedDict (does not support generics pre py311) +from typing import IO as IO +from typing import TYPE_CHECKING as TYPE_CHECKING from typing import AbstractSet as AbstractSet from typing import Annotated as Annotated from typing import Any as Any @@ -79,7 +82,6 @@ from typing import TextIO as TextIO from typing import Tuple as Tuple from typing import Type as Type -from typing import TypedDict as TypedDict from typing import Union as Union from typing import ValuesView as ValuesView from typing import cast as cast @@ -167,6 +169,18 @@ def override(method, /): return method +if TYPE_CHECKING: + from typing_extensions import TypedDict as TypedDict +else: + + class TypedDict(dict): + def __new__(cls, *args, **kwargs): + return dict(*args, **kwargs) + + @classmethod + def __init_subclass__(cls, total: bool = True) -> None: ... + + def assert_never(value: Never, /) -> Never: """Assert to the type checker that a line of code is unreachable.""" raise AssertionError(f"Expected code to be unreachable, but got: {value!r}") diff --git a/src/denokv/_utils.py b/src/denokv/_utils.py new file mode 100644 index 0000000..41cb365 --- /dev/null +++ b/src/denokv/_utils.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from dataclasses import FrozenInstanceError + +from denokv._pycompat.typing import TypeVar + +TypeT = TypeVar("TypeT", bound=type) + + +def frozen_setattr(cls: type, name: str, value: object) -> None: + raise FrozenInstanceError(f"Cannot assign to field {name!r}") + + +def frozen_delattr(cls: type, name: str) -> None: + raise FrozenInstanceError(f"Cannot delete field {name!r}") + + +def frozen(cls: TypeT) -> TypeT: + """Disable `__setattr__` and `__delattr__`, much like @dataclass(frozen=True).""" + cls.__setattr__ = frozen_setattr # type: ignore[method-assign,assignment] + cls.__delattr__ = frozen_delattr # type: ignore[method-assign,assignment] + return cls diff --git a/src/denokv/auth.py b/src/denokv/auth.py index a18219e..1e460eb 100644 --- a/src/denokv/auth.py +++ b/src/denokv/auth.py @@ -1,5 +1,6 @@ from __future__ import annotations +import functools from dataclasses import dataclass from datetime import datetime from uuid import UUID @@ -54,10 +55,27 @@ class DatabaseMetadata: expires_at: datetime +@functools.total_ordering class ConsistencyLevel(StrEnum): + """ + A read consistency requirement for a Deno KV Database server endpoint. + + Examples + -------- + Levels are ordered by amount of consistency — strong greater than eventual. + + >>> assert ConsistencyLevel.STRONG > ConsistencyLevel.EVENTUAL + >>> assert ConsistencyLevel.EVENTUAL < ConsistencyLevel.STRONG + """ + STRONG = "strong" EVENTUAL = "eventual" + def __lt__(self, value: object) -> bool: + if not isinstance(value, ConsistencyLevel): + return NotImplemented + return self is ConsistencyLevel.EVENTUAL and value is ConsistencyLevel.STRONG + @dataclass(frozen=True, **slots_if310()) class EndpointInfo: diff --git a/src/denokv/backoff.py b/src/denokv/backoff.py index 61380f8..fae62c5 100644 --- a/src/denokv/backoff.py +++ b/src/denokv/backoff.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from enum import IntEnum from itertools import count -from typing import Literal +from typing import Literal # noqa: TID251 from denokv._pycompat.typing import Callable from denokv._pycompat.typing import Iterable diff --git a/src/denokv/datapath.py b/src/denokv/datapath.py index fd8372a..c868697 100644 --- a/src/denokv/datapath.py +++ b/src/denokv/datapath.py @@ -9,7 +9,7 @@ from dataclasses import dataclass from enum import Enum from enum import auto -from typing import overload +from typing import overload # noqa: TID251 import aiohttp import aiohttp.client_exceptions @@ -18,16 +18,22 @@ from google.protobuf.message import Error as ProtobufMessageError from v8serialize import Decoder +from denokv._datapath_pb2 import AtomicWrite +from denokv._datapath_pb2 import AtomicWriteOutput +from denokv._datapath_pb2 import AtomicWriteStatus +from denokv._datapath_pb2 import Check from denokv._datapath_pb2 import KvEntry from denokv._datapath_pb2 import ReadRange from denokv._datapath_pb2 import SnapshotRead from denokv._datapath_pb2 import SnapshotReadOutput from denokv._datapath_pb2 import SnapshotReadStatus from denokv._datapath_pb2 import ValueEncoding +from denokv._pycompat.typing import AbstractSet from denokv._pycompat.typing import Awaitable from denokv._pycompat.typing import Callable from denokv._pycompat.typing import Container from denokv._pycompat.typing import Final +from denokv._pycompat.typing import Iterable from denokv._pycompat.typing import Protocol from denokv._pycompat.typing import Type from denokv._pycompat.typing import TypeAlias @@ -78,6 +84,8 @@ def kv_key_bytes(self) -> bytes: ... @runtime_checkable class KvKeyRangeEncodable(Container[AnyKvKey], Protocol): + __slots__ = () + def kv_key_range_bytes(self) -> tuple[bytes, bytes]: ... @@ -167,6 +175,12 @@ def __init__( self.status = status self.body_text = body_text + def __str__(self) -> str: + return ( + f"{super().__str__()}: HTTP response: status={self.status}, " + f"body_text={self.body_text!r}" + ) + class RequestUnsuccessful(DataPathDenoKvError): """Unable to make a Data Path request to the KV server.""" @@ -174,6 +188,54 @@ class RequestUnsuccessful(DataPathDenoKvError): pass +@dataclass(init=False) +class CheckFailure(DataPathDenoKvError): + """ + The KV server could not complete an Atomic Write because of a concurrent change. + + This is an expected response to Atomic Write requests that occurs when one + or more of the checks an Atomic Write is conditional on are found to not + hold at the point that the database attempts to commit the write, because + another Atomic Write has written new version(s) of the key(s) referenced by + the check(s). The client must re-read the keys it was attempting to write, + and submit a new Atomic Write if necessary that reflects the latest state of + the keys. + """ + + all_checks: tuple[Check, ...] + """All of the Checks sent with the AtomicWrite.""" + failed_check_indexes: AbstractSet[int] | None + """ + The indexes of Checks in all_checks keys whose versionstamp check failed. + + The set is sorted with ascending iteration order. Will be None if the + database does not support reporting which checks failed. + """ + + def __init__( + self, + message: str, + all_checks: Iterable[Check], + failed_check_indexes: Iterable[int] | None, + *args: object, + endpoint: EndpointInfo, + ) -> None: + super().__init__(message, *args, endpoint=endpoint, auto_retry=AutoRetry.NEVER) + + self.all_checks = tuple(all_checks) + if len(self.all_checks) == 0: + raise ValueError("all_checks is empty") + + ordered_indexes = sorted(failed_check_indexes) if failed_check_indexes else [] + if len(ordered_indexes) > 0 and ( + ordered_indexes[0] < 0 or ordered_indexes[-1] >= len(self.all_checks) + ): + raise IndexError("failed_check_indexes contains out-of-bounds index") + self.failed_check_indexes = ( + {i: True for i in ordered_indexes}.keys() if ordered_indexes else None + ) + + DataPathError: TypeAlias = Union[ EndpointNotUsable, RequestUnsuccessful, ResponseUnsuccessful, ProtocolViolation ] @@ -181,7 +243,7 @@ class RequestUnsuccessful(DataPathDenoKvError): class _DataPathRequestKind(Enum): SnapshotRead = "snapshot_read" - SnapshotWrite = "snapshot_write" + AtomicWrite = "atomic_write" Watch = "watch" @@ -373,6 +435,141 @@ async def snapshot_read( return Ok(read_output) +AtomicWriteResult: TypeAlias = Result[bytes, Union[CheckFailure, DataPathError]] + + +async def atomic_write( + *, + session: aiohttp.ClientSession, + meta: DatabaseMetadata, + endpoint: EndpointInfo, + write: AtomicWrite, +) -> AtomicWriteResult: + """ + Perform a Data Path Atomic Write request against a database endpoint. + + The endpoint must have strong consistency. The write is conditional on the + checks of the provided AtomicWrite passing. Callers must expect to need to + retry a write when these checks are not satisfied due to another write + having modified a checked key. The result is an Err containing a + [CheckFailure](`denokv.datapath.CheckFailure`) when checks fail. + + When the write succeeds, the return value is the 10-byte versionstamp of the + committed version. + + The request does not retry on error conditions, the caller is responsible + for retrying if they wish. The Err results report whether retries are + permitted by the Data Path protocol spec using their `auto_retry: AutoRetry` + field. + + Returns + ------- + Ok[bytes]: + 10-byte versionstamp when the write succeeds + Err[CheckFailure]: + When one or more of the AtomicWrite's checks are not satisfied. + Err[ProtocolViolation]: + When the endpoint sends an unexpected response violating the protocol + spec. + Err[RequestUnsuccessful]: + When the request cannot be sent, e.g. due to a network error. + Err[ResponseUnsuccessful]: + When the request is not handled successfully by the endpoint, e.g. due + to a the service being unavailable. + """ + if endpoint.consistency is not ConsistencyLevel.STRONG: + raise ValueError( + f"endpoints used with atomic_write must be " + f"{ConsistencyLevel.STRONG!r}: {endpoint}" + ) + + result = await _datapath_request( + kind=_DataPathRequestKind.AtomicWrite, + session=session, + meta=meta, + endpoint=endpoint, + request_body=write.SerializeToString(), + handle_response=_response_body_bytes, + ) + if isinstance(result, Err): + return result + response_bytes = result.value + + try: + write_output = AtomicWriteOutput.FromString(response_bytes) + except ProtobufMessageError as e: + err = ProtocolViolation( + "Server responded to Data Path request with invalid AtomicWriteOutput", + data=response_bytes, + endpoint=endpoint, + ) + err.__cause__ = e + return Err(err) + + if write_output.status == AtomicWriteStatus.AW_SUCCESS: + if len(write_output.failed_checks) != 0: + return Err( + ProtocolViolation( + "Server responded to Data Path Atomic Write with " + "SUCCESS containing failed checks", + data=write_output, + endpoint=endpoint, + ) + ) + if len(write_output.versionstamp) != 10: + return Err( + ProtocolViolation( + "Server responded to Data Path Atomic Write with " + "SUCCESS containing an invalid versionstamp", + data=write_output, + endpoint=endpoint, + ) + ) + return Ok(write_output.versionstamp) + elif write_output.status == AtomicWriteStatus.AW_CHECK_FAILURE: + try: + return Err( + CheckFailure( + "Not all checks required by the Atomic Write passed", + all_checks=write.checks, + failed_check_indexes=write_output.failed_checks, + endpoint=endpoint, + ) + ) + except IndexError as e: + err = ProtocolViolation( + "Server responded to Data Path Atomic Write with " + "CHECK_FAILURE referencing out-of-bounds check index", + data=write_output, + endpoint=endpoint, + ) + err.__cause__ = e + return Err(err) + elif write_output.status == AtomicWriteStatus.AW_WRITE_DISABLED: + return Err( + EndpointNotUsable( + "Server responded to Data Path request indicating it is cannot " + "write this database", + endpoint=endpoint, + reason=EndpointNotUsableReason.DISABLED, + ) + ) + else: + msg = ( + "UNSPECIFIED" + if write_output.status == AtomicWriteStatus.AW_UNSPECIFIED + else f"unknown: {write_output.status}" + ) + return Err( + ProtocolViolation( + f"Server responded to Data Path Atomic Write request with " + f"status {msg}", + data=write_output, + endpoint=endpoint, + ) + ) + + def is_kv_key_tuple(tup: object) -> TypeGuard[KvKeyTuple]: """Check if a tuple only contains valid KV key tuple type values.""" return isinstance(tup, tuple) and all( diff --git a/src/denokv/errors.py b/src/denokv/errors.py index 83734a0..db59de0 100644 --- a/src/denokv/errors.py +++ b/src/denokv/errors.py @@ -1,20 +1,22 @@ from dataclasses import dataclass -from denokv._pycompat.typing import cast +from denokv._pycompat.typing import TYPE_CHECKING @dataclass(init=False) -class DenoKvError(BaseException): - message: str +class DenoKvError(Exception): + # Define message for dataclass field metadata only, not type annotation. + if not TYPE_CHECKING: + message: str - def __init__(self, message: str, *args: object) -> None: - super().__init__(message, *args) - if not isinstance(message, str): - raise TypeError(f"first argument must be a str message: {message!r}") + def __init__(self, *args: object) -> None: + super(DenoKvError, self).__init__(*args) - @property # type: ignore[no-redef] + @property def message(self) -> str: - return cast(str, self.args[0]) + if args := self.args: + return str(args[0]) + return type(self).__name__ class DenoKvValidationError(ValueError, DenoKvError): diff --git a/src/denokv/kv.py b/src/denokv/kv.py index cde6e88..3a1866d 100644 --- a/src/denokv/kv.py +++ b/src/denokv/kv.py @@ -4,31 +4,59 @@ import weakref from base64 import urlsafe_b64decode from base64 import urlsafe_b64encode -from binascii import unhexlify from contextlib import AbstractAsyncContextManager from dataclasses import dataclass from dataclasses import field from enum import Flag from enum import auto +from functools import partial from os import environ from types import TracebackType -from typing import Literal -from typing import overload +from typing import Literal # noqa: TID251 +from typing import overload # noqa: TID251 import aiohttp +import v8serialize from fdb.tuple import unpack from v8serialize import Decoder +from v8serialize import Encoder from yarl import URL +from denokv import _datapath_pb2 as dp_protobuf from denokv import datapath -from denokv._datapath_pb2 import ReadRange +from denokv._datapath_pb2 import AtomicWrite from denokv._datapath_pb2 import SnapshotRead from denokv._datapath_pb2 import SnapshotReadOutput +from denokv._kv_types import AtomicWriteRepresentationWriter +from denokv._kv_types import KvWriter +from denokv._kv_types import KvWriterWriteResult +from denokv._kv_types import WriteResultT +from denokv._kv_values import KvEntry +from denokv._kv_values import KvU64 +from denokv._kv_values import VersionStamp +from denokv._kv_writes import Check +from denokv._kv_writes import CheckMixin +from denokv._kv_writes import CheckRepresentation +from denokv._kv_writes import CompletedWrite +from denokv._kv_writes import DeleteMutatorMixin +from denokv._kv_writes import Enqueue +from denokv._kv_writes import EnqueueMixin +from denokv._kv_writes import MaxMutatorMixin +from denokv._kv_writes import MinMutatorMixin +from denokv._kv_writes import Mutation +from denokv._kv_writes import MutationRepresentation +from denokv._kv_writes import PlannedWrite +from denokv._kv_writes import SetMutatorMixin +from denokv._kv_writes import SumMutatorMixin +from denokv._kv_writes import WriteOperation from denokv._pycompat.dataclasses import slots_if310 +from denokv._pycompat.types import NotSet +from denokv._pycompat.types import NotSetType +from denokv._pycompat.typing import Any from denokv._pycompat.typing import AsyncIterator from denokv._pycompat.typing import Awaitable from denokv._pycompat.typing import Callable -from denokv._pycompat.typing import ClassVar +from denokv._pycompat.typing import Coroutine from denokv._pycompat.typing import Final from denokv._pycompat.typing import Generic from denokv._pycompat.typing import Iterable @@ -39,6 +67,7 @@ from denokv._pycompat.typing import TypedDict from denokv._pycompat.typing import TypeVar from denokv._pycompat.typing import TypeVarTuple +from denokv._pycompat.typing import Union from denokv._pycompat.typing import Unpack from denokv._pycompat.typing import override from denokv.asyncio import loop_time @@ -53,12 +82,13 @@ from denokv.datapath import AnyKvKey from denokv.datapath import AnyKvKeyT from denokv.datapath import AutoRetry +from denokv.datapath import CheckFailure +from denokv.datapath import DataPathDenoKvError from denokv.datapath import DataPathError from denokv.datapath import KvKeyEncodable from denokv.datapath import KvKeyPiece from denokv.datapath import KvKeyTuple from denokv.datapath import ProtocolViolation -from denokv.datapath import SnapshotReadResult from denokv.datapath import is_kv_key_tuple from denokv.datapath import pack_key from denokv.datapath import parse_protobuf_kv_entry @@ -69,17 +99,30 @@ from denokv.result import Err from denokv.result import Ok from denokv.result import Result +from denokv.result import is_ok T = TypeVar("T", default=object) # Note that the default arg doesn't seem to work with MyPy yet. The # DefaultKvKey alias is what this should behave as when defaulted. Pieces = TypeVarTuple("Pieces", default=Unpack[tuple[KvKeyPiece, ...]]) +_DataPathErrorT = TypeVar("_DataPathErrorT", bound=DataPathDenoKvError) SAFE_FLOAT_INT_RANGE: Final = range(-(2**53 - 1), 2**53) # 2**53 - 1 is max safe CursorFormatType: TypeAlias = Callable[["ListContext"], "AnyCursorFormat"] +def v8_encode_int_as_bigint( + value: object, + ctx: v8serialize.encode.EncodeContext, + next: v8serialize.encode.EncodeNextFn, +) -> None: + if isinstance(value, int): + ctx.stream.write_bigint(value) + else: + next(value) + + class KvListOptions(TypedDict, total=False): """Keyword arguments of `Kv.list()`.""" @@ -91,15 +134,6 @@ class KvListOptions(TypedDict, total=False): cursor_format_type: CursorFormatType | None -@dataclass(frozen=True, **slots_if310()) -class KvEntry(Generic[AnyKvKeyT, T]): - """A value read from the Deno KV database, along with its key and version.""" - - key: AnyKvKeyT - value: T - versionstamp: VersionStamp - - @dataclass(frozen=True, **slots_if310()) class ListKvEntry(KvEntry[AnyKvKeyT, T]): """ @@ -120,121 +154,6 @@ def cursor(self) -> str: return result.value -class VersionStamp(bytes): - r""" - A 20-hex-char / (10 byte) version identifier. - - This value represents the relative age of a KvEntry. A VersionStamp that - compares larger than another is newer. - - Examples - -------- - >>> VersionStamp(0xff << 16) - VersionStamp('00000000000000ff0000') - >>> int(VersionStamp('000000000000000000ff')) - 255 - >>> bytes(VersionStamp('00000000000000ff0000')) - b'\x00\x00\x00\x00\x00\x00\x00\xff\x00\x00' - >>> VersionStamp(b'\x00\x00\x00\x00\x00\x00\x00\xff\x00\x00') - VersionStamp('00000000000000ff0000') - >>> isinstance(VersionStamp(0), bytes) - True - >>> str(VersionStamp(0xff << 16)) - '00000000000000ff0000' - """ - - RANGE: ClassVar = range(0, 2**80) - - def __new__(cls, value: str | bytes | int) -> Self: - if isinstance(value, int): - if value not in VersionStamp.RANGE: - raise ValueError("value not in range for 80-bit unsigned int") - # Unlike most others, versionstamp uses big-endian as it needs to - # sort lexicographically as bytes. - value = value.to_bytes(length=10, byteorder="big") - if isinstance(value, str): - try: - value = unhexlify(value) - except Exception: - value = b"" - if len(value) != 10: - raise ValueError("value is not a 20 char hex string") - else: - if len(value) != 10: - raise ValueError("value is not 10 bytes long") - return bytes.__new__(cls, value) - - def __index__(self) -> int: - return int.from_bytes(self, byteorder="big") - - def __bytes__(self) -> bytes: - return self[:] - - def __str__(self) -> str: - return self.hex() - - def __repr__(self) -> str: - return f"{type(self).__name__}({str(self)!r})" - - -@dataclass(frozen=True, **slots_if310()) -class KvU64: - """ - An special int value that supports operations like `sum`, `max`, and `min`. - - Notes - ----- - This type is not an int subtype to avoid it being mistakenly flattened into - a regular int and loosing its special meaning when written back to the DB. - - Examples - -------- - >>> KvU64(bytes([0, 0, 0, 0, 0, 0, 0, 0])) - KvU64(0) - >>> KvU64(bytes([1, 0, 0, 0, 0, 0, 0, 0])) - KvU64(1) - >>> KvU64(bytes([1, 1, 0, 0, 0, 0, 0, 0])) - KvU64(257) - >>> KvU64(2**64 - 1) - KvU64(18446744073709551615) - >>> KvU64(2**64) - Traceback (most recent call last): - ... - ValueError: value not in range for 64-bit unsigned int - >>> KvU64(-1) - Traceback (most recent call last): - ... - ValueError: value not in range for 64-bit unsigned int - """ - - RANGE: ClassVar = range(0, 2**64) - value: int - - def __init__(self, value: bytes | int) -> None: - if isinstance(value, bytes): - if len(value) != 8: - raise ValueError("value must be a 8 bytes") - value = int.from_bytes(value, byteorder="little") - elif isinstance(value, int): - if value not in KvU64.RANGE: - raise ValueError("value not in range for 64-bit unsigned int") - else: - raise TypeError("value must be 8 bytes or a 64-bit unsigned int") - object.__setattr__(self, "value", value) - - def __index__(self) -> int: - return self.value - - def __bytes__(self) -> bytes: - return self.to_bytes() - - def to_bytes(self) -> bytes: - return self.value.to_bytes(8, byteorder="little") - - def __repr__(self) -> str: - return f"{type(self).__name__}({self.value})" - - @dataclass(frozen=True, **slots_if310()) class EndpointSelector: # Right now this is very simple, which is fine for the local SQLite-backed @@ -294,7 +213,7 @@ class KvCredentials: access_token: str -@dataclass +@dataclass(frozen=True, **slots_if310()) class Authenticator: """ Authenticates with a KV database server and returns its metadata. @@ -462,7 +381,17 @@ class KvFlags(Flag): @dataclass(init=False) -class Kv(AbstractAsyncContextManager["Kv", None]): +class Kv( + CheckMixin[Awaitable[bool]], + SetMutatorMixin[Awaitable[VersionStamp]], + SumMutatorMixin[Awaitable[VersionStamp]], + MinMutatorMixin[Awaitable[VersionStamp]], + MaxMutatorMixin[Awaitable[VersionStamp]], + DeleteMutatorMixin[Awaitable[VersionStamp]], + EnqueueMixin[Awaitable[VersionStamp]], + KvWriter, + AbstractAsyncContextManager["Kv", None], +): """ Interface to perform requests against a Deno KV database. @@ -477,6 +406,7 @@ class Kv(AbstractAsyncContextManager["Kv", None]): session: aiohttp.ClientSession retry_delays: Backoff metadata_cache: DatabaseMetadataCache + v8_encoder: Encoder v8_decoder: Decoder flags: KvFlags @@ -485,12 +415,14 @@ def __init__( session: aiohttp.ClientSession, auth: AuthenticatorFn, retry: Backoff | None = None, + v8_encoder: Encoder | None = None, v8_decoder: Decoder | None = None, flags: KvFlags | None = None, ) -> None: self.session = session self.metadata_cache = DatabaseMetadataCache(authenticator=auth) self.retry_delays = ExponentialBackoff() if retry is None else retry + self.v8_encoder = v8_encoder or Encoder() self.v8_decoder = v8_decoder or Decoder() self.flags = KvFlags.IntAsNumber if flags is None else flags @@ -604,7 +536,7 @@ async def get( args = tuple(self._prepare_key(key) for key in args) ranges = [read_range_single(key) for key in args] snapshot_read_result = await self._snapshot_read( - ranges, consistency=consistency + dp_protobuf.SnapshotRead(ranges=ranges), consistency=consistency ) if isinstance(snapshot_read_result, Err): raise snapshot_read_result.error @@ -754,7 +686,7 @@ async def list( ) snapshot_read_result = await self._snapshot_read( - ranges=[read_range], consistency=consistency + dp_protobuf.SnapshotRead(ranges=[read_range]), consistency=consistency ) if isinstance(snapshot_read_result, Err): raise snapshot_read_result.error @@ -806,10 +738,34 @@ async def list( batch_start = parsed_key async def _snapshot_read( - self, ranges: Sequence[ReadRange], *, consistency: ConsistencyLevel + self, read: SnapshotRead, *, consistency: ConsistencyLevel ) -> _KvSnapshotReadResult: - read = SnapshotRead(ranges=ranges) - result: SnapshotReadResult + return await self._datapath_request( + partial(datapath.snapshot_read, read=read), consistency=consistency + ) + + @staticmethod + def _parse_versionstamp( + value: tuple[bytes, EndpointInfo], + ) -> tuple[VersionStamp, EndpointInfo]: + raw_versionstamp, endpoint = value + return VersionStamp(raw_versionstamp), endpoint + + async def _atomic_write(self, write: AtomicWrite) -> _KvAtomicWriteResult: + return ( + await self._datapath_request( + partial(datapath.atomic_write, write=write), + consistency=ConsistencyLevel.STRONG, + ) + ).map(self._parse_versionstamp) + + async def _datapath_request( + self, + datapath_request: partial[Coroutine[Any, Any, Result[T, _DataPathErrorT]]], + *, + consistency: ConsistencyLevel, + ) -> Result[tuple[T, EndpointInfo], _DataPathErrorT]: + result: Result[T, _DataPathErrorT] endpoint: EndpointInfo for delay in attempts(self.retry_delays): # return error from this? @@ -828,11 +784,8 @@ async def _snapshot_read( endpoints = EndpointSelector(meta=cached_meta.value) endpoint = endpoints.get_endpoint(consistency) - result = await datapath.snapshot_read( - session=self.session, - meta=cached_meta.value, - endpoint=endpoint, - read=read, + result = await datapath_request( + session=self.session, meta=cached_meta.value, endpoint=endpoint ) if isinstance(result, Err): if result.error.auto_retry is AutoRetry.AFTER_BACKOFF: @@ -850,13 +803,106 @@ async def _snapshot_read( assert isinstance(result, Ok) return Ok((result.value, endpoint)) + def atomic(self, *operations: WriteOperation) -> PlannedWrite: + write = PlannedWrite(kv=self) + for op in operations: + if isinstance(op, Check): + write.check(op) + elif isinstance(op, Mutation): + write.mutate(op) + else: + assert isinstance(op, Enqueue) + write.enqueue(op) + return write + + @overload + async def write(self, *operations: WriteOperation) -> CompletedWrite: ... + + @overload + async def write(self, planned_write: PlannedWrite, /) -> CompletedWrite: ... + + @overload + async def write( + self, atomic_write: AtomicWriteRepresentationWriter[WriteResultT], / + ) -> WriteResultT: ... + + @overload + async def write( + self, *, protobuf_atomic_write: dp_protobuf.AtomicWrite + ) -> KvWriterWriteResult: ... + + @override + async def write( + self, + arg: AtomicWriteRepresentationWriter[WriteResultT] + | WriteOperation + | NotSetType = NotSet, # NotSet is a sentinel to detect 0 args + *args: WriteOperation, + protobuf_atomic_write: dp_protobuf.AtomicWrite | None = None, + ) -> CompletedWrite | WriteResultT | KvWriterWriteResult: + if protobuf_atomic_write is not None: + if arg is not NotSet or len(args) > 0: + raise TypeError( + "Kv.write() got an unexpected positional argument with " + "keyword argument 'protobuf_atomic_write'" + ) + + return await self._atomic_write(protobuf_atomic_write) + + planned_write: PlannedWrite | AtomicWriteRepresentationWriter[WriteResultT] + if arg is NotSet: + # arg is NotSet when 0 args were passed, which is OK (no operations). + # But NotSet when args are provided means it was passed explicitly. + if args: + raise TypeError("Kv.write() got an unexpected 'NotSet'") + # Note that it's OK to submit a write with no operations. We get a + # versionstamp back. Submitting a write with only checks could be + # used to check if a key has been changed without reading the value. + planned_write = PlannedWrite() + elif isinstance(arg, AtomicWriteRepresentationWriter): + planned_write = arg + if args: + raise TypeError( + "Kv.write() got unexpected arguments after 'planned_write'" + ) + else: + planned_write = self.atomic(arg, *args) + + return await planned_write.write(kv=self, v8_encoder=self.v8_encoder) + + @override + async def _check(self, check: CheckRepresentation, /) -> bool: + return is_ok(await self.write(check)) + + @override + async def mutate(self, mutation: MutationRepresentation) -> VersionStamp: + result = await self.write(mutation) + if is_ok(result): + return result.versionstamp + # This is a write conflict which we don't expect to occur, because the + # shortcut mutation methods (like set(), sum(), etc) don't include + # checks. + raise result + + @override + async def _enqueue(self, enqueue: Enqueue, /) -> VersionStamp: + result = await self.write(enqueue) + if is_ok(result): + return result.versionstamp + # This is a write conflict which we don't expect to occur, because the + # enqueue() shortcut doesn't include checks. + raise result + _KvSnapshotReadResult: TypeAlias = Result[ tuple[SnapshotReadOutput, EndpointInfo], DataPathError ] +_KvAtomicWriteResult: TypeAlias = Result[ + tuple[VersionStamp, EndpointInfo], Union[CheckFailure, DataPathError] +] -@dataclass(frozen=True) +@dataclass(frozen=True, **slots_if310()) class ListContext: prefix: AnyKvKey | None start: AnyKvKey | None @@ -879,12 +925,14 @@ def __post_init__(self) -> None: class AnyCursorFormat(Protocol): + __slots__ = () + def get_key_for_cursor(self, cursor: str) -> Result[KvKeyTuple, InvalidCursor]: ... def get_cursor_for_key(self, key: AnyKvKey) -> Result[str, ValueError]: ... -@dataclass(frozen=True) +@dataclass(frozen=True, **slots_if310()) class Base64KeySuffixCursorFormat(AnyCursorFormat): r""" A cursor format that encodes keys as URL-safe base64. diff --git a/src/denokv/kv_keys.py b/src/denokv/kv_keys.py index 0221c5d..a9736fe 100644 --- a/src/denokv/kv_keys.py +++ b/src/denokv/kv_keys.py @@ -4,7 +4,7 @@ import sys from dataclasses import dataclass from dataclasses import field -from typing import overload +from typing import overload # noqa: TID251 from fdb.tuple import pack from fdb.tuple import unpack @@ -333,6 +333,8 @@ def __repr__(self) -> str: class Include(_KeyBoundary[AnyKvKeyT_co]): """KvKeyRange boundary that includes its key in the range.""" + __slots__ = () + if TYPE_CHECKING: # For some reason mypy only infers types of Pieces using new not init @overload @@ -351,6 +353,7 @@ def range(self) -> KvKeyRange[Self, Self]: class IncludePrefix(_KeyBoundary[AnyKvKeyT_co]): """KvKeyRange boundary that includes keys prefixed by its key in the range.""" + __slots__ = () if TYPE_CHECKING: # For some reason mypy only infers types of Pieces using new not init @overload @@ -388,6 +391,7 @@ def range(self) -> KvKeyRange[Include[AnyKvKeyT_co], Self]: class Exclude(_KeyBoundary[AnyKvKeyT_co]): """KvKeyRange boundary that excludes its key from the range.""" + __slots__ = () if TYPE_CHECKING: # For some reason mypy only infers types of Pieces using new not init @overload diff --git a/src/denokv/result.py b/src/denokv/result.py index 68ba208..d88b2a4 100644 --- a/src/denokv/result.py +++ b/src/denokv/result.py @@ -2,7 +2,7 @@ from abc import ABCMeta from dataclasses import dataclass -from typing import overload +from typing import overload # noqa: TID251 from denokv._pycompat.dataclasses import slots_if310 from denokv._pycompat.typing import TYPE_CHECKING @@ -29,11 +29,15 @@ @runtime_checkable class AnySuccess(Protocol, metaclass=ABCMeta): + __slots__ = () + def _AnySuccess_marker(self, no_call: Never) -> Never: ... @runtime_checkable class AnyFailure(Protocol, metaclass=ABCMeta): + __slots__ = () + def _AnyFailure_marker(self, no_call: Never) -> Never: ... @@ -902,6 +906,23 @@ def or_else(self, fn: Callable[[], Result[T, U]]) -> Result[T_co | T, U]: >>> assert Err('error a').or_else(lambda: Err('error b')) == Err('error b') """ + def or_raise(self) -> Ok[T_co]: + """ + Return the Ok as-is, or raise the Err.error value if this is Err. + + Examples + -------- + >>> assert Ok(1).or_raise() == Ok(1) + + >>> Err(ValueError('bad')).or_raise() + Traceback (most recent call last): + ValueError: bad + + >>> Err('foo').or_raise() + Traceback (most recent call last): + Exception: foo + """ + def value_or(self, default: U) -> T_co | U: """ Return the Ok's value, or default if this is Err. @@ -922,6 +943,23 @@ def value_or_else(self, fn: Callable[[], U]) -> T_co | U: >>> assert Err('x').value_or_else(lambda: 2) == 2 """ + def value_or_raise(self) -> T_co: + """ + Return the Ok's value, or raise the Err.error value if this is Err. + + Examples + -------- + >>> assert Ok(1).value_or_raise() == 1 + + >>> Err(ValueError('bad')).value_or_raise() + Traceback (most recent call last): + ValueError: bad + + >>> Err('foo').value_or_raise() + Traceback (most recent call last): + Exception: foo + """ + def __iter__(self) -> Iterator[T_co]: """ Return an iterator containing the Ok's value or no values if this is Err. @@ -1111,6 +1149,10 @@ def or_(self, default: Result[T, U]) -> Result[T_co, U]: def or_else(self, fn: Callable[[], Result[T, U]]) -> Result[T_co, U]: return self + @doc_from(ResultMethods) + def or_raise(self) -> Ok[T_co]: + return self + @doc_from(ResultMethods) def value_or(self, default: U) -> T_co: return self.value @@ -1119,6 +1161,10 @@ def value_or(self, default: U) -> T_co: def value_or_else(self, fn: Callable[[], U]) -> T_co: return self.value + @doc_from(ResultMethods) + def value_or_raise(self) -> T_co: + return self.value + def __iter__(self) -> Iterator[T_co]: return iter((self.value,)) @@ -1202,6 +1248,12 @@ def or_(self, default: Result[T_co, U]) -> Result[T_co, U]: def or_else(self, fn: Callable[[], Result[T_co, U]]) -> Result[T_co, U]: return fn() + @doc_from(ResultMethods) + def or_raise(self) -> Never: + if isinstance(self.error, BaseException): + raise self.error + raise Exception(self.error) + if not TYPE_CHECKING: @property @@ -1216,6 +1268,12 @@ def value_or(self, x_default: U) -> U: def value_or_else(self, fn: Callable[[], U]) -> U: return fn() + @doc_from(ResultMethods) + def value_or_raise(self) -> Never: + if isinstance(self.error, BaseException): + raise self.error + raise Exception(self.error) + def __iter__(self) -> Iterator[Never]: return iter(()) diff --git a/stubs/fdb/tuple.pyi b/stubs/fdb/tuple.pyi index ba32142..c211cca 100644 --- a/stubs/fdb/tuple.pyi +++ b/stubs/fdb/tuple.pyi @@ -1,3 +1,5 @@ +# ruff: noqa: TID251 + import ctypes from typing import Hashable from uuid import UUID diff --git a/test/conftest.py b/test/conftest.py index e6c8d8a..9cecf7c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,3 +1,69 @@ +from __future__ import annotations + +import os + +import pytest +from google.protobuf.message import Message +from hypothesis import Verbosity +from hypothesis import settings +from pytest import Config +from v8serialize import Encoder + +from denokv._pycompat.typing import Sequence from test import advance_time +from test.denokv_testing import diff_protobuf_messages + +settings.register_profile("ci", max_examples=1000) +settings.register_profile("dev", max_examples=10) +settings.register_profile("debug", max_examples=10, verbosity=Verbosity.verbose) +settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "default").lower()) advance_time_time = advance_time.advance_time_time + +_pytest_assertion_verbosity: int = 0 + + +def pytest_configure(config: Config) -> None: + global _pytest_assertion_verbosity + _pytest_assertion_verbosity = config.get_verbosity(Config.VERBOSITY_ASSERTIONS) + + +# Provide descriptive diffs for failed protobuf message equality assertions +def pytest_assertrepr_compare( + op: str, left: object, right: object +) -> Sequence[str] | None: + if isinstance(left, Message) and isinstance(right, Message): + repr_left = f"<{left.DESCRIPTOR.full_name} protobuf message at {hex(id(left))}>" + repr_right = ( + f"<{right.DESCRIPTOR.full_name} protobuf message at {hex(id(right))}>" + ) + comparison = [f"{repr_left} {op} {repr_right}"] + + if left == right: + comparison.append("Protobuf messages are equal") + return comparison + if type(left) is not type(right): + comparison.append("Protobuf messages are different types") + return comparison + + end = ( + " (use -v for diff)" + if _pytest_assertion_verbosity == 0 + else " (repeat -v for more context):" + ) + comparison.append(f"Protobuf messages are not equal{end}") + if _pytest_assertion_verbosity == 0: + return comparison + + # Scale context lines with verbosity level: 1=3, 2=9, 3=27, 4=81, 5=243 + context = 3**_pytest_assertion_verbosity + comparison.extend( + diff_protobuf_messages(left, right, context_line_count=context, lineterm="") + ) + return comparison + return None + + +@pytest.fixture(scope="session") +def v8_encoder() -> Encoder: + return Encoder() diff --git a/test/denokv_testing.py b/test/denokv_testing.py index e257ca2..f388ffc 100644 --- a/test/denokv_testing.py +++ b/test/denokv_testing.py @@ -1,44 +1,81 @@ from __future__ import annotations +import difflib import math +import re import sys from base64 import b16decode from base64 import b16encode +from collections import deque from dataclasses import dataclass +from dataclasses import field from datetime import datetime from datetime import timedelta from itertools import groupby -from typing import overload +from typing import Literal # noqa: TID251 +from typing import overload # noqa: TID251 +from unittest.mock import Mock from uuid import UUID +import pytest import v8serialize +import v8serialize.encode +from aiohttp import web +from fdb.tuple import pack from fdb.tuple import unpack - +from google.protobuf.message import Message +from v8serialize.constants import SerializationTag +from v8serialize.decode import DecodeContext +from v8serialize.decode import DecodeNextFn +from v8serialize.jstypes import JSBigInt +from yarl import URL + +from denokv._datapath_pb2 import AtomicWrite +from denokv._datapath_pb2 import AtomicWriteOutput +from denokv._datapath_pb2 import AtomicWriteStatus +from denokv._datapath_pb2 import Enqueue from denokv._datapath_pb2 import KvEntry as ProtobufKvEntry from denokv._datapath_pb2 import KvValue +from denokv._datapath_pb2 import Mutation +from denokv._datapath_pb2 import MutationType from denokv._datapath_pb2 import ReadRange from denokv._datapath_pb2 import ReadRangeOutput +from denokv._datapath_pb2 import SnapshotRead +from denokv._datapath_pb2 import SnapshotReadOutput +from denokv._datapath_pb2 import SnapshotReadStatus from denokv._datapath_pb2 import ValueEncoding +from denokv._kv_values import KvEntry +from denokv._kv_values import KvU64 +from denokv._kv_values import VersionStamp +from denokv._kv_writes import LimitExceededPolicy +from denokv._pycompat.dataclasses import slots_if310 +from denokv._pycompat.enum import EvalEnumRepr +from denokv._pycompat.protobuf import enum_name from denokv._pycompat.typing import Any +from denokv._pycompat.typing import Callable from denokv._pycompat.typing import ClassVar +from denokv._pycompat.typing import Final from denokv._pycompat.typing import Iterable from denokv._pycompat.typing import Mapping from denokv._pycompat.typing import NamedTuple +from denokv._pycompat.typing import Never from denokv._pycompat.typing import Sequence +from denokv._pycompat.typing import TypeIs from denokv._pycompat.typing import TypeVar +from denokv._pycompat.typing import Union +from denokv._pycompat.typing import cast +from denokv.auth import ConsistencyLevel from denokv.auth import DatabaseMetadata from denokv.auth import EndpointInfo from denokv.datapath import AnyKvKey from denokv.datapath import KvKeyTuple +from denokv.datapath import increment_packed_key from denokv.datapath import is_kv_key_tuple from denokv.datapath import pack_key from denokv.datapath import parse_protobuf_kv_entry from denokv.errors import InvalidCursor from denokv.kv import AnyCursorFormat -from denokv.kv import KvEntry -from denokv.kv import KvU64 from denokv.kv import ListContext -from denokv.kv import VersionStamp from denokv.kv_keys import KvKey from denokv.result import Err from denokv.result import Ok @@ -49,8 +86,60 @@ T = TypeVar("T") E = TypeVar("E") E2 = TypeVar("E2") +MessageT = TypeVar("MessageT", bound=Message) + + +def mocked(mocked_value: object) -> Mock: + """Type-safely cast `mocked_value` to a Mock.""" + assert isinstance(mocked_value, Mock) + return mocked_value + + +def diff_protobuf_messages( + left: Message, + right: Message, + *, + left_name: str | None = None, + right_name: str | None = None, + context_line_count: int = 3, + lineterm: str = "\n", +) -> Sequence[str]: + if left_name is None: + left_name = f"left: {left.DESCRIPTOR.full_name}" + if right_name is None: + right_name = f"right: {right.DESCRIPTOR.full_name}" + left_lines = str(left).splitlines(keepends=False) + right_lines = str(right).splitlines(keepends=False) + return [ + *difflib.unified_diff( + left_lines, + right_lines, + fromfile=left_name, + tofile=right_name, + n=context_line_count, + lineterm=lineterm, + ) + ] + -v8_decoder = v8serialize.Decoder() +def decode_js_number_as_float( + tag: SerializationTag, /, ctx: DecodeContext, next: DecodeNextFn +) -> object: + if tag in { + SerializationTag.kInt32, + SerializationTag.kDouble, + SerializationTag.kUint32, + SerializationTag.kNumberObject, + }: + number = next(tag) + if isinstance(number, int): + return float(number) + return number + return next(tag) + + +default_v8_decoder = v8serialize.Decoder() +default_v8_encoder = v8serialize.Encoder() def assume_ok(result: Result[T, E]) -> T: @@ -77,38 +166,84 @@ def assume_err(result: Result[T, E], type: type[E2] | None = None) -> E | E2: ) -def mk_db_meta(endpoints: Sequence[EndpointInfo]) -> DatabaseMetadata: - """Create a placeholder DB meta object with the provided endpoints.""" - return DatabaseMetadata( - version=3, - database_id=UUID("00000000-0000-0000-0000-000000000000"), - expires_at=datetime.now() + timedelta(hours=1), - endpoints=[*endpoints], - token="secret", - ) +@dataclass(repr=False) +class KvNumberEncoding: + py_name: str + js_name: str | None + value_encoding: ValueEncoding + + +class KvNumber(KvNumberEncoding, EvalEnumRepr): + """ + The number types supported by datapath Sum/Min/Max operations. + + Examples + -------- + >>> f'Foo bar: {KvNumber.bigint}' + 'Foo bar: JSBigInt (VE_V8 BigInt)' + >>> f'Foo bar: {KvNumber.u64}' + 'Foo bar: KvU64 (VE_LE64)' + >>> repr(KvNumber.bigint) + 'KvNumber.bigint' + """ + + bigint = "JSBigInt", "BigInt", ValueEncoding.VE_V8 + float = "int/float", "Number", ValueEncoding.VE_V8 + u64 = "KvU64", None, ValueEncoding.VE_LE64 + + def __str__(self) -> str: + ve_name = enum_name(ValueEncoding, self.value_encoding) + if self.js_name is None: + return f"{self.py_name} ({ve_name})" + return f"{self.py_name} ({ve_name} {self.js_name})" + + +@dataclass(**slots_if310(), frozen=True) +class KvWriteValue: + data: bytes + encoding: ValueEncoding + expire_at_ms: int = field(default=0) + + @staticmethod + def tombstone() -> KvWriteValue: + return KvWriteValue(b"", ValueEncoding.VE_UNSPECIFIED, expire_at_ms=-1) class MockKvDbEntry(NamedTuple): key: bytes versionstamp: int encoding: ValueEncoding - value: bytes + data: bytes + expire_at_ms: int + + +class MockKvDbMessage(NamedTuple): + payload: object + deadline_ms: int + keys_if_undelivered: Sequence[KvKey] + backoff_schedule: Sequence[int] + + +class SumLimitExceeded(ValueError): + pass @dataclass class MockKvDb: entries: list[MockKvDbEntry] next_version: int + queued_messages: deque[MockKvDbMessage] - def __init__(self, entries: Iterable[tuple[bytes, KvValue]] = ()) -> None: + def __init__(self, entries: Iterable[tuple[bytes, KvWriteValue]] = ()) -> None: self.clear() self.extend(entries) def clear(self) -> None: self.entries = [] self.next_version = 0 + self.queued_messages = deque() - def extend(self, entries: Iterable[tuple[bytes, KvValue]]) -> None: + def extend(self, entries: Iterable[tuple[bytes, KvWriteValue]]) -> None: version = self.next_version self.next_version += 1 @@ -117,28 +252,39 @@ def extend(self, entries: Iterable[tuple[bytes, KvValue]]) -> None: key=key, versionstamp=version, encoding=kv_value.encoding, - value=kv_value.data, + data=kv_value.data, + expire_at_ms=kv_value.expire_at_ms, ) for (key, kv_value) in entries ) self.entries.sort(key=lambda e: (e.key, e.versionstamp)) def _read_range( - self, start: bytes, end: bytes, limit: int, reverse: bool + self, start: bytes, end: bytes, limit: int, reverse: bool, current_time_ms: int ) -> Sequence[MockKvDbEntry]: assert limit >= 0 matches = [e for e in self.entries if start <= e.key < end] latest_matches = [ - list(versions)[-1] - for (k, versions) in groupby(matches, key=lambda m: m.key) + ver + for ver in ( + list(versions)[-1] + for (k, versions) in groupby(matches, key=lambda m: m.key) + ) + if (ver.expire_at_ms == 0 or ver.expire_at_ms > current_time_ms) ] if reverse: latest_matches = list(reversed(latest_matches)) return latest_matches[:limit] - def snapshot_read_range(self, read: ReadRange) -> ReadRangeOutput: + def snapshot_read_range( + self, read: ReadRange, current_time_ms: int = 0 + ) -> ReadRangeOutput: entries = self._read_range( - start=read.start, end=read.end, limit=read.limit, reverse=read.reverse + start=read.start, + end=read.end, + limit=read.limit, + reverse=read.reverse, + current_time_ms=current_time_ms, ) return ReadRangeOutput( values=[ @@ -146,38 +292,532 @@ def snapshot_read_range(self, read: ReadRange) -> ReadRangeOutput: key=e.key, versionstamp=bytes(VersionStamp(e.versionstamp)), encoding=e.encoding, - value=e.value, + value=e.data, ) for e in entries ] ) + @overload + def _read_single( + self, key: bytes, current_time_ms: int, pending_entries: None = None + ) -> MockKvDbEntry | None: ... + + @overload + def _read_single( + self, + key: bytes, + current_time_ms: int, + pending_entries: Mapping[bytes, KvWriteValue], + ) -> MockKvDbEntry | KvWriteValue | None: ... + + def _read_single( + self, + key: bytes, + current_time_ms: int, + pending_entries: Mapping[bytes, KvWriteValue] | None = None, + ) -> MockKvDbEntry | KvWriteValue | None: + if pending_entries and (pending := pending_entries.get(key)): + if pending.expire_at_ms != 0 and pending.expire_at_ms <= current_time_ms: + return None + return pending + matches = self._read_range( + start=key, + end=increment_packed_key(key), + limit=1, + reverse=False, + current_time_ms=current_time_ms, + ) + assert len(matches) < 2 + return matches[0] if matches else None + + def atomic_write( + self, write: AtomicWrite, current_time_ms: int = 0 + ) -> AtomicWriteOutput: + failed_checks: list[int] = [] + for i, check in enumerate(write.checks): + checked_entry = self._read_single( + check.key, current_time_ms=current_time_ms + ) + + if len(check.versionstamp) == 0: + if checked_entry is not None: + failed_checks.append(i) + elif len(check.versionstamp) == 10: + if checked_entry is None: + failed_checks.append(i) + else: + if VersionStamp(checked_entry.versionstamp) != VersionStamp( + check.versionstamp + ): + failed_checks.append(i) + else: + raise ValueError( + f"Check versionstamp is not valid: {check.versionstamp!r}" + ) + + if len(failed_checks) > 0: + return AtomicWriteOutput( + status=AtomicWriteStatus.AW_CHECK_FAILURE, failed_checks=failed_checks + ) + + messages = [decode_enqueue_message(enqueue) for enqueue in write.enqueues] + + versionstamp = VersionStamp(self.next_version) + mutation_entries: dict[bytes, KvWriteValue] = {} + for mut in write.mutations: + cause: Exception | None = None + try: + key_tuple = unpack(mut.key) + key_bytes = pack(unpack(mut.key)) + except Exception as e: + key_tuple = None + key_bytes = None + cause = e + if key_bytes != mut.key: + raise ValueError(f"Mutation key is not valid: {mut.key!r}") from cause + assert key_tuple is not None + + expires_at_ms = mut.expire_at_ms + if expires_at_ms < 0: + raise ValueError( + f"Mutation expire_at_ms cannot be negative: {mut.expire_at_ms}" + ) + + if mut.mutation_type == MutationType.M_SET: + mutation_entries[key_bytes] = KvWriteValue( + data=mut.value.data, + encoding=mut.value.encoding, + expire_at_ms=expires_at_ms, + ) + elif mut.mutation_type == MutationType.M_DELETE: + mutation_entries[key_bytes] = KvWriteValue.tombstone() + elif ( + mut.mutation_type == MutationType.M_SUM + or mut.mutation_type == MutationType.M_MIN + or mut.mutation_type == MutationType.M_MAX + ): + # FIXME: Check this again, I don't think KvU64 actually allows + # bigint operands with the sqlite implementation. Does + # the FoundationDB impl act differently? + # + # Deno KV allows sum(left, right) with certain combinations of + # types: + # + # (Left is stored in the database, right is the value sent in + # the atomic operation.) + # + # Left | Right | + # ———— | KvU64 | bigint | number | + # KvU64 | yes | yes | no | + # bigint | no | yes | no | + # number | no | no | yes | + # + # min() and max() can only use KvU64 values. + + # We need to also read from mutation_entries to take into + # account values changed by preceding mutations within this + # AtomicWrite operation. + current = self._read_single( + mut.key, + current_time_ms=current_time_ms, + pending_entries=mutation_entries, + ) + operand_encoding, operand_value = decode_number_value(mut.value) + if current is None: + # float operands must have 0.0 not 0 as the default as + # cross-type sum operations are not allowed. + current_encoding, current_value = None, (type(operand_value)(0)) + else: + current_encoding, current_value = decode_number_value(current) + + op = _get_number_operator(mut, operand_encoding=operand_encoding) + + if not _is_allowed_op_combination( + op, + (current_encoding, current_value), + (operand_encoding, operand_value), + ): + raise ValueError( + f"Cannot apply operation " + f"{enum_name(MutationType, mut.mutation_type)}, " + f"number types are incompatible: " + f"current type: {current_encoding}, " + f"operand type: {operand_encoding}" + ) + + try: + result = op(current_value, operand_value) + except Exception as e: + raise ValueError( + f"Mutation is not a valid " + f"{enum_name(MutationType, mut.mutation_type)} operation: {e}" + ) from e + result_encoding = current_encoding or operand_encoding + assert result_encoding is not None + + mutation_entries[key_bytes] = KvWriteValue( + data=encode_number_value(result, result_encoding), + encoding=result_encoding.value_encoding, + expire_at_ms=expires_at_ms, + ) + elif mut.mutation_type == MutationType.M_SET_SUFFIX_VERSIONSTAMPED_KEY: + suffix_key = pack((*key_tuple, str(versionstamp))) + mutation_entries[suffix_key] = KvWriteValue( + data=mut.value.data, + encoding=mut.value.encoding, + expire_at_ms=expires_at_ms, + ) + else: + raise ValueError( + f"Mutation mutation_type is not valid: {mut.mutation_type}" + ) + self.extend(mutation_entries.items()) + self.queued_messages.extend(messages) + + return AtomicWriteOutput( + status=AtomicWriteStatus.AW_SUCCESS, versionstamp=versionstamp + ) + + +def _is_allowed_op_combination( + op: Callable[[int | float, int | float], int | float] | MutationSumOperator | None, + left: tuple[KvNumber | None, int | float], + right: tuple[KvNumber, int | float], +) -> TypeIs[Callable[[int | float, int | float], int | float] | MutationSumOperator]: + left_encoding, left_value = left + right_encoding, right_value = right + + def values_are(types: type | tuple[type, ...]) -> bool: + return isinstance(left_value, types) and isinstance(right_value, types) + + if isinstance(op, MutationSumOperator): + if left_encoding is KvNumber.u64: + return ( + right_encoding is KvNumber.u64 + or (right_encoding is KvNumber.bigint) + and values_are(int) + ) + elif left_encoding is KvNumber.bigint: + return right_encoding is KvNumber.bigint and values_are(JSBigInt) + elif left_encoding is KvNumber.float: + return right_encoding is KvNumber.float + elif left_encoding is None: + # Sum can be used with a missing left operand. + return right_encoding is KvNumber.float or values_are(int) + elif op is min or op is max: + return ( + left_encoding in (KvNumber.u64, None) + and right_encoding is KvNumber.u64 + and values_are(int) + ) + raise AssertionError(f"Unexpected op combinations: {op=}, {left=}, {right=}") + + +def _get_number_operator( + mut: Mutation, *, operand_encoding: KvNumber +) -> Callable[[float, float], float] | None: + if mut.mutation_type == MutationType.M_SUM: + if ( + mut.sum_min or mut.sum_max or mut.sum_clamp + ) and mut.value.encoding != ValueEncoding.VE_V8: + raise ValueError( + "Mutation used sum_min/sum_max/sum_clamp with non-V8 encoding" + ) + if mut.value.encoding == ValueEncoding.VE_LE64: + return MutationSumOperator(0, 2**64 - 1, LimitExceededPolicy.WRAP) + + min_enc, min_ = decode_v8_number(mut.sum_min) if mut.sum_min else (None, None) + max_enc, max_ = decode_v8_number(mut.sum_max) if mut.sum_max else (None, None) + if (min_enc and min_enc is not operand_encoding) or ( + max_enc and max_enc is not operand_encoding + ): + raise ValueError( + "Mutation used different number types for value/sum_min/sum_max" + ) + boundary = ( + LimitExceededPolicy.CLAMP if mut.sum_clamp else LimitExceededPolicy.ABORT + ) + return MutationSumOperator(min=min_, max=max_, boundary=boundary) + elif mut.mutation_type == MutationType.M_MAX: + return max + elif mut.mutation_type == MutationType.M_MIN: + return min + return None + + +@dataclass +class MutationSumOperator: + min: int | float | None + max: int | float | None + boundary: LimitExceededPolicy + + def __call__(self, left: int | float, right: int | float) -> int | float: + min, max = self.min, self.max + + result = left + right + if self.boundary is LimitExceededPolicy.WRAP: + # wrap is only used for uint64 + assert min == 0 + assert max is not None and max >= 0 + result = result % (max + 1) + elif min is not None and result < min: + if self.boundary is LimitExceededPolicy.CLAMP: + result = min + else: + assert self.boundary is LimitExceededPolicy.ABORT + raise SumLimitExceeded( + f"result of sum({left}, {right}) = {result}, which is less " + f"than the minimum {min}" + ) + if max is not None and result > max: + if self.boundary is LimitExceededPolicy.CLAMP: + result = max + else: + assert self.boundary is LimitExceededPolicy.ABORT + raise SumLimitExceeded( + f"result of sum({left}, {right}) = {result}, which is " + f"greater than the maximum {max}" + ) + return result + + +def decode_number_value( + entry: MockKvDbEntry | KvWriteValue | KvValue, +) -> tuple[KvNumber, int | float]: + if entry.encoding == ValueEncoding.VE_LE64: + return KvNumber.u64, KvU64(entry.data).value + elif entry.encoding == ValueEncoding.VE_V8: + return decode_v8_number(entry.data) + else: + raise ValueError("entry value is not an LE64 or V8-encoded BigInt or Number") + + +def decode_v8_number(data: bytes) -> tuple[KvNumber, int | float]: + value = default_v8_decoder.decodes(data) + if type(value) is JSBigInt: + return KvNumber.bigint, value + if type(value) in (int, float): + return KvNumber.float, cast(Union[int, float], value) + raise ValueError("V8-serialized value is not a BigInt or Number") + + +def encode_number_value(value: int | float, encoding: KvNumber) -> bytes: + if encoding is KvNumber.float: + return bytes(default_v8_encoder.encode(float(value))) + elif encoding is KvNumber.bigint: + if isinstance(value, float): + raise TypeError("Cannot encode float as V8 BigInt") + return bytes(default_v8_encoder.encode(JSBigInt(value))) + else: + assert encoding is KvNumber.u64 + if isinstance(value, float): + raise TypeError("Cannot encode float as LE64") + return KvU64(value).to_bytes() + -def encode_protobuf_kv_value(value: object) -> KvValue: +def encode_kv_write_value(value: object, expires_at_ms: int = 0) -> KvWriteValue: if isinstance(value, KvU64): - return KvValue(data=bytes(value), encoding=ValueEncoding.VE_LE64) + return KvWriteValue( + data=bytes(value), + encoding=ValueEncoding.VE_LE64, + expire_at_ms=expires_at_ms, + ) elif isinstance(value, bytes): - return KvValue(data=value, encoding=ValueEncoding.VE_BYTES) + return KvWriteValue( + data=value, encoding=ValueEncoding.VE_BYTES, expire_at_ms=expires_at_ms + ) else: - return KvValue(data=v8serialize.dumps(value), encoding=ValueEncoding.VE_V8) + return KvWriteValue( + data=v8serialize.dumps(value), + encoding=ValueEncoding.VE_V8, + expire_at_ms=expires_at_ms, + ) + + +def decode_enqueue_message(enqueue: Enqueue) -> MockKvDbMessage: + try: + payload_value = default_v8_decoder.decodes(enqueue.payload) + except v8serialize.V8SerializeError as e: + raise ValueError("Enqueue payload is not a valid V8-encoded value") from e + keys_if_undelivered = list[KvKey]() + for k in enqueue.keys_if_undelivered: + try: + keys_if_undelivered.append(KvKey.from_kv_key_bytes(k)) + except ValueError as e: + raise ValueError( + f"Enqueue keys_if_undelivered contains invalid key: {k!r}" + ) from e + return MockKvDbMessage( + payload=payload_value, + backoff_schedule=list(enqueue.backoff_schedule), + deadline_ms=enqueue.deadline_ms, + keys_if_undelivered=keys_if_undelivered, + ) + + +def mock_db_api(mock_db: MockKvDb) -> web.Application: + """HTTP endpoints implementing the KV Data Path protocol against MockKvDb.""" + + def get_server_version(request: web.Request) -> Literal[1, 2, 3]: + match = re.match(r"^/v([123])/", request.path) + version: Final = int(match.group(1)) if match else -1 + if version not in (1, 2, 3): + raise AssertionError("handler is not registered at /v[123]/ URL path") + return cast(Literal[1, 2, 3], version) + + def validate_request(request: web.Request) -> None: + server_version = get_server_version(request) + + if request.method != "POST": + raise web.HTTPBadRequest(text="method must be POST") + if request.content_type != "application/x-protobuf": + raise web.HTTPBadRequest(text="content-type must be application/x-protobuf") + + db_id_header = ( + "x-transaction-domain-id" if server_version == 1 else "x-denokv-database-id" + ) + try: + UUID(request.headers.get(db_id_header, "")) + except Exception: + raise web.HTTPBadRequest( + text=f"client did not set a valid {db_id_header} when talking to a " + f"v{server_version} server" + ) from None + + if server_version > 2: + try: + client_version = int(request.headers.get("x-denokv-version", "")) + if client_version not in (2, 3): + raise ValueError(f"invalid client_version: {client_version}") + except Exception: + raise web.HTTPBadRequest( + text=f"client did not set a valid x-denokv-version header when " + f"talking to a v{server_version} server" + ) from None + + def parse_protobuf_body( + body_bytes: bytes, message_type: type[MessageT] + ) -> MessageT: + message = message_type() + try: + count = message.ParseFromString(body_bytes) + if len(body_bytes) != count: + raise ValueError( + f"{len(body_bytes) - count} trailing bytes after " + f"{message_type.__name__}" + ) + except Exception as e: + raise web.HTTPBadRequest( + text=f"body is not a valid {message_type.__name__} message: {e}" + ) from e + return message + + # Valid snapshot_read handler + async def strong_snapshot_read(request: web.Request) -> web.Response: + validate_request(request) + read = parse_protobuf_body(await request.read(), SnapshotRead) + + read_result = SnapshotReadOutput( + status=SnapshotReadStatus.SR_SUCCESS, + read_is_strongly_consistent=True, + ranges=[mock_db.snapshot_read_range(r) for r in read.ranges], + ) + return web.Response( + status=200, + content_type="application/x-protobuf", + body=read_result.SerializeToString(), + ) + + # Valid atomic_write handler + async def atomic_write(request: web.Request) -> web.Response: + validate_request(request) + + write = parse_protobuf_body(await request.read(), AtomicWrite) + + try: + write_result = mock_db.atomic_write(write) + except ValueError as e: + raise web.HTTPBadRequest(text=f"SnapshotWrite is not valid: {e}") from e + + return web.Response( + status=200, + content_type="application/x-protobuf", + body=write_result.SerializeToString(), + ) + + app = web.Application() + + # Working endpoints + app.router.add_post("/v1/consistency/strong/snapshot_read", strong_snapshot_read) + app.router.add_post("/v2/consistency/strong/snapshot_read", strong_snapshot_read) + app.router.add_post("/v3/consistency/strong/snapshot_read", strong_snapshot_read) + app.router.add_post("/v1/consistency/strong/atomic_write", atomic_write) + app.router.add_post("/v2/consistency/strong/atomic_write", atomic_write) + app.router.add_post("/v3/consistency/strong/atomic_write", atomic_write) + return app + + +def make_database_metadata( + endpoints: URL | Sequence[EndpointInfo], + *, + endpoint_consistency: ConsistencyLevel | None = None, + version: Literal[1, 2, 3] = 3, + database_id: UUID | None = None, + expires_at: datetime | None = None, + token: str = "hunter2.123", +) -> DatabaseMetadata: + if isinstance(endpoints, URL): + if endpoint_consistency is None: + endpoint_consistency = ConsistencyLevel.STRONG + endpoints = [EndpointInfo(url=endpoints, consistency=endpoint_consistency)] + else: + if endpoint_consistency is not None: + raise TypeError( + "cannot set endpoint_consistency argument wen endpoints is a Sequence" + ) + + if database_id is None: + database_id = UUID("00000000-0000-0000-0000-000000000000") + if expires_at is None: + expires_at = datetime.now() + timedelta(minutes=30) + + meta = DatabaseMetadata( + version=version, + database_id=database_id, + endpoints=endpoints, + expires_at=expires_at, + token=token, + ) + return meta + + +def meta_endpoint(meta: DatabaseMetadata) -> tuple[DatabaseMetadata, EndpointInfo]: + return meta, meta.endpoints[0] def add_entries( db: MockKvDb, - entries: Mapping[KvKeyTuple, object] | Iterable[tuple[KvKeyTuple, object]], + entries: Mapping[AnyKvKey, object] + | Mapping[KvKeyTuple, object] + | Iterable[tuple[AnyKvKey, object]], ) -> VersionStamp: if isinstance(entries, Mapping): entries = entries.items() version = VersionStamp(db.next_version) encoded_entries = [ - (pack_key(key), encode_protobuf_kv_value(value)) for (key, value) in entries + (pack_key(key), encode_kv_write_value(value)) for (key, value) in entries ] db.extend(encoded_entries) return version -def unsafe_parse_protobuf_kv_entry(raw: ProtobufKvEntry) -> KvEntry: +def unsafe_parse_protobuf_kv_entry( + raw: ProtobufKvEntry, v8_decoder: v8serialize.Decoder | None = None +) -> KvEntry: + if v8_decoder is None: + v8_decoder = default_v8_decoder key, value, versionstamp = assume_ok( parse_protobuf_kv_entry(raw, v8_decoder=v8_decoder, le64_type=KvU64) ) @@ -237,3 +877,18 @@ def nextafter(x: float, y: float, *, steps: int = 1) -> float: for _ in range(steps): x = math.nextafter(x, y) return x + + +def typeval(value: T) -> tuple[type[T], T]: + return type(value), value + + +def create_dataclass_slots_test() -> Callable[[Never], None]: + @pytest.mark.skipif( + sys.version_info < (3, 10), reason="<3.10 does not use slots for dataclass" + ) + def test_instances_dont_have_dict_because_of_slots(instance: object) -> None: + with pytest.raises(AttributeError): + _ = instance.__dict__ + + return test_instances_dont_have_dict_because_of_slots diff --git a/test/test__kv_values.py b/test/test__kv_values.py new file mode 100644 index 0000000..3e67731 --- /dev/null +++ b/test/test__kv_values.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +import pytest +from hypothesis import given +from hypothesis import strategies as st + +from denokv._kv_values import KvEntry +from denokv._kv_values import KvU64 +from denokv._kv_values import VersionStamp +from denokv._pycompat.typing import Callable +from denokv.kv_keys import KvKey +from test.denokv_testing import create_dataclass_slots_test + + +@pytest.fixture( + params=[ + pytest.param(lambda: KvEntry(KvKey("a"), 42, VersionStamp(1)), id="KvEntry"), + pytest.param(lambda: VersionStamp(1), id="VersionStamp"), + pytest.param(lambda: KvU64(1), id="KvU64"), + ] +) +def instance(request: pytest.FixtureRequest) -> object: + param: Callable[[], object] = request.param + return param() + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() + + +@given(v=st.integers(min_value=0, max_value=2**80 - 1)) +def test_VersionStamp_init(v: int) -> None: + vs_int = VersionStamp(v) + assert int(vs_int) == v + assert VersionStamp(str(vs_int)) == vs_int + assert VersionStamp(bytes(vs_int)) == vs_int + assert bytes(vs_int) == vs_int + assert isinstance(vs_int, bytes) + + +@given(i=st.integers(min_value=0, max_value=2**64 - 1)) +def test_KvU64_init(i: int) -> None: + u64 = KvU64(i) + assert int(u64) == i + assert KvU64(bytes(u64)) == u64 + assert u64.to_bytes() == bytes(u64) + assert u64.to_bytes() == i.to_bytes(8, "little") + + +@given( + v1=st.integers(min_value=0, max_value=2**80 - 1), + v2=st.integers(min_value=0, max_value=2**80 - 1), +) +def test_VersionStamp_ordering(v1: int, v2: int) -> None: + vs1, vs2 = VersionStamp(v1), VersionStamp(v2) + if v1 < v2: + assert vs1 < vs2 + elif v1 > v2: + assert vs1 > vs2 + else: + assert vs1 == vs2 + + +def test_KVU64__bytes() -> None: + assert KvU64(bytes(KvU64(123456789))).value == 123456789 + assert KvU64(KvU64(123456789).to_bytes()).value == 123456789 diff --git a/test/test__kv_writes__Check.py b/test/test__kv_writes__Check.py new file mode 100644 index 0000000..b70bf36 --- /dev/null +++ b/test/test__kv_writes__Check.py @@ -0,0 +1,46 @@ +import pytest +from v8serialize import Encoder + +from denokv import _datapath_pb2 as datapath_pb2 +from denokv._kv_values import VersionStamp +from denokv._kv_writes import Check +from denokv.kv_keys import KvKey +from test.denokv_testing import create_dataclass_slots_test + + +@pytest.fixture +def instance() -> Check: + return Check(KvKey("a"), None) + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() + + +def test_constructors() -> None: + assert Check(KvKey("a"), VersionStamp(1)) == Check.for_key_with_version( + KvKey("a"), VersionStamp(1) + ) + assert Check(KvKey("a"), None) == Check.for_key_not_set(KvKey("a")) + + +def test_as_protobuf(v8_encoder: Encoder) -> None: + protobuf = ( + datapath_pb2.Check(key=bytes(KvKey("a")), versionstamp=bytes(VersionStamp(1))), + ) + # v8_encoder is optional + assert Check(KvKey("a"), VersionStamp(1)).as_protobuf() == protobuf + assert Check(KvKey("a"), VersionStamp(1)).as_protobuf(v8_encoder=None) == protobuf + + assert ( + Check(KvKey("a"), VersionStamp(1)).as_protobuf(v8_encoder=v8_encoder) + == protobuf + ) + + +def test_as_protobuf__empty_version_is_different_to_zero_version() -> None: + assert datapath_pb2.Check(key=bytes(KvKey("a"))) == datapath_pb2.Check( + key=bytes(KvKey("a")), versionstamp=b"" + ) + assert datapath_pb2.Check(key=bytes(KvKey("a"))) != datapath_pb2.Check( + versionstamp=VersionStamp(0) + ) diff --git a/test/test__kv_writes__CommittedWrite.py b/test/test__kv_writes__CommittedWrite.py new file mode 100644 index 0000000..b30869d --- /dev/null +++ b/test/test__kv_writes__CommittedWrite.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +import pytest +from yarl import URL + +from denokv._kv_values import VersionStamp +from denokv._kv_writes import Check +from denokv._kv_writes import CommittedWrite +from denokv._kv_writes import Enqueue +from denokv._kv_writes import Set +from denokv._rfc3339 import parse_rfc3339_datetime +from denokv.auth import ConsistencyLevel +from denokv.auth import EndpointInfo +from denokv.kv_keys import KvKey +from denokv.result import is_ok +from test.denokv_testing import create_dataclass_slots_test + +T1 = parse_rfc3339_datetime("2000-01-02T03:04:05.6Z").value_or_raise() +EP = EndpointInfo(URL("https://example.com/"), consistency=ConsistencyLevel.STRONG) + + +@pytest.fixture +def instance() -> CommittedWrite: + return CommittedWrite( + VersionStamp(1), checks=[], mutations=[], enqueues=[], endpoint=EP + ) + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() + + +def test_is_AnySuccess() -> None: + assert is_ok( + CommittedWrite( + VersionStamp(1), checks=[], mutations=[], enqueues=[], endpoint=EP + ) + ) + + +def test_constructors() -> None: + instance = CommittedWrite( + VersionStamp(1), checks=[], mutations=[], enqueues=[], endpoint=EP + ) + assert instance.ok + assert instance.versionstamp == VersionStamp(1) + assert instance.checks == () + assert instance.mutations == () + assert instance.enqueues == () + assert instance.endpoint is EP + + instance = CommittedWrite( + VersionStamp(1), + checks=[Check.for_key_not_set(key=KvKey("a"))], + mutations=[Set(KvKey("a"), 42)], + enqueues=[Enqueue("Hi")], + endpoint=EP, + ) + assert instance.ok + assert instance.versionstamp == VersionStamp(1) + assert instance.checks == (Check.for_key_not_set(key=KvKey("a")),) + assert instance.mutations == (Set(KvKey("a"), 42),) + assert instance.enqueues == (Enqueue("Hi"),) + assert instance.endpoint is EP + + assert instance.conflicts == {} + assert not instance.has_unknown_conflicts + + +def test_str_repr() -> None: + instance = CommittedWrite( + VersionStamp(1), + checks=[Check.for_key_not_set(key=KvKey("a"))], + mutations=[Set(KvKey("a"), 42)], + enqueues=[Enqueue("Hi")], + endpoint=EP, + ) + assert ( + str(instance) == "Write committed version 0x00000000000000000001 " + "to 'https://example.com/' with 1 checks, 1 mutations, 1 enqueues" + ) + assert ( + repr(instance) == "" + ) diff --git a/test/test__kv_writes__ConflictedWrite.py b/test/test__kv_writes__ConflictedWrite.py new file mode 100644 index 0000000..c0587d3 --- /dev/null +++ b/test/test__kv_writes__ConflictedWrite.py @@ -0,0 +1,176 @@ +from __future__ import annotations + +import traceback + +import pytest +from yarl import URL + +from denokv import _datapath_pb2 as datapath_pb2 +from denokv._kv_writes import Check +from denokv._kv_writes import ConflictedWrite +from denokv._kv_writes import Enqueue +from denokv._kv_writes import Set +from denokv._pycompat.typing import Iterable +from denokv._pycompat.typing import Sequence +from denokv._pycompat.typing import cast +from denokv._rfc3339 import parse_rfc3339_datetime +from denokv.auth import ConsistencyLevel +from denokv.auth import EndpointInfo +from denokv.datapath import CheckFailure +from denokv.kv_keys import KvKey +from denokv.result import is_err + +T1 = parse_rfc3339_datetime("2000-01-02T03:04:05.6Z").value_or_raise() # noqa: F821 +EP = EndpointInfo(URL("https://example.com/"), consistency=ConsistencyLevel.STRONG) + + +@pytest.fixture +def checks() -> tuple[Check, Check, Check]: + return ( + Check.for_key_not_set(KvKey("a")), + Check.for_key_not_set(KvKey("b")), + Check.for_key_not_set(KvKey("c")), + ) + + +@pytest.fixture +def instance(checks: Iterable[Check]) -> ConflictedWrite: + pb_checks = [ + datapath_pb2.Check(key=bytes(KvKey("a")), versionstamp=None), + datapath_pb2.Check(key=bytes(KvKey("b")), versionstamp=None), + datapath_pb2.Check(key=bytes(KvKey("c")), versionstamp=None), + ] + failed_checks = [0, 2] + + cause = CheckFailure( + "Not all checks required by the Atomic Write passed", + all_checks=pb_checks, + failed_check_indexes=failed_checks, + endpoint=EP, + ) + + return ConflictedWrite( + failed_checks=failed_checks, + checks=checks, + mutations=[Set(KvKey("a"), 42)], + enqueues=[Enqueue("Hi")], + endpoint=EP, + cause=cause, + ) + + +def test_constructor(checks: Sequence[Check]) -> None: + instance = ConflictedWrite( + failed_checks=[0, 2], + checks=cast(Iterable[Check], checks), + mutations=[Set(KvKey("a"), 42)], + enqueues=[Enqueue("Hi")], + endpoint=EP, + ) + + assert not instance.ok + assert instance.versionstamp is None + assert instance.checks == tuple(checks) + assert instance.mutations == (Set(KvKey("a"), 42),) + assert instance.enqueues == (Enqueue("Hi"),) + assert instance.endpoint is EP + + assert instance.conflicts == {KvKey("a"): checks[0], KvKey("c"): checks[2]} + assert instance.conflicts[KvKey("a")] is checks[0] + + +@pytest.mark.parametrize("failed_checks", [None, [], [0]]) +def test_constructor__conflicts_are_always_known_with_single_check( + failed_checks: Iterable[int] | None, +) -> None: + instance = ConflictedWrite( + failed_checks=failed_checks, + checks=iter([Check.for_key_not_set(KvKey("a"))]), + mutations=[Set(KvKey("a"), 42)], + enqueues=[Enqueue("Hi")], + endpoint=EP, + ) + + assert KvKey("a") in instance.conflicts + assert instance.conflicts[KvKey("a")].key == KvKey("a") + assert not instance.has_unknown_conflicts + + +@pytest.mark.parametrize("failed_checks", [None, []]) +def test_constructor__conflicts_are_unknown_with_multiple_checks_without_failed_checks( + failed_checks: Iterable[int] | None, checks: Iterable[Check] +) -> None: + instance = ConflictedWrite( + failed_checks=failed_checks, + checks=checks, + mutations=[Set(KvKey("a"), 42)], + enqueues=[Enqueue("Hi")], + endpoint=EP, + ) + + assert len(instance.conflicts) == 0 + assert instance.has_unknown_conflicts + + +def test_constructor__rejects_out_of_bounds_failed_checks( + checks: tuple[Check, Check, Check], +) -> None: + assert len(checks) == 3 + with pytest.raises(ValueError, match=r"failed_checks contains out-of-bounds index"): + ConflictedWrite( + failed_checks=[0, 10], + checks=checks, + mutations=[Set(KvKey("a"), 42)], + enqueues=[], + endpoint=EP, + ) + + +def test_changes_to_conflicts_do_not_persist(instance: ConflictedWrite) -> None: + assert isinstance(instance.conflicts, dict) + # Changes to conflicts do not persist + assert KvKey("a") in instance.conflicts + del instance.conflicts[KvKey("a")] + assert KvKey("a") in instance.conflicts + + +def test_is_AnyFailure(instance: ConflictedWrite) -> None: + assert is_err(instance) + + +@pytest.mark.parametrize("with_cause", [True, False]) +def test_str(instance: ConflictedWrite, with_cause: bool) -> None: + assert instance.__cause__ + if not with_cause: + instance.__cause__ = None + assert ( + str(instance) == "Write NOT APPLIED to 'https://example.com/' " + "with 2/3 checks CONFLICTING, 1 mutations, 1 enqueues" + ) + + +@pytest.mark.parametrize("with_cause", [True, False]) +def test_repr(instance: ConflictedWrite, with_cause: bool) -> None: + assert instance.__cause__ + if not with_cause: + instance.__cause__ = None + + assert ( + repr(instance) == "" + ) + + +@pytest.mark.parametrize("with_cause", [True, False]) +def test_traceback_presentation(instance: ConflictedWrite, with_cause: bool) -> None: + assert instance.__cause__ + if not with_cause: + instance.__cause__ = None + + assert "\n".join( + traceback.format_exception_only(type(instance), instance) + ).strip() == ( + "denokv._kv_writes.ConflictedWrite: " + "Write NOT APPLIED to 'https://example.com/' " + "with 2/3 checks CONFLICTING, 1 mutations, 1 enqueues" + ) diff --git a/test/test__kv_writes__Delete.py b/test/test__kv_writes__Delete.py new file mode 100644 index 0000000..d023cb9 --- /dev/null +++ b/test/test__kv_writes__Delete.py @@ -0,0 +1,32 @@ +import pytest +from v8serialize import Encoder + +from denokv import _datapath_pb2 as datapath_pb2 +from denokv._kv_writes import Delete +from denokv._rfc3339 import parse_rfc3339_datetime +from denokv.kv_keys import KvKey +from test.denokv_testing import create_dataclass_slots_test + +T1 = parse_rfc3339_datetime("2000-01-02T03:04:05.6Z").value_or_raise() + + +@pytest.fixture +def instance() -> Delete: + return Delete(KvKey("a")) + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() + + +def test_constructors() -> None: + instance = Delete(KvKey("a")) + assert instance.key == KvKey("a") + + +def test_as_protobuf(v8_encoder: Encoder) -> None: + delete = Delete(KvKey("a")) + assert delete.as_protobuf(v8_encoder=v8_encoder) == ( + datapath_pb2.Mutation( + key=bytes(KvKey("a")), mutation_type=datapath_pb2.M_DELETE + ), + ) diff --git a/test/test__kv_writes__Enqueue.py b/test/test__kv_writes__Enqueue.py new file mode 100644 index 0000000..2f2b1cb --- /dev/null +++ b/test/test__kv_writes__Enqueue.py @@ -0,0 +1,80 @@ +from itertools import count +from itertools import islice + +import pytest +from v8serialize import Encoder + +from denokv import _datapath_pb2 as datapath_pb2 +from denokv._kv_writes import DEFAULT_ENQUEUE_RETRY_DELAY_COUNT +from denokv._kv_writes import DEFAULT_ENQUEUE_RETRY_DELAYS +from denokv._kv_writes import Enqueue +from denokv._rfc3339 import parse_rfc3339_datetime +from denokv.kv_keys import KvKey +from test.denokv_testing import create_dataclass_slots_test + +T1 = parse_rfc3339_datetime("2000-01-02T03:04:05.6Z").value_or_raise() + + +@pytest.fixture +def instance() -> Enqueue: + return Enqueue(42) + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() + + +def test_constructors() -> None: + message = {"msg": "Hi"} + instance = Enqueue(message) + assert instance.message is message + assert instance.delivery_time is None + assert instance.retry_delays == DEFAULT_ENQUEUE_RETRY_DELAYS + assert len(instance.dead_letter_keys) == 0 + + retry_delays = [1, 2, 3] + dead_letter_keys = (KvKey("a"),) + instance = Enqueue( + message, + delivery_time=T1, + retry_delays=retry_delays, + dead_letter_keys=dead_letter_keys, + ) + assert instance.message is message + assert instance.delivery_time == T1 + assert instance.retry_delays == retry_delays + assert instance.dead_letter_keys == dead_letter_keys + + +def test_as_protobuf__default_retry_delays(v8_encoder: Encoder) -> None: + message = {"msg": "Hi"} + instance = Enqueue(message) + (protobuf,) = instance.as_protobuf(v8_encoder=v8_encoder) + + # Default retry delays have random jitter. A fixed number are drawn from the + # backoff provider. + assert len(protobuf.backoff_schedule) == DEFAULT_ENQUEUE_RETRY_DELAY_COUNT + assert all(delay > 0 for delay in protobuf.backoff_schedule) + + +def test_as_protobuf(v8_encoder: Encoder) -> None: + message = {"msg": "Hi"} + instance = Enqueue(message, retry_delays=[]) + (protobuf,) = instance.as_protobuf(v8_encoder=v8_encoder) + assert protobuf == datapath_pb2.Enqueue(payload=bytes(v8_encoder.encode(message))) + + evaluated_backoff = [ + i * 1000 for i in islice(count(1), DEFAULT_ENQUEUE_RETRY_DELAY_COUNT) + ] + instance = Enqueue( + message, + retry_delays=count(1), + delivery_time=T1, + dead_letter_keys=[KvKey("a"), KvKey("b")], + ) + (protobuf,) = instance.as_protobuf(v8_encoder=v8_encoder) + assert protobuf == datapath_pb2.Enqueue( + payload=bytes(v8_encoder.encode(message)), + backoff_schedule=evaluated_backoff, + keys_if_undelivered=[bytes(KvKey("a")), bytes(KvKey("b"))], + deadline_ms=int(T1.timestamp() * 1000), + ) diff --git a/test/test__kv_writes__FailedWrite.py b/test/test__kv_writes__FailedWrite.py new file mode 100644 index 0000000..cc9aa31 --- /dev/null +++ b/test/test__kv_writes__FailedWrite.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import traceback + +import pytest +from yarl import URL + +from denokv._kv_writes import Check +from denokv._kv_writes import Enqueue +from denokv._kv_writes import FailedWrite +from denokv._kv_writes import Set +from denokv._rfc3339 import parse_rfc3339_datetime +from denokv.auth import ConsistencyLevel +from denokv.auth import EndpointInfo +from denokv.datapath import ProtocolViolation +from denokv.kv_keys import KvKey +from denokv.result import is_err + +T1 = parse_rfc3339_datetime("2000-01-02T03:04:05.6Z").value_or_raise() +EP = EndpointInfo(URL("https://example.com/"), consistency=ConsistencyLevel.STRONG) + + +@pytest.fixture +def instance() -> FailedWrite: + checks = [ + Check.for_key_not_set(KvKey("a")), + Check.for_key_not_set(KvKey("b")), + Check.for_key_not_set(KvKey("c")), + ] + return FailedWrite( + checks=list(checks), + mutations=[Set(KvKey("a"), 42)], + enqueues=[Enqueue("Hi")], + endpoint=EP, + cause=ProtocolViolation("Server misbehaved", data=None, endpoint=EP), + ) + + +@pytest.mark.parametrize( + "cause", [None, ProtocolViolation("Server misbehaved", data=None, endpoint=EP)] +) +def test_constructor(cause: BaseException | None) -> None: + checks = [ + Check.for_key_not_set(KvKey("a")), + Check.for_key_not_set(KvKey("b")), + Check.for_key_not_set(KvKey("c")), + ] + instance = FailedWrite( + checks=list(checks), + mutations=[Set(KvKey("a"), 42)], + enqueues=[Enqueue("Hi")], + endpoint=EP, + cause=cause, + ) + + assert not instance.ok + assert instance.versionstamp is None + assert instance.checks == tuple(checks) + assert instance.mutations == (Set(KvKey("a"), 42),) + assert instance.enqueues == (Enqueue("Hi"),) + assert instance.endpoint is EP + assert instance.__cause__ is cause + + assert instance.conflicts == {} + assert not instance.has_unknown_conflicts + + +def test_exception_attributes(instance: FailedWrite) -> None: + assert instance.args == () + + +def test_changes_to_conflicts_do_not_persist(instance: FailedWrite) -> None: + conflicts = instance.conflicts + assert isinstance(conflicts, dict) + # Changes to conflicts do not persist + conflicts[KvKey("a")] = instance.checks[0] + assert instance.conflicts == {} + + +def test_is_AnyFailure(instance: FailedWrite) -> None: + assert is_err(instance) + + +def test_str(instance: FailedWrite) -> None: + assert ( + str(instance) == "Write failed to 'https://example.com/' " + "due to ProtocolViolation, with 3 checks, 1 mutations, 1 enqueues" + ) + + instance.__cause__ = None + assert ( + str(instance) == "Write failed to 'https://example.com/' " + "due to unspecified cause, with 3 checks, 1 mutations, 1 enqueues" + ) + + +def test_repr(instance: FailedWrite) -> None: + assert ( + repr(instance) == "" + ) + + instance.__cause__ = None + assert ( + repr(instance) == "" + ) + + +def test_traceback_presentation(instance: FailedWrite) -> None: + assert "\n".join( + traceback.format_exception_only(type(instance), instance) + ).strip() == ( + "denokv._kv_writes.FailedWrite: Write failed " + "to 'https://example.com/' " + "due to ProtocolViolation, with 3 checks, 1 mutations, 1 enqueues" + ) diff --git a/test/test__kv_writes__KvNumber.py b/test/test__kv_writes__KvNumber.py new file mode 100644 index 0000000..ae4b8c1 --- /dev/null +++ b/test/test__kv_writes__KvNumber.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from dataclasses import FrozenInstanceError +from typing import Literal # noqa: TID251 + +import pytest +from v8serialize.jstypes import JSBigInt + +from denokv._kv_values import KvU64 +from denokv._kv_writes import BigIntKvNumberInfo +from denokv._kv_writes import FloatKvNumberInfo +from denokv._kv_writes import KvNumber +from denokv._kv_writes import KvNumberInfo +from denokv._kv_writes import U64KvNumberInfo +from denokv._pycompat.typing import assert_type + + +def test_dataclass_behaviours() -> None: + assert KvNumber.bigint < KvNumber.float + assert KvNumber.float < KvNumber.u64 + assert {KvNumber.bigint: "foo"}[KvNumber.bigint] == "foo" + + assert sorted(KvNumber) == [ + KvNumber.bigint, + KvNumber.float, + KvNumber.u64, + ] + + with pytest.raises(FrozenInstanceError): + KvNumber.bigint.foo = "bar" # type: ignore[attr-defined] + + +def test_resolve() -> None: + assert KvNumber.bigint.name == "bigint" + assert ( + assert_type(KvNumber.resolve("bigint"), Literal[KvNumber.bigint]) + is KvNumber.bigint + ) + assert ( + assert_type(KvNumber.resolve("float"), Literal[KvNumber.float]) + is KvNumber.float + ) + assert assert_type(KvNumber.resolve("u64"), Literal[KvNumber.u64]) is KvNumber.u64 + assert ( + assert_type(KvNumber.resolve(JSBigInt), Literal[KvNumber.bigint]) + is KvNumber.bigint + ) + assert ( + assert_type(KvNumber.resolve(float), Literal[KvNumber.float]) is KvNumber.float + ) + assert assert_type(KvNumber.resolve(KvU64), Literal[KvNumber.u64]) is KvNumber.u64 + + +def test_types() -> None: + assert_type(KvNumber.bigint.value, BigIntKvNumberInfo) + assert_type(KvNumber.float.value, FloatKvNumberInfo) + assert_type(KvNumber.u64.value, U64KvNumberInfo) + + _t1: KvNumberInfo[Literal["bigint"], int, JSBigInt] = KvNumber.bigint.value + _t2: KvNumberInfo[Literal["float"], float, float] = KvNumber.float.value + _t3: KvNumberInfo[Literal["u64"], int, KvU64] = KvNumber.u64.value + + # name is covariant — can treat the name as str + _t4: KvNumberInfo[str, int, KvU64] = KvNumber.u64.value + # number params are invariant — cannot broaden the types + _t_err1: KvNumberInfo[str, int | float, KvU64] = KvNumber.u64.value # type: ignore[assignment] + _t_err2: KvNumberInfo[str, int, KvU64 | float] = KvNumber.u64.value # type: ignore[assignment] diff --git a/test/test__kv_writes__Limit.py b/test/test__kv_writes__Limit.py new file mode 100644 index 0000000..a0e2d8b --- /dev/null +++ b/test/test__kv_writes__Limit.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import pytest +from v8serialize.constants import FLOAT64_SAFE_INT_RANGE +from v8serialize.jstypes import JSBigInt + +from denokv._kv_values import KvU64 +from denokv._kv_writes import LIMIT_KVU64 +from denokv._kv_writes import LIMIT_UNLIMITED +from denokv._kv_writes import Limit +from denokv._kv_writes import LimitExceededPolicy +from test.denokv_testing import create_dataclass_slots_test + + +@pytest.fixture +def instance() -> Limit: + return Limit(1, 5, "clamp") + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() + + +def test_constructor() -> None: + assert Limit(1, 5, "clamp").limit_exceeded is LimitExceededPolicy.CLAMP + assert Limit(1, 5, "abort").limit_exceeded is LimitExceededPolicy.ABORT + assert Limit(1, 5).limit_exceeded is LimitExceededPolicy.ABORT + assert ( + Limit(1, 5, LimitExceededPolicy.CLAMP).limit_exceeded + is LimitExceededPolicy.CLAMP + ) + assert Limit(1, 5).min == 1 + assert type(Limit(1, 5).min) is int + assert Limit(1, 5).max == 5 + assert type(Limit(1, 5).max) is int + + +def test_contains() -> None: + assert 3 in Limit(max=5) + assert -10 in Limit(max=5) + assert 5 in Limit(max=5) + assert 6 not in Limit(max=5) + assert 1 in Limit(1, 5) + assert -10 not in Limit(0, 5) + assert 10 not in Limit(0, 5) + assert 5 in Limit() + # Non-numbers are not contained + assert object() not in Limit() + + # contains works across types + assert 1.0 in Limit(FLOAT64_SAFE_INT_RANGE.start, FLOAT64_SAFE_INT_RANGE.stop - 1) + assert JSBigInt(1) in Limit( + float(FLOAT64_SAFE_INT_RANGE.start), float(FLOAT64_SAFE_INT_RANGE.stop - 1) + ) + + +def test_LIMIT_KVU64() -> None: + assert LIMIT_KVU64.limit_exceeded is LimitExceededPolicy.WRAP + assert KvU64.RANGE[0] in LIMIT_KVU64 + assert KvU64.RANGE[-1] in LIMIT_KVU64 + + +def test_LIMIT_UNLIMITED() -> None: + assert -(2**256) in LIMIT_UNLIMITED + assert 0 in LIMIT_UNLIMITED + assert 2**256 in LIMIT_UNLIMITED + _limit1: Limit[int] = LIMIT_UNLIMITED + _limit2: Limit[float] = LIMIT_UNLIMITED diff --git a/test/test__kv_writes__Max.py b/test/test__kv_writes__Max.py new file mode 100644 index 0000000..8256f21 --- /dev/null +++ b/test/test__kv_writes__Max.py @@ -0,0 +1,120 @@ +import builtins +from typing import Literal # noqa: TID251 + +import pytest +from v8serialize import Encoder +from v8serialize.jstypes import JSBigInt + +from denokv import _datapath_pb2 as pb2 +from denokv._kv_values import KvU64 +from denokv._kv_writes import BigIntMax +from denokv._kv_writes import FloatMax +from denokv._kv_writes import KvNumber +from denokv._kv_writes import KvNumberInfo +from denokv._kv_writes import Max +from denokv._kv_writes import U64Max +from denokv._pycompat.typing import Any +from denokv._pycompat.typing import NewType +from denokv._pycompat.typing import assert_type +from denokv._pycompat.typing import cast +from denokv._rfc3339 import parse_rfc3339_datetime +from denokv.kv_keys import KvKey +from test.denokv_testing import create_dataclass_slots_test +from test.denokv_testing import typeval + +T1 = parse_rfc3339_datetime("2000-01-02T03:04:05.6Z").value_or_raise() +k = KvKey("a") + + +@pytest.fixture +def instance() -> Max: + return Max(k, 9, KvNumber.float) + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() + + +def test_init__float() -> None: + float_max = Max(k, 9, KvNumber.float.value) + assert float_max.key is k + assert typeval(float_max.value) == (int, 9) + assert float_max.number_type is KvNumber.float.value + + assert typeval(Max(k, 9.0).value) == (float, 9.0) + assert Max(k, 9) == float_max + assert Max(k, 9.0) == float_max + + for nt in ("float", KvNumber.float, builtins.float, KvNumber.float.value): + assert Max(k, 9, nt) == float_max + + assert Max(k, 9, "float", expire_at=T1).expire_at == T1 + + +def test_init__bigint() -> None: + bigint_max = Max(k, 9, KvNumber.bigint.value) + assert bigint_max.key is k + assert typeval(bigint_max.value) == (int, 9) + assert bigint_max.number_type is KvNumber.bigint.value + + assert Max(k, JSBigInt(9)) == bigint_max + + for nt in ("bigint", KvNumber.bigint, JSBigInt, KvNumber.bigint.value): + assert Max(k, 9, nt) == bigint_max + + assert Max(k, 9, "bigint", expire_at=T1).expire_at == T1 + + +def test_init__u64() -> None: + u64_max = Max(k, 9, KvNumber.u64.value) + assert u64_max.key is k + assert typeval(u64_max.value) == (int, 9) + assert u64_max.number_type is KvNumber.u64.value + + assert Max(k, KvU64(9)) == u64_max + + for nt in ("u64", KvNumber.u64, KvU64, KvNumber.u64.value): + assert Max(k, 9, nt) == u64_max + + assert Max(k, 9, "u64", expire_at=T1).expire_at == T1 + + +def test_init__overloads() -> None: + k = KvKey("a") + bigint, float, u64 = KvNumber.bigint.value, KvNumber.float.value, KvNumber.u64.value + assert assert_type(Max(k, 9), FloatMax).number_type == float + assert assert_type(Max(k, 9.0), FloatMax).number_type == float + assert assert_type(Max(k, 9, "float"), FloatMax).number_type == float + assert assert_type(Max(k, 9, KvNumber.float), FloatMax).number_type == float + assert assert_type(Max(k, 9, builtins.float), FloatMax).number_type == float + assert assert_type(Max(k, 9, float), FloatMax).number_type == float + + assert assert_type(Max(k, 9, "bigint"), BigIntMax).number_type == bigint + assert assert_type(Max(k, 9, KvNumber.bigint), BigIntMax).number_type == bigint + assert assert_type(Max(k, 9, JSBigInt), BigIntMax).number_type == bigint + assert assert_type(Max(k, 9, bigint), BigIntMax).number_type == bigint + assert assert_type(Max(k, JSBigInt(9)), BigIntMax).number_type == bigint + + assert assert_type(Max(k, 9, "u64"), U64Max).number_type == u64 + assert assert_type(Max(k, 9, KvNumber.u64), U64Max).number_type == u64 + assert assert_type(Max(k, 9, KvU64), U64Max).number_type == u64 + assert assert_type(Max(k, 9, u64), U64Max).number_type == u64 + assert assert_type(Max(k, KvU64(9)), U64Max).number_type == u64 + + FooInt = NewType("FooInt", int) + BarInt = NewType("BarInt", int) + number_info: KvNumberInfo[Literal["test"], FooInt, BarInt] = cast(Any, bigint) + assert ( + assert_type( + Max(k, FooInt(1), number_info), Max[Literal["test"], FooInt, BarInt] + ) + ).number_type == number_info + + +@pytest.mark.parametrize("number_type", KvNumber) +def test_as_protobuf__float(number_type: KvNumber, v8_encoder: Encoder) -> None: + mutations = Max(k, 9, number_type.value, expire_at=T1).as_protobuf( + v8_encoder=v8_encoder + ) + assert len(mutations) > 0 + # We test the effect of Max mutations elsewhere, e.g. in test_kv. + assert all(isinstance(m, pb2.Mutation) for m in mutations) diff --git a/test/test__kv_writes__Min.py b/test/test__kv_writes__Min.py new file mode 100644 index 0000000..989cf7c --- /dev/null +++ b/test/test__kv_writes__Min.py @@ -0,0 +1,120 @@ +import builtins +from typing import Literal # noqa: TID251 + +import pytest +from v8serialize import Encoder +from v8serialize.jstypes import JSBigInt + +from denokv import _datapath_pb2 as pb2 +from denokv._kv_values import KvU64 +from denokv._kv_writes import BigIntMin +from denokv._kv_writes import FloatMin +from denokv._kv_writes import KvNumber +from denokv._kv_writes import KvNumberInfo +from denokv._kv_writes import Min +from denokv._kv_writes import U64Min +from denokv._pycompat.typing import Any +from denokv._pycompat.typing import NewType +from denokv._pycompat.typing import assert_type +from denokv._pycompat.typing import cast +from denokv._rfc3339 import parse_rfc3339_datetime +from denokv.kv_keys import KvKey +from test.denokv_testing import create_dataclass_slots_test +from test.denokv_testing import typeval + +T1 = parse_rfc3339_datetime("2000-01-02T03:04:05.6Z").value_or_raise() +k = KvKey("a") + + +@pytest.fixture +def instance() -> Min: + return Min(k, 9, KvNumber.float) + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() + + +def test_init__float() -> None: + float_min = Min(k, 9, KvNumber.float.value) + assert float_min.key is k + assert typeval(float_min.value) == (int, 9) + assert float_min.number_type is KvNumber.float.value + + assert typeval(Min(k, 9.0).value) == (float, 9.0) + assert Min(k, 9) == float_min + assert Min(k, 9.0) == float_min + + for nt in ("float", KvNumber.float, builtins.float, KvNumber.float.value): + assert Min(k, 9, nt) == float_min + + assert Min(k, 9, "float", expire_at=T1).expire_at == T1 + + +def test_init__bigint() -> None: + bigint_min = Min(k, 9, KvNumber.bigint.value) + assert bigint_min.key is k + assert typeval(bigint_min.value) == (int, 9) + assert bigint_min.number_type is KvNumber.bigint.value + + assert Min(k, JSBigInt(9)) == bigint_min + + for nt in ("bigint", KvNumber.bigint, JSBigInt, KvNumber.bigint.value): + assert Min(k, 9, nt) == bigint_min + + assert Min(k, 9, "bigint", expire_at=T1).expire_at == T1 + + +def test_init__u64() -> None: + u64_min = Min(k, 9, KvNumber.u64.value) + assert u64_min.key is k + assert typeval(u64_min.value) == (int, 9) + assert u64_min.number_type is KvNumber.u64.value + + assert Min(k, KvU64(9)) == u64_min + + for nt in ("u64", KvNumber.u64, KvU64, KvNumber.u64.value): + assert Min(k, 9, nt) == u64_min + + assert Min(k, 9, "u64", expire_at=T1).expire_at == T1 + + +def test_init__overloads() -> None: + k = KvKey("a") + bigint, float, u64 = KvNumber.bigint.value, KvNumber.float.value, KvNumber.u64.value + assert assert_type(Min(k, 9), FloatMin).number_type == float + assert assert_type(Min(k, 9.0), FloatMin).number_type == float + assert assert_type(Min(k, 9, "float"), FloatMin).number_type == float + assert assert_type(Min(k, 9, KvNumber.float), FloatMin).number_type == float + assert assert_type(Min(k, 9, builtins.float), FloatMin).number_type == float + assert assert_type(Min(k, 9, float), FloatMin).number_type == float + + assert assert_type(Min(k, 9, "bigint"), BigIntMin).number_type == bigint + assert assert_type(Min(k, 9, KvNumber.bigint), BigIntMin).number_type == bigint + assert assert_type(Min(k, 9, JSBigInt), BigIntMin).number_type == bigint + assert assert_type(Min(k, 9, bigint), BigIntMin).number_type == bigint + assert assert_type(Min(k, JSBigInt(9)), BigIntMin).number_type == bigint + + assert assert_type(Min(k, 9, "u64"), U64Min).number_type == u64 + assert assert_type(Min(k, 9, KvNumber.u64), U64Min).number_type == u64 + assert assert_type(Min(k, 9, KvU64), U64Min).number_type == u64 + assert assert_type(Min(k, 9, u64), U64Min).number_type == u64 + assert assert_type(Min(k, KvU64(9)), U64Min).number_type == u64 + + FooInt = NewType("FooInt", int) + BarInt = NewType("BarInt", int) + number_info: KvNumberInfo[Literal["test"], FooInt, BarInt] = cast(Any, bigint) + assert ( + assert_type( + Min(k, FooInt(1), number_info), Min[Literal["test"], FooInt, BarInt] + ) + ).number_type == number_info + + +@pytest.mark.parametrize("number_type", KvNumber) +def test_as_protobuf__float(number_type: KvNumber, v8_encoder: Encoder) -> None: + mutations = Min(k, 9, number_type.value, expire_at=T1).as_protobuf( + v8_encoder=v8_encoder + ) + assert len(mutations) > 0 + # We test the effect of Min mutations elsewhere, e.g. in test_kv. + assert all(isinstance(m, pb2.Mutation) for m in mutations) diff --git a/test/test__kv_writes__PlannedWrite.py b/test/test__kv_writes__PlannedWrite.py new file mode 100644 index 0000000..6a0e116 --- /dev/null +++ b/test/test__kv_writes__PlannedWrite.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import re +from unittest.mock import create_autospec + +import pytest +from v8serialize import Encoder +from v8serialize.jstypes import JSBigInt +from yarl import URL + +from denokv import _datapath_pb2 as datapath_pb2 +from denokv._kv_types import KvWriter +from denokv._kv_types import KvWriterWriteResult +from denokv._kv_values import KvEntry +from denokv._kv_values import KvU64 +from denokv._kv_values import VersionStamp +from denokv._kv_writes import Check +from denokv._kv_writes import CommittedWrite +from denokv._kv_writes import ConflictedWrite +from denokv._kv_writes import Delete +from denokv._kv_writes import Enqueue +from denokv._kv_writes import FailedWrite +from denokv._kv_writes import Limit +from denokv._kv_writes import LimitExceededPolicy +from denokv._kv_writes import Max +from denokv._kv_writes import Min +from denokv._kv_writes import PlannedWrite +from denokv._kv_writes import Sum +from denokv._pycompat.typing import TypedDict +from denokv._rfc3339 import parse_rfc3339_datetime +from denokv.auth import ConsistencyLevel +from denokv.auth import EndpointInfo +from denokv.datapath import AutoRetry +from denokv.datapath import CheckFailure +from denokv.datapath import ResponseUnsuccessful +from denokv.kv_keys import KvKey +from denokv.result import Err +from denokv.result import Ok +from test.denokv_testing import mocked + +EP = EndpointInfo(URL("https://example.com/"), consistency=ConsistencyLevel.STRONG) + + +@pytest.fixture +def planned_write() -> PlannedWrite: + return ( + PlannedWrite() + .check(KvKey("check1"), VersionStamp(1)) + .check(KvKey("check2"), VersionStamp(2)) + .check(KvKey("check3"), None) + .sum(KvKey("sum1"), 1, clamp_under=0, clamp_over=2) + .sum(KvKey("sum2"), 2, "bigint", clamp_under=0) + .sum(KvKey("sum3"), 4, "u64") + .delete(KvKey("delete1")) + .enqueue(message="Hi", retry_delays=(1, 2, 3)) + ) + + +@pytest.mark.asyncio() +async def test_as_protobuf( + v8_encoder: Encoder, +) -> None: + assert PlannedWrite().as_protobuf(v8_encoder=v8_encoder) == ( + datapath_pb2.AtomicWrite(), + ) + + T1 = parse_rfc3339_datetime("2000-01-01T00:00:00Z").value_or_raise() + + planned_write_start = PlannedWrite() + planned_write = ( + planned_write_start.check(KvEntry(KvKey("check1"), None, VersionStamp(1))) + .check(Check(KvKey("check2"), VersionStamp(2))) + .check_key_not_set(KvKey("check3")) + .check_key_has_version(KvKey("check4"), VersionStamp(4)) + .check(KvKey("check5")) + .sum(KvKey("sum1"), KvU64(1)) + .sum(KvKey("sum2"), 2.0, abort_under=0) + .sum(KvKey("sum3"), 0.2, clamp_under=0, clamp_over=1) + .sum(KvKey("sum4"), 4.0, limit=Limit(1.0, 3.0, LimitExceededPolicy.CLAMP)) + .mutate(Sum(KvKey("sum5"), JSBigInt(42))) + .min(KvKey("min1"), 1) + .min(KvKey("min2"), KvU64(2)) + .mutate(Min(KvKey("min3"), 3, expire_at=T1)) + .max(KvKey("max1"), 1) + .max(KvKey("max2"), KvU64(2)) + .mutate(Max(KvKey("max3"), 3, expire_at=T1)) + .delete(KvKey("delete1")) + .mutate(Delete(KvKey("delete2"))) + .enqueue({"event": "example1"}, delivery_time=T1, retry_delays=[1, 10, 100]) + .enqueue( + Enqueue({"event": "example2"}, delivery_time=T1, retry_delays=[1, 10, 100]) + ) + ) + assert planned_write is planned_write_start # builder methods update in-place + + assert planned_write.as_protobuf(v8_encoder=v8_encoder) == ( + datapath_pb2.AtomicWrite( + checks=[ + pb_msg + for check in [ + Check(KvKey("check1"), VersionStamp(1)), + Check(KvKey("check2"), VersionStamp(2)), + Check(KvKey("check3"), None), + Check(KvKey("check4"), VersionStamp(4)), + Check(KvKey("check5"), None), + ] + for pb_msg in check.as_protobuf(v8_encoder=v8_encoder) + ], + mutations=[ + pb_msg + for mutation in [ + Sum(KvKey("sum1"), KvU64(1)), + Sum(KvKey("sum2"), 2.0, abort_under=0), + Sum(KvKey("sum3"), 0.2, clamp_under=0, clamp_over=1), + Sum( + KvKey("sum4"), + 4.0, + limit=Limit(1.0, 3.0, LimitExceededPolicy.CLAMP), + ), + Sum(KvKey("sum5"), JSBigInt(42)), + Min(KvKey("min1"), 1), + Min(KvKey("min2"), KvU64(2)), + Min(KvKey("min3"), 3, expire_at=T1), + Max(KvKey("max1"), 1), + Max(KvKey("max2"), KvU64(2)), + Max(KvKey("max3"), 3, expire_at=T1), + Delete(KvKey("delete1")), + Delete(KvKey("delete2")), + ] + for pb_msg in mutation.as_protobuf(v8_encoder=v8_encoder) + ], + enqueues=[ + pb_msg + for enqueue in [ + Enqueue( + {"event": "example1"}, + delivery_time=T1, + retry_delays=[1, 10, 100], + ), + Enqueue( + {"event": "example2"}, + delivery_time=T1, + retry_delays=[1, 10, 100], + ), + ] + for pb_msg in enqueue.as_protobuf(v8_encoder=v8_encoder) + ], + ), + ) + + +class AtomicWriteRepresentationWriterWriteOptions(TypedDict, total=False): + kv: KvWriter + v8_encoder: Encoder + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("kv_via_write_arg", [False, True]) +@pytest.mark.parametrize("v8_encoder_via_write_arg", [False, True]) +async def test_write__handles_successful_write( + kv_via_write_arg: bool, + v8_encoder_via_write_arg: bool, + planned_write: PlannedWrite, + v8_encoder: Encoder, +) -> None: + writer: KvWriter = create_autospec(KvWriter) + successful_write: KvWriterWriteResult = Ok((VersionStamp(1), EP)) + mocked(writer.write).return_value = successful_write + + kwargs = AtomicWriteRepresentationWriterWriteOptions() + if kv_via_write_arg: + kwargs["kv"] = writer + else: + planned_write.kv = writer + if v8_encoder_via_write_arg: + kwargs["v8_encoder"] = v8_encoder + else: + planned_write.v8_encoder = v8_encoder + + result = await planned_write.write(**kwargs) + + versionstamp, endpoint = successful_write.value_or_raise() + assert result == CommittedWrite( + versionstamp=versionstamp, + endpoint=endpoint, + checks=planned_write.checks, + mutations=planned_write.mutations, + enqueues=planned_write.enqueues, + ) + mocked(writer.write).assert_called_once_with( + protobuf_atomic_write=planned_write.as_protobuf(v8_encoder=v8_encoder)[0] + ) + + +@pytest.mark.asyncio() +async def test_write__handles_unsuccessful_conflicted_write( + planned_write: PlannedWrite, + v8_encoder: Encoder, +) -> None: + writer: KvWriter = create_autospec(KvWriter) + failed_write: KvWriterWriteResult = Err( + error := CheckFailure( + "Not all checks required by the Atomic Write passed", + all_checks=[pb for c in planned_write.checks for pb in c.as_protobuf()], + failed_check_indexes=[0, 2], + endpoint=EP, + ) + ) + mocked(writer.write).return_value = failed_write + + result = await planned_write.write(kv=writer, v8_encoder=v8_encoder) + + assert result == ConflictedWrite( + failed_checks=error.failed_check_indexes, + checks=planned_write.checks, + mutations=planned_write.mutations, + enqueues=planned_write.enqueues, + endpoint=error.endpoint, + ) + + +@pytest.mark.asyncio() +async def test_write__handles_write_request_failure( + planned_write: PlannedWrite, v8_encoder: Encoder +) -> None: + writer: KvWriter = create_autospec(KvWriter) + failed_write: KvWriterWriteResult = Err( + error := ResponseUnsuccessful( + "Server rejected Data Path request indicating client error", + status=403, + body_text="Permission denied", + endpoint=EP, + auto_retry=AutoRetry.NEVER, + ) + ) + mocked(writer.write).return_value = failed_write + + with pytest.raises(FailedWrite) as exc_info: + await planned_write.write(kv=writer, v8_encoder=v8_encoder) + + assert exc_info.value.__cause__ == error + + +@pytest.mark.asyncio() +async def test_write__requires_Kv() -> None: + with pytest.raises( + TypeError, + match=re.escape( + "PlannedWrite.write() must get a value for its 'kv' argument when " + "'self.kv' isn't set" + ), + ): + await PlannedWrite().write() + + +@pytest.mark.asyncio() +async def test_check__raises_on_invalid_use() -> None: + with pytest.raises( + TypeError, + match=r"'versionstamp' argument cannot be set when the first argument to " + r"check\(\) is an object with 'key' and 'versionstamp' attributes", + ): + PlannedWrite().check( + KvEntry(KvKey("a"), None, versionstamp=VersionStamp(1)), + versionstamp=VersionStamp(2), + ) # type: ignore[call-overload] + + with pytest.raises( + TypeError, + match=r"'versionstamp' argument cannot be set when the first argument to " + r"check\(\) is an object with an 'as_protobuf' method", + ): + PlannedWrite().check( + Check(KvKey("a"), versionstamp=VersionStamp(1)), + versionstamp=VersionStamp(2), + ) # type: ignore[call-overload] diff --git a/test/test__kv_writes__Set.py b/test/test__kv_writes__Set.py new file mode 100644 index 0000000..f264b4b --- /dev/null +++ b/test/test__kv_writes__Set.py @@ -0,0 +1,86 @@ +import pytest +from v8serialize import Encoder + +from denokv import _datapath_pb2 as datapath_pb2 +from denokv._kv_values import KvU64 +from denokv._kv_writes import Set +from denokv._rfc3339 import parse_rfc3339_datetime +from denokv.kv_keys import KvKey +from test.denokv_testing import create_dataclass_slots_test + +T1 = parse_rfc3339_datetime("2000-01-02T03:04:05.6Z").value_or_raise() + + +@pytest.fixture +def instance() -> Set: + return Set(KvKey("a"), "foo") + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() + + +def test_constructors() -> None: + value = {"foo": "bar"} + instance = Set(KvKey("a"), value) + assert instance.key == KvKey("a") + assert instance.value is value + assert instance.versioned is False + assert instance.expire_at is None + + instance = Set(KvKey("a"), value, expire_at=T1, versioned=True) + assert instance.key == KvKey("a") + assert instance.value is value + assert instance.versioned is True + assert instance.expire_at == T1 + + +def test_as_protobuf(v8_encoder: Encoder) -> None: + v8_value = {"foo": "bar"} + instance = Set(KvKey("a"), v8_value) + + assert instance.as_protobuf(v8_encoder=v8_encoder) == ( + datapath_pb2.Mutation( + mutation_type=datapath_pb2.M_SET, + key=bytes(KvKey("a")), + value=datapath_pb2.KvValue( + data=bytes(v8_encoder.encode(v8_value)), encoding=datapath_pb2.VE_V8 + ), + ), + ) + + byte_value = b"\x00\xff" + instance = Set(KvKey("a"), byte_value) + + assert instance.as_protobuf(v8_encoder=v8_encoder) == ( + datapath_pb2.Mutation( + mutation_type=datapath_pb2.M_SET, + key=bytes(KvKey("a")), + value=datapath_pb2.KvValue(data=byte_value, encoding=datapath_pb2.VE_BYTES), + ), + ) + + kvu64_value = KvU64(2) + instance = Set(KvKey("a"), kvu64_value) + + assert instance.as_protobuf(v8_encoder=v8_encoder) == ( + datapath_pb2.Mutation( + mutation_type=datapath_pb2.M_SET, + key=bytes(KvKey("a")), + value=datapath_pb2.KvValue( + data=bytes(kvu64_value), encoding=datapath_pb2.VE_LE64 + ), + ), + ) + + instance = Set(KvKey("a"), v8_value, expire_at=T1, versioned=True) + + assert instance.as_protobuf(v8_encoder=v8_encoder) == ( + datapath_pb2.Mutation( + mutation_type=datapath_pb2.M_SET_SUFFIX_VERSIONSTAMPED_KEY, + key=bytes(KvKey("a")), + value=datapath_pb2.KvValue( + data=bytes(v8_encoder.encode(v8_value)), encoding=datapath_pb2.VE_V8 + ), + expire_at_ms=int(T1.timestamp() * 1000), + ), + ) diff --git a/test/test__kv_writes__Sum.py b/test/test__kv_writes__Sum.py new file mode 100644 index 0000000..43ba371 --- /dev/null +++ b/test/test__kv_writes__Sum.py @@ -0,0 +1,403 @@ +from __future__ import annotations + +import re +from datetime import datetime +from decimal import Decimal +from math import isnan +from typing import Literal # noqa: TID251 + +import pytest +from hypothesis import example +from hypothesis import given +from hypothesis import strategies as st +from v8serialize.constants import FLOAT64_SAFE_INT_RANGE +from v8serialize.jstypes import JSBigInt + +from denokv import _datapath_pb2 as datapath_pb2 +from denokv._kv_values import KvU64 +from denokv._kv_writes import LIMIT_KVU64 +from denokv._kv_writes import LIMIT_UNLIMITED +from denokv._kv_writes import BigIntSum +from denokv._kv_writes import FloatSum +from denokv._kv_writes import KvNumber +from denokv._kv_writes import KvNumberIdentifier +from denokv._kv_writes import KvNumberInfo +from denokv._kv_writes import KvNumberNameT +from denokv._kv_writes import KvNumberTypeT +from denokv._kv_writes import Limit +from denokv._kv_writes import LimitExceededPolicy +from denokv._kv_writes import NumberT +from denokv._kv_writes import Sum +from denokv._kv_writes import U64Sum +from denokv._pycompat.typing import Any +from denokv._pycompat.typing import NewType +from denokv._pycompat.typing import assert_type +from denokv._pycompat.typing import cast +from denokv._rfc3339 import parse_rfc3339_datetime +from denokv.datapath import read_range_single +from denokv.kv_keys import KvKey +from denokv.result import Err +from denokv.result import Ok +from denokv.result import Result +from denokv.result import is_err +from test.denokv_testing import MockKvDb +from test.denokv_testing import SumLimitExceeded +from test.denokv_testing import add_entries +from test.denokv_testing import create_dataclass_slots_test +from test.denokv_testing import typeval +from test.denokv_testing import unsafe_parse_protobuf_kv_entry + +T1 = parse_rfc3339_datetime("2000-01-02T03:04:05.6Z").value_or_raise() + +u64 = st.integers(min_value=0, max_value=KvU64.RANGE.stop - 1) +neg_u64 = st.integers(min_value=-(KvU64.RANGE.stop - 1), max_value=0) + + +@pytest.fixture +def instance() -> Sum: + return Sum(KvKey("a"), 1) + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() + + +def test_init__limits() -> None: + sum1 = Sum(KvKey("a"), 1) + assert sum1.limit == LIMIT_UNLIMITED + + sum2 = Sum(KvKey("a"), 1, number_type="u64") + assert sum2.limit == LIMIT_KVU64 + + with pytest.raises( + ValueError, + match=re.escape( + "Limit keyword arguments in conflict: " + "Options abort_*=, clamp_*=, limit= cannot be used together.\n" + "Use limit=Limit(limit_exceeded=..., ...) to create a limit with a " + "dynamic type." + ), + ): + Sum(KvKey("a"), 1, limit=Limit(), clamp_over=1, abort_under=3) + + with pytest.raises( + ValueError, + match=re.escape( + "Limit keyword arguments in conflict: " + "Options abort_*=, clamp_*= cannot be used together.\n" + "Use limit=Limit(limit_exceeded=..., ...) to create a limit with a " + "dynamic type." + ), + ): + Sum(KvKey("a"), 1, clamp_over=1, abort_under=3) + + sum4 = Sum(KvKey("a"), 1, clamp_over=98, clamp_under=2) + assert sum4.limit == Limit(min=2, max=98, limit_exceeded="clamp") + + sum4 = Sum(KvKey("a"), 1, abort_under=3, abort_over=100) + assert sum4.limit == Limit(min=3, max=100, limit_exceeded="abort") + + sum6 = Sum( + KvKey("a"), + 1, + limit=Limit(min=3, limit_exceeded="clamp"), + expire_at=datetime.now(), + ) + assert sum6.limit == Limit(min=3, limit_exceeded="clamp") + + # Passing None as a clamp/abort enables that limit type with the default + sum7 = Sum(KvKey("a"), 1, "u64", clamp_under=None) + assert sum7.limit == Limit(min=None, max=None, limit_exceeded="clamp") + + sum8 = Sum(KvKey("a"), 1, abort_over=None) + assert sum8.limit == Limit(min=None, max=None, limit_exceeded="abort") + + sum9 = Sum(KvKey("a"), 1, "u64", limit=None) + assert sum9.limit == LIMIT_KVU64 + + +def test_init__overloads() -> None: + k = KvKey("a") + bigint, float, u64 = KvNumber.bigint.value, KvNumber.float.value, KvNumber.u64.value + assert assert_type(Sum(k, JSBigInt(1)), BigIntSum).number_type == bigint + assert assert_type(Sum(k, 1, "bigint"), BigIntSum).number_type == bigint + assert assert_type(Sum(k, KvU64(1)), U64Sum).number_type == u64 + assert assert_type(Sum(k, 1, "u64"), U64Sum).number_type == u64 + assert assert_type(Sum(k, 1), FloatSum).number_type == float + assert assert_type(Sum(k, 1.0), FloatSum).number_type == float + assert assert_type(Sum(k, 1.0, "float"), FloatSum).number_type == float + + FooInt = NewType("FooInt", int) + BarInt = NewType("BarInt", int) + number_info: KvNumberInfo[Literal["test"], FooInt, BarInt] = cast(Any, bigint) + assert ( + assert_type( + Sum(k, FooInt(1), number_info), Sum[Literal["test"], FooInt, BarInt] + ) + ).number_type == number_info + + +@pytest.mark.parametrize( + "delta,number_type,expected_delta,expected_number_type", + [ + (1, None, (int, 1), KvNumber.float.value), + (1.0, None, (float, 1.0), KvNumber.float.value), + (1, "float", (int, 1), KvNumber.float.value), + (1.0, float, (float, 1.0), KvNumber.float.value), + (1, KvNumber.float, (int, 1), KvNumber.float.value), + (1, KvNumber.float.value, (int, 1), KvNumber.float.value), + (JSBigInt(1), None, (int, 1), KvNumber.bigint.value), + (1, "bigint", (int, 1), KvNumber.bigint.value), + (1, JSBigInt, (int, 1), KvNumber.bigint.value), + (1, KvNumber.bigint, (int, 1), KvNumber.bigint.value), + (1, KvNumber.bigint.value, (int, 1), KvNumber.bigint.value), + (KvU64(1), None, (int, 1), KvNumber.u64.value), + (1, "u64", (int, 1), KvNumber.u64.value), + (1, KvU64, (int, 1), KvNumber.u64.value), + (1, KvNumber.u64, (int, 1), KvNumber.u64.value), + (1, KvNumber.u64.value, (int, 1), KvNumber.u64.value), + ], +) +def test_init__number_types( + delta: int | float | KvU64 | JSBigInt, + number_type: KvNumberInfo | KvNumberIdentifier | None, + expected_delta: tuple[type[int], int] | tuple[type[float], float], + expected_number_type: KvNumberInfo, +) -> None: + sum = Sum(KvKey("a"), delta, cast(KvNumberInfo, number_type)) + assert typeval(sum.delta) == expected_delta + assert sum.number_type is expected_number_type + + +def test_init() -> None: + k = KvKey("a") + limit1 = Limit(0, 10, "abort") + + sum1 = Sum(k, 1.0) + assert sum1.key is k + assert typeval(sum1.delta) == (float, 1.0) + assert sum1.number_type is KvNumber.float.value + assert sum1.expire_at is None + assert sum1.limit == LIMIT_UNLIMITED + + sum2 = Sum(k, 1, "float", expire_at=T1, limit=limit1) + assert sum2.key == k + assert typeval(sum2.delta) == (int, 1) + assert sum2.number_type is KvNumber.float.value + assert sum2.expire_at is T1 + assert sum2.limit is limit1 + + sum3 = Sum( + key=k, delta=JSBigInt(1), number_type="bigint", expire_at=T1, limit=limit1 + ) + assert sum3.key == k + assert typeval(sum3.delta) == (int, 1) + assert sum3.number_type is KvNumber.bigint.value + assert sum3.expire_at is T1 + assert sum3.limit is limit1 + + with pytest.raises( + TypeError, + match=re.escape("Sum.__init__() got an unexpected keyword argument 'foo'"), + ): + Sum(KvKey("a"), 0, foo="bar") # type: ignore[call-overload] + + +def test_init__unsupported_value_type_is_type_error() -> None: + with pytest.raises( + TypeError, + match=re.escape("number is not supported by any KvNumber: Decimal('42')"), + ): + Sum(KvKey("a"), Decimal(42)) # type: ignore[call-overload] + + with pytest.raises( + TypeError, + match=re.escape( + "number is not compatible with bigint py number type\n" + "number: Decimal('42') (), " + "bigint=BigIntKvNumberInfo(name='bigint', py_type=, " + "kv_type=)" + ), + ): + Sum(KvKey("a"), Decimal(42), "bigint") # type: ignore[call-overload] + + +def test_init__float_number_type_rejects_out_of_range_int_values() -> None: + with pytest.raises( + ValueError, + match=re.escape( + "number is not compatible with float py number type\n" + "number: 9007199254740992 (), " + "float=FloatKvNumberInfo(name='float', py_type=, " + "kv_type=)\n" + "The int is too large to represent as a 64-bit floating point value." + ), + ): + Sum(KvKey("a"), FLOAT64_SAFE_INT_RANGE.stop, "float") + + +def test_init__kvu64_limit_cannot_be_changed() -> None: + assert Sum(KvKey("a"), KvU64(1)).limit == LIMIT_KVU64 + assert Sum(KvKey("a"), KvU64(1), limit=LIMIT_KVU64) == Sum(KvKey("a"), KvU64(1)) + + custom_wrap_limit = Limit(max=42, limit_exceeded=LimitExceededPolicy.WRAP) # type: ignore[arg-type] + with pytest.raises( + ValueError, + match=re.escape( + "Number type 'u64' wrap limit's min, max bounds cannot be changed\n" + "'u64' (KvU64) can only wrap at 0 and 2^64 - 1. It can use clamp " + "with custom bounds through." + ), + ): + Sum(KvKey("a"), KvU64(1), limit=custom_wrap_limit) + + with pytest.raises( + ValueError, + match=re.escape( + "Number type 'bigint' does not support wrap limits\n" + "Use 'u64' (KvU64) to wrap on 0, 2^64 - 1 bounds." + ), + ): + Sum(KvKey("a"), 1, "bigint", limit=LIMIT_KVU64) + + +# delta values beyond +/-2^64 are wrapped to this range. We still include them +# as inputs, to ensure that we are handling them correctly though. We don't just +# use st.integers() as the input, as using the two separate u64 int classes +# should probe 64-bit boundary values more effectively than just using +# st.integers(). +@given(value=u64, delta=u64 | neg_u64 | st.integers()) +def test_as_protobuf__u64_wrap(value: int, delta: int) -> None: + expected = KvU64((value + delta) % KvU64.RANGE.stop) + sum = Sum(KvKey("a"), delta, "u64") + + actual = apply_sum_mutation(sum, value).value_or_raise() + assert actual == expected + + +@given( + value=u64, + delta=u64 | neg_u64 | st.integers(), + clamp_under=st.none() | u64, + clamp_over=st.none() | u64, +) +# Include examples to always hit branches, to avoid random coverage misses. +# constant result as clamp_over <= clamp_under +@example(value=0, delta=-1, clamp_under=0, clamp_over=0) +# constant result as result always meets clamp_under +@example(value=0, delta=-1, clamp_under=KvU64.RANGE.stop - 2, clamp_over=None) +def test_as_protobuf__u64_clamp( + value: int, delta: int, clamp_under: int | None, clamp_over: int | None +) -> None: + expected = KvU64( + min( + KvU64.RANGE.stop - 1 if clamp_over is None else clamp_over, + max( + 0 if clamp_under is None else clamp_under, + value + delta, + ), + ) + ) + sum = Sum(KvKey("a"), delta, "u64", clamp_under=clamp_under, clamp_over=clamp_over) + actual = apply_sum_mutation(sum, value).value_or_raise() + assert actual == expected + + +floats = st.floats(allow_nan=True) +float_safe_integers = st.integers( + min_value=FLOAT64_SAFE_INT_RANGE.start, max_value=FLOAT64_SAFE_INT_RANGE.stop - 1 +) +v8_sum_limits_bigint: st.SearchStrategy[Limit[int]] = st.builds( + Limit, + min=st.none() | st.integers(), + max=st.none() | st.integers(), + limit_exceeded=st.sampled_from( + [LimitExceededPolicy.ABORT, LimitExceededPolicy.CLAMP] + ), +) +v8_sum_limits_float: st.SearchStrategy[Limit[float]] = st.builds( + Limit, + max=st.none() | float_safe_integers | floats, + min=st.none() | float_safe_integers | floats, + limit_exceeded=st.sampled_from( + [ + LimitExceededPolicy.ABORT, + LimitExceededPolicy.CLAMP, + ] + ), +) + + +@given(value=st.integers(), delta=st.integers(), limit=v8_sum_limits_bigint) +def test_as_protobuf__v8_bigint(value: int, delta: int, limit: Limit[int]) -> None: + _test_as_protobuf__v8(KvNumber.bigint.value, value, delta, limit) + + +@given( + value=float_safe_integers | floats, + delta=float_safe_integers | floats, + limit=v8_sum_limits_float, +) +def test_as_protobuf__v8_float(value: float, delta: float, limit: Limit[float]) -> None: + _test_as_protobuf__v8(KvNumber.float.value, value, delta, limit) + + +def _test_as_protobuf__v8( + number_type: KvNumberInfo[KvNumberNameT, NumberT, KvNumberTypeT], + value: NumberT, + delta: NumberT, + limit: Limit[NumberT], +) -> None: + if limit.limit_exceeded == LimitExceededPolicy.ABORT: + # Explicitly calculate the expected result in the kv type, as with + # floats, we can add int values and get greater precision than we would + # with actual floats. (Normally as_kv_type() preserves ints in + # float-safe range.) + expected_value = number_type.kv_type(value) + number_type.kv_type(delta) # type: ignore[call-arg,operator] + should_abort = False + if limit.min is not None and expected_value < limit.min: + should_abort = True + if limit.max is not None and expected_value > limit.max: + should_abort = True + else: + should_abort = False + expected_value = number_type.kv_type(value) + number_type.kv_type(delta) # type: ignore[call-arg,operator] + if limit.min is not None and expected_value < limit.min: + expected_value = limit.min + if limit.max is not None and expected_value > limit.max: + expected_value = limit.max + + sum = Sum(KvKey("a"), delta, number_type, limit=limit) + actual_result = apply_sum_mutation(sum, value) + + if should_abort: + assert is_err(actual_result) + assert isinstance(actual_result.error, SumLimitExceeded) + else: + actual_value = actual_result.value_or_raise() + assert actual_value == expected_value or all( + # We allow nan as an input, and nan can occur independently as a + # result, e.g. inf + -inf = nan. + isinstance(x, float) and isnan(x) + for x in (actual_value, expected_value) + ) + + +def apply_sum_mutation( + sum: Sum[str, NumberT, KvNumberTypeT], value: NumberT +) -> Result[KvNumberTypeT, SumLimitExceeded]: + db = MockKvDb() + add_entries(db, {sum.key: sum.number_type.as_kv_number(value)}) + mutations = sum.as_protobuf() + + try: + write_result = db.atomic_write(datapath_pb2.AtomicWrite(mutations=mutations)) + except Exception as e: + if isinstance((cause := e.__cause__), SumLimitExceeded): + return Err(cause) + raise e + + assert write_result.status == datapath_pb2.AW_SUCCESS + raw_entry = db.snapshot_read_range(read_range_single(sum.key)).values[0] + entry = unsafe_parse_protobuf_kv_entry(raw_entry) + assert sum.number_type.is_kv_number(entry.value) + return Ok(entry.value) diff --git a/test/test__kv_writes__U64KvNumberType.py b/test/test__kv_writes__U64KvNumberType.py new file mode 100644 index 0000000..625b3bd --- /dev/null +++ b/test/test__kv_writes__U64KvNumberType.py @@ -0,0 +1,202 @@ +""" +Proof-of-concepts for extended atomic mutations for KvU64 numbers. + +The default KvU64 supports sum() with positive delta and wrapping at 2**64. +This module also implements: + +- sum() with negative delta and wrapping at 2**64 +- sum() with positive and negative delta, with clamping at user-defined bounds + - This is the same as BigInt + - Which makes KvU64 more powerful than BigInt to some extent + +It does not support wrapping on custom bounds, or error limits. +""" + +from __future__ import annotations + +from hypothesis import example +from hypothesis import given +from hypothesis import strategies as st + +from denokv._kv_values import KvU64 + +u64 = st.integers(min_value=0, max_value=KvU64.RANGE.stop - 1) +neg_u64 = st.integers(min_value=-(KvU64.RANGE.stop - 1), max_value=0) + + +def sum_with_clamp__no_overflow( + value: int, delta: int, limit_min: int | None, limit_max: int | None +) -> int: + if limit_min is None: + limit_min = 0 + if limit_max is None: + limit_max = KvU64.RANGE.stop - 1 + + assert all(x in KvU64.RANGE for x in [value, abs(delta), limit_min, limit_max]) + + result = value + delta + if limit_min is not None: + result = max(limit_min, result) + if limit_max is not None: + result = min(limit_max, result) + return result + + +def sum_with_clamp__overflow( + value: int, delta: int, limit_min: int | None, limit_max: int | None +) -> int: + if limit_min is None: + limit_min = 0 + if limit_max is None: + limit_max = KvU64.RANGE.stop - 1 + assert all( + x in KvU64.RANGE for x in [value, delta, limit_min, limit_max] if x is not None + ) + + # When the upper limit is <= the delta, the result is always clamped at the + # upper limit. Likewise if the lower limit pushes the result above the upper + # limit, the upper limit is used (it's applied last). + min_result = 0 + delta + if limit_max <= min_result or limit_max <= limit_min: + return limit_max # set to max as mutation + + if limit_min >= limit_max or limit_min <= delta: + limit_min = None + + # Can be < 0 which is not allowed in practice. + # We can use high positive numbers like negative and rely on wrapping + max_start = limit_max - delta + if max_start < 0: + start = max(max_start % 2**64, value) + else: + start = min(max_start, value) + assert start in KvU64.RANGE + result = start + delta + result = result % KvU64.RANGE.stop + if limit_min is not None: + result = max(limit_min, result) + return result + + +@given(value=u64, delta=u64, limit_min=u64 | st.none(), limit_max=u64 | st.none()) +def test_sum_min_max( + value: int, delta: int, limit_min: int | None, limit_max: int | None +) -> None: + """Implement sum() with clamp for KvU64 (which can only wrap normally).""" + expected = sum_with_clamp__no_overflow(value, delta, limit_min, limit_max) + actual = sum_with_clamp__overflow(value, delta, limit_min, limit_max) + + assert actual == expected + + +# --------------------------- + + +def neg_sum_with_clamp__overflow( + value: int, + delta: int, + limit_min: int | None, + limit_max: int | None, +) -> int: + if limit_max is None: + limit_max = KvU64.RANGE.stop - 1 + if limit_min is None: + limit_min = 0 + assert delta <= 0 + + assert all(x in KvU64.RANGE for x in [value, abs(delta), limit_min, limit_max]) + + # If value after adding the delta is always <= the lower limit, the lower + # limit is always the result. However the upper limit applies last, so if + # the upper limit is lower than the lower limit, it applies instead. + if limit_max <= limit_min: + return limit_max # set to limit_max as mutation + max_result = (KvU64.RANGE.stop - 1) + delta + if limit_min >= max_result: + assert limit_max > limit_min + return limit_min # set to limit_min as mutation + + if limit_max >= max_result: + assert limit_max > limit_min + # limit_max can have no effect on the result + limit_max = None + + # Offset the start to prevent it going negative after adding the delta + min_start = abs(delta) + limit_min + if min_start >= KvU64.RANGE.stop: + start = min(min_start % KvU64.RANGE.stop, value) + else: + start = max(min_start, value) + assert start in KvU64.RANGE + + # Make the negative delta to a positive value that overflows to the original + # negative delta offset. + if delta < 0: + delta = KvU64.RANGE.stop + delta + assert delta in KvU64.RANGE + + # Apply the delta (effectively subtracting) + result = (start + delta) % (KvU64.RANGE.stop) + assert result in KvU64.RANGE + + if limit_max is not None: + result = min(limit_max, result) + return result + + +@given(value=u64, delta=neg_u64, limit_min=u64 | st.none(), limit_max=u64 | st.none()) +@example(value=2**64 - 1, delta=-1, limit_min=0, limit_max=2**64 - 3) +def test_negative_sum_min_max( + value: int, delta: int, limit_min: int | None, limit_max: int | None +) -> None: + """ + Implement sum() with negative delta for KvU64 with clamp min/max. + + UvU64 sum can only add positive values with wrapping normally. + """ + expected = sum_with_clamp__no_overflow(value, delta, limit_min, limit_max) + actual = neg_sum_with_clamp__overflow(value, delta, limit_min, limit_max) + + assert actual == expected + + +# --------------------------- + + +def neg_sum_with_wrap__no_overflow(value: int, delta: int) -> int: + assert delta <= 0 + assert all(x in KvU64.RANGE for x in [value, abs(delta)] if x is not None) + + result = (value + delta) % (2**64) + assert result in KvU64.RANGE + return result + + +def neg_sum_with_wrap__overflow( + value: int, + delta: int, +) -> int: + assert delta <= 0 + + assert all(x in KvU64.RANGE for x in [value, abs(delta)]) + + if delta == 0: + return value # no mutation + + delta = 2**64 + delta + assert delta in KvU64.RANGE + + result = (value + delta) % (2**64) + assert result in KvU64.RANGE + + return result + + +# TODO: can we do wrapping on custom limits, not just 0 and 2**64? +@given(value=u64, delta=neg_u64) +def test_negative_sum_with_wrap(value: int, delta: int) -> None: + """Implement sum() with negative delta for KvU64 (with wrapping).""" + expected = neg_sum_with_wrap__no_overflow(value, delta) + actual = neg_sum_with_wrap__overflow(value, delta) + + assert actual == expected diff --git a/test/test__utils.py b/test/test__utils.py new file mode 100644 index 0000000..4ae0c79 --- /dev/null +++ b/test/test__utils.py @@ -0,0 +1,54 @@ +from dataclasses import FrozenInstanceError +from dataclasses import dataclass +from enum import Enum + +import pytest + +from denokv._pycompat.typing import Final +from denokv._utils import frozen + + +def test_frozen_decorator() -> None: + @dataclass + class Info: + label: Final[str] # type: ignore[misc] + size: Final[int] # type: ignore[misc] + + class Things(Info, Enum): + FOO = "foo", 42 + BAR = "bar", 100 + + # Enums prevent assigning to special enum fields + with pytest.raises(AttributeError): + Things.FOO.name = "XXX" # type: ignore[misc] + + # But custom fields are writable + assert Things.FOO.label == "foo" + assert Things.FOO.size == 42 + + Things.FOO.label = "lol" # type: ignore[misc] + del Things.FOO.size + + assert Things.FOO.label == "lol" + assert not hasattr(Things.FOO, "size") + + # By not when using @frozen + + @frozen + class FrozenThings(Info, Enum): + FOO = "foo", 42 + BAR = "bar", 100 + + with pytest.raises(FrozenInstanceError): + FrozenThings.FOO.name = "XXX" # type: ignore[misc] + + assert FrozenThings.FOO.label == "foo" + assert FrozenThings.FOO.size == 42 + + with pytest.raises(FrozenInstanceError): + FrozenThings.FOO.label = "lol" # type: ignore[misc] + with pytest.raises(FrozenInstanceError): + del FrozenThings.FOO.size + + assert FrozenThings.FOO.label == "foo" + assert FrozenThings.FOO.size == 42 diff --git a/test/test_auth.py b/test/test_auth.py index 00de3a9..e430165 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -2,6 +2,7 @@ import json from copy import deepcopy +from datetime import datetime from uuid import UUID import aiohttp @@ -30,12 +31,47 @@ from denokv.auth import read_metadata_exchange_response from test.denokv_testing import assume_err from test.denokv_testing import assume_ok +from test.denokv_testing import create_dataclass_slots_test TestClient: TypeAlias = _TestClient[web.Request, web.Application] pytest_mark_asyncio = pytest.mark.asyncio() +@pytest.fixture( + params=[ + pytest.param( + lambda: DatabaseMetadata( + version=2, + database_id=UUID("AD50A341-5351-4FC3-82D0-72CFEE369A09"), + token="thisisnotasecret", + expires_at=datetime.now(), + endpoints=( + EndpointInfo( + url=URL("https://db.example.com/v2"), + consistency=ConsistencyLevel.STRONG, + ), + ), + ), + id="DatabaseMetadata", + ), + pytest.param( + lambda: EndpointInfo( + url=URL("https://db.example.com/v2"), + consistency=ConsistencyLevel.STRONG, + ), + id="EndpointInfo", + ), + ] +) +def instance(request: pytest.FixtureRequest) -> object: + param: Callable[[], object] = request.param + return param() + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() + + @pytest.fixture def valid_metadata_exchange_response() -> dict[str, object]: return { diff --git a/test/test_datapath.py b/test/test_datapath.py index 8ab2248..00994fd 100644 --- a/test/test_datapath.py +++ b/test/test_datapath.py @@ -3,43 +3,55 @@ import functools import re import struct -from datetime import datetime -from datetime import timedelta -from typing import Literal -from uuid import UUID +from typing import Literal # noqa: TID251 import pytest import pytest_asyncio import v8serialize from aiohttp import web from aiohttp.test_utils import TestClient as _TestClient +from aiohttp.typedefs import Handler from fdb.tuple import pack from fdb.tuple import unpack +from google.protobuf.message import Message from hypothesis import example from hypothesis import given from hypothesis import strategies as st from v8serialize import Decoder +from v8serialize.jstypes import JSBigInt from yarl import URL from denokv import datapath +from denokv._datapath_pb2 import AtomicWrite +from denokv._datapath_pb2 import AtomicWriteOutput +from denokv._datapath_pb2 import AtomicWriteStatus +from denokv._datapath_pb2 import Check from denokv._datapath_pb2 import KvEntry as ProtobufKvEntry +from denokv._datapath_pb2 import KvValue +from denokv._datapath_pb2 import Mutation +from denokv._datapath_pb2 import MutationType from denokv._datapath_pb2 import ReadRange from denokv._datapath_pb2 import ReadRangeOutput from denokv._datapath_pb2 import SnapshotRead from denokv._datapath_pb2 import SnapshotReadOutput from denokv._datapath_pb2 import SnapshotReadStatus from denokv._datapath_pb2 import ValueEncoding +from denokv._kv_values import KvEntry +from denokv._kv_values import KvU64 +from denokv._kv_values import VersionStamp from denokv._pycompat.typing import Awaitable from denokv._pycompat.typing import Callable -from denokv._pycompat.typing import Final +from denokv._pycompat.typing import Iterable from denokv._pycompat.typing import Mapping +from denokv._pycompat.typing import Sequence from denokv._pycompat.typing import TypeAlias +from denokv._pycompat.typing import TypeVar from denokv._pycompat.typing import cast from denokv.auth import ConsistencyLevel -from denokv.auth import DatabaseMetadata from denokv.auth import EndpointInfo from denokv.datapath import KV_KEY_PIECE_TYPES from denokv.datapath import AutoRetry +from denokv.datapath import CheckFailure from denokv.datapath import DataPathDenoKvError from denokv.datapath import EndpointNotUsable from denokv.datapath import EndpointNotUsableReason @@ -48,6 +60,8 @@ from denokv.datapath import ProtocolViolation from denokv.datapath import RequestUnsuccessful from denokv.datapath import ResponseUnsuccessful +from denokv.datapath import _DataPathRequestKind +from denokv.datapath import atomic_write from denokv.datapath import increment_packed_key from denokv.datapath import is_any_kv_key from denokv.datapath import is_kv_key_tuple @@ -56,14 +70,17 @@ from denokv.datapath import parse_protobuf_kv_entry from denokv.datapath import read_range_single from denokv.datapath import snapshot_read -from denokv.kv import KvEntry -from denokv.kv import KvU64 -from denokv.kv import VersionStamp from denokv.kv_keys import KvKey from denokv.result import Err from denokv.result import Ok +from denokv.result import Result +from denokv.result import is_ok from test.denokv_testing import MockKvDb from test.denokv_testing import add_entries +from test.denokv_testing import default_v8_encoder +from test.denokv_testing import make_database_metadata +from test.denokv_testing import meta_endpoint +from test.denokv_testing import mock_db_api from test.denokv_testing import nextafter from test.denokv_testing import unsafe_parse_protobuf_kv_entry @@ -71,6 +88,8 @@ pytest_mark_asyncio = pytest.mark.asyncio() +MessageT = TypeVar("MessageT", bound=Message) + @pytest.fixture def mock_db() -> MockKvDb: @@ -94,67 +113,9 @@ def example_entries() -> Mapping[KvKeyTuple, object]: @pytest.fixture def db_api(mock_db: MockKvDb) -> web.Application: - def get_server_version(request: web.Request) -> Literal[1, 2, 3]: - match = re.match(r"^/v([123])/", request.path) - version: Final = int(match.group(1)) if match else -1 - if version not in (1, 2, 3): - raise AssertionError("handler is not registered at /v[123]/ URL path") - return cast(Literal[1, 2, 3], version) - - async def strong_snapshot_read(request: web.Request) -> web.Response: - server_version = get_server_version(request) - - if request.method != "POST": - raise web.HTTPBadRequest(body="method must be POST") - if request.content_type != "application/x-protobuf": - raise web.HTTPBadRequest(body="content-type must be application/x-protobuf") - - db_id_header = ( - "x-transaction-domain-id" if server_version == 1 else "x-denokv-database-id" - ) - try: - UUID(request.headers.get(db_id_header, "")) - except Exception: - raise web.HTTPBadRequest( - body=f"client did not set a valid {db_id_header} when talking to a " - f"v{server_version} server" - ) from None - - if server_version > 2: - try: - client_version = int(request.headers.get("x-denokv-version", "")) - if client_version not in (2, 3): - raise ValueError(f"invalid client_version: {client_version}") - except Exception: - raise web.HTTPBadRequest( - body=f"client did not set a valid x-denokv-version header when " - f"talking to a v{server_version} server" - ) from None - - req_body_bytes = await request.read() - try: - read = SnapshotRead() - count = read.ParseFromString(req_body_bytes) - if len(req_body_bytes) != count: - raise ValueError( - f"{len(req_body_bytes) - count} trailing bytes after SnapshotRead" - ) - except Exception as e: - raise web.HTTPBadRequest( - body=f"body is not a valid SnapshotRead message: {e}" - ) from e - - read_result = SnapshotReadOutput( - status=SnapshotReadStatus.SR_SUCCESS, - read_is_strongly_consistent=True, - ranges=[mock_db.snapshot_read_range(r) for r in read.ranges], - ) - return web.Response( - status=200, - content_type="application/x-protobuf", - body=read_result.SerializeToString(), - ) + """HTTP endpoints backed by a MockKvDb, plus various misbehaving endpoints.""" + # Generic Data Path errors async def violation_2xx_text_body(request: web.Request) -> web.Response: """Only 200, not 2xx is the permitted successful response status.""" return web.Response( @@ -197,6 +158,7 @@ async def violation_invalid_protobuf_body(request: web.Request) -> web.Response: body=b"\x00foo", ) + # Invalid snapshot_read handlers async def unusable_disabled_via_read_disabled(request: web.Request) -> web.Response: return web.Response( status=200, @@ -255,23 +217,120 @@ async def violation_wrong_ranges(request: web.Request) -> web.Response: ).SerializeToString(), ) - app = web.Application() - app.router.add_post( - "/violation_2xx_text_body/snapshot_read", violation_2xx_text_body - ) - app.router.add_post( - "/violation_2xx_protobuf_body/snapshot_read", violation_2xx_protobuf_body - ) - app.router.add_post("/violation_307/snapshot_read", violation_307) - app.router.add_post("/errors_401/snapshot_read", errors_401) - app.router.add_post("/errors_503/snapshot_read", errors_503) - app.router.add_post( - "/violation_bad_content_type/snapshot_read", violation_bad_content_type - ) - app.router.add_post( - "/violation_invalid_protobuf_body/snapshot_read", - violation_invalid_protobuf_body, + # Invalid atomic_write handlers + async def violation_atomic_write_success_with_failed_checks( + request: web.Request, + ) -> web.Response: + write = AtomicWrite() + write.ParseFromString(await request.read()) + assert len(write.checks) > 0, "write request must have at least one check" + + return web.Response( + status=200, + content_type="application/x-protobuf", + body=AtomicWriteOutput( + status=AtomicWriteStatus.AW_SUCCESS, + failed_checks=[0], + versionstamp=VersionStamp(0), + ).SerializeToString(), + ) + + async def violation_atomic_write_success_with_invalid_versionstamp( + request: web.Request, + ) -> web.Response: + return web.Response( + status=200, + content_type="application/x-protobuf", + body=AtomicWriteOutput( + status=AtomicWriteStatus.AW_SUCCESS, versionstamp=b"\xff" + ).SerializeToString(), + ) + + async def violation_atomic_write_check_failure_with_out_of_bounds_index( + request: web.Request, + ) -> web.Response: + write = AtomicWrite() + write.ParseFromString(await request.read()) + assert len(write.checks) > 0, "write request must have at least one check" + + return web.Response( + status=200, + content_type="application/x-protobuf", + body=AtomicWriteOutput( + status=AtomicWriteStatus.AW_CHECK_FAILURE, + failed_checks=[len(write.checks)], + ).SerializeToString(), + ) + + # The denokv self-hosted implementation does not return indexes of failed + # checks. + # https://github.com/denoland/denokv/issues/110 + async def quirk_atomic_write_check_failure_without_failed_checks( + request: web.Request, + ) -> web.Response: + write = AtomicWrite() + write.ParseFromString(await request.read()) + assert len(write.checks) > 0, "write request must have at least one check" + + return web.Response( + status=200, + content_type="application/x-protobuf", + body=AtomicWriteOutput( + status=AtomicWriteStatus.AW_CHECK_FAILURE + ).SerializeToString(), + ) + + async def violation_atomic_write_unspecified_status( + request: web.Request, + ) -> web.Response: + return web.Response( + status=200, + content_type="application/x-protobuf", + body=AtomicWriteOutput( + status=AtomicWriteStatus.AW_UNSPECIFIED + ).SerializeToString(), + ) + + async def violation_atomic_write_invalid_status( + request: web.Request, + ) -> web.Response: + return web.Response( + status=200, + content_type="application/x-protobuf", + body=AtomicWriteOutput( + status=max(AtomicWriteStatus.values()) + 1 # type: ignore[attr-defined] + ).SerializeToString(), + ) + + async def unusable_atomic_write(request: web.Request) -> web.Response: + return web.Response( + status=200, + content_type="application/x-protobuf", + body=AtomicWriteOutput( + status=AtomicWriteStatus.AW_WRITE_DISABLED + ).SerializeToString(), + ) + + def add_datapath_post(app: web.Application, path: str, handler: Handler) -> None: + assert path.startswith("/") and not path.endswith("/") + for req_kind in _DataPathRequestKind: + app.router.add_post(f"{path}/{req_kind.value}", handler) + + app = mock_db_api(mock_db) + + # Generic error endpoints + add_datapath_post(app, "/violation_2xx_text_body", violation_2xx_text_body) + + add_datapath_post(app, "/violation_2xx_protobuf_body", violation_2xx_protobuf_body) + add_datapath_post(app, "/violation_307", violation_307) + add_datapath_post(app, "/errors_401", errors_401) + add_datapath_post(app, "/errors_503", errors_503) + add_datapath_post(app, "/violation_bad_content_type", violation_bad_content_type) + add_datapath_post( + app, "/violation_invalid_protobuf_body", violation_invalid_protobuf_body ) + + # snapshot_read only error endpoints app.router.add_post( "/unusable_disabled_via_read_disabled/snapshot_read", unusable_disabled_via_read_disabled, @@ -292,9 +351,32 @@ async def violation_wrong_ranges(request: web.Request) -> web.Response: "/violation_wrong_ranges/snapshot_read", violation_wrong_ranges, ) - app.router.add_post("/v1/consistency/strong/snapshot_read", strong_snapshot_read) - app.router.add_post("/v2/consistency/strong/snapshot_read", strong_snapshot_read) - app.router.add_post("/v3/consistency/strong/snapshot_read", strong_snapshot_read) + + # atomic_write only error endpoints + app.router.add_post( + "/success_with_failed_checks/atomic_write", + violation_atomic_write_success_with_failed_checks, + ) + app.router.add_post( + "/success_with_invalid_versionstamp/atomic_write", + violation_atomic_write_success_with_invalid_versionstamp, + ) + app.router.add_post( + "/check_failure_with_out_of_bounds_index/atomic_write", + violation_atomic_write_check_failure_with_out_of_bounds_index, + ) + app.router.add_post( + "/check_failure_without_failed_checks/atomic_write", + quirk_atomic_write_check_failure_without_failed_checks, + ) + app.router.add_post("/unusable/atomic_write", unusable_atomic_write) + app.router.add_post( + "/unspecified_status/atomic_write", violation_atomic_write_unspecified_status + ) + app.router.add_post( + "/invalid_status/atomic_write", violation_atomic_write_invalid_status + ) + return app @@ -306,92 +388,110 @@ async def client( return await aiohttp_client(db_api) -def make_database_metadata_for_endpoint( - endpoint_url: URL, - endpoint_consistency: ConsistencyLevel = ConsistencyLevel.STRONG, - version: Literal[1, 2, 3] = 3, - database_id: UUID | None = None, - expires_at: datetime | None = None, - token: str = "hunter2.123", -) -> tuple[DatabaseMetadata, EndpointInfo]: - if database_id is None: - database_id = UUID("00000000-0000-0000-0000-000000000000") - if expires_at is None: - expires_at = datetime.now() + timedelta(minutes=30) - - endpoint = EndpointInfo(url=endpoint_url, consistency=endpoint_consistency) - - meta = DatabaseMetadata( - version=version, - database_id=database_id, - endpoints=[endpoint], - expires_at=expires_at, - token=token, +@pytest.mark.parametrize( + "datapath_request_fn", + [ + pytest.param( + functools.partial(snapshot_read, read=SnapshotRead()), id="snapshot_read" + ), + pytest.param( + functools.partial(atomic_write, write=AtomicWrite()), id="atomic_write" + ), + ], +) +@pytest_mark_asyncio +async def test_datapath_request_function__handles_network_error( + client: TestClient, + unused_tcp_port_factory: Callable[[], int], + datapath_request_fn: functools.partial[Awaitable[Result[object, object]]], +) -> None: + server_url = client.make_url("/") + server_url = server_url.with_port(unused_tcp_port_factory()) + + meta, endpoint = meta_endpoint(make_database_metadata(endpoints=server_url)) + + # will fail to connect to URL with nothing listening on the port + result = await datapath_request_fn( + session=client.session, + meta=meta, + endpoint=endpoint, + ) + assert isinstance(result, Err) + assert result.error == RequestUnsuccessful( + "Failed to make Data Path HTTP request to KV server", + endpoint=endpoint, + auto_retry=AutoRetry.AFTER_BACKOFF, ) - return meta, endpoint -@pytest.mark.parametrize( - "path, mk_error", - [ - ( - "/violation_2xx_text_body", - lambda endpoint: ResponseUnsuccessful( - "Server responded to Data Path request with unexpected HTTP status", - status=201, - body_text="Strange behaviour.", - auto_retry=AutoRetry.NEVER, - endpoint=endpoint, - ), +generic_datapath_unsuccessful_response_params: Sequence[ + tuple[str, Callable[[EndpointInfo], DataPathDenoKvError]] +] = [ + ( + "/violation_2xx_text_body", + lambda endpoint: ResponseUnsuccessful( + "Server responded to Data Path request with unexpected HTTP status", + status=201, + body_text="Strange behaviour.", + auto_retry=AutoRetry.NEVER, + endpoint=endpoint, ), - ( - "/violation_2xx_protobuf_body", - lambda endpoint: ResponseUnsuccessful( - "Server responded to Data Path request with unexpected HTTP status", - status=201, - body_text="Response content-type: application/x-protobuf", - auto_retry=AutoRetry.NEVER, - endpoint=endpoint, - ), + ), + ( + "/violation_2xx_protobuf_body", + lambda endpoint: ResponseUnsuccessful( + "Server responded to Data Path request with unexpected HTTP status", + status=201, + body_text="Response content-type: application/x-protobuf", + auto_retry=AutoRetry.NEVER, + endpoint=endpoint, ), - ( - "/violation_307", - lambda endpoint: ResponseUnsuccessful( - "Server responded to Data Path request with unexpected HTTP status", - status=307, - body_text="testdb: redirecting to /foo", - auto_retry=AutoRetry.NEVER, - endpoint=endpoint, - ), + ), + ( + "/violation_307", + lambda endpoint: ResponseUnsuccessful( + "Server responded to Data Path request with unexpected HTTP status", + status=307, + body_text="testdb: redirecting to /foo", + auto_retry=AutoRetry.NEVER, + endpoint=endpoint, ), - ( - "/errors_401", - lambda endpoint: ResponseUnsuccessful( - "Server rejected Data Path request indicating client error", - status=401, - body_text="testdb: Unauthorized", - auto_retry=AutoRetry.NEVER, - endpoint=endpoint, - ), + ), + ( + "/errors_401", + lambda endpoint: ResponseUnsuccessful( + "Server rejected Data Path request indicating client error", + status=401, + body_text="testdb: Unauthorized", + auto_retry=AutoRetry.NEVER, + endpoint=endpoint, ), - ( - "/errors_503", - lambda endpoint: ResponseUnsuccessful( - "Server failed to respond to Data Path request indicating server error", - status=503, - body_text="testdb: Unavailable", - auto_retry=AutoRetry.AFTER_BACKOFF, - endpoint=endpoint, - ), + ), + ( + "/errors_503", + lambda endpoint: ResponseUnsuccessful( + "Server failed to respond to Data Path request indicating server error", + status=503, + body_text="testdb: Unavailable", + auto_retry=AutoRetry.AFTER_BACKOFF, + endpoint=endpoint, ), - ( - "/violation_bad_content_type", - lambda endpoint: ProtocolViolation( - "response content-type is not application/x-protobuf: text/plain", - data="text/plain", - endpoint=endpoint, - ), + ), + ( + "/violation_bad_content_type", + lambda endpoint: ProtocolViolation( + "response content-type is not application/x-protobuf: text/plain", + data="text/plain", + endpoint=endpoint, ), + ), +] + + +@pytest.mark.parametrize( + "path, mk_error", + [ + *generic_datapath_unsuccessful_response_params, ( "/violation_invalid_protobuf_body", lambda endpoint: ProtocolViolation( @@ -455,7 +555,7 @@ async def test_snapshot_read__handles_unsuccessful_responses( mk_error: Callable[[EndpointInfo], DataPathDenoKvError], ) -> None: server_url = client.make_url(path) - meta, endpoint = make_database_metadata_for_endpoint(endpoint_url=server_url) + meta, endpoint = meta_endpoint(make_database_metadata(endpoints=server_url)) error = mk_error(endpoint) assert isinstance(error, DataPathDenoKvError) read = SnapshotRead(ranges=[]) @@ -470,31 +570,6 @@ async def test_snapshot_read__handles_unsuccessful_responses( assert result.error == error -@pytest_mark_asyncio -async def test_snapshot_read__handles_network_error( - client: TestClient, unused_tcp_port_factory: Callable[[], int] -) -> None: - server_url = client.make_url("/") - server_url = server_url.with_port(unused_tcp_port_factory()) - - meta, endpoint = make_database_metadata_for_endpoint(endpoint_url=server_url) - read = SnapshotRead(ranges=[]) - - # will fail to connect to URL with nothing listening on the port - result = await snapshot_read( - session=client.session, - meta=meta, - endpoint=endpoint, - read=read, - ) - assert isinstance(result, Err) - assert result.error == RequestUnsuccessful( - "Failed to make Data Path HTTP request to KV server", - endpoint=endpoint, - auto_retry=AutoRetry.AFTER_BACKOFF, - ) - - @pytest.mark.parametrize( "read_ranges, result_ranges", [ @@ -609,8 +684,8 @@ async def test_snapshot_read__reads_expected_values( version: Literal[1, 2, 3], ) -> None: server_url = client.make_url(f"/v{version}/consistency/strong/") - meta, endpoint = make_database_metadata_for_endpoint( - endpoint_url=server_url, version=version + meta, endpoint = meta_endpoint( + make_database_metadata(endpoints=server_url, version=version) ) ver = add_entries(mock_db, example_entries) @@ -635,6 +710,236 @@ async def test_snapshot_read__reads_expected_values( assert actual_result_ranges == expected_result_ranges +@pytest_mark_asyncio +async def test_atomic_write__raises_when_given_endpoint_without_strong_consistency( + client: TestClient, +) -> None: + # this is considered an avoidable programmer error, so it raises + meta, eventual_endpoint = meta_endpoint( + make_database_metadata( + URL("https://example/"), endpoint_consistency=ConsistencyLevel.EVENTUAL + ) + ) + with pytest.raises( + ValueError, + match=r"endpoints used with atomic_write must be " + r"", + ): + await atomic_write( + session=client.session, + meta=meta, + endpoint=eventual_endpoint, + write=AtomicWrite(), + ) + + +@pytest.mark.parametrize( + "path, mk_error", + [ + *generic_datapath_unsuccessful_response_params, + ( + "/violation_invalid_protobuf_body", + lambda endpoint: ProtocolViolation( + "Server responded to Data Path request with invalid " + "AtomicWriteOutput", + data=b"\x00foo", + endpoint=endpoint, + ), + ), + ( + "/success_with_failed_checks", + lambda endpoint: ProtocolViolation( + "Server responded to Data Path Atomic Write with SUCCESS " + "containing failed checks", + data=AtomicWriteOutput( + status=AtomicWriteStatus.AW_SUCCESS, + failed_checks=[0], + versionstamp=VersionStamp(0), + ), + endpoint=endpoint, + ), + ), + ( + "/success_with_invalid_versionstamp", + lambda endpoint: ProtocolViolation( + "Server responded to Data Path Atomic Write with SUCCESS " + "containing an invalid versionstamp", + data=AtomicWriteOutput( + status=AtomicWriteStatus.AW_SUCCESS, + versionstamp=b"\xff", + ), + endpoint=endpoint, + ), + ), + ( + "/check_failure_with_out_of_bounds_index", + lambda endpoint: ProtocolViolation( + "Server responded to Data Path Atomic Write with CHECK_FAILURE " + "referencing out-of-bounds check index", + data=AtomicWriteOutput( + status=AtomicWriteStatus.AW_CHECK_FAILURE, + failed_checks=[1], + ), + endpoint=endpoint, + ), + ), + ( + "/check_failure_without_failed_checks", + lambda endpoint: CheckFailure( + "Not all checks required by the Atomic Write passed", + all_checks=[ + Check(key=pack_key(("x",)), versionstamp=bytes(VersionStamp(0))) + ], + failed_check_indexes=[], + endpoint=endpoint, + ), + ), + ( + "/unspecified_status", + lambda endpoint: ProtocolViolation( + "Server responded to Data Path Atomic Write request " + "with status UNSPECIFIED", + data=AtomicWriteOutput(status=AtomicWriteStatus.AW_UNSPECIFIED), + endpoint=endpoint, + ), + ), + ( + "/invalid_status", + lambda endpoint: ProtocolViolation( + "Server responded to Data Path Atomic Write request " + "with status unknown: 6", + data=AtomicWriteOutput(status=6), # type: ignore[arg-type] + endpoint=endpoint, + ), + ), + ( + "/unusable", + lambda endpoint: EndpointNotUsable( + "Server responded to Data Path request indicating it is cannot " + "write this database", + reason=EndpointNotUsableReason.DISABLED, + endpoint=endpoint, + ), + ), + ], +) +@pytest_mark_asyncio +async def test_atomic_write__handles_unsuccessful_responses( + client: TestClient, + path: str, + mk_error: Callable[[EndpointInfo], DataPathDenoKvError], +) -> None: + server_url = client.make_url(path) + meta, endpoint = meta_endpoint(make_database_metadata(endpoints=server_url)) + error = mk_error(endpoint) + assert isinstance(error, DataPathDenoKvError) + + result = await atomic_write( + session=client.session, + meta=meta, + endpoint=endpoint, + write=AtomicWrite( + checks=[Check(key=pack_key(("x",)), versionstamp=bytes(VersionStamp(0)))] + ), + ) + assert isinstance(result, Err) + assert result.error == error + + +@pytest.fixture +def example_entries_write() -> Mapping[KvKeyTuple, object]: + return {("bigint", 1): JSBigInt(10)} + + +# There's not really much point in testing many successful mutations here, as +# our atomic_write() function is just passing along the encoded protobuf data +# without doing anything to it — we're just testing the db implementation if we +# were to test lots of things here. Error cases are where all the work is. +@pytest.mark.parametrize( + "write, read_ranges, result_ranges", + [ + pytest.param(AtomicWrite(), [], [], id="empty"), + pytest.param( + AtomicWrite( + mutations=[ + Mutation( + key=pack_key(("bigint", 1)), + value=KvValue( + data=bytes(default_v8_encoder.encode(JSBigInt(20))), + encoding=ValueEncoding.VE_V8, + ), + mutation_type=MutationType.M_SET, + ) + ] + ), + [ + ReadRange( + start=pack_key(("bigint", 1)), end=pack_key(("bigint", 2)), limit=1 + ), + ], + [[(KvKey("bigint", 1), JSBigInt(20))]], + id="set", + ), + pytest.param( + AtomicWrite( + mutations=[ + Mutation( + key=pack_key(("bigint", 1)), + value=KvValue( + data=bytes(default_v8_encoder.encode(JSBigInt(20))), + encoding=ValueEncoding.VE_V8, + ), + mutation_type=MutationType.M_SUM, + ) + ] + ), + [ + ReadRange( + start=pack_key(("bigint", 1)), end=pack_key(("bigint", 2)), limit=1 + ), + ], + [[(KvKey("bigint", 1), JSBigInt(30))]], + id="sum", + ), + ], +) +@pytest.mark.parametrize("version", [1, 2, 3]) +@pytest_mark_asyncio +async def test_atomic_write__writes_expected_values( + client: TestClient, + mock_db: MockKvDb, + example_entries_write: Mapping[KvKeyTuple, object], + write: AtomicWrite, + read_ranges: list[ReadRange], + result_ranges: list[list[tuple[KvKeyTuple, object]]], + version: Literal[1, 2, 3], +) -> None: + server_url = client.make_url(f"/v{version}/consistency/strong/") + meta, endpoint = meta_endpoint( + make_database_metadata(endpoints=server_url, version=version) + ) + add_entries(mock_db, example_entries_write) + + write_result = await atomic_write( + session=client.session, meta=meta, endpoint=endpoint, write=write + ) + + assert is_ok(write_result) + write_ver = VersionStamp(write_result.value) + + actual_result_ranges = [ + [unsafe_parse_protobuf_kv_entry(raw_entry) for raw_entry in res_range.values] + for res_range in ( + mock_db.snapshot_read_range(read=range) for range in read_ranges + ) + ] + expected_result_ranges = [ + [KvEntry(key, value, versionstamp=write_ver) for (key, value) in entries] + for entries in result_ranges + ] + assert actual_result_ranges == expected_result_ranges + + @pytest.mark.parametrize( "raw_entry, decoded", [ @@ -1056,3 +1361,89 @@ def test_is_any_kv_key() -> None: assert is_any_kv_key(("a", 1, 1.0, True, b"b")) assert not is_any_kv_key([]) assert not is_any_kv_key(((),)) + + +@pytest.fixture +def example_endpoint() -> EndpointInfo: + _, endpoint = meta_endpoint( + make_database_metadata(endpoints=URL("https://example.com")) + ) + return endpoint + + +def test_CheckFailure(example_endpoint: EndpointInfo) -> None: + checks = [ + Check(key=bytes(KvKey(f"a{i}")), versionstamp=bytes(VersionStamp(i))) + for i in range(4) + ] + msg = "Not all checks required by the Atomic Write passed" + e = CheckFailure( + msg, + all_checks=iter(checks), + failed_check_indexes=[3, 0, 2], + endpoint=example_endpoint, + ) + assert e.all_checks == tuple(checks) + assert e.failed_check_indexes == {0, 2, 3} + # failed_check_indexes are ordered ascending + assert list(e.failed_check_indexes) == [0, 2, 3] + assert e.endpoint is example_endpoint + assert msg in str(e) + + +@pytest.mark.parametrize("failed_check_indexes", [None, ()]) +def test_CheckFailure__failed_check_indexes_is_None_when_no_indexes( + failed_check_indexes: Iterable[int] | None, example_endpoint: EndpointInfo +) -> None: + checks = [ + Check(key=bytes(KvKey(f"a{i}")), versionstamp=bytes(VersionStamp(i))) + for i in range(4) + ] + # Failed_check_indexes can be empty (the self-hosted sqlite implementation + # does not return the indexes of failed checks). + e = CheckFailure( + "Foo", + all_checks=iter(checks), + failed_check_indexes=failed_check_indexes, + endpoint=example_endpoint, + ) + assert e.all_checks == tuple(checks) + assert e.failed_check_indexes is None + + +def test_CheckFailure__validates_constructor_args( + example_endpoint: EndpointInfo, +) -> None: + checks = [Check(key=bytes(KvKey("a")), versionstamp=bytes(VersionStamp(1)))] + + with pytest.raises(ValueError, match=r"all_checks is empty"): + CheckFailure( + "Foo", all_checks=[], failed_check_indexes=[], endpoint=example_endpoint + ) + + with pytest.raises( + IndexError, match=r"failed_check_indexes contains out-of-bounds index" + ): + CheckFailure( + "Foo", + all_checks=checks, + failed_check_indexes=[5], + endpoint=example_endpoint, + ) + + +def test_ResponseUnsuccessful(example_endpoint: EndpointInfo) -> None: + msg = "Server rejected Data Path request indicating client error" + response_body = "Info about what is wrong." + status = 400 + e = ResponseUnsuccessful( + msg, + status=status, + body_text=response_body, + endpoint=example_endpoint, + auto_retry=AutoRetry.NEVER, + ) + assert str(msg) in str(e) + assert str(status) in str(e) + assert str(response_body) in str(e) + assert e.endpoint is example_endpoint diff --git a/test/test_errors.py b/test/test_errors.py new file mode 100644 index 0000000..cf96ff4 --- /dev/null +++ b/test/test_errors.py @@ -0,0 +1,20 @@ +import pytest + +from denokv.errors import DenoKvError + + +def test_errors_are_regular_exceptions() -> None: + """Errors must be caught by generic Exception handlers — not BaseException.""" + with pytest.raises(Exception): # noqa: B017 + raise DenoKvError("error") + + +def test_DenoKvError_message() -> None: + assert DenoKvError().message == "DenoKvError" + assert DenoKvError("Foo bar").message == "Foo bar" + + class CustomError(DenoKvError): + pass + + assert CustomError().message == "CustomError" + assert CustomError("Bar baz").message == "Bar baz" diff --git a/test/test_kv.py b/test/test_kv.py index 2389c43..465c56f 100644 --- a/test/test_kv.py +++ b/test/test_kv.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import re import sys import weakref from contextlib import asynccontextmanager @@ -9,6 +10,7 @@ from datetime import timedelta from functools import partial from itertools import repeat +from typing import Literal # noqa: TID251 from unittest.mock import AsyncMock from unittest.mock import Mock from unittest.mock import patch @@ -18,12 +20,16 @@ import pytest import pytest_asyncio import v8serialize +from aiohttp import web +from aiohttp.test_utils import TestClient as _TestClient from fdb.tuple import unpack from hypothesis import HealthCheck from hypothesis import given from hypothesis import settings from hypothesis import strategies as st from v8serialize import Decoder +from v8serialize.jstypes import JSBigInt +from v8serialize.jstypes import JSMap from yarl import URL from denokv import datapath @@ -33,13 +39,24 @@ from denokv._datapath_pb2 import SnapshotReadOutput from denokv._datapath_pb2 import SnapshotReadStatus from denokv._datapath_pb2 import ValueEncoding +from denokv._kv_values import KvEntry +from denokv._kv_values import KvU64 +from denokv._kv_values import VersionStamp +from denokv._kv_writes import DEFAULT_ENQUEUE_RETRY_DELAY_COUNT +from denokv._kv_writes import LIMIT_KVU64 +from denokv._kv_writes import Check +from denokv._kv_writes import FailedWrite +from denokv._kv_writes import Limit +from denokv._kv_writes import SumArgs from denokv._pycompat.enum import StrEnum from denokv._pycompat.typing import Any from denokv._pycompat.typing import AsyncGenerator +from denokv._pycompat.typing import Awaitable from denokv._pycompat.typing import Callable from denokv._pycompat.typing import Generator from denokv._pycompat.typing import Mapping from denokv._pycompat.typing import TypeAlias +from denokv._pycompat.typing import Union from denokv._pycompat.typing import cast from denokv.asyncio import loop_time from denokv.auth import ConsistencyLevel @@ -54,80 +71,48 @@ from denokv.datapath import KvKeyTuple from denokv.datapath import RequestUnsuccessful from denokv.datapath import ResponseUnsuccessful -from denokv.datapath import SnapshotReadResult from denokv.datapath import increment_packed_key from denokv.datapath import pack_key from denokv.errors import DenoKvError from denokv.errors import InvalidCursor from denokv.kv import Authenticator from denokv.kv import AuthenticatorFn +from denokv.kv import Base64KeySuffixCursorFormat from denokv.kv import CachedValue from denokv.kv import DatabaseMetadataCache from denokv.kv import EndpointSelector from denokv.kv import Kv -from denokv.kv import KvEntry +from denokv.kv import KvCredentials from denokv.kv import KvFlags from denokv.kv import KvListOptions -from denokv.kv import KvU64 +from denokv.kv import ListContext +from denokv.kv import ListKvEntry from denokv.kv import OpenKvFinalize -from denokv.kv import VersionStamp from denokv.kv import normalize_key from denokv.kv import open_kv from denokv.kv_keys import KvKey from denokv.result import Err from denokv.result import Ok from denokv.result import Result +from denokv.result import is_err +from denokv.result import is_ok from test.advance_time import advance_time from test.denokv_testing import ExampleCursorFormat from test.denokv_testing import MockKvDb from test.denokv_testing import add_entries from test.denokv_testing import assume_ok -from test.denokv_testing import mk_db_meta +from test.denokv_testing import create_dataclass_slots_test +from test.denokv_testing import make_database_metadata +from test.denokv_testing import mock_db_api from test.denokv_testing import unsafe_parse_protobuf_kv_entry -pytest_mark_asyncio = pytest.mark.asyncio() - - -@given(v=st.integers(min_value=0, max_value=2**80 - 1)) -def test_VersionStamp_init(v: int) -> None: - vs_int = VersionStamp(v) - assert int(vs_int) == v - assert VersionStamp(str(vs_int)) == vs_int - assert VersionStamp(bytes(vs_int)) == vs_int - assert bytes(vs_int) == vs_int - assert isinstance(vs_int, bytes) - - -@given(i=st.integers(min_value=0, max_value=2**64 - 1)) -def test_KvU64_init(i: int) -> None: - u64 = KvU64(i) - assert int(u64) == i - assert KvU64(bytes(u64)) == u64 - assert u64.to_bytes() == bytes(u64) - assert u64.to_bytes() == i.to_bytes(8, "little") - +TestClient: TypeAlias = _TestClient[web.Request, web.Application] -@given( - v1=st.integers(min_value=0, max_value=2**80 - 1), - v2=st.integers(min_value=0, max_value=2**80 - 1), -) -def test_VersionStamp_ordering(v1: int, v2: int) -> None: - vs1, vs2 = VersionStamp(v1), VersionStamp(v2) - if v1 < v2: - assert vs1 < vs2 - elif v1 > v2: - assert vs1 > vs2 - else: - assert vs1 == vs2 - - -def test_KVU64__bytes() -> None: - assert KvU64(bytes(KvU64(123456789))).value == 123456789 - assert KvU64(KvU64(123456789).to_bytes()).value == 123456789 +pytest_mark_asyncio = pytest.mark.asyncio() def test_EndpointSelector__rejects_meta_without_strong_endpoint() -> None: - meta_no_strong = mk_db_meta( + meta_no_strong = make_database_metadata( [ EndpointInfo( url=URL("https://example.com/eventual/"), @@ -141,7 +126,7 @@ def test_EndpointSelector__rejects_meta_without_strong_endpoint() -> None: def test_EndpointSelector__single() -> None: - meta = mk_db_meta( + meta = make_database_metadata( [ endpoint := EndpointInfo( url=URL("https://example.com/"), consistency=ConsistencyLevel.STRONG @@ -155,7 +140,7 @@ def test_EndpointSelector__single() -> None: def test_EndpointSelector__multi() -> None: - meta = mk_db_meta( + meta = make_database_metadata( [ endpoint_eventual := EndpointInfo( url=URL("https://example.com/eventual/"), @@ -341,21 +326,43 @@ def mock_snapshot_read() -> Generator[Mock]: yield mock +@pytest.fixture +def mock_atomic_write() -> Generator[Mock]: + mock = AsyncMock(side_effect=NotImplementedError) + with patch("denokv.datapath.atomic_write", mock) as mock: + yield mock + + @pytest.fixture def retry_delays() -> Backoff: return () -@pytest_asyncio.fixture -async def client_session() -> AsyncGenerator[aiohttp.ClientSession]: - async with aiohttp.ClientSession() as cs: - yield cs +@pytest.fixture +def client_session(client: TestClient) -> aiohttp.ClientSession: + return client.session + + +@pytest.fixture(params=[1, 2, 3], ids=lambda v: f"datapath_v{v}") +def datapath_version(request: pytest.FixtureRequest) -> Literal[1, 2, 3]: + assert request.param in (1, 2, 3) + return cast(Literal[1, 2, 3], request.param) @pytest.fixture -def meta() -> DatabaseMetadata: - return mk_db_meta( - [EndpointInfo(URL("https://example.com/"), ConsistencyLevel.STRONG)] +def datapath_endpoint_url( + client: TestClient, datapath_version: Literal[1, 2, 3] +) -> URL: + return client.make_url(f"/v{datapath_version}/consistency/strong/") + + +@pytest.fixture +def meta( + datapath_version: Literal[1, 2, 3], datapath_endpoint_url: URL +) -> DatabaseMetadata: + return make_database_metadata( + [EndpointInfo(datapath_endpoint_url, ConsistencyLevel.STRONG)], + version=datapath_version, ) @@ -378,7 +385,25 @@ def kv_flags() -> KvFlags: @pytest.fixture -def create_db( +def mock_db() -> MockKvDb: + return MockKvDb() + + +@pytest.fixture +def db_api(mock_db: MockKvDb) -> web.Application: + return mock_db_api(mock_db) + + +@pytest_asyncio.fixture +async def client( + db_api: web.Application, + aiohttp_client: Callable[[web.Application], Awaitable[TestClient]], +) -> TestClient: + return await aiohttp_client(db_api) + + +@pytest.fixture +def create_kv( client_session: aiohttp.ClientSession, auth_fn: AuthenticatorFn, retry_delays: Backoff, @@ -396,8 +421,8 @@ def create_db( @pytest.fixture -def db(create_db: partial[Kv]) -> Kv: - return create_db() +def kv(create_kv: partial[Kv]) -> Kv: + return create_kv() @pytest.fixture @@ -430,19 +455,19 @@ def pack_kv_entry( @pytest_mark_asyncio async def test_Kv_get__rejects_invalid_arguments( - db: Kv, mock_snapshot_read: AsyncMock + kv: Kv, mock_snapshot_read: AsyncMock ) -> None: with pytest.raises( TypeError, match=r"cannot use positional keys and keys keyword argument" ): - await db.get(("a", 1), keys=[("a", 2)]) # type: ignore[call-overload] + await kv.get(("a", 1), keys=[("a", 2)]) # type: ignore[call-overload] with pytest.raises(TypeError, match=r"at least one key argument must be passed"): - await db.get() # type: ignore[call-overload] + await kv.get() # type: ignore[call-overload] @pytest_mark_asyncio async def test_Kv_get__returns_single_value_for_single_key( - db: Kv, mock_snapshot_read: AsyncMock + kv: Kv, mock_snapshot_read: AsyncMock ) -> None: read_output = SnapshotReadOutput( ranges=[ReadRangeOutput(values=[pack_kv_entry(("a", 1), b"x")])], @@ -454,7 +479,7 @@ async def test_Kv_get__returns_single_value_for_single_key( mock_snapshot_read.side_effect = None mock_snapshot_read.return_value = Ok(read_output) - k, kval = await db.get(("a", 1)) + k, kval = await kv.get(("a", 1)) assert k == ("a", 1) assert kval is not None @@ -472,7 +497,7 @@ class ArgKind(StrEnum): @pytest.mark.parametrize("arg_kind", ArgKind) @pytest_mark_asyncio async def test_Kv_get__returns_n_values_for_n_keys( - n: int, arg_kind: ArgKind, db: Kv, mock_snapshot_read: AsyncMock + n: int, arg_kind: ArgKind, kv: Kv, mock_snapshot_read: AsyncMock ) -> None: read_output = SnapshotReadOutput( ranges=[ @@ -492,9 +517,9 @@ async def test_Kv_get__returns_n_values_for_n_keys( mock_snapshot_read.return_value = Ok(read_output) if arg_kind is ArgKind.KWARGS: - values = await db.get(keys=[("i", i) for i in range(n)]) + values = await kv.get(keys=[("i", i) for i in range(n)]) else: - values = await db.get(*[("i", i) for i in range(n)]) + values = await kv.get(*[("i", i) for i in range(n)]) assert isinstance(values, tuple) assert len(values) == n @@ -514,7 +539,7 @@ async def test_Kv_get__returns_n_values_for_n_keys( ) @pytest_mark_asyncio async def test_Kv_get__treats_int_as_float_when_IntAsNumber_enabled( - db: Kv, mock_snapshot_read: AsyncMock, int_type: type + kv: Kv, mock_snapshot_read: AsyncMock, int_type: type ) -> None: read_output = SnapshotReadOutput( ranges=[ReadRangeOutput(values=[pack_kv_entry(("a", int_type(1)), b"x")])], @@ -526,7 +551,7 @@ async def test_Kv_get__treats_int_as_float_when_IntAsNumber_enabled( mock_snapshot_read.side_effect = None mock_snapshot_read.return_value = Ok(read_output) - k, kval = await db.get(("a", 1)) + k, kval = await kv.get(("a", 1)) assert k == ("a", 1) # 1 == 1.0 assert type(k[1]) is int_type @@ -585,14 +610,14 @@ def retryable_errors( @given(data=st.data(), retry_delays=st.sampled_from([[], [1.0], [1.0, 2.0, 4.0]])) @pytest_mark_asyncio async def test_Kv_get__retries_retryable_snapshot_read_errors( - create_db: partial[Kv], + create_kv: partial[Kv], meta: DatabaseMetadata, mock_snapshot_read: AsyncMock, data: st.DataObject, retry_delays: Backoff, ) -> None: auth_fn = AsyncMock(name="auth_fn", return_value=Ok(meta)) - db = create_db(retry=retry_delays, auth=auth_fn) + db = create_kv(retry=retry_delays, auth=auth_fn) retry_errors: list[DataPathError] = [] def fail_with_retryable_error(*args: Any, **kwargs: Any) -> Err[DenoKvError]: @@ -629,17 +654,17 @@ def fail_with_retryable_error(*args: Any, **kwargs: Any) -> Err[DenoKvError]: @pytest_mark_asyncio -async def test_Kv_list__rejects_invalid_arguments(db: Kv) -> None: +async def test_Kv_list__rejects_invalid_arguments(kv: Kv) -> None: with pytest.raises(ValueError, match=r"limit cannot be negative"): - async for _ in db.list(limit=-1): + async for _ in kv.list(limit=-1): raise AssertionError("should not generate values") with pytest.raises(ValueError, match=r"batch_size cannot be < 1"): - async for _ in db.list(batch_size=0): + async for _ in kv.list(batch_size=0): raise AssertionError("should not generate values") with pytest.raises(InvalidCursor, match=r"cursor is not valid URL-safe base64"): - async for _ in db.list(cursor="x"): + async for _ in kv.list(cursor="x"): raise AssertionError("should not generate values") @@ -658,12 +683,12 @@ def pack_example_cursor(key: KvKeyTuple) -> str: ) @pytest_mark_asyncio async def test_Kv_list__rejects_cursor_outside_listed_range( - db: Kv, range_options: KvListOptions, cursor_key: KvKeyTuple + kv: Kv, range_options: KvListOptions, cursor_key: KvKeyTuple ) -> None: with pytest.raises( InvalidCursor, match=r"cursor is not within the the start and end key range" ): - async for _ in db.list( + async for _ in kv.list( **KvListOptions( **range_options, cursor_format_type=ExampleCursorFormat, @@ -673,11 +698,6 @@ async def test_Kv_list__rejects_cursor_outside_listed_range( raise AssertionError("should not generate values") -@pytest.fixture -def mock_db() -> MockKvDb: - return MockKvDb() - - @pytest.fixture def list_example_entries() -> Mapping[KvKeyTuple, object]: return { @@ -687,33 +707,6 @@ def list_example_entries() -> Mapping[KvKeyTuple, object]: } -@pytest.fixture -def mock_snapshot_read_to_return_mock_db_results( - mock_snapshot_read: AsyncMock, mock_db: MockKvDb -) -> Callable[[], AsyncMock]: - async def snapshot_read_effect( - *, - session: aiohttp.ClientSession, - meta: DatabaseMetadata, - endpoint: EndpointInfo, - read: SnapshotRead, - ) -> SnapshotReadResult: - assert len(read.ranges) == 1 - snapshot_read_output = SnapshotReadOutput( - ranges=[mock_db.snapshot_read_range(read.ranges[0])], - read_disabled=False, - read_is_strongly_consistent=True, - status=SnapshotReadStatus.SR_SUCCESS, - ) - return Ok(snapshot_read_output) - - def apply() -> AsyncMock: - mock_snapshot_read.side_effect = snapshot_read_effect - return mock_snapshot_read - - return apply - - list_example_keys = st.one_of( st.none(), st.just(()), @@ -766,8 +759,7 @@ def list_example_cursors( @pytest_mark_asyncio async def test_Kv_list__generates_values_from_sequential_snapshot_reads( data: st.DataObject, - db: Kv, - mock_snapshot_read_to_return_mock_db_results: Callable[[], AsyncMock], + kv: Kv, mock_db: MockKvDb, list_example_entries: Mapping[KvKeyTuple, object], prefix: KvKeyTuple | None, @@ -780,7 +772,6 @@ async def test_Kv_list__generates_values_from_sequential_snapshot_reads( ) -> None: mock_db.clear() add_entries(mock_db, list_example_entries) - mock_snapshot_read_to_return_mock_db_results() # Kv.list() should be equivalent to reading the listed range in one go. listed_range = datapath.read_range_multi( @@ -815,7 +806,7 @@ async def test_Kv_list__generates_values_from_sequential_snapshot_reads( ] results: list[tuple[KvKeyTuple, object, VersionStamp]] = [] - async for kv_entry in db.list( + async for kv_entry in kv.list( prefix=prefix, start=start, end=end, @@ -854,15 +845,12 @@ async def test_Kv_list__generates_values_from_sequential_snapshot_reads( @pytest_mark_asyncio async def test_Kv_list__retries_retryable_snapshot_read_errors( - create_db: partial[Kv], + create_kv: partial[Kv], meta: DatabaseMetadata, mock_snapshot_read: AsyncMock, - # mock_snapshot_read_to_return_mock_db_results: Callable[[], AsyncMock], - # mock_db: MockKvDb, - # list_example_entries: Mapping[KvKeyTuple, object], ) -> None: auth_fn = AsyncMock(name="auth_fn", return_value=Ok(meta)) - db = create_db(retry=repeat(0), auth=auth_fn) + db = create_kv(retry=repeat(0), auth=auth_fn) auth_fn.side_effect = [ Err(MetadataExchangeDenoKvError("Failed", retryable=True)), @@ -978,6 +966,472 @@ async def test_Kv_list__retries_retryable_snapshot_read_errors( assert len(auth_fn.mock_calls) == 7 +@pytest_mark_asyncio +async def test_Kv_write__set(kv: Kv) -> None: + _, before = await kv.get(("foo", 1)) + result = await kv.atomic().set(("foo", 1), "Hi").write() + _, after = await kv.get(("foo", 1)) + + assert before is None + assert is_ok(result) + assert after and after.value == "Hi" + + +@pytest_mark_asyncio +async def test_Kv_write__set_versioned(kv: Kv) -> None: + result = await kv.atomic().set(("foo", 1), "Hi", versioned=True).write() + assert is_ok(result) + _, entry = await kv.get(("foo", 1, str(result.versionstamp))) + assert entry and entry.value == "Hi" + + +ErrorPredicate: TypeAlias = Callable[[Union[BaseException, None]], bool] + + +def match_client_error(server_msg_content: str) -> ErrorPredicate: + def is_client_error(e: BaseException | None) -> bool: + return ( + isinstance(e, FailedWrite) + and isinstance(e.__cause__, ResponseUnsuccessful) + and e.__cause__.status == 400 + and server_msg_content in e.__cause__.body_text + ) + + return is_client_error + + +def match_error( + kind: type[BaseException], + containing: str | None = None, + matching: str | re.Pattern[str] | None = None, + cause: ErrorPredicate | None = None, +) -> ErrorPredicate: + if containing is not None: + if matching is not None: + raise ValueError("containing and matching args cannot both be set") + matching = re.escape(containing) + elif matching is None: + raise ValueError("containing or matching args must be set") + + def is_error(e: BaseException | None) -> bool: + return ( + isinstance(e, kind) + and bool(re.search(matching, str(e))) + and (cause is None or cause(e.__cause__)) + ) + + return is_error + + +def match_write_failure( + kind: type[BaseException], + containing: str | None = None, + matching: str | re.Pattern[str] | None = None, +) -> ErrorPredicate: + return match_error( + kind=FailedWrite, + containing="", + cause=match_error(kind, containing=containing, matching=matching), + ) + + +@asynccontextmanager +async def validate_write_outcome( + kv: Kv, + initial_val: object | None, + result: object | None | ErrorPredicate, +) -> AsyncGenerator[tuple[Kv, KvKeyTuple]]: + match_error: ErrorPredicate | None = result if callable(result) else None + + if initial_val is not None: + assert is_ok(await kv.atomic().set(("foo", 0), initial_val).write()) + else: + assert is_ok(await kv.atomic().delete(("foo", 0)).write()) + + try: + yield (kv, ("foo", 0)) + assert not match_error, "write succeeded but is expected to fail" + except AssertionError: + raise + except Exception as e: + if not match_error: + raise + assert match_error(e), f"Did not match error: {e!r}" + return + + (_, a) = await kv.get(("foo", 0)) + if result is None: + assert a is None + else: + assert a and a.value == result and type(a.value) is type(result) + + +# fmt: off +_params_test_Kv_write__sum = pytest.mark.parametrize( + "initial_val, sum_val, sum_kwargs, result", + [ + (12, 3, {}, 15), + (12, 3.5, {}, 15.5), + (JSBigInt(12), JSBigInt(3), {}, JSBigInt(15)), + (12, -3, {}, 9), + (12, -3.5, {}, 8.5), + (JSBigInt(12), JSBigInt(-3), {}, JSBigInt(9)), + (None, 3, {}, 3), + (12.5, 2.5, {}, 15), + (12.5, -2.5, {}, 10), + (None, 2.5, {}, 2.5), + (None, -2.5, {}, -2.5), + (KvU64(12), KvU64(3), {}, KvU64(15)), + (KvU64(12), 3, SumArgs(number_type='u64'), KvU64(15)), + (KvU64(12), -3, SumArgs(number_type='u64'), KvU64(9)), + (None, KvU64(3), {}, KvU64(3)), + # KvU64 wraps on overflow + (KvU64(1), -3, SumArgs(number_type='u64'), KvU64(2**64 - 2)), + (KvU64(2**64 - 2), 3, SumArgs(number_type='u64'), KvU64(1)), + # Limits + (JSBigInt(12), JSBigInt(10), SumArgs(clamp_under=10, clamp_over=20), JSBigInt(20)), # noqa: E501 + (JSBigInt(12), JSBigInt(-10), SumArgs(clamp_under=10, clamp_over=20), JSBigInt(10)), # noqa: E501 + (12.0, 10.0, SumArgs(clamp_under=10.0, clamp_over=20.0), 20), + (12.0, -10.0, SumArgs(clamp_under=10.0, clamp_over=20.0), 10), + (KvU64(12), KvU64(10), SumArgs(clamp_under=10, clamp_over=20), KvU64(20)), + (KvU64(12), -10, SumArgs(number_type='u64', clamp_under=10, clamp_over=20), KvU64(10)), # noqa: E501 + # limit via Limit object + (12, -10, SumArgs(limit=Limit(10, 20, 'clamp')), 10), + # overflow with limit_exceeded error causes write to fail with client error + pytest.param(12, 10, SumArgs(abort_under=10, abort_over=20 ), match_client_error("Mutation is not a valid M_SUM operation"), id='err-limit-high-BigInt'), # noqa: E501 + pytest.param(12, -10, SumArgs(abort_under=10, abort_over=20 ), match_client_error("Mutation is not a valid M_SUM operation"), id='err-limit-low-BigInt'), # noqa: E501 + pytest.param(12.0, 10.0, SumArgs(abort_under=10.0, abort_over=20.0 ), match_client_error("Mutation is not a valid M_SUM operation"), id='err-limit-high-Number'), # noqa: E501 + pytest.param(12.0, -10.0, SumArgs(abort_under=10.0, abort_over=20.0 ), match_client_error("Mutation is not a valid M_SUM operation"), id='err-limit-low-Number'), # noqa: E501 + # Cannot use limit_exceeded other than wrap for KvU64 + pytest.param(KvU64(12), KvU64(1), SumArgs(abort_over=100), lambda e: isinstance(e, ValueError) and "Number type 'u64' does not support abort limits" == str(e), id='err-invalid-exceeded-KvU64'), # noqa: E501 + # Cannot use limit_exceeded wrap for BigInt/Number + pytest.param(1, JSBigInt(1), SumArgs(limit=LIMIT_KVU64), match_error(ValueError, "Number type 'bigint' does not support wrap limits"), id='err-invalid-exceeded-BigInt'), # noqa: E501 + pytest.param(1.0, 1.0, SumArgs(limit=LIMIT_KVU64), match_error(ValueError, "Number type 'float' does not support wrap limits"), id='err-invalid-exceeded-Number'), # noqa: E501 + ], +) +# fmt: on +@_params_test_Kv_write__sum +@pytest_mark_asyncio +async def test_Kv_write__atomic_sum( + kv: Kv, + initial_val: int | float | JSBigInt | KvU64 | None, + sum_val: int | float | JSBigInt | KvU64, + sum_kwargs: SumArgs[Any, Any, Any], + result: int | float | KvU64 | Callable[[Exception], bool], +) -> None: + async with validate_write_outcome(kv, initial_val, result) as (kv, key): + sum_args = SumArgs(key=key, delta=sum_val, **sum_kwargs) + assert is_ok(await kv.atomic().sum(**sum_args).write()) # type: ignore[arg-type] + + +@_params_test_Kv_write__sum +@pytest_mark_asyncio +async def test_Kv_write__sum( + kv: Kv, + initial_val: int | float | JSBigInt | KvU64 | None, + sum_val: int | float | JSBigInt | KvU64, + sum_kwargs: SumArgs[Any, Any, Any], + result: int | float | KvU64 | Callable[[Exception], bool], +) -> None: + async with validate_write_outcome(kv, initial_val, result) as (kv, key): + sum_args = SumArgs(key=key, delta=sum_val, **sum_kwargs) + assert isinstance(await kv.sum(**sum_args), VersionStamp) # type: ignore[arg-type] + + +_params_test_Kv_write__max = pytest.mark.parametrize( + "initial_val, max_val, max_kwargs, result", + [ + (JSBigInt(12), JSBigInt(3), {}, JSBigInt(12)), + (JSBigInt(3), JSBigInt(12), {}, JSBigInt(12)), + (12.5, 3, {}, 12.5), + (3, 12.5, {}, 12.5), + (KvU64(12), KvU64(3), {}, KvU64(12)), + (KvU64(3), KvU64(12), {}, KvU64(12)), + ( + JSBigInt(1), + 2.0, + {}, + # The errors reference M_SUM because bigint/number implement min/max + # using clamped M_SUM operations, not the actual M_MIN/M_MAX, + # because they only support u64. + match_write_failure( + ResponseUnsuccessful, + "SnapshotWrite is not valid: " + "Cannot apply operation M_SUM, number types are incompatible: " + "current type: JSBigInt (VE_V8 BigInt), " + "operand type: int/float (VE_V8 Number)", + ), + ), + ( + 1.5, + JSBigInt(2), + {}, + match_write_failure( + ResponseUnsuccessful, + "SnapshotWrite is not valid: " + "Cannot apply operation M_SUM, number types are incompatible: " + "current type: int/float (VE_V8 Number), " + "operand type: JSBigInt (VE_V8 BigInt)", + ), + ), + ( + KvU64(1), + 2.0, + {}, + match_write_failure( + ResponseUnsuccessful, + "SnapshotWrite is not valid: " + "Cannot apply operation M_SUM, number types are incompatible: " + "current type: KvU64 (VE_LE64), " + "operand type: int/float (VE_V8 Number)", + ), + ), + ( + 2.0, + KvU64(1), + {}, + match_write_failure( + ResponseUnsuccessful, + "SnapshotWrite is not valid: " + "Cannot apply operation M_MAX, number types are incompatible: " + "current type: int/float (VE_V8 Number), " + "operand type: KvU64 (VE_LE64)", + ), + ), + ], +) + + +@_params_test_Kv_write__max +@pytest_mark_asyncio +async def test_Kv_write__atomic_max( + kv: Kv, + initial_val: int | float | KvU64 | None, + max_val: KvU64, + max_kwargs: dict[str, Any], + result: int | float | KvU64 | Callable[[Exception], bool], +) -> None: + async with validate_write_outcome(kv, initial_val, result): + assert is_ok(await kv.atomic().max(("foo", 0), max_val, **max_kwargs).write()) + + +@_params_test_Kv_write__max +@pytest_mark_asyncio +async def test_Kv_write__max( + kv: Kv, + initial_val: int | float | KvU64 | None, + max_val: KvU64, + max_kwargs: dict[str, Any], + result: int | float | KvU64 | Callable[[Exception], bool], +) -> None: + async with validate_write_outcome(kv, initial_val, result): + assert isinstance(await kv.max(("foo", 0), max_val, **max_kwargs), VersionStamp) + + +_params_test_Kv_write__min = pytest.mark.parametrize( + "initial_val, min_val, min_kwargs, result", + [ + (JSBigInt(12), JSBigInt(3), {}, JSBigInt(3)), + (JSBigInt(3), JSBigInt(12), {}, JSBigInt(3)), + (12, 3.1, {}, 3.1), + (3.1, 12, {}, 3.1), + (KvU64(12), KvU64(3), {}, KvU64(3)), + (KvU64(3), KvU64(12), {}, KvU64(3)), + ( + JSBigInt(1), + 2.0, + {}, + # The errors reference M_SUM because bigint/number implement min/max + # using clamped M_SUM operations, not the actual M_MIN/M_MAX, + # because they only support u64. + match_write_failure( + ResponseUnsuccessful, + "SnapshotWrite is not valid: " + "Cannot apply operation M_SUM, number types are incompatible: " + "current type: JSBigInt (VE_V8 BigInt), " + "operand type: int/float (VE_V8 Number)", + ), + ), + ( + 1.5, + JSBigInt(2), + {}, + match_write_failure( + ResponseUnsuccessful, + "SnapshotWrite is not valid: " + "Cannot apply operation M_SUM, number types are incompatible: " + "current type: int/float (VE_V8 Number), " + "operand type: JSBigInt (VE_V8 BigInt)", + ), + ), + ( + KvU64(1), + 2.0, + {}, + match_write_failure( + ResponseUnsuccessful, + "SnapshotWrite is not valid: " + "Cannot apply operation M_SUM, number types are incompatible: " + "current type: KvU64 (VE_LE64), " + "operand type: int/float (VE_V8 Number)", + ), + ), + ( + 2.0, + KvU64(1), + {}, + match_write_failure( + ResponseUnsuccessful, + "SnapshotWrite is not valid: " + "Cannot apply operation M_MIN, number types are incompatible: " + "current type: int/float (VE_V8 Number), " + "operand type: KvU64 (VE_LE64)", + ), + ), + ], +) + + +@_params_test_Kv_write__min +@pytest_mark_asyncio +async def test_Kv_write__atomic_min( + kv: Kv, + initial_val: int | float | KvU64 | None, + min_val: KvU64, + min_kwargs: dict[str, Any], + result: int | float | KvU64 | Callable[[Exception], bool], +) -> None: + async with validate_write_outcome(kv, initial_val, result): + assert is_ok(await kv.atomic().min(("foo", 0), min_val, **min_kwargs).write()) + + +@_params_test_Kv_write__min +@pytest_mark_asyncio +async def test_Kv_write__min( + kv: Kv, + initial_val: int | float | KvU64 | None, + min_val: KvU64, + min_kwargs: dict[str, Any], + result: int | float | KvU64 | Callable[[Exception], bool], +) -> None: + async with validate_write_outcome(kv, initial_val, result): + assert isinstance(await kv.min(("foo", 0), min_val, **min_kwargs), VersionStamp) + + +@pytest.mark.parametrize("initial_val", [None, 42]) +@pytest_mark_asyncio +async def test_Kv_write__atomic_delete( + kv: Kv, initial_val: int | float | KvU64 | None +) -> None: + async with validate_write_outcome(kv, initial_val, result=None): + assert is_ok(await kv.atomic().delete(("foo", 0)).write()) + + +@pytest.mark.parametrize("initial_val", [None, 42]) +@pytest_mark_asyncio +async def test_Kv_write__delete( + kv: Kv, initial_val: int | float | KvU64 | None +) -> None: + async with validate_write_outcome(kv, initial_val, result=None): + assert isinstance(await kv.delete(("foo", 0)), VersionStamp) + + +@pytest_mark_asyncio +async def test_Kv_write__atomic_check__allows_write_when_matching(kv: Kv) -> None: + async with validate_write_outcome(kv, None, result=42) as (kv, key): + assert is_ok(await kv.atomic().check(key, None).set(key, 42).write()) + + async with validate_write_outcome(kv, 41, result=42) as (kv, key): + _, initial = await kv.get(key) + assert initial + assert is_ok( + await kv.atomic().check(key, initial.versionstamp).set(key, 42).write() + ) + + +@pytest_mark_asyncio +async def test_Kv_write__atomic_check__fails_write_when_mismatching(kv: Kv) -> None: + async with validate_write_outcome(kv, None, result=None) as (kv, key): + result = await kv.atomic().check(key, VersionStamp(1)).set(key, 42).write() + assert is_err(result) + assert result.conflicts[key].versionstamp == VersionStamp(1) + + async with validate_write_outcome(kv, 41, result=42) as (kv, key): + _, initial = await kv.get(key) + assert initial + assert is_ok( + await kv.atomic().check(key, initial.versionstamp).set(key, 42).write() + ) + # Try to change from original version + result = await kv.atomic().check(key, initial.versionstamp).set(key, 80).write() + assert is_err(result) + assert result.conflicts[key].versionstamp == initial.versionstamp + + +@pytest_mark_asyncio +async def test_Kv_write__check__returns_False_when_mismatching(kv: Kv) -> None: + async with validate_write_outcome(kv, None, result=None) as (kv, key): + assert (await kv.check(key, VersionStamp(1))) is False + assert (await kv.check(KvEntry(key, None, VersionStamp(1)))) is False + assert (await kv.check(Check(key, VersionStamp(1)))) is False + + async with validate_write_outcome(kv, 41, result=41) as (kv, key): + _, initial = await kv.get(key) + assert initial + wrong_ver = VersionStamp(int(initial.versionstamp) + 1) + assert (await kv.check(key)) is False + assert (await kv.check(key, None)) is False + assert (await kv.check(key, wrong_ver)) is False + assert (await kv.check(KvEntry(key, None, wrong_ver))) is False + assert (await kv.check(Check(key, None))) is False + assert (await kv.check(Check(key, wrong_ver))) is False + + +@pytest_mark_asyncio +async def test_Kv_write__check__returns_True_when_matching(kv: Kv) -> None: + async with validate_write_outcome(kv, None, result=None) as (kv, key): + assert (await kv.check(key)) is True + assert (await kv.check(key, None)) is True + assert (await kv.check(Check(key, None))) is True + + async with validate_write_outcome(kv, 41, result=41) as (kv, key): + _, initial = await kv.get(key) + assert initial + assert (await kv.check(key, initial.versionstamp)) is True + assert (await kv.check(KvEntry(key, None, initial.versionstamp))) is True + assert (await kv.check(Check(key, initial.versionstamp))) is True + + +@pytest_mark_asyncio +async def test_Kv_write__enqueue(kv: Kv, mock_db: MockKvDb) -> None: + assert len(mock_db.queued_messages) == 0 + + t = datetime.now() + timedelta(seconds=60) + await ( + kv.atomic() + .enqueue( + {"foo": "bar"}, + delivery_time=t, + retry_delays=[1, 2], + dead_letter_keys=[("foo", 1), ("bar", 2)], + ) + .enqueue({"baz": "boz"}) + .write() + ) + + assert len(mock_db.queued_messages) == 2 + a, b = mock_db.queued_messages + assert a.payload == JSMap(foo="bar") + assert a.deadline_ms == pytest.approx(t.timestamp() * 1000, rel=1) + assert a.backoff_schedule == [1000, 2000] # milliseconds + assert a.keys_if_undelivered == [KvKey("foo", 1), KvKey("bar", 2)] + + assert b.payload == JSMap(baz="boz") + assert b.deadline_ms == 0 + assert len(b.backoff_schedule) == DEFAULT_ENQUEUE_RETRY_DELAY_COUNT + assert b.keys_if_undelivered == [] + + @pytest_mark_asyncio async def test_aclose() -> None: authenticator = Mock() @@ -1172,3 +1626,57 @@ async def all_inner_tasks_awaited() -> AsyncGenerator[None]: inner_tasks = asyncio.all_tasks() - pre_existing_tasks if inner_tasks: await asyncio.wait(inner_tasks) + + +LIST_CONTEXT = ListContext( + None, + None, + None, + b"", + b"", + None, + None, + False, + ConsistencyLevel.STRONG, + 1, + lambda lc: Base64KeySuffixCursorFormat(lc.packed_start, lc.packed_end), +) + +DB_META = DatabaseMetadata( + version=2, + database_id=UUID("AD50A341-5351-4FC3-82D0-72CFEE369A09"), + token="thisisnotasecret", + expires_at=datetime.now(), + endpoints=( + EndpointInfo( + url=URL("https://db.example.com/v2"), + consistency=ConsistencyLevel.STRONG, + ), + ), +) + + +@pytest.fixture( + params=[ + pytest.param( + ListKvEntry(KvKey("a"), 42, VersionStamp(1), LIST_CONTEXT), id="ListKvEntry" + ), + pytest.param(EndpointSelector(DB_META), id="EndpointSelector"), + pytest.param(CachedValue(fresh_until=42, value=42), id="CachedValue"), + pytest.param(KvCredentials(URL("http://example"), ""), id="KvCredentials"), + pytest.param( + Authenticator(cast(Any, None), cast(Any, None), cast(Any, None)), + id="Authenticator", + ), + pytest.param(LIST_CONTEXT, id="ListContext"), + pytest.param( + Base64KeySuffixCursorFormat(b"", b""), id="Base64KeySuffixCursorFormat" + ), + ] +) +def instance(request: pytest.FixtureRequest) -> object: + param: object = request.param + return param + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() diff --git a/test/test_kv_keys__kvkey.py b/test/test_kv_keys__kvkey.py index b75dd9a..5115d3f 100644 --- a/test/test_kv_keys__kvkey.py +++ b/test/test_kv_keys__kvkey.py @@ -2,7 +2,7 @@ import re import weakref -from typing import Literal +from typing import Literal # noqa: TID251 import pytest from fdb.tuple import pack @@ -16,12 +16,15 @@ from denokv.datapath import KvKeyTuple from denokv.datapath import pack_key from denokv.kv_keys import KvKey +from test.denokv_testing import create_dataclass_slots_test -def test_instances_do_not_define_dict() -> None: - k = KvKey() - with pytest.raises(AttributeError): - print(k.__dict__) +@pytest.fixture +def instance() -> KvKey: + return KvKey("a") + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() def test_instances_are_KvKeyEncodable() -> None: diff --git a/test/test_kv_keys__kvkeyrange.py b/test/test_kv_keys__kvkeyrange.py index d8de0d6..f2fd3d6 100644 --- a/test/test_kv_keys__kvkeyrange.py +++ b/test/test_kv_keys__kvkeyrange.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal # noqa: TID251 import pytest @@ -11,6 +11,7 @@ from denokv.kv_keys import KvKeyRange from denokv.kv_keys import StartBoundary from denokv.kv_keys import StopBoundary +from test.denokv_testing import create_dataclass_slots_test def test_types() -> None: @@ -192,3 +193,20 @@ def test_contains__stop( key_range = KvKeyRange(IncludeAll(), stop) assert (key in key_range) == key_included + + +@pytest.fixture( + params=[ + pytest.param(Include("b", 10), id="Include"), + pytest.param(IncludePrefix("b", 10), id="IncludePrefix"), + pytest.param(Exclude("b", 10), id="Exclude"), + pytest.param(IncludeAll(), id="IncludeAll"), + pytest.param(KvKeyRange(), id="KvKeyRange"), + ] +) +def instance(request: pytest.FixtureRequest) -> object: + param: object = request.param + return param + + +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() diff --git a/test/test_result.py b/test/test_result.py index 4e9394c..3d65699 100644 --- a/test/test_result.py +++ b/test/test_result.py @@ -1,17 +1,18 @@ from __future__ import annotations -import sys -from typing import Literal +from typing import Literal # noqa: TID251 # noqa: TID251 from unittest.mock import Mock import pytest from denokv._pycompat.typing import TYPE_CHECKING from denokv._pycompat.typing import Any +from denokv._pycompat.typing import Callable from denokv._pycompat.typing import Iterable from denokv._pycompat.typing import Never from denokv._pycompat.typing import Sequence from denokv._pycompat.typing import TypeIs +from denokv._pycompat.typing import Union from denokv._pycompat.typing import cast from denokv.result import AnyFailure from denokv.result import AnySuccess @@ -27,17 +28,23 @@ from denokv.result import Some from denokv.result import is_err from denokv.result import is_ok +from test.denokv_testing import create_dataclass_slots_test -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="<3.10 does not use slots for dataclass" +@pytest.fixture( + params=[ + pytest.param(lambda: Some(1), id="Some"), + pytest.param(lambda: Nothing(), id="Nothing"), + pytest.param(lambda: Ok(1), id="Ok"), + pytest.param(lambda: Err("x"), id="Err"), + ] ) -def test_Option__instances_use_slots_to_avoid_dict() -> None: - with pytest.raises(AttributeError): - print(Some(1).__dict__) +def instance(request: pytest.FixtureRequest) -> Option[int] | Result[int, str]: + param: Callable[[], Union[Option[int], Result[int, str]]] = request.param + return param() + - with pytest.raises(AttributeError): - print(Nothing().__dict__) +test_instances_dont_have_dict_because_of_slots = create_dataclass_slots_test() def test_Option__satisfies_OptionMethods() -> None: @@ -201,17 +208,6 @@ def type_check_zip_with(a: Option[str], b: Option[object]) -> Option[int]: return a.zip_with(b, int) # type: ignore[arg-type] -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="<3.10 does not use slots for dataclass" -) -def test_Result__instances_use_slots_to_avoid_dict() -> None: - with pytest.raises(AttributeError): - print(Ok(1).__dict__) - - with pytest.raises(AttributeError): - print(Err("x").__dict__) - - def test_Result__satisfies_ResultMethods() -> None: ok = Ok(1) err = Err("x")