diff --git a/sse/sse.go b/sse/sse.go index 39b98646..697ebc7f 100644 --- a/sse/sse.go +++ b/sse/sse.go @@ -2,9 +2,11 @@ package sse import ( + "bufio" "context" "encoding/json" "fmt" + "io" "net/http" "os" "reflect" @@ -14,6 +16,7 @@ import ( "time" "github.com/danielgtaylor/huma/v2" + "github.com/valyala/fasthttp" ) // WriteTimeout is the timeout for writing to the client. @@ -54,6 +57,14 @@ func (s Sender) Data(data any) error { return s(Message{Data: data}) } +type flusherFunc func() error + +func (f flusherFunc) Flush() { + if err := f(); err != nil { + fmt.Fprintf(os.Stderr, "warning: flush failed: %v\n", err) + } +} + // Register a new SSE operation. The `eventTypeMap` maps from event name to // the type of the data that will be sent. The `f` function is called with // the context, input, and a `send` function that can be used to send messages @@ -133,8 +144,61 @@ func Register[I any](api huma.API, op huma.Operation, eventTypeMap map[string]an Body: func(ctx huma.Context) { ctx.SetHeader("Content-Type", "text/event-stream") bw := ctx.BodyWriter() - encoder := json.NewEncoder(bw) + send := func(deadliner writeDeadliner, flusher http.Flusher, writer io.Writer) Sender { + encoder := json.NewEncoder(writer) + return func(msg Message) error { + if deadliner != nil { + if err := deadliner.SetWriteDeadline(time.Now().Add(WriteTimeout)); err != nil { + fmt.Fprintf(os.Stderr, "warning: unable to set write deadline: %v\n", err) + } + } else { + fmt.Fprintln(os.Stderr, "write deadline not supported by underlying writer") + } + // Write optional fields + if msg.ID > 0 { + writer.Write(fmt.Appendf(nil, "id: %d\n", msg.ID)) + } + if msg.Retry > 0 { + writer.Write(fmt.Appendf(nil, "retry: %d\n", msg.Retry)) + } + + event, ok := typeToEvent[deref(reflect.TypeOf(msg.Data))] + if !ok { + fmt.Fprintf(os.Stderr, "error: unknown event type %v\n", reflect.TypeOf(msg.Data)) + debug.PrintStack() + } + if event != "" && event != "message" { + // `message` is the default, so no need to transmit it. + writer.Write([]byte("event: " + event + "\n")) + } + + // Write the message data. + if _, err := writer.Write([]byte("data: ")); err != nil { + return err + } + if err := encoder.Encode(msg.Data); err != nil { + writer.Write([]byte(`{"error": "encode error: `)) + writer.Write([]byte(err.Error())) + writer.Write([]byte("\"}\n\n")) + return err + } + writer.Write([]byte("\n")) + if flusher != nil { + flusher.Flush() + } else { + fmt.Fprintln(os.Stderr, "error: unable to flush") + return fmt.Errorf("unable to flush: %w", http.ErrNotSupported) + } + return nil + } + } + if fastCtx, ok := bw.(*fasthttp.RequestCtx); ok { + fastCtx.SetBodyStreamWriter(func(bfw *bufio.Writer) { + f(ctx.Context(), input, send(fastCtx.Conn(), flusherFunc(bfw.Flush), bfw)) + }) + return + } // Get the flusher/deadliner from the response writer if possible. var flusher http.Flusher flushCheck := bw @@ -164,55 +228,8 @@ func Register[I any](api huma.API, op huma.Operation, eventTypeMap map[string]an } } - send := func(msg Message) error { - if deadliner != nil { - if err := deadliner.SetWriteDeadline(time.Now().Add(WriteTimeout)); err != nil { - fmt.Fprintf(os.Stderr, "warning: unable to set write deadline: %v\n", err) - } - } else { - fmt.Fprintln(os.Stderr, "write deadline not supported by underlying writer") - } - - // Write optional fields - if msg.ID > 0 { - bw.Write(fmt.Appendf(nil, "id: %d\n", msg.ID)) - } - if msg.Retry > 0 { - bw.Write(fmt.Appendf(nil, "retry: %d\n", msg.Retry)) - } - - event, ok := typeToEvent[deref(reflect.TypeOf(msg.Data))] - if !ok { - fmt.Fprintf(os.Stderr, "error: unknown event type %v\n", reflect.TypeOf(msg.Data)) - debug.PrintStack() - } - if event != "" && event != "message" { - // `message` is the default, so no need to transmit it. - bw.Write([]byte("event: " + event + "\n")) - } - - // Write the message data. - if _, err := bw.Write([]byte("data: ")); err != nil { - return err - } - if err := encoder.Encode(msg.Data); err != nil { - bw.Write([]byte(`{"error": "encode error: `)) - bw.Write([]byte(err.Error())) - bw.Write([]byte("\"}\n\n")) - return err - } - bw.Write([]byte("\n")) - if flusher != nil { - flusher.Flush() - } else { - fmt.Fprintln(os.Stderr, "error: unable to flush") - return fmt.Errorf("unable to flush: %w", http.ErrNotSupported) - } - return nil - } - // Call the user-provided SSE handler. - f(ctx.Context(), input, send) + f(ctx.Context(), input, send(deadliner, flusher, bw)) }, }, nil })