Skip to content

Commit 27f668a

Browse files
committed
address feedback
1 parent 4ae5d12 commit 27f668a

7 files changed

Lines changed: 171 additions & 136 deletions

File tree

examples/example-fastmcp-mcp/src/auth0/__init__.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,22 @@
55
including token verification, middleware, and scoped tool decorators.
66
"""
77

8+
import logging
9+
import os
10+
from typing import Callable, Union
11+
812
from mcp.server.auth.routes import create_protected_resource_routes
913
from mcp.server.fastmcp import FastMCP
1014
from starlette.middleware import Middleware
15+
from starlette.requests import Request
16+
from starlette.responses import JSONResponse
1117
from starlette.routing import Route, Router
1218

19+
from .errors import AuthenticationRequired, InsufficientScope, MalformedAuthorizationRequest
1320
from .middleware import Auth0Middleware
1421

22+
logger = logging.getLogger(__name__)
23+
1524

1625
class Auth0Mcp:
1726
def __init__(self, name: str, audience: str, domain: str):
@@ -56,3 +65,55 @@ def register_scopes(self, scopes: list[str]) -> None:
5665
"""
5766
if scopes:
5867
self._scopes_supported.update(scopes)
68+
69+
def exception_handlers(self) -> dict[Union[int, type[Exception]], Callable]:
70+
return {
71+
AuthenticationRequired: self._auth_error_handler,
72+
InsufficientScope: self._auth_error_handler,
73+
MalformedAuthorizationRequest: self._auth_error_handler,
74+
# Generic fallback for any other exceptions
75+
Exception: self._generic_exception_handler,
76+
}
77+
78+
def _auth_error_handler(self, request: Request, exc: Exception):
79+
"""
80+
Handle auth errors: malformed authorization requests, missing auth, invalid tokens, and insufficient scopes.
81+
"""
82+
# Include resource metadata parameter for 401 responses per RFC 9728 Section 5.1
83+
include_resource_metadata = exc.status_code == 401
84+
85+
return JSONResponse(
86+
{
87+
"error": exc.error_code,
88+
"error_description": exc.description
89+
},
90+
status_code=exc.status_code,
91+
headers={"WWW-Authenticate": self._build_www_authenticate_header(exc.error_code, exc.description, include_resource_metadata)},
92+
)
93+
94+
def _generic_exception_handler(self, request:Request, exc: Exception):
95+
"""
96+
Fallback handler for all other exceptions.
97+
"""
98+
logger.error(f"Unexpected error in: {exc}", exc_info=exc)
99+
100+
# Return standard HTTP 500 error
101+
return JSONResponse(
102+
{
103+
"error": "internal_server_error",
104+
"error_description": "An unexpected error occurred"
105+
},
106+
status_code=500,
107+
)
108+
109+
def _build_www_authenticate_header(self, error_code: str, description: str, include_resource_metadata: bool = False) -> str:
110+
"""
111+
Build WWW-Authenticate header according to RFC 9728 Section 5.1.
112+
"""
113+
www_auth_params = [f'error="{error_code}"', f'error_description="{description}"']
114+
115+
if include_resource_metadata:
116+
metadata_url = f"{os.getenv('MCP_SERVER_URL')}/.well-known/oauth-protected-resource"
117+
www_auth_params.append(f'resource_metadata="{metadata_url}"')
118+
119+
return f"Bearer {', '.join(www_auth_params)}"
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
from collections.abc import Iterable
5+
from functools import wraps
6+
7+
from mcp.server.fastmcp import Context
8+
9+
from .errors import AuthenticationRequired, InsufficientScope
10+
11+
12+
def require_scopes(required_scopes: Iterable[str]):
13+
"""
14+
Decorator that requires scopes on MCP tools.
15+
16+
Example:
17+
@mcp.tool(...)
18+
@require_scopes(["tool:greet", "tool:whoami"])
19+
def my_tool(name: str, ctx: Context) -> str:
20+
return f"Hello {name}!"
21+
"""
22+
required_scopes_list = list(required_scopes)
23+
def decorator(func):
24+
@wraps(func)
25+
async def wrapper(*args, **kwargs):
26+
# ctx is passed in either kw or positional
27+
ctx: Context | None = (kwargs.get("ctx") if isinstance(kwargs.get("ctx"), Context) else None) or next((arg for arg in args if isinstance(arg, Context)), None)
28+
if ctx is None:
29+
raise TypeError("ctx: Context is required")
30+
31+
auth = getattr(ctx.request_context.request.state, "auth", {})
32+
if not auth:
33+
raise AuthenticationRequired("Authentication required")
34+
35+
user_scopes = set(auth.get("scopes", []))
36+
missing_scopes = [s for s in required_scopes_list if s not in user_scopes]
37+
if missing_scopes:
38+
raise InsufficientScope(f"Missing required scopes: {missing_scopes}")
39+
40+
# Call the original function
41+
if asyncio.iscoroutinefunction(func):
42+
return await func(*args, **kwargs)
43+
else:
44+
return func(*args, **kwargs)
45+
return wrapper
46+
return decorator
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
class AuthenticationRequired(Exception):
2+
"""
3+
Raised when authentication is required but missing.
4+
5+
This maps to HTTP 401 Unauthorized status.
6+
Indicates the request lacks valid authentication credentials.
7+
"""
8+
status_code = 401
9+
error_code = "invalid_token"
10+
default_description = "Authentication required"
11+
12+
def __init__(self, message: str | None = None):
13+
self.description = message or self.default_description
14+
super().__init__(self.description)
15+
16+
17+
class InsufficientScope(Exception):
18+
"""
19+
Raised when user lacks required OAuth scopes.
20+
21+
This maps to HTTP 403 Forbidden status.
22+
Indicates the user is authenticated but doesn't have permission
23+
to access the requested resource due to insufficient scopes.
24+
"""
25+
status_code = 403
26+
error_code = "insufficient_scope"
27+
default_description = "Insufficient scope"
28+
29+
def __init__(self, message: str | None = None):
30+
self.description = message or self.default_description
31+
super().__init__(self.description)
32+
33+
34+
class MalformedAuthorizationRequest(Exception):
35+
"""
36+
Raised when authorization request is malformed.
37+
38+
This maps to HTTP 400 Bad Request status.
39+
Indicates the authorization header or token format is invalid.
40+
"""
41+
status_code = 400
42+
error_code = "invalid_request"
43+
default_description = "Malformed authorization request"
44+
45+
def __init__(self, message: str | None = None):
46+
self.description = message or self.default_description
47+
super().__init__(self.description)

examples/example-fastmcp-mcp/src/auth0/middleware.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import logging
2-
import os
32

43
from auth0_api_python import ApiClient, ApiClientOptions
54
from auth0_api_python.errors import VerifyAccessTokenError
65
from starlette.middleware.base import BaseHTTPMiddleware
76
from starlette.requests import Request
8-
from starlette.responses import JSONResponse
97
from starlette.types import ASGIApp
108

9+
from .errors import AuthenticationRequired, MalformedAuthorizationRequest
10+
1111
logger = logging.getLogger(__name__)
1212

1313
class Auth0Middleware(BaseHTTPMiddleware):
@@ -29,13 +29,9 @@ async def dispatch(self, request: Request, call_next):
2929
# Extract Authorization header
3030
auth_header = request.headers.get("authorization")
3131
if not auth_header:
32-
return self._return_auth_error_response(status_code=401, error="Authentication required", description="Missing Authorization header")
32+
raise AuthenticationRequired("Missing Authorization header")
3333
if not auth_header.lower().startswith("bearer "):
34-
return self._return_auth_error_response(
35-
status_code=401,
36-
error="Authentication required",
37-
description="Invalid Authorization header format"
38-
)
34+
raise MalformedAuthorizationRequest("Invalid Authorization header format")
3935

4036
# Extract and verify token
4137
token = auth_header[7:] # Remove "Bearer " prefix
@@ -72,25 +68,8 @@ async def dispatch(self, request: Request, call_next):
7268
return await call_next(request)
7369
except VerifyAccessTokenError as e:
7470
logger.error(f"Token verification failed: {str(e)}")
75-
return self._return_auth_error_response(
76-
status_code=401,
77-
error="Authentication failed",
78-
description="Invalid token"
79-
)
71+
raise AuthenticationRequired("Invalid token")
8072
except Exception as e:
8173
logger.error(f"Unexpected error in middleware: {str(e)}")
82-
return self._return_auth_error_response(
83-
status_code=500,
84-
error="Internal Server Error",
85-
description="Internal Server Error"
86-
)
87-
88-
def _return_auth_error_response(self, status_code: int, error: str, description: str) -> JSONResponse:
89-
www_auth_parts = [f'error="{error}"', f'error_description="{description}"', f'resource_metadata="{os.getenv("MCP_SERVER_URL")}"']
90-
www_authenticate = f"Bearer {', '.join(www_auth_parts)}"
91-
92-
return JSONResponse(
93-
status_code=status_code,
94-
content={"error": error, "error_description": description},
95-
headers={"WWW-Authenticate": www_authenticate}
96-
)
74+
# Re-raise unexpected errors to be handled by generic exception handler
75+
raise

examples/example-fastmcp-mcp/src/auth0/tools.py

Lines changed: 0 additions & 97 deletions
This file was deleted.

examples/example-fastmcp-mcp/src/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]:
4444
),
4545
],
4646
lifespan=lifespan,
47+
exception_handlers=auth0_mcp.exception_handlers(),
4748
)
4849

4950
# Wrap ASGI application with CORS middleware to expose Mcp-Session-Id header

examples/example-fastmcp-mcp/src/tools.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22

33
from mcp.server.fastmcp import Context
44

5-
from .auth0.tools import create_scoped_tool_decorator, get_auth_info
5+
from .auth0.authz import require_scopes
66

77

88
def register_tools(auth0Mcp):
99
"""
1010
Register all tools with the MCP server.
1111
"""
1212
mcp = auth0Mcp.mcp
13-
# Create a scoped_tool decorator bound to the mcp instance
14-
scoped_tool = create_scoped_tool_decorator(auth0Mcp)
13+
# Register scopes used by tools for Protected Resource Metadata
14+
auth0Mcp.register_scopes(["tool:greet", "tool:whoami"])
1515

1616
# Tool without required scopes
1717
@mcp.tool()
@@ -20,31 +20,29 @@ def echo(text: str) -> str:
2020
return text
2121

2222
# A MCP tool with required scopes
23-
@scoped_tool(
24-
required_scopes=["tool:greet"],
23+
@mcp.tool(
2524
name="greet",
2625
title="Greet Tool",
2726
description="Greets a user",
2827
annotations={"readOnlyHint": True}
2928
)
29+
@require_scopes(["tool:greet"])
3030
def greet(name: str, ctx: Context) -> str:
3131
name = (name or "").strip() or "world"
32-
request = ctx.request_context.request
33-
auth_info = get_auth_info(request)
32+
auth_info = ctx.request_context.request.state.auth
3433
user_id = auth_info.get("extra", {}).get("sub")
3534
return f"Hello, {name}! You are authenticated as {user_id}"
3635

3736
# A MCP tool with required scopes
38-
@scoped_tool(
39-
required_scopes=["tool:whoami"],
37+
@mcp.tool(
4038
name="whoami",
4139
title="Who Am I Tool",
4240
description="Returns information about the authenticated user",
4341
annotations={"readOnlyHint": True}
4442
)
43+
@require_scopes(["tool:whoami"])
4544
def whoami(ctx: Context) -> str:
46-
request = ctx.request_context.request
47-
auth_info = get_auth_info(request)
45+
auth_info = ctx.request_context.request.state.auth
4846

4947
response_data = {
5048
"user": auth_info.get("extra", {}),

0 commit comments

Comments
 (0)