diff --git a/huma.go b/huma.go index db4e9a79..ca4dd5d4 100644 --- a/huma.go +++ b/huma.go @@ -143,7 +143,7 @@ func findParams(registry Registry, op *Operation, t reflect.Type) *findResult[*p name = c if f.Type == cookieType { - // Special case: this will be parsed from a string input to a + // Special case: this will be parsed from a string input to an // `http.Cookie` struct. f.Type = stringType } @@ -244,7 +244,7 @@ type headerInfo struct { func findHeaders(t reflect.Type) *findResult[*headerInfo] { return findInType(t, nil, func(sf reflect.StructField, i []int) *headerInfo { - // Ignore embedded fields + // Ignore embedded fields. if sf.Anonymous { return nil } @@ -253,6 +253,7 @@ func findHeaders(t reflect.Type) *findResult[*headerInfo] { if header == "" { header = sf.Name } + timeFormat := "" if sf.Type == timeType { timeFormat = http.TimeFormat @@ -260,6 +261,7 @@ func findHeaders(t reflect.Type) *findResult[*headerInfo] { timeFormat = f } } + return &headerInfo{sf, header, timeFormat} }, false, "Status", "Body") } @@ -439,9 +441,9 @@ func _findInType[T comparable](t reflect.Type, path []int, result *findResult[T] result.Paths = append(result.Paths, findResultPath[T]{fi, v}) } } - if f.Anonymous || recurseFields || deref(f.Type).Kind() != reflect.Struct { + if f.Anonymous || recurseFields || baseType(f.Type).Kind() != reflect.Struct { // Always process embedded structs and named fields which are not - // structs. If `recurseFields` is true then we also process named + // structs. If `recurseFields` is true, then we also process named // struct fields recursively. visited[t] = struct{}{} _findInType(f.Type, fi, result, onType, onField, recurseFields, visited, ignore...) @@ -661,10 +663,8 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I) if documenter, ok := api.(OperationDocumenter); ok { // Enables customization of OpenAPI documentation behavior for operations. documenter.DocumentOperation(&op) - } else { - if !op.Hidden { - oapi.AddOperation(&op) - } + } else if !op.Hidden { + oapi.AddOperation(&op) } resolvers := findResolvers(resolverType, inputType) @@ -1152,7 +1152,7 @@ func initResponses(op *Operation) { } } -// processInputType validates the input type, extracts expected requests and +// processInputType validates the input type, extracts expected requests, and // defines them on the operation op. func processInputType(inputType reflect.Type, op *Operation, registry Registry) (*findResult[*paramFieldInfo], []int, bool, []int, rawBodyType, *Schema) { inputParams := findParams(registry, op, inputType) @@ -1398,11 +1398,30 @@ func processOutputType(outputType reflect.Type, op *Operation, registry Registry } outHeaders := findHeaders(outputType) for _, entry := range outHeaders.Paths { + v := entry.Value + + // Check if this field or any parent is hidden. + hidden := false + currentType := outputType + for _, idx := range entry.Path { + currentType = baseType(currentType) + + field := currentType.Field(idx) + if boolTag(field, "hidden", false) { + hidden = true + break + } + + currentType = field.Type + } + if hidden { + continue + } + // Document the header's name and type. if op.Responses[defaultStatusStr].Headers == nil { op.Responses[defaultStatusStr].Headers = map[string]*Param{} } - v := entry.Value f := v.Field if f.Type.Kind() == reflect.Slice { f.Type = deref(f.Type.Elem()) diff --git a/huma_test.go b/huma_test.go index 1f09708a..9ab0a74d 100644 --- a/huma_test.go +++ b/huma_test.go @@ -1895,33 +1895,57 @@ Content-Type: text/plain Name: "response-headers", Register: func(t *testing.T, api huma.API) { type Resp struct { - Str string `header:"str"` - Int int `header:"int"` - Uint uint `header:"uint"` - Float float64 `header:"float"` - Bool bool `header:"bool"` - Date time.Time `header:"date"` - Empty string `header:"empty"` + Str string `header:"str"` + Int int `header:"int"` + Uint uint `header:"uint"` + Float float64 `header:"float"` + Bool bool `header:"bool"` + Date time.Time `header:"date"` + Empty string `header:"empty"` + CustomTime time.Time `header:"custom-time" timeFormat:"2006-01-02"` + WithoutTag string // No header tag - SHOULD be set as a header using field name. + LastModified time.Time // No header tag - SHOULD be set as a header using field name. } huma.Register(api, huma.Operation{ Method: http.MethodGet, Path: "/response-headers", }, func(ctx context.Context, input *struct{}) (*Resp, error) { - resp := &Resp{} - resp.Str = "str" - resp.Int = 1 - resp.Uint = 2 - resp.Float = 3.45 - resp.Bool = true - resp.Date = time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC) - return resp, nil + return &Resp{ + Str: "str", + Int: 1, + Uint: 2, + Float: 3.45, + Bool: true, + Date: time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC), + CustomTime: time.Date(2023, 6, 15, 10, 30, 0, 0, time.UTC), + WithoutTag: "without-tag-value", + LastModified: time.Date(2023, 6, 15, 10, 30, 0, 0, time.UTC), + }, nil }) + + headers := api.OpenAPI().Paths["/response-headers"].Get.Responses["204"].Headers + + // Surface-level fields with explicit tags should be documented. + assert.NotNil(t, headers["str"]) + assert.NotNil(t, headers["int"]) + assert.NotNil(t, headers["uint"]) + assert.NotNil(t, headers["float"]) + assert.NotNil(t, headers["bool"]) + assert.NotNil(t, headers["date"]) + assert.NotNil(t, headers["empty"]) + assert.NotNil(t, headers["custom-time"]) + + // Surface-level fields without tags should be documented using field name. + assert.NotNil(t, headers["WithoutTag"]) + assert.NotNil(t, headers["LastModified"]) }, Method: http.MethodGet, URL: "/response-headers", Assert: func(t *testing.T, resp *httptest.ResponseRecorder) { assert.Equal(t, http.StatusNoContent, resp.Code) + + // Surface-level fields with explicit tags should be set. assert.Equal(t, "str", resp.Header().Get("Str")) assert.Equal(t, "1", resp.Header().Get("Int")) assert.Equal(t, "2", resp.Header().Get("Uint")) @@ -1929,6 +1953,78 @@ Content-Type: text/plain assert.Equal(t, "true", resp.Header().Get("Bool")) assert.Equal(t, "Sun, 01 Jan 2023 12:00:00 GMT", resp.Header().Get("Date")) assert.Empty(t, resp.Header().Values("Empty")) + assert.Equal(t, "2023-06-15", resp.Header().Get("Custom-Time")) + + // Surface-level fields without tags should be set using field name. + assert.Equal(t, "without-tag-value", resp.Header().Get("WithoutTag")) + assert.Equal(t, "Thu, 15 Jun 2023 10:30:00 GMT", resp.Header().Get("LastModified")) + }, + }, + { + Name: "response-headers-hidden", + Register: func(t *testing.T, api huma.API) { + type HiddenHeaders struct { + HiddenWithTag string `header:"X-Hidden-With-Tag"` + HiddenWithoutTag string // No header tag - should be set as header using field name. + } + + type Resp struct { + *HiddenHeaders `hidden:"true"` + + VisibleWithTag string `header:"X-Visible-With-Tag"` + VisibleWithoutTag string // No header tag - SHOULD be set as a header using field name. + LastModified time.Time // No header tag - SHOULD be set as a header using field name. + Body struct { + Message string `json:"message"` + } + } + + huma.Register(api, huma.Operation{ + Method: http.MethodGet, + Path: "/response-headers-hidden", + }, func(ctx context.Context, input *struct{}) (*Resp, error) { + return &Resp{ + HiddenHeaders: &HiddenHeaders{ + HiddenWithTag: "hidden-with-tag-value", + HiddenWithoutTag: "should-be-header", + }, + VisibleWithTag: "visible-with-tag-value", + VisibleWithoutTag: "visible-without-tag-value", + LastModified: time.Date(2023, 6, 15, 10, 30, 0, 0, time.UTC), + Body: struct { + Message string `json:"message"` + }{ + Message: "Hello", + }, + }, nil + }) + + headers := api.OpenAPI().Paths["/response-headers-hidden"].Get.Responses["200"].Headers + + // Hidden headers should NOT appear in OpenAPI documentation. + assert.Nil(t, headers["X-Hidden-With-Tag"], "hidden header with tag should not appear in OpenAPI docs") + assert.Nil(t, headers["HiddenWithoutTag"], "hidden header without tag should not appear in OpenAPI docs") + + // Visible surface-level fields should appear in OpenAPI documentation. + assert.NotNil(t, headers["X-Visible-With-Tag"], "visible header with tag should appear in OpenAPI docs") + assert.NotNil(t, headers["VisibleWithoutTag"], "visible header without tag should appear in OpenAPI docs") + assert.NotNil(t, headers["LastModified"], "visible time header should appear in OpenAPI docs") + }, + Method: http.MethodGet, + URL: "/response-headers-hidden", + Assert: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, http.StatusOK, resp.Code) + + // Hidden headers with explicit tag SHOULD still be sent at runtime. + assert.Equal(t, "hidden-with-tag-value", resp.Header().Get("X-Hidden-With-Tag")) + + // Hidden headers without tag SHOULD still be sent at runtime using field name. + assert.Equal(t, "should-be-header", resp.Header().Get("HiddenWithoutTag")) + + // Visible surface-level fields should be sent at runtime. + assert.Equal(t, "visible-with-tag-value", resp.Header().Get("X-Visible-With-Tag")) + assert.Equal(t, "visible-without-tag-value", resp.Header().Get("VisibleWithoutTag")) + assert.Equal(t, "Thu, 15 Jun 2023 10:30:00 GMT", resp.Header().Get("LastModified")) }, }, { diff --git a/schema.go b/schema.go index 0214b105..b5788e22 100644 --- a/schema.go +++ b/schema.go @@ -47,6 +47,18 @@ var ( rawMessageType = reflect.TypeOf(json.RawMessage{}) ) +func baseType(t reflect.Type) reflect.Type { + t = deref(t) + for { + switch t.Kind() { + case reflect.Slice, reflect.Array, reflect.Map: + t = deref(t.Elem()) + default: + return t + } + } +} + func deref(t reflect.Type) reflect.Type { for t.Kind() == reflect.Ptr { t = t.Elem()