diff --git a/backend/internal/adapters/scm/github/observer_provider.go b/backend/internal/adapters/scm/github/observer_provider.go index bfb703d2..8e0ca91b 100644 --- a/backend/internal/adapters/scm/github/observer_provider.go +++ b/backend/internal/adapters/scm/github/observer_provider.go @@ -57,36 +57,36 @@ func (p *Provider) RepoPRListGuard(ctx context.Context, repo ports.SCMRepo, etag return ports.SCMGuardResult{ETag: firstNonEmptyHeader(resp.ETag, etag), NotModified: resp.NotModified}, nil } -// DetectPRByBranch finds an open PR whose head branch matches the session branch. -func (p *Provider) DetectPRByBranch(ctx context.Context, repo ports.SCMRepo, branch string) (ports.SCMPRObservation, error) { - branch = strings.TrimSpace(branch) - if branch == "" { - return ports.SCMPRObservation{}, fmt.Errorf("%w: empty branch", ErrNotFound) - } - pulls, err := p.detectPRByHead(ctx, repo, repo.Owner+":"+branch) - if err != nil { - return ports.SCMPRObservation{}, err - } - if len(pulls) == 0 { - return ports.SCMPRObservation{}, fmt.Errorf("%w: no open PR for branch %s", ErrNotFound, branch) - } - return restListPullToSCM(pulls[0]), nil -} - -func (p *Provider) detectPRByHead(ctx context.Context, repo ports.SCMRepo, head string) ([]restListPull, error) { - q := url.Values{} - q.Set("state", "open") - q.Set("head", head) - q.Set("per_page", "10") - resp, err := p.client.doREST(ctx, http.MethodGet, repoPath(repo.Owner, repo.Name, "pulls"), q, nil) - if err != nil { - return nil, err - } - var pulls []restListPull - if err := json.Unmarshal(resp.Body, &pulls); err != nil { - return nil, fmt.Errorf("github scm: decode branch PR list: %w", err) +// ListOpenPRsByRepo lists every open pull request in the repository so the +// observer can attribute each to a session by head-branch prefix. It paginates +// the REST pulls endpoint; AO repos are not expected to carry thousands of +// concurrent open PRs, and the observer only calls this when the repo PR-list +// ETag guard reports a change. +func (p *Provider) ListOpenPRsByRepo(ctx context.Context, repo ports.SCMRepo) ([]ports.SCMPRObservation, error) { + const perPage = 100 + out := []ports.SCMPRObservation{} + for page := 1; ; page++ { + q := url.Values{} + q.Set("state", "open") + q.Set("sort", "updated") + q.Set("direction", "desc") + q.Set("per_page", strconv.Itoa(perPage)) + q.Set("page", strconv.Itoa(page)) + resp, err := p.client.doREST(ctx, http.MethodGet, repoPath(repo.Owner, repo.Name, "pulls"), q, nil) + if err != nil { + return nil, err + } + var pulls []restListPull + if err := json.Unmarshal(resp.Body, &pulls); err != nil { + return nil, fmt.Errorf("github scm: decode open PR list: %w", err) + } + for _, pull := range pulls { + out = append(out, restListPullToSCM(pull)) + } + if len(pulls) < perPage { + return out, nil + } } - return pulls, nil } // CommitChecksGuard checks GitHub's per-commit check-runs ETag guard. diff --git a/backend/internal/domain/pr.go b/backend/internal/domain/pr.go index 704c0a7c..89b4a961 100644 --- a/backend/internal/domain/pr.go +++ b/backend/internal/domain/pr.go @@ -16,6 +16,8 @@ type PRFacts struct { Review ReviewDecision Mergeability Mergeability ReviewComments bool // has unresolved review comments (any author) to address + SourceBranch string + TargetBranch string UpdatedAt time.Time } diff --git a/backend/internal/domain/session.go b/backend/internal/domain/session.go index 7e289b37..6cc639e5 100644 --- a/backend/internal/domain/session.go +++ b/backend/internal/domain/session.go @@ -61,4 +61,8 @@ type Session struct { SessionRecord Status SessionStatus `json:"status"` TerminalHandleID string `json:"terminalHandleId,omitempty"` + // PRs are the session's attributed pull requests (one session can own many). + // They feed status derivation and are surfaced on the API read model. Not + // serialized here: the HTTP boundary maps them to the curated wire shape. + PRs []PRFacts `json:"-"` } diff --git a/backend/internal/httpd/apispec/openapi.yaml b/backend/internal/httpd/apispec/openapi.yaml index c91c02e3..96aa607e 100644 --- a/backend/internal/httpd/apispec/openapi.yaml +++ b/backend/internal/httpd/apispec/openapi.yaml @@ -1136,6 +1136,49 @@ components: - sessionId - reason type: object + ControllersSessionView: + properties: + activity: + $ref: '#/components/schemas/DomainActivity' + createdAt: + format: date-time + type: string + displayName: + type: string + harness: + type: string + id: + type: string + isTerminated: + type: boolean + issueId: + type: string + kind: + type: string + projectId: + type: string + prs: + items: + $ref: '#/components/schemas/SessionPRFacts' + type: array + status: + type: string + terminalHandleId: + type: string + updatedAt: + format: date-time + type: string + required: + - id + - projectId + - kind + - activity + - isTerminated + - createdAt + - updatedAt + - status + - prs + type: object DegradedProject: properties: id: @@ -1220,7 +1263,7 @@ components: properties: sessions: items: - $ref: '#/components/schemas/Session' + $ref: '#/components/schemas/ControllersSessionView' type: array required: - sessions @@ -1397,7 +1440,7 @@ components: ok: type: boolean session: - $ref: '#/components/schemas/Session' + $ref: '#/components/schemas/ControllersSessionView' sessionId: type: string required: @@ -1496,44 +1539,6 @@ components: - sessionId - message type: object - Session: - properties: - activity: - $ref: '#/components/schemas/DomainActivity' - createdAt: - format: date-time - type: string - displayName: - type: string - harness: - type: string - id: - type: string - isTerminated: - type: boolean - issueId: - type: string - kind: - type: string - projectId: - type: string - status: - type: string - terminalHandleId: - type: string - updatedAt: - format: date-time - type: string - required: - - id - - projectId - - kind - - activity - - isTerminated - - createdAt - - updatedAt - - status - type: object SessionPRFacts: properties: ci: @@ -1571,7 +1576,7 @@ components: SessionResponse: properties: session: - $ref: '#/components/schemas/Session' + $ref: '#/components/schemas/ControllersSessionView' required: - session type: object diff --git a/backend/internal/httpd/controllers/dto.go b/backend/internal/httpd/controllers/dto.go index 557c42ee..551319ad 100644 --- a/backend/internal/httpd/controllers/dto.go +++ b/backend/internal/httpd/controllers/dto.go @@ -111,9 +111,18 @@ type CleanupSessionsQuery struct { Project string `query:"project,omitempty" description:"Project id filter. When omitted, clean terminated sessions across all projects."` } +// SessionView is the session wire shape: the domain read model plus the +// session's attributed pull requests in the curated SessionPRFacts shape. One +// session can own many PRs (e.g. a stack), so prs is a list. The embedded +// domain.Session.PRs field is json:"-"; this curated prs is what serializes. +type SessionView struct { + domain.Session + PRs []SessionPRFacts `json:"prs"` +} + // ListSessionsResponse is the body of GET /api/v1/sessions. type ListSessionsResponse struct { - Sessions []domain.Session `json:"sessions"` + Sessions []SessionView `json:"sessions"` } // SpawnSessionRequest is the body of POST /api/v1/sessions. @@ -128,7 +137,7 @@ type SpawnSessionRequest struct { // SessionResponse is the { session } body shared by session create/get. type SessionResponse struct { - Session domain.Session `json:"session"` + Session SessionView `json:"session"` } // RenameSessionRequest is the body of PATCH /api/v1/sessions/{sessionId}. @@ -147,7 +156,7 @@ type RenameSessionResponse struct { type RestoreSessionResponse struct { OK bool `json:"ok"` SessionID domain.SessionID `json:"sessionId"` - Session domain.Session `json:"session"` + Session SessionView `json:"session"` } // KillSessionResponse is the body of POST /api/v1/sessions/{sessionId}/kill. diff --git a/backend/internal/httpd/controllers/sessions.go b/backend/internal/httpd/controllers/sessions.go index bf033584..81aa285b 100644 --- a/backend/internal/httpd/controllers/sessions.go +++ b/backend/internal/httpd/controllers/sessions.go @@ -88,7 +88,7 @@ func (c *SessionsController) list(w http.ResponseWriter, r *http.Request) { envelope.WriteError(w, r, err) return } - envelope.WriteJSON(w, http.StatusOK, ListSessionsResponse{Sessions: sessions}) + envelope.WriteJSON(w, http.StatusOK, ListSessionsResponse{Sessions: sessionViews(sessions)}) } func (c *SessionsController) spawn(w http.ResponseWriter, r *http.Request) { @@ -117,7 +117,7 @@ func (c *SessionsController) spawn(w http.ResponseWriter, r *http.Request) { envelope.WriteError(w, r, err) return } - envelope.WriteJSON(w, http.StatusCreated, SessionResponse{Session: sess}) + envelope.WriteJSON(w, http.StatusCreated, SessionResponse{Session: sessionView(sess)}) } func (c *SessionsController) get(w http.ResponseWriter, r *http.Request) { @@ -130,7 +130,7 @@ func (c *SessionsController) get(w http.ResponseWriter, r *http.Request) { envelope.WriteError(w, r, err) return } - envelope.WriteJSON(w, http.StatusOK, SessionResponse{Session: sess}) + envelope.WriteJSON(w, http.StatusOK, SessionResponse{Session: sessionView(sess)}) } func (c *SessionsController) listPRs(w http.ResponseWriter, r *http.Request) { @@ -204,7 +204,7 @@ func (c *SessionsController) restore(w http.ResponseWriter, r *http.Request) { envelope.WriteError(w, r, err) return } - envelope.WriteJSON(w, http.StatusOK, RestoreSessionResponse{OK: true, SessionID: sessionID(r), Session: sess}) + envelope.WriteJSON(w, http.StatusOK, RestoreSessionResponse{OK: true, SessionID: sessionID(r), Session: sessionView(sess)}) } func (c *SessionsController) kill(w http.ResponseWriter, r *http.Request) { @@ -344,7 +344,7 @@ func (c *SessionsController) listOrchestrators(w http.ResponseWriter, r *http.Re envelope.WriteError(w, r, err) return } - envelope.WriteJSON(w, http.StatusOK, ListSessionsResponse{Sessions: sessions}) + envelope.WriteJSON(w, http.StatusOK, ListSessionsResponse{Sessions: sessionViews(sessions)}) } func (c *SessionsController) getOrchestrator(w http.ResponseWriter, r *http.Request) { @@ -361,7 +361,7 @@ func (c *SessionsController) getOrchestrator(w http.ResponseWriter, r *http.Requ envelope.WriteAPIError(w, r, http.StatusNotFound, "not_found", "SESSION_NOT_FOUND", "Unknown session", nil) return } - envelope.WriteJSON(w, http.StatusOK, SessionResponse{Session: sess}) + envelope.WriteJSON(w, http.StatusOK, SessionResponse{Session: sessionView(sess)}) } func sessionID(r *http.Request) domain.SessionID { @@ -432,6 +432,18 @@ func writeSessionPRError(w http.ResponseWriter, r *http.Request, err error) { } } +func sessionView(s domain.Session) SessionView { + return SessionView{Session: s, PRs: sessionPRFacts(s.PRs)} +} + +func sessionViews(sessions []domain.Session) []SessionView { + out := make([]SessionView, 0, len(sessions)) + for _, s := range sessions { + out = append(out, sessionView(s)) + } + return out +} + func sessionPRFacts(prs []domain.PRFacts) []SessionPRFacts { out := make([]SessionPRFacts, 0, len(prs)) for _, pr := range prs { diff --git a/backend/internal/integration/scm_observer_test.go b/backend/internal/integration/scm_observer_test.go index 94e83385..26f4d20d 100644 --- a/backend/internal/integration/scm_observer_test.go +++ b/backend/internal/integration/scm_observer_test.go @@ -67,7 +67,7 @@ func (s *scmMessengerSpy) snapshot() []scmCapturedNudge { } // cannedSCMProvider satisfies observe/scm.Provider with hand-built observations -// keyed by branch (for DetectPRByBranch) and by PR number (for everything else, +// keyed by branch (for ListOpenPRsByRepo) and by PR number (for everything else, // since every test case uses scmTestRepo). It is the integration-package analog // of observer_test.go's fakeProvider: the SCM adapter has its own httptest-based // coverage, so this test holds the provider constant and exercises every other @@ -101,14 +101,14 @@ func (p *cannedSCMProvider) RepoPRListGuard(_ context.Context, _ ports.SCMRepo, return ports.SCMGuardResult{ETag: "repo-etag"}, nil } -func (p *cannedSCMProvider) DetectPRByBranch(_ context.Context, _ ports.SCMRepo, branch string) (ports.SCMPRObservation, error) { +func (p *cannedSCMProvider) ListOpenPRsByRepo(_ context.Context, _ ports.SCMRepo) ([]ports.SCMPRObservation, error) { p.mu.Lock() defer p.mu.Unlock() - pr, ok := p.detected[branch] - if !ok { - return ports.SCMPRObservation{}, ports.ErrSCMNotFound + out := make([]ports.SCMPRObservation, 0, len(p.detected)) + for _, pr := range p.detected { + out = append(out, pr) } - return pr, nil + return out, nil } func (p *cannedSCMProvider) CommitChecksGuard(_ context.Context, _ ports.SCMRepo, _, _ string) (ports.SCMGuardResult, error) { @@ -350,7 +350,7 @@ func TestSCMObserverEndToEnd(t *testing.T) { // the production fallback the observer relies on when the upstream // ETag guard misses. The ETag-driven 304 short-circuit on the same // SHA is covered by the unit tests in observe/scm/observer_test.go - // (Poll_RepoETag304SkipsDetectPR, Poll_CIETagChangeRefreshesWhenRepoUnchanged). + // (Poll_RepoETag304SkipsListPRs, Poll_CIETagChangeRefreshesWhenRepoUnchanged). if err := f.observer.Poll(ctx); err != nil { t.Fatalf("second Poll: %v", err) } @@ -397,7 +397,7 @@ func TestSCMObserverEndToEnd(t *testing.T) { t.Run("Branch with no open PR writes nothing and sends no nudge", func(t *testing.T) { ctx := context.Background() f := newSCMFixture(t, "feat/quiet") - // No entry in provider.detected — DetectPRByBranch returns ErrSCMNotFound, + // No entry in provider.detected — ListOpenPRsByRepo returns an empty list, // the production "no PR yet" signal. if err := f.observer.Poll(ctx); err != nil { @@ -416,3 +416,224 @@ func TestSCMObserverEndToEnd(t *testing.T) { } }) } + +// openSCMObservation builds an open-PR observation with caller-chosen branches +// and mergeability, CI passing and no review. The multi-PR cases drive the stack +// model (target/source branch pairs) and the completion rule, so branches must +// be configurable rather than the fixed feat/x->main the single-PR helpers bake in. +func openSCMObservation(prURL string, num int, headSHA, src, tgt string, merge domain.Mergeability) ports.SCMObservation { + mo := ports.SCMMergeabilityObservation{State: string(merge)} + switch merge { + case domain.MergeMergeable: + mo.Mergeable = true + case domain.MergeConflicting: + mo.Conflict = true + mo.Blockers = []string{"conflicts"} + } + return ports.SCMObservation{ + Fetched: true, + Provider: "github", Host: "github.com", Repo: "octocat/hello", + PR: ports.SCMPRObservation{ + URL: prURL, + HTMLURL: prURL, + Number: num, + State: string(domain.PRStateOpen), + SourceBranch: src, + TargetBranch: tgt, + HeadSHA: headSHA, + Title: "wip", + }, + CI: ports.SCMCIObservation{Summary: string(domain.CIPassing), HeadSHA: headSHA}, + Review: ports.SCMReviewObservation{Decision: string(domain.ReviewNone)}, + Mergeability: mo, + } +} + +// mergedSCMObservationBranches is mergedSCMObservation with caller-chosen +// branches so a stacked child (feat/x/auth -> feat/x) can be merged distinctly +// from the root (feat/x -> main). +func mergedSCMObservationBranches(prURL string, num int, headSHA, src, tgt string) ports.SCMObservation { + o := mergedSCMObservation(prURL, num, headSHA) + o.PR.SourceBranch = src + o.PR.TargetBranch = tgt + return o +} + +// detectedPR is the open-PR-list discovery shape: the observer attributes a +// listed PR to a session by source-branch prefix, so only identity + branches +// matter here. +func detectedPR(prURL string, num int, src, tgt, headSHA string) ports.SCMPRObservation { + return ports.SCMPRObservation{URL: prURL, HTMLURL: prURL, Number: num, SourceBranch: src, TargetBranch: tgt, HeadSHA: headSHA} +} + +// TestSCMObserverMultiPREndToEnd is the functional regression guard for the +// multi-PR-per-session feature. It drives the real store + lifecycle + observer +// through the three behaviours the feature adds on top of the single-PR lane: +// branch-prefix attribution of several PRs to one session, the "all PRs +// merged/closed and at least one merged" completion bar, and the stacked-child +// merge-conflict nudge suppression. The SCM provider is canned (its own httptest +// coverage lives in observe/scm), so every other layer runs for real. +func TestSCMObserverMultiPREndToEnd(t *testing.T) { + t.Run("one session owns its root and stacked child PRs from a single repo list", func(t *testing.T) { + ctx := context.Background() + f := newSCMFixture(t, "feat/x") + const ( + rootURL = "https://github.com/octocat/hello/pull/101" + childURL = "https://github.com/octocat/hello/pull/102" + ) + // Root PR on the session branch, plus a stacked child whose source branch + // descends from it (feat/x/auth). matchSession claims both for the one + // session: the child by the "branch/..." stacking convention. + f.provider.detected["feat/x"] = detectedPR(rootURL, 101, "feat/x", "main", "sha-root") + f.provider.detected["feat/x/auth"] = detectedPR(childURL, 102, "feat/x/auth", "feat/x", "sha-child") + f.provider.observations[101] = openSCMObservation(rootURL, 101, "sha-root", "feat/x", "main", domain.MergeMergeable) + f.provider.observations[102] = openSCMObservation(childURL, 102, "sha-child", "feat/x/auth", "feat/x", domain.MergeBlocked) + + if err := f.observer.Poll(ctx); err != nil { + t.Fatalf("Poll: %v", err) + } + + prs, err := f.store.ListPRsBySession(ctx, f.session.ID) + if err != nil { + t.Fatalf("ListPRsBySession: %v", err) + } + if len(prs) != 2 { + t.Fatalf("one session should own both discovered PRs, got %d: %+v", len(prs), prs) + } + byURL := map[string]domain.PullRequest{} + for _, pr := range prs { + if pr.SessionID != f.session.ID { + t.Fatalf("PR %q attributed to %q, want %q", pr.URL, pr.SessionID, f.session.ID) + } + byURL[pr.URL] = pr + } + // The branch pair is what the stack model is derived from, so it must be + // persisted by the observer write path (not just discovered). + if byURL[rootURL].SourceBranch != "feat/x" || byURL[rootURL].TargetBranch != "main" { + t.Fatalf("root branch pair lost: %+v", byURL[rootURL]) + } + if byURL[childURL].SourceBranch != "feat/x/auth" || byURL[childURL].TargetBranch != "feat/x" { + t.Fatalf("child branch pair lost: %+v", byURL[childURL]) + } + if got := f.spy.count(); got != 0 { + t.Fatalf("clean PRs must not nudge, got %d: %+v", got, f.spy.snapshot()) + } + }) + + t.Run("session stays alive while a stacked PR is open and terminates once all are merged", func(t *testing.T) { + ctx := context.Background() + f := newSCMFixture(t, "feat/x") + const ( + rootURL = "https://github.com/octocat/hello/pull/201" + childURL = "https://github.com/octocat/hello/pull/202" + ) + f.provider.detected["feat/x"] = detectedPR(rootURL, 201, "feat/x", "main", "sha-root") + f.provider.detected["feat/x/auth"] = detectedPR(childURL, 202, "feat/x/auth", "feat/x", "sha-child") + f.provider.observations[201] = openSCMObservation(rootURL, 201, "sha-root", "feat/x", "main", domain.MergeMergeable) + f.provider.observations[202] = openSCMObservation(childURL, 202, "sha-child", "feat/x/auth", "feat/x", domain.MergeBlocked) + + // Poll 1: both PRs open and tracked. The session is live. + if err := f.observer.Poll(ctx); err != nil { + t.Fatalf("Poll 1: %v", err) + } + if rec, _, _ := f.store.GetSession(ctx, f.session.ID); rec.IsTerminated { + t.Fatal("session terminated with two open PRs") + } + + // Poll 2: the root merges while the child stays open. One merged PR does + // not satisfy the completion bar while another PR is still open. + f.provider.observations[201] = mergedSCMObservationBranches(rootURL, 201, "sha-root", "feat/x", "main") + if err := f.observer.Poll(ctx); err != nil { + t.Fatalf("Poll 2: %v", err) + } + rootPR, ok, err := f.store.GetPR(ctx, rootURL) + if err != nil || !ok { + t.Fatalf("GetPR root: ok=%v err=%v", ok, err) + } + if !rootPR.Merged { + t.Fatalf("root PR should be persisted merged: %+v", rootPR) + } + if rec, _, _ := f.store.GetSession(ctx, f.session.ID); rec.IsTerminated { + t.Fatal("session terminated while the stacked child PR is still open") + } + + // Poll 3: the child merges too. Now every PR is merged/closed and at least + // one merged, so the session completes and terminates. + f.provider.observations[202] = mergedSCMObservationBranches(childURL, 202, "sha-child", "feat/x/auth", "feat/x") + if err := f.observer.Poll(ctx); err != nil { + t.Fatalf("Poll 3: %v", err) + } + rec, ok, err := f.store.GetSession(ctx, f.session.ID) + if err != nil || !ok { + t.Fatalf("GetSession: ok=%v err=%v", ok, err) + } + if !rec.IsTerminated { + t.Fatalf("session should terminate once all PRs are merged: %+v", rec) + } + if got := f.spy.count(); got != 0 { + t.Fatalf("merge-driven completion must not nudge, got %d: %+v", got, f.spy.snapshot()) + } + }) + + t.Run("stacked child blocked by an open parent is exempt from the rebase nudge", func(t *testing.T) { + ctx := context.Background() + f := newSCMFixture(t, "feat/x") + const ( + rootURL = "https://github.com/octocat/hello/pull/301" + childURL = "https://github.com/octocat/hello/pull/302" + ) + f.provider.detected["feat/x"] = detectedPR(rootURL, 301, "feat/x", "main", "sha-root") + f.provider.detected["feat/x/auth"] = detectedPR(childURL, 302, "feat/x/auth", "feat/x", "sha-child") + // Poll 1 establishes both rows (open, mergeable) so the stack relationship + // is durable before conflicts appear, making the poll-2 reaction order + // independent of map iteration. + f.provider.observations[301] = openSCMObservation(rootURL, 301, "sha-root", "feat/x", "main", domain.MergeMergeable) + f.provider.observations[302] = openSCMObservation(childURL, 302, "sha-child", "feat/x/auth", "feat/x", domain.MergeMergeable) + if err := f.observer.Poll(ctx); err != nil { + t.Fatalf("Poll 1: %v", err) + } + if got := f.spy.count(); got != 0 { + t.Fatalf("clean establishing poll must not nudge, got %d: %+v", got, f.spy.snapshot()) + } + + // Poll 2: both PRs now report merge conflicts. Only the bottom of the + // stack (the root, targeting main) is eligible for the rebase nudge; the + // child targets feat/x, the still-open root's source branch, so it is + // expected to conflict against its parent until the parent merges and is + // suppressed. + f.provider.observations[301] = openSCMObservation(rootURL, 301, "sha-root", "feat/x", "main", domain.MergeConflicting) + f.provider.observations[302] = openSCMObservation(childURL, 302, "sha-child", "feat/x/auth", "feat/x", domain.MergeConflicting) + if err := f.observer.Poll(ctx); err != nil { + t.Fatalf("Poll 2: %v", err) + } + + msgs := f.spy.snapshot() + if len(msgs) != 1 { + t.Fatalf("exactly one PR (the stack bottom) should be nudged, got %d: %+v", len(msgs), msgs) + } + if msgs[0].session != f.session.ID { + t.Fatalf("nudge addressed to %q, want %q", msgs[0].session, f.session.ID) + } + if !strings.Contains(msgs[0].body, "merge conflicts") { + t.Fatalf("nudge body missing merge-conflict cue: %q", msgs[0].body) + } + + // The persisted dedup signature must be the root's, never the child's — + // proving the child was suppressed at the reaction layer, not merely + // deduped after sending. + rootSig, err := f.store.GetPRLastNudgeSignature(ctx, rootURL) + if err != nil { + t.Fatalf("GetPRLastNudgeSignature root: %v", err) + } + if rootSig == "" { + t.Fatal("root PR should have a persisted nudge signature") + } + childSig, err := f.store.GetPRLastNudgeSignature(ctx, childURL) + if err != nil { + t.Fatalf("GetPRLastNudgeSignature child: %v", err) + } + if childSig != "" { + t.Fatalf("stacked child must not record a nudge signature: %q", childSig) + } + }) +} diff --git a/backend/internal/lifecycle/manager.go b/backend/internal/lifecycle/manager.go index 925af553..d05ba1cd 100644 --- a/backend/internal/lifecycle/manager.go +++ b/backend/internal/lifecycle/manager.go @@ -17,6 +17,11 @@ import ( type sessionStore interface { GetSession(ctx context.Context, id domain.SessionID) (domain.SessionRecord, bool, error) UpdateSession(ctx context.Context, rec domain.SessionRecord) error + // ListPRsBySession returns every PR row tracked for the session. The + // reducer reads it to apply the multi-PR completion rule (terminate only + // when no open PR remains and at least one merged) and to suppress + // merge-conflict nudges on PRs stacked behind an open parent. + ListPRsBySession(ctx context.Context, id domain.SessionID) ([]domain.PullRequest, error) // GetPRLastNudgeSignature / UpdatePRLastNudgeSignature persist the // reaction-dedup map so nudges survive a daemon restart. GetPRLastNudgeSignature(ctx context.Context, prURL string) (string, error) diff --git a/backend/internal/lifecycle/manager_test.go b/backend/internal/lifecycle/manager_test.go index e739fa01..45816bcd 100644 --- a/backend/internal/lifecycle/manager_test.go +++ b/backend/internal/lifecycle/manager_test.go @@ -15,6 +15,7 @@ var ctx = context.Background() type fakeStore struct { sessions map[domain.SessionID]domain.SessionRecord + prs map[domain.SessionID][]domain.PullRequest signatures map[string]string signatureWriteErr error @@ -22,7 +23,7 @@ type fakeStore struct { } func newFakeStore() *fakeStore { - return &fakeStore{sessions: map[domain.SessionID]domain.SessionRecord{}, signatures: map[string]string{}} + return &fakeStore{sessions: map[domain.SessionID]domain.SessionRecord{}, prs: map[domain.SessionID][]domain.PullRequest{}, signatures: map[string]string{}} } func (f *fakeStore) GetSession(_ context.Context, id domain.SessionID) (domain.SessionRecord, bool, error) { @@ -30,6 +31,10 @@ func (f *fakeStore) GetSession(_ context.Context, id domain.SessionID) (domain.S return r, ok, nil } +func (f *fakeStore) ListPRsBySession(_ context.Context, id domain.SessionID) ([]domain.PullRequest, error) { + return f.prs[id], nil +} + func (f *fakeStore) UpdateSession(_ context.Context, rec domain.SessionRecord) error { f.sessions[rec.ID] = rec return nil @@ -224,6 +229,7 @@ func TestPRObservation_MergeConflictNudgesAgent(t *testing.T) { func TestPRObservation_MergedTerminatesWithoutNudge(t *testing.T) { m, st, msg := newManager() st.sessions["mer-1"] = working("mer-1") + st.prs["mer-1"] = []domain.PullRequest{{URL: "pr1", Merged: true}} if err := m.ApplyPRObservation(ctx, "mer-1", ports.PRObservation{Fetched: true, URL: "pr1", Merged: true}); err != nil { t.Fatal(err) } @@ -236,6 +242,91 @@ func TestPRObservation_MergedTerminatesWithoutNudge(t *testing.T) { } } +// A session with one merged PR and one still-open PR must NOT terminate: the +// completion bar is "no open PR remains AND at least one merged". +func TestPRObservation_MergedWithOpenSiblingDoesNotTerminate(t *testing.T) { + m, st, _ := newManager() + st.sessions["mer-1"] = working("mer-1") + st.prs["mer-1"] = []domain.PullRequest{ + {URL: "pr1", Merged: true}, + {URL: "pr2"}, + } + if err := m.ApplyPRObservation(ctx, "mer-1", ports.PRObservation{Fetched: true, URL: "pr1", Merged: true}); err != nil { + t.Fatal(err) + } + if got := st.sessions["mer-1"]; got.IsTerminated { + t.Fatalf("session with an open sibling PR must stay alive, got %+v", got) + } +} + +// Once the last open PR merges (all PRs now merged), the session terminates. +func TestPRObservation_LastMergeTerminatesSession(t *testing.T) { + m, st, _ := newManager() + st.sessions["mer-1"] = working("mer-1") + st.prs["mer-1"] = []domain.PullRequest{ + {URL: "pr1", Merged: true}, + {URL: "pr2", Merged: true}, + } + if err := m.ApplyPRObservation(ctx, "mer-1", ports.PRObservation{Fetched: true, URL: "pr2", Merged: true}); err != nil { + t.Fatal(err) + } + if got := st.sessions["mer-1"]; !got.IsTerminated { + t.Fatalf("session should terminate once all PRs are merged, got %+v", got) + } +} + +// A closed PR that leaves the session with an open sibling and no merge does not +// terminate; closing the last PR with no merge also does not terminate (nothing +// shipped). +func TestPRObservation_ClosedWithoutMergeDoesNotTerminate(t *testing.T) { + m, st, _ := newManager() + st.sessions["mer-1"] = working("mer-1") + st.prs["mer-1"] = []domain.PullRequest{{URL: "pr1", Closed: true}} + if err := m.ApplyPRObservation(ctx, "mer-1", ports.PRObservation{Fetched: true, URL: "pr1", Closed: true}); err != nil { + t.Fatal(err) + } + if got := st.sessions["mer-1"]; got.IsTerminated { + t.Fatalf("a closed-without-merge PR must not terminate the session, got %+v", got) + } +} + +// A PR stacked on an open parent (its target branch is the parent's source +// branch) is exempt from the merge-conflict nudge: conflicts there are expected +// until the parent merges. +func TestPRObservation_StackedChildConflictSuppressed(t *testing.T) { + m, st, msg := newManager() + st.sessions["mer-1"] = working("mer-1") + st.prs["mer-1"] = []domain.PullRequest{ + {URL: "parent", SourceBranch: "ao/x", TargetBranch: "main"}, + {URL: "child", SourceBranch: "ao/x/auth", TargetBranch: "ao/x"}, + } + o := ports.PRObservation{Fetched: true, URL: "child", Mergeability: domain.MergeConflicting} + if err := m.ApplyPRObservation(ctx, "mer-1", o); err != nil { + t.Fatal(err) + } + if len(msg.msgs) != 0 { + t.Fatalf("stacked child conflict should be suppressed, got %v", msg.msgs) + } +} + +// The bottom-of-stack PR (not stacked on any open parent) still gets the +// merge-conflict nudge even when it has open stacked children. +func TestPRObservation_BottomOfStackConflictNudges(t *testing.T) { + m, st, msg := newManager() + st.sessions["mer-1"] = working("mer-1") + st.prs["mer-1"] = []domain.PullRequest{ + {URL: "parent", SourceBranch: "ao/x", TargetBranch: "main"}, + {URL: "child", SourceBranch: "ao/x/auth", TargetBranch: "ao/x"}, + } + o := ports.PRObservation{Fetched: true, URL: "parent", Mergeability: domain.MergeConflicting} + if err := m.ApplyPRObservation(ctx, "mer-1", o); err != nil { + t.Fatal(err) + } + if len(msg.msgs) != 1 || !strings.Contains(msg.msgs[0], "merge conflicts") { + t.Fatalf("bottom-of-stack conflict should nudge, got %v", msg.msgs) + } +} + // TestPRObservation_DedupSurvivesManagerRestart simulates a daemon restart by // constructing a second Manager over the same store and asserts that an // identical PR observation does not re-fire the nudge — the dedup signature diff --git a/backend/internal/lifecycle/reactions.go b/backend/internal/lifecycle/reactions.go index 614d3db9..ea24a03c 100644 --- a/backend/internal/lifecycle/reactions.go +++ b/backend/internal/lifecycle/reactions.go @@ -43,10 +43,19 @@ func (m *Manager) ApplyPRObservation(ctx context.Context, id domain.SessionID, o if !o.Fetched { return nil } - if o.Merged { - return m.MarkTerminated(ctx, id) - } - if o.Closed { + // A PR reaching a terminal state (merged or closed) no longer ends the + // session on its own: a session may own several PRs. Terminate only when no + // open PR remains and at least one of them merged. The observer persists the + // PR row before calling lifecycle, so the store already reflects this + // transition when sessionComplete reads it. + if o.Merged || o.Closed { + done, err := m.sessionComplete(ctx, id) + if err != nil { + return err + } + if done { + return m.MarkTerminated(ctx, id) + } return nil } rec, ok, err := m.store.GetSession(ctx, id) @@ -79,11 +88,67 @@ func (m *Manager) ApplyPRObservation(ctx context.Context, id domain.SessionID, o return m.sendOnce(ctx, id, o.URL, "review:"+o.URL, sig, msg, reviewMaxNudge) } if o.Mergeability == domain.MergeConflicting { + // Only the bottom of a stack is eligible for the rebase nudge. A PR + // stacked on an open parent is expected to report conflicts against its + // parent branch until the parent merges and it retargets, so nudging the + // agent to rebase it now would be noise. Mergeability UNKNOWN (the brief + // post-retarget recompute window) never reaches here. + blocked, err := m.prBlockedByOpenParent(ctx, id, o.URL) + if err != nil { + return err + } + if blocked { + return nil + } return m.sendOnce(ctx, id, o.URL, "merge-conflict:"+o.URL, string(o.Mergeability), "Your PR has merge conflicts. Rebase onto the base branch and resolve them.", 0) } return nil } +// sessionComplete reports whether the session has reached the multi-PR +// completion bar: at least one PR merged and no PR still open. A session with no +// PRs, or with any open PR, is not complete. +func (m *Manager) sessionComplete(ctx context.Context, id domain.SessionID) (bool, error) { + prs, err := m.store.ListPRsBySession(ctx, id) + if err != nil { + return false, err + } + merged := false + for _, pr := range prs { + if !pr.Merged && !pr.Closed { + return false, nil + } + if pr.Merged { + merged = true + } + } + return merged, nil +} + +// prBlockedByOpenParent reports whether the PR at prURL is stacked on top of +// another still-open PR in the same session — i.e. its target branch is the +// source branch of a sibling open PR. Such a PR is not the bottom of its stack +// and is exempt from merge-conflict nudges. Branch facts are read from the +// store, which the observer has already updated for this observation. +func (m *Manager) prBlockedByOpenParent(ctx context.Context, id domain.SessionID, prURL string) (bool, error) { + prs, err := m.store.ListPRsBySession(ctx, id) + if err != nil { + return false, err + } + openSources := make(map[string]bool, len(prs)) + for _, pr := range prs { + if !pr.Merged && !pr.Closed && pr.SourceBranch != "" { + openSources[pr.SourceBranch] = true + } + } + for _, pr := range prs { + if pr.URL == prURL { + return pr.TargetBranch != "" && openSources[pr.TargetBranch], nil + } + } + return false, nil +} + // ApplySCMObservation is the provider-neutral lifecycle entrypoint used by the // SCM observer. The existing reaction logic still operates on PRObservation, so // lifecycle performs the compatibility projection internally instead of leaking diff --git a/backend/internal/observe/scm/observer.go b/backend/internal/observe/scm/observer.go index afc35d57..b7cb6b36 100644 --- a/backend/internal/observe/scm/observer.go +++ b/backend/internal/observe/scm/observer.go @@ -38,7 +38,7 @@ const ( type Provider interface { ParseRepository(remote string) (ports.SCMRepo, bool) RepoPRListGuard(ctx context.Context, repo ports.SCMRepo, etag string) (ports.SCMGuardResult, error) - DetectPRByBranch(ctx context.Context, repo ports.SCMRepo, branch string) (ports.SCMPRObservation, error) + ListOpenPRsByRepo(ctx context.Context, repo ports.SCMRepo) ([]ports.SCMPRObservation, error) CommitChecksGuard(ctx context.Context, repo ports.SCMRepo, headSHA, etag string) (ports.SCMGuardResult, error) FetchPullRequests(ctx context.Context, refs []ports.SCMPRRef) ([]ports.SCMObservation, error) FetchFailedCheckLogTail(ctx context.Context, repo ports.SCMRepo, check ports.SCMCheckObservation) (string, error) @@ -191,6 +191,14 @@ type subject struct { hasPR bool } +// sessionRepo pairs a live session with its parsed repo and branch for per-repo +// branch-prefix discovery of new (including stacked) pull requests. +type sessionRepo struct { + session domain.SessionRecord + repo ports.SCMRepo + branch string +} + type repoGuardState struct { result ports.SCMGuardResult hadETag bool @@ -226,11 +234,11 @@ func (o *Observer) Poll(ctx context.Context) error { if o.disabled { return nil } - subjects, err := o.discoverSubjects(ctx) + subjects, sessionRepos, err := o.discoverSubjects(ctx) if err != nil { return err } - if len(subjects) == 0 { + if len(sessionRepos) == 0 { return nil } proceed, err := o.checkCredentials(ctx) @@ -241,7 +249,7 @@ func (o *Observer) Poll(ctx context.Context) error { return nil } - repoGuards := o.guardRepos(ctx, subjects) + repoGuards := o.guardRepos(ctx, sessionRepos) repoRefreshOK := pendingRepoRefreshes(repoGuards) markRepoRefreshFailed := func(repo ports.SCMRepo) { key := prKey(repo, 0) @@ -252,7 +260,7 @@ func (o *Observer) Poll(ctx context.Context) error { if err := ctx.Err(); err != nil { return err } - o.detectMissingPRs(ctx, subjects, repoGuards, now, markRepoRefreshFailed) + o.discoverNewPRs(ctx, sessionRepos, subjects, repoGuards, now, markRepoRefreshFailed) if err := ctx.Err(); err != nil { return err } @@ -405,13 +413,19 @@ func (o *Observer) checkCredentials(ctx context.Context) (bool, error) { return observe.CheckCredentialsOnce(ctx, probe, &o.credentialsChecked, &o.disabled, o.logger, "scm observer") } -func (o *Observer) discoverSubjects(ctx context.Context) (map[string]*subject, error) { +// discoverSubjects builds the per-PR refresh subjects (one per open tracked PR) +// and the per-session repo list used for branch-prefix discovery of new PRs. A +// session may own several PRs, so each open tracked PR becomes its own subject; +// merged/closed PRs are not re-fetched since lifecycle already saw the terminal +// transition and the completion rule reads them from the store. +func (o *Observer) discoverSubjects(ctx context.Context) (map[string]*subject, []sessionRepo, error) { sessions, err := o.store.ListAllSessions(ctx) if err != nil { - return nil, err + return nil, nil, err } projects := map[domain.ProjectID]domain.ProjectRecord{} out := map[string]*subject{} + var sessionRepos []sessionRepo for _, sess := range sessions { if sess.IsTerminated { continue @@ -424,7 +438,7 @@ func (o *Observer) discoverSubjects(ctx context.Context) (map[string]*subject, e if !ok { p, found, err := o.store.GetProject(ctx, string(sess.ProjectID)) if err != nil { - return nil, err + return nil, nil, err } if !found || !p.ArchivedAt.IsZero() { continue @@ -445,47 +459,37 @@ func (o *Observer) discoverSubjects(ctx context.Context) (map[string]*subject, e o.logger.Debug("scm observer: project has no supported SCM origin", "project", proj.ID, "origin", proj.RepoOriginURL) continue } + sessionRepos = append(sessionRepos, sessionRepo{session: sess, repo: repo, branch: branch}) prs, err := o.store.ListPRsBySession(ctx, sess.ID) if err != nil { - return nil, err + return nil, nil, err } - known, hasPR := chooseKnownPR(prs) - s := &subject{session: sess, repo: repo, branch: branch, known: known, hasPR: hasPR} - if hasPR && known.Number > 0 { - key := prKey(repo, known.Number) + for _, pr := range openTrackedPRs(prs) { + key := prKey(repo, pr.Number) if existing, ok := out[key]; ok { o.logger.Warn("scm observer: duplicate tracked PR ownership skipped", "pr", key, "kept_session", existing.session.ID, "skipped_session", sess.ID) continue } - out[key] = s - } else { - out["session:"+string(sess.ID)] = s + out[key] = &subject{session: sess, repo: repo, branch: branch, known: pr, hasPR: true} } } - return out, nil + return out, sessionRepos, nil } -func chooseKnownPR(prs []domain.PullRequest) (domain.PullRequest, bool) { - if len(prs) == 0 { - return domain.PullRequest{}, false - } +func openTrackedPRs(prs []domain.PullRequest) []domain.PullRequest { + out := make([]domain.PullRequest, 0, len(prs)) for _, pr := range prs { if pr.Number > 0 && !pr.Merged && !pr.Closed { - return pr, true + out = append(out, pr) } } - for _, pr := range prs { - if pr.Number > 0 { - return pr, true - } - } - return domain.PullRequest{}, false + return out } -func (o *Observer) guardRepos(ctx context.Context, subjects map[string]*subject) map[string]repoGuardState { +func (o *Observer) guardRepos(ctx context.Context, sessionRepos []sessionRepo) map[string]repoGuardState { repos := map[string]ports.SCMRepo{} - for _, s := range subjects { - repos[prKey(s.repo, 0)] = s.repo + for _, sr := range sessionRepos { + repos[prKey(sr.repo, 0)] = sr.repo } out := map[string]repoGuardState{} for key, repo := range repos { @@ -511,39 +515,91 @@ func pendingRepoRefreshes(guards map[string]repoGuardState) map[string]bool { return out } -func (o *Observer) detectMissingPRs(ctx context.Context, subjects map[string]*subject, guards map[string]repoGuardState, now time.Time, markRepoFailed func(ports.SCMRepo)) { - for oldKey, s := range subjects { - if s.hasPR { - continue - } - g := guards[prKey(s.repo, 0)] +// discoverNewPRs lists each repo's open PRs once and attaches any not-yet-tracked +// PR to the session that owns its source branch. A session owns a PR when the +// PR's source branch equals the session branch or descends from it (the +// "branch/..." stacking convention). One session may therefore pick up several +// PRs (its root plus stacked children). Repos whose PR-list guard reports +// NotModified against a known ETag are skipped, since nothing new can have +// appeared since the last poll. +func (o *Observer) discoverNewPRs(ctx context.Context, sessionRepos []sessionRepo, subjects map[string]*subject, guards map[string]repoGuardState, now time.Time, markRepoFailed func(ports.SCMRepo)) { + byRepo := map[string][]sessionRepo{} + repos := map[string]ports.SCMRepo{} + for _, sr := range sessionRepos { + key := prKey(sr.repo, 0) + byRepo[key] = append(byRepo[key], sr) + repos[key] = sr.repo + } + for repoKey, repo := range repos { + g := guards[repoKey] if g.err != nil { continue } if g.result.NotModified && g.hadETag { continue } - pr, err := o.provider.DetectPRByBranch(ctx, s.repo, s.branch) + pulls, err := o.provider.ListOpenPRsByRepo(ctx, repo) if err != nil { - o.logger.Debug("scm observer: no PR detected for branch", "session", s.session.ID, "branch", s.branch, "err", err) + o.logger.Debug("scm observer: open PR list failed", "repo", repoFullName(repo), "err", err) if markRepoFailed != nil && !errors.Is(err, ports.ErrSCMNotFound) { - markRepoFailed(s.repo) + markRepoFailed(repo) } continue } - if pr.Number <= 0 { - continue + for _, pr := range pulls { + if pr.Number <= 0 || pr.SourceBranch == "" { + continue + } + key := prKey(repo, pr.Number) + if _, ok := subjects[key]; ok { + continue + } + sr, ok := matchSession(byRepo[repoKey], pr.SourceBranch) + if !ok { + continue + } + subjects[key] = &subject{ + session: sr.session, + repo: repo, + branch: sr.branch, + known: domain.PullRequest{ + URL: firstNonEmpty(pr.URL, pr.HTMLURL), + SessionID: sr.session.ID, + Number: pr.Number, + SourceBranch: pr.SourceBranch, + TargetBranch: pr.TargetBranch, + HeadSHA: pr.HeadSHA, + Provider: repo.Provider, + Host: repo.Host, + Repo: repoFullName(repo), + UpdatedAt: now, + }, + hasPR: true, + } } - newKey := prKey(s.repo, pr.Number) - if existing, ok := subjects[newKey]; ok && existing != s { - o.logger.Warn("scm observer: detected PR is already tracked by another session", "pr", newKey, "kept_session", existing.session.ID, "skipped_session", s.session.ID) + } +} + +// matchSession picks the session that owns sourceBranch. A session owns the +// branch when it is an exact match or a stacked descendant ("branch/..."). When +// several session branches are prefixes of the same source branch the longest +// (most specific) one wins, so a child session claims its own stacked PRs rather +// than the ancestor session. +func matchSession(candidates []sessionRepo, sourceBranch string) (sessionRepo, bool) { + var best sessionRepo + bestLen := -1 + for _, sr := range candidates { + if sr.branch == "" { continue } - s.known = domain.PullRequest{URL: pr.URL, SessionID: s.session.ID, Number: pr.Number, SourceBranch: pr.SourceBranch, TargetBranch: pr.TargetBranch, HeadSHA: pr.HeadSHA, Provider: s.repo.Provider, Host: s.repo.Host, Repo: repoFullName(s.repo), UpdatedAt: now} - s.hasPR = true - delete(subjects, oldKey) - subjects[newKey] = s + if sr.branch == sourceBranch || strings.HasPrefix(sourceBranch, sr.branch+"/") { + if len(sr.branch) > bestLen { + best = sr + bestLen = len(sr.branch) + } + } } + return best, bestLen >= 0 } func (o *Observer) selectRefreshCandidates(ctx context.Context, subjects map[string]*subject, guards map[string]repoGuardState, markRepoFailed func(ports.SCMRepo)) refreshSelection { diff --git a/backend/internal/observe/scm/observer_test.go b/backend/internal/observe/scm/observer_test.go index b8f9cc0c..cd46d579 100644 --- a/backend/internal/observe/scm/observer_test.go +++ b/backend/internal/observe/scm/observer_test.go @@ -108,7 +108,8 @@ type fakeProvider struct { mu sync.Mutex repoGuards map[string]ports.SCMGuardResult checkGuards map[string]ports.SCMGuardResult - detected map[string]ports.SCMPRObservation + openPRs map[string][]ports.SCMPRObservation + listErr error observations map[string]ports.SCMObservation reviews map[string]ports.SCMReviewObservation logTails map[string]string @@ -120,7 +121,7 @@ type fakeProvider struct { credentialErr error credentialChecks int repoGuardCalls int - detectCalls int + listCalls int fetchBatches [][]ports.SCMPRRef logCalls int reviewCalls int @@ -145,15 +146,14 @@ func (p *fakeProvider) RepoPRListGuard(_ context.Context, repo ports.SCMRepo, _ p.repoGuardCalls++ return p.repoGuards[prKey(repo, 0)], nil } -func (p *fakeProvider) DetectPRByBranch(_ context.Context, _ ports.SCMRepo, branch string) (ports.SCMPRObservation, error) { +func (p *fakeProvider) ListOpenPRsByRepo(_ context.Context, repo ports.SCMRepo) ([]ports.SCMPRObservation, error) { p.mu.Lock() defer p.mu.Unlock() - p.detectCalls++ - pr, ok := p.detected[branch] - if !ok { - return ports.SCMPRObservation{}, ports.ErrSCMNotFound + p.listCalls++ + if p.listErr != nil { + return nil, p.listErr } - return pr, nil + return p.openPRs[prKey(repo, 0)], nil } func (p *fakeProvider) CommitChecksGuard(_ context.Context, repo ports.SCMRepo, sha, _ string) (ports.SCMGuardResult, error) { return p.checkGuards[commitKey(repo, sha)], nil @@ -272,9 +272,9 @@ func TestPoll_DisablesOnceWhenCredentialsUnavailable(t *testing.T) { if provider.credentialChecks != 1 { t.Fatalf("credential checks = %d, want one lazy check", provider.credentialChecks) } - if provider.repoGuardCalls != 0 || provider.detectCalls != 0 || len(provider.fetchBatches) != 0 { - t.Fatalf("provider API calls should be skipped without credentials: guards=%d detects=%d batches=%d", - provider.repoGuardCalls, provider.detectCalls, len(provider.fetchBatches)) + if provider.repoGuardCalls != 0 || provider.listCalls != 0 || len(provider.fetchBatches) != 0 { + t.Fatalf("provider API calls should be skipped without credentials: guards=%d lists=%d batches=%d", + provider.repoGuardCalls, provider.listCalls, len(provider.fetchBatches)) } } @@ -389,13 +389,13 @@ func TestStart_LogsDisabledWarningWhenNoTokenAndNoSubjects(t *testing.T) { if provider.credentialChecks != 1 { t.Fatalf("credential checks = %d, want exactly one pre-poll check", provider.credentialChecks) } - if provider.repoGuardCalls != 0 || provider.detectCalls != 0 || len(provider.fetchBatches) != 0 { - t.Fatalf("no provider API calls expected when disabled: guards=%d detects=%d batches=%d", - provider.repoGuardCalls, provider.detectCalls, len(provider.fetchBatches)) + if provider.repoGuardCalls != 0 || provider.listCalls != 0 || len(provider.fetchBatches) != 0 { + t.Fatalf("no provider API calls expected when disabled: guards=%d lists=%d batches=%d", + provider.repoGuardCalls, provider.listCalls, len(provider.fetchBatches)) } } -func TestPoll_RepoETag304SkipsDetectPR(t *testing.T) { +func TestPoll_RepoETag304SkipsListPRs(t *testing.T) { store := testStoreWithSession() provider := &fakeProvider{repoGuards: map[string]ports.SCMGuardResult{prKey(testRepo, 0): {ETag: "v1", NotModified: true}}, observations: map[string]ports.SCMObservation{}} obs := newTestObserver(store, provider, &fakeLifecycle{}, time.Unix(1, 0).UTC()) @@ -403,16 +403,16 @@ func TestPoll_RepoETag304SkipsDetectPR(t *testing.T) { if err := obs.Poll(context.Background()); err != nil { t.Fatal(err) } - if provider.detectCalls != 0 { - t.Fatalf("detectPR called on 304: %d", provider.detectCalls) + if provider.listCalls != 0 { + t.Fatalf("ListOpenPRsByRepo called on 304: %d", provider.listCalls) } } -func TestPoll_DetectPRNotFoundCommitsRepoETag(t *testing.T) { +func TestPoll_NoOpenPRsCommitsRepoETag(t *testing.T) { store := testStoreWithSession() provider := &fakeProvider{ repoGuards: map[string]ports.SCMGuardResult{prKey(testRepo, 0): {ETag: "v2"}}, - detected: map[string]ports.SCMPRObservation{}, + openPRs: map[string][]ports.SCMPRObservation{}, observations: map[string]ports.SCMObservation{}, } obs := newTestObserver(store, provider, &fakeLifecycle{}, time.Unix(1, 0).UTC()) @@ -420,22 +420,22 @@ func TestPoll_DetectPRNotFoundCommitsRepoETag(t *testing.T) { if err := obs.Poll(context.Background()); err != nil { t.Fatal(err) } - if provider.detectCalls != 1 { - t.Fatalf("detectPR calls = %d, want 1", provider.detectCalls) + if provider.listCalls != 1 { + t.Fatalf("ListOpenPRsByRepo calls = %d, want 1", provider.listCalls) } if got := obs.Cache.RepoPRListETag[prKey(testRepo, 0)]; got != "v2" { - t.Fatalf("repo ETag after not-found detection = %q, want v2", got) + t.Fatalf("repo ETag after empty listing = %q, want v2", got) } if len(provider.fetchBatches) != 0 { - t.Fatalf("not-found branch should not fetch PR batch: %#v", provider.fetchBatches) + t.Fatalf("empty listing should not fetch PR batch: %#v", provider.fetchBatches) } } -func TestPoll_RepoETag200DetectsPRAndRefreshesSamePoll(t *testing.T) { +func TestPoll_RepoETag200DiscoversPRAndRefreshesSamePoll(t *testing.T) { store := testStoreWithSession() provider := &fakeProvider{ repoGuards: map[string]ports.SCMGuardResult{prKey(testRepo, 0): {ETag: "v2"}}, - detected: map[string]ports.SCMPRObservation{"feat": {URL: "https://github.com/o/r/pull/1", Number: 1, SourceBranch: "feat", TargetBranch: "main", HeadSHA: "sha1"}}, + openPRs: map[string][]ports.SCMPRObservation{prKey(testRepo, 0): {{URL: "https://github.com/o/r/pull/1", Number: 1, SourceBranch: "feat", TargetBranch: "main", HeadSHA: "sha1"}}}, observations: map[string]ports.SCMObservation{prKey(testRepo, 1): testObs(1)}, } lc := &fakeLifecycle{} @@ -443,8 +443,8 @@ func TestPoll_RepoETag200DetectsPRAndRefreshesSamePoll(t *testing.T) { if err := obs.Poll(context.Background()); err != nil { t.Fatal(err) } - if provider.detectCalls != 1 { - t.Fatalf("detectPR calls = %d, want 1", provider.detectCalls) + if provider.listCalls != 1 { + t.Fatalf("ListOpenPRsByRepo calls = %d, want 1", provider.listCalls) } if len(provider.fetchBatches) != 1 || len(provider.fetchBatches[0]) != 1 || provider.fetchBatches[0][0].Number != 1 { t.Fatalf("new PR not refreshed in same poll: %#v", provider.fetchBatches) @@ -454,6 +454,37 @@ func TestPoll_RepoETag200DetectsPRAndRefreshesSamePoll(t *testing.T) { } } +// A session whose branch is the prefix of two open PRs (its root plus a stacked +// child on branch "feat/child") picks up both PRs in a single poll. +func TestPoll_DiscoversStackedChildByBranchPrefix(t *testing.T) { + store := testStoreWithSession() + childObs := testObs(2) + childObs.PR.SourceBranch = "feat/child" + childObs.PR.TargetBranch = "feat" + provider := &fakeProvider{ + repoGuards: map[string]ports.SCMGuardResult{prKey(testRepo, 0): {ETag: "v2"}}, + openPRs: map[string][]ports.SCMPRObservation{prKey(testRepo, 0): { + {URL: "https://github.com/o/r/pull/1", Number: 1, SourceBranch: "feat", TargetBranch: "main", HeadSHA: "sha1"}, + {URL: "https://github.com/o/r/pull/2", Number: 2, SourceBranch: "feat/child", TargetBranch: "feat", HeadSHA: "sha2"}, + }}, + observations: map[string]ports.SCMObservation{prKey(testRepo, 1): testObs(1), prKey(testRepo, 2): childObs}, + } + lc := &fakeLifecycle{} + obs := newTestObserver(store, provider, lc, time.Unix(1, 0).UTC()) + if err := obs.Poll(context.Background()); err != nil { + t.Fatal(err) + } + fetched := map[int]bool{} + for _, batch := range provider.fetchBatches { + for _, ref := range batch { + fetched[ref.Number] = true + } + } + if !fetched[1] || !fetched[2] { + t.Fatalf("expected both root and stacked child fetched, got %#v", fetched) + } +} + func TestPoll_CIETagChangeRefreshesWhenRepoUnchanged(t *testing.T) { store := testStoreWithSession() store.prs["p-1"] = []domain.PullRequest{knownPR(1)} @@ -950,7 +981,7 @@ func TestDiscoverSubjects_BackfillsRepoOriginURL(t *testing.T) { provider := &fakeProvider{} obs := newTestObserver(store, provider, &fakeLifecycle{}, time.Unix(0, 0).UTC()) - if _, err := obs.discoverSubjects(context.Background()); err != nil { + if _, _, err := obs.discoverSubjects(context.Background()); err != nil { t.Fatalf("discoverSubjects: %v", err) } if got := store.projects["p"].RepoOriginURL; got != "https://github.com/o/r.git" { @@ -970,12 +1001,12 @@ func TestDiscoverSubjects_NonGitPathDoesNotBackfill(t *testing.T) { checks: map[string][]domain.PullRequestCheck{}, } obs := newTestObserver(store, &fakeProvider{}, &fakeLifecycle{}, time.Unix(0, 0).UTC()) - subjects, err := obs.discoverSubjects(context.Background()) + subjects, sessionRepos, err := obs.discoverSubjects(context.Background()) if err != nil { t.Fatalf("discoverSubjects: %v", err) } - if len(subjects) != 0 { - t.Fatalf("non-git project should be skipped, got %d subjects", len(subjects)) + if len(subjects) != 0 || len(sessionRepos) != 0 { + t.Fatalf("non-git project should be skipped, got %d subjects %d sessionRepos", len(subjects), len(sessionRepos)) } if got := store.projects["p"].RepoOriginURL; got != "" { t.Fatalf("RepoOriginURL = %q, want empty (no persist on failed backfill)", got) diff --git a/backend/internal/service/session/service.go b/backend/internal/service/session/service.go index 8512b302..5d2e874d 100644 --- a/backend/internal/service/session/service.go +++ b/backend/internal/service/session/service.go @@ -20,6 +20,7 @@ type Store interface { ListAllSessions(ctx context.Context) ([]domain.SessionRecord, error) RenameSession(ctx context.Context, id domain.SessionID, displayName string, updatedAt time.Time) (bool, error) GetDisplayPRFactsForSession(ctx context.Context, id domain.SessionID) (domain.PRFacts, bool, error) + ListPRFactsForSession(ctx context.Context, id domain.SessionID) ([]domain.PRFacts, error) ListPRsBySession(ctx context.Context, sessionID domain.SessionID) ([]domain.PullRequest, error) ListPRComments(ctx context.Context, prURL string) ([]domain.PullRequestComment, error) GetProject(ctx context.Context, id string) (domain.ProjectRecord, bool, error) @@ -350,14 +351,11 @@ func toAPIError(err error) error { } func (s *Service) toSession(ctx context.Context, rec domain.SessionRecord) (domain.Session, error) { - pr, ok, err := s.store.GetDisplayPRFactsForSession(ctx, rec.ID) + prs, err := s.store.ListPRFactsForSession(ctx, rec.ID) if err != nil { return domain.Session{}, fmt.Errorf("pr facts %s: %w", rec.ID, err) } - if !ok { - return domain.Session{SessionRecord: rec, Status: deriveStatus(rec, nil, s.now(), s.harnessSignals(rec.Harness)), TerminalHandleID: rec.Metadata.RuntimeHandleID}, nil - } - return domain.Session{SessionRecord: rec, Status: deriveStatus(rec, &pr, s.now(), s.harnessSignals(rec.Harness)), TerminalHandleID: rec.Metadata.RuntimeHandleID}, nil + return domain.Session{SessionRecord: rec, Status: deriveStatus(rec, prs, s.now(), s.harnessSignals(rec.Harness)), TerminalHandleID: rec.Metadata.RuntimeHandleID, PRs: prs}, nil } // now tolerates a zero-value Service (tests construct the struct literally diff --git a/backend/internal/service/session/service_test.go b/backend/internal/service/session/service_test.go index 58dc8e52..8e39045a 100644 --- a/backend/internal/service/session/service_test.go +++ b/backend/internal/service/session/service_test.go @@ -78,6 +78,14 @@ func (f *fakeStore) ListPRsBySession(_ context.Context, id domain.SessionID) ([] return []domain.PullRequest{{URL: pr.URL, SessionID: id, Number: pr.Number, Draft: pr.Draft, Merged: pr.Merged, Closed: pr.Closed, CI: pr.CI, Review: pr.Review, Mergeability: pr.Mergeability, UpdatedAt: pr.UpdatedAt}}, nil } +func (f *fakeStore) ListPRFactsForSession(_ context.Context, id domain.SessionID) ([]domain.PRFacts, error) { + pr, ok := f.pr[id] + if !ok { + return nil, nil + } + return []domain.PRFacts{pr}, nil +} + func (f *fakeStore) ListPRComments(context.Context, string) ([]domain.PullRequestComment, error) { return nil, nil } diff --git a/backend/internal/service/session/stack.go b/backend/internal/service/session/stack.go new file mode 100644 index 00000000..b4542d2a --- /dev/null +++ b/backend/internal/service/session/stack.go @@ -0,0 +1,34 @@ +package session + +import "github.com/aoagents/agent-orchestrator/backend/internal/domain" + +// stackInfo is the derived position of one PR within its session's set of PRs. +// PRs form a stack when one targets the source branch of another: PR B is a +// child of PR A when B.TargetBranch == A.SourceBranch and A is open. +type stackInfo struct { + // Blocked is true when an open PR in the set owns the branch this PR targets, + // i.e. this PR is a child stacked on a parent that has not merged yet. + Blocked bool + // BottomOfStack is true when no open PR sits below this one. It is the only + // PR in a stack that should receive a merge-conflict rebase nudge; an + // independent PR (targeting the base branch) is its own bottom. + BottomOfStack bool +} + +// buildStacks derives the stack position of every PR from the source/target +// branch columns alone. A parent counts only while open, matching the rule that +// a merged or closed parent no longer blocks its children. +func buildStacks(prs []domain.PRFacts) map[string]stackInfo { + openSources := make(map[string]bool, len(prs)) + for _, p := range prs { + if !p.Merged && !p.Closed && p.SourceBranch != "" { + openSources[p.SourceBranch] = true + } + } + out := make(map[string]stackInfo, len(prs)) + for _, p := range prs { + blocked := p.TargetBranch != "" && openSources[p.TargetBranch] + out[p.URL] = stackInfo{Blocked: blocked, BottomOfStack: !blocked} + } + return out +} diff --git a/backend/internal/service/session/stack_test.go b/backend/internal/service/session/stack_test.go new file mode 100644 index 00000000..c2d2f0cc --- /dev/null +++ b/backend/internal/service/session/stack_test.go @@ -0,0 +1,127 @@ +package session + +import ( + "testing" + + "github.com/aoagents/agent-orchestrator/backend/internal/domain" +) + +// live builds an idle, non-terminated session that has already signaled, so the +// derived status is governed purely by its PRs. +func live() domain.SessionRecord { + return statusRec(domain.ActivityIdle, false) +} + +func TestBuildStacksMarksBlockedChildren(t *testing.T) { + // #142 (root → main), #143 stacked on #142, #144 stacked on #143. + prs := []domain.PRFacts{ + {URL: "p142", SourceBranch: "ao/abc", TargetBranch: "main"}, + {URL: "p143", SourceBranch: "ao/abc/auth", TargetBranch: "ao/abc"}, + {URL: "p144", SourceBranch: "ao/abc/tests", TargetBranch: "ao/abc/auth"}, + } + st := buildStacks(prs) + if st["p142"].Blocked || !st["p142"].BottomOfStack { + t.Fatalf("root should be unblocked bottom-of-stack, got %+v", st["p142"]) + } + if !st["p143"].Blocked || st["p143"].BottomOfStack { + t.Fatalf("middle should be blocked, got %+v", st["p143"]) + } + if !st["p144"].Blocked { + t.Fatalf("top should be blocked, got %+v", st["p144"]) + } +} + +func TestBuildStacksMergedParentUnblocksChild(t *testing.T) { + prs := []domain.PRFacts{ + {URL: "p142", SourceBranch: "ao/abc", TargetBranch: "main", Merged: true}, + {URL: "p143", SourceBranch: "ao/abc/auth", TargetBranch: "ao/abc"}, + } + st := buildStacks(prs) + if st["p143"].Blocked { + t.Fatal("child should be unblocked once parent is merged") + } +} + +func TestDeriveStatusWorstWinsAcrossIndependentPRs(t *testing.T) { + // Two independent open PRs (both target main): mergeable vs ci_failed. + // CI failure is more urgent, so the session reports ci_failed. + prs := []domain.PRFacts{ + {URL: "a", SourceBranch: "ao/a", TargetBranch: "main", Mergeability: domain.MergeMergeable}, + {URL: "b", SourceBranch: "ao/b", TargetBranch: "main", CI: domain.CIFailing}, + } + if got := deriveStatus(live(), prs, statusNow, true); got != domain.StatusCIFailed { + t.Fatalf("got %q want ci_failed", got) + } +} + +func TestDeriveStatusAllMergeableReportsMergeable(t *testing.T) { + prs := []domain.PRFacts{ + {URL: "a", SourceBranch: "ao/a", TargetBranch: "main", Mergeability: domain.MergeMergeable}, + {URL: "b", SourceBranch: "ao/b", TargetBranch: "main", Mergeability: domain.MergeMergeable}, + } + if got := deriveStatus(live(), prs, statusNow, true); got != domain.StatusMergeable { + t.Fatalf("got %q want mergeable", got) + } +} + +func TestDeriveStatusStackedChildExemptFromAggregation(t *testing.T) { + // Root mergeable; blocked child is pr_open. Child is exempt, so the session + // reports mergeable rather than being dragged down to pr_open. + prs := []domain.PRFacts{ + {URL: "root", SourceBranch: "ao/abc", TargetBranch: "main", Mergeability: domain.MergeMergeable}, + {URL: "child", SourceBranch: "ao/abc/x", TargetBranch: "ao/abc"}, + } + if got := deriveStatus(live(), prs, statusNow, true); got != domain.StatusMergeable { + t.Fatalf("got %q want mergeable (child exempt)", got) + } +} + +func TestDeriveStatusMergedParentOpenChildStaysOnChild(t *testing.T) { + // Parent merged, child now unblocked and review_pending: still alive, status + // follows the open child. + prs := []domain.PRFacts{ + {URL: "root", SourceBranch: "ao/abc", TargetBranch: "main", Merged: true}, + {URL: "child", SourceBranch: "ao/abc/x", TargetBranch: "main", Review: domain.ReviewRequired}, + } + if got := deriveStatus(live(), prs, statusNow, true); got != domain.StatusReviewPending { + t.Fatalf("got %q want review_pending", got) + } +} + +func TestDeriveStatusAllMergedReportsMerged(t *testing.T) { + prs := []domain.PRFacts{ + {URL: "a", Merged: true}, + {URL: "b", Merged: true}, + } + if got := deriveStatus(live(), prs, statusNow, true); got != domain.StatusMerged { + t.Fatalf("got %q want merged", got) + } +} + +func TestDeriveStatusAllClosedNoneMergedFallsToActivity(t *testing.T) { + prs := []domain.PRFacts{ + {URL: "a", Closed: true}, + {URL: "b", Closed: true}, + } + if got := deriveStatus(statusRec(domain.ActivityActive, false), prs, statusNow, true); got != domain.StatusWorking { + t.Fatalf("got %q want working", got) + } +} + +func TestDeriveStatusEmptyPRsUsesActivity(t *testing.T) { + if got := deriveStatus(statusRec(domain.ActivityActive, false), nil, statusNow, true); got != domain.StatusWorking { + t.Fatalf("got %q want working", got) + } +} + +func TestDeriveStatusDegenerateAllBlockedStillAggregates(t *testing.T) { + // Two PRs each targeting the other's source branch (no visible root). The + // fallback aggregates across all so the session never goes dark. + prs := []domain.PRFacts{ + {URL: "a", SourceBranch: "x", TargetBranch: "y", CI: domain.CIFailing}, + {URL: "b", SourceBranch: "y", TargetBranch: "x", Mergeability: domain.MergeMergeable}, + } + if got := deriveStatus(live(), prs, statusNow, true); got != domain.StatusCIFailed { + t.Fatalf("got %q want ci_failed (degenerate fallback)", got) + } +} diff --git a/backend/internal/service/session/status.go b/backend/internal/service/session/status.go index 4510f7f5..dca5e062 100644 --- a/backend/internal/service/session/status.go +++ b/backend/internal/service/session/status.go @@ -19,9 +19,14 @@ const noSignalGrace = 90 * time.Second // session's harness has an activity hook pipeline at all; only then can // prolonged silence mean the pipeline is broken (no_signal) rather than the // permanent, normal silence of a hook-less harness. -func deriveStatus(rec domain.SessionRecord, pr *domain.PRFacts, now time.Time, signalCapable bool) domain.SessionStatus { +// +// A session may own several PRs at once (independent or stacked). The PR-derived +// status is the worst-wins aggregate across its open PRs; stacked children whose +// parent is still open are exempt from the aggregation since they cannot merge +// until the parent does. Merged/closed PRs only matter once no open PR remains. +func deriveStatus(rec domain.SessionRecord, prs []domain.PRFacts, now time.Time, signalCapable bool) domain.SessionStatus { if rec.IsTerminated { - if pr != nil && pr.Merged { + if anyMerged(prs) { return domain.StatusMerged } return domain.StatusTerminated @@ -31,13 +36,12 @@ func deriveStatus(rec domain.SessionRecord, pr *domain.PRFacts, now time.Time, s return domain.StatusNeedsInput } - if pr != nil { - if pr.Merged { - return domain.StatusMerged - } - if !pr.Closed { - return prPipelineStatus(*pr) - } + open := openPRs(prs) + if len(open) > 0 { + return aggregatePRStatus(open) + } + if anyMerged(prs) { + return domain.StatusMerged } if rec.Activity.State == domain.ActivityActive { @@ -53,6 +57,75 @@ func deriveStatus(rec domain.SessionRecord, pr *domain.PRFacts, now time.Time, s return domain.StatusIdle } +// openPRs returns the PRs that are neither merged nor closed, preserving order. +func openPRs(prs []domain.PRFacts) []domain.PRFacts { + out := make([]domain.PRFacts, 0, len(prs)) + for _, p := range prs { + if !p.Merged && !p.Closed { + out = append(out, p) + } + } + return out +} + +func anyMerged(prs []domain.PRFacts) bool { + for _, p := range prs { + if p.Merged { + return true + } + } + return false +} + +// aggregatePRStatus is the worst-wins reduction over a session's open PRs. +// Stacked children blocked by an open parent are excluded: they cannot merge +// yet, so their pipeline state is not a user-actionable signal for the session. +// If every open PR is blocked (a degenerate stack with no visible root), it +// falls back to aggregating across all of them so the session never goes dark. +func aggregatePRStatus(open []domain.PRFacts) domain.SessionStatus { + stacks := buildStacks(open) + candidates := make([]domain.PRFacts, 0, len(open)) + for _, p := range open { + if !stacks[p.URL].Blocked { + candidates = append(candidates, p) + } + } + if len(candidates) == 0 { + candidates = open + } + worst := prPipelineStatus(candidates[0]) + for _, p := range candidates[1:] { + if s := prPipelineStatus(p); statusSeverity(s) < statusSeverity(worst) { + worst = s + } + } + return worst +} + +// statusSeverity ranks PR pipeline statuses from most to least urgent so the +// aggregate surfaces the PR that most needs attention. mergeable is least urgent +// so a session only reports mergeable when every aggregated PR is mergeable. +func statusSeverity(s domain.SessionStatus) int { + switch s { + case domain.StatusCIFailed: + return 0 + case domain.StatusChangesRequested: + return 1 + case domain.StatusDraft: + return 2 + case domain.StatusReviewPending: + return 3 + case domain.StatusPROpen: + return 4 + case domain.StatusApproved: + return 5 + case domain.StatusMergeable: + return 6 + default: + return 7 + } +} + func prPipelineStatus(pr domain.PRFacts) domain.SessionStatus { switch { case pr.CI == domain.CIFailing: diff --git a/backend/internal/service/session/status_test.go b/backend/internal/service/session/status_test.go index f7a96e2a..2c0a5ebb 100644 --- a/backend/internal/service/session/status_test.go +++ b/backend/internal/service/session/status_test.go @@ -27,13 +27,13 @@ func silentRec(age time.Duration) domain.SessionRecord { } } -func statusPR(facts domain.PRFacts) *domain.PRFacts { return &facts } +func statusPR(facts domain.PRFacts) []domain.PRFacts { return []domain.PRFacts{facts} } func TestServiceDerivesStatusFromSessionFactsAndPR(t *testing.T) { tests := []struct { name string rec domain.SessionRecord - pr *domain.PRFacts + pr []domain.PRFacts // hookless marks a harness with no activity pipeline (signalCapable // false): silence is its permanent normal state, never no_signal. hookless bool diff --git a/backend/internal/session_manager/manager.go b/backend/internal/session_manager/manager.go index 9104a8c7..c7d163f2 100644 --- a/backend/internal/session_manager/manager.go +++ b/backend/internal/session_manager/manager.go @@ -623,8 +623,9 @@ func (m *Manager) buildSystemPrompt(ctx context.Context, kind domain.SessionKind return "", err } if ok { - return workerOrchestratorPrompt(orchestratorID), nil + return workerOrchestratorPrompt(orchestratorID) + "\n\n" + workerMultiPRPrompt(), nil } + return workerMultiPRPrompt(), nil } return "", nil } @@ -665,6 +666,23 @@ An active orchestrator session exists for this project. If you hit a true blocke Only ping the orchestrator for true blockers, cross-session coordination, or decisions that cannot be resolved within your own task.`, orchestratorID) } +// workerMultiPRPrompt explains the branch convention AO uses to attribute pull +// requests to this session. A worker may open several PRs in one session: AO +// tracks every open PR whose source branch is the session's own branch or a +// descendant of it. Stacking a PR on top of another therefore only requires +// branching off with a `/` name; PRs on unrelated +// branches are attributed to whichever session owns their branch prefix. +func workerMultiPRPrompt() string { + return `## Pull requests for this session + +You can open more than one pull request from this session. AO attributes a PR to you when its source branch is your session's working branch or a branch descended from it (a "/"-separated child like ` + "`your-branch/topic`" + `). + +- For independent PRs, branch off your base branch as usual and open each PR; all of them stay tracked under this session. +- To stack a PR on top of another (so it merges after its parent), create the child branch from the parent branch and name it ` + "`/`" + `, then target the parent branch in the PR. AO recognizes the stack from the branch relationship and will only nudge you to resolve conflicts on the bottom-most PR. + +Keep branch names within your session's branch namespace so AO can track every PR you open.` +} + // spawnEnv builds the runtime environment: the per-project env vars first, then // the AO-internal vars last so they always win (a project cannot override // AO_SESSION_ID and friends). diff --git a/backend/internal/storage/sqlite/gen/pr.sql.go b/backend/internal/storage/sqlite/gen/pr.sql.go index b5bdaad4..4a92810d 100644 --- a/backend/internal/storage/sqlite/gen/pr.sql.go +++ b/backend/internal/storage/sqlite/gen/pr.sql.go @@ -181,6 +181,79 @@ func (q *Queries) GetPRLastNudgeSignature(ctx context.Context, url string) (stri return last_nudge_signature, err } +const listPRFactsBySession = `-- name: ListPRFactsBySession :many +SELECT + pr.url, + pr.number, + pr.pr_state, + pr.review_decision, + pr.ci_state, + pr.mergeability, + pr.source_branch, + pr.target_branch, + pr.updated_at, + EXISTS ( + SELECT 1 + FROM pr_comment + WHERE pr_comment.pr_url = pr.url + AND pr_comment.resolved = 0 + AND pr_comment.is_bot = 0 + ) AS review_comments +FROM pr +WHERE pr.session_id = ? +ORDER BY pr.updated_at DESC +` + +type ListPRFactsBySessionRow struct { + URL string + Number int64 + PRState domain.PRState + ReviewDecision domain.ReviewDecision + CIState domain.CIState + Mergeability domain.Mergeability + SourceBranch string + TargetBranch string + UpdatedAt time.Time + ReviewComments bool +} + +// All PR snapshots for a session (every state), with source/target branch for +// stack derivation and the unresolved-comment flag. The status aggregator +// filters open vs merged/closed in Go and derives stacks from the branches. +func (q *Queries) ListPRFactsBySession(ctx context.Context, sessionID domain.SessionID) ([]ListPRFactsBySessionRow, error) { + rows, err := q.db.QueryContext(ctx, listPRFactsBySession, sessionID) + if err != nil { + return nil, err + } + defer rows.Close() + items := []ListPRFactsBySessionRow{} + for rows.Next() { + var i ListPRFactsBySessionRow + if err := rows.Scan( + &i.URL, + &i.Number, + &i.PRState, + &i.ReviewDecision, + &i.CIState, + &i.Mergeability, + &i.SourceBranch, + &i.TargetBranch, + &i.UpdatedAt, + &i.ReviewComments, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listPRsBySession = `-- name: ListPRsBySession :many SELECT url, session_id, number, pr_state, review_decision, ci_state, mergeability, updated_at, provider, host, repo, source_branch, target_branch, head_sha, title, additions, deletions, changed_files, author, base_sha, merge_commit_sha, is_draft, is_merged, is_closed, provider_state, provider_mergeable, provider_merge_state_status, html_url, created_at_provider, updated_at_provider, merged_at_provider, closed_at_provider, metadata_hash, ci_hash, review_hash, observed_at, ci_observed_at, review_observed_at, last_nudge_signature FROM pr WHERE session_id = ? diff --git a/backend/internal/storage/sqlite/queries/pr.sql b/backend/internal/storage/sqlite/queries/pr.sql index b4d745cd..8767b703 100644 --- a/backend/internal/storage/sqlite/queries/pr.sql +++ b/backend/internal/storage/sqlite/queries/pr.sql @@ -101,6 +101,31 @@ ORDER BY pr.updated_at DESC LIMIT 1; +-- name: ListPRFactsBySession :many +-- All PR snapshots for a session (every state), with source/target branch for +-- stack derivation and the unresolved-comment flag. The status aggregator +-- filters open vs merged/closed in Go and derives stacks from the branches. +SELECT + pr.url, + pr.number, + pr.pr_state, + pr.review_decision, + pr.ci_state, + pr.mergeability, + pr.source_branch, + pr.target_branch, + pr.updated_at, + EXISTS ( + SELECT 1 + FROM pr_comment + WHERE pr_comment.pr_url = pr.url + AND pr_comment.resolved = 0 + AND pr_comment.is_bot = 0 + ) AS review_comments +FROM pr +WHERE pr.session_id = ? +ORDER BY pr.updated_at DESC; + -- name: ClaimPRForSession :exec INSERT INTO pr (url, session_id, number, pr_state, review_decision, ci_state, mergeability, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) diff --git a/backend/internal/storage/sqlite/store/pr_facts.go b/backend/internal/storage/sqlite/store/pr_facts.go index 894bd727..d17f0592 100644 --- a/backend/internal/storage/sqlite/store/pr_facts.go +++ b/backend/internal/storage/sqlite/store/pr_facts.go @@ -24,6 +24,34 @@ func (s *Store) GetDisplayPRFactsForSession(ctx context.Context, id domain.Sessi return prFactsFromGen(r), true, nil } +// ListPRFactsForSession returns the PR snapshot for every PR a session owns +// (open, merged, and closed), newest first. The status aggregator filters and +// builds stacks from these; an empty slice means the session has no PRs. +func (s *Store) ListPRFactsForSession(ctx context.Context, id domain.SessionID) ([]domain.PRFacts, error) { + rows, err := s.qr.ListPRFactsBySession(ctx, id) + if err != nil { + return nil, fmt.Errorf("list pr facts for %s: %w", id, err) + } + out := make([]domain.PRFacts, 0, len(rows)) + for _, r := range rows { + out = append(out, domain.PRFacts{ + URL: r.URL, + Number: int(r.Number), + Draft: r.PRState == domain.PRStateDraft, + Merged: r.PRState == domain.PRStateMerged, + Closed: r.PRState == domain.PRStateClosed, + CI: r.CIState, + Review: r.ReviewDecision, + Mergeability: r.Mergeability, + ReviewComments: r.ReviewComments, + SourceBranch: r.SourceBranch, + TargetBranch: r.TargetBranch, + UpdatedAt: r.UpdatedAt, + }) + } + return out, nil +} + func prFactsFromGen(r gen.GetDisplayPRFactsBySessionRow) domain.PRFacts { state := r.PRState return domain.PRFacts{ diff --git a/backend/internal/storage/sqlite/store/pr_facts_test.go b/backend/internal/storage/sqlite/store/pr_facts_test.go new file mode 100644 index 00000000..dd2405ac --- /dev/null +++ b/backend/internal/storage/sqlite/store/pr_facts_test.go @@ -0,0 +1,81 @@ +package store_test + +import ( + "context" + "testing" + "time" + + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/ports" +) + +// ListPRFactsForSession is the real-SQLite batch read the multi-PR status +// aggregator builds stacks from: every owned PR returned newest-first with its +// state flags and branch pair projected (the stack model needs both). +// +// The branch pair is written via WriteSCMObservation (the observer path, the +// source of truth for tracked PRs). The other writer, WritePR, deliberately +// omits source/target branch (UpsertLegacyPR), so the stack model depends on the +// observer having populated the row. +func TestListPRFactsForSessionProjectsAllPRsNewestFirst(t *testing.T) { + s := newTestStore(t) + ctx := context.Background() + seedProject(t, s, "mer") + r, _ := s.CreateSession(ctx, sampleRecord("mer")) + now := time.Now().UTC().Truncate(time.Second) + + // A stack: root (open) -> child targets the root branch (open) -> a merged + // historical PR. Distinct updated_at so newest-first ordering is observable. + write := func(pr domain.PullRequest) { + t.Helper() + if err := s.WriteSCMObservation(ctx, pr, nil, nil, nil, ports.ReviewWritePreserve); err != nil { + t.Fatalf("write %s: %v", pr.URL, err) + } + } + write(domain.PullRequest{URL: "root", SessionID: r.ID, Number: 1, CI: domain.CIPassing, SourceBranch: "feat/x", TargetBranch: "main", UpdatedAt: now, ObservedAt: now}) + write(domain.PullRequest{URL: "child", SessionID: r.ID, Number: 2, Draft: true, SourceBranch: "feat/x/child", TargetBranch: "feat/x", UpdatedAt: now.Add(time.Second), ObservedAt: now}) + write(domain.PullRequest{URL: "old", SessionID: r.ID, Number: 3, Merged: true, SourceBranch: "feat/old", TargetBranch: "main", UpdatedAt: now.Add(2 * time.Second), ObservedAt: now}) + + facts, err := s.ListPRFactsForSession(ctx, r.ID) + if err != nil { + t.Fatal(err) + } + if len(facts) != 3 { + t.Fatalf("ListPRFactsForSession = %d, want 3", len(facts)) + } + // Newest-first by updated_at: old, child, root. + if facts[0].URL != "old" || facts[1].URL != "child" || facts[2].URL != "root" { + t.Fatalf("order = [%s %s %s], want [old child root]", facts[0].URL, facts[1].URL, facts[2].URL) + } + byURL := map[string]domain.PRFacts{} + for _, f := range facts { + byURL[f.URL] = f + } + if !byURL["old"].Merged || byURL["old"].Closed || byURL["old"].Draft { + t.Fatalf("merged PR flags wrong: %+v", byURL["old"]) + } + if !byURL["child"].Draft || byURL["child"].Merged { + t.Fatalf("draft child flags wrong: %+v", byURL["child"]) + } + // The stack model is derived from the source/target branch pair, so it must + // survive the projection. + if byURL["child"].SourceBranch != "feat/x/child" || byURL["child"].TargetBranch != "feat/x" { + t.Fatalf("child branch pair lost: %+v", byURL["child"]) + } + if byURL["root"].SourceBranch != "feat/x" || byURL["root"].TargetBranch != "main" { + t.Fatalf("root branch pair lost: %+v", byURL["root"]) + } + if byURL["root"].CI != domain.CIPassing { + t.Fatalf("root CI = %q, want passing", byURL["root"].CI) + } + + // A session with no PRs returns an empty (non-nil) slice, never an error. + empty, _ := s.CreateSession(ctx, sampleRecord("mer")) + got, err := s.ListPRFactsForSession(ctx, empty.ID) + if err != nil { + t.Fatal(err) + } + if len(got) != 0 { + t.Fatalf("no-PR session = %d facts, want 0", len(got)) + } +} diff --git a/frontend/src/api/schema.ts b/frontend/src/api/schema.ts index 93d306ad..92a4177f 100644 --- a/frontend/src/api/schema.ts +++ b/frontend/src/api/schema.ts @@ -410,6 +410,23 @@ export interface components { reason: string; sessionId: string; }; + ControllersSessionView: { + activity: components["schemas"]["DomainActivity"]; + /** Format: date-time */ + createdAt: string; + displayName?: string; + harness?: string; + id: string; + isTerminated: boolean; + issueId?: string; + kind: string; + projectId: string; + prs: components["schemas"]["SessionPRFacts"][]; + status: string; + terminalHandleId?: string; + /** Format: date-time */ + updatedAt: string; + }; DegradedProject: { id: string; kind: string; @@ -442,7 +459,7 @@ export interface components { sessionId: string; }; ListSessionsResponse: { - sessions: components["schemas"]["Session"][]; + sessions: components["schemas"]["ControllersSessionView"][]; }; MergePRResponse: { method: string; @@ -512,7 +529,7 @@ export interface components { }; RestoreSessionResponse: { ok: boolean; - session: components["schemas"]["Session"]; + session: components["schemas"]["ControllersSessionView"]; sessionId: string; }; ReviewFinding: { @@ -551,22 +568,6 @@ export interface components { ok: boolean; sessionId: string; }; - Session: { - activity: components["schemas"]["DomainActivity"]; - /** Format: date-time */ - createdAt: string; - displayName?: string; - harness?: string; - id: string; - isTerminated: boolean; - issueId?: string; - kind: string; - projectId: string; - status: string; - terminalHandleId?: string; - /** Format: date-time */ - updatedAt: string; - }; SessionPRFacts: { ci: string; mergeability: string; @@ -580,7 +581,7 @@ export interface components { url: string; }; SessionResponse: { - session: components["schemas"]["Session"]; + session: components["schemas"]["ControllersSessionView"]; }; SetActivityRequest: { /**