diff --git a/pkg/relay/session_handler.go b/pkg/relay/session_handler.go index f075faa..823d020 100644 --- a/pkg/relay/session_handler.go +++ b/pkg/relay/session_handler.go @@ -224,23 +224,18 @@ func (h *sessionHandler) run(ctx context.Context) error { // request and continues serving the session. This matches §9.5's rule that // a single bad request must not break unrelated subscriptions. func (h *sessionHandler) runRequestLoop(ctx context.Context) error { - for { - req, err := h.sess.AcceptRequest(ctx) - if err != nil { - // A malformed / duplicate / overflowing / unknown - // AUTHORIZATION_TOKEN alias is a session-level fault per - // §10.2.2: close the session with the mapped SESSION_ERROR - // code rather than just tearing down the request loop. - if tce, ok := errors.AsType[*session.TokenCacheError](err); ok { - h.log.LogAttrs(ctx, slog.LevelDebug, "relay closing session on token cache error", - slog.String("err", err.Error()), - slog.Uint64("code", uint64(tce.Code))) - _ = h.sess.Close(tce.Code, tce.Error()) - } - return err - } - h.dispatch(ctx, req) + err := h.requestMux(ctx).Run(ctx, h.sess) + // A malformed / duplicate / overflowing / unknown AUTHORIZATION_TOKEN alias + // surfaces from AcceptRequest as a session-level fault per §10.2.2: close the + // session with the mapped SESSION_ERROR code rather than just tearing down + // the request loop. + if tce, ok := errors.AsType[*session.TokenCacheError](err); ok { + h.log.LogAttrs(ctx, slog.LevelDebug, "relay closing session on token cache error", + slog.String("err", err.Error()), + slog.Uint64("code", uint64(tce.Code))) + _ = h.sess.Close(tce.Code, tce.Error()) } + return err } // runDataLoop accepts inbound data streams and routes each by type: subgroup @@ -281,8 +276,9 @@ func (h *sessionHandler) runDataLoop(ctx context.Context) error { } } -// dispatch routes one inbound [*session.Request] to the handler responsible -// for its First-message type. Each handler is expected to: +// requestMux builds the per-session [session.RequestMux] that routes each inbound +// request to the handler responsible for its First-message type. Each handler is +// expected to: // // 1. Authorize the request. // 2. Reply with either *_OK or REQUEST_ERROR. @@ -290,62 +286,94 @@ func (h *sessionHandler) runDataLoop(ctx context.Context) error { // warrants (or close it cleanly on rejection). // 4. Update [registry.TrackRegistry] / [registry.NamespaceRegistry] as appropriate. // -// Unknown message types are treated as protocol violations per §3.3.2: we -// reset the bidi stream and log. We do NOT close the session — §9.5 ("If a -// Session is closed due to an unknown or invalid control message or Object, -// the Relay MUST NOT propagate that message or Object to another Session") -// implies the relay must isolate the failure to the one request. -func (h *sessionHandler) dispatch(ctx context.Context, req *session.Request) { - h.log.LogAttrs(ctx, slog.LevelDebug, "relay dispatching request", - slog.String("type", fmt.Sprintf("%T", req.First))) - - // §10.2.2: apply the application's TokenVerifier to the request's resolved - // AUTHORIZATION_TOKEN(s) before any handler runs. A denial is per-request: - // reply REQUEST_ERROR with the mapped code and keep the session running. - if err := h.sess.VerifyRequestTokens(ctx, req); err != nil { - h.rejectTokenDenied(ctx, req, err) - return - } +// Two cross-cutting policies are shared across the per-type handlers: +// verifyRequest applies the §10.2.2 token-verification pre-step, and +// namespaceRequest folds in the §13.7.1 per-session cap for the three +// namespace-state requests (the §13.1 subscription cap is inline on SUBSCRIBE). +// +// An unknown / unexpected first-message type is a protocol violation per §3.3.2: +// OnUnknown resets the bidi stream and logs. The session is NOT closed — §9.5 +// ("if a Session is closed due to an unknown or invalid control message [...] the +// Relay MUST NOT propagate that message [...] to another Session") means the +// relay isolates the failure to the one request. +func (h *sessionHandler) requestMux(ctx context.Context) *session.RequestMux { + mux := session.NewRequestMux() - switch msg := req.First.(type) { - case *message.Subscribe: + session.HandleType(mux, func(req *session.Request, msg *message.Subscribe) { + if !h.verifyRequest(ctx, req) { + return + } // §13.1: bound concurrent subscriptions per session. if !h.limiter.acquireSub() { h.rejectExcessiveLoad(ctx, req, "subscription") return } h.spawn(func() { defer h.limiter.releaseSub(); h.handleSubscribe(ctx, req, msg) }) - case *message.Publish: - h.spawn(func() { h.handlePublish(ctx, req, msg) }) - case *message.Fetch: - h.spawn(func() { h.handleFetch(ctx, req, msg) }) - case *message.TrackStatus: - h.spawn(func() { h.handleTrackStatus(ctx, req, msg) }) - case *message.PublishNamespace: - // §13.7.1: bound concurrent namespace-state requests per session. - if !h.limiter.acquireNamespace() { - h.rejectExcessiveLoad(ctx, req, "namespace request") + }) + session.HandleType(mux, func(req *session.Request, msg *message.Publish) { + if !h.verifyRequest(ctx, req) { return } - h.spawn(func() { defer h.limiter.releaseNamespace(); h.handlePublishNamespace(ctx, req, msg) }) - case *message.SubscribeNamespace: - if !h.limiter.acquireNamespace() { - h.rejectExcessiveLoad(ctx, req, "namespace request") + h.spawn(func() { h.handlePublish(ctx, req, msg) }) + }) + session.HandleType(mux, func(req *session.Request, msg *message.Fetch) { + if !h.verifyRequest(ctx, req) { return } - h.spawn(func() { defer h.limiter.releaseNamespace(); h.handleSubscribeNamespace(ctx, req, msg) }) - case *message.SubscribeTracks: - if !h.limiter.acquireNamespace() { - h.rejectExcessiveLoad(ctx, req, "namespace request") + h.spawn(func() { h.handleFetch(ctx, req, msg) }) + }) + session.HandleType(mux, func(req *session.Request, msg *message.TrackStatus) { + if !h.verifyRequest(ctx, req) { return } - h.spawn(func() { defer h.limiter.releaseNamespace(); h.handleSubscribeTracks(ctx, req, msg) }) - default: + h.spawn(func() { h.handleTrackStatus(ctx, req, msg) }) + }) + session.HandleType(mux, func(req *session.Request, msg *message.PublishNamespace) { + h.namespaceRequest(ctx, req, func() { h.handlePublishNamespace(ctx, req, msg) }) + }) + session.HandleType(mux, func(req *session.Request, msg *message.SubscribeNamespace) { + h.namespaceRequest(ctx, req, func() { h.handleSubscribeNamespace(ctx, req, msg) }) + }) + session.HandleType(mux, func(req *session.Request, msg *message.SubscribeTracks) { + h.namespaceRequest(ctx, req, func() { h.handleSubscribeTracks(ctx, req, msg) }) + }) + + mux.OnUnknown(func(req *session.Request) { h.log.LogAttrs(ctx, slog.LevelWarn, "relay rejected unknown request type", slog.String("type", fmt.Sprintf("%T", req.First))) req.Stream.CancelRead(uint64(moqt.StreamResetInternalError)) req.Stream.CancelWrite(uint64(moqt.StreamResetInternalError)) + }) + + return mux +} + +// verifyRequest runs the per-request dispatch log and the §10.2.2 token +// verification shared by every known request type. It returns false — after +// replying REQUEST_ERROR with the mapped code — when the request's resolved +// AUTHORIZATION_TOKEN is denied; the session stays up (a denial is per-request). +func (h *sessionHandler) verifyRequest(ctx context.Context, req *session.Request) bool { + h.log.LogAttrs(ctx, slog.LevelDebug, "relay dispatching request", + slog.String("type", fmt.Sprintf("%T", req.First))) + if err := h.sess.VerifyRequestTokens(ctx, req); err != nil { + h.rejectTokenDenied(ctx, req, err) + return false + } + return true +} + +// namespaceRequest wraps a namespace-state handler (PUBLISH_NAMESPACE, +// SUBSCRIBE_NAMESPACE, SUBSCRIBE_TRACKS) with the shared token verification and +// the §13.7.1 per-session cap, spawning fn under the limiter when admitted. +func (h *sessionHandler) namespaceRequest(ctx context.Context, req *session.Request, fn func()) { + if !h.verifyRequest(ctx, req) { + return + } + if !h.limiter.acquireNamespace() { + h.rejectExcessiveLoad(ctx, req, "namespace request") + return } + h.spawn(func() { defer h.limiter.releaseNamespace(); fn() }) } // spawn registers a goroutine with the handler's wg so run() can join it