From 90d5b7bac9c27dcba98d6b84df6a2422f0554513 Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 20 Apr 2026 05:12:19 +0000 Subject: [PATCH 1/3] Exempt async CM methods from ASYNC910/911 when partner checkpoints ASYNC910 and ASYNC911 no longer require every `__aenter__`/`__aexit__` to contain a checkpoint. Per Trio's documentation, an async context manager only needs one of entry/exit to act as a checkpoint. When a class defines both methods, the one without an `await` is exempt if its partner contains one. When a class defines only one of the two, the partner is charitably assumed to be inherited from a base class and to contain a checkpoint, so the defined method is also exempt. Closes https://github.com/python-trio/flake8-async/issues/441 https://claude.ai/code/session_014jAydKywq31Ew4fVYGJdiG --- docs/changelog.rst | 1 + flake8_async/visitors/visitor91x.py | 62 +++++++++++++++++++++ tests/autofix_files/async910.py | 80 ++++++++++++++++++++++++++++ tests/autofix_files/async910.py.diff | 20 +++++++ tests/eval_files/async910.py | 77 ++++++++++++++++++++++++++ 5 files changed, 240 insertions(+) diff --git a/docs/changelog.rst b/docs/changelog.rst index 5ac72f4..c704870 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,6 +7,7 @@ Changelog Unreleased ========== - Autofix for :ref:`ASYNC910 ` / :ref:`ASYNC911 ` no longer inserts checkpoints inside ``except`` clauses (which would trigger :ref:`ASYNC120 `); instead the checkpoint is added at the top of the function or of the enclosing loop. `(issue #403) `_ +- :ref:`ASYNC910 ` and :ref:`ASYNC911 ` now accept ``__aenter__`` / ``__aexit__`` methods when the partner method provides the checkpoint, or when only one of the two is defined on a class that inherits from another class (charitably assuming the partner is inherited and contains a checkpoint). `(issue #441) `_ 25.7.1 ====== diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index 67f9d51..fc3409a 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -465,6 +465,17 @@ def __init__(self, *args: Any, **kwargs: Any): # used to transfer new body between visit_FunctionDef and leave_FunctionDef self.new_body: cst.BaseSuite | None = None + # Tracks whether the current scope is a class body and, if so, which of + # `__aenter__`/`__aexit__` are directly defined on it (values: True if + # that method contains an `await`, False otherwise, missing key if not + # defined). Used to exempt async context manager methods from + # ASYNC910/911 when their partner method provides the checkpoint, or + # when the partner is assumed inherited (not defined on this class). + self.async_cm_class: dict[str, bool] | None = None + # Set on entry to an exempt `__aenter__`/`__aexit__` so that + # `error_91x` skips emitting ASYNC910/911. + self.exempt_async_cm_method = False + def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool: if code is None: code = "ASYNC911" if self.has_yield else "ASYNC910" @@ -532,6 +543,44 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: self.suppress_imported_as.append("suppress") return + # Async context manager methods may legitimately skip checkpointing if the + # partner method provides the checkpoint, or if the partner is inherited + # from a base class (which we charitably assume contains a checkpoint). + # See https://github.com/python-trio/flake8-async/issues/441. + def visit_ClassDef(self, node: cst.ClassDef) -> None: + self.save_state(node, "async_cm_class") + defined: dict[str, bool] = {} + if isinstance(node.body, cst.IndentedBlock): + for stmt in node.body.body: + if ( + isinstance(stmt, cst.FunctionDef) + and stmt.asynchronous is not None + and stmt.name.value in ("__aenter__", "__aexit__") + ): + defined[stmt.name.value] = bool(m.findall(stmt, m.Await())) + self.async_cm_class = defined + + def leave_ClassDef( + self, original_node: cst.ClassDef, updated_node: cst.ClassDef + ) -> cst.ClassDef: + self.restore_state(original_node) + return updated_node + + def _is_exempt_async_cm_method(self, node: cst.FunctionDef) -> bool: + if self.async_cm_class is None: + return False + name = node.name.value + if name not in ("__aenter__", "__aexit__"): + return False + if name not in self.async_cm_class: + return False + partner = "__aexit__" if name == "__aenter__" else "__aenter__" + # Partner not defined in this class -> assume inherited with checkpoint. + if partner not in self.async_cm_class: + return True + # Partner defined and (charitably) contains a checkpoint. + return self.async_cm_class[partner] + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # `await` in default values happen in parent scope # we also know we don't ever modify parameters so we can ignore the return value @@ -543,6 +592,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: if func_has_decorator(node, "overload", "fixture") or func_empty_body(node): return False # subnodes can be ignored + is_exempt_cm = self._is_exempt_async_cm_method(node) + self.save_state( node, "has_yield", @@ -557,6 +608,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: "suppress_imported_as", # a copy is saved, but state is not reset "except_depth", "add_checkpoint_at_function_start", + "async_cm_class", + "exempt_async_cm_method", copy=True, ) self.uncheckpointed_statements = set() @@ -567,6 +620,9 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: self.taskgroup_has_start_soon = {} self.except_depth = 0 self.add_checkpoint_at_function_start = False + # Class-level context does not apply to nested scopes. + self.async_cm_class = None + self.exempt_async_cm_method = is_exempt_cm self.async_function = ( node.asynchronous is not None @@ -747,6 +803,12 @@ def error_91x( ) -> bool: assert not isinstance(statement, ArtificialStatement), statement + # Exempt `__aenter__`/`__aexit__` when the partner method contains a + # checkpoint, or when the partner is missing and charitably assumed + # inherited. + if self.exempt_async_cm_method: + return False + if isinstance(node, cst.FunctionDef): msg = "exit" else: diff --git a/tests/autofix_files/async910.py b/tests/autofix_files/async910.py index 0d67a69..6972072 100644 --- a/tests/autofix_files/async910.py +++ b/tests/autofix_files/async910.py @@ -636,3 +636,83 @@ async def foo_nested_empty_async(): async def bar(): ... await foo() + + +# Issue #441: async context manager methods may legitimately skip checkpointing +# if the partner method provides the checkpoint, or if the partner is inherited. +class CtxWithSetup: # safe: __aenter__ checkpoints, __aexit__ can be fast + async def __aenter__(self): + await foo() + + async def __aexit__(self, exc_type, exc, tb): + print("fast exit") + + +class CtxWithTeardown: # safe: __aexit__ checkpoints, __aenter__ can be fast + async def __aenter__(self): + print("fast setup") + + async def __aexit__(self, exc_type, exc, tb): + await foo() + + +class CtxWithBothCheckpoint: # safe: both checkpoint + async def __aenter__(self): + await foo() + + async def __aexit__(self, exc_type, exc, tb): + await foo() + + +# fmt: off +class CtxNeitherCheckpoint: + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + await trio.lowlevel.checkpoint() + + async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line) + print("teardown") + await trio.lowlevel.checkpoint() +# fmt: on + + +# Only one method defined: charitably assume the other is inherited with a checkpoint. +class CtxOnlyAenter: # safe: __aexit__ assumed inherited with checkpoint + async def __aenter__(self): + print("setup") + + +class CtxOnlyAexit: # safe: __aenter__ assumed inherited with checkpoint + async def __aexit__(self, *a): + print("teardown") + + +class CtxOnlyAenterWithCheckpoint: # safe + async def __aenter__(self): + await foo() + + +class CtxOnlyAexitWithCheckpoint: # safe + async def __aexit__(self, *a): + await foo() + + +# a nested function named `__aenter__` inside another function is not a method +def not_a_class(): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + await trio.lowlevel.checkpoint() + + +# class nested inside a function still gets the exemption +def factory(): + class NestedCtx: # safe + async def __aenter__(self): + print("setup") + + +# nested class; outer class has nothing relevant +class Outer: + class Inner: # safe: charitable inheritance for __aexit__ + async def __aenter__(self): + print("setup") diff --git a/tests/autofix_files/async910.py.diff b/tests/autofix_files/async910.py.diff index c765401..f8bbf55 100644 --- a/tests/autofix_files/async910.py.diff +++ b/tests/autofix_files/async910.py.diff @@ -223,3 +223,23 @@ async def foo_nested_empty_async(): +@@ x,9 x,11 @@ + class CtxNeitherCheckpoint: + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") ++ await trio.lowlevel.checkpoint() + + async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line) + print("teardown") ++ await trio.lowlevel.checkpoint() + # fmt: on + + +@@ x,6 x,7 @@ + def not_a_class(): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") ++ await trio.lowlevel.checkpoint() + + + # class nested inside a function still gets the exemption diff --git a/tests/eval_files/async910.py b/tests/eval_files/async910.py index d370155..e6ae448 100644 --- a/tests/eval_files/async910.py +++ b/tests/eval_files/async910.py @@ -606,3 +606,80 @@ async def foo_nested_empty_async(): async def bar(): ... await foo() + + +# Issue #441: async context manager methods may legitimately skip checkpointing +# if the partner method provides the checkpoint, or if the partner is inherited. +class CtxWithSetup: # safe: __aenter__ checkpoints, __aexit__ can be fast + async def __aenter__(self): + await foo() + + async def __aexit__(self, exc_type, exc, tb): + print("fast exit") + + +class CtxWithTeardown: # safe: __aexit__ checkpoints, __aenter__ can be fast + async def __aenter__(self): + print("fast setup") + + async def __aexit__(self, exc_type, exc, tb): + await foo() + + +class CtxWithBothCheckpoint: # safe: both checkpoint + async def __aenter__(self): + await foo() + + async def __aexit__(self, exc_type, exc, tb): + await foo() + + +# fmt: off +class CtxNeitherCheckpoint: + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + + async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line) + print("teardown") +# fmt: on + + +# Only one method defined: charitably assume the other is inherited with a checkpoint. +class CtxOnlyAenter: # safe: __aexit__ assumed inherited with checkpoint + async def __aenter__(self): + print("setup") + + +class CtxOnlyAexit: # safe: __aenter__ assumed inherited with checkpoint + async def __aexit__(self, *a): + print("teardown") + + +class CtxOnlyAenterWithCheckpoint: # safe + async def __aenter__(self): + await foo() + + +class CtxOnlyAexitWithCheckpoint: # safe + async def __aexit__(self, *a): + await foo() + + +# a nested function named `__aenter__` inside another function is not a method +def not_a_class(): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + + +# class nested inside a function still gets the exemption +def factory(): + class NestedCtx: # safe + async def __aenter__(self): + print("setup") + + +# nested class; outer class has nothing relevant +class Outer: + class Inner: # safe: charitable inheritance for __aexit__ + async def __aenter__(self): + print("setup") From f2968d4984c7a6cdee1ea0ad60d7ad64b08b5fd7 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 21 Apr 2026 19:34:39 +0000 Subject: [PATCH 2/3] Tighten async CM exemption rules Two refinements in response to review feedback: - If an `__aenter__`/`__aexit__` method contains any checkpoint-like construct (`await`, `async with`, or `async for`), it must always checkpoint. We no longer exempt such methods even when the partner provides a checkpoint -- conditional checkpoints are still flagged. - Only charitably assume a missing partner is inherited (with a checkpoint) when the class actually inherits from something. Classes with no base classes are treated as flat, and methods that don't checkpoint are flagged. `metaclass=` and other keyword arguments do not count as inheriting, since they live in `ClassDef.keywords` rather than `ClassDef.bases`. https://claude.ai/code/session_014jAydKywq31Ew4fVYGJdiG --- flake8_async/visitors/visitor91x.py | 36 ++++++++++++++----- tests/autofix_files/async910.py | 53 ++++++++++++++++++++++++---- tests/autofix_files/async910.py.diff | 31 +++++++++++++++- tests/eval_files/async910.py | 49 +++++++++++++++++++++---- 4 files changed, 147 insertions(+), 22 deletions(-) diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index fc3409a..d54c895 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -467,11 +467,15 @@ def __init__(self, *args: Any, **kwargs: Any): # Tracks whether the current scope is a class body and, if so, which of # `__aenter__`/`__aexit__` are directly defined on it (values: True if - # that method contains an `await`, False otherwise, missing key if not - # defined). Used to exempt async context manager methods from - # ASYNC910/911 when their partner method provides the checkpoint, or - # when the partner is assumed inherited (not defined on this class). + # that method contains a checkpoint-like construct, False otherwise, + # missing key if not defined). Used to exempt async context manager + # methods from ASYNC910/911 when their partner method provides the + # checkpoint, or when the partner is inherited from a base class. self.async_cm_class: dict[str, bool] | None = None + # Whether the enclosing class has an explicit base class (other than + # implicit `object`). We only assume a missing partner is inherited if + # the class actually inherits from something. + self.async_cm_class_has_bases = False # Set on entry to an exempt `__aenter__`/`__aexit__` so that # `error_91x` skips emitting ASYNC910/911. self.exempt_async_cm_method = False @@ -548,8 +552,13 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None: # from a base class (which we charitably assume contains a checkpoint). # See https://github.com/python-trio/flake8-async/issues/441. def visit_ClassDef(self, node: cst.ClassDef) -> None: - self.save_state(node, "async_cm_class") + self.save_state(node, "async_cm_class", "async_cm_class_has_bases") defined: dict[str, bool] = {} + checkpointy = ( + m.Await() + | m.With(asynchronous=m.Asynchronous()) + | m.For(asynchronous=m.Asynchronous()) + ) if isinstance(node.body, cst.IndentedBlock): for stmt in node.body.body: if ( @@ -557,8 +566,10 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: and stmt.asynchronous is not None and stmt.name.value in ("__aenter__", "__aexit__") ): - defined[stmt.name.value] = bool(m.findall(stmt, m.Await())) + defined[stmt.name.value] = bool(m.findall(stmt, checkpointy)) self.async_cm_class = defined + # Keyword args like `metaclass=` are in `node.keywords`, not `bases`. + self.async_cm_class_has_bases = bool(node.bases) def leave_ClassDef( self, original_node: cst.ClassDef, updated_node: cst.ClassDef @@ -574,11 +585,16 @@ def _is_exempt_async_cm_method(self, node: cst.FunctionDef) -> bool: return False if name not in self.async_cm_class: return False + # A method that contains any checkpoint must always checkpoint: we + # still check it normally so conditional checkpoints are flagged. + if self.async_cm_class[name]: + return False partner = "__aexit__" if name == "__aenter__" else "__aenter__" - # Partner not defined in this class -> assume inherited with checkpoint. if partner not in self.async_cm_class: - return True - # Partner defined and (charitably) contains a checkpoint. + # Partner is not defined on this class; only assume it is inherited + # (and contains a checkpoint) if the class inherits from something. + return self.async_cm_class_has_bases + # Partner defined; exempt iff it contains a checkpoint. return self.async_cm_class[partner] def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: @@ -609,6 +625,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: "except_depth", "add_checkpoint_at_function_start", "async_cm_class", + "async_cm_class_has_bases", "exempt_async_cm_method", copy=True, ) @@ -622,6 +639,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: self.add_checkpoint_at_function_start = False # Class-level context does not apply to nested scopes. self.async_cm_class = None + self.async_cm_class_has_bases = False self.exempt_async_cm_method = is_exempt_cm self.async_function = ( diff --git a/tests/autofix_files/async910.py b/tests/autofix_files/async910.py index 6972072..4f37e12 100644 --- a/tests/autofix_files/async910.py +++ b/tests/autofix_files/async910.py @@ -640,6 +640,10 @@ async def bar(): ... # Issue #441: async context manager methods may legitimately skip checkpointing # if the partner method provides the checkpoint, or if the partner is inherited. +class ACM: # a dummy base to opt into the charitable-inheritance assumption + pass + + class CtxWithSetup: # safe: __aenter__ checkpoints, __aexit__ can be fast async def __aenter__(self): await foo() @@ -676,17 +680,43 @@ async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", # fmt: on -# Only one method defined: charitably assume the other is inherited with a checkpoint. -class CtxOnlyAenter: # safe: __aexit__ assumed inherited with checkpoint +# A method that contains any checkpoint is still required to always checkpoint. +class CtxAenterConditionalAexitFast(ACM): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + if _: + await foo() + await trio.lowlevel.checkpoint() + + async def __aexit__(self, *a): + print("fast exit") + + +# Only one method defined: charitably assume the other is inherited with a +# checkpoint -- but only when the class inherits from something. +class CtxOnlyAenterInherited(ACM): # safe: __aexit__ assumed inherited async def __aenter__(self): print("setup") -class CtxOnlyAexit: # safe: __aenter__ assumed inherited with checkpoint +class CtxOnlyAexitInherited(ACM): # safe: __aenter__ assumed inherited async def __aexit__(self, *a): print("teardown") +# fmt: off +class CtxOnlyAenter: # no base class -> don't assume inheritance + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + await trio.lowlevel.checkpoint() + + +class CtxOnlyAexit: # no base class -> don't assume inheritance + async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line) + print("teardown") + await trio.lowlevel.checkpoint() +# fmt: on + + class CtxOnlyAenterWithCheckpoint: # safe async def __aenter__(self): await foo() @@ -697,6 +727,17 @@ async def __aexit__(self, *a): await foo() +# keyword-only bases (like `metaclass=`) don't count as inheriting. +class Meta(type): + pass + + +class CtxMetaclassOnly(metaclass=Meta): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + await trio.lowlevel.checkpoint() + + # a nested function named `__aenter__` inside another function is not a method def not_a_class(): async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) @@ -704,15 +745,15 @@ async def __aenter__(self): # error: 4, "exit", Stmt("function definition", lin await trio.lowlevel.checkpoint() -# class nested inside a function still gets the exemption +# class nested inside a function still gets the exemption when it inherits def factory(): - class NestedCtx: # safe + class NestedCtx(ACM): # safe async def __aenter__(self): print("setup") # nested class; outer class has nothing relevant class Outer: - class Inner: # safe: charitable inheritance for __aexit__ + class Inner(ACM): # safe: charitable inheritance for __aexit__ async def __aenter__(self): print("setup") diff --git a/tests/autofix_files/async910.py.diff b/tests/autofix_files/async910.py.diff index f8bbf55..e6d7d19 100644 --- a/tests/autofix_files/async910.py.diff +++ b/tests/autofix_files/async910.py.diff @@ -236,10 +236,39 @@ @@ x,6 x,7 @@ + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + if _: + await foo() ++ await trio.lowlevel.checkpoint() + + async def __aexit__(self, *a): + print("fast exit") +@@ x,11 x,13 @@ + class CtxOnlyAenter: # no base class -> don't assume inheritance + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") ++ await trio.lowlevel.checkpoint() + + + class CtxOnlyAexit: # no base class -> don't assume inheritance + async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line) + print("teardown") ++ await trio.lowlevel.checkpoint() + # fmt: on + + +@@ x,12 x,14 @@ + class CtxMetaclassOnly(metaclass=Meta): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") ++ await trio.lowlevel.checkpoint() + + + # a nested function named `__aenter__` inside another function is not a method def not_a_class(): async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) print("setup") + await trio.lowlevel.checkpoint() - # class nested inside a function still gets the exemption + # class nested inside a function still gets the exemption when it inherits diff --git a/tests/eval_files/async910.py b/tests/eval_files/async910.py index e6ae448..720d68a 100644 --- a/tests/eval_files/async910.py +++ b/tests/eval_files/async910.py @@ -610,6 +610,10 @@ async def bar(): ... # Issue #441: async context manager methods may legitimately skip checkpointing # if the partner method provides the checkpoint, or if the partner is inherited. +class ACM: # a dummy base to opt into the charitable-inheritance assumption + pass + + class CtxWithSetup: # safe: __aenter__ checkpoints, __aexit__ can be fast async def __aenter__(self): await foo() @@ -644,17 +648,40 @@ async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", # fmt: on -# Only one method defined: charitably assume the other is inherited with a checkpoint. -class CtxOnlyAenter: # safe: __aexit__ assumed inherited with checkpoint +# A method that contains any checkpoint is still required to always checkpoint. +class CtxAenterConditionalAexitFast(ACM): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + if _: + await foo() + + async def __aexit__(self, *a): + print("fast exit") + + +# Only one method defined: charitably assume the other is inherited with a +# checkpoint -- but only when the class inherits from something. +class CtxOnlyAenterInherited(ACM): # safe: __aexit__ assumed inherited async def __aenter__(self): print("setup") -class CtxOnlyAexit: # safe: __aenter__ assumed inherited with checkpoint +class CtxOnlyAexitInherited(ACM): # safe: __aenter__ assumed inherited async def __aexit__(self, *a): print("teardown") +# fmt: off +class CtxOnlyAenter: # no base class -> don't assume inheritance + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + + +class CtxOnlyAexit: # no base class -> don't assume inheritance + async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line) + print("teardown") +# fmt: on + + class CtxOnlyAenterWithCheckpoint: # safe async def __aenter__(self): await foo() @@ -665,21 +692,31 @@ async def __aexit__(self, *a): await foo() +# keyword-only bases (like `metaclass=`) don't count as inheriting. +class Meta(type): + pass + + +class CtxMetaclassOnly(metaclass=Meta): + async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) + print("setup") + + # a nested function named `__aenter__` inside another function is not a method def not_a_class(): async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) print("setup") -# class nested inside a function still gets the exemption +# class nested inside a function still gets the exemption when it inherits def factory(): - class NestedCtx: # safe + class NestedCtx(ACM): # safe async def __aenter__(self): print("setup") # nested class; outer class has nothing relevant class Outer: - class Inner: # safe: charitable inheritance for __aexit__ + class Inner(ACM): # safe: charitable inheritance for __aexit__ async def __aenter__(self): print("setup") From 791fc88f492308cc94fe7e6feb99e0cf2b4b04b0 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 21 Apr 2026 19:45:19 +0000 Subject: [PATCH 3/3] Only flag __aenter__ when neither CM method checkpoints When both `__aenter__` and `__aexit__` are defined and neither contains a checkpoint, we used to flag (and autofix) both methods, which produced redundant `lowlevel.checkpoint()` calls -- only one is needed for the async context manager to checkpoint. Prefer to report and fix `__aenter__` in this case; `__aexit__` is exempted since adding a checkpoint to either satisfies the rule. https://claude.ai/code/session_014jAydKywq31Ew4fVYGJdiG --- flake8_async/visitors/visitor91x.py | 16 ++++++++++------ tests/autofix_files/async910.py | 3 +-- tests/autofix_files/async910.py.diff | 8 ++------ tests/eval_files/async910.py | 2 +- 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/flake8_async/visitors/visitor91x.py b/flake8_async/visitors/visitor91x.py index d54c895..123fba7 100644 --- a/flake8_async/visitors/visitor91x.py +++ b/flake8_async/visitors/visitor91x.py @@ -590,12 +590,16 @@ def _is_exempt_async_cm_method(self, node: cst.FunctionDef) -> bool: if self.async_cm_class[name]: return False partner = "__aexit__" if name == "__aenter__" else "__aenter__" - if partner not in self.async_cm_class: - # Partner is not defined on this class; only assume it is inherited - # (and contains a checkpoint) if the class inherits from something. - return self.async_cm_class_has_bases - # Partner defined; exempt iff it contains a checkpoint. - return self.async_cm_class[partner] + if partner in self.async_cm_class: + # Partner is defined on the class; if it checkpoints, we're fine. + if self.async_cm_class[partner]: + return True + # Neither method checkpoints -- to avoid double-flagging (and a + # redundant autofix), we report and fix only `__aenter__`. + return name == "__aexit__" + # Partner is not defined on this class; only assume it is inherited + # (and contains a checkpoint) if the class inherits from something. + return self.async_cm_class_has_bases def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: # `await` in default values happen in parent scope diff --git a/tests/autofix_files/async910.py b/tests/autofix_files/async910.py index 4f37e12..2b5baa4 100644 --- a/tests/autofix_files/async910.py +++ b/tests/autofix_files/async910.py @@ -674,9 +674,8 @@ async def __aenter__(self): # error: 4, "exit", Stmt("function definition", lin print("setup") await trio.lowlevel.checkpoint() - async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line) + async def __aexit__(self, *a): # only __aenter__ is flagged to avoid redundancy print("teardown") - await trio.lowlevel.checkpoint() # fmt: on diff --git a/tests/autofix_files/async910.py.diff b/tests/autofix_files/async910.py.diff index e6d7d19..9c55839 100644 --- a/tests/autofix_files/async910.py.diff +++ b/tests/autofix_files/async910.py.diff @@ -223,18 +223,14 @@ async def foo_nested_empty_async(): -@@ x,9 x,11 @@ +@@ x,6 x,7 @@ class CtxNeitherCheckpoint: async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) print("setup") + await trio.lowlevel.checkpoint() - async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line) + async def __aexit__(self, *a): # only __aenter__ is flagged to avoid redundancy print("teardown") -+ await trio.lowlevel.checkpoint() - # fmt: on - - @@ x,6 x,7 @@ async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) if _: diff --git a/tests/eval_files/async910.py b/tests/eval_files/async910.py index 720d68a..2f2850d 100644 --- a/tests/eval_files/async910.py +++ b/tests/eval_files/async910.py @@ -643,7 +643,7 @@ class CtxNeitherCheckpoint: async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line) print("setup") - async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line) + async def __aexit__(self, *a): # only __aenter__ is flagged to avoid redundancy print("teardown") # fmt: on