diff --git a/mypy/errors.py b/mypy/errors.py index b977f3b1a103..ba365d5e815b 100644 --- a/mypy/errors.py +++ b/mypy/errors.py @@ -6,7 +6,7 @@ from collections import defaultdict from collections.abc import Callable, Iterable, Iterator from itertools import chain -from typing import Final, Literal, NoReturn, TextIO, TypeAlias as _TypeAlias, TypeVar +from typing import Final, Literal, NamedTuple, NoReturn, TextIO, TypeAlias as _TypeAlias, TypeVar from typing_extensions import Self from mypy import errorcodes as codes @@ -229,11 +229,19 @@ def filtered_errors(self) -> list[ErrorInfo]: return self._filtered +class NonOverlapErrorInfo(NamedTuple): + line: int + column: int + end_line: int | None + end_column: int | None + kind: str + + class IterationDependentErrors: """An `IterationDependentErrors` instance serves to collect the `unreachable`, - `redundant-expr`, and `redundant-casts` errors, as well as the revealed types, - handled by the individual `IterationErrorWatcher` instances sequentially applied to - the same code section.""" + `redundant-expr`, and `redundant-casts` errors, as well as the revealed types and + non-overlapping types, handled by the individual `IterationErrorWatcher` instances + sequentially applied to the same code section.""" # One set of `unreachable`, `redundant-expr`, and `redundant-casts` errors per # iteration step. Meaning of the tuple items: ErrorCode, message, line, column, @@ -249,9 +257,13 @@ class IterationDependentErrors: # end_line, end_column: revealed_types: dict[tuple[int, int, int | None, int | None], list[Type]] + # One dictionary of non-overlapping types per iteration step: + nonoverlapping_types: list[dict[NonOverlapErrorInfo, tuple[Type, Type]]] + def __init__(self) -> None: self.uselessness_errors = [] self.unreachable_lines = [] + self.nonoverlapping_types = [] self.revealed_types = defaultdict(list) def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCode]]: @@ -271,6 +283,36 @@ def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCod context.end_column = error_info[5] yield error_info[1], context, error_info[0] + def yield_nonoverlapping_types( + self, + ) -> Iterator[tuple[tuple[list[Type], list[Type]], str, Context]]: + """Report expressions where non-overlapping types were detected for all iterations + were the expression was reachable.""" + + selected = set() + for candidate in set(chain.from_iterable(self.nonoverlapping_types)): + if all( + (candidate in nonoverlap) or (candidate.line in lines) + for nonoverlap, lines in zip(self.nonoverlapping_types, self.unreachable_lines) + ): + selected.add(candidate) + + persistent_nonoverlaps: dict[NonOverlapErrorInfo, tuple[list[Type], list[Type]]] = ( + defaultdict(lambda: ([], [])) + ) + for nonoverlaps in self.nonoverlapping_types: + for candidate, (left, right) in nonoverlaps.items(): + if candidate in selected: + types = persistent_nonoverlaps[candidate] + types[0].append(left) + types[1].append(right) + + for error_info, types in persistent_nonoverlaps.items(): + context = Context(line=error_info.line, column=error_info.column) + context.end_line = error_info.end_line + context.end_column = error_info.end_column + yield (types[0], types[1]), error_info.kind, context + def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]: """Yield all types revealed in at least one iteration step.""" @@ -283,8 +325,9 @@ def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]: class IterationErrorWatcher(ErrorWatcher): """Error watcher that filters and separately collects `unreachable` errors, - `redundant-expr` and `redundant-casts` errors, and revealed types when analysing - code sections iteratively to help avoid making too-hasty reports.""" + `redundant-expr` and `redundant-casts` errors, and revealed types and + non-overlapping types when analysing code sections iteratively to help avoid + making too-hasty reports.""" iteration_dependent_errors: IterationDependentErrors @@ -305,6 +348,7 @@ def __init__( ) self.iteration_dependent_errors = iteration_dependent_errors iteration_dependent_errors.uselessness_errors.append(set()) + iteration_dependent_errors.nonoverlapping_types.append({}) iteration_dependent_errors.unreachable_lines.append(set()) def on_error(self, file: str, info: ErrorInfo) -> bool: diff --git a/mypy/messages.py b/mypy/messages.py index 28a4f8d614ca..54c08415a302 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -29,6 +29,7 @@ ErrorWatcher, IterationDependentErrors, IterationErrorWatcher, + NonOverlapErrorInfo, ) from mypy.nodes import ( ARG_NAMED, @@ -1624,6 +1625,26 @@ def incompatible_typevar_value( ) def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) -> None: + # In loops (and similar cases), the same expression might be analysed multiple + # times and thereby confronted with different types. We only want to raise a + # `comparison-overlap` error if it occurs in all cases and therefore collect the + # respective types of the current iteration here so that we can report the error + # later if it is persistent over all iteration steps: + for watcher in self.errors.get_watchers(): + if watcher._filter: + break + if isinstance(watcher, IterationErrorWatcher): + watcher.iteration_dependent_errors.nonoverlapping_types[-1][ + NonOverlapErrorInfo( + line=ctx.line, + column=ctx.column, + end_line=ctx.end_line, + end_column=ctx.end_column, + kind=kind, + ) + ] = (left, right) + return + left_str = "element" if kind == "container" else "left operand" right_str = "container item" if kind == "container" else "right operand" message = "Non-overlapping {} check ({} type: {}, {} type: {})" @@ -2513,8 +2534,11 @@ def match_statement_inexhaustive_match(self, typ: Type, context: Context) -> Non def iteration_dependent_errors(self, iter_errors: IterationDependentErrors) -> None: for error_info in iter_errors.yield_uselessness_error_infos(): self.fail(*error_info[:2], code=error_info[2]) + msu = mypy.typeops.make_simplified_union + for nonoverlaps, kind, context in iter_errors.yield_nonoverlapping_types(): + self.dangerous_comparison(msu(nonoverlaps[0]), msu(nonoverlaps[1]), kind, context) for types, context in iter_errors.yield_revealed_type_infos(): - self.reveal_type(mypy.typeops.make_simplified_union(types), context) + self.reveal_type(msu(types), context) def quote_type_string(type_string: str) -> str: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 00d33c86414f..d92e100f7b6d 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2446,6 +2446,41 @@ while x is not None and b(): x = f() [builtins fixtures/primitives.pyi] +[case testAvoidFalseNonOverlappingEqualityCheckInLoop1] +# flags: --allow-redefinition-new --local-partial-types --strict-equality + +x = 1 +while True: + if x == str(): + break + x = str() + if x == int(): # E: Non-overlapping equality check (left operand type: "str", right operand type: "int") + break +[builtins fixtures/primitives.pyi] + +[case testAvoidFalseNonOverlappingEqualityCheckInLoop2] +# flags: --allow-redefinition-new --local-partial-types --strict-equality + +class A: ... +class B: ... +class C: ... + +x = A() +while True: + if x == C(): # E: Non-overlapping equality check (left operand type: "Union[A, B]", right operand type: "C") + break + x = B() +[builtins fixtures/primitives.pyi] + +[case testAvoidFalseNonOverlappingEqualityCheckInLoop3] +# flags: --strict-equality + +for y in [1.0]: + if y is not None or y != "None": + ... + +[builtins fixtures/primitives.pyi] + [case testNarrowPromotionsInsideUnions1] from typing import Union