Skip to content

Commit 8a6a6d0

Browse files
committed
♻️ Plugin system for fastapi app (yes, this time it's actually wired up) 😅
1 parent 9b55c29 commit 8a6a6d0

6 files changed

Lines changed: 188 additions & 15 deletions

File tree

src/fastapi_cli/cli.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
2+
from importlib.metadata import entry_points as _entry_points
23
from pathlib import Path
34
from typing import Annotated, Any
45

56
import typer
67
from pydantic import ValidationError
78
from rich import print
89
from rich.tree import Tree
10+
from typer.models import CommandInfo
911

1012
from fastapi_cli.config import FastAPIConfig
1113
from fastapi_cli.discover import get_import_data, get_import_data_from_import_string
@@ -15,9 +17,7 @@
1517
from .logging import setup_logging
1618
from .utils.cli import get_rich_toolkit, get_uvicorn_log_config
1719

18-
app = typer.Typer(
19-
rich_markup_mode="rich", context_settings={"help_option_names": ["-h", "--help"]}
20-
)
20+
app = typer.Typer(rich_markup_mode="rich", context_settings={"help_option_names": ["-h", "--help"]})
2121

2222
logger = logging.getLogger(__name__)
2323

@@ -48,6 +48,55 @@
4848
pass
4949

5050

51+
def _cmd_name(registered_command: CommandInfo) -> Any:
52+
"""Return the effective CLI name for a registered Typer command."""
53+
if registered_command.name is not None:
54+
return registered_command.name
55+
if registered_command.callback is not None:
56+
return registered_command.callback.__name__.lower().replace("_", "-")
57+
return None
58+
59+
60+
def _load_cli_plugins(typer_app: typer.Typer) -> None:
61+
"""Load commands registered via the 'fastapi_cli.plugins' entry point group."""
62+
63+
# Seed with built-in command names so plugins overriding them get flagged.
64+
known: set[str] = set()
65+
for registered_command in typer_app.registered_commands:
66+
name = _cmd_name(registered_command)
67+
if name:
68+
known.add(name)
69+
70+
for entry_point in _entry_points(group="fastapi_cli.plugins"):
71+
# Snapshot length to slice off only what the plugin adds.
72+
cursor = len(typer_app.registered_commands)
73+
try:
74+
# resolves the `register` callable.
75+
entry_point.load()(typer_app)
76+
except Exception as e:
77+
# Warning on broken plugin ans continue CLI execution.
78+
logger.warning("Plugin '%s' failed to load: %s", entry_point.name, e)
79+
continue
80+
81+
# Walk only plugin's new commands to detect collision.
82+
collisions: list[str] = []
83+
for registered_command in typer_app.registered_commands[cursor:]:
84+
name = _cmd_name(registered_command)
85+
if not name:
86+
continue
87+
if name in known:
88+
collisions.append(name)
89+
known.add(name)
90+
91+
# One warning per plugin, listing all the names it overrode.
92+
if collisions:
93+
logger.warning(
94+
"Plugin '%s' overrides existing command(s): %s",
95+
entry_point.name,
96+
", ".join(sorted(collisions)),
97+
)
98+
99+
51100
def version_callback(value: bool) -> None:
52101
if value:
53102
print(f"FastAPI CLI version: [green]{__version__}[/green]")
@@ -58,9 +107,7 @@ def version_callback(value: bool) -> None:
58107
def callback(
59108
version: Annotated[
60109
bool | None,
61-
typer.Option(
62-
"--version", help="Show the version and exit.", callback=version_callback
63-
),
110+
typer.Option("--version", help="Show the version and exit.", callback=version_callback),
64111
] = None,
65112
verbose: bool = typer.Option(False, help="Enable verbose output"),
66113
) -> None:
@@ -88,9 +135,7 @@ def _get_module_tree(module_paths: list[Path]) -> Tree:
88135

89136
tree = root_tree
90137
for sub_path in module_paths[1:]:
91-
sub_name = (
92-
f"🐍 {sub_path.name}" if sub_path.is_file() else f"📁 {sub_path.name}"
93-
)
138+
sub_name = f"🐍 {sub_path.name}" if sub_path.is_file() else f"📁 {sub_path.name}"
94139
tree = tree.add(sub_name)
95140
if sub_path.is_dir():
96141
tree.add("[dim]🐍 __init__.py[/dim]")
@@ -125,9 +170,7 @@ def _run(
125170

126171
if entrypoint and (path or app):
127172
toolkit.print_line()
128-
toolkit.print(
129-
"[error]Cannot use --entrypoint together with path or --app arguments"
130-
)
173+
toolkit.print("[error]Cannot use --entrypoint together with path or --app arguments")
131174
toolkit.print_line()
132175
raise typer.Exit(code=1)
133176

@@ -221,9 +264,7 @@ def _run(
221264
port=port,
222265
reload=reload,
223266
reload_dirs=(
224-
[str(directory.resolve()) for directory in reload_dirs]
225-
if reload_dirs
226-
else None
267+
[str(directory.resolve()) for directory in reload_dirs] if reload_dirs else None
227268
),
228269
workers=workers,
229270
root_path=root_path,
@@ -448,4 +489,5 @@ def run(
448489

449490

450491
def main() -> None:
492+
_load_cli_plugins(app)
451493
app()

tests/assets/plugins/__init__.py

Whitespace-only changes.

tests/assets/plugins/broken.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import typer
2+
3+
4+
def register(app: typer.Typer) -> None:
5+
raise RuntimeError("intentionally broken plugin")

tests/assets/plugins/colliding.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import typer
2+
3+
4+
def register(app: typer.Typer) -> None:
5+
@app.command("dev") # collides with built-in dev command
6+
def dev() -> None:
7+
pass # pragma: no cover

tests/assets/plugins/sample.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import typer
2+
3+
4+
def register(app: typer.Typer) -> None:
5+
@app.command("ping")
6+
def ping() -> None:
7+
"""Test command added by plugin."""
8+
typer.echo("pong") # pragma: no cover

tests/test_cli_plugin.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import sys
2+
from collections.abc import Generator
3+
from importlib.metadata import EntryPoint
4+
from pathlib import Path
5+
from unittest.mock import patch
6+
7+
import pytest
8+
import typer
9+
10+
from fastapi_cli.cli import _load_cli_plugins
11+
12+
assets_path = Path(__file__).parent / "assets"
13+
14+
15+
@pytest.fixture
16+
def plugins_on_path() -> Generator[None, None, None]:
17+
original_path = sys.path.copy()
18+
sys.path.insert(0, str(assets_path))
19+
try:
20+
yield
21+
finally:
22+
sys.path[:] = original_path
23+
for key in list(sys.modules.keys()):
24+
if key.startswith("plugins."):
25+
del sys.modules[key]
26+
27+
28+
@pytest.fixture
29+
def app() -> typer.Typer:
30+
app = typer.Typer()
31+
return app
32+
33+
34+
def _entry_point(name: str, module_attr: str) -> EntryPoint:
35+
return EntryPoint(
36+
name=name,
37+
value=f"plugins.{module_attr}",
38+
group="fastapi_cli.plugins",
39+
)
40+
41+
42+
def test_load_cli_plugins_happy_path(
43+
plugins_on_path: None,
44+
app: typer.Typer,
45+
) -> None:
46+
entry_point = _entry_point("sample", "sample:register")
47+
48+
with patch("fastapi_cli.cli._entry_points", return_value=[entry_point]):
49+
_load_cli_plugins(app)
50+
51+
names = {registered_command.name for registered_command in app.registered_commands}
52+
assert "ping" in names
53+
54+
55+
def test_load_cli_plugins_logs_on_failure(
56+
plugins_on_path: None,
57+
app: typer.Typer,
58+
) -> None:
59+
entry_point = _entry_point("broken", "broken:register")
60+
61+
with (
62+
patch("fastapi_cli.cli._entry_points", return_value=[entry_point]),
63+
patch("fastapi_cli.cli.logger") as mock_logger,
64+
):
65+
_load_cli_plugins(app)
66+
67+
mock_logger.warning.assert_called_once()
68+
_fmt, entry_point_name, *_ = mock_logger.warning.call_args.args
69+
assert entry_point_name == "broken"
70+
71+
72+
def test_load_cli_plugins_warns_on_collision_with_builtin(
73+
plugins_on_path: None,
74+
app: typer.Typer,
75+
) -> None:
76+
77+
@app.command("dev")
78+
def existing() -> None:
79+
pass # pragma: no cover
80+
81+
entry_point = _entry_point("colliding", "colliding:register")
82+
83+
with (
84+
patch("fastapi_cli.cli._entry_points", return_value=[entry_point]),
85+
patch("fastapi_cli.cli.logger") as mock_logger,
86+
):
87+
_load_cli_plugins(app)
88+
89+
mock_logger.warning.assert_called_once()
90+
_fmt, entry_point_name, collisions = mock_logger.warning.call_args.args
91+
assert entry_point_name == "colliding"
92+
assert "dev" in collisions
93+
94+
95+
def test_load_cli_plugins_warns_on_cross_plugin_collision(
96+
plugins_on_path: None,
97+
app: typer.Typer,
98+
) -> None:
99+
first = _entry_point("sample", "sample:register")
100+
duplicate = _entry_point("duplicate", "sample:register")
101+
102+
with (
103+
patch("fastapi_cli.cli._entry_points", return_value=[first, duplicate]),
104+
patch("fastapi_cli.cli.logger") as mock_logger,
105+
):
106+
_load_cli_plugins(app)
107+
108+
mock_logger.warning.assert_called_once()
109+
_fmt, entry_point_name, collisions = mock_logger.warning.call_args.args
110+
assert entry_point_name == "duplicate"
111+
assert "ping" in collisions

0 commit comments

Comments
 (0)