Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func findParams(registry Registry, op *Operation, t reflect.Type) *findResult[*p
name = c

if f.Type == cookieType {
// Special case: this will be parsed from a string input to a
// Special case: this will be parsed from a string input to an
// `http.Cookie` struct.
f.Type = stringType
}
Expand Down Expand Up @@ -244,7 +244,7 @@ type headerInfo struct {

func findHeaders(t reflect.Type) *findResult[*headerInfo] {
return findInType(t, nil, func(sf reflect.StructField, i []int) *headerInfo {
// Ignore embedded fields
// Ignore embedded fields.
if sf.Anonymous {
return nil
}
Expand All @@ -253,13 +253,15 @@ func findHeaders(t reflect.Type) *findResult[*headerInfo] {
if header == "" {
header = sf.Name
}

timeFormat := ""
if sf.Type == timeType {
timeFormat = http.TimeFormat
if f := sf.Tag.Get("timeFormat"); f != "" {
timeFormat = f
}
}

return &headerInfo{sf, header, timeFormat}
}, false, "Status", "Body")
}
Expand Down Expand Up @@ -439,9 +441,9 @@ func _findInType[T comparable](t reflect.Type, path []int, result *findResult[T]
result.Paths = append(result.Paths, findResultPath[T]{fi, v})
}
}
if f.Anonymous || recurseFields || deref(f.Type).Kind() != reflect.Struct {
if f.Anonymous || recurseFields || baseType(f.Type).Kind() != reflect.Struct {
// Always process embedded structs and named fields which are not
// structs. If `recurseFields` is true then we also process named
// structs. If `recurseFields` is true, then we also process named
// struct fields recursively.
visited[t] = struct{}{}
_findInType(f.Type, fi, result, onType, onField, recurseFields, visited, ignore...)
Expand Down Expand Up @@ -661,10 +663,8 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
if documenter, ok := api.(OperationDocumenter); ok {
// Enables customization of OpenAPI documentation behavior for operations.
documenter.DocumentOperation(&op)
} else {
if !op.Hidden {
oapi.AddOperation(&op)
}
} else if !op.Hidden {
oapi.AddOperation(&op)
}

resolvers := findResolvers(resolverType, inputType)
Expand Down Expand Up @@ -1152,7 +1152,7 @@ func initResponses(op *Operation) {
}
}

// processInputType validates the input type, extracts expected requests and
// processInputType validates the input type, extracts expected requests, and
// defines them on the operation op.
func processInputType(inputType reflect.Type, op *Operation, registry Registry) (*findResult[*paramFieldInfo], []int, bool, []int, rawBodyType, *Schema) {
inputParams := findParams(registry, op, inputType)
Expand Down Expand Up @@ -1398,11 +1398,30 @@ func processOutputType(outputType reflect.Type, op *Operation, registry Registry
}
outHeaders := findHeaders(outputType)
for _, entry := range outHeaders.Paths {
v := entry.Value

// Check if this field or any parent is hidden.
hidden := false
currentType := outputType
for _, idx := range entry.Path {
currentType = baseType(currentType)

field := currentType.Field(idx)
if boolTag(field, "hidden", false) {
hidden = true
break
}

currentType = field.Type
}
if hidden {
continue
}

// Document the header's name and type.
if op.Responses[defaultStatusStr].Headers == nil {
op.Responses[defaultStatusStr].Headers = map[string]*Param{}
}
v := entry.Value
f := v.Field
if f.Type.Kind() == reflect.Slice {
f.Type = deref(f.Type.Elem())
Expand Down
126 changes: 111 additions & 15 deletions huma_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1895,40 +1895,136 @@ Content-Type: text/plain
Name: "response-headers",
Register: func(t *testing.T, api huma.API) {
type Resp struct {
Str string `header:"str"`
Int int `header:"int"`
Uint uint `header:"uint"`
Float float64 `header:"float"`
Bool bool `header:"bool"`
Date time.Time `header:"date"`
Empty string `header:"empty"`
Str string `header:"str"`
Int int `header:"int"`
Uint uint `header:"uint"`
Float float64 `header:"float"`
Bool bool `header:"bool"`
Date time.Time `header:"date"`
Empty string `header:"empty"`
CustomTime time.Time `header:"custom-time" timeFormat:"2006-01-02"`
WithoutTag string // No header tag - SHOULD be set as a header using field name.
LastModified time.Time // No header tag - SHOULD be set as a header using field name.
}

huma.Register(api, huma.Operation{
Method: http.MethodGet,
Path: "/response-headers",
}, func(ctx context.Context, input *struct{}) (*Resp, error) {
resp := &Resp{}
resp.Str = "str"
resp.Int = 1
resp.Uint = 2
resp.Float = 3.45
resp.Bool = true
resp.Date = time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC)
return resp, nil
return &Resp{
Str: "str",
Int: 1,
Uint: 2,
Float: 3.45,
Bool: true,
Date: time.Date(2023, 1, 1, 12, 0, 0, 0, time.UTC),
CustomTime: time.Date(2023, 6, 15, 10, 30, 0, 0, time.UTC),
WithoutTag: "without-tag-value",
LastModified: time.Date(2023, 6, 15, 10, 30, 0, 0, time.UTC),
}, nil
})

headers := api.OpenAPI().Paths["/response-headers"].Get.Responses["204"].Headers

// Surface-level fields with explicit tags should be documented.
assert.NotNil(t, headers["str"])
assert.NotNil(t, headers["int"])
assert.NotNil(t, headers["uint"])
assert.NotNil(t, headers["float"])
assert.NotNil(t, headers["bool"])
assert.NotNil(t, headers["date"])
assert.NotNil(t, headers["empty"])
assert.NotNil(t, headers["custom-time"])

// Surface-level fields without tags should be documented using field name.
assert.NotNil(t, headers["WithoutTag"])
assert.NotNil(t, headers["LastModified"])
},
Method: http.MethodGet,
URL: "/response-headers",
Assert: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusNoContent, resp.Code)

// Surface-level fields with explicit tags should be set.
assert.Equal(t, "str", resp.Header().Get("Str"))
assert.Equal(t, "1", resp.Header().Get("Int"))
assert.Equal(t, "2", resp.Header().Get("Uint"))
assert.Equal(t, "3.45", resp.Header().Get("Float"))
assert.Equal(t, "true", resp.Header().Get("Bool"))
assert.Equal(t, "Sun, 01 Jan 2023 12:00:00 GMT", resp.Header().Get("Date"))
assert.Empty(t, resp.Header().Values("Empty"))
assert.Equal(t, "2023-06-15", resp.Header().Get("Custom-Time"))

// Surface-level fields without tags should be set using field name.
assert.Equal(t, "without-tag-value", resp.Header().Get("WithoutTag"))
assert.Equal(t, "Thu, 15 Jun 2023 10:30:00 GMT", resp.Header().Get("LastModified"))
},
},
{
Name: "response-headers-hidden",
Register: func(t *testing.T, api huma.API) {
type HiddenHeaders struct {
HiddenWithTag string `header:"X-Hidden-With-Tag"`
HiddenWithoutTag string // No header tag - should be set as header using field name.
}

type Resp struct {
*HiddenHeaders `hidden:"true"`

VisibleWithTag string `header:"X-Visible-With-Tag"`
VisibleWithoutTag string // No header tag - SHOULD be set as a header using field name.
LastModified time.Time // No header tag - SHOULD be set as a header using field name.
Body struct {
Message string `json:"message"`
}
}

huma.Register(api, huma.Operation{
Method: http.MethodGet,
Path: "/response-headers-hidden",
}, func(ctx context.Context, input *struct{}) (*Resp, error) {
return &Resp{
HiddenHeaders: &HiddenHeaders{
HiddenWithTag: "hidden-with-tag-value",
HiddenWithoutTag: "should-be-header",
},
VisibleWithTag: "visible-with-tag-value",
VisibleWithoutTag: "visible-without-tag-value",
LastModified: time.Date(2023, 6, 15, 10, 30, 0, 0, time.UTC),
Body: struct {
Message string `json:"message"`
}{
Message: "Hello",
},
}, nil
})

headers := api.OpenAPI().Paths["/response-headers-hidden"].Get.Responses["200"].Headers

// Hidden headers should NOT appear in OpenAPI documentation.
assert.Nil(t, headers["X-Hidden-With-Tag"], "hidden header with tag should not appear in OpenAPI docs")
assert.Nil(t, headers["HiddenWithoutTag"], "hidden header without tag should not appear in OpenAPI docs")

// Visible surface-level fields should appear in OpenAPI documentation.
assert.NotNil(t, headers["X-Visible-With-Tag"], "visible header with tag should appear in OpenAPI docs")
assert.NotNil(t, headers["VisibleWithoutTag"], "visible header without tag should appear in OpenAPI docs")
assert.NotNil(t, headers["LastModified"], "visible time header should appear in OpenAPI docs")
},
Method: http.MethodGet,
URL: "/response-headers-hidden",
Assert: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)

// Hidden headers with explicit tag SHOULD still be sent at runtime.
assert.Equal(t, "hidden-with-tag-value", resp.Header().Get("X-Hidden-With-Tag"))

// Hidden headers without tag SHOULD still be sent at runtime using field name.
assert.Equal(t, "should-be-header", resp.Header().Get("HiddenWithoutTag"))

// Visible surface-level fields should be sent at runtime.
assert.Equal(t, "visible-with-tag-value", resp.Header().Get("X-Visible-With-Tag"))
assert.Equal(t, "visible-without-tag-value", resp.Header().Get("VisibleWithoutTag"))
assert.Equal(t, "Thu, 15 Jun 2023 10:30:00 GMT", resp.Header().Get("LastModified"))
},
},
{
Expand Down
12 changes: 12 additions & 0 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ var (
rawMessageType = reflect.TypeOf(json.RawMessage{})
)

func baseType(t reflect.Type) reflect.Type {
t = deref(t)
for {
switch t.Kind() {
case reflect.Slice, reflect.Array, reflect.Map:
t = deref(t.Elem())
default:
return t
}
}
}

func deref(t reflect.Type) reflect.Type {
for t.Kind() == reflect.Ptr {
t = t.Elem()
Expand Down
Loading