Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions cmd/spr/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ var (
date = "unknown"
)

func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n]
}

func init() {
zerolog.SetGlobalLevel(zerolog.InfoLevel)
log.Logger = log.With().Caller().Logger().Output(zerolog.ConsoleWriter{Out: os.Stderr})
Expand Down Expand Up @@ -98,6 +105,13 @@ func main() {
Usage: "Show detailed status bits output",
}

textFlag := &cli.BoolFlag{
Name: "text",
Aliases: []string{"t"},
Value: false,
Usage: "Show plain text output (URL : title)",
}

cli.AppHelpTemplate = `NAME:
{{.Name}} - {{.Usage}}

Expand All @@ -117,7 +131,7 @@ VERSION: fork of {{.Version}}
Name: "spr",
Usage: "Stacked Pull Requests on GitHub",
HideVersion: true,
Version: fmt.Sprintf("%s : %s : %s\n", version, date, commit[:8]),
Version: fmt.Sprintf("%s : %s : %s\n", version, date, truncate(commit, 8)),
EnableBashCompletion: true,
Authors: []*cli.Author{
{
Expand Down Expand Up @@ -170,12 +184,16 @@ VERSION: fork of {{.Version}}
Name: "status",
Aliases: []string{"s", "st"},
Usage: "Show status of open pull requests",
Action: func(c *cli.Context) error {
stackedpr.StatusPullRequests(ctx)
return nil
},
Action: func(c *cli.Context) error {
if c.IsSet("text") {
stackedpr.TextEnabled = true
}
stackedpr.StatusPullRequests(ctx)
return nil
},
Flags: []cli.Flag{
detailFlag,
textFlag,
},
},
{
Expand Down
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type RepoConfig struct {
GitHubBranch string `default:"main" yaml:"githubBranch"`

RequireChecks bool `default:"true" yaml:"requireChecks"`
RequiredChecks []string `yaml:"requiredChecks"`
RequireApproval bool `default:"true" yaml:"requireApproval"`
DefaultReviewers []string `yaml:"defaultReviewers"`

Expand Down
6 changes: 3 additions & 3 deletions git/mockgit/mockgit.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,21 @@ func (m *Mock) ExpectEditStart() {

// ExpectEditDoneAmend expects the amend + rebase continue sequence for a successful edit --done
func (m *Mock) ExpectEditDoneAmend() {
m.expect("git add -A")
m.expect("git add -u")
m.expect("git commit --amend --no-edit")
m.expect("git rebase --continue")
}

// ExpectEditDoneAmendWithConflict expects amend succeeds but rebase --continue fails (conflict)
func (m *Mock) ExpectEditDoneAmendWithConflict() {
m.expect("git add -A")
m.expect("git add -u")
m.expect("git commit --amend --no-edit")
m.expectError("git rebase --continue", errors.New("conflict"))
}

// ExpectEditDoneConflictResolved expects the conflict resolution path (no amend, just rebase continue)
func (m *Mock) ExpectEditDoneConflictResolved() {
m.expect("git add -A")
m.expect("git add -u")
m.expect("git rebase --continue")
}

Expand Down
262 changes: 256 additions & 6 deletions github/githubclient/client.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package githubclient

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"path"
Expand Down Expand Up @@ -152,8 +155,9 @@ func NewGitHubClient(ctx context.Context, config *config.Config) *client {
tc := oauth2.NewClient(ctx, ts)

var api genclient.Client
var endpoint string
if strings.HasSuffix(config.Repo.GitHubHost, "github.com") {
api = genclient.NewClient("https://api.github.com/graphql", tc)
endpoint = "https://api.github.com/graphql"
} else {
var scheme, host string
gitHubRemoteUrl, err := url.Parse(config.Repo.GitHubHost)
Expand All @@ -165,17 +169,22 @@ func NewGitHubClient(ctx context.Context, config *config.Config) *client {
host = gitHubRemoteUrl.Host
scheme = gitHubRemoteUrl.Scheme
}
api = genclient.NewClient(fmt.Sprintf("%s://%s/api/graphql", scheme, host), tc)
endpoint = fmt.Sprintf("%s://%s/api/graphql", scheme, host)
}
api = genclient.NewClient(endpoint, tc)
return &client{
config: config,
api: api,
config: config,
api: api,
graphqlEndpoint: endpoint,
httpClient: tc,
}
}

type client struct {
config *config.Config
api genclient.Client
config *config.Config
api genclient.Client
graphqlEndpoint string
httpClient *http.Client
}

func (c *client) GetInfo(ctx context.Context, gitcmd git.GitInterface) *github.GitHubInfo {
Expand Down Expand Up @@ -208,6 +217,22 @@ func (c *client) GetInfo(ctx context.Context, gitcmd git.GitInterface) *github.G
localCommitStack := git.GetLocalCommitStack(c.config, gitcmd)

pullRequests := matchPullRequestStack(c.config.Repo, c.config.User.BranchPrefix, targetBranch, localCommitStack, pullRequestConnection)

// When RequiredChecks is explicitly configured, fetch individual check contexts
// and only evaluate the listed checks. This allows non-required check failures
// to be ignored. When RequiredChecks is not set, the statusCheckRollup.state
// from the fezzik query is used as-is (all checks matter).
if c.config.Repo.RequireChecks && len(c.config.Repo.RequiredChecks) > 0 && len(pullRequests) > 0 {
requiredStatus := c.fetchRequiredChecksStatus(ctx, pullRequests)
if requiredStatus != nil {
for _, pr := range pullRequests {
if status, ok := requiredStatus[pr.Number]; ok {
pr.MergeStatus.ChecksPass = status
}
}
}
}

for _, pr := range pullRequests {
if pr.Ready(c.config) {
pr.MergeStatus.Stacked = true
Expand Down Expand Up @@ -592,6 +617,231 @@ func (c *client) ClosePullRequest(ctx context.Context, pr *github.PullRequest) {
}
}

// Response types for the raw GraphQL query that fetches individual check contexts.
// These are used instead of fezzik-generated types because fezzik does not support
// inline fragments on union types (StatusCheckRollupContext = CheckRun | StatusContext).

type checkContextNode struct {
TypeName string `json:"__typename"`
Name string `json:"name"` // CheckRun
Conclusion *string `json:"conclusion"` // CheckRun (nil when not completed)
Status string `json:"status"` // CheckRun: COMPLETED, IN_PROGRESS, QUEUED, etc.
Context string `json:"context"` // StatusContext
State string `json:"state"` // StatusContext: SUCCESS, FAILURE, PENDING, etc.
}

type checkContextsResult struct {
Number int `json:"number"`
Commits struct {
Nodes []struct {
Commit struct {
StatusCheckRollup *struct {
Contexts struct {
Nodes []checkContextNode `json:"nodes"`
} `json:"contexts"`
} `json:"statusCheckRollup"`
} `json:"commit"`
} `json:"nodes"`
} `json:"commits"`
}

type graphqlRequest struct {
Query string `json:"query"`
}

type graphqlResponse struct {
Data map[string]json.RawMessage `json:"data"`
Errors []struct {
Message string `json:"message"`
} `json:"errors"`
}

// fetchRequiredChecksStatus makes a single batched GraphQL query to fetch
// individual check contexts for all given pull requests. It evaluates only
// the checks listed in config.Repo.RequiredChecks and returns a map from
// PR number to the computed status.
func (c *client) fetchRequiredChecksStatus(ctx context.Context, pullRequests []*github.PullRequest) map[int]github.CheckStatus {
if len(pullRequests) == 0 {
return nil
}

if c.config.User.LogGitHubCalls {
fmt.Printf("> github fetch required check status\n")
}

// Build a single GraphQL query with one aliased field per PR.
var queryBuilder strings.Builder
queryBuilder.WriteString("query {")
for _, pr := range pullRequests {
fmt.Fprintf(&queryBuilder, `
pr_%d: node(id: %q) {
... on PullRequest {
number
commits(last: 1) {
nodes {
commit {
statusCheckRollup {
contexts(first: 100) {
nodes {
__typename
... on CheckRun {
name
conclusion
status
}
... on StatusContext {
context
state
}
}
}
}
}
}
}
}
}`, pr.Number, pr.ID)
}
queryBuilder.WriteString("\n}")

reqBody, err := json.Marshal(graphqlRequest{
Query: queryBuilder.String(),
})
if err != nil {
log.Warn().Err(err).Msg("failed to marshal required checks query")
return nil
}

httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.graphqlEndpoint, bytes.NewReader(reqBody))
if err != nil {
log.Warn().Err(err).Msg("failed to create required checks request")
return nil
}
httpReq.Header.Set("Content-Type", "application/json; charset=utf-8")
httpReq.Header.Set("Accept", "application/json; charset=utf-8")

httpResp, err := c.httpClient.Do(httpReq)
if err != nil {
log.Warn().Err(err).Msg("failed to fetch required checks status")
return nil
}
defer httpResp.Body.Close()

var gqlResp graphqlResponse
if err := json.NewDecoder(httpResp.Body).Decode(&gqlResp); err != nil {
log.Warn().Err(err).Msg("failed to decode required checks response")
return nil
}
if len(gqlResp.Errors) > 0 {
log.Warn().Str("error", gqlResp.Errors[0].Message).Msg("graphql error fetching required checks")
return nil
}

// Build the set of required check names from config.
requiredSet := make(map[string]bool, len(c.config.Repo.RequiredChecks))
for _, name := range c.config.Repo.RequiredChecks {
requiredSet[name] = true
}

result := make(map[int]github.CheckStatus)
for _, pr := range pullRequests {
alias := fmt.Sprintf("pr_%d", pr.Number)
raw, ok := gqlResp.Data[alias]
if !ok {
continue
}
var prResult checkContextsResult
if err := json.Unmarshal(raw, &prResult); err != nil {
log.Warn().Err(err).Int("pr", pr.Number).Msg("failed to unmarshal check contexts for PR")
continue
}
if len(prResult.Commits.Nodes) == 0 {
continue
}
commit := prResult.Commits.Nodes[0].Commit
if commit.StatusCheckRollup == nil {
// No checks configured — treat as pass
result[pr.Number] = github.CheckStatusPass
continue
}
result[pr.Number] = computeRequiredCheckStatus(commit.StatusCheckRollup.Contexts.Nodes, requiredSet)
}

return result
}

// contextName returns the display name for a check context node.
// For CheckRun nodes this is the Name field; for StatusContext nodes it's the Context field.
func contextName(ctx checkContextNode) string {
if ctx.TypeName == "StatusContext" {
return ctx.Context
}
return ctx.Name
}

// computeRequiredCheckStatus determines the aggregate check status considering
// only the checks whose name/context appears in requiredChecks.
// If a required check hasn't reported yet (not present in contexts), it is
// treated as pending.
func computeRequiredCheckStatus(contexts []checkContextNode, requiredChecks map[string]bool) github.CheckStatus {
// Track which required checks we've seen
seen := make(map[string]bool, len(requiredChecks))
hasPending := false
hasFail := false

for _, ctx := range contexts {
name := contextName(ctx)
if !requiredChecks[name] {
continue
}
seen[name] = true

switch ctx.TypeName {
case "CheckRun":
switch ctx.Status {
case "COMPLETED":
if ctx.Conclusion == nil {
hasFail = true
} else {
switch *ctx.Conclusion {
case "SUCCESS", "NEUTRAL", "SKIPPED":
// pass
default:
hasFail = true
}
}
default:
// IN_PROGRESS, QUEUED, REQUESTED, WAITING, PENDING
hasPending = true
}
case "StatusContext":
switch ctx.State {
case "SUCCESS":
// pass
case "PENDING", "EXPECTED":
hasPending = true
default:
hasFail = true
}
}
}

// Any required check that hasn't reported yet is pending
for name := range requiredChecks {
if !seen[name] {
hasPending = true
}
}

if hasFail {
return github.CheckStatusFail
}
if hasPending {
return github.CheckStatusPending
}
return github.CheckStatusPass
}

func check(err error) {
if err != nil {
msg := err.Error()
Expand Down
Loading
Loading