diff --git a/internal/mcp/connection.go b/internal/mcp/connection.go index cc3c3b33..d46fd303 100644 --- a/internal/mcp/connection.go +++ b/internal/mcp/connection.go @@ -170,6 +170,62 @@ func setupHTTPRequest(ctx context.Context, url string, requestBody []byte, heade return httpReq, nil } +// httpRequestResult contains the result of an HTTP request execution +type httpRequestResult struct { + StatusCode int + ResponseBody []byte + Header http.Header +} + +// executeHTTPRequest executes an HTTP JSON-RPC request and returns the response details. +// This helper consolidates the common pattern of: create request → marshal → setup HTTP → execute → read response. +// It handles connection errors consistently and provides method-specific error messages. +// The headerModifier function allows callers to modify headers before the request is sent. +func (c *Connection) executeHTTPRequest(ctx context.Context, method string, params interface{}, requestID uint64, headerModifier func(*http.Request)) (*httpRequestResult, error) { + // Create JSON-RPC request + request := createJSONRPCRequest(requestID, method, params) + + // Marshal request body + requestBody, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("failed to marshal %s request: %w", method, err) + } + + // Create HTTP request with standard headers + httpReq, err := setupHTTPRequest(ctx, c.httpURL, requestBody, c.headers) + if err != nil { + return nil, err + } + + // Allow caller to modify headers (e.g., add session ID) + if headerModifier != nil { + headerModifier(httpReq) + } + + // Execute HTTP request + httpResp, err := c.httpClient.Do(httpReq) + if err != nil { + // Check if it's a connection error (cannot connect at all) + if isHTTPConnectionError(err) { + return nil, fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err) + } + return nil, fmt.Errorf("%s HTTP request failed: %w", method, err) + } + defer httpResp.Body.Close() + + // Read response body + responseBody, err := io.ReadAll(httpResp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read %s response: %w", method, err) + } + + return &httpRequestResult{ + StatusCode: httpResp.StatusCode, + ResponseBody: responseBody, + Header: httpResp.Header, + }, nil +} + // NewConnection creates a new MCP connection using the official SDK func NewConnection(ctx context.Context, serverID, command string, args []string, env map[string]string) (*Connection, error) { logger.LogInfo("backend", "Creating new MCP backend connection, command=%s, args=%v", command, sanitize.SanitizeArgs(args)) @@ -573,42 +629,23 @@ func (c *Connection) initializeHTTPSession() (string, error) { }, } - request := createJSONRPCRequest(requestID, "initialize", initParams) - - requestBody, err := json.Marshal(request) - if err != nil { - return "", fmt.Errorf("failed to marshal initialize request: %w", err) - } - - logConn.Printf("Sending initialize request: %s", string(requestBody)) - - // Create HTTP request with standard headers - httpReq, err := setupHTTPRequest(context.Background(), c.httpURL, requestBody, c.headers) - if err != nil { - return "", err - } + logConn.Printf("Sending initialize request") // Generate a temporary session ID for the initialize request // Some backends may require this header even during initialization tempSessionID := fmt.Sprintf("awmg-init-%d", requestID) - httpReq.Header.Set("Mcp-Session-Id", tempSessionID) - logConn.Printf("Sending initialize with temporary session ID: %s", tempSessionID) - logConn.Printf("Sending initialize to %s", c.httpURL) - - // Send request - httpResp, err := c.httpClient.Do(httpReq) + // Execute HTTP request with custom header modification + result, err := c.executeHTTPRequest(context.Background(), "initialize", initParams, requestID, func(httpReq *http.Request) { + httpReq.Header.Set("Mcp-Session-Id", tempSessionID) + logConn.Printf("Sending initialize with temporary session ID: %s", tempSessionID) + }) if err != nil { - // Check if it's a connection error (cannot connect at all) - if isHTTPConnectionError(err) { - return "", fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err) - } - return "", fmt.Errorf("failed to send initialize request to %s: %w", c.httpURL, err) + return "", err } - defer httpResp.Body.Close() // Capture the Mcp-Session-Id from response headers - sessionID := httpResp.Header.Get("Mcp-Session-Id") + sessionID := result.Header.Get("Mcp-Session-Id") if sessionID != "" { logConn.Printf("Captured Mcp-Session-Id from response: %s", sessionID) } else { @@ -618,17 +655,11 @@ func (c *Connection) initializeHTTPSession() (string, error) { logConn.Printf("No Mcp-Session-Id in response, using temporary session ID: %s", sessionID) } - // Read response body - responseBody, err := io.ReadAll(httpResp.Body) - if err != nil { - return "", fmt.Errorf("failed to read initialize response: %w", err) - } - - logConn.Printf("Initialize response: status=%d, body_len=%d, session=%s", httpResp.StatusCode, len(responseBody), sessionID) + logConn.Printf("Initialize response: status=%d, body_len=%d, session=%s", result.StatusCode, len(result.ResponseBody), sessionID) // Check for HTTP errors - if httpResp.StatusCode != http.StatusOK { - return "", fmt.Errorf("initialize failed: status=%d, body=%s", httpResp.StatusCode, string(responseBody)) + if result.StatusCode != http.StatusOK { + return "", fmt.Errorf("initialize failed: status=%d, body=%s", result.StatusCode, string(result.ResponseBody)) } // Parse JSON-RPC response to check for errors @@ -636,13 +667,13 @@ func (c *Connection) initializeHTTPSession() (string, error) { // Try to parse as JSON first, if that fails, try SSE format var rpcResponse Response - if err := json.Unmarshal(responseBody, &rpcResponse); err != nil { + if err := json.Unmarshal(result.ResponseBody, &rpcResponse); err != nil { // Try parsing as SSE format logConn.Printf("Initial JSON parse failed, attempting SSE format parsing") - sseData, sseErr := parseSSEResponse(responseBody) + sseData, sseErr := parseSSEResponse(result.ResponseBody) if sseErr != nil { // Include the response body to help debug what the server actually returned - bodyPreview := string(responseBody) + bodyPreview := string(result.ResponseBody) if len(bodyPreview) > 500 { bodyPreview = bodyPreview[:500] + "... (truncated)" } @@ -674,84 +705,59 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params params = ensureToolCallArguments(params) } - // Create JSON-RPC request - request := createJSONRPCRequest(requestID, method, params) - - requestBody, err := json.Marshal(request) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - // Create HTTP request with standard headers - httpReq, err := setupHTTPRequest(ctx, c.httpURL, requestBody, c.headers) - if err != nil { - return nil, err - } - - // Add Mcp-Session-Id header with priority: - // 1) Context session ID (if explicitly provided for this request) - // 2) Stored httpSessionID from initialization - var sessionID string - if ctxSessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && ctxSessionID != "" { - sessionID = ctxSessionID - logConn.Printf("Using session ID from context: %s", sessionID) - } else if c.httpSessionID != "" { - sessionID = c.httpSessionID - logConn.Printf("Using stored session ID from initialization: %s", sessionID) - } - - if sessionID != "" { - httpReq.Header.Set("Mcp-Session-Id", sessionID) - } else { - logConn.Printf("No session ID available (backend may not require session management)") - } - logConn.Printf("Sending HTTP request to %s: method=%s, id=%d", c.httpURL, method, requestID) - // Send request using the reusable HTTP client - httpResp, err := c.httpClient.Do(httpReq) - if err != nil { - // Check if it's a connection error (cannot connect at all) - if isHTTPConnectionError(err) { - return nil, fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err) + // Execute HTTP request with custom header modification for session ID + result, err := c.executeHTTPRequest(ctx, method, params, requestID, func(httpReq *http.Request) { + // Add Mcp-Session-Id header with priority: + // 1) Context session ID (if explicitly provided for this request) + // 2) Stored httpSessionID from initialization + var sessionID string + if ctxSessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && ctxSessionID != "" { + sessionID = ctxSessionID + logConn.Printf("Using session ID from context: %s", sessionID) + } else if c.httpSessionID != "" { + sessionID = c.httpSessionID + logConn.Printf("Using stored session ID from initialization: %s", sessionID) } - return nil, fmt.Errorf("failed to send HTTP request to %s: %w", c.httpURL, err) - } - defer httpResp.Body.Close() - // Read response - responseBody, err := io.ReadAll(httpResp.Body) + if sessionID != "" { + httpReq.Header.Set("Mcp-Session-Id", sessionID) + } else { + logConn.Printf("No session ID available (backend may not require session management)") + } + }) if err != nil { - return nil, fmt.Errorf("failed to read HTTP response: %w", err) + return nil, err } - logConn.Printf("Received HTTP response: status=%d, body_len=%d", httpResp.StatusCode, len(responseBody)) + logConn.Printf("Received HTTP response: status=%d, body_len=%d", result.StatusCode, len(result.ResponseBody)) // Parse JSON-RPC response // The response might be in SSE format (event: message\ndata: {...}) // Try to parse as JSON first, if that fails, try SSE format var rpcResponse Response - if err := json.Unmarshal(responseBody, &rpcResponse); err != nil { + if err := json.Unmarshal(result.ResponseBody, &rpcResponse); err != nil { // Try parsing as SSE format logConn.Printf("Initial JSON parse failed, attempting SSE format parsing") - sseData, sseErr := parseSSEResponse(responseBody) + sseData, sseErr := parseSSEResponse(result.ResponseBody) if sseErr != nil { // If we have a non-OK HTTP status and can't parse the response, // construct a JSON-RPC error response with HTTP error details - if httpResp.StatusCode != http.StatusOK { - logConn.Printf("HTTP error status=%d, body cannot be parsed as JSON-RPC", httpResp.StatusCode) + if result.StatusCode != http.StatusOK { + logConn.Printf("HTTP error status=%d, body cannot be parsed as JSON-RPC", result.StatusCode) return &Response{ JSONRPC: "2.0", Error: &ResponseError{ Code: -32603, // Internal error - Message: fmt.Sprintf("HTTP %d: %s", httpResp.StatusCode, http.StatusText(httpResp.StatusCode)), - Data: json.RawMessage(responseBody), + Message: fmt.Sprintf("HTTP %d: %s", result.StatusCode, http.StatusText(result.StatusCode)), + Data: json.RawMessage(result.ResponseBody), }, }, nil } // Include the response body to help debug what the server actually returned - bodyPreview := string(responseBody) + bodyPreview := string(result.ResponseBody) if len(bodyPreview) > 500 { bodyPreview = bodyPreview[:500] + "... (truncated)" } @@ -762,14 +768,14 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params if err := json.Unmarshal(sseData, &rpcResponse); err != nil { // If we have a non-OK HTTP status and can't parse the SSE data, // construct a JSON-RPC error response with HTTP error details - if httpResp.StatusCode != http.StatusOK { - logConn.Printf("HTTP error status=%d, SSE data cannot be parsed as JSON-RPC", httpResp.StatusCode) + if result.StatusCode != http.StatusOK { + logConn.Printf("HTTP error status=%d, SSE data cannot be parsed as JSON-RPC", result.StatusCode) return &Response{ JSONRPC: "2.0", Error: &ResponseError{ Code: -32603, // Internal error - Message: fmt.Sprintf("HTTP %d: %s", httpResp.StatusCode, http.StatusText(httpResp.StatusCode)), - Data: json.RawMessage(responseBody), + Message: fmt.Sprintf("HTTP %d: %s", result.StatusCode, http.StatusText(result.StatusCode)), + Data: json.RawMessage(result.ResponseBody), }, }, nil } @@ -781,14 +787,14 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params // Check for HTTP errors after parsing // If we have a non-OK status but successfully parsed a JSON-RPC response, // pass it through (it may already contain an error field) - if httpResp.StatusCode != http.StatusOK { - logConn.Printf("HTTP error status=%d with valid JSON-RPC response, passing through", httpResp.StatusCode) + if result.StatusCode != http.StatusOK { + logConn.Printf("HTTP error status=%d with valid JSON-RPC response, passing through", result.StatusCode) // If the response doesn't already have an error, construct one if rpcResponse.Error == nil { rpcResponse.Error = &ResponseError{ Code: -32603, // Internal error - Message: fmt.Sprintf("HTTP %d: %s", httpResp.StatusCode, http.StatusText(httpResp.StatusCode)), - Data: responseBody, + Message: fmt.Sprintf("HTTP %d: %s", result.StatusCode, http.StatusText(result.StatusCode)), + Data: result.ResponseBody, } } } diff --git a/main.go b/main.go index 97160503..299f88f1 100644 --- a/main.go +++ b/main.go @@ -13,7 +13,7 @@ var log = logger.New("main:main") func main() { log.Print("Starting MCP Gateway application") - + // Build version string with metadata versionStr := buildVersionString() log.Printf("Built version string: %s", versionStr)