Skip to content
131 changes: 130 additions & 1 deletion tests/test_type_conversion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from pathlib import Path
from typing import Any
from typing import Any, Optional, Union

import click
import pytest
Expand Down Expand Up @@ -48,6 +48,135 @@ def opt(user: str | None = None):
assert "User: Camila" in result.output


@pytest.mark.parametrize(
("value", "expected"),
[("0", "ROOTED!"), ("12", "ID: 12"), ("name", "USER: name")],
)
def test_union(value, expected):
app = typer.Typer()

@app.command()
def opt(id_or_name: int | str):
if isinstance(id_or_name, int):
if id_or_name == 0:
print("ROOTED!")
else:
print(f"ID: {id_or_name}")
else:
print(f"USER: {id_or_name}")

result = runner.invoke(app, [value])
assert result.exit_code == 0
assert expected in result.output


def test_union_optional():
app = typer.Typer()

@app.command()
def cmd(x: int | str | None = None):
print(f"x={x!r} ({type(x).__name__})")

result = runner.invoke(app)
assert result.exit_code == 0
assert "x=None (NoneType)" in result.output

result = runner.invoke(app, ["--x", "7"])
assert result.exit_code == 0
assert "x=7 (int)" in result.output

result = runner.invoke(app, ["--x", "hello"])
assert result.exit_code == 0
assert "x='hello' (str)" in result.output


def test_union_rejects_invalid():
app = typer.Typer()

@app.command()
def cmd(x: int | float):
print(x)

result = runner.invoke(app, ["1"])
assert result.exit_code == 0
assert "1" in result.output

result = runner.invoke(app, ["not-a-number"])
assert result.exit_code != 0


def test_union_metavar_in_help():
app = typer.Typer()

@app.command()
def cmd(x: int | str):
"""Cmd."""

result = runner.invoke(app, ["--help"])
assert result.exit_code == 0
assert "INTEGER | TEXT" in result.output


@pytest.mark.parametrize(
("value", "expected"),
[("0", "ROOTED!"), ("12", "ID: 12"), ("name", "USER: name")],
)
def test_union_pipe(value, expected):
app = typer.Typer()

@app.command()
def opt(id_or_name: int | str):
if isinstance(id_or_name, int):
if id_or_name == 0:
print("ROOTED!")
else:
print(f"ID: {id_or_name}")
else:
print(f"USER: {id_or_name}")

result = runner.invoke(app, [value])
assert result.exit_code == 0
assert expected in result.output


def test_union_pipe_optional():
app = typer.Typer()

@app.command()
def cmd(x: int | str | None = None):
print(f"x={x!r} ({type(x).__name__})")

result = runner.invoke(app)
assert result.exit_code == 0
assert "x=None (NoneType)" in result.output

result = runner.invoke(app, ["--x", "7"])
assert result.exit_code == 0
assert "x=7 (int)" in result.output

result = runner.invoke(app, ["--x", "hello"])
assert result.exit_code == 0
assert "x='hello' (str)" in result.output


@pytest.mark.parametrize("args", [[], ["--x", "7"], ["--x", "hello"]])
def test_union_pipe_and_typing_equivalent(args):
def make_app(annotation):
app = typer.Typer()

@app.command()
def cmd(x: annotation = None):
print(f"x={x!r} ({type(x).__name__})")

return app

typing_out = runner.invoke(make_app(Union[int, str, None]), args).output # noqa: UP007
pipe_out = runner.invoke(make_app(int | str | None), args).output
optional_out = runner.invoke(make_app(Optional[int | str]), args).output # noqa: UP045

assert typing_out == pipe_out == optional_out


def test_optional_tuple():
app = typer.Typer()

Expand Down
42 changes: 39 additions & 3 deletions typer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,6 +1517,30 @@ def wrapper(**kwargs: Any) -> Any:
return wrapper


class UnionParamType(click.ParamType):
@property
def name(self) -> str: # type: ignore[override]
return " | ".join(_type.name for _type in self._types)

def __init__(self, types: list[click.ParamType]):
super().__init__()
self._types = types

def convert(
self,
value: Any,
param: click.Parameter | None,
ctx: click.Context | None,
) -> Any:
error_messages = []
for _type in self._types:
try:
return _type.convert(value, param, ctx)
except click.BadParameter as e:
error_messages.append(str(e))
self.fail("\n" + "\nbut also\n".join(error_messages), param, ctx)


def get_click_type(
*, annotation: Any, parameter_info: ParameterInfo
) -> click.ParamType:
Expand Down Expand Up @@ -1617,6 +1641,13 @@ def get_click_type(
literal_values(annotation),
case_sensitive=parameter_info.case_sensitive,
)
elif get_origin(annotation) is not None and is_union(get_origin(annotation)):
types = [
get_click_type(annotation=arg, parameter_info=parameter_info)
for arg in get_args(annotation)
if arg is not NoneType
]
return UnionParamType(types)
raise RuntimeError(f"Type not yet supported: {annotation}") # pragma: no cover


Expand Down Expand Up @@ -1665,9 +1696,14 @@ def get_click_param(
if type_ is NoneType:
continue
types.append(type_)
assert len(types) == 1, "Typer Currently doesn't support Union types"
main_type = types[0]
origin = get_origin(main_type)
if len(types) == 1:
(main_type,) = types
origin = get_origin(main_type)
else:
for type_ in get_args(main_type):
assert not get_origin(type_), (
"Union types with complex sub-types are not currently supported"
)
# Handle Tuples and Lists
if lenient_issubclass(origin, list):
main_type = get_args(main_type)[0]
Expand Down
Loading