diff --git a/packages/mcpplugin/src/microsoft_teams/mcpplugin/ai_plugin.py b/packages/mcpplugin/src/microsoft_teams/mcpplugin/ai_plugin.py index b2b76fad..2742a53e 100644 --- a/packages/mcpplugin/src/microsoft_teams/mcpplugin/ai_plugin.py +++ b/packages/mcpplugin/src/microsoft_teams/mcpplugin/ai_plugin.py @@ -17,6 +17,7 @@ from .models import McpCachedValue, McpClientPluginParams, McpToolDetails from .transport import create_transport +from .url_validation import UrlValidationParams, validate_mcp_server_url REFETCH_TIMEOUT_MS = 24 * 60 * 60 * 1000 # 1 day @@ -248,6 +249,13 @@ async def _fetch_tools_from_server(self, url: str, params: McpClientPluginParams Raises: Exception: If connection or tool listing fails """ + await validate_mcp_server_url( + url, + UrlValidationParams( + allow_private_network=params.allow_private_network, + validate_url=params.validate_url, + ), + ) transport_context = create_transport(url, params.transport or "streamable_http", params.headers) async with transport_context as (read_stream, write_stream): @@ -285,6 +293,13 @@ async def _call_mcp_tool( Returns: Tool execution result as string or list of strings """ + await validate_mcp_server_url( + url, + UrlValidationParams( + allow_private_network=params.allow_private_network, + validate_url=params.validate_url, + ), + ) transport_context = create_transport(url, params.transport or "streamable_http", params.headers) async with transport_context as (read_stream, write_stream): diff --git a/packages/mcpplugin/src/microsoft_teams/mcpplugin/models/params.py b/packages/mcpplugin/src/microsoft_teams/mcpplugin/models/params.py index 69d3a3a0..79b404c3 100644 --- a/packages/mcpplugin/src/microsoft_teams/mcpplugin/models/params.py +++ b/packages/mcpplugin/src/microsoft_teams/mcpplugin/models/params.py @@ -25,3 +25,5 @@ class McpClientPluginParams: ) skip_if_unavailable: Optional[bool] = True # Continue if server is unavailable refetch_timeout_ms: Optional[int] = None # Override default cache timeout + allow_private_network: bool = False # Allow MCP URLs resolving to loopback/RFC1918/link-local + validate_url: Optional[Callable[[str], Union[bool, Awaitable[bool]]]] = None # Fully replace default URL filter diff --git a/packages/mcpplugin/src/microsoft_teams/mcpplugin/server_plugin.py b/packages/mcpplugin/src/microsoft_teams/mcpplugin/server_plugin.py index 3ed7578a..b47eb862 100644 --- a/packages/mcpplugin/src/microsoft_teams/mcpplugin/server_plugin.py +++ b/packages/mcpplugin/src/microsoft_teams/mcpplugin/server_plugin.py @@ -6,7 +6,7 @@ import importlib.metadata import logging from inspect import isawaitable -from typing import Annotated, Any, TypeVar, cast +from typing import Annotated, Any, Awaitable, Callable, Optional, TypeVar, Union, cast from fastmcp import FastMCP from fastmcp.tools import FunctionTool @@ -20,6 +20,7 @@ PluginStartEvent, ) from pydantic import BaseModel +from starlette.requests import Request try: version = importlib.metadata.version("microsoft-teams-mcpplugin") @@ -30,6 +31,44 @@ P = TypeVar("P", bound=BaseModel) +RequireAuthCallable = Callable[[Request], Union[bool, Awaitable[bool]]] + + +class _AuthMiddleware: + """ASGI middleware that gates inbound MCP requests behind ``require_auth``.""" + + def __init__(self, app: Any, require_auth: RequireAuthCallable) -> None: + self.app = app + self.require_auth = require_auth + + async def __call__(self, scope: Any, receive: Any, send: Any) -> None: + if scope.get("type") != "http": + await self.app(scope, receive, send) + return + + request = Request(scope, receive=receive) + ok = False + try: + result = self.require_auth(request) + if isawaitable(result): + result = await result + ok = bool(result) + except Exception as err: # noqa: BLE001 + logger.debug("require_auth raised: %s", err) + + if not ok: + await send( + { + "type": "http.response.start", + "status": 401, + "headers": [(b"content-type", b"text/plain")], + } + ) + await send({"type": "http.response.body", "body": b"unauthorized"}) + return + + await self.app(scope, receive, send) + @Plugin(name="mcp-server", version=version, description="MCP server plugin that exposes AI functions as MCP tools") class McpServerPlugin(PluginBase): @@ -44,17 +83,27 @@ class McpServerPlugin(PluginBase): # Dependency injection http_server: Annotated[HttpServer, DependencyMetadata()] - def __init__(self, name: str = "teams-mcp-server", path: str = "/mcp"): + def __init__( + self, + name: str = "teams-mcp-server", + path: str = "/mcp", + require_auth: Optional[RequireAuthCallable] = None, + ): """ Initialize the MCP server plugin. Args: name: The name of the MCP server for identification path: The HTTP path to mount the MCP server on (default: /mcp) + require_auth: Optional callable gating inbound MCP requests. Receives + a Starlette Request; return True to allow, False (or raise) to + reject with HTTP 401. When unset, all requests are accepted and + a warning is emitted at plugin startup. """ self.mcp_server = FastMCP(name) self.path = path self._mounted = False + self._require_auth = require_auth @property def server(self) -> FastMCP: @@ -165,7 +214,16 @@ async def on_start(self, event: PluginStartEvent) -> None: # We mount the mcp server as a separate app at self.path mcp_http_app = self.mcp_server.http_app(path=self.path, transport="http") adapter.lifespans.append(mcp_http_app.lifespan) # pyright: ignore[reportArgumentType] - adapter.app.mount("/", mcp_http_app) + + if self._require_auth is not None: + adapter.app.mount("/", _AuthMiddleware(mcp_http_app, self._require_auth)) + else: + logger.warning( + "McpServerPlugin started without require_auth. All MCP requests at %s " + "will be accepted. Pass require_auth to enforce authentication.", + self.path, + ) + adapter.app.mount("/", mcp_http_app) self._mounted = True diff --git a/packages/mcpplugin/src/microsoft_teams/mcpplugin/url_validation.py b/packages/mcpplugin/src/microsoft_teams/mcpplugin/url_validation.py new file mode 100644 index 00000000..54719af9 --- /dev/null +++ b/packages/mcpplugin/src/microsoft_teams/mcpplugin/url_validation.py @@ -0,0 +1,111 @@ +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +import asyncio +import ipaddress +import logging +import socket +from dataclasses import dataclass +from inspect import isawaitable +from typing import Awaitable, Callable, List, Optional, Union +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + + +class UrlValidationError(ValueError): + """Raised when an MCP server URL fails validation.""" + + +@dataclass +class UrlValidationParams: + """Parameters controlling MCP server URL validation.""" + + allow_private_network: bool = False + validate_url: Optional[Callable[[str], Union[bool, Awaitable[bool]]]] = None + + +async def validate_mcp_server_url(url: str, params: Optional[UrlValidationParams] = None) -> str: + """ + Validate a URL destined for an MCP server connection. + + When ``validate_url`` is provided, it fully replaces the default checks. + Otherwise the default policy rejects non-http(s) schemes, and (unless + ``allow_private_network`` is True) rejects URLs whose hostname resolves + to a private / loopback / link-local address. + + Returns the original URL on success. Raises :class:`UrlValidationError` + on rejection. + """ + params = params or UrlValidationParams() + + try: + parsed = urlparse(url) + except ValueError as err: + raise UrlValidationError(f"Invalid URL: {url!r}") from err + + if not parsed.scheme: + raise UrlValidationError(f"Invalid URL: {url!r}") + + if params.validate_url is not None: + result = params.validate_url(url) + if isawaitable(result): + result = await result + if not result: + raise UrlValidationError(f"URL rejected by validate_url: {url}") + return url + + if parsed.scheme not in ("http", "https"): + raise UrlValidationError(f"URL scheme {parsed.scheme!r} is not allowed; must be http or https") + + if not parsed.hostname: + raise UrlValidationError(f"Invalid URL: {url!r}") + + if params.allow_private_network: + return url + + addresses = await _resolve_host(parsed.hostname) + if not addresses: + raise UrlValidationError(f"URL {url} did not resolve to any address") + for address in addresses: + if is_private_address(address): + raise UrlValidationError( + f"URL {url} resolves to private or loopback address {address}; set allow_private_network=True to bypass" + ) + + return url + + +def is_private_address(address: str) -> bool: + """True if an IP address is loopback, private, link-local, unspecified, or IPv6 site-local.""" + try: + ip = ipaddress.ip_address(address) + except ValueError: + return True # Unknown: fail closed. + if ip.is_loopback or ip.is_private or ip.is_link_local or ip.is_unspecified: + return True + # RFC 4291 deprecated IPv6 site-local (fec0::/10). Python's is_private does + # not classify it, but we reject for parity with the C# SDK. + if isinstance(ip, ipaddress.IPv6Address): + packed = int(ip) + return (packed >> 118) == 0x3FB # top 10 bits == 1111111011 + return False + + +async def _resolve_host(host: str) -> List[str]: + """Resolve a hostname to its IP addresses; short-circuit IP literals.""" + try: + ipaddress.ip_address(host) + return [host] + except ValueError: + pass + + loop = asyncio.get_running_loop() + try: + results = await loop.getaddrinfo(host, None) + except socket.gaierror as err: + raise UrlValidationError(f"Could not resolve {host}: {err}") from err + + return list({entry[4][0] for entry in results}) diff --git a/packages/mcpplugin/tests/test_ai_plugin.py b/packages/mcpplugin/tests/test_ai_plugin.py index 13f30406..4ae4d781 100644 --- a/packages/mcpplugin/tests/test_ai_plugin.py +++ b/packages/mcpplugin/tests/test_ai_plugin.py @@ -106,6 +106,19 @@ async def create_transport(self, url: str, transport_type: str, headers: Optiona yield (read_stream, write_stream) +@pytest.fixture(autouse=True) +def _stub_url_validation_dns(): + """ + Prevent DNS lookups from running when ai_plugin runs URL validation + against placeholder test hosts (e.g., http://test-server). + """ + with patch( + "microsoft_teams.mcpplugin.url_validation._resolve_host", + new=AsyncMock(return_value=["8.8.8.8"]), + ): + yield + + @pytest.fixture def sample_tools() -> List[MockMCPTool]: """Sample MCP tools for testing.""" diff --git a/packages/mcpplugin/tests/test_server_plugin.py b/packages/mcpplugin/tests/test_server_plugin.py index 1069544b..b7131177 100644 --- a/packages/mcpplugin/tests/test_server_plugin.py +++ b/packages/mcpplugin/tests/test_server_plugin.py @@ -3,12 +3,12 @@ Licensed under the MIT License. """ -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from microsoft_teams.ai import Function from microsoft_teams.apps import FastAPIAdapter, HttpServer, PluginStartEvent -from microsoft_teams.mcpplugin.server_plugin import McpServerPlugin +from microsoft_teams.mcpplugin.server_plugin import McpServerPlugin, _AuthMiddleware from pydantic import BaseModel, ValidationError # pyright: basic @@ -352,6 +352,113 @@ async def test_on_start_exception_propagates_and_not_mounted( assert plugin._mounted is False + async def test_on_start_without_require_auth_emits_warning( + self, plugin: McpServerPlugin, mock_mcp_server: MagicMock, mock_http_server + ): + """Starting without require_auth logs the 'no auth' warning once.""" + mock_fastapi_adapter = MagicMock(spec=FastAPIAdapter) + mock_fastapi_adapter.lifespans = [] + mock_fastapi_adapter.app = MagicMock() + mock_http_server.adapter = mock_fastapi_adapter + plugin.http_server = mock_http_server + mock_mcp_server.http_app = MagicMock(return_value=MagicMock()) + + with patch("microsoft_teams.mcpplugin.server_plugin.logger") as mock_logger: + await plugin.on_start(PluginStartEvent(port=3978)) + + warning_calls = [call for call in mock_logger.warning.call_args_list if "without require_auth" in call.args[0]] + assert len(warning_calls) == 1 + + async def test_on_start_with_require_auth_wraps_in_middleware(self, mock_http_server, mock_mcp_server: MagicMock): + """Starting with require_auth wraps the mcp app in _AuthMiddleware and skips the warning.""" + with patch("microsoft_teams.mcpplugin.server_plugin.FastMCP") as mock_fastmcp_class: + mock_fastmcp_class.return_value = mock_mcp_server + plugin = McpServerPlugin(require_auth=lambda _req: True) + + mock_fastapi_adapter = MagicMock(spec=FastAPIAdapter) + mock_fastapi_adapter.lifespans = [] + mock_fastapi_adapter.app = MagicMock() + mock_http_server.adapter = mock_fastapi_adapter + plugin.http_server = mock_http_server + + mock_mcp_http_app = MagicMock() + mock_mcp_server.http_app = MagicMock(return_value=mock_mcp_http_app) + + with patch("microsoft_teams.mcpplugin.server_plugin.logger") as mock_logger: + await plugin.on_start(PluginStartEvent(port=3978)) + + mount_args = mock_fastapi_adapter.app.mount.call_args + assert mount_args.args[0] == "/" + assert isinstance(mount_args.args[1], _AuthMiddleware) + assert not any("without require_auth" in call.args[0] for call in mock_logger.warning.call_args_list) + + +class TestAuthMiddleware: + """Tests for _AuthMiddleware ASGI wrapper.""" + + def _http_scope(self) -> dict: + return {"type": "http", "method": "POST", "path": "/mcp", "headers": []} + + async def _noop_receive(self) -> dict: + return {"type": "http.request", "body": b"", "more_body": False} + + async def test_passthrough_non_http_scope(self): + app = AsyncMock() + mw = _AuthMiddleware(app, lambda _req: True) + + scope = {"type": "lifespan"} + send = AsyncMock() + await mw(scope, self._noop_receive, send) + + app.assert_awaited_once_with(scope, self._noop_receive, send) + send.assert_not_called() + + async def test_allows_when_require_auth_returns_true(self): + app = AsyncMock() + mw = _AuthMiddleware(app, lambda _req: True) + + scope = self._http_scope() + send = AsyncMock() + await mw(scope, self._noop_receive, send) + + app.assert_awaited_once() + send.assert_not_called() + + async def test_allows_when_require_auth_async_returns_true(self): + app = AsyncMock() + + async def require_auth(_req): + return True + + mw = _AuthMiddleware(app, require_auth) + await mw(self._http_scope(), self._noop_receive, AsyncMock()) + app.assert_awaited_once() + + async def test_rejects_with_401_when_returning_false(self): + app = AsyncMock() + mw = _AuthMiddleware(app, lambda _req: False) + + send = AsyncMock() + await mw(self._http_scope(), self._noop_receive, send) + + app.assert_not_called() + start_call = send.await_args_list[0] + assert start_call.args[0]["type"] == "http.response.start" + assert start_call.args[0]["status"] == 401 + + async def test_rejects_with_401_when_raising(self): + app = AsyncMock() + + def require_auth(_req): + raise RuntimeError("bad token") + + mw = _AuthMiddleware(app, require_auth) + send = AsyncMock() + await mw(self._http_scope(), self._noop_receive, send) + + app.assert_not_called() + assert send.await_args_list[0].args[0]["status"] == 401 + class TestMcpServerPluginOnStop: """Tests for McpServerPlugin.on_stop method.""" diff --git a/packages/mcpplugin/tests/test_url_validation.py b/packages/mcpplugin/tests/test_url_validation.py new file mode 100644 index 00000000..3dc9186d --- /dev/null +++ b/packages/mcpplugin/tests/test_url_validation.py @@ -0,0 +1,178 @@ +""" +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. +""" + +from unittest.mock import AsyncMock, patch + +import pytest +from microsoft_teams.mcpplugin.url_validation import ( + UrlValidationError, + UrlValidationParams, + is_private_address, + validate_mcp_server_url, +) + +RESOLVE_HOST = "microsoft_teams.mcpplugin.url_validation._resolve_host" + + +@pytest.mark.parametrize( + "address,expected", + [ + ("127.0.0.1", True), + ("10.0.0.1", True), + ("10.255.255.255", True), + ("172.16.0.1", True), + ("172.31.255.255", True), + ("192.168.1.1", True), + ("169.254.169.254", True), + ("8.8.8.8", False), + ("1.1.1.1", False), + ("172.15.0.1", False), + ("172.32.0.1", False), + ("::1", True), + ("fc00::1", True), + ("fd00::1", True), + ("fe80::1", True), + ("fec0::1", True), + ("::", True), + ("2001:4860:4860::8888", False), + ("::ffff:127.0.0.1", True), + ("::ffff:8.8.8.8", False), + ("not-an-ip", True), + ], +) +def test_is_private_address(address: str, expected: bool) -> None: + assert is_private_address(address) == expected + + +@pytest.mark.asyncio +async def test_rejects_unparseable_url() -> None: + with pytest.raises(UrlValidationError): + await validate_mcp_server_url("not a url") + + +@pytest.mark.asyncio +async def test_rejects_non_http_schemes() -> None: + with pytest.raises(UrlValidationError, match="scheme"): + await validate_mcp_server_url("file:///etc/passwd") + with pytest.raises(UrlValidationError, match="scheme"): + await validate_mcp_server_url("ftp://example.com/x") + + +@pytest.mark.asyncio +async def test_accepts_public_url_with_public_dns() -> None: + with patch(RESOLVE_HOST, new=AsyncMock(return_value=["8.8.8.8"])): + result = await validate_mcp_server_url("https://example.com/mcp") + assert result == "https://example.com/mcp" + + +@pytest.mark.asyncio +async def test_rejects_url_resolving_to_private_ip() -> None: + with patch(RESOLVE_HOST, new=AsyncMock(return_value=["10.0.0.5"])): + with pytest.raises(UrlValidationError, match="private or loopback"): + await validate_mcp_server_url("https://internal.example.com/mcp") + + +@pytest.mark.asyncio +async def test_rejects_when_any_resolved_address_is_private() -> None: + with patch(RESOLVE_HOST, new=AsyncMock(return_value=["8.8.8.8", "192.168.1.1"])): + with pytest.raises(UrlValidationError, match="private or loopback"): + await validate_mcp_server_url("https://mixed.example.com/mcp") + + +@pytest.mark.asyncio +async def test_rejects_ip_literal_private() -> None: + # _resolve_host short-circuits for IP literals (no DNS call), so let the real + # implementation run; rejection should still fire from the private-IP check. + with pytest.raises(UrlValidationError, match="private or loopback"): + await validate_mcp_server_url("http://127.0.0.1:3000") + + +@pytest.mark.asyncio +async def test_accepts_private_ip_when_allow_private_network() -> None: + result = await validate_mcp_server_url( + "http://127.0.0.1:3000", + UrlValidationParams(allow_private_network=True), + ) + assert result == "http://127.0.0.1:3000" + + +@pytest.mark.asyncio +async def test_accepts_private_hostname_when_allow_private_network_skips_dns() -> None: + resolve = AsyncMock() + with patch(RESOLVE_HOST, new=resolve): + result = await validate_mcp_server_url( + "https://internal.example.com/mcp", + UrlValidationParams(allow_private_network=True), + ) + assert result == "https://internal.example.com/mcp" + resolve.assert_not_called() + + +@pytest.mark.asyncio +async def test_validate_url_sync_fully_replaces_default_checks() -> None: + seen: list[str] = [] + + def validator(url: str) -> bool: + seen.append(url) + return True + + result = await validate_mcp_server_url( + "file:///etc/passwd", + UrlValidationParams(validate_url=validator), + ) + assert result == "file:///etc/passwd" + assert seen == ["file:///etc/passwd"] + + +@pytest.mark.asyncio +async def test_validate_url_async_replaces_default_checks() -> None: + async def validator(url: str) -> bool: + return "example" in url + + result = await validate_mcp_server_url( + "https://example.com/mcp", + UrlValidationParams(validate_url=validator), + ) + assert result == "https://example.com/mcp" + + +@pytest.mark.asyncio +async def test_validate_url_rejects_when_returning_false() -> None: + with pytest.raises(UrlValidationError, match="rejected by validate_url"): + await validate_mcp_server_url( + "https://example.com/mcp", + UrlValidationParams(validate_url=lambda _url: False), + ) + + +@pytest.mark.asyncio +async def test_rejects_when_dns_lookup_fails() -> None: + # _resolve_host wraps socket.gaierror as UrlValidationError; verify the + # caller surfaces that rather than swallowing it. + with patch( + RESOLVE_HOST, + new=AsyncMock(side_effect=UrlValidationError("Could not resolve nonexistent.invalid")), + ): + with pytest.raises(UrlValidationError, match="Could not resolve"): + await validate_mcp_server_url("https://nonexistent.invalid/mcp") + + +@pytest.mark.asyncio +async def test_rejects_when_dns_returns_empty_list() -> None: + with patch(RESOLVE_HOST, new=AsyncMock(return_value=[])): + with pytest.raises(UrlValidationError, match="did not resolve"): + await validate_mcp_server_url("https://example.com/mcp") + + +@pytest.mark.asyncio +async def test_propagates_exceptions_from_validate_url() -> None: + def boom(_url: str) -> bool: + raise RuntimeError("custom failure") + + with pytest.raises(RuntimeError, match="custom failure"): + await validate_mcp_server_url( + "https://example.com/mcp", + UrlValidationParams(validate_url=boom), + )