Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 84 additions & 56 deletions pkg/relay/session_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,23 +224,18 @@ func (h *sessionHandler) run(ctx context.Context) error {
// request and continues serving the session. This matches §9.5's rule that
// a single bad request must not break unrelated subscriptions.
func (h *sessionHandler) runRequestLoop(ctx context.Context) error {
for {
req, err := h.sess.AcceptRequest(ctx)
if err != nil {
// A malformed / duplicate / overflowing / unknown
// AUTHORIZATION_TOKEN alias is a session-level fault per
// §10.2.2: close the session with the mapped SESSION_ERROR
// code rather than just tearing down the request loop.
if tce, ok := errors.AsType[*session.TokenCacheError](err); ok {
h.log.LogAttrs(ctx, slog.LevelDebug, "relay closing session on token cache error",
slog.String("err", err.Error()),
slog.Uint64("code", uint64(tce.Code)))
_ = h.sess.Close(tce.Code, tce.Error())
}
return err
}
h.dispatch(ctx, req)
err := h.requestMux(ctx).Run(ctx, h.sess)
// A malformed / duplicate / overflowing / unknown AUTHORIZATION_TOKEN alias
// surfaces from AcceptRequest as a session-level fault per §10.2.2: close the
// session with the mapped SESSION_ERROR code rather than just tearing down
// the request loop.
if tce, ok := errors.AsType[*session.TokenCacheError](err); ok {
h.log.LogAttrs(ctx, slog.LevelDebug, "relay closing session on token cache error",
slog.String("err", err.Error()),
slog.Uint64("code", uint64(tce.Code)))
_ = h.sess.Close(tce.Code, tce.Error())
}
return err
}

// runDataLoop accepts inbound data streams and routes each by type: subgroup
Expand Down Expand Up @@ -281,71 +276,104 @@ func (h *sessionHandler) runDataLoop(ctx context.Context) error {
}
}

// dispatch routes one inbound [*session.Request] to the handler responsible
// for its First-message type. Each handler is expected to:
// requestMux builds the per-session [session.RequestMux] that routes each inbound
// request to the handler responsible for its First-message type. Each handler is
// expected to:
//
// 1. Authorize the request.
// 2. Reply with either *_OK or REQUEST_ERROR.
// 3. Keep the bidi stream open for as long as the subscription's lifetime
// warrants (or close it cleanly on rejection).
// 4. Update [registry.TrackRegistry] / [registry.NamespaceRegistry] as appropriate.
//
// Unknown message types are treated as protocol violations per §3.3.2: we
// reset the bidi stream and log. We do NOT close the session — §9.5 ("If a
// Session is closed due to an unknown or invalid control message or Object,
// the Relay MUST NOT propagate that message or Object to another Session")
// implies the relay must isolate the failure to the one request.
func (h *sessionHandler) dispatch(ctx context.Context, req *session.Request) {
h.log.LogAttrs(ctx, slog.LevelDebug, "relay dispatching request",
slog.String("type", fmt.Sprintf("%T", req.First)))

// §10.2.2: apply the application's TokenVerifier to the request's resolved
// AUTHORIZATION_TOKEN(s) before any handler runs. A denial is per-request:
// reply REQUEST_ERROR with the mapped code and keep the session running.
if err := h.sess.VerifyRequestTokens(ctx, req); err != nil {
h.rejectTokenDenied(ctx, req, err)
return
}
// Two cross-cutting policies are shared across the per-type handlers:
// verifyRequest applies the §10.2.2 token-verification pre-step, and
// namespaceRequest folds in the §13.7.1 per-session cap for the three
// namespace-state requests (the §13.1 subscription cap is inline on SUBSCRIBE).
//
// An unknown / unexpected first-message type is a protocol violation per §3.3.2:
// OnUnknown resets the bidi stream and logs. The session is NOT closed — §9.5
// ("if a Session is closed due to an unknown or invalid control message [...] the
// Relay MUST NOT propagate that message [...] to another Session") means the
// relay isolates the failure to the one request.
func (h *sessionHandler) requestMux(ctx context.Context) *session.RequestMux {
mux := session.NewRequestMux()

switch msg := req.First.(type) {
case *message.Subscribe:
session.HandleType(mux, func(req *session.Request, msg *message.Subscribe) {
if !h.verifyRequest(ctx, req) {
return
}
// §13.1: bound concurrent subscriptions per session.
if !h.limiter.acquireSub() {
h.rejectExcessiveLoad(ctx, req, "subscription")
return
}
h.spawn(func() { defer h.limiter.releaseSub(); h.handleSubscribe(ctx, req, msg) })
case *message.Publish:
h.spawn(func() { h.handlePublish(ctx, req, msg) })
case *message.Fetch:
h.spawn(func() { h.handleFetch(ctx, req, msg) })
case *message.TrackStatus:
h.spawn(func() { h.handleTrackStatus(ctx, req, msg) })
case *message.PublishNamespace:
// §13.7.1: bound concurrent namespace-state requests per session.
if !h.limiter.acquireNamespace() {
h.rejectExcessiveLoad(ctx, req, "namespace request")
})
session.HandleType(mux, func(req *session.Request, msg *message.Publish) {
if !h.verifyRequest(ctx, req) {
return
}
h.spawn(func() { defer h.limiter.releaseNamespace(); h.handlePublishNamespace(ctx, req, msg) })
case *message.SubscribeNamespace:
if !h.limiter.acquireNamespace() {
h.rejectExcessiveLoad(ctx, req, "namespace request")
h.spawn(func() { h.handlePublish(ctx, req, msg) })
})
session.HandleType(mux, func(req *session.Request, msg *message.Fetch) {
if !h.verifyRequest(ctx, req) {
return
}
h.spawn(func() { defer h.limiter.releaseNamespace(); h.handleSubscribeNamespace(ctx, req, msg) })
case *message.SubscribeTracks:
if !h.limiter.acquireNamespace() {
h.rejectExcessiveLoad(ctx, req, "namespace request")
h.spawn(func() { h.handleFetch(ctx, req, msg) })
})
session.HandleType(mux, func(req *session.Request, msg *message.TrackStatus) {
if !h.verifyRequest(ctx, req) {
return
}
h.spawn(func() { defer h.limiter.releaseNamespace(); h.handleSubscribeTracks(ctx, req, msg) })
default:
h.spawn(func() { h.handleTrackStatus(ctx, req, msg) })
})
session.HandleType(mux, func(req *session.Request, msg *message.PublishNamespace) {
h.namespaceRequest(ctx, req, func() { h.handlePublishNamespace(ctx, req, msg) })
})
session.HandleType(mux, func(req *session.Request, msg *message.SubscribeNamespace) {
h.namespaceRequest(ctx, req, func() { h.handleSubscribeNamespace(ctx, req, msg) })
})
session.HandleType(mux, func(req *session.Request, msg *message.SubscribeTracks) {
h.namespaceRequest(ctx, req, func() { h.handleSubscribeTracks(ctx, req, msg) })
})

mux.OnUnknown(func(req *session.Request) {
h.log.LogAttrs(ctx, slog.LevelWarn, "relay rejected unknown request type",
slog.String("type", fmt.Sprintf("%T", req.First)))
req.Stream.CancelRead(uint64(moqt.StreamResetInternalError))
req.Stream.CancelWrite(uint64(moqt.StreamResetInternalError))
})

return mux
}

// verifyRequest runs the per-request dispatch log and the §10.2.2 token
// verification shared by every known request type. It returns false — after
// replying REQUEST_ERROR with the mapped code — when the request's resolved
// AUTHORIZATION_TOKEN is denied; the session stays up (a denial is per-request).
func (h *sessionHandler) verifyRequest(ctx context.Context, req *session.Request) bool {
h.log.LogAttrs(ctx, slog.LevelDebug, "relay dispatching request",
slog.String("type", fmt.Sprintf("%T", req.First)))
if err := h.sess.VerifyRequestTokens(ctx, req); err != nil {
h.rejectTokenDenied(ctx, req, err)
return false
}
return true
}

// namespaceRequest wraps a namespace-state handler (PUBLISH_NAMESPACE,
// SUBSCRIBE_NAMESPACE, SUBSCRIBE_TRACKS) with the shared token verification and
// the §13.7.1 per-session cap, spawning fn under the limiter when admitted.
func (h *sessionHandler) namespaceRequest(ctx context.Context, req *session.Request, fn func()) {
if !h.verifyRequest(ctx, req) {
return
}
if !h.limiter.acquireNamespace() {
h.rejectExcessiveLoad(ctx, req, "namespace request")
return
}
h.spawn(func() { defer h.limiter.releaseNamespace(); fn() })
}

// spawn registers a goroutine with the handler's wg so run() can join it
Expand Down