Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 105 additions & 99 deletions internal/mcp/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,62 @@ func setupHTTPRequest(ctx context.Context, url string, requestBody []byte, heade
return httpReq, nil
}

// httpRequestResult contains the result of an HTTP request execution
type httpRequestResult struct {
StatusCode int
ResponseBody []byte
Header http.Header
}

// executeHTTPRequest executes an HTTP JSON-RPC request and returns the response details.
// This helper consolidates the common pattern of: create request → marshal → setup HTTP → execute → read response.
// It handles connection errors consistently and provides method-specific error messages.
// The headerModifier function allows callers to modify headers before the request is sent.
func (c *Connection) executeHTTPRequest(ctx context.Context, method string, params interface{}, requestID uint64, headerModifier func(*http.Request)) (*httpRequestResult, error) {
// Create JSON-RPC request
request := createJSONRPCRequest(requestID, method, params)

// Marshal request body
requestBody, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal %s request: %w", method, err)
}

// Create HTTP request with standard headers
httpReq, err := setupHTTPRequest(ctx, c.httpURL, requestBody, c.headers)
if err != nil {
return nil, err
}

// Allow caller to modify headers (e.g., add session ID)
if headerModifier != nil {
headerModifier(httpReq)
}

// Execute HTTP request
httpResp, err := c.httpClient.Do(httpReq)
if err != nil {
// Check if it's a connection error (cannot connect at all)
if isHTTPConnectionError(err) {
return nil, fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err)
}
return nil, fmt.Errorf("%s HTTP request failed: %w", method, err)
}
defer httpResp.Body.Close()

// Read response body
responseBody, err := io.ReadAll(httpResp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read %s response: %w", method, err)
}

return &httpRequestResult{
StatusCode: httpResp.StatusCode,
ResponseBody: responseBody,
Header: httpResp.Header,
}, nil
}

// NewConnection creates a new MCP connection using the official SDK
func NewConnection(ctx context.Context, serverID, command string, args []string, env map[string]string) (*Connection, error) {
logger.LogInfo("backend", "Creating new MCP backend connection, command=%s, args=%v", command, sanitize.SanitizeArgs(args))
Expand Down Expand Up @@ -573,42 +629,23 @@ func (c *Connection) initializeHTTPSession() (string, error) {
},
}

request := createJSONRPCRequest(requestID, "initialize", initParams)

requestBody, err := json.Marshal(request)
if err != nil {
return "", fmt.Errorf("failed to marshal initialize request: %w", err)
}

logConn.Printf("Sending initialize request: %s", string(requestBody))

// Create HTTP request with standard headers
httpReq, err := setupHTTPRequest(context.Background(), c.httpURL, requestBody, c.headers)
if err != nil {
return "", err
}
logConn.Printf("Sending initialize request")

// Generate a temporary session ID for the initialize request
// Some backends may require this header even during initialization
tempSessionID := fmt.Sprintf("awmg-init-%d", requestID)
httpReq.Header.Set("Mcp-Session-Id", tempSessionID)
logConn.Printf("Sending initialize with temporary session ID: %s", tempSessionID)

logConn.Printf("Sending initialize to %s", c.httpURL)

// Send request
httpResp, err := c.httpClient.Do(httpReq)
// Execute HTTP request with custom header modification
result, err := c.executeHTTPRequest(context.Background(), "initialize", initParams, requestID, func(httpReq *http.Request) {
httpReq.Header.Set("Mcp-Session-Id", tempSessionID)
logConn.Printf("Sending initialize with temporary session ID: %s", tempSessionID)
})
if err != nil {
// Check if it's a connection error (cannot connect at all)
if isHTTPConnectionError(err) {
return "", fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err)
}
return "", fmt.Errorf("failed to send initialize request to %s: %w", c.httpURL, err)
return "", err
}
defer httpResp.Body.Close()

// Capture the Mcp-Session-Id from response headers
sessionID := httpResp.Header.Get("Mcp-Session-Id")
sessionID := result.Header.Get("Mcp-Session-Id")
if sessionID != "" {
logConn.Printf("Captured Mcp-Session-Id from response: %s", sessionID)
} else {
Expand All @@ -618,31 +655,25 @@ func (c *Connection) initializeHTTPSession() (string, error) {
logConn.Printf("No Mcp-Session-Id in response, using temporary session ID: %s", sessionID)
}

// Read response body
responseBody, err := io.ReadAll(httpResp.Body)
if err != nil {
return "", fmt.Errorf("failed to read initialize response: %w", err)
}

logConn.Printf("Initialize response: status=%d, body_len=%d, session=%s", httpResp.StatusCode, len(responseBody), sessionID)
logConn.Printf("Initialize response: status=%d, body_len=%d, session=%s", result.StatusCode, len(result.ResponseBody), sessionID)

// Check for HTTP errors
if httpResp.StatusCode != http.StatusOK {
return "", fmt.Errorf("initialize failed: status=%d, body=%s", httpResp.StatusCode, string(responseBody))
if result.StatusCode != http.StatusOK {
return "", fmt.Errorf("initialize failed: status=%d, body=%s", result.StatusCode, string(result.ResponseBody))
}

// Parse JSON-RPC response to check for errors
// The response might be in SSE format (event: message\ndata: {...})
// Try to parse as JSON first, if that fails, try SSE format
var rpcResponse Response

if err := json.Unmarshal(responseBody, &rpcResponse); err != nil {
if err := json.Unmarshal(result.ResponseBody, &rpcResponse); err != nil {
// Try parsing as SSE format
logConn.Printf("Initial JSON parse failed, attempting SSE format parsing")
sseData, sseErr := parseSSEResponse(responseBody)
sseData, sseErr := parseSSEResponse(result.ResponseBody)
if sseErr != nil {
// Include the response body to help debug what the server actually returned
bodyPreview := string(responseBody)
bodyPreview := string(result.ResponseBody)
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "... (truncated)"
}
Expand Down Expand Up @@ -674,84 +705,59 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params
params = ensureToolCallArguments(params)
}

// Create JSON-RPC request
request := createJSONRPCRequest(requestID, method, params)

requestBody, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}

// Create HTTP request with standard headers
httpReq, err := setupHTTPRequest(ctx, c.httpURL, requestBody, c.headers)
if err != nil {
return nil, err
}

// Add Mcp-Session-Id header with priority:
// 1) Context session ID (if explicitly provided for this request)
// 2) Stored httpSessionID from initialization
var sessionID string
if ctxSessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && ctxSessionID != "" {
sessionID = ctxSessionID
logConn.Printf("Using session ID from context: %s", sessionID)
} else if c.httpSessionID != "" {
sessionID = c.httpSessionID
logConn.Printf("Using stored session ID from initialization: %s", sessionID)
}

if sessionID != "" {
httpReq.Header.Set("Mcp-Session-Id", sessionID)
} else {
logConn.Printf("No session ID available (backend may not require session management)")
}

logConn.Printf("Sending HTTP request to %s: method=%s, id=%d", c.httpURL, method, requestID)

// Send request using the reusable HTTP client
httpResp, err := c.httpClient.Do(httpReq)
if err != nil {
// Check if it's a connection error (cannot connect at all)
if isHTTPConnectionError(err) {
return nil, fmt.Errorf("cannot connect to HTTP backend at %s: %w", c.httpURL, err)
// Execute HTTP request with custom header modification for session ID
result, err := c.executeHTTPRequest(ctx, method, params, requestID, func(httpReq *http.Request) {
// Add Mcp-Session-Id header with priority:
// 1) Context session ID (if explicitly provided for this request)
// 2) Stored httpSessionID from initialization
var sessionID string
if ctxSessionID, ok := ctx.Value(SessionIDContextKey).(string); ok && ctxSessionID != "" {
sessionID = ctxSessionID
logConn.Printf("Using session ID from context: %s", sessionID)
} else if c.httpSessionID != "" {
sessionID = c.httpSessionID
logConn.Printf("Using stored session ID from initialization: %s", sessionID)
}
return nil, fmt.Errorf("failed to send HTTP request to %s: %w", c.httpURL, err)
}
defer httpResp.Body.Close()

// Read response
responseBody, err := io.ReadAll(httpResp.Body)
if sessionID != "" {
httpReq.Header.Set("Mcp-Session-Id", sessionID)
} else {
logConn.Printf("No session ID available (backend may not require session management)")
}
})
if err != nil {
return nil, fmt.Errorf("failed to read HTTP response: %w", err)
return nil, err
}

logConn.Printf("Received HTTP response: status=%d, body_len=%d", httpResp.StatusCode, len(responseBody))
logConn.Printf("Received HTTP response: status=%d, body_len=%d", result.StatusCode, len(result.ResponseBody))

// Parse JSON-RPC response
// The response might be in SSE format (event: message\ndata: {...})
// Try to parse as JSON first, if that fails, try SSE format
var rpcResponse Response

if err := json.Unmarshal(responseBody, &rpcResponse); err != nil {
if err := json.Unmarshal(result.ResponseBody, &rpcResponse); err != nil {
// Try parsing as SSE format
logConn.Printf("Initial JSON parse failed, attempting SSE format parsing")
sseData, sseErr := parseSSEResponse(responseBody)
sseData, sseErr := parseSSEResponse(result.ResponseBody)
if sseErr != nil {
// If we have a non-OK HTTP status and can't parse the response,
// construct a JSON-RPC error response with HTTP error details
if httpResp.StatusCode != http.StatusOK {
logConn.Printf("HTTP error status=%d, body cannot be parsed as JSON-RPC", httpResp.StatusCode)
if result.StatusCode != http.StatusOK {
logConn.Printf("HTTP error status=%d, body cannot be parsed as JSON-RPC", result.StatusCode)
return &Response{
JSONRPC: "2.0",
Error: &ResponseError{
Code: -32603, // Internal error
Message: fmt.Sprintf("HTTP %d: %s", httpResp.StatusCode, http.StatusText(httpResp.StatusCode)),
Data: json.RawMessage(responseBody),
Message: fmt.Sprintf("HTTP %d: %s", result.StatusCode, http.StatusText(result.StatusCode)),
Data: json.RawMessage(result.ResponseBody),
},
}, nil
}
// Include the response body to help debug what the server actually returned
bodyPreview := string(responseBody)
bodyPreview := string(result.ResponseBody)
if len(bodyPreview) > 500 {
bodyPreview = bodyPreview[:500] + "... (truncated)"
}
Expand All @@ -762,14 +768,14 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params
if err := json.Unmarshal(sseData, &rpcResponse); err != nil {
// If we have a non-OK HTTP status and can't parse the SSE data,
// construct a JSON-RPC error response with HTTP error details
if httpResp.StatusCode != http.StatusOK {
logConn.Printf("HTTP error status=%d, SSE data cannot be parsed as JSON-RPC", httpResp.StatusCode)
if result.StatusCode != http.StatusOK {
logConn.Printf("HTTP error status=%d, SSE data cannot be parsed as JSON-RPC", result.StatusCode)
return &Response{
JSONRPC: "2.0",
Error: &ResponseError{
Code: -32603, // Internal error
Message: fmt.Sprintf("HTTP %d: %s", httpResp.StatusCode, http.StatusText(httpResp.StatusCode)),
Data: json.RawMessage(responseBody),
Message: fmt.Sprintf("HTTP %d: %s", result.StatusCode, http.StatusText(result.StatusCode)),
Data: json.RawMessage(result.ResponseBody),
},
}, nil
}
Expand All @@ -781,14 +787,14 @@ func (c *Connection) sendHTTPRequest(ctx context.Context, method string, params
// Check for HTTP errors after parsing
// If we have a non-OK status but successfully parsed a JSON-RPC response,
// pass it through (it may already contain an error field)
if httpResp.StatusCode != http.StatusOK {
logConn.Printf("HTTP error status=%d with valid JSON-RPC response, passing through", httpResp.StatusCode)
if result.StatusCode != http.StatusOK {
logConn.Printf("HTTP error status=%d with valid JSON-RPC response, passing through", result.StatusCode)
// If the response doesn't already have an error, construct one
if rpcResponse.Error == nil {
rpcResponse.Error = &ResponseError{
Code: -32603, // Internal error
Message: fmt.Sprintf("HTTP %d: %s", httpResp.StatusCode, http.StatusText(httpResp.StatusCode)),
Data: responseBody,
Message: fmt.Sprintf("HTTP %d: %s", result.StatusCode, http.StatusText(result.StatusCode)),
Data: result.ResponseBody,
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var log = logger.New("main:main")

func main() {
log.Print("Starting MCP Gateway application")

// Build version string with metadata
versionStr := buildVersionString()
log.Printf("Built version string: %s", versionStr)
Expand Down