diff --git a/pkg/moqt/session/pubsub.go b/pkg/moqt/session/publish.go similarity index 53% rename from pkg/moqt/session/pubsub.go rename to pkg/moqt/session/publish.go index ea6df68..9ffb77f 100644 --- a/pkg/moqt/session/pubsub.go +++ b/pkg/moqt/session/publish.go @@ -7,7 +7,6 @@ import ( "github.com/floatdrop/moq-go/pkg/moqt" "github.com/floatdrop/moq-go/pkg/moqt/message" - "github.com/floatdrop/moq-go/pkg/moqt/track" ) // Publication is a live track this side publishes objects on. It owns the @@ -128,117 +127,3 @@ func (s *Session) Publish(ctx context.Context, m *message.Publish) (*Publication func (s *Session) OpenPublish(m *message.Publish) (Stream, error) { return s.openAllocRequest(m) } - -// Subscription is a live subscriber-initiated track subscription. It owns the -// request stream (embedded, so Close / reads / message.Marshal work directly -// on it) plus the identifiers follow-up traffic needs — the Request ID and the -// publisher-assigned Track Alias — so the caller can send REQUEST_UPDATE via -// [Subscription.Update] without holding them separately. It is returned by -// [Session.Subscribe]. -type Subscription struct { - // Stream is the SUBSCRIBE request stream, still open for follow-up - // traffic: REQUEST_UPDATE and inbound PUBLISH_DONE. Close it to end the - // subscription. - Stream - - // OK is the parsed SUBSCRIBE_OK response — the publisher-assigned Track - // Alias, negotiated Parameters, and TrackProperties. - OK *message.SubscribeOK - - s *Session - requestID uint64 -} - -// TrackAlias reports the §11.1 Track Alias the publisher assigned to this -// subscription — the integer inbound subgroup and datagram streams carry to -// identify the track (see [Session.AcceptDataStream]). It is shorthand for -// sub.OK.TrackAlias. -func (sub *Subscription) TrackAlias() uint64 { return sub.OK.TrackAlias } - -// Update sends a REQUEST_UPDATE (§10.9) on the subscription stream and awaits -// the single REQUEST_OK / REQUEST_ERROR the spec mandates. params carries only -// the fields to change; any parameter omitted keeps its prior value on the -// peer. It is [Session.UpdateRequest] with this subscription's stream and -// Request ID filled in. -func (sub *Subscription) Update(ctx context.Context, params message.Parameters) (*message.RequestOK, error) { - return sub.s.UpdateRequest(ctx, sub.Stream, sub.requestID, params) -} - -// Subscribe opens a SUBSCRIBE request stream (§10.7) and awaits SUBSCRIBE_OK. -// The session assigns m.RequestID; the caller supplies the rest. On success a -// [Subscription] is returned whose embedded stream stays open for follow-up -// traffic (REQUEST_UPDATE via [Subscription.Update], inbound PUBLISH_DONE) and -// whose [Subscription.TrackAlias] matches the alias on inbound subgroup -// streams. REQUEST_ERROR is surfaced as a *RequestRejectedError and the stream -// is closed. -func (s *Session) Subscribe(ctx context.Context, m *message.Subscribe) (*Subscription, error) { - stream, err := s.openAllocRequest(m) - if err != nil { - return nil, err - } - resp, err := s.readResponse(ctx, stream) - if err != nil { - _ = stream.Close() - return nil, fmt.Errorf("moqt/session: read SUBSCRIBE response: %w", err) - } - switch r := resp.(type) { - case *message.SubscribeOK: - // §2.5.1: reject tracks with unknown mandatory track properties. - if err := s.validateTrackProperties(r.TrackProperties, "SUBSCRIBE_OK"); err != nil { - _ = stream.Close() - return nil, err - } - // §11.1: register the alias the publisher assigned so we can detect - // DUPLICATE_TRACK_ALIAS if the same alias is reused for a different track. - key := track.NewKey(m.Namespace, m.Name) - if err := s.RegisterInboundTrackAlias(r.TrackAlias, key); err != nil { - _ = stream.Close() - return nil, err - } - return &Subscription{Stream: stream, OK: r, s: s, requestID: m.RequestID}, nil - case *message.RequestError: - _ = stream.Close() - return nil, &RequestRejectedError{Code: r.ErrorCode, Reason: r.ErrorReason} - default: - _ = stream.Close() - return nil, fmt.Errorf("moqt/session: unexpected %s in SUBSCRIBE response", resp.Type()) - } -} - -// UpdateRequest sends a REQUEST_UPDATE (§10.9) on an already-established -// request stream and awaits the single REQUEST_OK / REQUEST_ERROR the spec -// mandates in response. requestID MUST be the Request ID of the original -// request the stream carries — REQUEST_UPDATE rides the original bidi stream -// and does NOT consume a new Request ID. params carries only the fields the -// caller wants to change; any parameter omitted keeps its prior value on the -// peer (§10.9). -// -// On REQUEST_OK the parsed message is returned and the stream is left open -// for further traffic. REQUEST_ERROR is surfaced as a *RequestRejectedError; -// the stream is left open so the caller can decide how to tear down (a failed -// subscription update is followed by PUBLISH_DONE from the publisher, §10.9). -func (s *Session) UpdateRequest( - ctx context.Context, - stream Stream, - requestID uint64, - params message.Parameters, -) (*message.RequestOK, error) { - if err := message.Marshal(stream, &message.RequestUpdate{ - RequestID: requestID, - Parameters: params, - }); err != nil { - return nil, fmt.Errorf("moqt/session: write REQUEST_UPDATE: %w", err) - } - resp, err := s.readResponse(ctx, stream) - if err != nil { - return nil, fmt.Errorf("moqt/session: read REQUEST_UPDATE response: %w", err) - } - switch r := resp.(type) { - case *message.RequestOK: - return r, nil - case *message.RequestError: - return nil, &RequestRejectedError{Code: r.ErrorCode, Reason: r.ErrorReason} - default: - return nil, fmt.Errorf("moqt/session: unexpected %s in REQUEST_UPDATE response", resp.Type()) - } -} diff --git a/pkg/moqt/session/publish_test.go b/pkg/moqt/session/publish_test.go new file mode 100644 index 0000000..f6ee78f --- /dev/null +++ b/pkg/moqt/session/publish_test.go @@ -0,0 +1,351 @@ +package session_test + +import ( + "errors" + "fmt" + "io" + "sync" + "testing" + + "github.com/floatdrop/moq-go/pkg/moqt" + "github.com/floatdrop/moq-go/pkg/moqt/message" + "github.com/floatdrop/moq-go/pkg/moqt/session" + "github.com/floatdrop/moq-go/pkg/moqt/track" + "github.com/floatdrop/moq-go/pkg/moqt/wire" +) + +// TestPublishRoundTrip exercises the full PUBLISH flow: +// +// 1. Client calls Session.Publish → sends PUBLISH on a bidi stream. +// 2. Server accepts the request, verifies the first message, replies REQUEST_OK. +// 3. Client receives the open stream from Publish(). +func TestPublishRoundTrip(t *testing.T) { + cli, srv := openPair(t) + ctx := t.Context() + + ns := wire.TrackNamespace{[]byte("example.com"), []byte("live")} + req := &message.Publish{ + Namespace: ns, + Name: []byte("video"), + TrackAlias: 42, + } + + var ( + wg sync.WaitGroup + serverErr error + clientErr error + gotStream session.Stream + ) + + // Server: accept PUBLISH, verify fields, reply REQUEST_OK. + wg.Go(func() { + r, err := srv.AcceptRequest(ctx) + if err != nil { + serverErr = err + return + } + pub, ok := r.First.(*message.Publish) + if !ok { + serverErr = errors.New("server: expected *message.Publish, got " + r.First.Type().String()) + return + } + // RequestID must have been assigned by the client (even, starts at 0). + if pub.RequestID != 0 { + serverErr = errors.New("server: unexpected RequestID") + return + } + if string(pub.Name) != string(req.Name) { + serverErr = errors.New("server: Name mismatch") + return + } + if pub.TrackAlias != req.TrackAlias { + serverErr = errors.New("server: TrackAlias mismatch") + return + } + serverErr = r.Reply(&message.RequestOK{}) + }) + + // Client: call Publish, check the returned stream is non-nil. + wg.Go(func() { + stream, err := cli.Publish(ctx, req) + if err != nil { + clientErr = err + return + } + gotStream = stream + }) + + wg.Wait() + + if serverErr != nil { + t.Fatalf("server: %v", serverErr) + } + if clientErr != nil { + t.Fatalf("client Publish: %v", clientErr) + } + if gotStream == nil { + t.Fatal("Publish returned nil stream") + } + _ = gotStream.Close() +} + +// TestPublishRejected verifies that Session.Publish returns a +// *RequestRejectedError when the server replies with REQUEST_ERROR. +func TestPublishRejected(t *testing.T) { + cli, srv := openPair(t) + ctx := t.Context() + + var wg sync.WaitGroup + + wg.Go(func() { + r, err := srv.AcceptRequest(ctx) + if err != nil { + return + } + _ = r.RejectError(moqt.RequestDoesNotExist, "track not found") + }) + + wg.Go(func() { + _, err := cli.Publish(ctx, &message.Publish{ + Namespace: wire.TrackNamespace{[]byte("ns")}, + Name: []byte("missing"), + TrackAlias: 1, + }) + var rejected *session.RequestRejectedError + if !errors.As(err, &rejected) { + t.Errorf("Publish error = %v (%T), want *session.RequestRejectedError", err, err) + return + } + if rejected.Code != moqt.RequestDoesNotExist { + t.Errorf("Code = %v, want RequestDoesNotExist", rejected.Code) + } + }) + + wg.Wait() +} + +// TestPublicationDone verifies Publication.Done writes a PUBLISH_DONE whose +// §10.11 Stream Count reflects the subgroups opened via the handle, with the +// given code and reason, and then FINs the request stream. +func TestPublicationDone(t *testing.T) { + cli, srv := openPair(t) + ctx := t.Context() + + var ( + wg sync.WaitGroup + serverErr error + gotDone *message.PublishDone + ) + + // Server: accept the PUBLISH, reply OK, drain the two subgroup streams the + // client opens (the in-process pipe is synchronous), then read PUBLISH_DONE + // off the request stream. + wg.Go(func() { + r, err := srv.AcceptRequest(ctx) + if err != nil { + serverErr = err + return + } + if err := r.Reply(&message.RequestOK{}); err != nil { + serverErr = err + return + } + for range 2 { + ds, err := srv.AcceptDataStream(ctx) + if err != nil { + serverErr = err + return + } + _, _ = io.Copy(io.Discard, ds) + } + msg, err := message.Parse(r.Stream) + if err != nil { + serverErr = err + return + } + pd, ok := msg.(*message.PublishDone) + if !ok { + serverErr = fmt.Errorf("got %T on request stream, want *message.PublishDone", msg) + return + } + gotDone = pd + }) + + // Client: publish, open two subgroups via the handle, end with Done. + wg.Go(func() { + pub, err := cli.Publish(ctx, &message.Publish{ + Namespace: wire.Namespace("ns"), + Name: []byte("track"), + }) + if err != nil { + return + } + for g := range uint64(2) { + sg, err := pub.OpenSubgroup(message.SubgroupHeader{ + SubgroupIDMode: message.SubgroupIDImplicitZero, + GroupID: g, + }) + if err != nil { + return + } + _ = sg.WriteObjectAt(0, &message.SubgroupObject{Payload: []byte("x")}) + _ = sg.Close() + } + _ = pub.Done(moqt.PublishDoneTrackEnded, "bye") + }) + + wg.Wait() + + if serverErr != nil { + t.Fatalf("server: %v", serverErr) + } + if gotDone == nil { + t.Fatal("no PUBLISH_DONE received") + } + if gotDone.StatusCode != moqt.PublishDoneTrackEnded { + t.Errorf("StatusCode = %v, want PublishDoneTrackEnded", gotDone.StatusCode) + } + if gotDone.StreamCount != 2 { + t.Errorf("StreamCount = %d, want 2 (subgroups opened via the handle)", gotDone.StreamCount) + } + if gotDone.ErrorReason != "bye" { + t.Errorf("ErrorReason = %q, want %q", gotDone.ErrorReason, "bye") + } +} + +// TestRequestAcceptPublish verifies AcceptPublish registers the publisher's +// Track Alias (§11.1) and replies REQUEST_OK so the client's Publish succeeds. +func TestRequestAcceptPublish(t *testing.T) { + cli, srv := openPair(t) + ctx := t.Context() + + var ( + wg sync.WaitGroup + serverErr error + gotKey track.Key + gotOK bool + ) + wg.Go(func() { + req, err := srv.AcceptRequest(ctx) + if err != nil { + serverErr = err + return + } + if err := req.AcceptPublish(); err != nil { + serverErr = err + return + } + gotKey, gotOK = srv.LookupInboundTrackAlias(7) + }) + var clientErr error + wg.Go(func() { + _, clientErr = cli.Publish(ctx, &message.Publish{ + Namespace: wire.Namespace("ns"), + Name: []byte("track"), + TrackAlias: 7, + }) + }) + wg.Wait() + + if serverErr != nil { + t.Fatalf("server: %v", serverErr) + } + if clientErr != nil { + t.Fatalf("client Publish: %v", clientErr) + } + if !gotOK { + t.Fatal("AcceptPublish did not register the inbound Track Alias") + } + if want := track.NewKey(wire.Namespace("ns"), []byte("track")); gotKey != want { + t.Errorf("registered key = %v, want %v", gotKey, want) + } +} + +// TestOpenPublish_SuccessDeliversPublish verifies the happy path: OpenPublish +// assigns a Request ID, writes PUBLISH as the stream's first message, and the +// peer accepts a bidi request carrying that exact PUBLISH. +func TestOpenPublish_SuccessDeliversPublish(t *testing.T) { + t.Parallel() + client, server := openPairWithLimits(t, -1) + + var ( + wg sync.WaitGroup + req *session.Request + err error + ) + wg.Go(func() { req, err = server.AcceptRequest(t.Context()) }) + + m := &message.Publish{ + Namespace: wire.TrackNamespace{[]byte("video")}, + Name: []byte("cam1"), + TrackAlias: 7, + } + stream, openErr := client.OpenPublish(m) + if openErr != nil { + t.Fatalf("OpenPublish: %v", openErr) + } + defer stream.Close() + // Client Request IDs are even (§10.1); the first allocation is 0. + if m.RequestID%2 != 0 { + t.Fatalf("client Request ID %d is not even", m.RequestID) + } + + wg.Wait() + if err != nil { + t.Fatalf("AcceptRequest: %v", err) + } + pub, ok := req.First.(*message.Publish) + if !ok { + t.Fatalf("server got %T, want *message.Publish", req.First) + } + if string(pub.Name) != "cam1" || pub.TrackAlias != 7 { + t.Fatalf("server got Publish{Name:%q, Alias:%d}, want {cam1, 7}", pub.Name, pub.TrackAlias) + } +} + +// TestOpenPublish_ExhaustedCreditReturnsErrNoStreamCredit pins the +// PUBLISH_BLOCKED trigger: with the client's bidi credit used up, OpenPublish +// returns session.ErrNoStreamCredit rather than blocking. +func TestOpenPublish_ExhaustedCreditReturnsErrNoStreamCredit(t *testing.T) { + t.Parallel() + client, server := openPairWithLimits(t, 1) + + // Drain accepts so the first (successful) open's delivery never backs up. + go func() { + for { + if _, err := server.AcceptRequest(t.Context()); err != nil { + return + } + } + }() + + // First publish consumes the single unit of credit. + first := &message.Publish{ + Namespace: wire.TrackNamespace{[]byte("video")}, + Name: []byte("cam1"), + } + s1, err := client.OpenPublish(first) + if err != nil { + t.Fatalf("OpenPublish #0: %v", err) + } + defer s1.Close() + firstID := first.RequestID + + // Second publish: credit exhausted → ErrNoStreamCredit, no ID consumed. + second := &message.Publish{ + Namespace: wire.TrackNamespace{[]byte("video")}, + Name: []byte("cam2"), + } + _, err = client.OpenPublish(second) + if !errors.Is(err, session.ErrNoStreamCredit) { + t.Fatalf("OpenPublish #1 err = %v, want ErrNoStreamCredit", err) + } + + // §6.1: a blocked attempt MUST NOT consume a Request ID. The next + // successful allocation (via a plain AllocRequestID) must be exactly + // firstID+2, proving the blocked OpenPublish left the sequence untouched. + if got := client.AllocRequestID(); got != firstID+2 { + t.Fatalf("Request ID after blocked OpenPublish = %d, want %d (firstID %d + 2)", + got, firstID+2, firstID) + } +} diff --git a/pkg/moqt/session/pubsub_ctx_test.go b/pkg/moqt/session/pubsub_ctx_test.go deleted file mode 100644 index b3c5a28..0000000 --- a/pkg/moqt/session/pubsub_ctx_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package session_test - -import ( - "context" - "errors" - "sync" - "testing" - "time" - - "github.com/floatdrop/moq-go/pkg/moqt/message" - "github.com/floatdrop/moq-go/pkg/moqt/wire" -) - -// TestSubscribe_ContextCancelUnblocksResponseWait verifies that cancelling the -// ctx unblocks an awaiting request method even though message.Parse reads from -// a context-free io.Reader. The server accepts the SUBSCRIBE but never replies, -// so Subscribe blocks in readResponse; the context.AfterFunc hook resets the -// read side on cancel and the call returns ctx.Err() (wrapped as context.Canceled). -func TestSubscribe_ContextCancelUnblocksResponseWait(t *testing.T) { - t.Parallel() - client, server := openPairWithLimits(t, -1) - - // Server accepts the request and holds it open without ever replying. - var srvWG sync.WaitGroup - srvWG.Go(func() { - // AcceptRequest reads the SUBSCRIBE (which unblocks the client's write); - // we then deliberately never send SUBSCRIBE_OK / REQUEST_ERROR. - _, _ = server.AcceptRequest(t.Context()) - }) - - ctx, cancel := context.WithCancel(context.Background()) - - type result struct { - err error - } - resCh := make(chan result, 1) - go func() { - _, err := client.Subscribe(ctx, &message.Subscribe{ - Namespace: wire.TrackNamespace{[]byte("video")}, - Name: []byte("cam1"), - }) - resCh <- result{err: err} - }() - - // Cancelling makes the blocked response read return; the call must surface - // context.Canceled rather than hang. - cancel() - - select { - case res := <-resCh: - if !errors.Is(res.err, context.Canceled) { - t.Fatalf("Subscribe err = %v, want context.Canceled", res.err) - } - case <-time.After(2 * time.Second): - t.Fatal("Subscribe did not return after ctx cancel") - } - srvWG.Wait() -} diff --git a/pkg/moqt/session/pubsub_openpublish_test.go b/pkg/moqt/session/pubsub_openpublish_test.go deleted file mode 100644 index 87cf4cc..0000000 --- a/pkg/moqt/session/pubsub_openpublish_test.go +++ /dev/null @@ -1,132 +0,0 @@ -package session_test - -import ( - "errors" - "sync" - "testing" - - "github.com/floatdrop/moq-go/pkg/moqt" - "github.com/floatdrop/moq-go/pkg/moqt/message" - "github.com/floatdrop/moq-go/pkg/moqt/session" - "github.com/floatdrop/moq-go/pkg/moqt/session/sessiontest" - "github.com/floatdrop/moq-go/pkg/moqt/wire" -) - -// openPairWithLimits performs the SETUP handshake over a credit-capped conn -// pair. aBidiLimit caps the client's outbound bidi-stream credit (the SETUP -// control stream is unidirectional, so it is unaffected by the cap). A -// negative limit means unlimited. -func openPairWithLimits(t *testing.T, aBidiLimit int) (*session.Session, *session.Session) { - t.Helper() - ctx := t.Context() - aConn, bConn := sessiontest.NewConnPairWithLimits(aBidiLimit, -1) - - var ( - wg sync.WaitGroup - aSess, bSess *session.Session - aErr, bErr error - ) - wg.Go(func() { aSess, aErr = session.Client(ctx, aConn) }) - wg.Go(func() { bSess, bErr = session.Server(ctx, bConn) }) - wg.Wait() - if aErr != nil { - t.Fatalf("client Open: %v", aErr) - } - if bErr != nil { - t.Fatalf("server Open: %v", bErr) - } - t.Cleanup(func() { - _ = aSess.Close(moqt.SessionNoError, "test cleanup") - _ = bSess.Close(moqt.SessionNoError, "test cleanup") - }) - return aSess, bSess -} - -// TestOpenPublish_SuccessDeliversPublish verifies the happy path: OpenPublish -// assigns a Request ID, writes PUBLISH as the stream's first message, and the -// peer accepts a bidi request carrying that exact PUBLISH. -func TestOpenPublish_SuccessDeliversPublish(t *testing.T) { - t.Parallel() - client, server := openPairWithLimits(t, -1) - - var ( - wg sync.WaitGroup - req *session.Request - err error - ) - wg.Go(func() { req, err = server.AcceptRequest(t.Context()) }) - - m := &message.Publish{ - Namespace: wire.TrackNamespace{[]byte("video")}, - Name: []byte("cam1"), - TrackAlias: 7, - } - stream, openErr := client.OpenPublish(m) - if openErr != nil { - t.Fatalf("OpenPublish: %v", openErr) - } - defer stream.Close() - // Client Request IDs are even (§10.1); the first allocation is 0. - if m.RequestID%2 != 0 { - t.Fatalf("client Request ID %d is not even", m.RequestID) - } - - wg.Wait() - if err != nil { - t.Fatalf("AcceptRequest: %v", err) - } - pub, ok := req.First.(*message.Publish) - if !ok { - t.Fatalf("server got %T, want *message.Publish", req.First) - } - if string(pub.Name) != "cam1" || pub.TrackAlias != 7 { - t.Fatalf("server got Publish{Name:%q, Alias:%d}, want {cam1, 7}", pub.Name, pub.TrackAlias) - } -} - -// TestOpenPublish_ExhaustedCreditReturnsErrNoStreamCredit pins the -// PUBLISH_BLOCKED trigger: with the client's bidi credit used up, OpenPublish -// returns session.ErrNoStreamCredit rather than blocking. -func TestOpenPublish_ExhaustedCreditReturnsErrNoStreamCredit(t *testing.T) { - t.Parallel() - client, server := openPairWithLimits(t, 1) - - // Drain accepts so the first (successful) open's delivery never backs up. - go func() { - for { - if _, err := server.AcceptRequest(t.Context()); err != nil { - return - } - } - }() - - // First publish consumes the single unit of credit. - first := &message.Publish{ - Namespace: wire.TrackNamespace{[]byte("video")}, - Name: []byte("cam1"), - } - s1, err := client.OpenPublish(first) - if err != nil { - t.Fatalf("OpenPublish #0: %v", err) - } - defer s1.Close() - firstID := first.RequestID - - // Second publish: credit exhausted → ErrNoStreamCredit, no ID consumed. - second := &message.Publish{ - Namespace: wire.TrackNamespace{[]byte("video")}, - Name: []byte("cam2"), - } - _, err = client.OpenPublish(second) - if !errors.Is(err, session.ErrNoStreamCredit) { - t.Fatalf("OpenPublish #1 err = %v, want ErrNoStreamCredit", err) - } - - // §6.1: a blocked attempt MUST NOT consume a Request ID. The next - // successful allocation (via a plain AllocRequestID) must be exactly - // firstID+2, proving the blocked OpenPublish left the sequence untouched. - if got := client.AllocRequestID(); got != firstID+2 { - t.Fatalf("Request ID after blocked OpenPublish = %d, want %d (firstID %d + 2)", - got, firstID+2, firstID) - } -} diff --git a/pkg/moqt/session/request.go b/pkg/moqt/session/request.go index d1b3e55..9a91161 100644 --- a/pkg/moqt/session/request.go +++ b/pkg/moqt/session/request.go @@ -265,6 +265,44 @@ func (s *Session) readResponse(ctx context.Context, stream Stream) (message.Mess return msg, err } +// UpdateRequest sends a REQUEST_UPDATE (§10.9) on an already-established +// request stream and awaits the single REQUEST_OK / REQUEST_ERROR the spec +// mandates in response. requestID MUST be the Request ID of the original +// request the stream carries — REQUEST_UPDATE rides the original bidi stream +// and does NOT consume a new Request ID. params carries only the fields the +// caller wants to change; any parameter omitted keeps its prior value on the +// peer (§10.9). +// +// On REQUEST_OK the parsed message is returned and the stream is left open +// for further traffic. REQUEST_ERROR is surfaced as a *RequestRejectedError; +// the stream is left open so the caller can decide how to tear down (a failed +// subscription update is followed by PUBLISH_DONE from the publisher, §10.9). +func (s *Session) UpdateRequest( + ctx context.Context, + stream Stream, + requestID uint64, + params message.Parameters, +) (*message.RequestOK, error) { + if err := message.Marshal(stream, &message.RequestUpdate{ + RequestID: requestID, + Parameters: params, + }); err != nil { + return nil, fmt.Errorf("moqt/session: write REQUEST_UPDATE: %w", err) + } + resp, err := s.readResponse(ctx, stream) + if err != nil { + return nil, fmt.Errorf("moqt/session: read REQUEST_UPDATE response: %w", err) + } + switch r := resp.(type) { + case *message.RequestOK: + return r, nil + case *message.RequestError: + return nil, &RequestRejectedError{Code: r.ErrorCode, Reason: r.ErrorReason} + default: + return nil, fmt.Errorf("moqt/session: unexpected %s in REQUEST_UPDATE response", resp.Type()) + } +} + // Reply marshals a response message onto the request's bidi stream. The // stream is left open so further messages can be written. Use RejectError or // Stream.Close to terminate the send direction. diff --git a/pkg/moqt/session/session_test.go b/pkg/moqt/session/session_test.go index 7e98191..541903f 100644 --- a/pkg/moqt/session/session_test.go +++ b/pkg/moqt/session/session_test.go @@ -57,6 +57,36 @@ func openPair(t *testing.T) (*session.Session, *session.Session) { return aSess, bSess } +// openPairWithLimits performs the SETUP handshake over a credit-capped conn +// pair. aBidiLimit caps the client's outbound bidi-stream credit (the SETUP +// control stream is unidirectional, so it is unaffected by the cap). A +// negative limit means unlimited. +func openPairWithLimits(t *testing.T, aBidiLimit int) (*session.Session, *session.Session) { + t.Helper() + ctx := t.Context() + aConn, bConn := sessiontest.NewConnPairWithLimits(aBidiLimit, -1) + + var ( + wg sync.WaitGroup + aSess, bSess *session.Session + aErr, bErr error + ) + wg.Go(func() { aSess, aErr = session.Client(ctx, aConn) }) + wg.Go(func() { bSess, bErr = session.Server(ctx, bConn) }) + wg.Wait() + if aErr != nil { + t.Fatalf("client Open: %v", aErr) + } + if bErr != nil { + t.Fatalf("server Open: %v", bErr) + } + t.Cleanup(func() { + _ = aSess.Close(moqt.SessionNoError, "test cleanup") + _ = bSess.Close(moqt.SessionNoError, "test cleanup") + }) + return aSess, bSess +} + func TestHandshakeExchangesPeerOptions(t *testing.T) { client, server := openPair(t) diff --git a/pkg/moqt/session/subscribe.go b/pkg/moqt/session/subscribe.go new file mode 100644 index 0000000..4bf44e3 --- /dev/null +++ b/pkg/moqt/session/subscribe.go @@ -0,0 +1,85 @@ +package session + +import ( + "context" + "fmt" + + "github.com/floatdrop/moq-go/pkg/moqt/message" + "github.com/floatdrop/moq-go/pkg/moqt/track" +) + +// Subscription is a live subscriber-initiated track subscription. It owns the +// request stream (embedded, so Close / reads / message.Marshal work directly +// on it) plus the identifiers follow-up traffic needs — the Request ID and the +// publisher-assigned Track Alias — so the caller can send REQUEST_UPDATE via +// [Subscription.Update] without holding them separately. It is returned by +// [Session.Subscribe]. +type Subscription struct { + // Stream is the SUBSCRIBE request stream, still open for follow-up + // traffic: REQUEST_UPDATE and inbound PUBLISH_DONE. Close it to end the + // subscription. + Stream + + // OK is the parsed SUBSCRIBE_OK response — the publisher-assigned Track + // Alias, negotiated Parameters, and TrackProperties. + OK *message.SubscribeOK + + s *Session + requestID uint64 +} + +// TrackAlias reports the §11.1 Track Alias the publisher assigned to this +// subscription — the integer inbound subgroup and datagram streams carry to +// identify the track (see [Session.AcceptDataStream]). It is shorthand for +// sub.OK.TrackAlias. +func (sub *Subscription) TrackAlias() uint64 { return sub.OK.TrackAlias } + +// Update sends a REQUEST_UPDATE (§10.9) on the subscription stream and awaits +// the single REQUEST_OK / REQUEST_ERROR the spec mandates. params carries only +// the fields to change; any parameter omitted keeps its prior value on the +// peer. It is [Session.UpdateRequest] with this subscription's stream and +// Request ID filled in. +func (sub *Subscription) Update(ctx context.Context, params message.Parameters) (*message.RequestOK, error) { + return sub.s.UpdateRequest(ctx, sub.Stream, sub.requestID, params) +} + +// Subscribe opens a SUBSCRIBE request stream (§10.7) and awaits SUBSCRIBE_OK. +// The session assigns m.RequestID; the caller supplies the rest. On success a +// [Subscription] is returned whose embedded stream stays open for follow-up +// traffic (REQUEST_UPDATE via [Subscription.Update], inbound PUBLISH_DONE) and +// whose [Subscription.TrackAlias] matches the alias on inbound subgroup +// streams. REQUEST_ERROR is surfaced as a *RequestRejectedError and the stream +// is closed. +func (s *Session) Subscribe(ctx context.Context, m *message.Subscribe) (*Subscription, error) { + stream, err := s.openAllocRequest(m) + if err != nil { + return nil, err + } + resp, err := s.readResponse(ctx, stream) + if err != nil { + _ = stream.Close() + return nil, fmt.Errorf("moqt/session: read SUBSCRIBE response: %w", err) + } + switch r := resp.(type) { + case *message.SubscribeOK: + // §2.5.1: reject tracks with unknown mandatory track properties. + if err := s.validateTrackProperties(r.TrackProperties, "SUBSCRIBE_OK"); err != nil { + _ = stream.Close() + return nil, err + } + // §11.1: register the alias the publisher assigned so we can detect + // DUPLICATE_TRACK_ALIAS if the same alias is reused for a different track. + key := track.NewKey(m.Namespace, m.Name) + if err := s.RegisterInboundTrackAlias(r.TrackAlias, key); err != nil { + _ = stream.Close() + return nil, err + } + return &Subscription{Stream: stream, OK: r, s: s, requestID: m.RequestID}, nil + case *message.RequestError: + _ = stream.Close() + return nil, &RequestRejectedError{Code: r.ErrorCode, Reason: r.ErrorReason} + default: + _ = stream.Close() + return nil, fmt.Errorf("moqt/session: unexpected %s in SUBSCRIBE response", resp.Type()) + } +} diff --git a/pkg/moqt/session/pubsub_test.go b/pkg/moqt/session/subscribe_test.go similarity index 70% rename from pkg/moqt/session/pubsub_test.go rename to pkg/moqt/session/subscribe_test.go index d9a1a43..29e26cf 100644 --- a/pkg/moqt/session/pubsub_test.go +++ b/pkg/moqt/session/subscribe_test.go @@ -1,11 +1,12 @@ package session_test import ( + "context" "errors" "fmt" - "io" "sync" "testing" + "time" "github.com/floatdrop/moq-go/pkg/moqt" "github.com/floatdrop/moq-go/pkg/moqt/message" @@ -14,213 +15,6 @@ import ( "github.com/floatdrop/moq-go/pkg/moqt/wire" ) -// --------------------------------------------------------------------------- -// Publish -// --------------------------------------------------------------------------- - -// TestPublishRoundTrip exercises the full PUBLISH flow: -// -// 1. Client calls Session.Publish → sends PUBLISH on a bidi stream. -// 2. Server accepts the request, verifies the first message, replies REQUEST_OK. -// 3. Client receives the open stream from Publish(). -func TestPublishRoundTrip(t *testing.T) { - cli, srv := openPair(t) - ctx := t.Context() - - ns := wire.TrackNamespace{[]byte("example.com"), []byte("live")} - req := &message.Publish{ - Namespace: ns, - Name: []byte("video"), - TrackAlias: 42, - } - - var ( - wg sync.WaitGroup - serverErr error - clientErr error - gotStream session.Stream - ) - - // Server: accept PUBLISH, verify fields, reply REQUEST_OK. - wg.Go(func() { - r, err := srv.AcceptRequest(ctx) - if err != nil { - serverErr = err - return - } - pub, ok := r.First.(*message.Publish) - if !ok { - serverErr = errors.New("server: expected *message.Publish, got " + r.First.Type().String()) - return - } - // RequestID must have been assigned by the client (even, starts at 0). - if pub.RequestID != 0 { - serverErr = errors.New("server: unexpected RequestID") - return - } - if string(pub.Name) != string(req.Name) { - serverErr = errors.New("server: Name mismatch") - return - } - if pub.TrackAlias != req.TrackAlias { - serverErr = errors.New("server: TrackAlias mismatch") - return - } - serverErr = r.Reply(&message.RequestOK{}) - }) - - // Client: call Publish, check the returned stream is non-nil. - wg.Go(func() { - stream, err := cli.Publish(ctx, req) - if err != nil { - clientErr = err - return - } - gotStream = stream - }) - - wg.Wait() - - if serverErr != nil { - t.Fatalf("server: %v", serverErr) - } - if clientErr != nil { - t.Fatalf("client Publish: %v", clientErr) - } - if gotStream == nil { - t.Fatal("Publish returned nil stream") - } - _ = gotStream.Close() -} - -// TestPublishRejected verifies that Session.Publish returns a -// *RequestRejectedError when the server replies with REQUEST_ERROR. -func TestPublishRejected(t *testing.T) { - cli, srv := openPair(t) - ctx := t.Context() - - var wg sync.WaitGroup - - wg.Go(func() { - r, err := srv.AcceptRequest(ctx) - if err != nil { - return - } - _ = r.RejectError(moqt.RequestDoesNotExist, "track not found") - }) - - wg.Go(func() { - _, err := cli.Publish(ctx, &message.Publish{ - Namespace: wire.TrackNamespace{[]byte("ns")}, - Name: []byte("missing"), - TrackAlias: 1, - }) - var rejected *session.RequestRejectedError - if !errors.As(err, &rejected) { - t.Errorf("Publish error = %v (%T), want *session.RequestRejectedError", err, err) - return - } - if rejected.Code != moqt.RequestDoesNotExist { - t.Errorf("Code = %v, want RequestDoesNotExist", rejected.Code) - } - }) - - wg.Wait() -} - -// TestPublicationDone verifies Publication.Done writes a PUBLISH_DONE whose -// §10.11 Stream Count reflects the subgroups opened via the handle, with the -// given code and reason, and then FINs the request stream. -func TestPublicationDone(t *testing.T) { - cli, srv := openPair(t) - ctx := t.Context() - - var ( - wg sync.WaitGroup - serverErr error - gotDone *message.PublishDone - ) - - // Server: accept the PUBLISH, reply OK, drain the two subgroup streams the - // client opens (the in-process pipe is synchronous), then read PUBLISH_DONE - // off the request stream. - wg.Go(func() { - r, err := srv.AcceptRequest(ctx) - if err != nil { - serverErr = err - return - } - if err := r.Reply(&message.RequestOK{}); err != nil { - serverErr = err - return - } - for range 2 { - ds, err := srv.AcceptDataStream(ctx) - if err != nil { - serverErr = err - return - } - _, _ = io.Copy(io.Discard, ds) - } - msg, err := message.Parse(r.Stream) - if err != nil { - serverErr = err - return - } - pd, ok := msg.(*message.PublishDone) - if !ok { - serverErr = fmt.Errorf("got %T on request stream, want *message.PublishDone", msg) - return - } - gotDone = pd - }) - - // Client: publish, open two subgroups via the handle, end with Done. - wg.Go(func() { - pub, err := cli.Publish(ctx, &message.Publish{ - Namespace: wire.Namespace("ns"), - Name: []byte("track"), - }) - if err != nil { - return - } - for g := range uint64(2) { - sg, err := pub.OpenSubgroup(message.SubgroupHeader{ - SubgroupIDMode: message.SubgroupIDImplicitZero, - GroupID: g, - }) - if err != nil { - return - } - _ = sg.WriteObjectAt(0, &message.SubgroupObject{Payload: []byte("x")}) - _ = sg.Close() - } - _ = pub.Done(moqt.PublishDoneTrackEnded, "bye") - }) - - wg.Wait() - - if serverErr != nil { - t.Fatalf("server: %v", serverErr) - } - if gotDone == nil { - t.Fatal("no PUBLISH_DONE received") - } - if gotDone.StatusCode != moqt.PublishDoneTrackEnded { - t.Errorf("StatusCode = %v, want PublishDoneTrackEnded", gotDone.StatusCode) - } - if gotDone.StreamCount != 2 { - t.Errorf("StreamCount = %d, want 2 (subgroups opened via the handle)", gotDone.StreamCount) - } - if gotDone.ErrorReason != "bye" { - t.Errorf("ErrorReason = %q, want %q", gotDone.ErrorReason, "bye") - } -} - -// --------------------------------------------------------------------------- -// Subscribe -// --------------------------------------------------------------------------- - // TestSubscribeRoundTrip exercises the full SUBSCRIBE flow: // // 1. Client calls Session.Subscribe → sends SUBSCRIBE on a bidi stream. @@ -390,6 +184,52 @@ func TestSubscribeRequestIDIncrement(t *testing.T) { } } +// TestSubscribe_ContextCancelUnblocksResponseWait verifies that cancelling the +// ctx unblocks an awaiting request method even though message.Parse reads from +// a context-free io.Reader. The server accepts the SUBSCRIBE but never replies, +// so Subscribe blocks in readResponse; the context.AfterFunc hook resets the +// read side on cancel and the call returns ctx.Err() (wrapped as context.Canceled). +func TestSubscribe_ContextCancelUnblocksResponseWait(t *testing.T) { + t.Parallel() + client, server := openPairWithLimits(t, -1) + + // Server accepts the request and holds it open without ever replying. + var srvWG sync.WaitGroup + srvWG.Go(func() { + // AcceptRequest reads the SUBSCRIBE (which unblocks the client's write); + // we then deliberately never send SUBSCRIBE_OK / REQUEST_ERROR. + _, _ = server.AcceptRequest(t.Context()) + }) + + ctx, cancel := context.WithCancel(context.Background()) + + type result struct { + err error + } + resCh := make(chan result, 1) + go func() { + _, err := client.Subscribe(ctx, &message.Subscribe{ + Namespace: wire.TrackNamespace{[]byte("video")}, + Name: []byte("cam1"), + }) + resCh <- result{err: err} + }() + + // Cancelling makes the blocked response read return; the call must surface + // context.Canceled rather than hang. + cancel() + + select { + case res := <-resCh: + if !errors.Is(res.err, context.Canceled) { + t.Fatalf("Subscribe err = %v, want context.Canceled", res.err) + } + case <-time.After(2 * time.Second): + t.Fatal("Subscribe did not return after ctx cancel") + } + srvWG.Wait() +} + // --------------------------------------------------------------------------- // Duplicate Track Alias detection (§11.1) // --------------------------------------------------------------------------- @@ -642,54 +482,6 @@ func TestRequestAcceptSubscribe(t *testing.T) { } } -// TestRequestAcceptPublish verifies AcceptPublish registers the publisher's -// Track Alias (§11.1) and replies REQUEST_OK so the client's Publish succeeds. -func TestRequestAcceptPublish(t *testing.T) { - cli, srv := openPair(t) - ctx := t.Context() - - var ( - wg sync.WaitGroup - serverErr error - gotKey track.Key - gotOK bool - ) - wg.Go(func() { - req, err := srv.AcceptRequest(ctx) - if err != nil { - serverErr = err - return - } - if err := req.AcceptPublish(); err != nil { - serverErr = err - return - } - gotKey, gotOK = srv.LookupInboundTrackAlias(7) - }) - var clientErr error - wg.Go(func() { - _, clientErr = cli.Publish(ctx, &message.Publish{ - Namespace: wire.Namespace("ns"), - Name: []byte("track"), - TrackAlias: 7, - }) - }) - wg.Wait() - - if serverErr != nil { - t.Fatalf("server: %v", serverErr) - } - if clientErr != nil { - t.Fatalf("client Publish: %v", clientErr) - } - if !gotOK { - t.Fatal("AcceptPublish did not register the inbound Track Alias") - } - if want := track.NewKey(wire.Namespace("ns"), []byte("track")); gotKey != want { - t.Errorf("registered key = %v, want %v", gotKey, want) - } -} - // TestAcceptWrongType verifies the accept helpers reject a mismatched request // type without writing a response. func TestAcceptWrongType(t *testing.T) {