Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 66 additions & 49 deletions sse/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
package sse

import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"reflect"
Expand All @@ -14,6 +16,7 @@ import (
"time"

"github.com/danielgtaylor/huma/v2"
"github.com/valyala/fasthttp"
)

// WriteTimeout is the timeout for writing to the client.
Expand Down Expand Up @@ -54,6 +57,14 @@ func (s Sender) Data(data any) error {
return s(Message{Data: data})
}

type flusherFunc func() error

func (f flusherFunc) Flush() {
if err := f(); err != nil {
fmt.Fprintf(os.Stderr, "warning: flush failed: %v\n", err)
}
}

// Register a new SSE operation. The `eventTypeMap` maps from event name to
// the type of the data that will be sent. The `f` function is called with
// the context, input, and a `send` function that can be used to send messages
Expand Down Expand Up @@ -133,8 +144,61 @@ func Register[I any](api huma.API, op huma.Operation, eventTypeMap map[string]an
Body: func(ctx huma.Context) {
ctx.SetHeader("Content-Type", "text/event-stream")
bw := ctx.BodyWriter()
encoder := json.NewEncoder(bw)
send := func(deadliner writeDeadliner, flusher http.Flusher, writer io.Writer) Sender {
encoder := json.NewEncoder(writer)
return func(msg Message) error {
if deadliner != nil {
if err := deadliner.SetWriteDeadline(time.Now().Add(WriteTimeout)); err != nil {
fmt.Fprintf(os.Stderr, "warning: unable to set write deadline: %v\n", err)
}
} else {
fmt.Fprintln(os.Stderr, "write deadline not supported by underlying writer")
}

// Write optional fields
if msg.ID > 0 {
writer.Write(fmt.Appendf(nil, "id: %d\n", msg.ID))
}
if msg.Retry > 0 {
writer.Write(fmt.Appendf(nil, "retry: %d\n", msg.Retry))
}

event, ok := typeToEvent[deref(reflect.TypeOf(msg.Data))]
if !ok {
fmt.Fprintf(os.Stderr, "error: unknown event type %v\n", reflect.TypeOf(msg.Data))
debug.PrintStack()
}
if event != "" && event != "message" {
// `message` is the default, so no need to transmit it.
writer.Write([]byte("event: " + event + "\n"))
}

// Write the message data.
if _, err := writer.Write([]byte("data: ")); err != nil {
return err
}
if err := encoder.Encode(msg.Data); err != nil {
writer.Write([]byte(`{"error": "encode error: `))
writer.Write([]byte(err.Error()))
writer.Write([]byte("\"}\n\n"))
return err
}
writer.Write([]byte("\n"))
if flusher != nil {
flusher.Flush()
} else {
fmt.Fprintln(os.Stderr, "error: unable to flush")
return fmt.Errorf("unable to flush: %w", http.ErrNotSupported)
}
return nil
}
}
if fastCtx, ok := bw.(*fasthttp.RequestCtx); ok {
fastCtx.SetBodyStreamWriter(func(bfw *bufio.Writer) {
f(ctx.Context(), input, send(fastCtx.Conn(), flusherFunc(bfw.Flush), bfw))
})
Comment thread
zhouyusd marked this conversation as resolved.
return
Comment thread
zhouyusd marked this conversation as resolved.
}
// Get the flusher/deadliner from the response writer if possible.
var flusher http.Flusher
flushCheck := bw
Expand Down Expand Up @@ -164,55 +228,8 @@ func Register[I any](api huma.API, op huma.Operation, eventTypeMap map[string]an
}
}

send := func(msg Message) error {
if deadliner != nil {
if err := deadliner.SetWriteDeadline(time.Now().Add(WriteTimeout)); err != nil {
fmt.Fprintf(os.Stderr, "warning: unable to set write deadline: %v\n", err)
}
} else {
fmt.Fprintln(os.Stderr, "write deadline not supported by underlying writer")
}

// Write optional fields
if msg.ID > 0 {
bw.Write(fmt.Appendf(nil, "id: %d\n", msg.ID))
}
if msg.Retry > 0 {
bw.Write(fmt.Appendf(nil, "retry: %d\n", msg.Retry))
}

event, ok := typeToEvent[deref(reflect.TypeOf(msg.Data))]
if !ok {
fmt.Fprintf(os.Stderr, "error: unknown event type %v\n", reflect.TypeOf(msg.Data))
debug.PrintStack()
}
if event != "" && event != "message" {
// `message` is the default, so no need to transmit it.
bw.Write([]byte("event: " + event + "\n"))
}

// Write the message data.
if _, err := bw.Write([]byte("data: ")); err != nil {
return err
}
if err := encoder.Encode(msg.Data); err != nil {
bw.Write([]byte(`{"error": "encode error: `))
bw.Write([]byte(err.Error()))
bw.Write([]byte("\"}\n\n"))
return err
}
bw.Write([]byte("\n"))
if flusher != nil {
flusher.Flush()
} else {
fmt.Fprintln(os.Stderr, "error: unable to flush")
return fmt.Errorf("unable to flush: %w", http.ErrNotSupported)
}
return nil
}

// Call the user-provided SSE handler.
f(ctx.Context(), input, send)
f(ctx.Context(), input, send(deadliner, flusher, bw))
},
}, nil
})
Expand Down