diff --git a/docs/reference/hooks.md b/docs/reference/hooks.md index ebf64a6e..1f593f09 100644 --- a/docs/reference/hooks.md +++ b/docs/reference/hooks.md @@ -32,11 +32,20 @@ if err != nil { http.ListenAndServe(":8080", handler) ``` -Register middleware before calling `Handler()` or `ServeMux()`. Middleware runs -in registration order; a middleware that does not call `next` owns the response -and skips generated headers, metrics, static serving, and request-time route -dispatch for that request. App-owned startup code can still wrap the returned -handler with ordinary middleware: +Register middleware before calling `App()`, `Handler()`, or `ServeMux()`. +Middleware runs in registration order; a middleware that does not call `next` +owns the response and skips generated headers, metrics, static serving, and +request-time route dispatch for that request. + +`App()` snapshots the registered chain around its raw application mux. Routes +mounted by lifecycle services before server startup therefore pass through the +same middleware as health, static, backend, dynamic sitemap, and realtime +routes. `ServeMux()` mounts the generated route graph behind the same finalized +wrapper; routes added directly to that returned mux afterward are caller-owned +and need their own middleware policy. + +App-owned startup code can still wrap the returned handler with ordinary +middleware: ```go handler, err := gowdkapp.Handler() diff --git a/internal/appgen/appgen_test.go b/internal/appgen/appgen_test.go index 08347350..89642df5 100644 --- a/internal/appgen/appgen_test.go +++ b/internal/appgen/appgen_test.go @@ -105,7 +105,10 @@ func TestGenerateWritesEmbeddedSPAApp(t *testing.T) { "func configuredServices() ([]gowdkruntime.Service, error)", "func RegisterMiddleware(middleware gowdkruntime.Middleware)", `gowdkruntime "github.com/cssbruno/gowdk/runtime/app"`, - `mux.Handle("/", gowdkruntime.ApplyMiddlewares(&gowdkruntime.Handler{`, + `handler := gowdkruntime.ApplyMiddlewares(mux, registeredMiddlewares()...)`, + `Handler: handler, Mux: mux`, + `return gowdkruntime.ApplyMiddlewares(mux, registeredMiddlewares()...), nil`, + `mux.Handle("/", &gowdkruntime.Handler{`, `Identity: identity,`, `Assets: gowdkruntime.LoadAssetManifest(root),`, `ErrorPages: gowdkruntime.LoadErrorPages(root),`, @@ -181,8 +184,19 @@ func TestGenerateWritesDynamicSitemapRoute(t *testing.T) { } assertSourceOrder(t, source, `mux.Handle("/sitemap.xml", gowdkseo.Handler`, - `mux.Handle("/", gowdkruntime.ApplyMiddlewares`, + `mux.Handle("/", &gowdkruntime.Handler`, ) + for _, want := range []string{ + `handler := gowdkruntime.ApplyMiddlewares(mux, registeredMiddlewares()...)`, + `mux.Handle("/", gowdkruntime.ApplyMiddlewares(routes, registeredMiddlewares()...))`, + } { + if !strings.Contains(source, want) { + t.Fatalf("expected generated middleware pipeline to contain %q:\n%s", want, source) + } + } + if strings.Contains(source, `mux.Handle("/", gowdkruntime.ApplyMiddlewares(&gowdkruntime.Handler`) { + t.Fatalf("generated root route must stay unwrapped until the final mux is composed:\n%s", source) + } } func TestGenerateWiresConfiguredLifecycleServices(t *testing.T) { @@ -1730,7 +1744,10 @@ func TestGenerateBackendAppRegistersBackendRoutes(t *testing.T) { `func RegisterMiddleware(middleware gowdkruntime.Middleware)`, `if err := validateEnvContract(); err != nil {`, `backendRouter, err := newBackendRouter()`, - `mux.Handle("/", gowdkruntime.ApplyMiddlewares(backendRouter, registeredMiddlewares()...))`, + `handler := gowdkruntime.ApplyMiddlewares(mux, registeredMiddlewares()...)`, + `return gowdkruntime.ApplyMiddlewares(mux, registeredMiddlewares()...), nil`, + `mux.Handle("/", backendRouter)`, + `mux.Handle("/", gowdkruntime.ApplyMiddlewares(routes, registeredMiddlewares()...))`, `func validateEnvContract() error`, `value := os.Getenv("GOWDK_TEST_DATABASE_URL")`, `missing = append(missing, "GOWDK_TEST_DATABASE_URL is required but is not set")`, @@ -1751,6 +1768,9 @@ func TestGenerateBackendAppRegistersBackendRoutes(t *testing.T) { if strings.Contains(source, `func backend(response http.ResponseWriter, request *http.Request) bool`) { t.Fatalf("expected backend-only app to use BackendRouter instead of generated backend dispatcher:\n%s", source) } + if strings.Contains(source, `mux.Handle("/", gowdkruntime.ApplyMiddlewares(backendRouter`) { + t.Fatalf("backend-only route mux should stay raw and be wrapped after route graph construction:\n%s", source) + } } func TestGenerateRenamesBackendAliasReservedByGeneratedRuntime(t *testing.T) { @@ -1860,7 +1880,10 @@ func TestGenerateBackendAppWiresSecurityHeaders(t *testing.T) { source := string(payload) for _, expected := range []string{ `"strings"`, - `mux.Handle("/", gowdkruntime.ApplyMiddlewares(http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {`, + `handler := gowdkruntime.ApplyMiddlewares(mux, registeredMiddlewares()...)`, + `return gowdkruntime.ApplyMiddlewares(mux, registeredMiddlewares()...), nil`, + `mux.Handle("/", http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {`, + `mux.Handle("/", gowdkruntime.ApplyMiddlewares(routes, registeredMiddlewares()...))`, `for name, value := range map[string]string{"X-Frame-Options": "DENY"} {`, `if strings.TrimSpace(name) == "" {`, `response.Header().Set(name, value)`, @@ -1873,6 +1896,9 @@ func TestGenerateBackendAppWiresSecurityHeaders(t *testing.T) { if strings.Contains(source, `mux.Handle("/", backendRouter)`) { t.Fatalf("backend-only app with configured security headers should wrap the router:\n%s", source) } + if strings.Contains(source, `mux.Handle("/", gowdkruntime.ApplyMiddlewares(http.HandlerFunc`) { + t.Fatalf("security-header route mux should stay raw and be wrapped after route graph construction:\n%s", source) + } } func TestGenerateWiresCORSForAPIRoutes(t *testing.T) { diff --git a/internal/appgen/source.go b/internal/appgen/source.go index e593d0fc..627bcce9 100644 --- a/internal/appgen/source.go +++ b/internal/appgen/source.go @@ -422,7 +422,15 @@ func handlerDecl() ast.Decl { {Type: sel("http", "Handler")}, {Type: id("error")}, }, []ast.Stmt{ - &ast.ReturnStmt{Results: []ast.Expr{call(sel("ServeMux"))}}, + define([]ast.Expr{id("mux"), id("err")}, call(id("newServeMux"), call(sel("gowdkruntime", "InstanceIdentity")))), + &ast.IfStmt{ + Cond: notNil("err"), + Body: block(&ast.ReturnStmt{Results: []ast.Expr{id("nil"), id("err")}}), + }, + &ast.ReturnStmt{Results: []ast.Expr{ + applyRegisteredMiddlewaresExpr(id("mux")), + id("nil"), + }}, }) } @@ -431,7 +439,14 @@ func serveMuxDecl(options Options, embedded bool) ast.Decl { {Type: &ast.StarExpr{X: sel("http", "ServeMux")}}, {Type: id("error")}, }, []ast.Stmt{ - &ast.ReturnStmt{Results: []ast.Expr{call(id("newServeMux"), call(sel("gowdkruntime", "InstanceIdentity")))}}, + define([]ast.Expr{id("routes"), id("err")}, call(id("newServeMux"), call(sel("gowdkruntime", "InstanceIdentity")))), + &ast.IfStmt{ + Cond: notNil("err"), + Body: block(&ast.ReturnStmt{Results: []ast.Expr{id("nil"), id("err")}}), + }, + define([]ast.Expr{id("mux")}, call(sel("http", "NewServeMux"))), + exprStmt(call(selExpr(id("mux"), "Handle"), stringLit("/"), applyRegisteredMiddlewaresExpr(id("routes")))), + &ast.ReturnStmt{Results: []ast.Expr{id("mux"), id("nil")}}, }) } @@ -468,13 +483,12 @@ func newServeMuxDecl(options Options, embedded bool) ast.Decl { stmts = append(stmts, exprStmt(call(selExpr(id("mux"), "Handle"), id("RealtimeEventsPath"), call(id("realtimeEventsHandler"))))) } if embedded { - stmts = append(stmts, exprStmt(call(selExpr(id("mux"), "Handle"), stringLit("/"), &ast.CallExpr{ - Fun: sel("gowdkruntime", "ApplyMiddlewares"), - Args: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: &ast.CompositeLit{ + stmts = append(stmts, exprStmt(call(selExpr(id("mux"), "Handle"), stringLit("/"), &ast.UnaryExpr{ + Op: token.AND, + X: &ast.CompositeLit{ Type: sel("gowdkruntime", "Handler"), Elts: embeddedHandlerFields(options, id("identity")), - }}, call(id("registeredMiddlewares"))}, - Ellipsis: token.Pos(1), + }, }))) } else { stmts = append(stmts, exprStmt(call(selExpr(id("mux"), "Handle"), stringLit("/"), backendOnlyHandlerExpr(options)))) @@ -801,11 +815,7 @@ func backendOnlyHandlerExpr(options Options) ast.Expr { if headers := securityHeadersExpr(options); headers != nil { handler = call(sel("http", "HandlerFunc"), backendOnlySecurityHeadersHandlerFunc(handler, headers)) } - return &ast.CallExpr{ - Fun: sel("gowdkruntime", "ApplyMiddlewares"), - Args: []ast.Expr{handler, call(id("registeredMiddlewares"))}, - Ellipsis: token.Pos(1), - } + return handler } func backendOnlyBaseHandlerExpr(options Options) ast.Expr { diff --git a/internal/appgen/source_lifecycle.go b/internal/appgen/source_lifecycle.go index 3d5da92e..20082f2b 100644 --- a/internal/appgen/source_lifecycle.go +++ b/internal/appgen/source_lifecycle.go @@ -21,6 +21,7 @@ func appDecl(options Options) ast.Decl { Cond: notNil("err"), Body: block(&ast.ReturnStmt{Results: []ast.Expr{id("nil"), id("err")}}), }, + define([]ast.Expr{id("handler")}, applyRegisteredMiddlewaresExpr(id("mux"))), define([]ast.Expr{id("values")}, &ast.CompositeLit{ Type: &ast.MapType{Key: id("string"), Value: id("any")}, }), @@ -35,7 +36,7 @@ func appDecl(options Options) ast.Decl { &ast.ReturnStmt{Results: []ast.Expr{&ast.UnaryExpr{Op: token.AND, X: &ast.CompositeLit{ Type: sel("gowdkruntime", "Application"), Elts: []ast.Expr{ - keyValue("Handler", id("mux")), + keyValue("Handler", id("handler")), keyValue("Mux", id("mux")), keyValue("Identity", id("identity")), keyValue("Services", id("services")), diff --git a/internal/appgen/source_middleware.go b/internal/appgen/source_middleware.go index f1b0f41b..30ab291a 100644 --- a/internal/appgen/source_middleware.go +++ b/internal/appgen/source_middleware.go @@ -40,6 +40,19 @@ func registerMiddlewareDecl() ast.Decl { }) } +// applyRegisteredMiddlewaresExpr snapshots the registered chain around the +// finalized route graph instead of wrapping only its fallback route. +func applyRegisteredMiddlewaresExpr(handler ast.Expr) ast.Expr { + return &ast.CallExpr{ + Fun: sel("gowdkruntime", "ApplyMiddlewares"), + Args: []ast.Expr{ + handler, + call(id("registeredMiddlewares")), + }, + Ellipsis: token.Pos(1), + } +} + func registeredMiddlewaresDecl() ast.Decl { return funcDecl("registeredMiddlewares", nil, []*ast.Field{ {Type: &ast.ArrayType{Elt: sel("gowdkruntime", "Middleware")}}, diff --git a/internal/appgen/testdata/generated_go_golden/app.go.golden b/internal/appgen/testdata/generated_go_golden/app.go.golden index f537fd81..6d6f21f1 100644 --- a/internal/appgen/testdata/generated_go_golden/app.go.golden +++ b/internal/appgen/testdata/generated_go_golden/app.go.golden @@ -49,16 +49,21 @@ func App() (*gowdkruntime.Application, error) { if err != nil { return nil, err } + handler := gowdkruntime.ApplyMiddlewares(mux, registeredMiddlewares()...) values := map[string]any{} values[gowdkruntime.ServiceValueContractRegistry] = ContractRegistry() services, err := configuredServices() if err != nil { return nil, err } - return &gowdkruntime.Application{Handler: mux, Mux: mux, Identity: identity, Services: services, Values: values}, nil + return &gowdkruntime.Application{Handler: handler, Mux: mux, Identity: identity, Services: services, Values: values}, nil } func Handler() (http.Handler, error) { - return ServeMux() + mux, err := newServeMux(gowdkruntime.InstanceIdentity()) + if err != nil { + return nil, err + } + return gowdkruntime.ApplyMiddlewares(mux, registeredMiddlewares()...), nil } func newServeMux(identity gowdkruntime.Identity) (*http.ServeMux, error) { if err := loadEnvFile(); err != nil { @@ -79,11 +84,17 @@ func newServeMux(identity gowdkruntime.Identity) (*http.ServeMux, error) { return nil, err } mux := http.NewServeMux() - mux.Handle("/", gowdkruntime.ApplyMiddlewares(&gowdkruntime.Handler{Root: root, Identity: identity, Assets: gowdkruntime.LoadAssetManifest(root), ErrorPages: gowdkruntime.LoadErrorPages(root), Backend: backendRouter.HandlerFunc(), CSRF: csrfTokenSource, SSRExact: ssrExact, SSRDynamic: ssrDynamic, RequestTimeout: gowdkruntime.DefaultRequestTimeout}, registeredMiddlewares()...)) + mux.Handle("/", &gowdkruntime.Handler{Root: root, Identity: identity, Assets: gowdkruntime.LoadAssetManifest(root), ErrorPages: gowdkruntime.LoadErrorPages(root), Backend: backendRouter.HandlerFunc(), CSRF: csrfTokenSource, SSRExact: ssrExact, SSRDynamic: ssrDynamic, RequestTimeout: gowdkruntime.DefaultRequestTimeout}) return mux, nil } func ServeMux() (*http.ServeMux, error) { - return newServeMux(gowdkruntime.InstanceIdentity()) + routes, err := newServeMux(gowdkruntime.InstanceIdentity()) + if err != nil { + return nil, err + } + mux := http.NewServeMux() + mux.Handle("/", gowdkruntime.ApplyMiddlewares(routes, registeredMiddlewares()...)) + return mux, nil } func configuredServices() ([]gowdkruntime.Service, error) { return nil, nil diff --git a/runtime/app/lifecycle_test.go b/runtime/app/lifecycle_test.go index 87ceae64..9e64e767 100644 --- a/runtime/app/lifecycle_test.go +++ b/runtime/app/lifecycle_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "net/http/httptest" "strings" "sync/atomic" "testing" @@ -44,6 +45,33 @@ func TestRunMountsServicesBeforeRun(t *testing.T) { } } +func TestMiddlewareWrappedMuxIncludesRoutesMountedAfterComposition(t *testing.T) { + mux := http.NewServeMux() + var calls atomic.Int32 + handler := ApplyMiddlewares(mux, func(next http.Handler) http.Handler { + return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { + calls.Add(1) + response.Header().Set("X-GOWDK-Middleware", "applied") + next.ServeHTTP(response, request) + }) + }) + mux.HandleFunc("/service", func(response http.ResponseWriter, _ *http.Request) { + response.WriteHeader(http.StatusNoContent) + }) + + response := httptest.NewRecorder() + handler.ServeHTTP(response, httptest.NewRequest(http.MethodGet, "/service", nil)) + if response.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d", response.Code, http.StatusNoContent) + } + if got := response.Header().Get("X-GOWDK-Middleware"); got != "applied" { + t.Fatalf("middleware header = %q, want applied", got) + } + if got := calls.Load(); got != 1 { + t.Fatalf("middleware calls = %d, want 1", got) + } +} + func TestRunIgnoresNilAndNoOpServices(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel()