Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 73 additions & 10 deletions pkg/transport/proxy/transparent/transparent_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"regexp"
"strings"
"sync"
"sync/atomic"
"time"

"go.opentelemetry.io/otel"
Expand Down Expand Up @@ -69,8 +70,8 @@ type TransparentProxy struct {
// Sessions for tracking state
sessionManager *session.Manager

// If mcp server has been initialized
IsServerInitialized bool
// If mcp server has been initialized (atomic access)
isServerInitialized atomic.Bool

// Listener for the HTTP server
listener net.Listener
Expand All @@ -83,9 +84,31 @@ type TransparentProxy struct {

// Callback when health check fails (for remote servers)
onHealthCheckFailed types.HealthCheckFailedCallback

// Health check interval (default: 10 seconds)
healthCheckInterval time.Duration
}

// NewTransparentProxy creates a new transparent proxy with optional middlewares.
const (
// DefaultHealthCheckInterval is the default interval for health checks
DefaultHealthCheckInterval = 10 * time.Second
)

// Option is a functional option for configuring TransparentProxy
type Option func(*TransparentProxy)

// withHealthCheckInterval sets the health check interval.
// This is primarily useful for testing with shorter intervals.
// Ignores non-positive intervals; default will be used.
func withHealthCheckInterval(interval time.Duration) Option {
return func(p *TransparentProxy) {
if interval > 0 {
p.healthCheckInterval = interval
}
}
}

// NewTransparentProxy creates a new transparent proxy with optional middlewares and configuration options.
func NewTransparentProxy(
host string,
port int,
Expand All @@ -97,6 +120,34 @@ func NewTransparentProxy(
transportType string,
onHealthCheckFailed types.HealthCheckFailedCallback,
middlewares ...types.NamedMiddleware,
) *TransparentProxy {
return newTransparentProxyWithOptions(
host,
port,
targetURI,
prometheusHandler,
authInfoHandler,
enableHealthCheck,
isRemote,
transportType,
onHealthCheckFailed,
middlewares,
)
}

// newTransparentProxyWithOptions creates a new transparent proxy with optional configuration.
func newTransparentProxyWithOptions(
host string,
port int,
targetURI string,
prometheusHandler http.Handler,
authInfoHandler http.Handler,
enableHealthCheck bool,
isRemote bool,
transportType string,
onHealthCheckFailed types.HealthCheckFailedCallback,
middlewares []types.NamedMiddleware,
options ...Option,
) *TransparentProxy {
proxy := &TransparentProxy{
host: host,
Expand All @@ -110,6 +161,12 @@ func NewTransparentProxy(
isRemote: isRemote,
transportType: transportType,
onHealthCheckFailed: onHealthCheckFailed,
healthCheckInterval: DefaultHealthCheckInterval,
}

// Apply options
for _, opt := range options {
opt(proxy)
}

// Create health checker always for Kubernetes probes
Expand All @@ -128,14 +185,16 @@ type tracingTransport struct {
}

func (p *TransparentProxy) setServerInitialized() {
if !p.IsServerInitialized {
p.mutex.Lock()
p.IsServerInitialized = true
p.mutex.Unlock()
if p.isServerInitialized.CompareAndSwap(false, true) {
logger.Infof("Server was initialized successfully for %s", p.targetURI)
}
}

// serverInitialized returns whether the server has been initialized (thread-safe)
func (p *TransparentProxy) serverInitialized() bool {
return p.isServerInitialized.Load()
}

func (t *tracingTransport) forward(req *http.Request) (*http.Response, error) {
tr := t.base
if tr == nil {
Expand Down Expand Up @@ -191,7 +250,7 @@ func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error)
return resp, nil
}
// status was ok and we saw an initialize call
if sawInitialize && !t.p.IsServerInitialized {
if sawInitialize && !t.p.serverInitialized() {
t.p.setServerInitialized()
return resp, nil
}
Expand Down Expand Up @@ -409,7 +468,11 @@ func (p *TransparentProxy) CloseListener() error {
}

func (p *TransparentProxy) monitorHealth(parentCtx context.Context) {
ticker := time.NewTicker(10 * time.Second)
interval := p.healthCheckInterval
if interval == 0 {
interval = DefaultHealthCheckInterval
}
ticker := time.NewTicker(interval)
defer ticker.Stop()

for {
Expand All @@ -422,7 +485,7 @@ func (p *TransparentProxy) monitorHealth(parentCtx context.Context) {
return
case <-ticker.C:
// Perform health check only if mcp server has been initialized
if p.IsServerInitialized {
if p.serverInitialized() {
alive := p.healthChecker.CheckHealth(parentCtx)
if alive.Status != healthcheck.StatusHealthy {
logger.Infof("Health check failed for %s; initiating proxy shutdown", p.targetURI)
Expand Down
Loading
Loading