diff --git a/README.md b/README.md index c17a596..060e650 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ That's it. ShiftAPI reflects your Go types into an OpenAPI 3.1 spec at `/openapi ### Generic type-safe handlers -Generic free functions capture your request and response types at compile time. Handlers with a body (`Post`, `Put`, `Patch`) receive the decoded request as a typed value. Handlers without a body (`Get`, `Delete`, `Head`) just receive the request. +Generic free functions capture your request and response types at compile time. Handlers with a body (`Post`, `Put`, `Patch`) receive the decoded request as a typed value. Handlers without a body (`Get`, `Delete`, `Head`) just receive the request. Query-param variants (`GetWithQuery`, `PostWithQuery`, etc.) add a typed query struct as well. ```go // POST — body is decoded and passed as *CreateUser @@ -96,6 +96,30 @@ shiftapi.Get(api, "/users/{id}", func(r *http.Request) (*User, error) { }) ``` +### Typed query parameters + +Define a struct with `query` tags and use `GetWithQuery`, `DeleteWithQuery`, `PostWithQuery`, etc. Query params are parsed, validated, and documented in the OpenAPI spec automatically. + +```go +type SearchQuery struct { + Q string `query:"q" validate:"required"` + Page int `query:"page" validate:"min=1"` + Limit int `query:"limit" validate:"min=1,max=100"` +} + +shiftapi.GetWithQuery(api, "/search", func(r *http.Request, query SearchQuery) (*Results, error) { + return doSearch(query.Q, query.Page, query.Limit), nil +}) +``` + +Supports `string`, `bool`, `int*`, `uint*`, `float*` scalars, `*T` pointers for optional params, and `[]T` slices for repeated params (e.g. `?tag=a&tag=b`). Use `query:"-"` to skip a field. Parse errors return `400`; validation failures return `422`. + +For handlers that need both query parameters and a request body, use `PostWithQuery`, `PutWithQuery`, or `PatchWithQuery`: + +```go +shiftapi.PostWithQuery[CreateQuery, CreateBody, *Result](api, "/items", handler) +``` + ### Validation Built-in validation via [go-playground/validator](https://github.com/go-playground/validator). Struct tags are enforced at runtime *and* reflected into the OpenAPI schema. @@ -198,6 +222,11 @@ const { data: greeting } = await client.POST("/greet", { body: { name: "frank" }, }); // body and response are fully typed from your Go structs + +const { data: results } = await client.GET("/search", { + params: { query: { q: "hello", page: 1, limit: 10 } }, +}); +// query params are fully typed too — { q: string, page?: number, limit?: number } ``` In dev mode the plugin also starts the Go server, proxies API requests through Vite, watches `.go` files, and hot-reloads the frontend when types change. diff --git a/examples/greeter/main.go b/examples/greeter/main.go index d29a829..33ae2ba 100644 --- a/examples/greeter/main.go +++ b/examples/greeter/main.go @@ -22,6 +22,26 @@ func greet(r *http.Request, body *Person) (*Greeting, error) { return &Greeting{Hello: body.Name}, nil } +type SearchQuery struct { + Q string `query:"q" validate:"required"` + Page int `query:"page" validate:"min=1"` + Limit int `query:"limit" validate:"min=1,max=100"` +} + +type SearchResult struct { + Query string `json:"query"` + Page int `json:"page"` + Limit int `json:"limit"` +} + +func search(r *http.Request, query SearchQuery) (*SearchResult, error) { + return &SearchResult{ + Query: query.Q, + Page: query.Page, + Limit: query.Limit, + }, nil +} + type Status struct { OK bool `json:"ok"` } @@ -43,6 +63,14 @@ func main() { }), ) + shiftapi.GetWithQuery(api, "/search", search, + shiftapi.WithRouteInfo(shiftapi.RouteInfo{ + Summary: "Search for things", + Description: "Search with typed query parameters", + Tags: []string{"search"}, + }), + ) + shiftapi.Get(api, "/health", health, shiftapi.WithRouteInfo(shiftapi.RouteInfo{ Summary: "Health check", diff --git a/handler.go b/handler.go index bff3adf..6975df6 100644 --- a/handler.go +++ b/handler.go @@ -13,6 +13,12 @@ type HandlerFunc[Resp any] func(r *http.Request) (Resp, error) // HandlerFuncWithBody is a typed handler for methods with a request body (POST, PUT, PATCH, etc.). type HandlerFuncWithBody[Body, Resp any] func(r *http.Request, body Body) (Resp, error) +// HandlerFuncWithQuery is a typed handler for methods with typed query parameters. +type HandlerFuncWithQuery[Query, Resp any] func(r *http.Request, query Query) (Resp, error) + +// HandlerFuncWithQueryAndBody is a typed handler for methods with both typed query parameters and a request body. +type HandlerFuncWithQueryAndBody[Query, Body, Resp any] func(r *http.Request, query Query, body Body) (Resp, error) + func adapt[Resp any](fn HandlerFunc[Resp], status int) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { resp, err := fn(r) @@ -44,6 +50,55 @@ func adaptWithBody[Body, Resp any](fn HandlerFuncWithBody[Body, Resp], status in } } +func adaptWithQuery[Query, Resp any](fn HandlerFuncWithQuery[Query, Resp], status int, validate func(any) error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + query, err := parseQuery[Query](r.URL.Query()) + if err != nil { + writeError(w, Error(http.StatusBadRequest, err.Error())) + return + } + if err := validate(query); err != nil { + writeError(w, err) + return + } + resp, err := fn(r, query) + if err != nil { + writeError(w, err) + return + } + writeJSON(w, status, resp) + } +} + +func adaptWithQueryAndBody[Query, Body, Resp any](fn HandlerFuncWithQueryAndBody[Query, Body, Resp], status int, validate func(any) error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + query, err := parseQuery[Query](r.URL.Query()) + if err != nil { + writeError(w, Error(http.StatusBadRequest, err.Error())) + return + } + if err := validate(query); err != nil { + writeError(w, err) + return + } + var body Body + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, Error(http.StatusBadRequest, "invalid request body")) + return + } + if err := validate(body); err != nil { + writeError(w, err) + return + } + resp, err := fn(r, query, body) + if err != nil { + writeError(w, err) + return + } + writeJSON(w, status, resp) + } +} + func writeJSON(w http.ResponseWriter, status int, v any) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(status) diff --git a/handlerFuncs.go b/handlerFuncs.go index 0f0dec6..54455ef 100644 --- a/handlerFuncs.go +++ b/handlerFuncs.go @@ -18,7 +18,7 @@ func registerRoute[Resp any]( var resp Resp outType := reflect.TypeOf(resp) - if err := api.updateSchema(method, path, nil, outType, cfg.info, cfg.status); err != nil { + if err := api.updateSchema(method, path, nil, nil, outType, cfg.info, cfg.status); err != nil { panic(fmt.Sprintf("shiftapi: schema generation failed for %s %s: %v", method, path, err)) } @@ -40,7 +40,7 @@ func registerRouteWithBody[Body, Resp any]( var resp Resp outType := reflect.TypeOf(resp) - if err := api.updateSchema(method, path, inType, outType, cfg.info, cfg.status); err != nil { + if err := api.updateSchema(method, path, nil, inType, outType, cfg.info, cfg.status); err != nil { panic(fmt.Sprintf("shiftapi: schema generation failed for %s %s: %v", method, path, err)) } @@ -48,6 +48,52 @@ func registerRouteWithBody[Body, Resp any]( api.mux.HandleFunc(pattern, adaptWithBody(fn, cfg.status, api.validateBody)) } +func registerRouteWithQuery[Query, Resp any]( + api *API, + method string, + path string, + fn HandlerFuncWithQuery[Query, Resp], + options ...RouteOption, +) { + cfg := applyRouteOptions(options) + + var query Query + queryType := reflect.TypeOf(query) + var resp Resp + outType := reflect.TypeOf(resp) + + if err := api.updateSchema(method, path, queryType, nil, outType, cfg.info, cfg.status); err != nil { + panic(fmt.Sprintf("shiftapi: schema generation failed for %s %s: %v", method, path, err)) + } + + pattern := fmt.Sprintf("%s %s", method, path) + api.mux.HandleFunc(pattern, adaptWithQuery(fn, cfg.status, api.validateBody)) +} + +func registerRouteWithQueryAndBody[Query, Body, Resp any]( + api *API, + method string, + path string, + fn HandlerFuncWithQueryAndBody[Query, Body, Resp], + options ...RouteOption, +) { + cfg := applyRouteOptions(options) + + var query Query + queryType := reflect.TypeOf(query) + var body Body + inType := reflect.TypeOf(body) + var resp Resp + outType := reflect.TypeOf(resp) + + if err := api.updateSchema(method, path, queryType, inType, outType, cfg.info, cfg.status); err != nil { + panic(fmt.Sprintf("shiftapi: schema generation failed for %s %s: %v", method, path, err)) + } + + pattern := fmt.Sprintf("%s %s", method, path) + api.mux.HandleFunc(pattern, adaptWithQueryAndBody(fn, cfg.status, api.validateBody)) +} + // No-body methods // Get registers a GET handler. @@ -96,3 +142,37 @@ func Patch[Body, Resp any](api *API, path string, fn HandlerFuncWithBody[Body, R func Connect[Resp any](api *API, path string, fn HandlerFunc[Resp], options ...RouteOption) { registerRoute(api, http.MethodConnect, path, fn, options...) } + +// Query methods (no body) + +// GetWithQuery registers a GET handler with typed query parameters. +func GetWithQuery[Query, Resp any](api *API, path string, fn HandlerFuncWithQuery[Query, Resp], options ...RouteOption) { + registerRouteWithQuery(api, http.MethodGet, path, fn, options...) +} + +// DeleteWithQuery registers a DELETE handler with typed query parameters. +func DeleteWithQuery[Query, Resp any](api *API, path string, fn HandlerFuncWithQuery[Query, Resp], options ...RouteOption) { + registerRouteWithQuery(api, http.MethodDelete, path, fn, options...) +} + +// HeadWithQuery registers a HEAD handler with typed query parameters. +func HeadWithQuery[Query, Resp any](api *API, path string, fn HandlerFuncWithQuery[Query, Resp], options ...RouteOption) { + registerRouteWithQuery(api, http.MethodHead, path, fn, options...) +} + +// Query + body methods + +// PostWithQuery registers a POST handler with typed query parameters and a request body. +func PostWithQuery[Query, Body, Resp any](api *API, path string, fn HandlerFuncWithQueryAndBody[Query, Body, Resp], options ...RouteOption) { + registerRouteWithQueryAndBody(api, http.MethodPost, path, fn, options...) +} + +// PutWithQuery registers a PUT handler with typed query parameters and a request body. +func PutWithQuery[Query, Body, Resp any](api *API, path string, fn HandlerFuncWithQueryAndBody[Query, Body, Resp], options ...RouteOption) { + registerRouteWithQueryAndBody(api, http.MethodPut, path, fn, options...) +} + +// PatchWithQuery registers a PATCH handler with typed query parameters and a request body. +func PatchWithQuery[Query, Body, Resp any](api *API, path string, fn HandlerFuncWithQueryAndBody[Query, Body, Resp], options ...RouteOption) { + registerRouteWithQueryAndBody(api, http.MethodPatch, path, fn, options...) +} diff --git a/query.go b/query.go new file mode 100644 index 0000000..0e4d7b6 --- /dev/null +++ b/query.go @@ -0,0 +1,145 @@ +package shiftapi + +import ( + "fmt" + "net/url" + "reflect" + "strconv" + "strings" +) + +// parseQuery populates a struct of type T from URL query parameters. +// It uses the `query` struct tag for parameter names, falling back to the field name. +// A tag of `query:"-"` causes the field to be skipped. +func parseQuery[T any](values url.Values) (T, error) { + var result T + rv := reflect.ValueOf(&result).Elem() + rt := rv.Type() + + for rt.Kind() == reflect.Ptr { + rv.Set(reflect.New(rt.Elem())) + rv = rv.Elem() + rt = rt.Elem() + } + + if rt.Kind() != reflect.Struct { + return result, fmt.Errorf("query type must be a struct, got %s", rt.Kind()) + } + + for i := range rt.NumField() { + field := rt.Field(i) + if !field.IsExported() { + continue + } + + name := queryFieldName(field) + if name == "-" { + continue + } + + fv := rv.Field(i) + ft := field.Type + + // Handle pointer fields (optional params) + if ft.Kind() == reflect.Ptr { + rawValues, exists := values[name] + if !exists || len(rawValues) == 0 { + continue // leave nil + } + ptr := reflect.New(ft.Elem()) + if err := setScalarValue(ptr.Elem(), rawValues[0]); err != nil { + return result, &queryParseError{Field: name, Err: err} + } + fv.Set(ptr) + continue + } + + // Handle slice fields + if ft.Kind() == reflect.Slice { + rawValues, exists := values[name] + if !exists || len(rawValues) == 0 { + continue + } + elemType := ft.Elem() + slice := reflect.MakeSlice(ft, len(rawValues), len(rawValues)) + for j, raw := range rawValues { + elem := reflect.New(elemType).Elem() + if err := setScalarValue(elem, raw); err != nil { + return result, &queryParseError{Field: name, Err: err} + } + slice.Index(j).Set(elem) + } + fv.Set(slice) + continue + } + + // Handle scalar fields + raw := values.Get(name) + if raw == "" { + continue + } + if err := setScalarValue(fv, raw); err != nil { + return result, &queryParseError{Field: name, Err: err} + } + } + + return result, nil +} + +// queryFieldName returns the query parameter name for a struct field. +func queryFieldName(f reflect.StructField) string { + tag := f.Tag.Get("query") + if tag == "" { + return f.Name + } + name, _, _ := strings.Cut(tag, ",") + if name == "" { + return f.Name + } + return name +} + +// setScalarValue parses a string and sets the value on a reflect.Value. +func setScalarValue(v reflect.Value, raw string) error { + switch v.Kind() { + case reflect.String: + v.SetString(raw) + case reflect.Bool: + b, err := strconv.ParseBool(raw) + if err != nil { + return fmt.Errorf("invalid boolean value %q", raw) + } + v.SetBool(b) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(raw, 10, v.Type().Bits()) + if err != nil { + return fmt.Errorf("invalid integer value %q", raw) + } + v.SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + n, err := strconv.ParseUint(raw, 10, v.Type().Bits()) + if err != nil { + return fmt.Errorf("invalid unsigned integer value %q", raw) + } + v.SetUint(n) + case reflect.Float32, reflect.Float64: + n, err := strconv.ParseFloat(raw, v.Type().Bits()) + if err != nil { + return fmt.Errorf("invalid float value %q", raw) + } + v.SetFloat(n) + default: + return fmt.Errorf("unsupported query parameter type %s", v.Kind()) + } + return nil +} + +// queryParseError is returned when a query parameter cannot be parsed. +type queryParseError struct { + Field string + Err error +} + +func (e *queryParseError) Error() string { + return fmt.Sprintf("invalid query parameter %q: %v", e.Field, e.Err) +} diff --git a/schema.go b/schema.go index 5e75923..fa591fa 100644 --- a/schema.go +++ b/schema.go @@ -12,7 +12,7 @@ import ( var pathParamRe = regexp.MustCompile(`\{([^}]+)\}`) -func (a *API) updateSchema(method, path string, inType, outType reflect.Type, info *RouteInfo, status int) error { +func (a *API) updateSchema(method, path string, queryType, inType, outType reflect.Type, info *RouteInfo, status int) error { op := &openapi3.Operation{ OperationID: operationID(method, path), Responses: openapi3.NewResponses(), @@ -34,6 +34,15 @@ func (a *API) updateSchema(method, path string, inType, outType reflect.Type, in }) } + // Query parameters + if queryType != nil { + queryParams, err := a.generateQueryParams(queryType) + if err != nil { + return err + } + op.Parameters = append(op.Parameters, queryParams...) + } + // Response schema statusStr := fmt.Sprintf("%d", status) outSchema, err := a.generateSchemaRef(outType) @@ -210,6 +219,88 @@ func (a *API) generateSchemaRef(t reflect.Type) (*openapi3.SchemaRef, error) { return schema, nil } +// generateQueryParams produces OpenAPI parameter definitions for a query struct type. +func (a *API) generateQueryParams(t reflect.Type) ([]*openapi3.ParameterRef, error) { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("query type must be a struct, got %s", t.Kind()) + } + + var params []*openapi3.ParameterRef + for i := range t.NumField() { + field := t.Field(i) + if !field.IsExported() { + continue + } + name := queryFieldName(field) + if name == "-" { + continue + } + + schema := fieldToOpenAPISchema(field.Type) + + // Apply validation constraints + if err := validateSchemaCustomizer(name, field.Type, field.Tag, schema.Value); err != nil { + return nil, err + } + + required := hasRule(field.Tag.Get("validate"), "required") + + params = append(params, &openapi3.ParameterRef{ + Value: &openapi3.Parameter{ + Name: name, + In: "query", + Required: required, + Schema: schema, + }, + }) + } + return params, nil +} + +// fieldToOpenAPISchema maps a Go type to an OpenAPI schema. +func fieldToOpenAPISchema(t reflect.Type) *openapi3.SchemaRef { + // Unwrap pointer + if t.Kind() == reflect.Ptr { + return fieldToOpenAPISchema(t.Elem()) + } + + // Handle slices + if t.Kind() == reflect.Slice { + items := scalarToOpenAPISchema(t.Elem()) + return &openapi3.SchemaRef{ + Value: &openapi3.Schema{ + Type: &openapi3.Types{"array"}, + Items: items, + }, + } + } + + return scalarToOpenAPISchema(t) +} + +// scalarToOpenAPISchema maps a scalar Go type to an OpenAPI schema. +func scalarToOpenAPISchema(t reflect.Type) *openapi3.SchemaRef { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + switch t.Kind() { + case reflect.String: + return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}} + case reflect.Bool: + return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"boolean"}}} + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"integer"}}} + case reflect.Float32, reflect.Float64: + return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"number"}}} + default: + return &openapi3.SchemaRef{Value: &openapi3.Schema{Type: &openapi3.Types{"string"}}} + } +} + func scrubRefs(s *openapi3.SchemaRef) { if s == nil || s.Value == nil || len(s.Value.Properties) == 0 { return diff --git a/shiftapi_test.go b/shiftapi_test.go index a25482a..dd02a43 100644 --- a/shiftapi_test.go +++ b/shiftapi_test.go @@ -1397,6 +1397,490 @@ func TestValidationNestedStructMissingFields(t *testing.T) { } } +// --- Query parameter test types --- + +type SearchQuery struct { + Q string `query:"q" validate:"required"` + Page int `query:"page" validate:"min=1"` + Limit int `query:"limit" validate:"min=1,max=100"` +} + +type SearchResult struct { + Query string `json:"query"` + Page int `json:"page"` + Limit int `json:"limit"` +} + +type TagQuery struct { + Tags []string `query:"tag"` +} + +type TagResult struct { + Tags []string `json:"tags"` +} + +type OptionalQuery struct { + Name string `query:"name"` + Debug *bool `query:"debug"` + Limit *int `query:"limit"` +} + +type OptionalResult struct { + Name string `json:"name"` + HasDebug bool `json:"has_debug"` + Debug bool `json:"debug"` + HasLimit bool `json:"has_limit"` + Limit int `json:"limit"` +} + +type FilterQuery struct { + Status string `query:"status" validate:"oneof=active inactive pending"` +} + +// --- Query parameter runtime tests --- + +func TestGetWithQueryBasic(t *testing.T) { + api := newTestAPI(t) + shiftapi.GetWithQuery(api, "/search", func(r *http.Request, query SearchQuery) (*SearchResult, error) { + return &SearchResult{Query: query.Q, Page: query.Page, Limit: query.Limit}, nil + }) + + resp := doRequest(t, api, http.MethodGet, "/search?q=hello&page=2&limit=10", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[SearchResult](t, resp) + if result.Query != "hello" { + t.Errorf("expected Query=hello, got %q", result.Query) + } + if result.Page != 2 { + t.Errorf("expected Page=2, got %d", result.Page) + } + if result.Limit != 10 { + t.Errorf("expected Limit=10, got %d", result.Limit) + } +} + +func TestGetWithQueryMissingRequired(t *testing.T) { + api := newTestAPI(t) + shiftapi.GetWithQuery(api, "/search", func(r *http.Request, query SearchQuery) (*SearchResult, error) { + return &SearchResult{}, nil + }) + + // Missing required "q" param + resp := doRequest(t, api, http.MethodGet, "/search?page=1&limit=10", "") + if resp.StatusCode != http.StatusUnprocessableEntity { + t.Fatalf("expected 422, got %d", resp.StatusCode) + } +} + +func TestGetWithQueryInvalidType(t *testing.T) { + api := newTestAPI(t) + shiftapi.GetWithQuery(api, "/search", func(r *http.Request, query SearchQuery) (*SearchResult, error) { + return &SearchResult{}, nil + }) + + // "page" should be an int, not "abc" + resp := doRequest(t, api, http.MethodGet, "/search?q=test&page=abc&limit=10", "") + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +func TestGetWithQuerySliceParams(t *testing.T) { + api := newTestAPI(t) + shiftapi.GetWithQuery(api, "/tags", func(r *http.Request, query TagQuery) (*TagResult, error) { + return &TagResult{Tags: query.Tags}, nil + }) + + resp := doRequest(t, api, http.MethodGet, "/tags?tag=a&tag=b&tag=c", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[TagResult](t, resp) + if len(result.Tags) != 3 { + t.Fatalf("expected 3 tags, got %d", len(result.Tags)) + } + expected := []string{"a", "b", "c"} + for i, tag := range result.Tags { + if tag != expected[i] { + t.Errorf("expected tag[%d]=%q, got %q", i, expected[i], tag) + } + } +} + +func TestGetWithQueryOptionalPointer(t *testing.T) { + api := newTestAPI(t) + shiftapi.GetWithQuery(api, "/optional", func(r *http.Request, query OptionalQuery) (*OptionalResult, error) { + result := &OptionalResult{Name: query.Name} + if query.Debug != nil { + result.HasDebug = true + result.Debug = *query.Debug + } + if query.Limit != nil { + result.HasLimit = true + result.Limit = *query.Limit + } + return result, nil + }) + + // With optional params + resp := doRequest(t, api, http.MethodGet, "/optional?name=test&debug=true&limit=50", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[OptionalResult](t, resp) + if !result.HasDebug || !result.Debug { + t.Error("expected debug=true") + } + if !result.HasLimit || result.Limit != 50 { + t.Error("expected limit=50") + } + + // Without optional params + resp2 := doRequest(t, api, http.MethodGet, "/optional?name=test", "") + if resp2.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp2.StatusCode) + } + result2 := decodeJSON[OptionalResult](t, resp2) + if result2.HasDebug { + t.Error("expected debug to be absent") + } + if result2.HasLimit { + t.Error("expected limit to be absent") + } +} + +func TestPostWithQueryAndBody(t *testing.T) { + api := newTestAPI(t) + + type CreateQuery struct { + DryRun bool `query:"dry_run"` + } + type CreateBody struct { + Name string `json:"name"` + } + type CreateResult struct { + Name string `json:"name"` + DryRun bool `json:"dry_run"` + } + + shiftapi.PostWithQuery[CreateQuery, CreateBody, *CreateResult](api, "/items", func(r *http.Request, query CreateQuery, body CreateBody) (*CreateResult, error) { + return &CreateResult{Name: body.Name, DryRun: query.DryRun}, nil + }) + + resp := doRequest(t, api, http.MethodPost, "/items?dry_run=true", `{"name":"widget"}`) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[CreateResult](t, resp) + if result.Name != "widget" { + t.Errorf("expected Name=widget, got %q", result.Name) + } + if !result.DryRun { + t.Error("expected DryRun=true") + } +} + +func TestGetWithQueryAndPathParams(t *testing.T) { + api := newTestAPI(t) + + type ItemQuery struct { + Fields string `query:"fields"` + } + + shiftapi.GetWithQuery(api, "/items/{id}", func(r *http.Request, query ItemQuery) (*map[string]string, error) { + return &map[string]string{ + "id": r.PathValue("id"), + "fields": query.Fields, + }, nil + }) + + resp := doRequest(t, api, http.MethodGet, "/items/abc123?fields=name,price", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[map[string]string](t, resp) + if result["id"] != "abc123" { + t.Errorf("expected id=abc123, got %q", result["id"]) + } + if result["fields"] != "name,price" { + t.Errorf("expected fields=name,price, got %q", result["fields"]) + } +} + +func TestDeleteWithQuery(t *testing.T) { + api := newTestAPI(t) + + type DeleteQuery struct { + Force bool `query:"force"` + } + + shiftapi.DeleteWithQuery(api, "/items/{id}", func(r *http.Request, query DeleteQuery) (*map[string]any, error) { + return &map[string]any{ + "id": r.PathValue("id"), + "force": query.Force, + }, nil + }) + + resp := doRequest(t, api, http.MethodDelete, "/items/42?force=true", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } +} + +func TestGetWithQueryValidationConstraint(t *testing.T) { + api := newTestAPI(t) + shiftapi.GetWithQuery(api, "/filter", func(r *http.Request, query FilterQuery) (*map[string]string, error) { + return &map[string]string{"status": query.Status}, nil + }) + + // Valid value + resp := doRequest(t, api, http.MethodGet, "/filter?status=active", "") + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + + // Invalid value → 422 + resp2 := doRequest(t, api, http.MethodGet, "/filter?status=unknown", "") + if resp2.StatusCode != http.StatusUnprocessableEntity { + t.Fatalf("expected 422, got %d", resp2.StatusCode) + } +} + +// --- Query parameter spec tests --- + +func TestSpecQueryParamsDocumented(t *testing.T) { + api := newTestAPI(t) + shiftapi.GetWithQuery(api, "/search", func(r *http.Request, query SearchQuery) (*SearchResult, error) { + return &SearchResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/search").Get + if op == nil { + t.Fatal("expected GET operation on /search") + } + + // Should have 3 query params: q, page, limit + queryParams := 0 + for _, p := range op.Parameters { + if p.Value.In == "query" { + queryParams++ + } + } + if queryParams != 3 { + t.Fatalf("expected 3 query parameters, got %d", queryParams) + } +} + +func TestSpecQueryParamTypes(t *testing.T) { + api := newTestAPI(t) + shiftapi.GetWithQuery(api, "/search", func(r *http.Request, query SearchQuery) (*SearchResult, error) { + return &SearchResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/search").Get + + paramByName := make(map[string]*openapi3.Parameter) + for _, p := range op.Parameters { + paramByName[p.Value.Name] = p.Value + } + + // q is a string + if q, ok := paramByName["q"]; !ok { + t.Fatal("expected 'q' query parameter") + } else if !q.Schema.Value.Type.Is("string") { + t.Errorf("expected q type 'string', got %v", q.Schema.Value.Type) + } + + // page is an integer + if page, ok := paramByName["page"]; !ok { + t.Fatal("expected 'page' query parameter") + } else if !page.Schema.Value.Type.Is("integer") { + t.Errorf("expected page type 'integer', got %v", page.Schema.Value.Type) + } + + // limit is an integer + if limit, ok := paramByName["limit"]; !ok { + t.Fatal("expected 'limit' query parameter") + } else if !limit.Schema.Value.Type.Is("integer") { + t.Errorf("expected limit type 'integer', got %v", limit.Schema.Value.Type) + } +} + +func TestSpecQueryParamRequired(t *testing.T) { + api := newTestAPI(t) + shiftapi.GetWithQuery(api, "/search", func(r *http.Request, query SearchQuery) (*SearchResult, error) { + return &SearchResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/search").Get + + paramByName := make(map[string]*openapi3.Parameter) + for _, p := range op.Parameters { + paramByName[p.Value.Name] = p.Value + } + + // q has validate:"required" so it should be required + if !paramByName["q"].Required { + t.Error("expected 'q' to be required") + } + // page does not have required tag + if paramByName["page"].Required { + t.Error("expected 'page' to not be required") + } +} + +func TestSpecQueryParamValidationConstraints(t *testing.T) { + api := newTestAPI(t) + shiftapi.GetWithQuery(api, "/search", func(r *http.Request, query SearchQuery) (*SearchResult, error) { + return &SearchResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/search").Get + + paramByName := make(map[string]*openapi3.Parameter) + for _, p := range op.Parameters { + paramByName[p.Value.Name] = p.Value + } + + // page has min=1 + pageSchema := paramByName["page"].Schema.Value + if pageSchema.Min == nil || *pageSchema.Min != 1 { + t.Errorf("expected page minimum 1, got %v", pageSchema.Min) + } + + // limit has min=1,max=100 + limitSchema := paramByName["limit"].Schema.Value + if limitSchema.Min == nil || *limitSchema.Min != 1 { + t.Errorf("expected limit minimum 1, got %v", limitSchema.Min) + } + if limitSchema.Max == nil || *limitSchema.Max != 100 { + t.Errorf("expected limit maximum 100, got %v", limitSchema.Max) + } +} + +func TestSpecQueryParamEnum(t *testing.T) { + api := newTestAPI(t) + shiftapi.GetWithQuery(api, "/filter", func(r *http.Request, query FilterQuery) (*Empty, error) { + return &Empty{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/filter").Get + + var statusParam *openapi3.Parameter + for _, p := range op.Parameters { + if p.Value.Name == "status" { + statusParam = p.Value + break + } + } + if statusParam == nil { + t.Fatal("expected 'status' query parameter") + } + if len(statusParam.Schema.Value.Enum) != 3 { + t.Fatalf("expected 3 enum values, got %d", len(statusParam.Schema.Value.Enum)) + } +} + +func TestSpecQueryParamSliceType(t *testing.T) { + api := newTestAPI(t) + shiftapi.GetWithQuery(api, "/tags", func(r *http.Request, query TagQuery) (*TagResult, error) { + return &TagResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/tags").Get + + var tagParam *openapi3.Parameter + for _, p := range op.Parameters { + if p.Value.Name == "tag" { + tagParam = p.Value + break + } + } + if tagParam == nil { + t.Fatal("expected 'tag' query parameter") + } + if !tagParam.Schema.Value.Type.Is("array") { + t.Errorf("expected tag type 'array', got %v", tagParam.Schema.Value.Type) + } + if tagParam.Schema.Value.Items == nil || !tagParam.Schema.Value.Items.Value.Type.Is("string") { + t.Error("expected tag items type 'string'") + } +} + +func TestSpecQueryParamsCombinedWithPathParams(t *testing.T) { + api := newTestAPI(t) + + type ItemQuery struct { + Fields string `query:"fields"` + } + + shiftapi.GetWithQuery(api, "/items/{id}", func(r *http.Request, query ItemQuery) (*Item, error) { + return &Item{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/items/{id}").Get + + pathParams := 0 + queryParams := 0 + for _, p := range op.Parameters { + switch p.Value.In { + case "path": + pathParams++ + case "query": + queryParams++ + } + } + if pathParams != 1 { + t.Errorf("expected 1 path parameter, got %d", pathParams) + } + if queryParams != 1 { + t.Errorf("expected 1 query parameter, got %d", queryParams) + } +} + +func TestSpecPostWithQueryHasQueryParamsAndBody(t *testing.T) { + api := newTestAPI(t) + + type CreateQuery struct { + DryRun bool `query:"dry_run"` + } + + shiftapi.PostWithQuery[CreateQuery, Item, *Item](api, "/items", func(r *http.Request, query CreateQuery, body Item) (*Item, error) { + return &body, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/items").Post + if op == nil { + t.Fatal("expected POST operation on /items") + } + + // Should have query params + queryParams := 0 + for _, p := range op.Parameters { + if p.Value.In == "query" { + queryParams++ + } + } + if queryParams != 1 { + t.Errorf("expected 1 query parameter, got %d", queryParams) + } + + // Should also have a request body + if op.RequestBody == nil { + t.Error("expected request body on POST with query and body") + } +} + func contains(slice []string, item string) bool { for _, s := range slice { if s == item {