Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions packages/mcpplugin/src/microsoft_teams/mcpplugin/ai_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from .models import McpCachedValue, McpClientPluginParams, McpToolDetails
from .transport import create_transport
from .url_validation import UrlValidationParams, validate_mcp_server_url

REFETCH_TIMEOUT_MS = 24 * 60 * 60 * 1000 # 1 day

Expand Down Expand Up @@ -248,6 +249,13 @@ async def _fetch_tools_from_server(self, url: str, params: McpClientPluginParams
Raises:
Exception: If connection or tool listing fails
"""
await validate_mcp_server_url(
url,
UrlValidationParams(
allow_private_network=params.allow_private_network,
validate_url=params.validate_url,
),
)
transport_context = create_transport(url, params.transport or "streamable_http", params.headers)

async with transport_context as (read_stream, write_stream):
Expand Down Expand Up @@ -285,6 +293,13 @@ async def _call_mcp_tool(
Returns:
Tool execution result as string or list of strings
"""
await validate_mcp_server_url(
url,
UrlValidationParams(
allow_private_network=params.allow_private_network,
validate_url=params.validate_url,
),
)
transport_context = create_transport(url, params.transport or "streamable_http", params.headers)

async with transport_context as (read_stream, write_stream):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ class McpClientPluginParams:
)
skip_if_unavailable: Optional[bool] = True # Continue if server is unavailable
refetch_timeout_ms: Optional[int] = None # Override default cache timeout
allow_private_network: bool = False # Allow MCP URLs resolving to loopback/RFC1918/link-local
validate_url: Optional[Callable[[str], Union[bool, Awaitable[bool]]]] = None # Fully replace default URL filter
64 changes: 61 additions & 3 deletions packages/mcpplugin/src/microsoft_teams/mcpplugin/server_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import importlib.metadata
import logging
from inspect import isawaitable
from typing import Annotated, Any, TypeVar, cast
from typing import Annotated, Any, Awaitable, Callable, Optional, TypeVar, Union, cast

from fastmcp import FastMCP
from fastmcp.tools import FunctionTool
Expand All @@ -20,6 +20,7 @@
PluginStartEvent,
)
from pydantic import BaseModel
from starlette.requests import Request

try:
version = importlib.metadata.version("microsoft-teams-mcpplugin")
Expand All @@ -30,6 +31,44 @@

P = TypeVar("P", bound=BaseModel)

RequireAuthCallable = Callable[[Request], Union[bool, Awaitable[bool]]]


class _AuthMiddleware:
"""ASGI middleware that gates inbound MCP requests behind ``require_auth``."""

def __init__(self, app: Any, require_auth: RequireAuthCallable) -> None:
self.app = app
self.require_auth = require_auth

async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
if scope.get("type") != "http":
await self.app(scope, receive, send)
return

request = Request(scope, receive=receive)
ok = False
try:
result = self.require_auth(request)
if isawaitable(result):
result = await result
ok = bool(result)
except Exception as err: # noqa: BLE001
logger.debug("require_auth raised: %s", err)

if not ok:
await send(
{
"type": "http.response.start",
"status": 401,
"headers": [(b"content-type", b"text/plain")],
}
)
await send({"type": "http.response.body", "body": b"unauthorized"})
return

await self.app(scope, receive, send)


@Plugin(name="mcp-server", version=version, description="MCP server plugin that exposes AI functions as MCP tools")
class McpServerPlugin(PluginBase):
Expand All @@ -44,17 +83,27 @@ class McpServerPlugin(PluginBase):
# Dependency injection
http_server: Annotated[HttpServer, DependencyMetadata()]

def __init__(self, name: str = "teams-mcp-server", path: str = "/mcp"):
def __init__(
self,
name: str = "teams-mcp-server",
path: str = "/mcp",
require_auth: Optional[RequireAuthCallable] = None,
):
"""
Initialize the MCP server plugin.

Args:
name: The name of the MCP server for identification
path: The HTTP path to mount the MCP server on (default: /mcp)
require_auth: Optional callable gating inbound MCP requests. Receives
a Starlette Request; return True to allow, False (or raise) to
reject with HTTP 401. When unset, all requests are accepted and
a warning is emitted at plugin startup.
"""
self.mcp_server = FastMCP(name)
self.path = path
self._mounted = False
self._require_auth = require_auth

@property
def server(self) -> FastMCP:
Expand Down Expand Up @@ -165,7 +214,16 @@ async def on_start(self, event: PluginStartEvent) -> None:
# We mount the mcp server as a separate app at self.path
mcp_http_app = self.mcp_server.http_app(path=self.path, transport="http")
adapter.lifespans.append(mcp_http_app.lifespan) # pyright: ignore[reportArgumentType]
adapter.app.mount("/", mcp_http_app)

if self._require_auth is not None:
adapter.app.mount("/", _AuthMiddleware(mcp_http_app, self._require_auth))
else:
logger.warning(
"McpServerPlugin started without require_auth. All MCP requests at %s "
"will be accepted. Pass require_auth to enforce authentication.",
self.path,
)
adapter.app.mount("/", mcp_http_app)

self._mounted = True

Expand Down
111 changes: 111 additions & 0 deletions packages/mcpplugin/src/microsoft_teams/mcpplugin/url_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
"""

import asyncio
import ipaddress
import logging
import socket
from dataclasses import dataclass
from inspect import isawaitable
from typing import Awaitable, Callable, List, Optional, Union
from urllib.parse import urlparse

logger = logging.getLogger(__name__)


class UrlValidationError(ValueError):
"""Raised when an MCP server URL fails validation."""


@dataclass
class UrlValidationParams:
"""Parameters controlling MCP server URL validation."""

allow_private_network: bool = False
validate_url: Optional[Callable[[str], Union[bool, Awaitable[bool]]]] = None


async def validate_mcp_server_url(url: str, params: Optional[UrlValidationParams] = None) -> str:
"""
Validate a URL destined for an MCP server connection.

When ``validate_url`` is provided, it fully replaces the default checks.
Otherwise the default policy rejects non-http(s) schemes, and (unless
``allow_private_network`` is True) rejects URLs whose hostname resolves
to a private / loopback / link-local address.

Returns the original URL on success. Raises :class:`UrlValidationError`
on rejection.
"""
params = params or UrlValidationParams()

try:
parsed = urlparse(url)
except ValueError as err:
raise UrlValidationError(f"Invalid URL: {url!r}") from err

if not parsed.scheme:
raise UrlValidationError(f"Invalid URL: {url!r}")

if params.validate_url is not None:
result = params.validate_url(url)
if isawaitable(result):
result = await result
if not result:
raise UrlValidationError(f"URL rejected by validate_url: {url}")
return url

if parsed.scheme not in ("http", "https"):
raise UrlValidationError(f"URL scheme {parsed.scheme!r} is not allowed; must be http or https")

if not parsed.hostname:
raise UrlValidationError(f"Invalid URL: {url!r}")

if params.allow_private_network:
return url

addresses = await _resolve_host(parsed.hostname)
if not addresses:
raise UrlValidationError(f"URL {url} did not resolve to any address")
for address in addresses:
if is_private_address(address):
raise UrlValidationError(
f"URL {url} resolves to private or loopback address {address}; set allow_private_network=True to bypass"
)

return url


def is_private_address(address: str) -> bool:
"""True if an IP address is loopback, private, link-local, unspecified, or IPv6 site-local."""
try:
ip = ipaddress.ip_address(address)
except ValueError:
return True # Unknown: fail closed.
if ip.is_loopback or ip.is_private or ip.is_link_local or ip.is_unspecified:
return True
# RFC 4291 deprecated IPv6 site-local (fec0::/10). Python's is_private does
# not classify it, but we reject for parity with the C# SDK.
if isinstance(ip, ipaddress.IPv6Address):
packed = int(ip)
return (packed >> 118) == 0x3FB # top 10 bits == 1111111011
return False


async def _resolve_host(host: str) -> List[str]:
"""Resolve a hostname to its IP addresses; short-circuit IP literals."""
try:
ipaddress.ip_address(host)
return [host]
except ValueError:
pass

loop = asyncio.get_running_loop()
try:
results = await loop.getaddrinfo(host, None)
except socket.gaierror as err:
raise UrlValidationError(f"Could not resolve {host}: {err}") from err

return list({entry[4][0] for entry in results})
13 changes: 13 additions & 0 deletions packages/mcpplugin/tests/test_ai_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ async def create_transport(self, url: str, transport_type: str, headers: Optiona
yield (read_stream, write_stream)


@pytest.fixture(autouse=True)
def _stub_url_validation_dns():
"""
Prevent DNS lookups from running when ai_plugin runs URL validation
against placeholder test hosts (e.g., http://test-server).
"""
with patch(
"microsoft_teams.mcpplugin.url_validation._resolve_host",
new=AsyncMock(return_value=["8.8.8.8"]),
):
yield


@pytest.fixture
def sample_tools() -> List[MockMCPTool]:
"""Sample MCP tools for testing."""
Expand Down
Loading
Loading