Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 79 additions & 86 deletions mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import itertools
from collections import defaultdict
from typing import Final, NamedTuple

Expand Down Expand Up @@ -233,11 +232,12 @@ def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType:

def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
#
# check for existence of a starred pattern
# Step 1. Check for existence of a starred pattern
#
current_type = get_proper_type(self.type_context[-1])
if not self.can_match_sequence(current_type):
return self.early_non_match()

star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)]
star_position: int | None = None
if len(star_positions) == 1:
Expand All @@ -248,98 +248,91 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
if star_position is not None:
required_patterns -= 1

# 1. Go through all possible types and filter to only those which are sequences that
# could match that number of items
# 2. If there is exactly one tuple left with an unpack, then use that type
# and the unpack index
# 3. Otherwise, take the product of the item types so that each index can have a
# unique type. For tuples with unpack fallback to merging all of their types
# for each index, since we can't handle multiple unpacked items at once yet.

# Whether we have encountered a type that we don't know how to handle in the union
unknown_type = False
# A list of types that could match any of the items in the sequence.
sequence_types: list[Type] = []
# A list of tuple types that could match the sequence, per index
tuple_types: list[list[Type]] = []
# A list of all the unpack tuple types that we encountered, each containing the
# tuple type, unpack index, and union index
unpack_tuple_types: list[tuple[TupleType, int, int]] = []
for i, t in enumerate(
current_type.items if isinstance(current_type, UnionType) else [current_type]
):
t = get_proper_type(t)
if isinstance(t, TupleType):
tuple_items = list(t.items)
unpack_index = find_unpack_in_list(tuple_items)
if unpack_index is None:
size_diff = len(tuple_items) - required_patterns
if size_diff < 0:
continue
if size_diff > 0 and star_position is None:
continue
if not size_diff and star_position is not None:
# Above we subtract from required_patterns if star_position is not None
tuple_items.append(UninhabitedType())
tuple_types.append(tuple_items)
else:
normalized_inner_types = []
for it in tuple_items:
# Unfortunately, it is not possible to "split" the TypeVarTuple
# into individual items, so we just use its upper bound for the whole
# analysis instead.
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
it = UnpackType(it.type.upper_bound)
normalized_inner_types.append(it)
if (
len(normalized_inner_types) - 1 > required_patterns
and star_position is None
):
continue
t = t.copy_modified(items=normalized_inner_types)
unpack_tuple_types.append((t, unpack_index, i))
# In case we have multiple unpacks we want to combine them all, so add
# the combined tuple type to the sequence types.
sequence_types.append(self.chk.iterable_item_type(tuple_fallback(t), o))
elif isinstance(t, AnyType):
sequence_types.append(AnyType(TypeOfAny.from_another_any, t))
elif self.chk.type_is_iterable(t) and isinstance(t, Instance):
sequence_types.append(self.chk.iterable_item_type(t, o))
#
# Step 2. If we have a union, recurse and return the combined result
#
if isinstance(current_type, UnionType):
match_types: list[Type] = []
rest_types: list[Type] = []
captures_list: dict[Expression, list[Type]] = {}

if star_position is not None:
star_pattern = o.patterns[star_position]
assert isinstance(star_pattern, StarredPattern)
star_expr = star_pattern.capture
else:
unknown_type = True

inner_types: list[Type]
star_expr = None

for t in current_type.items:
match_type, rest_type, captures = self.accept(o, t)
match_types.append(match_type)
rest_types.append(rest_type)
if not is_uninhabited(match_type):
for expr, typ in captures.items():
p_typ = get_proper_type(typ)
if expr not in captures_list:
captures_list[expr] = []
# Avoid adding in a list[Never] for empty list captures
if (
expr == star_expr
and isinstance(p_typ, Instance)
and p_typ.type.fullname == "builtins.list"
and is_uninhabited(p_typ.args[0])
):
continue
captures_list[expr].append(typ)

return PatternType(
make_simplified_union(match_types),
make_simplified_union(rest_types),
{expr: make_simplified_union(types) for expr, types in captures_list.items()},
)

# If we only got one unpack tuple type, we can use that
#
# Step 3. Get inner types of original type
#
unpack_index = None
if len(unpack_tuple_types) == 1 and len(sequence_types) == 1 and not tuple_types:
update_tuple_type, unpack_index, union_index = unpack_tuple_types[0]
inner_types = update_tuple_type.items
if isinstance(current_type, UnionType):
union_items = list(current_type.items)
union_items[union_index] = update_tuple_type
current_type = get_proper_type(UnionType.make_union(items=union_items))
if isinstance(current_type, TupleType):
inner_types = current_type.items
unpack_index = find_unpack_in_list(inner_types)
if unpack_index is None:
size_diff = len(inner_types) - required_patterns
if size_diff < 0:
return self.early_non_match()
elif size_diff > 0 and star_position is None:
return self.early_non_match()
else:
current_type = update_tuple_type
# If we only got tuples we can't match, then exit early
elif not tuple_types and not sequence_types and not unknown_type:
return self.early_non_match()
elif tuple_types:
inner_types = [
make_simplified_union([*sequence_types, *[t for t in group if t is not None]])
for group in itertools.zip_longest(*tuple_types)
]
elif sequence_types:
inner_types = [make_simplified_union(sequence_types)] * len(o.patterns)
normalized_inner_types = []
for it in inner_types:
# Unfortunately, it is not possible to "split" the TypeVarTuple
# into individual items, so we just use its upper bound for the whole
# analysis instead.
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
it = UnpackType(it.type.upper_bound)
normalized_inner_types.append(it)
inner_types = normalized_inner_types
current_type = current_type.copy_modified(items=normalized_inner_types)
if len(inner_types) - 1 > required_patterns and star_position is None:
return self.early_non_match()
elif isinstance(current_type, AnyType):
inner_type: Type = AnyType(TypeOfAny.from_another_any, current_type)
inner_types = [inner_type] * len(o.patterns)
elif isinstance(current_type, TupleType):
inner_type = self.chk.iterable_item_type(tuple_fallback(current_type), o)
inner_types = [inner_type] * len(o.patterns)
elif isinstance(current_type, Instance) and self.chk.type_is_iterable(current_type):
inner_type = self.chk.iterable_item_type(current_type, o)
inner_types = [inner_type] * len(o.patterns)
else:
inner_types = [self.chk.named_type("builtins.object")] * len(o.patterns)
inner_type = self.chk.named_type("builtins.object")
inner_types = [inner_type] * len(o.patterns)

#
# match inner patterns
# Step 4. Match inner patterns
#
contracted_new_inner_types: list[Type] = []
contracted_rest_inner_types: list[Type] = []
captures: dict[Expression, Type] = {}
captures = {} # dict[Expression, Type]

contracted_inner_types = self.contract_starred_pattern_types(
inner_types, star_position, required_patterns
Expand All @@ -359,10 +352,10 @@ def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
)

#
# Calculate new type
# Step 5. Calculate new type
#
new_type: Type
rest_type: Type = current_type
rest_type = current_type
if isinstance(current_type, TupleType) and unpack_index is None:
narrowed_inner_types = []
inner_rest_types = []
Expand Down
44 changes: 24 additions & 20 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1773,42 +1773,38 @@ match m4:
reveal_type(a4) # N: Revealed type is "builtins.str"
reveal_type(b4) # N: Revealed type is "builtins.str"

# properly handles unpack when all other patterns are not sequences
m5: tuple[int, Unpack[tuple[float, ...]]] | None
match m5:
case (a5, b5):
reveal_type(a5) # N: Revealed type is "builtins.int"
reveal_type(b5) # N: Revealed type is "builtins.float"

# currently can't handle combing unpacking with other sequence patterns, if this happens revert to worst case
# of combing all types
m6: tuple[int, Unpack[tuple[float, ...]]] | list[str]
match m6:
case (a6, b6):
reveal_type(a6) # N: Revealed type is "builtins.int | builtins.float | builtins.str"
reveal_type(b6) # N: Revealed type is "builtins.int | builtins.float | builtins.str"
reveal_type(a6) # N: Revealed type is "builtins.int | builtins.str"
reveal_type(b6) # N: Revealed type is "builtins.float | builtins.str"

# but do still separate types from non unpacked types
m7: tuple[int, Unpack[tuple[float, ...]]] | tuple[str, str]
match m7:
case (a7, b7, *rest7):
reveal_type(a7) # N: Revealed type is "builtins.int | builtins.float | builtins.str"
reveal_type(b7) # N: Revealed type is "builtins.int | builtins.float | builtins.str"
reveal_type(rest7) # N: Revealed type is "builtins.list[builtins.int | builtins.float]"
reveal_type(a7) # N: Revealed type is "builtins.int | builtins.str"
reveal_type(b7) # N: Revealed type is "builtins.float | builtins.str"
reveal_type(rest7) # N: Revealed type is "builtins.list[builtins.float]"

# verify that if we are unpacking, it will get the type of the sequence if the tuple is too short
m8: tuple[int, str] | list[float]
match m8:
case (a8, b8, *rest8):
reveal_type(a8) # N: Revealed type is "builtins.float | builtins.int"
reveal_type(b8) # N: Revealed type is "builtins.float | builtins.str"
reveal_type(a8) # N: Revealed type is "builtins.int | builtins.float"
reveal_type(b8) # N: Revealed type is "builtins.str | builtins.float"
reveal_type(rest8) # N: Revealed type is "builtins.list[builtins.float]"

m9: tuple[str, str, int] | tuple[str, str]
match m9:
case (a9, *rest9):
reveal_type(a9) # N: Revealed type is "builtins.str"
reveal_type(rest9) # N: Revealed type is "builtins.list[builtins.str | builtins.int]"
reveal_type(rest9) # N: Revealed type is "builtins.list[builtins.str | builtins.int] | builtins.list[builtins.str]"

[builtins fixtures/tuple.pyi]

Expand Down Expand Up @@ -2261,15 +2257,23 @@ match foo:
reveal_type(x) # N: Revealed type is "builtins.int"
[builtins fixtures/tuple.pyi]

[case testMatchUnionTwoTuplesNoCrash]
var: tuple[int, int] | tuple[str, str]
[case testMatchUnionTwoTuples]
# flags: --strict-equality --warn-unreachable

def main(var: tuple[int, int] | tuple[str, str]):
match var:
case (42, a):
reveal_type(a) # N: Revealed type is "builtins.int"
case ("yes", b):
reveal_type(b) # N: Revealed type is "builtins.str"

# TODO: we can infer better here.
match var:
case (42, a):
reveal_type(a) # N: Revealed type is "builtins.int | builtins.str"
case ("yes", b):
reveal_type(b) # N: Revealed type is "builtins.int | builtins.str"

def main2(var: tuple[int, int] | tuple[str, str] | tuple[str, int]):
match var:
case (42, a):
reveal_type(a) # N: Revealed type is "builtins.int"
case ("yes", b):
reveal_type(b) # N: Revealed type is "builtins.str | builtins.int"
[builtins fixtures/tuple.pyi]

[case testMatchNamedAndKeywordsAreTheSame]
Expand Down