diff --git a/galaxy/core/events.py b/galaxy/core/events.py index 5e0e5cc4d..156733fe0 100644 --- a/galaxy/core/events.py +++ b/galaxy/core/events.py @@ -45,6 +45,10 @@ class EventType(Enum): ) DEVICE_STATUS_CHANGED = "device_status_changed" # Device status changed + # LLM cost/token tracking events + LLM_CALL_COMPLETED = "llm_call_completed" # LLM API call finished + COST_THRESHOLD_EXCEEDED = "cost_threshold_exceeded" # Session cost exceeded configured threshold + @dataclass class Event: @@ -120,6 +124,37 @@ class DeviceEvent(Event): all_devices: Dict[str, Dict[str, Any]] # Snapshot of all devices in registry +@dataclass +class LLMCallEvent(Event): + """ + LLM API call completed event. + + Extends base Event class with LLM-specific cost and token usage + information for tracking and aggregation. + """ + + agent_type: str # e.g. "HOST_AGENT", "CONSTELLATION_AGENT" + model: str # e.g. "gpt-4o", "claude-3-5-sonnet-20241022" + prompt_tokens: int + completion_tokens: int + cost: float + duration_ms: float # wall time of the API call + + +@dataclass +class CostThresholdExceededEvent(Event): + """ + Session cost threshold exceeded event. + + Published when the accumulated session cost crosses the configured + cost_alert_threshold, allowing observers to surface alerts. + """ + + session_id: str + total_cost: float + threshold: float + + class IEventObserver(ABC): """ Interface for event observers. diff --git a/galaxy/session/observers/base_observer.py b/galaxy/session/observers/base_observer.py index 728ce31f2..7d72c36b7 100644 --- a/galaxy/session/observers/base_observer.py +++ b/galaxy/session/observers/base_observer.py @@ -11,9 +11,11 @@ from ...agents.constellation_agent import ConstellationAgent from ...core.events import ( ConstellationEvent, + CostThresholdExceededEvent, Event, EventType, IEventObserver, + LLMCallEvent, TaskEvent, ) from ...visualization.change_detector import VisualizationChangeDetector @@ -107,12 +109,21 @@ class SessionMetricsObserver(IEventObserver): Observer that collects session metrics and statistics. """ - def __init__(self, session_id: str, logger: Optional[logging.Logger] = None): + _LLM_CALLS_CAP = 500 + + def __init__( + self, + session_id: str, + logger: Optional[logging.Logger] = None, + cost_alert_threshold: float = 0.0, + ): """ Initialize SessionMetricsObserver. :param session_id: Unique session identifier for metrics tracking :param logger: Optional logger instance (creates default if None) + :param cost_alert_threshold: Emit CostThresholdExceededEvent when total + cost exceeds this value. ``0.0`` (default) disables the check. """ self.metrics: Dict[str, Any] = { "session_id": session_id, @@ -127,7 +138,18 @@ def __init__(self, session_id: str, logger: Optional[logging.Logger] = None): "total_constellation_time": 0.0, "constellation_timings": {}, "constellation_modifications": {}, # Track modifications per constellation + "llm_metrics": { + "total_cost": 0.0, + "total_prompt_tokens": 0, + "total_completion_tokens": 0, + "total_api_calls": 0, + "cost_by_agent": {}, # agent_type -> float + "cost_by_model": {}, # model_name -> float + "calls": [], # list of LLMCallEvent data (capped at last 500) + }, } + self._cost_alert_threshold = cost_alert_threshold + self._threshold_already_exceeded = False self.logger = logger or logging.getLogger(__name__) async def on_event(self, event: Event) -> None: @@ -136,7 +158,9 @@ async def on_event(self, event: Event) -> None: :param event: Event instance for metrics collection """ - if isinstance(event, TaskEvent): + if isinstance(event, LLMCallEvent): + await self._handle_llm_call_event(event) + elif isinstance(event, TaskEvent): await self._handle_task_event(event) elif isinstance(event, ConstellationEvent): await self._handle_constellation_event(event) @@ -304,6 +328,65 @@ def _handle_constellation_modified(self, event: ConstellationEvent) -> None: modification_record ) + async def _handle_llm_call_event(self, event: LLMCallEvent) -> None: + """ + Handle LLM_CALL_COMPLETED event — update llm_metrics aggregates. + + :param event: LLMCallEvent instance + """ + llm = self.metrics["llm_metrics"] + + # Update totals + llm["total_cost"] += event.cost + llm["total_prompt_tokens"] += event.prompt_tokens + llm["total_completion_tokens"] += event.completion_tokens + llm["total_api_calls"] += 1 + + # Per-agent and per-model breakdowns + llm["cost_by_agent"][event.agent_type] = ( + llm["cost_by_agent"].get(event.agent_type, 0.0) + event.cost + ) + llm["cost_by_model"][event.model] = ( + llm["cost_by_model"].get(event.model, 0.0) + event.cost + ) + + # Append call record (cap at _LLM_CALLS_CAP) + call_record = { + "agent_type": event.agent_type, + "model": event.model, + "prompt_tokens": event.prompt_tokens, + "completion_tokens": event.completion_tokens, + "cost": event.cost, + "duration_ms": event.duration_ms, + "timestamp": event.timestamp, + } + calls = llm["calls"] + calls.append(call_record) + if len(calls) > self._LLM_CALLS_CAP: + llm["calls"] = calls[-self._LLM_CALLS_CAP :] + + # Cost threshold alerting + if ( + self._cost_alert_threshold > 0 + and llm["total_cost"] > self._cost_alert_threshold + and not self._threshold_already_exceeded + ): + self._threshold_already_exceeded = True + self.logger.warning( + "Session %s exceeded cost threshold: $%.4f", + self.metrics["session_id"], + llm["total_cost"], + ) + from ...core.events import get_event_bus + + threshold_event = CostThresholdExceededEvent( + event_type=EventType.COST_THRESHOLD_EXCEEDED, + session_id=self.metrics["session_id"], + total_cost=llm["total_cost"], + threshold=self._cost_alert_threshold, + ) + await get_event_bus().publish_event(threshold_event) + def get_metrics(self) -> Dict[str, Any]: """ Get collected metrics with computed statistics. diff --git a/galaxy/webui/frontend/src/components/layout/RightPanel.tsx b/galaxy/webui/frontend/src/components/layout/RightPanel.tsx index d19bcb5dd..d62e513ef 100644 --- a/galaxy/webui/frontend/src/components/layout/RightPanel.tsx +++ b/galaxy/webui/frontend/src/components/layout/RightPanel.tsx @@ -1,10 +1,11 @@ import React, { useEffect, useMemo } from 'react'; import { shallow } from 'zustand/shallow'; import clsx from 'clsx'; -import { Network, Star } from 'lucide-react'; +import { Network, Star, DollarSign } from 'lucide-react'; import ConstellationBlock from '../constellation/ConstellationBlock'; import TaskList from '../tasks/TaskList'; import TaskDetailPanel from '../tasks/TaskDetailPanel'; +import CostDashboard from '../metrics/CostDashboard'; import { ConstellationSummary, Task, useGalaxyStore } from '../../store/galaxyStore'; const statusColors: Record = { @@ -22,6 +23,7 @@ const RightPanel: React.FC = () => { ui, setActiveConstellation, setActiveTask, + setRightPanelTab, } = useGalaxyStore( (state) => ({ constellations: state.constellations, @@ -29,6 +31,7 @@ const RightPanel: React.FC = () => { ui: state.ui, setActiveConstellation: state.setActiveConstellation, setActiveTask: state.setActiveTask, + setRightPanelTab: state.setRightPanelTab, }), shallow, ); @@ -77,9 +80,47 @@ const RightPanel: React.FC = () => { setActiveConstellation(selected || null); }; + const isCostView = ui.rightPanelTab === 'cost'; + return (
+ {/* Panel tab bar */} +
+ + +
+ + {/* Cost dashboard */} + {isCostView && ( +
+ +
+ )} + {/* Constellation Overview - Top half */} + {!isCostView && (
@@ -116,12 +157,14 @@ const RightPanel: React.FC = () => { />
+ )} {/* TaskStar List or Task Detail - Bottom half */} + {!isCostView && (
{activeTask ? ( - setActiveTask(null)} /> ) : ( @@ -142,6 +185,7 @@ const RightPanel: React.FC = () => { )}
+ )}
); }; diff --git a/galaxy/webui/frontend/src/components/metrics/CostByModelChart.tsx b/galaxy/webui/frontend/src/components/metrics/CostByModelChart.tsx new file mode 100644 index 000000000..22bea6015 --- /dev/null +++ b/galaxy/webui/frontend/src/components/metrics/CostByModelChart.tsx @@ -0,0 +1,51 @@ +import React from 'react'; + +interface BarChartProps { + data: Record; + label: string; + colorClass?: string; +} + +/** + * Horizontal bar chart rendered with pure Tailwind CSS. + * Used to show cost broken down by a string key (model name, agent type, etc.). + */ +const CostByModelChart: React.FC = ({ + data, + label, + colorClass = 'bg-cyan-500', +}) => { + const entries = Object.entries(data).sort(([, a], [, b]) => b - a); + const max = entries.length > 0 ? entries[0][1] : 0; + + if (entries.length === 0) { + return ( +
No data yet
+ ); + } + + return ( +
+
{label}
+ {entries.map(([key, value]) => { + const pct = max > 0 ? (value / max) * 100 : 0; + return ( +
+
{key}
+
+
+
+
+ ${value.toFixed(4)} +
+
+ ); + })} +
+ ); +}; + +export default CostByModelChart; diff --git a/galaxy/webui/frontend/src/components/metrics/CostDashboard.tsx b/galaxy/webui/frontend/src/components/metrics/CostDashboard.tsx new file mode 100644 index 000000000..1880ff22e --- /dev/null +++ b/galaxy/webui/frontend/src/components/metrics/CostDashboard.tsx @@ -0,0 +1,204 @@ +import React, { useState } from 'react'; +import { shallow } from 'zustand/shallow'; +import { DollarSign, Zap, Activity, X } from 'lucide-react'; +import { useGalaxyStore, LLMCallRecord } from '../../store/galaxyStore'; +import CostByModelChart from './CostByModelChart'; + +const formatCost = (v: number) => `$${v.toFixed(4)}`; +const formatTs = (ts: number) => { + try { + return new Intl.DateTimeFormat('en-US', { + hour: '2-digit', + minute: '2-digit', + second: '2-digit', + }).format(new Date(ts)); + } catch { + return '—'; + } +}; + +interface CostAlertBannerProps { + message: string; + onDismiss: () => void; +} + +const CostAlertBanner: React.FC = ({ message, onDismiss }) => ( +
+ {message} + +
+); + +interface RecentCallsTableProps { + calls: LLMCallRecord[]; +} + +const RecentCallsTable: React.FC = ({ calls }) => { + const recent = [...calls].reverse().slice(0, 50); + + if (recent.length === 0) { + return ( +
+ No LLM calls recorded yet. +
+ ); + } + + return ( +
+ + + + + + + + + + + + + + {recent.map((call, i) => ( + + + + + + + + + + ))} + +
TimeAgentModelInOutCostms
{formatTs(call.timestamp)} + {call.agent_type} + + {call.model} + {call.prompt_tokens.toLocaleString()}{call.completion_tokens.toLocaleString()} + {formatCost(call.cost)} + + {call.duration_ms.toFixed(0)} +
+
+ ); +}; + +/** + * CostDashboard — shows LLM cost and token metrics accumulated during the + * current Galaxy session. Updates in real-time via WebSocket events handled + * by the Zustand store (appendLLMCall / pushNotification). + */ +const CostDashboard: React.FC = () => { + const { llmMetrics, notifications, dismissNotification } = useGalaxyStore( + (state) => ({ + llmMetrics: state.llmMetrics, + notifications: state.notifications, + dismissNotification: state.dismissNotification, + }), + shallow, + ); + + // Surface the most recent unread cost_alert notification as a banner. + const costAlert = notifications.find( + (n) => !n.read && n.source === 'llm_metrics' && n.severity === 'warning', + ); + + const [showTable, setShowTable] = useState(false); + + return ( +
+ {/* Cost alert banner */} + {costAlert && ( + dismissNotification(costAlert.id)} + /> + )} + + {/* Summary row */} +
+
+
+ + Total Cost +
+
+ {formatCost(llmMetrics.totalCost)} +
+
+ +
+
+ + Prompt Tokens +
+
+ {llmMetrics.totalPromptTokens.toLocaleString()} +
+
+ +
+
+ + Completion +
+
+ {llmMetrics.totalCompletionTokens.toLocaleString()} +
+
+ +
+
+ + API Calls +
+
+ {llmMetrics.totalApiCalls} +
+
+
+ + {/* Cost by model */} +
+ +
+ + {/* Cost by agent */} +
+ +
+ + {/* Recent calls toggle */} +
+ + {showTable && ( + + )} +
+
+ ); +}; + +export default CostDashboard; diff --git a/galaxy/webui/frontend/src/main.tsx b/galaxy/webui/frontend/src/main.tsx index 666ed8750..64f6656f9 100644 --- a/galaxy/webui/frontend/src/main.tsx +++ b/galaxy/webui/frontend/src/main.tsx @@ -5,6 +5,7 @@ import './index.css'; import { GalaxyEvent, getWebSocketClient } from './services/websocket'; import { createClientId, + LLMCallRecord, NotificationItem, Task, TaskLogEntry, @@ -495,6 +496,35 @@ const handleDeviceEvent = (event: GalaxyEvent) => { // Status is still tracked and displayed in the UI }; +const handleLLMMetricsUpdate = (event: GalaxyEvent) => { + const store = useGalaxyStore.getState(); + const call: LLMCallRecord = { + agent_type: event.agent_type || 'unknown', + model: event.model || 'unknown', + prompt_tokens: event.prompt_tokens ?? 0, + completion_tokens: event.completion_tokens ?? 0, + cost: event.cost ?? 0, + duration_ms: event.duration_ms ?? 0, + timestamp: safeTimestamp(event), + }; + store.appendLLMCall(call); +}; + +const handleCostAlert = (event: GalaxyEvent) => { + const store = useGalaxyStore.getState(); + const totalCost: number = event.total_cost ?? 0; + const threshold: number = event.threshold ?? 0; + store.pushNotification({ + id: `cost-alert-${Date.now()}`, + title: 'Cost threshold exceeded', + description: `Session cost $${totalCost.toFixed(4)} exceeded threshold $${threshold.toFixed(4)}`, + severity: 'warning', + timestamp: Date.now(), + read: false, + source: 'llm_metrics', + }); +}; + const handleGenericEvent = (event: GalaxyEvent) => { // Handle session control messages (use 'type' field instead of 'event_type') const messageType = event.type || event.event_type; @@ -540,6 +570,17 @@ const handleGenericEvent = (event: GalaxyEvent) => { return; } + // Handle LLM metrics events + if (messageType === 'llm_metrics_update') { + handleLLMMetricsUpdate(event); + return; + } + + if (messageType === 'cost_alert') { + handleCostAlert(event); + return; + } + // Handle device events if (event.event_type?.startsWith('device_')) { handleDeviceEvent(event); diff --git a/galaxy/webui/frontend/src/services/websocket.ts b/galaxy/webui/frontend/src/services/websocket.ts index 3bae129ed..1b1eda9cf 100644 --- a/galaxy/webui/frontend/src/services/websocket.ts +++ b/galaxy/webui/frontend/src/services/websocket.ts @@ -28,6 +28,15 @@ export interface GalaxyEvent { message?: string; session_name?: string; task_name?: string; + // LLM metrics events + model?: string; + prompt_tokens?: number; + completion_tokens?: number; + cost?: number; + duration_ms?: number; + // Cost threshold events + total_cost?: number; + threshold?: number; } export type EventCallback = (event: GalaxyEvent) => void; diff --git a/galaxy/webui/frontend/src/store/galaxyStore.ts b/galaxy/webui/frontend/src/store/galaxyStore.ts index 1d065d5e2..0fa495e24 100644 --- a/galaxy/webui/frontend/src/store/galaxyStore.ts +++ b/galaxy/webui/frontend/src/store/galaxyStore.ts @@ -124,6 +124,27 @@ export interface Device { highlightUntil?: number; } +export interface LLMCallRecord { + agent_type: string; + model: string; + prompt_tokens: number; + completion_tokens: number; + cost: number; + duration_ms: number; + timestamp: number; +} + +export interface LLMMetrics { + totalCost: number; + totalPromptTokens: number; + totalCompletionTokens: number; + totalApiCalls: number; + costByAgent: Record; + costByModel: Record; + recentCalls: LLMCallRecord[]; + lastUpdated: number | null; +} + export interface NotificationItem { id: string; title: string; @@ -148,7 +169,7 @@ interface SessionState { interface UIState { searchQuery: string; messageKindFilter: MessageKind | 'all'; - rightPanelTab: 'constellation' | 'tasks' | 'details'; + rightPanelTab: 'constellation' | 'tasks' | 'details' | 'cost'; activeConstellationId: string | null; activeTaskId: string | null; activeDeviceId: string | null; @@ -214,6 +235,10 @@ interface GalaxyStore { toggleLeftDrawer: (open?: boolean) => void; toggleRightDrawer: (open?: boolean) => void; + llmMetrics: LLMMetrics; + setLLMMetrics: (metrics: LLMMetrics) => void; + appendLLMCall: (call: LLMCallRecord) => void; + toggleDebugMode: () => void; toggleHighContrast: () => void; resetSessionState: (options?: { clearHistory?: boolean }) => void; @@ -738,6 +763,38 @@ export const useGalaxyStore = create()((set, get) => ({ })), })), + llmMetrics: { + totalCost: 0, + totalPromptTokens: 0, + totalCompletionTokens: 0, + totalApiCalls: 0, + costByAgent: {}, + costByModel: {}, + recentCalls: [], + lastUpdated: null, + }, + setLLMMetrics: (metrics) => set({ llmMetrics: metrics }), + appendLLMCall: (call) => + set((state) => { + const prev = state.llmMetrics; + const costByAgent = { ...prev.costByAgent }; + costByAgent[call.agent_type] = (costByAgent[call.agent_type] ?? 0) + call.cost; + const costByModel = { ...prev.costByModel }; + costByModel[call.model] = (costByModel[call.model] ?? 0) + call.cost; + return { + llmMetrics: { + totalCost: prev.totalCost + call.cost, + totalPromptTokens: prev.totalPromptTokens + call.prompt_tokens, + totalCompletionTokens: prev.totalCompletionTokens + call.completion_tokens, + totalApiCalls: prev.totalApiCalls + 1, + costByAgent, + costByModel, + recentCalls: [...prev.recentCalls, call].slice(-500), + lastUpdated: Date.now(), + }, + }; + }), + ui: { ...defaultUIState(), activeConstellationId: mockData?.constellation.id || null, diff --git a/galaxy/webui/models/responses.py b/galaxy/webui/models/responses.py index da3d9b9b1..7ebc83cea 100644 --- a/galaxy/webui/models/responses.py +++ b/galaxy/webui/models/responses.py @@ -195,3 +195,37 @@ class ErrorMessage(BaseModel): type: Literal[WebSocketMessageType.ERROR] = WebSocketMessageType.ERROR message: str = Field(..., description="Error message describing what went wrong") + + +class LLMCallRecord(BaseModel): + """ + Record of a single LLM API call. + + Captures token usage, cost, and timing for one completed LLM call. + """ + + agent_type: str = Field(..., description="Agent type that made the call") + model: str = Field(..., description="Model name used for the call") + prompt_tokens: int = Field(..., description="Number of prompt tokens consumed") + completion_tokens: int = Field(..., description="Number of completion tokens generated") + cost: float = Field(..., description="Estimated cost in USD") + duration_ms: float = Field(..., description="Wall-clock duration of the API call in milliseconds") + timestamp: float = Field(..., description="Unix timestamp when the call completed") + + +class SessionCostSummary(BaseModel): + """ + Aggregated LLM cost and token usage for the active session. + + Summarises all LLM calls made during a Galaxy session including + per-agent and per-model breakdowns. + """ + + session_id: str = Field(..., description="Unique identifier of the session") + total_cost: float = Field(..., description="Total estimated cost in USD") + total_prompt_tokens: int = Field(..., description="Total prompt tokens consumed") + total_completion_tokens: int = Field(..., description="Total completion tokens generated") + total_api_calls: int = Field(..., description="Total number of LLM API calls made") + cost_by_agent: Dict[str, float] = Field(..., description="Cost breakdown by agent type") + cost_by_model: Dict[str, float] = Field(..., description="Cost breakdown by model name") + recent_calls: List[LLMCallRecord] = Field(..., description="Most recent LLM call records") diff --git a/galaxy/webui/routers/__init__.py b/galaxy/webui/routers/__init__.py index e20700041..f9eacf6e7 100644 --- a/galaxy/webui/routers/__init__.py +++ b/galaxy/webui/routers/__init__.py @@ -11,11 +11,13 @@ from galaxy.webui.routers.auth import router as auth_router from galaxy.webui.routers.health import router as health_router from galaxy.webui.routers.devices import router as devices_router +from galaxy.webui.routers.metrics import router as metrics_router from galaxy.webui.routers.websocket import router as websocket_router __all__ = [ "auth_router", "health_router", "devices_router", + "metrics_router", "websocket_router", ] diff --git a/galaxy/webui/routers/metrics.py b/galaxy/webui/routers/metrics.py new file mode 100644 index 000000000..3509fef90 --- /dev/null +++ b/galaxy/webui/routers/metrics.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Metrics router for Galaxy Web UI. + +Exposes LLM cost and token-usage data accumulated during the active Galaxy +session. All data originates from ``SessionMetricsObserver`` and is served +by ``MetricsService``. +""" + +import logging +from typing import Literal + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import PlainTextResponse, Response + +from galaxy.webui.dependencies import get_app_state, verify_api_key +from galaxy.webui.models.responses import SessionCostSummary +from galaxy.webui.services.metrics_service import MetricsService + +router = APIRouter(prefix="/api/metrics", tags=["metrics"]) +logger = logging.getLogger(__name__) + + +@router.get( + "/cost", + response_model=SessionCostSummary, + dependencies=[Depends(verify_api_key)], +) +async def get_session_cost() -> SessionCostSummary: + """ + Return aggregated LLM cost and token usage for the active session. + + Includes per-agent and per-model cost breakdowns plus the most recent + call records (up to the last 500 stored by the observer). + + :return: ``SessionCostSummary`` for the active session. + :raises HTTPException: 404 when no active session is available. + """ + app_state = get_app_state() + service = MetricsService(app_state) + + summary = service.get_cost_summary() + if summary is None: + raise HTTPException(status_code=404, detail="No active session") + + return SessionCostSummary(**summary) + + +@router.get( + "/cost/export", + dependencies=[Depends(verify_api_key)], +) +async def export_cost_log( + format: Literal["json", "csv"] = "json", +) -> Response: + """ + Download the full LLM call log for the active session. + + :param format: Output format — ``json`` (default) or ``csv``. + :return: File download response. + :raises HTTPException: 404 when no active session is available. + """ + app_state = get_app_state() + service = MetricsService(app_state) + + if service.get_cost_summary() is None: + raise HTTPException(status_code=404, detail="No active session") + + if format == "csv": + content = service.export_calls_csv() + return PlainTextResponse( + content=content, + media_type="text/csv", + headers={"Content-Disposition": "attachment; filename=llm_calls.csv"}, + ) + + content = service.export_calls_json() + return PlainTextResponse( + content=content, + media_type="application/json", + headers={"Content-Disposition": "attachment; filename=llm_calls.json"}, + ) diff --git a/galaxy/webui/server.py b/galaxy/webui/server.py index d5d375ba0..0437b976e 100644 --- a/galaxy/webui/server.py +++ b/galaxy/webui/server.py @@ -28,7 +28,7 @@ from galaxy.core.events import get_event_bus from galaxy.webui.dependencies import get_app_state -from galaxy.webui.routers import auth_router, health_router, devices_router, websocket_router +from galaxy.webui.routers import auth_router, health_router, devices_router, metrics_router, websocket_router from galaxy.webui.websocket_observer import WebSocketObserver if TYPE_CHECKING: @@ -107,6 +107,7 @@ async def lifespan(app: FastAPI): app.include_router(auth_router) app.include_router(health_router) app.include_router(devices_router) +app.include_router(metrics_router) app.include_router(websocket_router) # Mount frontend static files if built diff --git a/galaxy/webui/services/__init__.py b/galaxy/webui/services/__init__.py index 2c72c9491..3aa0ba36a 100644 --- a/galaxy/webui/services/__init__.py +++ b/galaxy/webui/services/__init__.py @@ -11,9 +11,11 @@ from galaxy.webui.services.device_service import DeviceService from galaxy.webui.services.galaxy_service import GalaxyService from galaxy.webui.services.config_service import ConfigService +from galaxy.webui.services.metrics_service import MetricsService __all__ = [ "DeviceService", "GalaxyService", "ConfigService", + "MetricsService", ] diff --git a/galaxy/webui/services/metrics_service.py b/galaxy/webui/services/metrics_service.py new file mode 100644 index 000000000..e1e77eaf0 --- /dev/null +++ b/galaxy/webui/services/metrics_service.py @@ -0,0 +1,140 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Metrics service for Galaxy Web UI. + +Reads LLM cost and token metrics accumulated by SessionMetricsObserver +and exposes them in a form suitable for the API layer. +""" + +import csv +import io +import json +import logging +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +if TYPE_CHECKING: + from galaxy.webui.dependencies import AppState + + +class MetricsService: + """ + Thin service layer over the SessionMetricsObserver metrics dict. + + All data originates from ``SessionMetricsObserver.metrics["llm_metrics"]`` + which is updated in real time as LLM calls complete. + """ + + def __init__(self, app_state: "AppState") -> None: + """ + Initialise MetricsService. + + :param app_state: Application state providing access to the Galaxy session. + """ + self._app_state = app_state + self._logger = logging.getLogger(__name__) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_llm_metrics(self) -> Optional[Dict[str, Any]]: + """ + Return the ``llm_metrics`` dict from the active session observer, or + ``None`` when no session / observer is available. + """ + session = self._app_state.galaxy_session + if session is None: + return None + + observer = getattr(session, "_metrics_observer", None) + if observer is None: + return None + + return observer.metrics.get("llm_metrics") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def get_cost_summary(self) -> Optional[Dict[str, Any]]: + """ + Return a dict matching the ``SessionCostSummary`` response model. + + Returns ``None`` when no active session exists. + + :return: Cost summary dict or None. + """ + session = self._app_state.galaxy_session + if session is None: + return None + + observer = getattr(session, "_metrics_observer", None) + if observer is None: + return None + + llm = observer.metrics.get("llm_metrics") + if llm is None: + return None + + session_id: str = observer.metrics.get("session_id", "unknown") + + raw_calls: List[Dict[str, Any]] = llm.get("calls", []) + recent_calls = [ + { + "agent_type": c.get("agent_type", ""), + "model": c.get("model", ""), + "prompt_tokens": c.get("prompt_tokens", 0), + "completion_tokens": c.get("completion_tokens", 0), + "cost": c.get("cost", 0.0), + "duration_ms": c.get("duration_ms", 0.0), + "timestamp": c.get("timestamp", 0.0), + } + for c in raw_calls + ] + + return { + "session_id": session_id, + "total_cost": llm.get("total_cost", 0.0), + "total_prompt_tokens": llm.get("total_prompt_tokens", 0), + "total_completion_tokens": llm.get("total_completion_tokens", 0), + "total_api_calls": llm.get("total_api_calls", 0), + "cost_by_agent": llm.get("cost_by_agent", {}), + "cost_by_model": llm.get("cost_by_model", {}), + "recent_calls": recent_calls, + } + + def export_calls_json(self) -> str: + """ + Serialise the full LLM call log as a JSON string. + + :return: JSON string of call records. + """ + llm = self._get_llm_metrics() + calls = llm.get("calls", []) if llm else [] + return json.dumps(calls, indent=2) + + def export_calls_csv(self) -> str: + """ + Serialise the full LLM call log as a CSV string. + + :return: CSV string with headers. + """ + llm = self._get_llm_metrics() + calls: List[Dict[str, Any]] = llm.get("calls", []) if llm else [] + + output = io.StringIO() + fieldnames = [ + "timestamp", + "agent_type", + "model", + "prompt_tokens", + "completion_tokens", + "cost", + "duration_ms", + ] + writer = csv.DictWriter(output, fieldnames=fieldnames, extrasaction="ignore") + writer.writeheader() + writer.writerows(calls) + return output.getvalue() diff --git a/galaxy/webui/websocket_observer.py b/galaxy/webui/websocket_observer.py index 18afdffd0..d83bd9e8c 100644 --- a/galaxy/webui/websocket_observer.py +++ b/galaxy/webui/websocket_observer.py @@ -18,9 +18,11 @@ from galaxy.core.events import ( AgentEvent, ConstellationEvent, + CostThresholdExceededEvent, DeviceEvent, Event, IEventObserver, + LLMCallEvent, TaskEvent, ) @@ -101,7 +103,11 @@ def serialize_event(self, event: Event) -> Dict[str, Any]: } # Add type-specific fields using polymorphism - if isinstance(event, TaskEvent): + if isinstance(event, LLMCallEvent): + base_dict.update(self._serialize_llm_call_event_fields(event)) + elif isinstance(event, CostThresholdExceededEvent): + base_dict.update(self._serialize_cost_threshold_event_fields(event)) + elif isinstance(event, TaskEvent): base_dict.update(self._serialize_task_event_fields(event)) elif isinstance(event, ConstellationEvent): base_dict.update(self._serialize_constellation_event_fields(event)) @@ -169,6 +175,39 @@ def _serialize_device_event_fields(self, event: DeviceEvent) -> Dict[str, Any]: "all_devices": self.serialize_value(event.all_devices), } + def _serialize_llm_call_event_fields(self, event: LLMCallEvent) -> Dict[str, Any]: + """ + Extract LLM call-specific fields and set the frontend message type. + + :param event: The LLM call event to serialize + :return: Dictionary of LLM call-specific fields + """ + return { + "message_type": "llm_metrics_update", + "agent_type": event.agent_type, + "model": event.model, + "prompt_tokens": event.prompt_tokens, + "completion_tokens": event.completion_tokens, + "cost": event.cost, + "duration_ms": event.duration_ms, + } + + def _serialize_cost_threshold_event_fields( + self, event: CostThresholdExceededEvent + ) -> Dict[str, Any]: + """ + Extract cost-threshold-specific fields and set the frontend message type. + + :param event: The cost threshold exceeded event to serialize + :return: Dictionary of cost threshold-specific fields + """ + return { + "message_type": "cost_alert", + "session_id": event.session_id, + "total_cost": event.total_cost, + "threshold": event.threshold, + } + def serialize_value(self, value: Any) -> Any: """ Serialize a value to JSON-compatible format. diff --git a/ufo/llm/base.py b/ufo/llm/base.py index be81960d5..0932c3fae 100644 --- a/ufo/llm/base.py +++ b/ufo/llm/base.py @@ -3,12 +3,20 @@ import abc from importlib import import_module -from typing import Dict +from typing import Dict, NamedTuple import functools from ufo.llm.config_helper import get_agent_config from config.config_loader import get_ufo_config, get_galaxy_config +class CostResult(NamedTuple): + """Token usage and cost from a single LLM API call.""" + + cost: float + prompt_tokens: int + completion_tokens: int + + class BaseService(abc.ABC): @abc.abstractmethod def __init__(self, *args, **kwargs): @@ -122,7 +130,7 @@ def get_cost_estimator( prices: Dict[str, float], prompt_tokens: int, completion_tokens: int, - ) -> float: + ) -> CostResult: """ Calculates the cost estimate for using a specific model based on the number of prompt tokens and completion tokens. :param api_type: The type of api used. @@ -130,7 +138,7 @@ def get_cost_estimator( :param prices: A dictionary containing the prices for different models. :param prompt_tokens: The number of prompt tokens used. :param completion_tokens: The number of completion tokens used. - :return: The estimated cost for using the model. + :return: CostResult with cost and token counts. """ if api_type.lower() == "openai": @@ -154,5 +162,5 @@ def get_cost_estimator( + completion_tokens * prices[name]["output"] / 1000 ) else: - return 0 - return cost + cost = 0.0 + return CostResult(cost=cost, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) diff --git a/ufo/llm/claude.py b/ufo/llm/claude.py index 9c0602ba7..1224c3105 100644 --- a/ufo/llm/claude.py +++ b/ufo/llm/claude.py @@ -6,7 +6,7 @@ import anthropic from PIL import Image -from ufo.llm.base import BaseService +from ufo.llm.base import BaseService, CostResult logger = logging.getLogger(__name__) @@ -57,7 +57,9 @@ def chat_completion( max_tokens = max_tokens if max_tokens is not None else self.config["MAX_TOKENS"] responses = [] - cost = 0.0 + total_cost = 0.0 + total_prompt_tokens = 0 + total_completion_tokens = 0 system_prompt, user_prompt = self.process_messages(messages) for _ in range(n): @@ -72,13 +74,16 @@ def chat_completion( responses.append(response.content[0].text) prompt_tokens = response.usage.input_tokens completion_tokens = response.usage.output_tokens - cost += self.get_cost_estimator( + call_result = self.get_cost_estimator( self.api_type, self.model, self.prices, prompt_tokens, completion_tokens, ) + total_cost += call_result.cost + total_prompt_tokens += call_result.prompt_tokens + total_completion_tokens += call_result.completion_tokens except Exception as e: import traceback @@ -91,7 +96,11 @@ def chat_completion( time.sleep(3) continue - return responses, cost + return responses, CostResult( + cost=total_cost, + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + ) def process_messages( self, messages: List[Dict[str, str]] diff --git a/ufo/llm/gemini.py b/ufo/llm/gemini.py index 233cca9fe..76fc9d9e6 100644 --- a/ufo/llm/gemini.py +++ b/ufo/llm/gemini.py @@ -9,7 +9,7 @@ from google import genai from google.genai.types import GenerateContentConfig, Part, GenerateContentResponse -from ufo.llm.base import BaseService +from ufo.llm.base import BaseService, CostResult from ufo.llm.response_schema import ( AppAgentResponse, @@ -106,7 +106,7 @@ def chat_completion( ) prompt_tokens = response.usage_metadata.prompt_token_count completion_tokens = response.usage_metadata.candidates_token_count - cost = self.get_cost_estimator( + cost_result = self.get_cost_estimator( self.api_type, self.model, self.prices, @@ -125,7 +125,7 @@ def chat_completion( ) time.sleep(sleep_time) - return self.get_text_from_all_candidates(response), cost + return self.get_text_from_all_candidates(response), cost_result def process_messages(self, messages: List[Dict[str, str]]) -> List[str]: """ diff --git a/ufo/llm/llm_call.py b/ufo/llm/llm_call.py index 589d3e5c9..b59a9a6b4 100644 --- a/ufo/llm/llm_call.py +++ b/ufo/llm/llm_call.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import logging +import time from ufo.llm import AgentType from typing import Tuple @@ -92,8 +93,11 @@ def get_completions( api_type_lower = api_type.lower() service = BaseService.get_service(api_type_lower, agent_type, api_model.lower()) if service: - response, cost = service.chat_completion(messages, n) - return response, cost + t0 = time.monotonic() + response, cost_result = service.chat_completion(messages, n) + duration_ms = (time.monotonic() - t0) * 1000.0 + _emit_llm_call_event(agent_type, api_model, cost_result, duration_ms) + return response, cost_result.cost else: raise ValueError(f"API_TYPE {api_type} not supported") except Exception as e: @@ -109,3 +113,32 @@ def get_completions( ) else: raise e + + +def _emit_llm_call_event(agent_type, model: str, cost_result, duration_ms: float) -> None: + """Emit LLMCallEvent on the Galaxy event bus (best-effort; never raises).""" + try: + from galaxy.core.events import EventType, LLMCallEvent, get_event_bus + import asyncio + + agent_type_str = agent_type.value if hasattr(agent_type, "value") else str(agent_type) + event = LLMCallEvent( + event_type=EventType.LLM_CALL_COMPLETED, + source_id="llm_call", + timestamp=time.time(), + data={}, + agent_type=agent_type_str, + model=model, + prompt_tokens=cost_result.prompt_tokens, + completion_tokens=cost_result.completion_tokens, + cost=cost_result.cost, + duration_ms=duration_ms, + ) + bus = get_event_bus() + try: + loop = asyncio.get_running_loop() + loop.create_task(bus.publish_event(event)) + except RuntimeError: + asyncio.run(bus.publish_event(event)) + except Exception as exc: + logging.getLogger(__name__).debug("LLMCallEvent emit failed: %s", exc) diff --git a/ufo/llm/openai.py b/ufo/llm/openai.py index 613ee7b4f..2dcbdbf9b 100644 --- a/ufo/llm/openai.py +++ b/ufo/llm/openai.py @@ -183,20 +183,20 @@ def _chat_completion( prompt_tokens = usage.prompt_tokens completion_tokens = usage.completion_tokens - cost = self.get_cost_estimator( + cost_result = self.get_cost_estimator( self.api_type, self.model, self.prices, prompt_tokens, completion_tokens, ) - return collected_content, cost + return collected_content, cost_result else: usage = response.usage prompt_tokens = usage.prompt_tokens completion_tokens = usage.completion_tokens - cost = self.get_cost_estimator( + cost_result = self.get_cost_estimator( self.api_type, self.model, self.prices, @@ -204,7 +204,7 @@ def _chat_completion( completion_tokens, ) - return [response.choices[0].message.content], cost + return [response.choices[0].message.content], cost_result except openai.APITimeoutError as e: # Handle timeout error, e.g. retry or log @@ -287,7 +287,7 @@ def _responses_completion( input_tokens = usage.get("input_tokens", 0) output_tokens = usage.get("output_tokens", 0) - cost = self.get_cost_estimator( + cost_result = self.get_cost_estimator( self.api_type, self.model, self.prices, @@ -295,7 +295,7 @@ def _responses_completion( output_tokens, ) - return [content_text], cost + return [content_text], cost_result @staticmethod def _messages_to_responses_input( @@ -393,7 +393,7 @@ def _chat_completion_operator( input_tokens = 0 output_tokens = 0 - cost = self.get_cost_estimator( + cost_result = self.get_cost_estimator( self.api_type, self.config_llm["API_MODEL"], self.prices, @@ -401,7 +401,7 @@ def _chat_completion_operator( output_tokens, ) - return [response], cost + return [response], cost_result @functools.lru_cache() @staticmethod @@ -891,7 +891,7 @@ def chat_completion( input_tokens = 0 output_tokens = 0 - cost = self.get_cost_estimator( + cost_result = self.get_cost_estimator( self.api_type, self.api_model, self.prices, @@ -899,7 +899,7 @@ def chat_completion( output_tokens, ) - return [response], cost + return [response], cost_result def get_token_provider(self): """