diff --git a/internal/api/handlers/v0/servers.go b/internal/api/handlers/v0/servers.go index f9f7ba5f7..52a3cfbbc 100644 --- a/internal/api/handlers/v0/servers.go +++ b/internal/api/handlers/v0/servers.go @@ -42,6 +42,8 @@ type ServerVersionsInput struct { } // RegisterServersEndpoints registers all server-related endpoints with a custom path prefix +// +//nolint:cyclop func RegisterServersEndpoints(api huma.API, pathPrefix string, registry service.RegistryService) { // List servers endpoint huma.Register(api, huma.Operation{ @@ -52,6 +54,10 @@ func RegisterServersEndpoints(api huma.API, pathPrefix string, registry service. Description: "Get a paginated list of MCP servers from the registry", Tags: []string{"servers"}, }, func(ctx context.Context, input *ListServersInput) (*Response[apiv0.ServerListResponse], error) { + if containsNULByte(input.Cursor) { + return nil, huma.Error400BadRequest("Invalid cursor: NUL byte not allowed") + } + // Build filter from input parameters filter := &database.ServerFilter{} @@ -119,12 +125,18 @@ func RegisterServersEndpoints(api huma.API, pathPrefix string, registry service. if err != nil { return nil, huma.Error400BadRequest("Invalid server name encoding", err) } + if containsNULByte(serverName) { + return nil, huma.Error400BadRequest("Invalid server name: NUL byte not allowed") + } // URL-decode the version version, err := url.PathUnescape(input.Version) if err != nil { return nil, huma.Error400BadRequest("Invalid version encoding", err) } + if containsNULByte(version) { + return nil, huma.Error400BadRequest("Invalid version: NUL byte not allowed") + } var serverResponse *apiv0.ServerResponse // Handle "latest" as a special version @@ -160,6 +172,9 @@ func RegisterServersEndpoints(api huma.API, pathPrefix string, registry service. if err != nil { return nil, huma.Error400BadRequest("Invalid server name encoding", err) } + if containsNULByte(serverName) { + return nil, huma.Error400BadRequest("Invalid server name: NUL byte not allowed") + } // Get all versions for this server servers, err := registry.GetAllVersionsByServerName(ctx, serverName) @@ -186,3 +201,7 @@ func RegisterServersEndpoints(api huma.API, pathPrefix string, registry service. }, nil }) } + +func containsNULByte(s string) bool { + return strings.IndexByte(s, 0) >= 0 +} diff --git a/internal/api/handlers/v0/servers_test.go b/internal/api/handlers/v0/servers_test.go index e0d19a031..6eaf8b2d9 100644 --- a/internal/api/handlers/v0/servers_test.go +++ b/internal/api/handlers/v0/servers_test.go @@ -461,6 +461,8 @@ func TestServersEndpointEdgeCases(t *testing.T) { {"limit too high", "?limit=1000", http.StatusUnprocessableEntity, "validation failed"}, {"negative limit", "?limit=-1", http.StatusUnprocessableEntity, "validation failed"}, {"invalid updated_since format", "?updated_since=invalid", http.StatusBadRequest, "Invalid updated_since format"}, + {"cursor contains NUL", "?cursor=%00", http.StatusBadRequest, "Invalid cursor"}, + {"cursor contains non NUL", "?cursor=server", http.StatusOK, ""}, {"future updated_since", "?updated_since=2030-01-01T00:00:00Z", http.StatusOK, ""}, {"very old updated_since", "?updated_since=1990-01-01T00:00:00Z", http.StatusOK, ""}, {"empty search parameter", "?search=", http.StatusOK, ""}, @@ -489,6 +491,16 @@ func TestServersEndpointEdgeCases(t *testing.T) { } }) + t.Run("path parameter NUL byte rejected", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/v0/servers/%00/versions", nil) + w := httptest.NewRecorder() + + mux.ServeHTTP(w, req) + + assert.Equal(t, http.StatusBadRequest, w.Code) + assert.Contains(t, w.Body.String(), "Invalid server name") + }) + t.Run("response structure validation", func(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/v0/servers", nil) w := httptest.NewRecorder()