From 9f6aa0a50f97702638795faec7361e1de4cbdfa5 Mon Sep 17 00:00:00 2001 From: Dima Gerasimov Date: Sun, 24 May 2026 16:49:36 +0100 Subject: [PATCH] core: fix potential dupes when cache write fails --- src/cachew/__init__.py | 81 +++++++++++++++++++++++++-------- src/cachew/common.py | 8 ++++ src/cachew/tests/test_cachew.py | 29 +++++++++++- 3 files changed, 99 insertions(+), 19 deletions(-) diff --git a/src/cachew/__init__.py b/src/cachew/__init__.py index 76710a4..cfb1b00 100644 --- a/src/cachew/__init__.py +++ b/src/cachew/__init__.py @@ -5,13 +5,15 @@ import logging import stat import warnings -from collections.abc import Callable, Iterable, Iterator +from collections.abc import Callable, Generator, Iterable, Iterator from dataclasses import dataclass +from enum import Enum, auto from pathlib import Path from typing import ( Any, Literal, Protocol, + assert_never, cast, overload, ) @@ -37,7 +39,7 @@ def orjson_dumps(*args: Any, **kwargs: Any) -> bytes: # type: ignore[misc] from .backend.common import AbstractBackend from .backend.file import FileBackend from .backend.sqlite import SqliteBackend -from .common import DEPENDENCIES, CacheReadError, CachewException, SourceHash +from .common import DEPENDENCIES, CacheReadError, CachewException, CacheWriteError, SourceHash from .logging_helper import make_logger from .marshall.cachew import CachewMarshall @@ -414,6 +416,25 @@ def composite_hash(self, *args, **kwargs) -> dict[str, Any]: return hash_parts +class _StreamState(Enum): + """ + Streaming lifecycle for cachew_wrapper. + Allowed transitions: + SETUP -> SOURCE when cachew deliberately yields from the wrapped function without caching. + SETUP -> CACHE_WRITE after cachew gets an exclusive write transaction. + SETUP -> FINISHED after a complete cache hit. + CACHE_WRITE -> SOURCE after a recoverable cache write failure; the same source iterator resumes uncached. + CACHE_WRITE -> FINISHED after the source iterator is exhausted and all items have been yielded. + SOURCE -> FINISHED after uncached streaming completes. + FINISHED is terminal; cleanup errors may be logged or raised, but must not trigger fallback that emits more items. + """ + + SETUP = auto() + SOURCE = auto() + CACHE_WRITE = auto() + FINISHED = auto() + + @dataclass class CacheSession[ItemT]: """ @@ -446,7 +467,7 @@ def cached_items(self) -> Iterator[ItemT]: f'failed to read cachew cache ({self.backend_name}:{self.resolved_cache_path}); remove the cache and try again' ) from e - def write_to_cache(self, datas: Iterable[ItemT]) -> Iterator[ItemT]: + def write_items_to_cache(self, datas: Iterable[ItemT]) -> Generator[ItemT, None, int]: if isinstance(self.backend, FileBackend): # FIXME uhhh.. this is a bit crap # but in sqlite mode we don't want to publish new hash before we write new items @@ -471,13 +492,19 @@ def flush() -> None: total_objects += 1 yield obj - dct = self.marshall.dump(obj) - blob = orjson_dumps(dct) + try: + dct = self.marshall.dump(obj) + blob = orjson_dumps(dct) + except Exception as e: + msg = f'failed to write cachew cache ({self.backend_name}:{self.resolved_cache_path})' + raise CacheWriteError(msg) from e chunk.append(blob) if len(chunk) >= self.chunk_by: flush() flush() + return total_objects + def finalize_cache(self, *, total_objects: int) -> None: self.backend.finalize(self.new_hash) self.logger.info( f'wrote {total_objects} objects to cachew ({self.backend_name}:{self.resolved_cache_path})' @@ -522,8 +549,7 @@ def cachew_wrapper[**P, ItemT]( # but it lets us save a function call, hence a stack frame # see test_recursive* early_exit = False - running_uncached = False - served_from_cache = False + stream_state = _StreamState.SETUP try: BackendCls = BACKENDS[C.backend] @@ -550,7 +576,7 @@ def cachew_wrapper[**P, ItemT]( if new_hash == old_hash: logger.debug('hash matched: loading from cache') yield from session.cached_items() - served_from_cache = True + stream_state = _StreamState.FINISHED return logger.debug('hash mismatch: computing data and writing to db') @@ -560,9 +586,9 @@ def cachew_wrapper[**P, ItemT]( # NOTE: this is the bit we really have to watch out for and not put in a helper function # otherwise it's causing an extra stack frame on every call # the rest (reading from cachew or writing to cachew) happens once per function call? so not a huge deal - running_uncached = True + stream_state = _StreamState.SOURCE yield from func(*args, **kwargs) - running_uncached = False + stream_state = _StreamState.FINISHED return if synthetic_key is not None: @@ -576,18 +602,31 @@ def cachew_wrapper[**P, ItemT]( kwargs[synthetic_key] = missing_synthetic_values # ty: ignore[invalid-assignment] # at this point we're guaranteed to have an exclusive write transaction + fit = iter(func(*args, **kwargs)) try: - yield from session.write_to_cache(func(*args, **kwargs)) + stream_state = _StreamState.CACHE_WRITE + total_objects = yield from session.write_items_to_cache(fit) except GeneratorExit: # GeneratorExit itself is not caught below, but SQLAlchemy cleanup during interpreter shutdown can raise a normal Exception while unwinding. early_exit = True raise + except CacheWriteError as e: + # If there is an error during marshalling, etc, we can't just reemit func(*args, **kwargs), we might end up with dupes + # so we try to switch back to the original iterator (fit) -- note it's reused/mutated iin write_items_to_cache + cachew_error(e, logger=logger) + stream_state = _StreamState.SOURCE + yield from fit + stream_state = _StreamState.FINISHED + return + stream_state = _StreamState.FINISHED + session.finalize_cache(total_objects=total_objects) except CacheReadError: # Cache read failures bypass THROW_ON_ERROR because fallback can duplicate already-yielded cached items. # This can be thrown from session.cached_items() raise except Exception as e: - if running_uncached: + if stream_state is _StreamState.SOURCE: + # SOURCE means the wrapped function is already streaming uncached, so its exceptions must propagate unchanged. raise # Work around known SQLAlchemy/sqlite shutdown noise; do not suppress other cleanup errors. @@ -595,16 +634,22 @@ def cachew_wrapper[**P, ItemT]( if early_exit and 'Cannot operate on a closed database' in str(e): return + if stream_state is _StreamState.CACHE_WRITE: + # CACHE_WRITE may have already yielded source items, so fallback could duplicate emitted output. + raise + cachew_error(e, logger=logger) - if served_from_cache: - # this can happen if we fully read from the cache, but hit some error while shutting backend down - # - we're past reading, so we emitted all items user wanted from cache - # - we don't want to yield any items from original func - # so it's safe to simply return + if stream_state is _StreamState.FINISHED: + # FINISHED means all requested items were emitted; cachew_error handles THROW_ON_ERROR, but fallback must not emit more data. return - yield from func(*args, **kwargs) + if stream_state is _StreamState.SETUP: + # SETUP means no user-visible items have been yielded, so fallback is safe. + yield from func(*args, **kwargs) + return + + assert_never(stream_state) __all__ = [ diff --git a/src/cachew/common.py b/src/cachew/common.py index e2eeba8..5af9ac5 100644 --- a/src/cachew/common.py +++ b/src/cachew/common.py @@ -20,6 +20,14 @@ class CacheReadError(CachewException): pass +class CacheWriteError(CachewException): + """ + Internal signal for defensive cache write fallback. + """ + + pass + + @dataclass class TypeNotSupported(CachewException): type_: type diff --git a/src/cachew/tests/test_cachew.py b/src/cachew/tests/test_cachew.py index 33a1574..1d12b49 100644 --- a/src/cachew/tests/test_cachew.py +++ b/src/cachew/tests/test_cachew.py @@ -1046,7 +1046,6 @@ def orig2(): assert list(fun()) == [123] -@pytest.mark.xfail(reason='cache write errors after yielding currently restart the source iterator', strict=True) def test_defensive_write_error_after_yield_does_not_duplicate( tmp_path: Path, restore_settings, @@ -1074,6 +1073,34 @@ def fun() -> Iterator[BB]: assert calls == 1 +def test_write_source_error_after_yield_propagates_without_retry( + tmp_path: Path, + restore_settings, +) -> None: + """ + If the wrapped iterator fails while cachew is writing, the source error must not trigger defensive retry. + """ + settings.THROW_ON_ERROR = False + + class UserError(Exception): + pass + + calls = 0 + + @cachew(tmp_path) + def fun() -> Iterator[int]: + nonlocal calls + calls += 1 + yield 1 + raise UserError('boom') + + it = iter(fun()) + assert next(it) == 1 + with pytest.raises(UserError, match='boom'): + next(it) + assert calls == 1 + + def test_defensive_read_error_after_yield_raises_cache_read_error( tmp_path: Path, restore_settings,