diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..509f8f4 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,5 @@ +.local/ +.env +dist/ +tmp/ +.git/ diff --git a/.github/.dependabot.yaml b/.github/.dependabot.yaml index 1230149..f3f3cd3 100644 --- a/.github/.dependabot.yaml +++ b/.github/.dependabot.yaml @@ -4,3 +4,8 @@ updates: directory: "/" schedule: interval: "daily" + + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9ba4e80..d166b19 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -46,10 +46,10 @@ jobs: with: go-version: '1.24' - - name: Run tests + - name: Run unit tests run: make test - - name: Run tests + - name: Run integration tests run: make test-verify-verbose build: diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..da8effe --- /dev/null +++ b/Dockerfile @@ -0,0 +1,11 @@ +FROM golang:1.24 AS builder + +WORKDIR /src +COPY go.mod go.sum ./ +RUN go mod download +COPY . . +RUN CGO_ENABLED=0 go build -trimpath -ldflags "-s -w" -o /server ./cmd/server + +FROM gcr.io/distroless/static-debian12 +COPY --from=builder /server /server +ENTRYPOINT ["/server"] diff --git a/Makefile b/Makefile index 121fa72..cdbd6db 100644 --- a/Makefile +++ b/Makefile @@ -43,3 +43,15 @@ test-verify: test-verify-verbose: go run ./cmd/verify -verbose +.PHONY: server-up +server-up: + docker compose up -d --build + +.PHONY: server-logs +server-logs: + docker compose logs -f server + +.PHONY: server-stop +server-stop: + docker compose down --rmi local --remove-orphans + diff --git a/assets/github/manafiest.json b/assets/github/manifest.json similarity index 100% rename from assets/github/manafiest.json rename to assets/github/manifest.json diff --git a/cmd/lambda/main.go b/cmd/lambda/main.go index 72e7f6e..8d41b21 100644 --- a/cmd/lambda/main.go +++ b/cmd/lambda/main.go @@ -1,15 +1,19 @@ package main import ( + "bytes" "context" "encoding/json" "fmt" "log/slog" + "net/http" + "net/http/httptest" "strings" "sync" awsevents "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambda" + "github.com/cockroachdb/errors" "github.com/cruxstack/github-ops-app/internal/app" "github.com/cruxstack/github-ops-app/internal/config" ) @@ -17,6 +21,7 @@ import ( var ( initOnce sync.Once appInst *app.App + router http.Handler logger *slog.Logger initErr error ) @@ -27,14 +32,19 @@ func initApp() { cfg, err := config.NewConfig() if err != nil { - initErr = fmt.Errorf("config init failed: %w", err) + initErr = errors.Wrap(err, "config init failed") return } - appInst, initErr = app.New(context.Background(), cfg) + appInst, initErr = app.NewApp(context.Background(), cfg, logger) + if initErr != nil { + return + } + router = appInst.Handler() }) } -// APIGatewayHandler converts API Gateway requests to unified app.Request. +// APIGatewayHandler converts API Gateway requests to stdlib *http.Request +// and routes them through the chi router. func APIGatewayHandler(ctx context.Context, req awsevents.APIGatewayV2HTTPRequest) (awsevents.APIGatewayV2HTTPResponse, error) { initApp() if initErr != nil { @@ -50,29 +60,35 @@ func APIGatewayHandler(ctx context.Context, req awsevents.APIGatewayV2HTTPReques logger.Debug("received api gateway request", slog.String("request", string(j))) } - headers := make(map[string]string) - for key, value := range req.Headers { - headers[strings.ToLower(key)] = value + httpReq, err := http.NewRequestWithContext( + ctx, + req.RequestContext.HTTP.Method, + req.RawPath, + strings.NewReader(req.Body), + ) + if err != nil { + return awsevents.APIGatewayV2HTTPResponse{ + StatusCode: 500, + Body: "failed to construct http request", + }, nil } - appReq := app.Request{ - Type: app.RequestTypeHTTP, - Method: req.RequestContext.HTTP.Method, - Path: req.RawPath, - Headers: headers, - Body: []byte(req.Body), + for key, value := range req.Headers { + httpReq.Header.Set(key, value) } - resp := appInst.HandleRequest(ctx, appReq) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, httpReq) return awsevents.APIGatewayV2HTTPResponse{ - StatusCode: resp.StatusCode, - Headers: resp.Headers, - Body: string(resp.Body), + StatusCode: rec.Code, + Headers: flattenHeaders(rec.Header()), + Body: rec.Body.String(), }, nil } -// EventBridgeHandler converts EventBridge events to unified app.Request. +// EventBridgeHandler converts EventBridge events to POST /scheduled/{action} +// requests and routes them through the chi router. func EventBridgeHandler(ctx context.Context, evt awsevents.CloudWatchEvent) error { initApp() if initErr != nil { @@ -90,16 +106,30 @@ func EventBridgeHandler(ctx context.Context, evt awsevents.CloudWatchEvent) erro return err } - req := app.Request{ - Type: app.RequestTypeScheduled, - ScheduledAction: detail.Action, - ScheduledData: detail.Data, + path := fmt.Sprintf("%s/scheduled/%s", appInst.Config.BasePath, detail.Action) + + var body []byte + if detail.Data != nil { + body = detail.Data } - resp := appInst.HandleRequest(ctx, req) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, path, bytes.NewReader(body)) + if err != nil { + return errors.Wrap(err, "failed to construct http request") + } - if resp.StatusCode >= 400 { - return fmt.Errorf("scheduled event failed: %s", string(resp.Body)) + if appInst.Config.AdminToken != "" { + httpReq.Header.Set("Authorization", "Bearer "+appInst.Config.AdminToken) + } + if len(body) > 0 { + httpReq.Header.Set("Content-Type", "application/json") + } + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, httpReq) + + if rec.Code >= 400 { + return errors.Newf("scheduled event failed: %s", rec.Body.String()) } return nil @@ -122,7 +152,18 @@ func UniversalHandler(ctx context.Context, event json.RawMessage) (any, error) { return nil, EventBridgeHandler(ctx, eventBridgeEvent) } - return nil, fmt.Errorf("unknown lambda event type") + return nil, errors.New("unknown lambda event type") +} + +// flattenHeaders converts multi-value http.Header to single-value map. +func flattenHeaders(h http.Header) map[string]string { + flat := make(map[string]string, len(h)) + for key, values := range h { + if len(values) > 0 { + flat[key] = values[0] + } + } + return flat } func main() { diff --git a/cmd/sample/main.go b/cmd/sample/main.go index e7fa906..40888c4 100644 --- a/cmd/sample/main.go +++ b/cmd/sample/main.go @@ -4,9 +4,13 @@ package main import ( + "bytes" "context" "encoding/json" + "fmt" "log/slog" + "net/http" + "net/http/httptest" "os" "path/filepath" @@ -32,12 +36,14 @@ func main() { os.Exit(1) } - a, err := app.New(ctx, cfg) + a, err := app.NewApp(ctx, cfg, logger) if err != nil { logger.Error("failed to initialize app", slog.String("error", err.Error())) os.Exit(1) } + router := a.Handler() + path := filepath.Join("fixtures", "samples.json") raw, err := os.ReadFile(path) if err != nil { @@ -52,28 +58,53 @@ func main() { } for i, sample := range samples { - eventType := sample["event_type"].(string) + eventType, ok := sample["event_type"].(string) + if !ok { + logger.Error("missing or invalid event_type", slog.Int("sample", i)) + os.Exit(1) + } switch eventType { case "okta_sync": - evt := app.ScheduledEvent{ - Action: "okta-sync", + reqPath := fmt.Sprintf("%s/scheduled/okta-sync", cfg.BasePath) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, reqPath, nil) + if err != nil { + logger.Error("failed to construct http request", + slog.Int("sample", i), + slog.String("error", err.Error())) + os.Exit(1) + } + if cfg.AdminToken != "" { + httpReq.Header.Set("Authorization", "Bearer "+cfg.AdminToken) } - if err := a.ProcessScheduledEvent(ctx, evt); err != nil { + rec := httptest.NewRecorder() + router.ServeHTTP(rec, httpReq) + if rec.Code >= 400 { logger.Error("failed to process okta_sync sample", slog.Int("sample", i), - slog.String("error", err.Error())) + slog.String("response", rec.Body.String())) os.Exit(1) } case "pr_webhook": payload, _ := json.Marshal(sample["payload"]) - if err := a.ProcessWebhook(ctx, payload, "pull_request"); err != nil { - logger.Error("failed to process pr_webhook sample", + reqPath := fmt.Sprintf("%s/webhooks", cfg.BasePath) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, reqPath, bytes.NewReader(payload)) + if err != nil { + logger.Error("failed to construct http request", slog.Int("sample", i), slog.String("error", err.Error())) os.Exit(1) } + httpReq.Header.Set("X-GitHub-Event", "pull_request") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, httpReq) + if rec.Code >= 400 { + logger.Error("failed to process pr_webhook sample", + slog.Int("sample", i), + slog.String("response", rec.Body.String())) + os.Exit(1) + } default: logger.Info("skipping unknown event type", slog.String("event_type", eventType)) diff --git a/cmd/server/main.go b/cmd/server/main.go index cc4b84c..23ab07c 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,12 +2,10 @@ package main import ( "context" - "io" "log/slog" "net/http" "os" "os/signal" - "strings" "syscall" "time" @@ -15,13 +13,8 @@ import ( "github.com/cruxstack/github-ops-app/internal/config" ) -var ( - appInst *app.App - logger *slog.Logger -) - func main() { - logger = config.NewLogger() + logger := config.NewLogger() ctx := context.Background() cfg, err := config.NewConfig() @@ -30,15 +23,12 @@ func main() { os.Exit(1) } - appInst, err = app.New(ctx, cfg) + appInst, err := app.NewApp(ctx, cfg, logger) if err != nil { logger.Error("app init failed", slog.String("error", err.Error())) os.Exit(1) } - mux := http.NewServeMux() - mux.HandleFunc("/", httpHandler) - port := os.Getenv("APP_PORT") if port == "" { port = "8080" @@ -46,7 +36,7 @@ func main() { srv := &http.Server{ Addr: ":" + port, - Handler: mux, + Handler: appInst.Handler(), ReadTimeout: 15 * time.Second, WriteTimeout: 15 * time.Second, IdleTimeout: 60 * time.Second, @@ -80,42 +70,3 @@ func main() { <-done logger.Info("server stopped") } - -// httpHandler converts http.Request to app.Request and handles the response. -func httpHandler(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "failed to read request body", http.StatusBadRequest) - return - } - defer r.Body.Close() - - headers := make(map[string]string) - for key, values := range r.Header { - if len(values) > 0 { - headers[strings.ToLower(key)] = values[0] - } - } - - req := app.Request{ - Type: app.RequestTypeHTTP, - Method: r.Method, - Path: r.URL.Path, - Headers: headers, - Body: body, - } - - resp := appInst.HandleRequest(r.Context(), req) - - for key, value := range resp.Headers { - w.Header().Set(key, value) - } - if resp.ContentType != "" && w.Header().Get("Content-Type") == "" { - w.Header().Set("Content-Type", resp.ContentType) - } - - w.WriteHeader(resp.StatusCode) - if len(resp.Body) > 0 { - w.Write(resp.Body) - } -} diff --git a/cmd/verify/.env.test b/cmd/verify/.env.test index d070bb9..9a375a7 100644 --- a/cmd/verify/.env.test +++ b/cmd/verify/.env.test @@ -1,8 +1,9 @@ # config used during offline verification testing # -# - all api calls are mocked all config is fake -# - private keys is dynamically generated during test setup -# - endpoints are https with tls configured during setup +# - all api calls are mocked and all config is fake +# - private keys are dynamically generated during test setup +# - endpoints use dynamic ports assigned at runtime +# - base urls are set per-scenario by the test runner APP_DEBUG_ENABLED=false @@ -11,7 +12,6 @@ APP_GITHUB_APP_ID=123456 APP_GITHUB_INSTALLATION_ID=987654 APP_GITHUB_ORG=acme-ghorg APP_GITHUB_WEBHOOK_SECRET=test_webhook_secret -APP_GITHUB_BASE_URL=https://localhost:9001/ # github pr compliance configuration APP_PR_COMPLIANCE_ENABLED=true @@ -19,7 +19,6 @@ APP_PR_MONITORED_BRANCHES=main,master # okta configuration (oauth2 with private key) APP_OKTA_DOMAIN=dev-12345.okta.com -APP_OKTA_CLIENT_ID=test-client-id # okta sync rules APP_OKTA_GITHUB_USER_FIELD=githubUsername @@ -28,4 +27,3 @@ APP_OKTA_SYNC_RULES=[{"enabled":true,"okta_group_name":"Engineering","github_tea # slack configuration APP_SLACK_TOKEN=xoxb-test-token APP_SLACK_CHANNEL=C01234TEST -APP_SLACK_API_URL=https://localhost:9003/ diff --git a/cmd/verify/logger.go b/cmd/verify/logger.go index 1780d7d..d7b9f68 100644 --- a/cmd/verify/logger.go +++ b/cmd/verify/logger.go @@ -18,7 +18,7 @@ type testHandler struct { // Enabled returns true for all log levels when verbose mode is enabled. func (h *testHandler) Enabled(_ context.Context, _ slog.Level) bool { - return true + return h.verbose } // Handle formats and writes log records to output with test-appropriate diff --git a/cmd/verify/scenario.go b/cmd/verify/scenario.go index 4b00464..2ce8e8b 100644 --- a/cmd/verify/scenario.go +++ b/cmd/verify/scenario.go @@ -1,17 +1,21 @@ package main import ( + "bytes" "context" "crypto/tls" "encoding/json" "fmt" "log/slog" "net/http" + "net/http/httptest" "os" "time" + "github.com/cockroachdb/errors" "github.com/cruxstack/github-ops-app/internal/app" "github.com/cruxstack/github-ops-app/internal/config" + oktaclient "github.com/cruxstack/github-ops-app/internal/okta" ) // TestScenario defines a test case with input events and expected outcomes. @@ -65,74 +69,58 @@ func runScenario(ctx context.Context, scenario TestScenario, verbose bool, logge tlsCert, certPool, err := generateSelfSignedCert() if err != nil { - return fmt.Errorf("generate cert: %w", err) + return errors.Wrap(err, "failed to generate cert") } githubAppKey, err := generateOAuthPrivateKey() if err != nil { - return fmt.Errorf("generate github app key: %w", err) + return errors.Wrap(err, "failed to generate github app key") } - os.Setenv("APP_GITHUB_APP_PRIVATE_KEY", string(githubAppKey)) oauthKey, err := generateOAuthPrivateKey() if err != nil { - return fmt.Errorf("generate oauth key: %w", err) + return errors.Wrap(err, "failed to generate oauth key") } - os.Setenv("APP_OKTA_CLIENT_ID", "test-client-id") - os.Setenv("APP_OKTA_PRIVATE_KEY", string(oauthKey)) - githubServer := &http.Server{ - Addr: "localhost:9001", - Handler: githubMock, - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - }, + // use dynamic ports: bind to :0 and extract the assigned port + tlsConfig := &tls.Config{Certificates: []tls.Certificate{tlsCert}} + + githubListener, err := tls.Listen("tcp", "localhost:0", tlsConfig) + if err != nil { + return errors.Wrap(err, "failed to listen github") } - oktaServer := &http.Server{ - Addr: "localhost:9002", - Handler: oktaMock, - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - }, + oktaListener, err := tls.Listen("tcp", "localhost:0", tlsConfig) + if err != nil { + githubListener.Close() + return errors.Wrap(err, "failed to listen okta") } - slackServer := &http.Server{ - Addr: "localhost:9003", - Handler: slackMock, - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - }, + slackListener, err := tls.Listen("tcp", "localhost:0", tlsConfig) + if err != nil { + githubListener.Close() + oktaListener.Close() + return errors.Wrap(err, "failed to listen slack") } - githubReady := make(chan bool) - oktaReady := make(chan bool) - slackReady := make(chan bool) + githubServer := &http.Server{Handler: githubMock} + oktaServer := &http.Server{Handler: oktaMock} + slackServer := &http.Server{Handler: slackMock} go func() { - githubReady <- true - if err := githubServer.ListenAndServeTLS("", ""); err != http.ErrServerClosed { + if err := githubServer.Serve(githubListener); err != http.ErrServerClosed { logger.Error("github mock server error", slog.String("error", err.Error())) } }() - go func() { - oktaReady <- true - if err := oktaServer.ListenAndServeTLS("", ""); err != http.ErrServerClosed { + if err := oktaServer.Serve(oktaListener); err != http.ErrServerClosed { logger.Error("okta mock server error", slog.String("error", err.Error())) } }() - go func() { - slackReady <- true - if err := slackServer.ListenAndServeTLS("", ""); err != http.ErrServerClosed { + if err := slackServer.Serve(slackListener); err != http.ErrServerClosed { logger.Error("slack mock server error", slog.String("error", err.Error())) } }() - <-githubReady - <-oktaReady - <-slackReady - time.Sleep(100 * time.Millisecond) - defer func() { shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -141,17 +129,51 @@ func runScenario(ctx context.Context, scenario TestScenario, verbose bool, logge slackServer.Shutdown(shutdownCtx) }() + githubAddr := fmt.Sprintf("https://%s/", githubListener.Addr().String()) + oktaAddr := fmt.Sprintf("https://%s", oktaListener.Addr().String()) + slackAddr := fmt.Sprintf("https://%s/", slackListener.Addr().String()) + + // save and restore environment variables for isolation between scenarios + envKeys := []string{ + "APP_GITHUB_APP_PRIVATE_KEY", "APP_OKTA_CLIENT_ID", "APP_OKTA_PRIVATE_KEY", + "APP_GITHUB_BASE_URL", "APP_SLACK_API_URL", "APP_OKTA_BASE_URL", + "APP_OKTA_ORPHANED_USER_NOTIFICATIONS", + } + for key := range scenario.ConfigOverrides { + envKeys = append(envKeys, key) + } + savedEnv := make(map[string]string, len(envKeys)) + for _, key := range envKeys { + savedEnv[key] = os.Getenv(key) + } + defer func() { + for key, value := range savedEnv { + if value == "" { + os.Unsetenv(key) + } else { + os.Setenv(key, value) + } + } + }() + + os.Setenv("APP_GITHUB_APP_PRIVATE_KEY", string(githubAppKey)) + os.Setenv("APP_OKTA_CLIENT_ID", "test-client-id") + os.Setenv("APP_OKTA_PRIVATE_KEY", string(oauthKey)) + os.Setenv("APP_GITHUB_BASE_URL", githubAddr) + os.Setenv("APP_SLACK_API_URL", slackAddr) + os.Setenv("APP_OKTA_BASE_URL", oktaAddr) + + // save and restore http.DefaultTransport + savedTransport := http.DefaultTransport + defer func() { http.DefaultTransport = savedTransport }() + http.DefaultTransport = &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: certPool, }, } - os.Setenv("APP_GITHUB_BASE_URL", "https://localhost:9001/") - os.Setenv("APP_SLACK_API_URL", "https://localhost:9003/") - os.Setenv("APP_OKTA_BASE_URL", "https://localhost:9002") - - ctx = context.WithValue(ctx, "okta_tls_cert_pool", certPool) + ctx = oktaclient.WithCertPool(ctx, certPool) if os.Getenv("APP_OKTA_ORPHANED_USER_NOTIFICATIONS") == "" { os.Setenv("APP_OKTA_ORPHANED_USER_NOTIFICATIONS", "false") @@ -163,67 +185,77 @@ func runScenario(ctx context.Context, scenario TestScenario, verbose bool, logge cfg, err := config.NewConfig() if err != nil { - return fmt.Errorf("config creation failed: %w", err) + return errors.Wrap(err, "config creation failed") } - a, err := app.New(ctx, cfg) + appLogger := slog.New(&testHandler{prefix: " ", verbose: verbose, w: os.Stdout}) + + a, err := app.NewApp(ctx, cfg, appLogger) if err != nil { - return fmt.Errorf("app creation failed: %w", err) + return errors.Wrap(err, "app creation failed") } if verbose { fmt.Printf("\n Application Output:\n") } - appLogger := slog.New(&testHandler{prefix: " ", verbose: verbose, w: os.Stdout}) - a.Logger = appLogger + router := a.Handler() - var req app.Request + var httpReq *http.Request switch scenario.EventType { case "scheduled_event": var evt app.ScheduledEvent if err := json.Unmarshal(scenario.EventPayload, &evt); err != nil { - return fmt.Errorf("unmarshal event payload failed: %w", err) + return errors.Wrap(err, "failed to unmarshal event payload") + } + path := fmt.Sprintf("%s/scheduled/%s", cfg.BasePath, evt.Action) + var body []byte + if evt.Data != nil { + body = evt.Data + } + var err error + httpReq, err = http.NewRequestWithContext(ctx, http.MethodPost, path, bytes.NewReader(body)) + if err != nil { + return errors.Wrap(err, "failed to construct scheduled http request") } - req = app.Request{ - Type: app.RequestTypeScheduled, - ScheduledAction: evt.Action, - ScheduledData: evt.Data, + if cfg.AdminToken != "" { + httpReq.Header.Set("Authorization", "Bearer "+cfg.AdminToken) + } + if len(body) > 0 { + httpReq.Header.Set("Content-Type", "application/json") } case "webhook": - req = app.Request{ - Type: app.RequestTypeHTTP, - Method: "POST", - Path: "/webhooks", - Headers: map[string]string{ - "x-github-event": scenario.WebhookType, - "x-hub-signature-256": "", // signature validated separately in tests - }, - Body: scenario.WebhookPayload, + var err error + httpReq, err = http.NewRequestWithContext(ctx, http.MethodPost, cfg.BasePath+"/webhooks", bytes.NewReader(scenario.WebhookPayload)) + if err != nil { + return errors.Wrap(err, "failed to construct webhook http request") } + httpReq.Header.Set("X-GitHub-Event", scenario.WebhookType) + httpReq.Header.Set("X-Hub-Signature-256", "") // signature validated separately in tests default: - return fmt.Errorf("unknown event type: %s", scenario.EventType) + return errors.Newf("unknown event type: %s", scenario.EventType) } - resp := a.HandleRequest(ctx, req) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, httpReq) var processErr error - if resp.StatusCode >= 400 { - processErr = fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(resp.Body)) + if rec.Code >= 400 { + processErr = errors.Newf("request failed with status %d: %s", rec.Code, rec.Body.String()) } if scenario.ExpectError { if processErr == nil { - return fmt.Errorf("expected error but processing succeeded") + return errors.New("expected error but processing succeeded") } if verbose { fmt.Printf(" ✓ Expected error occurred: %v\n", processErr) } } else { if processErr != nil { - return fmt.Errorf("process event failed: %w", processErr) + return errors.Wrap(processErr, "process event failed") } } @@ -244,6 +276,12 @@ func runScenario(ctx context.Context, scenario TestScenario, verbose bool, logge fmt.Printf("\n") } + if err := validateNoUnexpectedCalls(scenario.ExpectedCalls, allReqs); err != nil { + fmt.Printf("\n Validation:\n") + fmt.Printf(" ✗ FAILED: %v\n", err) + return err + } + if err := validateExpectedCalls(scenario.ExpectedCalls, allReqs); err != nil { fmt.Printf("\n Validation:\n") fmt.Printf(" ✗ FAILED: %v\n", err) @@ -294,7 +332,32 @@ func validateExpectedCalls(expected []ExpectedCall, allReqs map[string][]Request } } if !found { - return fmt.Errorf("expected call not found: %s %s %s", exp.Service, exp.Method, exp.Path) + return errors.Newf("expected call not found: %s %s %s", exp.Service, exp.Method, exp.Path) + } + } + return nil +} + +// validateNoUnexpectedCalls checks that no unexpected destructive API calls +// were made. only flags DELETE calls to catch unintended member removal or +// resource deletion. +func validateNoUnexpectedCalls(expected []ExpectedCall, allReqs map[string][]RequestRecord) error { + for service, reqs := range allReqs { + for _, req := range reqs { + // only flag unexpected destructive calls + if req.Method != "DELETE" { + continue + } + matched := false + for _, exp := range expected { + if exp.Service == service && exp.Method == req.Method && matchPath(req.Path, exp.Path) { + matched = true + break + } + } + if !matched { + return errors.Newf("unexpected destructive call: %s %s %s", service, req.Method, req.Path) + } } } return nil diff --git a/cmd/verify/tls.go b/cmd/verify/tls.go index 0cab7d2..acbb4a2 100644 --- a/cmd/verify/tls.go +++ b/cmd/verify/tls.go @@ -7,9 +7,11 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" - "fmt" "math/big" + "net" "time" + + "github.com/cockroachdb/errors" ) // generateOAuthPrivateKey creates an RSA private key for OAuth testing. @@ -17,7 +19,7 @@ import ( func generateOAuthPrivateKey() ([]byte, error) { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return nil, fmt.Errorf("generate oauth key: %w", err) + return nil, errors.Wrap(err, "failed to generate oauth key") } keyPEM := pem.EncodeToMemory(&pem.Block{ @@ -33,7 +35,7 @@ func generateOAuthPrivateKey() ([]byte, error) { func generateSelfSignedCert() (tls.Certificate, *x509.CertPool, error) { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return tls.Certificate{}, nil, fmt.Errorf("generate key: %w", err) + return tls.Certificate{}, nil, errors.Wrap(err, "failed to generate key") } notBefore := time.Now() @@ -41,7 +43,7 @@ func generateSelfSignedCert() (tls.Certificate, *x509.CertPool, error) { serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) if err != nil { - return tls.Certificate{}, nil, fmt.Errorf("generate serial: %w", err) + return tls.Certificate{}, nil, errors.Wrap(err, "failed to generate serial") } template := x509.Certificate{ @@ -56,11 +58,12 @@ func generateSelfSignedCert() (tls.Certificate, *x509.CertPool, error) { ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, } certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { - return tls.Certificate{}, nil, fmt.Errorf("create cert: %w", err) + return tls.Certificate{}, nil, errors.Wrap(err, "failed to create cert") } certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) @@ -68,12 +71,12 @@ func generateSelfSignedCert() (tls.Certificate, *x509.CertPool, error) { tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { - return tls.Certificate{}, nil, fmt.Errorf("create keypair: %w", err) + return tls.Certificate{}, nil, errors.Wrap(err, "failed to create keypair") } cert, err := x509.ParseCertificate(certDER) if err != nil { - return tls.Certificate{}, nil, fmt.Errorf("parse cert: %w", err) + return tls.Certificate{}, nil, errors.Wrap(err, "failed to parse cert") } certPool := x509.NewCertPool() diff --git a/compose.yaml b/compose.yaml new file mode 100644 index 0000000..28ef6a2 --- /dev/null +++ b/compose.yaml @@ -0,0 +1,10 @@ +services: + server: + build: . + container_name: github-ops-server + env_file: .env + ports: + - "${APP_PORT:-8080}:${APP_PORT:-8080}" + volumes: + - ./.local:/.local:ro + restart: unless-stopped diff --git a/docs/github-app-setup.md b/docs/github-app-setup.md index 360bab7..feac482 100644 --- a/docs/github-app-setup.md +++ b/docs/github-app-setup.md @@ -1,6 +1,7 @@ # GitHub App Setup -This guide walks through creating and installing a GitHub App for github-ops-app. +This guide walks through creating and installing a GitHub App for +github-ops-app. ## Prerequisites @@ -9,39 +10,50 @@ This guide walks through creating and installing a GitHub App for github-ops-app ## Step 1: Create the GitHub App 1. Navigate to your organization's settings: + - Go to `https://github.com/organizations/YOUR_ORG/settings/apps` - - Or: **Organization** → **Settings** → **Developer settings** → **GitHub Apps** + - Or: **Organization** → **Settings** → **Developer settings** → **GitHub + Apps** 2. Click **New GitHub App** 3. Fill in the basic information: - | Field | Value | - |-----------------------|-------------------------------------------------| - | GitHub App name | `github-ops-app` (must be unique across GitHub) | - | Homepage URL | Your organization's URL or repo URL | - | Webhook > Webhook URL | Leave blank for now | - | Webhook > Secret | Generate a strong secret (save this for later) | - | Webhook > Active | **Uncheck** to disable webhooks initially | - - > **Note**: Disable webhooks during creation since you may not know your - > endpoint URL until after deployment. You'll configure webhooks and - > subscribe to events in [Step 7](#step-7-configure-webhook-and-events). - -4. Under **Permissions**, set the following: - - Repository Permissions - - Contents: Read - - Read branch protection rules - - Pull requests: Read - - Access PR details for compliance - - Organization Permissions - - Administration: Read - - Read organization settings - - Members: Read/Write - - Manage team membership - -4. Under Set installation scope: - - Where can this GitHub App be installed?: Only on this account + | Field | Value | + | --------------------- | ----------------------------------------------- | + | GitHub App name | `github-ops-app` (must be unique across GitHub) | + | Homepage URL | Your organization's URL or repo URL | + | Webhook > Webhook URL | Leave blank for now | + | Webhook > Secret | Generate a strong secret (save this for later) | + | Webhook > Active | **Uncheck** to disable webhooks initially | + + > **Note**: Disable webhooks during creation since you may not know your + > endpoint URL until after deployment. You'll configure webhooks and + > subscribe to events in [Step 7](#step-7-configure-webhook-and-events). + +### Configure Permissions + +Under **Permissions**, set the following: + +### Repository Permissions + +| Permission | Access | Purpose | +| ------------- | ------ | -------------------------------- | +| Contents | Read | Read branch protection rules | +| Pull requests | Read | Access PR details for compliance | + +#### Organization Permissions + +| Permission | Access | Purpose | +| -------------- | ---------- | -------------------------- | +| Members | Read/Write | Manage team membership | +| Administration | Read | Read organization settings | + +4. Set installation scope: + + | Setting | Value | + | --------------------------------------- | -------------------- | + | Where can this GitHub App be installed? | Only on this account | 5. Click **Create GitHub App** @@ -73,6 +85,7 @@ On the app's settings page, find and save: ## Step 5: Get Installation ID After installation, you'll be redirected to a URL like: + ``` https://github.com/organizations/YOUR_ORG/settings/installations/12345678 ``` @@ -80,6 +93,7 @@ https://github.com/organizations/YOUR_ORG/settings/installations/12345678 The number at the end (`12345678`) is your **Installation ID**. Alternatively, use the GitHub API: + ```bash # List installations (requires app JWT authentication) curl -H "Authorization: Bearer YOUR_JWT" \ @@ -115,25 +129,40 @@ After deploying your server, configure and enable webhooks: 1. Go to your GitHub App settings: `https://github.com/organizations/YOUR_ORG/settings/apps/YOUR_APP` -2. On the **General** tab, under **Webhook**: - - Set **Webhook URL** to your endpoint: - - Lambda: `https://xxx.execute-api.region.amazonaws.com/webhooks` - - Server: `https://your-domain.com/webhooks` - - Check **Active** to enable webhooks - - Click **Save changes** -3. Go to the **Permissions & events** tab -4. Scroll to **Subscribe to events** and check: + +2. Set **Webhook URL** to your endpoint: + + - Lambda: `https://xxx.execute-api.region.amazonaws.com/webhooks` + - Server: `https://your-domain.com/webhooks` + +3. Check **Active** to enable webhooks + +4. Click **Save changes** + +5. Under **Subscribe to events**, check: + - [x] **Pull request** - PR open, close, merge events - [x] **Team** - Team creation, deletion, changes - [x] **Membership** - Team membership changes -5. Click **Save changes** + +6. Click **Save changes** + +## Using the App Manifest (Alternative) + +For automated setup, use the manifest at `assets/github/manifest.json`: + +1. Go to `https://github.com/settings/apps/new` +2. Append `?manifest=` with URL-encoded manifest JSON +3. Or use the manifest creation API + +The manifest pre-configures all required permissions and events. ## Verification Test your setup: -1. **Webhook delivery**: Check **Settings** → **Developer settings** → - **GitHub Apps** → your app → **Advanced** → **Recent Deliveries** +1. **Webhook delivery**: Check **Settings** → **Developer settings** → **GitHub + Apps** → your app → **Advanced** → **Recent Deliveries** 2. **Create a test PR**: Open and merge a PR to a monitored branch to verify webhook reception diff --git a/docs/okta-setup.md b/docs/okta-setup.md index b8955a9..ba660b9 100644 --- a/docs/okta-setup.md +++ b/docs/okta-setup.md @@ -11,14 +11,18 @@ github-ops-app to sync Okta groups with GitHub teams. ## Step 1: Create API Services Application 1. Log in to your **Okta Admin Console** + 2. Navigate to **Applications** → **Applications** + 3. Click **Create App Integration** + 4. Select **API Services** and click **Next** > API Services apps use OAuth 2.0 client credentials flow with no user > context, ideal for server-to-server integrations. 5. Enter application name: `github-ops-app` (or similar) + 6. Click **Save** ## Step 2: Configure Client Authentication @@ -64,12 +68,13 @@ On the **General** tab, find and save: ## Step 5: Grant API Scopes 1. Go to the **Okta API Scopes** tab + 2. Grant the following scopes: - | Scope | Purpose | - |--------------------|-------------------------------| - | `okta.groups.read` | Read group names and members | - | `okta.users.read` | Read user profiles | + | Scope | Purpose | + | ------------------ | ---------------------------- | + | `okta.groups.read` | Read group names and members | + | `okta.users.read` | Read user profiles | 3. Click **Grant** for each scope @@ -82,22 +87,26 @@ API Services applications require an admin role to access Okta APIs. Without this, API calls will fail with permission errors even if scopes are granted. 1. Go to the **Admin roles** tab for your application + 2. Click **Edit assignments** + 3. Select one of the following roles: - | Role | Access Level | - |---------------------|-----------------------------------------------| - | **Read Only Admin** | Read access to all resources (recommended) | - | **Group Admin** | Full access to groups only | + | Role | Access Level | + | ------------------- | ------------------------------------------ | + | **Read Only Admin** | Read access to all resources (recommended) | + | **Group Admin** | Full access to groups only | 4. If using **Group Admin**, optionally restrict to specific groups: - - Under **Edit constraints for Group Administrator**, select specific - groups or group types the app can access + + - Under **Edit constraints for Group Administrator**, select specific groups + or group types the app can access + 5. Click **Save changes** -> **Note**: Read Only Admin is recommended for sync operations since it -> provides sufficient access without write permissions. Group Admin is an -> alternative if you need to limit the app's scope to group resources only. +> **Note**: Read Only Admin is recommended for sync operations since it provides +> sufficient access without write permissions. Group Admin is an alternative if +> you need to limit the app's scope to group resources only. ## Step 7: Identify Your Okta Domain @@ -113,12 +122,12 @@ Use the domain without `https://` prefix for `APP_OKTA_DOMAIN`. The app needs to map Okta users to GitHub usernames. Determine which Okta user profile field contains GitHub usernames: -| Common Fields | Description | -|--------------------|------------------------------------------| -| `login` | Okta username (often email) | -| `email` | User's email address | -| `githubUsername` | Custom field (recommended) | -| `nickName` | Sometimes used for GitHub username | +| Common Fields | Description | +| ---------------- | ---------------------------------- | +| `login` | Okta username (often email) | +| `email` | User's email address | +| `githubUsername` | Custom field (recommended) | +| `nickName` | Sometimes used for GitHub username | ### Adding a Custom GitHub Username Field (Recommended) @@ -141,13 +150,14 @@ rules: **Example naming conventions:** -| Pattern | Example Groups | -|----------------------|----------------------------------------------| -| `github-{team}` | `github-engineering`, `github-platform` | -| `gh-eng-{team}` | `gh-eng-frontend`, `gh-eng-backend` | -| `Team - {name}` | `Team - Platform`, `Team - Security` | +| Pattern | Example Groups | +| --------------- | --------------------------------------- | +| `github-{team}` | `github-engineering`, `github-platform` | +| `gh-eng-{team}` | `gh-eng-frontend`, `gh-eng-backend` | +| `Team - {name}` | `Team - Platform`, `Team - Security` | Groups can be: + - Okta groups (manually managed) - Groups synced from Active Directory - Groups from other identity providers @@ -192,18 +202,18 @@ APP_OKTA_SYNC_RULES='[ ### Rule Fields -| Field | Description | -|-------------------------|------------------------------------------------------| -| `name` | Rule identifier (for logging) | -| `enabled` | Enable/disable rule (default: `true`) | -| `okta_group_pattern` | Regex to match Okta groups | -| `okta_group_name` | Exact Okta group name (alternative to pattern) | -| `github_team_prefix` | Prefix for generated GitHub team names | -| `github_team_name` | Exact GitHub team name (overrides pattern) | -| `strip_prefix` | Remove this prefix from Okta group name | -| `sync_members` | Sync members between Okta and GitHub (default: `true`)| -| `create_team_if_missing`| Auto-create GitHub teams if they don't exist | -| `team_privacy` | GitHub team visibility: `secret` or `closed` | +| Field | Description | +| ------------------------ | ------------------------------------------------------ | +| `name` | Rule identifier (for logging) | +| `enabled` | Enable/disable rule (default: `true`) | +| `okta_group_pattern` | Regex to match Okta groups | +| `okta_group_name` | Exact Okta group name (alternative to pattern) | +| `github_team_prefix` | Prefix for generated GitHub team names | +| `github_team_name` | Exact GitHub team name (overrides pattern) | +| `strip_prefix` | Remove this prefix from Okta group name | +| `sync_members` | Sync members between Okta and GitHub (default: `true`) | +| `create_team_if_missing` | Auto-create GitHub teams if they don't exist | +| `team_privacy` | GitHub team visibility: `secret` or `closed` | See the [main README](../README.md#okta-sync-rules) for additional examples. @@ -221,6 +231,7 @@ Test your Okta configuration: ``` Trigger a sync and verify: + 1. POST to `/scheduled/okta-sync` endpoint 2. Check logs for groups discovered and teams synced 3. Verify GitHub team memberships match Okta groups @@ -249,6 +260,7 @@ Trigger a sync and verify: ### Rate limiting Okta has API rate limits. If you hit limits: + - Reduce sync frequency - The app handles rate limit responses gracefully diff --git a/docs/slack-setup.md b/docs/slack-setup.md index c7208a6..1b4b4ee 100644 --- a/docs/slack-setup.md +++ b/docs/slack-setup.md @@ -15,8 +15,9 @@ This guide walks through creating a Slack app for github-ops-app notifications. 2. Click **Create New App** 3. Select **From an app manifest** 4. Select your workspace -5. Copy the contents of [`assets/slack/manifest.json`](../assets/slack/manifest.json) - and paste into the manifest editor +5. Copy the contents of + [`assets/slack/manifest.json`](../assets/slack/manifest.json) and paste into + the manifest editor 6. Click **Create** ### Option B: Manual Setup @@ -34,15 +35,17 @@ Then continue to configure OAuth scopes manually (Step 2). If you used the manifest, scopes are pre-configured. Otherwise: 1. Go to **OAuth & Permissions** in the sidebar + 2. Scroll to **Scopes** → **Bot Token Scopes** + 3. Add the following scopes: - | Scope | Purpose | - |---------------------|----------------------------------------------| - | `chat:write` | Post messages to channels bot is member of | - | `chat:write.public` | Post to public channels without joining | - | `channels:read` | View basic channel info | - | `channels:join` | Join public channels | + | Scope | Purpose | + | ------------------- | ------------------------------------------ | + | `chat:write` | Post messages to channels bot is member of | + | `chat:write.public` | Post to public channels without joining | + | `channels:read` | View basic channel info | + | `channels:join` | Join public channels | ## Step 3: Install to Workspace @@ -51,6 +54,7 @@ If you used the manifest, scopes are pre-configured. Otherwise: 3. Review permissions and click **Allow** If your workspace requires admin approval: + - Submit the app for approval - Wait for workspace admin to approve - Return to install after approval @@ -129,7 +133,8 @@ Make notifications more recognizable: 2. Under **Display Information**: - **App name**: `GitHub Ops Bot` (or your preference) - **Short description**: Brief description of the bot - - **App icon**: Upload a custom icon (use `assets/slack/icon.png` or your own) + - **App icon**: Upload a custom icon (use `assets/slack/icon.png` or your + own) - **Background color**: `#10203B` or your brand color ## Verification @@ -148,6 +153,7 @@ curl -X POST https://slack.com/api/chat.postMessage \ ``` Expected response: + ```json { "ok": true, @@ -161,12 +167,12 @@ Expected response: The bot sends these notification types: -| Event | Description | -|-----------------------|------------------------------------------------| -| PR Compliance Alert | PR merged bypassing branch protection | -| Okta Sync Report | Summary of team membership changes | -| Orphaned Users Alert | Org members not in any synced teams | -| Sync Error | Errors during Okta sync process | +| Event | Description | +| -------------------- | ------------------------------------- | +| PR Compliance Alert | PR merged bypassing branch protection | +| Okta Sync Report | Summary of team membership changes | +| Orphaned Users Alert | Org members not in any synced teams | +| Sync Error | Errors during Okta sync process | ## Troubleshooting @@ -204,6 +210,7 @@ The bot sends these notification types: Slack has rate limits (typically 1 message per second per channel). The app handles rate limits gracefully, but if you see delays: + - Notifications are queued and retried - Consider consolidating notifications for high-volume events diff --git a/go.mod b/go.mod index b63beae..7f069fb 100644 --- a/go.mod +++ b/go.mod @@ -34,6 +34,7 @@ require ( github.com/cockroachdb/redact v1.1.5 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/getsentry/sentry-go v0.27.0 // indirect + github.com/go-chi/chi/v5 v5.2.5 // indirect github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/gogo/protobuf v1.3.2 // indirect diff --git a/go.sum b/go.sum index a2ee53d..ffd662f 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvw github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= github.com/getsentry/sentry-go v0.27.0 h1:Pv98CIbtB3LkMWmXi4Joa5OOcwbmnX88sF5qbK3r3Ps= github.com/getsentry/sentry-go v0.27.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= +github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= +github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= diff --git a/internal/app/app.go b/internal/app/app.go index 8d6516b..d71ca38 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -6,79 +6,24 @@ import ( "context" "encoding/json" "log/slog" + "net/http" + "sync" "github.com/cockroachdb/errors" "github.com/cruxstack/github-ops-app/internal/config" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" - "github.com/cruxstack/github-ops-app/internal/github/client" - "github.com/cruxstack/github-ops-app/internal/notifiers" - "github.com/cruxstack/github-ops-app/internal/okta" + "github.com/cruxstack/github-ops-app/internal/domain" ) // App is the main application instance containing all clients and -// configuration. +// configuration. depends on domain interfaces, not concrete implementations. type App struct { Config *config.Config Logger *slog.Logger - GitHubClient *client.Client - OktaClient *okta.Client - Notifier *notifiers.SlackNotifier -} - -// New creates a new App instance with configured clients. -// Initializes GitHub, Okta, and Slack clients based on config. -func New(ctx context.Context, cfg *config.Config) (*App, error) { - logger := config.NewLogger() - - app := &App{ - Config: cfg, - Logger: logger, - } - - if cfg.IsGitHubConfigured() { - ghClient, err := client.NewAppClientWithBaseURL( - cfg.GitHubAppID, - cfg.GitHubInstallationID, - cfg.GitHubAppPrivateKey, - cfg.GitHubOrg, - cfg.GitHubBaseURL, - ) - if err != nil { - return nil, errors.Wrap(err, "failed to create github app client") - } - app.GitHubClient = ghClient - } - - if cfg.IsOktaSyncEnabled() { - oktaClient, err := okta.NewClientWithContext(ctx, &okta.ClientConfig{ - Domain: cfg.OktaDomain, - ClientID: cfg.OktaClientID, - PrivateKey: cfg.OktaPrivateKey, - PrivateKeyID: cfg.OktaPrivateKeyID, - Scopes: cfg.OktaScopes, - GitHubUserField: cfg.OktaGitHubUserField, - BaseURL: cfg.OktaBaseURL, - }) - if err != nil { - return nil, errors.Wrap(err, "failed to create okta client") - } - app.OktaClient = oktaClient - } - - if cfg.SlackEnabled { - channels := notifiers.SlackChannels{ - Default: cfg.SlackChannel, - PRBypass: cfg.SlackChannelPRBypass, - OktaSync: cfg.SlackChannelOktaSync, - OrphanedUsers: cfg.SlackChannelOrphanedUsers, - } - messages := notifiers.SlackMessages{ - PRBypassFooterNote: cfg.SlackPRBypassFooterNote, - } - app.Notifier = notifiers.NewSlackNotifierWithAPIURL(cfg.SlackToken, channels, messages, cfg.SlackAPIURL) - } - - return app, nil + GitHubClient domain.GitHubClient + OktaClient domain.OktaClient + Notifier domain.Notifier + router http.Handler + routerOnce sync.Once } // ScheduledEvent represents a generic scheduled event. @@ -87,9 +32,9 @@ type ScheduledEvent struct { Data json.RawMessage `json:"data,omitempty"` } -// ProcessScheduledEvent handles scheduled events (e.g., cron jobs). -// Routes to appropriate handlers based on event action. -func (a *App) ProcessScheduledEvent(ctx context.Context, evt ScheduledEvent) error { +// processScheduledEvent handles scheduled events (e.g., cron jobs). +// routes to appropriate handlers based on event action. +func (a *App) processScheduledEvent(ctx context.Context, evt ScheduledEvent) error { if a.Config.DebugEnabled { j, _ := json.Marshal(evt) a.Logger.Debug("received scheduled event", slog.String("event", string(j))) @@ -105,9 +50,9 @@ func (a *App) ProcessScheduledEvent(ctx context.Context, evt ScheduledEvent) err } } -// ProcessWebhook handles incoming GitHub webhook events. -// Supports pull_request, team, and membership events. -func (a *App) ProcessWebhook(ctx context.Context, payload []byte, eventType string) error { +// processWebhook handles incoming GitHub webhook events. +// supports pull_request, team, and membership events. +func (a *App) processWebhook(ctx context.Context, payload []byte, eventType string) error { if a.Config.DebugEnabled { a.Logger.Debug("received webhook", slog.String("event_type", eventType)) } @@ -120,7 +65,7 @@ func (a *App) ProcessWebhook(ctx context.Context, payload []byte, eventType stri case "membership": return a.handleMembershipWebhook(ctx, payload) default: - return errors.Wrapf(internalerrors.ErrInvalidEventType, "%s", eventType) + return errors.Wrapf(domain.ErrInvalidEventType, "%s", eventType) } } diff --git a/internal/app/app_test.go b/internal/app/app_test.go index ee8bf0c..3ababf0 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -3,12 +3,13 @@ package app import ( "context" "log/slog" + "net/http" + "net/http/httptest" "os" "testing" "github.com/cruxstack/github-ops-app/internal/config" - "github.com/cruxstack/github-ops-app/internal/github/client" - "github.com/cruxstack/github-ops-app/internal/okta" + "github.com/cruxstack/github-ops-app/internal/domain" ) func TestHandleSlackTest_NotConfigured(t *testing.T) { @@ -136,7 +137,7 @@ func TestProcessScheduledEvent_SlackTest(t *testing.T) { } evt := ScheduledEvent{Action: "slack-test"} - err := app.ProcessScheduledEvent(context.Background(), evt) + err := app.processScheduledEvent(context.Background(), evt) // should fail because slack is not configured if err == nil { @@ -151,26 +152,21 @@ func TestProcessScheduledEvent_UnknownAction(t *testing.T) { } evt := ScheduledEvent{Action: "unknown-action"} - err := app.ProcessScheduledEvent(context.Background(), evt) + err := app.processScheduledEvent(context.Background(), evt) if err == nil { t.Error("expected error for unknown action") } } -// verify fake data types match expected interfaces +// verify fake data types match expected domain types func TestFakeDataTypes(t *testing.T) { - // ensure fake PR result is compatible with notifier - var _ *client.PRComplianceResult = fakePRComplianceResult() - - // ensure fake sync reports are compatible with notifier - var _ []*okta.SyncReport = fakeOktaSyncReports() - - // ensure fake orphaned users report is compatible with notifier - var _ *okta.OrphanedUsersReport = fakeOrphanedUsersReport() + var _ *domain.PRComplianceResult = fakePRComplianceResult() + var _ []*domain.SyncReport = fakeOktaSyncReports() + var _ *domain.OrphanedUsersReport = fakeOrphanedUsersReport() } -func TestCheckAdminAuth(t *testing.T) { +func TestAdminAuthMiddleware(t *testing.T) { tests := []struct { name string adminToken string @@ -217,33 +213,33 @@ func TestCheckAdminAuth(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - app := &App{ + a := &App{ Config: &config.Config{AdminToken: tt.adminToken}, Logger: slog.New(slog.NewTextHandler(os.Stderr, nil)), } - headers := map[string]string{} + // test via /server/status which uses admin auth middleware + router := a.Handler() + + req := httptest.NewRequest(http.MethodGet, "/server/status", nil) if tt.authHeader != "" { - headers["authorization"] = tt.authHeader + req.Header.Set("Authorization", tt.authHeader) } - req := Request{Headers: headers} - resp := app.checkAdminAuth(req) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) - if tt.expectError && resp == nil { - t.Error("expected error response, got nil") - } - if !tt.expectError && resp != nil { - t.Errorf("expected no error, got status %d", resp.StatusCode) + if tt.expectError && rec.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rec.Code) } - if tt.expectError && resp != nil && resp.StatusCode != 401 { - t.Errorf("expected status 401, got %d", resp.StatusCode) + if !tt.expectError && rec.Code == http.StatusUnauthorized { + t.Errorf("expected success, got 401") } }) } } -func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { +func TestRouter_AdminAuthOnProtectedEndpoints(t *testing.T) { tests := []struct { name string path string @@ -255,7 +251,7 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { { name: "status endpoint, no token configured", path: "/server/status", - method: "GET", + method: http.MethodGet, adminToken: "", authHeader: "", expectedStatus: 200, @@ -263,7 +259,7 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { { name: "status endpoint, token required, missing", path: "/server/status", - method: "GET", + method: http.MethodGet, adminToken: "secret", authHeader: "", expectedStatus: 401, @@ -271,7 +267,7 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { { name: "status endpoint, token required, correct", path: "/server/status", - method: "GET", + method: http.MethodGet, adminToken: "secret", authHeader: "Bearer secret", expectedStatus: 200, @@ -279,7 +275,7 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { { name: "config endpoint, token required, missing", path: "/server/config", - method: "GET", + method: http.MethodGet, adminToken: "secret", authHeader: "", expectedStatus: 401, @@ -287,7 +283,7 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { { name: "config endpoint, token required, correct", path: "/server/config", - method: "GET", + method: http.MethodGet, adminToken: "secret", authHeader: "Bearer secret", expectedStatus: 200, @@ -295,7 +291,7 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { { name: "scheduled endpoint, token required, missing", path: "/scheduled/slack-test", - method: "POST", + method: http.MethodPost, adminToken: "secret", authHeader: "", expectedStatus: 401, @@ -304,27 +300,24 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - app := &App{ + a := &App{ Config: &config.Config{AdminToken: tt.adminToken}, Logger: slog.New(slog.NewTextHandler(os.Stderr, nil)), } - headers := map[string]string{} - if tt.authHeader != "" { - headers["authorization"] = tt.authHeader - } + router := a.Handler() - req := Request{ - Type: RequestTypeHTTP, - Method: tt.method, - Path: tt.path, - Headers: headers, + req := httptest.NewRequest(tt.method, tt.path, nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) } - resp := app.HandleRequest(context.Background(), req) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) - if resp.StatusCode != tt.expectedStatus { - t.Errorf("expected status %d, got %d", tt.expectedStatus, resp.StatusCode) + if rec.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d (body: %s)", + tt.expectedStatus, rec.Code, rec.Body.String()) } }) } diff --git a/internal/app/factory.go b/internal/app/factory.go new file mode 100644 index 0000000..ac2beb4 --- /dev/null +++ b/internal/app/factory.go @@ -0,0 +1,69 @@ +package app + +import ( + "context" + "log/slog" + + "github.com/cruxstack/github-ops-app/internal/config" + ghclient "github.com/cruxstack/github-ops-app/internal/github/client" + "github.com/cruxstack/github-ops-app/internal/notifiers" + "github.com/cruxstack/github-ops-app/internal/okta" +) + +// NewApp creates a new App instance with configured clients (composition +// root). concrete implementations are instantiated here and wired as domain +// interfaces. this is the single shared factory used by all entry points. +func NewApp(ctx context.Context, cfg *config.Config, logger *slog.Logger) (*App, error) { + a := &App{ + Config: cfg, + Logger: logger, + } + + if cfg.IsGitHubConfigured() { + ghClient, err := ghclient.NewAppClientWithBaseURL( + cfg.GitHubAppID, + cfg.GitHubInstallationID, + cfg.GitHubAppPrivateKey, + cfg.GitHubOrg, + cfg.GitHubBaseURL, + ) + if err != nil { + return nil, err + } + a.GitHubClient = ghClient + } + + if cfg.IsOktaSyncEnabled() { + oktaClient, err := okta.NewClientWithContext(ctx, &okta.ClientConfig{ + Domain: cfg.OktaDomain, + ClientID: cfg.OktaClientID, + PrivateKey: cfg.OktaPrivateKey, + PrivateKeyID: cfg.OktaPrivateKeyID, + Scopes: cfg.OktaScopes, + GitHubUserField: cfg.OktaGitHubUserField, + BaseURL: cfg.OktaBaseURL, + Logger: logger, + }) + if err != nil { + return nil, err + } + a.OktaClient = oktaClient + } + + if cfg.SlackEnabled { + channels := notifiers.SlackChannels{ + Default: cfg.SlackChannel, + PRBypass: cfg.SlackChannelPRBypass, + OktaSync: cfg.SlackChannelOktaSync, + OrphanedUsers: cfg.SlackChannelOrphanedUsers, + } + messages := notifiers.SlackMessages{ + PRBypassFooterNote: cfg.SlackPRBypassFooterNote, + } + a.Notifier = notifiers.NewSlackNotifierWithAPIURL( + cfg.SlackToken, channels, messages, cfg.SlackAPIURL, + ) + } + + return a, nil +} diff --git a/internal/app/handlers.go b/internal/app/handlers.go index 6fa6c01..82b5a5e 100644 --- a/internal/app/handlers.go +++ b/internal/app/handlers.go @@ -5,10 +5,9 @@ import ( "log/slog" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" - "github.com/cruxstack/github-ops-app/internal/github/client" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/cruxstack/github-ops-app/internal/github/webhooks" - "github.com/cruxstack/github-ops-app/internal/okta" + "github.com/cruxstack/github-ops-app/internal/sync" ) // handleOktaSync executes Okta group synchronization to GitHub teams. @@ -20,10 +19,10 @@ func (a *App) handleOktaSync(ctx context.Context) error { } if a.OktaClient == nil || a.GitHubClient == nil { - return errors.Wrap(internalerrors.ErrClientNotInit, "okta or github client") + return errors.Wrap(domain.ErrClientNotInit, "okta or github client") } - syncer := okta.NewSyncer(a.OktaClient, a.GitHubClient, a.Config.OktaSyncRules, a.Config.OktaSyncSafetyThreshold, a.Logger) + syncer := sync.NewSyncer(a.OktaClient, a.GitHubClient, a.Config.OktaSyncRules, a.Config.OktaSyncSafetyThreshold, a.Logger) syncResult, err := syncer.Sync(ctx) if err != nil { return errors.Wrap(err, "okta sync failed") @@ -83,30 +82,14 @@ func (a *App) handlePullRequestWebhook(ctx context.Context, payload []byte) erro return nil } - ghClient := a.GitHubClient - - if prEvent.GetInstallationID() != 0 && prEvent.GetInstallationID() != a.Config.GitHubInstallationID { - installClient, err := client.NewAppClientWithBaseURL( - a.Config.GitHubAppID, - prEvent.GetInstallationID(), - a.Config.GitHubAppPrivateKey, - a.Config.GitHubOrg, - a.Config.GitHubBaseURL, - ) - if err != nil { - return errors.Wrapf(err, "failed to create client for installation %d", prEvent.GetInstallationID()) - } - ghClient = installClient - } - - if ghClient == nil { - return errors.Wrap(internalerrors.ErrClientNotInit, "github client") + if a.GitHubClient == nil { + return errors.Wrap(domain.ErrClientNotInit, "github client") } owner := prEvent.GetRepoOwner() repo := prEvent.GetRepoName() - result, err := ghClient.CheckPRCompliance(ctx, owner, repo, prEvent.Number) + result, err := a.GitHubClient.CheckPRCompliance(ctx, owner, repo, prEvent.Number) if err != nil { return errors.Wrapf(err, "failed to check pr #%d compliance", prEvent.Number) } diff --git a/internal/app/handlers_test.go b/internal/app/handlers_test.go new file mode 100644 index 0000000..3a58fb3 --- /dev/null +++ b/internal/app/handlers_test.go @@ -0,0 +1,596 @@ +package app + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/cockroachdb/errors" + "github.com/cruxstack/github-ops-app/internal/config" + "github.com/cruxstack/github-ops-app/internal/domain" + "github.com/google/go-github/v79/github" +) + +// --- mock implementations --- + +type mockGitHubClient struct { + checkPRComplianceFn func(ctx context.Context, owner, repo string, prNumber int) (*domain.PRComplianceResult, error) + getOrCreateTeamFn func(ctx context.Context, teamName, privacy string) (*github.Team, error) + syncTeamMembersFn func(ctx context.Context, teamSlug string, desiredMembers []string, threshold float64) (*domain.TeamSyncResult, error) + getTeamMembersFn func(ctx context.Context, teamSlug string) ([]string, error) + listOrgMembersFn func(ctx context.Context) ([]string, error) + isExternalCollaboratorFn func(ctx context.Context, username string) (bool, error) + getAppSlugFn func(ctx context.Context) (string, error) +} + +func (m *mockGitHubClient) CheckPRCompliance(ctx context.Context, owner, repo string, prNumber int) (*domain.PRComplianceResult, error) { + if m.checkPRComplianceFn != nil { + return m.checkPRComplianceFn(ctx, owner, repo, prNumber) + } + return &domain.PRComplianceResult{}, nil +} +func (m *mockGitHubClient) GetOrCreateTeam(ctx context.Context, teamName, privacy string) (*github.Team, error) { + if m.getOrCreateTeamFn != nil { + return m.getOrCreateTeamFn(ctx, teamName, privacy) + } + slug := teamName + return &github.Team{Slug: &slug}, nil +} +func (m *mockGitHubClient) SyncTeamMembers(ctx context.Context, teamSlug string, desiredMembers []string, threshold float64) (*domain.TeamSyncResult, error) { + if m.syncTeamMembersFn != nil { + return m.syncTeamMembersFn(ctx, teamSlug, desiredMembers, threshold) + } + return &domain.TeamSyncResult{TeamName: teamSlug, MembersAdded: desiredMembers}, nil +} +func (m *mockGitHubClient) GetTeamMembers(ctx context.Context, teamSlug string) ([]string, error) { + if m.getTeamMembersFn != nil { + return m.getTeamMembersFn(ctx, teamSlug) + } + return []string{}, nil +} +func (m *mockGitHubClient) ListOrgMembers(ctx context.Context) ([]string, error) { + if m.listOrgMembersFn != nil { + return m.listOrgMembersFn(ctx) + } + return []string{}, nil +} +func (m *mockGitHubClient) IsExternalCollaborator(ctx context.Context, username string) (bool, error) { + if m.isExternalCollaboratorFn != nil { + return m.isExternalCollaboratorFn(ctx, username) + } + return false, nil +} +func (m *mockGitHubClient) GetAppSlug(ctx context.Context) (string, error) { + if m.getAppSlugFn != nil { + return m.getAppSlugFn(ctx) + } + return "test-app", nil +} +func (m *mockGitHubClient) GetOrg() string { return "test-org" } + +type mockOktaClient struct { + getGroupsByPatternFn func(ctx context.Context, pattern string) ([]*domain.GroupInfo, error) + getGroupInfoFn func(ctx context.Context, groupName string) (*domain.GroupInfo, error) +} + +func (m *mockOktaClient) GetGroupsByPattern(ctx context.Context, pattern string) ([]*domain.GroupInfo, error) { + if m.getGroupsByPatternFn != nil { + return m.getGroupsByPatternFn(ctx, pattern) + } + return []*domain.GroupInfo{}, nil +} +func (m *mockOktaClient) GetGroupInfo(ctx context.Context, groupName string) (*domain.GroupInfo, error) { + if m.getGroupInfoFn != nil { + return m.getGroupInfoFn(ctx, groupName) + } + return &domain.GroupInfo{ID: "g1", Name: groupName, Members: []string{}}, nil +} + +type mockNotifier struct { + notifyPRBypassFn func(ctx context.Context, result *domain.PRComplianceResult, repoFullName string) error + notifyOktaSyncFn func(ctx context.Context, reports []*domain.SyncReport, githubOrg string) error + notifyOrphanedUsersFn func(ctx context.Context, report *domain.OrphanedUsersReport) error +} + +func (m *mockNotifier) NotifyPRBypass(ctx context.Context, result *domain.PRComplianceResult, repoFullName string) error { + if m.notifyPRBypassFn != nil { + return m.notifyPRBypassFn(ctx, result, repoFullName) + } + return nil +} +func (m *mockNotifier) NotifyOktaSync(ctx context.Context, reports []*domain.SyncReport, githubOrg string) error { + if m.notifyOktaSyncFn != nil { + return m.notifyOktaSyncFn(ctx, reports, githubOrg) + } + return nil +} +func (m *mockNotifier) NotifyOrphanedUsers(ctx context.Context, report *domain.OrphanedUsersReport) error { + if m.notifyOrphanedUsersFn != nil { + return m.notifyOrphanedUsersFn(ctx, report) + } + return nil +} + +func discardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +// --- handler tests --- + +func TestHandleOktaSync_Disabled(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + } + + err := a.handleOktaSync(context.Background()) + if err != nil { + t.Fatalf("expected nil error when sync disabled, got: %v", err) + } +} + +func TestHandleOktaSync_ClientsNil(t *testing.T) { + a := &App{ + Config: &config.Config{ + OktaDomain: "test.okta.com", + OktaClientID: "cid", + OktaPrivateKey: []byte("key"), + OktaSyncRules: []domain.SyncRule{{OktaGroupName: "g1", GitHubTeamName: "t1"}}, + }, + Logger: discardLogger(), + } + + err := a.handleOktaSync(context.Background()) + if err == nil { + t.Fatal("expected error when clients are nil") + } + if !errors.Is(err, domain.ErrClientNotInit) { + t.Errorf("expected ErrClientNotInit, got: %v", err) + } +} + +func TestHandleOktaSync_Success(t *testing.T) { + notified := false + a := &App{ + Config: &config.Config{ + OktaDomain: "test.okta.com", + OktaClientID: "cid", + OktaPrivateKey: []byte("key"), + OktaSyncRules: []domain.SyncRule{{OktaGroupName: "Engineering", GitHubTeamName: "engineering"}}, + GitHubOrg: "test-org", + }, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{}, + OktaClient: &mockOktaClient{ + getGroupInfoFn: func(_ context.Context, name string) (*domain.GroupInfo, error) { + return &domain.GroupInfo{ID: "g1", Name: name, Members: []string{"alice"}}, nil + }, + }, + Notifier: &mockNotifier{ + notifyOktaSyncFn: func(_ context.Context, reports []*domain.SyncReport, org string) error { + notified = true + if len(reports) != 1 { + t.Errorf("expected 1 report, got %d", len(reports)) + } + return nil + }, + }, + } + + err := a.handleOktaSync(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !notified { + t.Error("expected slack notification to be sent") + } +} + +func TestHandleOktaSync_NotifierFailureDoesNotFail(t *testing.T) { + a := &App{ + Config: &config.Config{ + OktaDomain: "test.okta.com", + OktaClientID: "cid", + OktaPrivateKey: []byte("key"), + OktaSyncRules: []domain.SyncRule{{OktaGroupName: "Eng", GitHubTeamName: "eng"}}, + GitHubOrg: "org", + }, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{}, + OktaClient: &mockOktaClient{}, + Notifier: &mockNotifier{ + notifyOktaSyncFn: func(_ context.Context, _ []*domain.SyncReport, _ string) error { + return errors.New("slack api error") + }, + }, + } + + err := a.handleOktaSync(context.Background()) + if err != nil { + t.Fatalf("notifier failure should not propagate, got: %v", err) + } +} + +func TestHandlePullRequestWebhook_NotMerged(t *testing.T) { + prNumber := 10 + baseBranch := "main" + payload := map[string]any{ + "action": "opened", + "number": prNumber, + "pull_request": map[string]any{ + "number": prNumber, + "base": map[string]any{"ref": baseBranch}, + }, + "repository": map[string]any{ + "name": "repo", + "full_name": "org/repo", + "owner": map[string]any{"login": "org"}, + }, + } + data, _ := json.Marshal(payload) + + a := &App{ + Config: &config.Config{PRComplianceEnabled: true, PRMonitoredBranches: []string{"main"}, + GitHubOrg: "org", GitHubAppID: 1, GitHubAppPrivateKey: []byte("k"), GitHubInstallationID: 1}, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{}, + } + + err := a.handlePullRequestWebhook(context.Background(), data) + if err != nil { + t.Fatalf("non-merged PR should not error, got: %v", err) + } +} + +func TestHandlePullRequestWebhook_UnmonitoredBranch(t *testing.T) { + merged := true + prNumber := 10 + payload := map[string]any{ + "action": "closed", + "number": prNumber, + "pull_request": map[string]any{ + "number": prNumber, + "merged": merged, + "base": map[string]any{"ref": "develop"}, + }, + "repository": map[string]any{ + "name": "repo", + "full_name": "org/repo", + "owner": map[string]any{"login": "org"}, + }, + } + data, _ := json.Marshal(payload) + + a := &App{ + Config: &config.Config{PRComplianceEnabled: true, PRMonitoredBranches: []string{"main"}, + GitHubOrg: "org", GitHubAppID: 1, GitHubAppPrivateKey: []byte("k"), GitHubInstallationID: 1}, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{}, + } + + err := a.handlePullRequestWebhook(context.Background(), data) + if err != nil { + t.Fatalf("unmonitored branch should not error, got: %v", err) + } +} + +func TestHandlePullRequestWebhook_ComplianceCheck(t *testing.T) { + merged := true + prNumber := 42 + checked := false + payload := map[string]any{ + "action": "closed", + "number": prNumber, + "pull_request": map[string]any{ + "number": prNumber, + "merged": merged, + "base": map[string]any{"ref": "main"}, + }, + "repository": map[string]any{ + "name": "repo", + "full_name": "org/repo", + "owner": map[string]any{"login": "org"}, + }, + } + data, _ := json.Marshal(payload) + + a := &App{ + Config: &config.Config{ + PRComplianceEnabled: true, PRMonitoredBranches: []string{"main"}, + GitHubOrg: "org", GitHubAppID: 1, GitHubAppPrivateKey: []byte("k"), GitHubInstallationID: 1, + }, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{ + checkPRComplianceFn: func(_ context.Context, owner, repo string, num int) (*domain.PRComplianceResult, error) { + checked = true + if num != prNumber { + t.Errorf("expected pr %d, got %d", prNumber, num) + } + return &domain.PRComplianceResult{ + Violations: []domain.ComplianceViolation{}, + }, nil + }, + }, + } + + err := a.handlePullRequestWebhook(context.Background(), data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !checked { + t.Error("expected compliance check to be called") + } +} + +func TestHandlePullRequestWebhook_BypassNotification(t *testing.T) { + merged := true + prNumber := 42 + notified := false + prURL := "https://github.com/org/repo/pull/42" + payload := map[string]any{ + "action": "closed", + "number": prNumber, + "pull_request": map[string]any{ + "number": prNumber, + "merged": merged, + "base": map[string]any{"ref": "main"}, + "html_url": prURL, + "merged_by": map[string]any{"login": "admin-user"}, + }, + "repository": map[string]any{ + "name": "repo", + "full_name": "org/repo", + "owner": map[string]any{"login": "org"}, + }, + } + data, _ := json.Marshal(payload) + + pr := &github.PullRequest{} + a := &App{ + Config: &config.Config{ + PRComplianceEnabled: true, PRMonitoredBranches: []string{"main"}, + GitHubOrg: "org", GitHubAppID: 1, GitHubAppPrivateKey: []byte("k"), GitHubInstallationID: 1, + }, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{ + checkPRComplianceFn: func(_ context.Context, _, _ string, _ int) (*domain.PRComplianceResult, error) { + return &domain.PRComplianceResult{ + PR: pr, + UserHasBypass: true, + Violations: []domain.ComplianceViolation{{Type: "test", Description: "test"}}, + }, nil + }, + }, + Notifier: &mockNotifier{ + notifyPRBypassFn: func(_ context.Context, result *domain.PRComplianceResult, repoName string) error { + notified = true + if repoName != "org/repo" { + t.Errorf("expected repo org/repo, got %s", repoName) + } + return nil + }, + }, + } + + err := a.handlePullRequestWebhook(context.Background(), data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !notified { + t.Error("expected bypass notification to be sent") + } +} + +func TestHandleTeamWebhook_SyncDisabled(t *testing.T) { + payload := map[string]any{ + "action": "edited", + "team": map[string]any{"slug": "engineering"}, + "sender": map[string]any{"login": "user1", "type": "User"}, + } + data, _ := json.Marshal(payload) + + a := &App{ + Config: &config.Config{}, // no okta sync configured + Logger: discardLogger(), + } + + err := a.handleTeamWebhook(context.Background(), data) + if err != nil { + t.Fatalf("expected nil when sync disabled, got: %v", err) + } +} + +func TestHandleTeamWebhook_IgnoresBotSender(t *testing.T) { + payload := map[string]any{ + "action": "edited", + "team": map[string]any{"slug": "engineering"}, + "sender": map[string]any{"login": "dependabot", "type": "Bot"}, + } + data, _ := json.Marshal(payload) + + a := &App{ + Config: &config.Config{ + OktaDomain: "test.okta.com", OktaClientID: "cid", + OktaPrivateKey: []byte("key"), + OktaSyncRules: []domain.SyncRule{{OktaGroupName: "g", GitHubTeamName: "t"}}, + }, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{}, + OktaClient: &mockOktaClient{}, + } + + err := a.handleTeamWebhook(context.Background(), data) + if err != nil { + t.Fatalf("bot sender should be ignored, got: %v", err) + } +} + +func TestHandleMembershipWebhook_NonTeamScope(t *testing.T) { + payload := map[string]any{ + "action": "added", + "scope": "organization", + "member": map[string]any{"login": "user1"}, + "team": map[string]any{"slug": "eng"}, + "sender": map[string]any{"login": "admin", "type": "User"}, + } + data, _ := json.Marshal(payload) + + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + } + + err := a.handleMembershipWebhook(context.Background(), data) + if err != nil { + t.Fatalf("non-team scope should be skipped, got: %v", err) + } +} + +func TestShouldIgnoreWebhookChange_BotSender(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + } + + event := &stubSender{senderType: "Bot", senderLogin: "some-bot"} + if !a.shouldIgnoreWebhookChange(context.Background(), event) { + t.Error("expected Bot sender to be ignored") + } +} + +func TestShouldIgnoreWebhookChange_AppSlugMatch(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{ + getAppSlugFn: func(_ context.Context) (string, error) { + return "my-app", nil + }, + }, + } + + event := &stubSender{senderType: "User", senderLogin: "my-app[bot]"} + if !a.shouldIgnoreWebhookChange(context.Background(), event) { + t.Error("expected app slug match to be ignored") + } +} + +func TestShouldIgnoreWebhookChange_HumanUser(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{ + getAppSlugFn: func(_ context.Context) (string, error) { + return "my-app", nil + }, + }, + } + + event := &stubSender{senderType: "User", senderLogin: "human-user"} + if a.shouldIgnoreWebhookChange(context.Background(), event) { + t.Error("expected human user NOT to be ignored") + } +} + +func TestProcessWebhook_UnknownEventType(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + } + + err := a.processWebhook(context.Background(), []byte(`{}`), "deployment") + if err == nil { + t.Fatal("expected error for unknown event type") + } + if !errors.Is(err, domain.ErrInvalidEventType) { + t.Errorf("expected ErrInvalidEventType, got: %v", err) + } +} + +func TestWriteErrorFromDomain(t *testing.T) { + tests := []struct { + name string + err error + wantStatus int + }{ + {"auth error direct", domain.ErrInvalidSignature, 401}, + {"auth error wrapped", errors.Wrap(domain.ErrInvalidSignature, "wrapped"), 401}, + {"validation error", domain.ErrMissingPRData, 400}, + {"validation error wrapped", errors.Wrap(domain.ErrInvalidEventType, "wrapped"), 400}, + {"config error", domain.ErrClientNotInit, 503}, + {"api error", errors.Wrap(domain.ErrTeamNotFound, "wrapped"), 502}, + {"generic error", errors.New("something"), 500}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + writeErrorFromDomain(rec, tt.err, "fallback") + if rec.Code != tt.wantStatus { + t.Errorf("expected status %d, got %d", tt.wantStatus, rec.Code) + } + }) + } +} + +func TestRouter_NotFound(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + } + + router := a.Handler() + req := httptest.NewRequest(http.MethodGet, "/nonexistent", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d", rec.Code) + } +} + +func TestRouter_MethodNotAllowed(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + } + + router := a.Handler() + req := httptest.NewRequest(http.MethodDelete, "/webhooks", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusMethodNotAllowed { + t.Errorf("expected 405, got %d", rec.Code) + } +} + +func TestRouter_BasePathStripping(t *testing.T) { + a := &App{ + Config: &config.Config{BasePath: "/api/v1"}, + Logger: discardLogger(), + } + + router := a.Handler() + req := httptest.NewRequest(http.MethodGet, "/api/v1/server/status", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected 200 for base-path-stripped status, got %d (body: %s)", + rec.Code, rec.Body.String()) + } +} + +// stubSender is a simple webhookSender for testing shouldIgnoreWebhookChange +type stubSender struct { + senderType string + senderLogin string +} + +func (s *stubSender) GetSenderType() string { return s.senderType } +func (s *stubSender) GetSenderLogin() string { return s.senderLogin } diff --git a/internal/app/request.go b/internal/app/request.go index d3a17d7..351b24f 100644 --- a/internal/app/request.go +++ b/internal/app/request.go @@ -1,237 +1,195 @@ package app import ( - "context" "encoding/json" + "io" "log/slog" + "net/http" "strings" + "github.com/cockroachdb/errors" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/cruxstack/github-ops-app/internal/github/webhooks" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" ) -// RequestType identifies the category of incoming request. -type RequestType string - -const ( - // RequestTypeHTTP represents HTTP requests (webhooks, status, config). - RequestTypeHTTP RequestType = "http" - // RequestTypeScheduled represents scheduled/cron events. - RequestTypeScheduled RequestType = "scheduled" -) - -// Request is a unified request type that abstracts HTTP and scheduled events. -// Runtimes (server, lambda) convert their native formats to this type. -type Request struct { - Type RequestType `json:"type"` - Method string `json:"method,omitempty"` - Path string `json:"path,omitempty"` - Headers map[string]string `json:"headers,omitempty"` - Body []byte `json:"body,omitempty"` - - // ScheduledAction is used for scheduled events (e.g., "okta-sync"). - ScheduledAction string `json:"scheduled_action,omitempty"` - // ScheduledData contains optional payload for scheduled events. - ScheduledData json.RawMessage `json:"scheduled_data,omitempty"` -} - -// Response is a unified response type returned by HandleRequest. -// Runtimes convert this to their native response format. -type Response struct { - StatusCode int `json:"status_code"` - Headers map[string]string `json:"headers,omitempty"` - Body []byte `json:"body,omitempty"` - ContentType string `json:"content_type,omitempty"` +// Handler returns the HTTP handler for the application. +// this is the single entry point for all HTTP request processing. +// both the server and lambda entry points feed requests into this handler. +func (a *App) Handler() http.Handler { + a.routerOnce.Do(func() { + a.router = a.buildRouter() + }) + return a.router } -// HandleRequest routes incoming requests to the appropriate handler. -// This is the single entry point for all request processing. -func (a *App) HandleRequest(ctx context.Context, req Request) Response { - if a.Config.DebugEnabled { - j, _ := json.Marshal(req) - a.Logger.Debug("handling request", slog.String("request", string(j))) - } - - switch req.Type { - case RequestTypeScheduled: - return a.handleScheduledRequest(ctx, req) - case RequestTypeHTTP: - return a.handleHTTPRequest(ctx, req) - default: - return errorResponse(400, "unknown request type") - } -} +// buildRouter constructs the chi router with all routes and middleware. +func (a *App) buildRouter() http.Handler { + r := chi.NewRouter() + r.Use(middleware.Recoverer) -// handleScheduledRequest processes scheduled/cron events. -func (a *App) handleScheduledRequest(ctx context.Context, req Request) Response { - evt := ScheduledEvent{ - Action: req.ScheduledAction, - Data: req.ScheduledData, + basePath := a.Config.BasePath + if basePath == "" { + basePath = "/" } - if err := a.ProcessScheduledEvent(ctx, evt); err != nil { - a.Logger.Error("scheduled event processing failed", - slog.String("action", evt.Action), - slog.String("error", err.Error())) - return errorResponse(500, "scheduled event processing failed") - } + r.Route(basePath, func(r chi.Router) { + r.Post("/webhooks", a.handleWebhookHTTP) - return jsonResponse(200, map[string]string{ - "status": "success", - "message": evt.Action + " completed", + r.Group(func(r chi.Router) { + r.Use(a.adminAuthMiddleware) + r.Get("/server/status", a.handleStatusHTTP) + r.Get("/server/config", a.handleConfigHTTP) + r.Post("/scheduled/{action}", a.handleScheduledHTTP) + }) }) + + return r } -// handleHTTPRequest routes HTTP requests based on path. -// strips BasePath prefix if configured (e.g., "/api/v1" -> "/"). -func (a *App) handleHTTPRequest(ctx context.Context, req Request) Response { - path := req.Path - if a.Config.BasePath != "" { - path = strings.TrimPrefix(path, a.Config.BasePath) - if path == "" { - path = "/" +// adminAuthMiddleware validates the admin bearer token on protected routes. +// if no admin token is configured, all requests pass through. +func (a *App) adminAuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if a.Config.AdminToken == "" { + next.ServeHTTP(w, r) + return } - } - switch path { - case "/server/status": - return a.handleStatusRequest(req) - case "/server/config": - return a.handleConfigRequest(req) - case "/webhooks", "/": - return a.handleWebhookRequest(ctx, req) - default: - if strings.HasPrefix(path, "/scheduled/") { - return a.handleScheduledHTTPRequest(ctx, req, path) + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeError(w, http.StatusUnauthorized, "unauthorized") + return } - return errorResponse(404, "not found") - } + + token := strings.TrimPrefix(authHeader, "Bearer ") + if token == authHeader { + token = strings.TrimPrefix(authHeader, "bearer ") + } + + if token != a.Config.AdminToken { + writeError(w, http.StatusUnauthorized, "unauthorized") + return + } + + next.ServeHTTP(w, r) + }) } -// handleStatusRequest returns application status. -func (a *App) handleStatusRequest(req Request) Response { - if req.Method != "GET" { - return errorResponse(405, "method not allowed") - } - if resp := a.checkAdminAuth(req); resp != nil { - return *resp - } - return jsonResponse(200, a.GetStatus()) +// handleStatusHTTP returns application status. +func (a *App) handleStatusHTTP(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, a.GetStatus()) } -// handleConfigRequest returns redacted configuration. -func (a *App) handleConfigRequest(req Request) Response { - if req.Method != "GET" { - return errorResponse(405, "method not allowed") - } - if resp := a.checkAdminAuth(req); resp != nil { - return *resp - } - return jsonResponse(200, a.Config.Redacted()) +// handleConfigHTTP returns redacted configuration. +func (a *App) handleConfigHTTP(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, a.Config.Redacted()) } -// handleWebhookRequest processes GitHub webhook POST requests. -func (a *App) handleWebhookRequest(ctx context.Context, req Request) Response { - if req.Method != "POST" { - return errorResponse(405, "method not allowed") +// handleWebhookHTTP processes GitHub webhook POST requests. +func (a *App) handleWebhookHTTP(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(io.LimitReader(r.Body, 10<<20)) // 10MB limit + if err != nil { + writeError(w, http.StatusBadRequest, "failed to read request body") + return } - eventType := req.Headers["x-github-event"] - signature := req.Headers["x-hub-signature-256"] + eventType := r.Header.Get("X-GitHub-Event") + signature := r.Header.Get("X-Hub-Signature-256") if err := webhooks.ValidateWebhookSignature( - req.Body, + body, signature, a.Config.GitHubWebhookSecret, ); err != nil { a.Logger.Warn("webhook signature validation failed", slog.String("error", err.Error())) - return errorResponse(401, "unauthorized") + writeErrorFromDomain(w, err, "unauthorized") + return } - if err := a.ProcessWebhook(ctx, req.Body, eventType); err != nil { + if err := a.processWebhook(r.Context(), body, eventType); err != nil { a.Logger.Error("webhook processing failed", slog.String("event_type", eventType), slog.String("error", err.Error())) - return errorResponse(500, "webhook processing failed") + writeErrorFromDomain(w, err, "webhook processing failed") + return } - return Response{ - StatusCode: 200, - ContentType: "text/plain", - Body: []byte("ok"), - } + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) } -// handleScheduledHTTPRequest processes scheduled events via HTTP POST. -// path is the normalized path with BasePath already stripped. -func (a *App) handleScheduledHTTPRequest(ctx context.Context, req Request, path string) Response { - if req.Method != "POST" { - return errorResponse(405, "method not allowed") - } - if resp := a.checkAdminAuth(req); resp != nil { - return *resp +// handleScheduledHTTP processes scheduled events via HTTP POST. +// the action is extracted from the chi URL parameter. +func (a *App) handleScheduledHTTP(w http.ResponseWriter, r *http.Request) { + action := chi.URLParam(r, "action") + if action == "" { + writeError(w, http.StatusBadRequest, "missing scheduled action") + return } - // extract action from path (e.g., "/scheduled/okta-sync" -> "okta-sync") - action := strings.TrimPrefix(path, "/scheduled/") - if action == "" { - return errorResponse(400, "missing scheduled action") + evt := ScheduledEvent{Action: action} + + if r.Body != nil { + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) + if err != nil { + writeError(w, http.StatusBadRequest, "failed to read request body") + return + } + if len(body) > 0 { + evt.Data = json.RawMessage(body) + } } - scheduledReq := Request{ - Type: RequestTypeScheduled, - ScheduledAction: action, + if err := a.processScheduledEvent(r.Context(), evt); err != nil { + a.Logger.Error("scheduled event processing failed", + slog.String("action", evt.Action), + slog.String("error", err.Error())) + writeErrorFromDomain(w, err, "scheduled event processing failed") + return } - return a.handleScheduledRequest(ctx, scheduledReq) + writeJSON(w, http.StatusOK, map[string]string{ + "status": "success", + "message": evt.Action + " completed", + }) } -// jsonResponse creates a JSON response with the given status and data. -func jsonResponse(status int, data any) Response { +// writeJSON writes a JSON response with the given status code. +func writeJSON(w http.ResponseWriter, status int, data any) { body, err := json.Marshal(data) if err != nil { - return errorResponse(500, "failed to marshal response") - } - return Response{ - StatusCode: status, - ContentType: "application/json", - Headers: map[string]string{"Content-Type": "application/json"}, - Body: body, + writeError(w, http.StatusInternalServerError, "failed to marshal response") + return } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + w.Write(body) } -// errorResponse creates an error response with the given status and message. -func errorResponse(status int, message string) Response { - return Response{ - StatusCode: status, - ContentType: "text/plain", - Body: []byte(message), - } +// writeError writes a plain text error response. +func writeError(w http.ResponseWriter, status int, message string) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(status) + w.Write([]byte(message)) } -// checkAdminAuth validates the admin token from the request. -// returns nil if auth is disabled (no token configured) or if token is valid. -// returns an error response if token is required but missing or invalid. -func (a *App) checkAdminAuth(req Request) *Response { - if a.Config.AdminToken == "" { - return nil - } - - authHeader := req.Headers["authorization"] - if authHeader == "" { - resp := errorResponse(401, "unauthorized") - return &resp - } - - token := strings.TrimPrefix(authHeader, "Bearer ") - if token == authHeader { - token = strings.TrimPrefix(authHeader, "bearer ") - } - - if token != a.Config.AdminToken { - resp := errorResponse(401, "unauthorized") - return &resp +// writeErrorFromDomain translates domain error types to HTTP status codes +// and writes the response. centralizes error-to-HTTP mapping. +func writeErrorFromDomain(w http.ResponseWriter, err error, fallbackMsg string) { + switch { + case errors.Is(err, domain.AuthError): + writeError(w, http.StatusUnauthorized, "unauthorized") + case errors.Is(err, domain.ValidationError): + writeError(w, http.StatusBadRequest, fallbackMsg) + case errors.Is(err, domain.ConfigError): + writeError(w, http.StatusServiceUnavailable, "service not configured") + case errors.Is(err, domain.APIError): + writeError(w, http.StatusBadGateway, fallbackMsg) + default: + writeError(w, http.StatusInternalServerError, fallbackMsg) } - - return nil } diff --git a/internal/app/testdata.go b/internal/app/testdata.go index 6c264ef..bae49a2 100644 --- a/internal/app/testdata.go +++ b/internal/app/testdata.go @@ -1,19 +1,18 @@ package app import ( - "github.com/cruxstack/github-ops-app/internal/github/client" - "github.com/cruxstack/github-ops-app/internal/okta" + "github.com/cruxstack/github-ops-app/internal/domain" gh "github.com/google/go-github/v79/github" ) // fakePRComplianceResult returns sample PR compliance data for testing. -func fakePRComplianceResult() *client.PRComplianceResult { +func fakePRComplianceResult() *domain.PRComplianceResult { prNumber := 42 prTitle := "Add new authentication feature" prURL := "https://github.com/acme-corp/demo-repo/pull/42" mergedByLogin := "test-user" - return &client.PRComplianceResult{ + return &domain.PRComplianceResult{ PR: &gh.PullRequest{ Number: &prNumber, Title: &prTitle, @@ -25,7 +24,7 @@ func fakePRComplianceResult() *client.PRComplianceResult { BaseBranch: "main", UserHasBypass: true, UserBypassReason: "repository admin", - Violations: []client.ComplianceViolation{ + Violations: []domain.ComplianceViolation{ {Type: "insufficient_reviews", Description: "required 2 approving reviews, had 0"}, {Type: "missing_status_check", Description: "required check 'ci/build' did not pass"}, }, @@ -33,8 +32,8 @@ func fakePRComplianceResult() *client.PRComplianceResult { } // fakeOktaSyncReports returns sample Okta sync reports for testing. -func fakeOktaSyncReports() []*okta.SyncReport { - return []*okta.SyncReport{ +func fakeOktaSyncReports() []*domain.SyncReport { + return []*domain.SyncReport{ { Rule: "engineering-team", OktaGroup: "Engineering", @@ -61,8 +60,8 @@ func fakeOktaSyncReports() []*okta.SyncReport { } // fakeOrphanedUsersReport returns sample orphaned users data for testing. -func fakeOrphanedUsersReport() *okta.OrphanedUsersReport { - return &okta.OrphanedUsersReport{ +func fakeOrphanedUsersReport() *domain.OrphanedUsersReport { + return &domain.OrphanedUsersReport{ OrphanedUsers: []string{"orphan-user-1", "orphan-user-2", "legacy-bot"}, } } diff --git a/internal/config/config.go b/internal/config/config.go index f43dde0..2895a98 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,7 +15,7 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/cockroachdb/errors" - "github.com/cruxstack/github-ops-app/internal/types" + "github.com/cruxstack/github-ops-app/internal/domain" ) // Config holds all application configuration loaded from environment @@ -46,7 +46,7 @@ type Config struct { OktaScopes []string OktaBaseURL string OktaGitHubUserField string - OktaSyncRules []types.SyncRule + OktaSyncRules []domain.SyncRule OktaSyncSafetyThreshold float64 OktaOrphanedUserNotifications bool @@ -259,7 +259,7 @@ func NewConfigWithContext(ctx context.Context) (*Config, error) { syncRulesJSON := os.Getenv("APP_OKTA_SYNC_RULES") if syncRulesJSON != "" { - var rules []types.SyncRule + var rules []domain.SyncRule if err := json.Unmarshal([]byte(syncRulesJSON), &rules); err != nil { return nil, errors.Wrap(err, "failed to parse APP_OKTA_SYNC_RULES") } @@ -363,16 +363,16 @@ type RedactedConfig struct { PRMonitoredBranches []string `json:"pr_monitored_branches"` // Okta - OktaDomain string `json:"okta_domain"` - OktaClientID string `json:"okta_client_id"` - OktaPrivateKey string `json:"okta_private_key"` - OktaPrivateKeyID string `json:"okta_private_key_id"` - OktaScopes []string `json:"okta_scopes"` - OktaBaseURL string `json:"okta_base_url"` - OktaGitHubUserField string `json:"okta_github_user_field"` - OktaSyncRules []types.SyncRule `json:"okta_sync_rules"` - OktaSyncSafetyThreshold float64 `json:"okta_sync_safety_threshold"` - OktaOrphanedUserNotifications bool `json:"okta_orphaned_user_notifications"` + OktaDomain string `json:"okta_domain"` + OktaClientID string `json:"okta_client_id"` + OktaPrivateKey string `json:"okta_private_key"` + OktaPrivateKeyID string `json:"okta_private_key_id"` + OktaScopes []string `json:"okta_scopes"` + OktaBaseURL string `json:"okta_base_url"` + OktaGitHubUserField string `json:"okta_github_user_field"` + OktaSyncRules []domain.SyncRule `json:"okta_sync_rules"` + OktaSyncSafetyThreshold float64 `json:"okta_sync_safety_threshold"` + OktaOrphanedUserNotifications bool `json:"okta_orphaned_user_notifications"` // Slack SlackEnabled bool `json:"slack_enabled"` diff --git a/internal/config/config_helpers_test.go b/internal/config/config_helpers_test.go new file mode 100644 index 0000000..ed234a2 --- /dev/null +++ b/internal/config/config_helpers_test.go @@ -0,0 +1,298 @@ +package config + +import ( + "testing" +) + +func TestIsGitHubConfigured(t *testing.T) { + tests := []struct { + name string + cfg Config + want bool + }{ + { + name: "fully configured", + cfg: Config{ + GitHubOrg: "org", GitHubAppID: 1, + GitHubAppPrivateKey: []byte("key"), GitHubInstallationID: 1, + }, + want: true, + }, + {name: "empty", cfg: Config{}, want: false}, + { + name: "missing org", + cfg: Config{GitHubAppID: 1, GitHubAppPrivateKey: []byte("k"), GitHubInstallationID: 1}, + want: false, + }, + { + name: "missing private key", + cfg: Config{GitHubOrg: "org", GitHubAppID: 1, GitHubInstallationID: 1}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.IsGitHubConfigured(); got != tt.want { + t.Errorf("IsGitHubConfigured() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsOktaSyncEnabled(t *testing.T) { + tests := []struct { + name string + cfg Config + want bool + }{ + {name: "empty", cfg: Config{}, want: false}, + { + name: "missing rules", + cfg: Config{OktaDomain: "d", OktaClientID: "c", OktaPrivateKey: []byte("k")}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.IsOktaSyncEnabled(); got != tt.want { + t.Errorf("IsOktaSyncEnabled() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsPRComplianceEnabled(t *testing.T) { + tests := []struct { + name string + cfg Config + want bool + }{ + {name: "disabled flag", cfg: Config{PRComplianceEnabled: false}, want: false}, + { + name: "enabled but no github", + cfg: Config{PRComplianceEnabled: true}, + want: false, + }, + { + name: "enabled with github", + cfg: Config{ + PRComplianceEnabled: true, + GitHubOrg: "org", GitHubAppID: 1, + GitHubAppPrivateKey: []byte("k"), GitHubInstallationID: 1, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.IsPRComplianceEnabled(); got != tt.want { + t.Errorf("IsPRComplianceEnabled() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestShouldMonitorBranch(t *testing.T) { + cfg := Config{ + PRComplianceEnabled: true, + PRMonitoredBranches: []string{"main", "master"}, + GitHubOrg: "org", + GitHubAppID: 1, + GitHubAppPrivateKey: []byte("k"), + GitHubInstallationID: 1, + } + + tests := []struct { + branch string + want bool + }{ + {"main", true}, + {"master", true}, + {"develop", false}, + {"refs/heads/main", true}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.branch, func(t *testing.T) { + if got := cfg.ShouldMonitorBranch(tt.branch); got != tt.want { + t.Errorf("ShouldMonitorBranch(%q) = %v, want %v", tt.branch, got, tt.want) + } + }) + } +} + +func TestRedacted(t *testing.T) { + cfg := Config{ + GitHubOrg: "my-org", + GitHubAppPrivateKey: []byte("secret-key"), + GitHubWebhookSecret: "webhook-secret", + AdminToken: "admin-secret", + SlackToken: "xoxb-token", + OktaClientID: "client-id", + OktaPrivateKey: []byte("okta-key"), + DebugEnabled: true, + } + + redacted := cfg.Redacted() + + // non-secrets should be preserved + if redacted.GitHubOrg != "my-org" { + t.Errorf("expected org to be preserved, got %q", redacted.GitHubOrg) + } + if !redacted.DebugEnabled { + t.Error("expected DebugEnabled to be preserved") + } + + // secrets should be redacted + if redacted.GitHubAppPrivateKey != "***REDACTED***" { + t.Errorf("expected private key to be redacted, got %q", redacted.GitHubAppPrivateKey) + } + if redacted.GitHubWebhookSecret != "***REDACTED***" { + t.Errorf("expected webhook secret to be redacted, got %q", redacted.GitHubWebhookSecret) + } + if redacted.AdminToken != "***REDACTED***" { + t.Errorf("expected admin token to be redacted, got %q", redacted.AdminToken) + } + if redacted.SlackToken != "***REDACTED***" { + t.Errorf("expected slack token to be redacted, got %q", redacted.SlackToken) + } + if redacted.OktaClientID != "***REDACTED***" { + t.Errorf("expected okta client id to be redacted, got %q", redacted.OktaClientID) + } +} + +func TestRedacted_EmptyValues(t *testing.T) { + cfg := Config{} + redacted := cfg.Redacted() + + if redacted.GitHubAppPrivateKey != "" { + t.Error("expected empty private key to remain empty") + } + if redacted.SlackToken != "" { + t.Error("expected empty slack token to remain empty") + } +} + +func TestNewConfigWithContext_Defaults(t *testing.T) { + // clear all env vars that NewConfig reads + envKeys := []string{ + "APP_DEBUG_ENABLED", "APP_GITHUB_ORG", "APP_GITHUB_APP_ID", + "APP_GITHUB_APP_PRIVATE_KEY", "APP_GITHUB_APP_PRIVATE_KEY_PATH", + "APP_GITHUB_INSTALLATION_ID", "APP_GITHUB_WEBHOOK_SECRET", + "APP_GITHUB_BASE_URL", "APP_PR_COMPLIANCE_ENABLED", + "APP_PR_MONITORED_BRANCHES", "APP_OKTA_DOMAIN", "APP_OKTA_CLIENT_ID", + "APP_OKTA_PRIVATE_KEY", "APP_OKTA_PRIVATE_KEY_PATH", + "APP_OKTA_PRIVATE_KEY_ID", "APP_OKTA_SCOPES", + "APP_OKTA_BASE_URL", "APP_OKTA_GITHUB_USER_FIELD", + "APP_OKTA_SYNC_RULES", "APP_OKTA_SYNC_SAFETY_THRESHOLD", + "APP_OKTA_ORPHANED_USER_NOTIFICATIONS", "APP_SLACK_TOKEN", + "APP_SLACK_CHANNEL", "APP_SLACK_API_URL", "APP_BASE_PATH", + "APP_ADMIN_TOKEN", + } + for _, key := range envKeys { + t.Setenv(key, "") + } + + cfg, err := NewConfig() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.OktaGitHubUserField != "githubUsername" { + t.Errorf("expected default github user field, got %q", cfg.OktaGitHubUserField) + } + if cfg.OktaSyncSafetyThreshold != 0.5 { + t.Errorf("expected default safety threshold 0.5, got %f", cfg.OktaSyncSafetyThreshold) + } + if len(cfg.PRMonitoredBranches) != 2 || cfg.PRMonitoredBranches[0] != "main" { + t.Errorf("expected default monitored branches [main, master], got %v", cfg.PRMonitoredBranches) + } + if len(cfg.OktaScopes) != 2 { + t.Errorf("expected default okta scopes, got %v", cfg.OktaScopes) + } +} + +func TestNewConfigWithContext_ParsesEnvVars(t *testing.T) { + t.Setenv("APP_DEBUG_ENABLED", "true") + t.Setenv("APP_GITHUB_ORG", "test-org") + t.Setenv("APP_GITHUB_APP_ID", "12345") + t.Setenv("APP_GITHUB_APP_PRIVATE_KEY", "test-key-data") + t.Setenv("APP_GITHUB_INSTALLATION_ID", "67890") + t.Setenv("APP_GITHUB_WEBHOOK_SECRET", "whsec") + t.Setenv("APP_PR_COMPLIANCE_ENABLED", "true") + t.Setenv("APP_PR_MONITORED_BRANCHES", "main,release") + t.Setenv("APP_OKTA_SYNC_SAFETY_THRESHOLD", "0.3") + t.Setenv("APP_BASE_PATH", "/api/v1") + t.Setenv("APP_OKTA_SYNC_RULES", `[{"okta_group_name":"Eng","github_team_name":"eng"}]`) + // clear keys that could interfere + t.Setenv("APP_GITHUB_APP_PRIVATE_KEY_PATH", "") + t.Setenv("APP_OKTA_PRIVATE_KEY_PATH", "") + t.Setenv("APP_OKTA_PRIVATE_KEY", "") + t.Setenv("APP_SLACK_TOKEN", "") + t.Setenv("APP_SLACK_CHANNEL", "") + t.Setenv("APP_ADMIN_TOKEN", "") + t.Setenv("APP_OKTA_DOMAIN", "") + t.Setenv("APP_OKTA_CLIENT_ID", "") + + cfg, err := NewConfig() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !cfg.DebugEnabled { + t.Error("expected debug enabled") + } + if cfg.GitHubOrg != "test-org" { + t.Errorf("expected org test-org, got %q", cfg.GitHubOrg) + } + if cfg.GitHubAppID != 12345 { + t.Errorf("expected app id 12345, got %d", cfg.GitHubAppID) + } + if string(cfg.GitHubAppPrivateKey) != "test-key-data" { + t.Errorf("expected key data, got %q", string(cfg.GitHubAppPrivateKey)) + } + if cfg.GitHubInstallationID != 67890 { + t.Errorf("expected installation 67890, got %d", cfg.GitHubInstallationID) + } + if cfg.OktaSyncSafetyThreshold != 0.3 { + t.Errorf("expected threshold 0.3, got %f", cfg.OktaSyncSafetyThreshold) + } + if cfg.BasePath != "/api/v1" { + t.Errorf("expected base path /api/v1, got %q", cfg.BasePath) + } + if len(cfg.PRMonitoredBranches) != 2 || cfg.PRMonitoredBranches[1] != "release" { + t.Errorf("expected [main, release], got %v", cfg.PRMonitoredBranches) + } + if len(cfg.OktaSyncRules) != 1 { + t.Errorf("expected 1 sync rule, got %d", len(cfg.OktaSyncRules)) + } +} + +func TestNewConfigWithContext_InvalidAppID(t *testing.T) { + t.Setenv("APP_GITHUB_APP_ID", "not-a-number") + t.Setenv("APP_GITHUB_WEBHOOK_SECRET", "") + t.Setenv("APP_SLACK_TOKEN", "") + t.Setenv("APP_ADMIN_TOKEN", "") + + _, err := NewConfig() + if err == nil { + t.Fatal("expected error for invalid app id") + } +} + +func TestNewConfigWithContext_InvalidSyncRulesJSON(t *testing.T) { + t.Setenv("APP_OKTA_SYNC_RULES", "not-json") + t.Setenv("APP_GITHUB_APP_ID", "") + t.Setenv("APP_GITHUB_WEBHOOK_SECRET", "") + t.Setenv("APP_SLACK_TOKEN", "") + t.Setenv("APP_ADMIN_TOKEN", "") + + _, err := NewConfig() + if err == nil { + t.Fatal("expected error for invalid sync rules JSON") + } +} diff --git a/internal/errors/errors.go b/internal/domain/errors.go similarity index 88% rename from internal/errors/errors.go rename to internal/domain/errors.go index 3a05f20..18249f0 100644 --- a/internal/errors/errors.go +++ b/internal/domain/errors.go @@ -1,6 +1,7 @@ -// Package errors defines sentinel errors and domain types for the -// application. uses cockroachdb/errors for automatic stack trace capture. -package errors +// Package domain defines shared business types, errors, and interfaces. +// this package has zero internal imports and serves as the dependency +// inversion layer for the application. +package domain import "github.com/cockroachdb/errors" @@ -39,5 +40,4 @@ var ( ErrClientNotInit = errors.Mark(errors.New("client not initialized"), ConfigError) ErrInvalidEventType = errors.Mark(errors.New("unknown event type"), ValidationError) ErrMissingOAuthCreds = errors.Mark(errors.New("must provide either api token or oauth credentials"), ConfigError) - ErrOAuthTokenExpired = errors.Mark(errors.New("oauth token expired"), AuthError) ) diff --git a/internal/domain/interfaces.go b/internal/domain/interfaces.go new file mode 100644 index 0000000..c40da02 --- /dev/null +++ b/internal/domain/interfaces.go @@ -0,0 +1,64 @@ +package domain + +import ( + "context" + + "github.com/google/go-github/v79/github" +) + +// GitHubClient defines the interface for GitHub API operations. +// implemented by internal/github/client.Client. +type GitHubClient interface { + // CheckPRCompliance verifies if a merged PR met branch protection + // requirements. + CheckPRCompliance(ctx context.Context, owner, repo string, prNumber int) (*PRComplianceResult, error) + + // GetOrCreateTeam fetches an existing team by slug or creates it if + // missing. + GetOrCreateTeam(ctx context.Context, teamName, privacy string) (*github.Team, error) + + // SyncTeamMembers adds and removes members to match desired state. + SyncTeamMembers(ctx context.Context, teamSlug string, desiredMembers []string, safetyThreshold float64) (*TeamSyncResult, error) + + // GetTeamMembers returns GitHub usernames of all team members. + GetTeamMembers(ctx context.Context, teamSlug string) ([]string, error) + + // ListOrgMembers returns all organization members excluding external + // collaborators. + ListOrgMembers(ctx context.Context) ([]string, error) + + // IsExternalCollaborator checks if a user is an outside collaborator + // rather than an organization member. + IsExternalCollaborator(ctx context.Context, username string) (bool, error) + + // GetAppSlug fetches the GitHub App slug identifier. + GetAppSlug(ctx context.Context) (string, error) + + // GetOrg returns the GitHub organization name. + GetOrg() string +} + +// OktaClient defines the interface for Okta API operations. +// implemented by internal/okta.Client. +type OktaClient interface { + // GetGroupsByPattern fetches all Okta groups matching a regex pattern. + GetGroupsByPattern(ctx context.Context, pattern string) ([]*GroupInfo, error) + + // GetGroupInfo fetches details for a single Okta group by name. + GetGroupInfo(ctx context.Context, groupName string) (*GroupInfo, error) +} + +// Notifier defines the interface for sending notifications. +// implemented by internal/notifiers.SlackNotifier. +type Notifier interface { + // NotifyPRBypass sends a notification when branch protection is + // bypassed. + NotifyPRBypass(ctx context.Context, result *PRComplianceResult, repoFullName string) error + + // NotifyOktaSync sends a notification with Okta sync results. + NotifyOktaSync(ctx context.Context, reports []*SyncReport, githubOrg string) error + + // NotifyOrphanedUsers sends a notification about organization members + // not in any synced teams. + NotifyOrphanedUsers(ctx context.Context, report *OrphanedUsersReport) error +} diff --git a/internal/domain/types.go b/internal/domain/types.go new file mode 100644 index 0000000..f08b05f --- /dev/null +++ b/internal/domain/types.go @@ -0,0 +1,126 @@ +package domain + +import "github.com/google/go-github/v79/github" + +// SyncRule defines how to sync Okta groups to GitHub teams. +type SyncRule struct { + Name string `json:"name"` + Enabled *bool `json:"enabled,omitempty"` + OktaGroupPattern string `json:"okta_group_pattern,omitempty"` + OktaGroupName string `json:"okta_group_name,omitempty"` + GitHubTeamPrefix string `json:"github_team_prefix,omitempty"` + GitHubTeamName string `json:"github_team_name,omitempty"` + StripPrefix string `json:"strip_prefix,omitempty"` + SyncMembers *bool `json:"sync_members,omitempty"` + CreateTeamIfMissing bool `json:"create_team_if_missing"` + TeamPrivacy string `json:"team_privacy,omitempty"` +} + +// IsEnabled returns true if the rule is enabled (defaults to true). +func (r SyncRule) IsEnabled() bool { + return r.Enabled == nil || *r.Enabled +} + +// ShouldSyncMembers returns true if members should be synced (defaults to +// true). +func (r SyncRule) ShouldSyncMembers() bool { + return r.SyncMembers == nil || *r.SyncMembers +} + +// GetName returns the rule name, defaulting to GitHubTeamName if not set. +func (r SyncRule) GetName() string { + if r.Name != "" { + return r.Name + } + if r.GitHubTeamName != "" { + return r.GitHubTeamName + } + return r.OktaGroupName +} + +// ComplianceViolation represents a single branch protection rule violation. +type ComplianceViolation struct { + Type string + Description string +} + +// PRComplianceResult contains PR compliance check results including +// violations and user bypass permissions. +type PRComplianceResult struct { + PR *github.PullRequest + BaseBranch string + Protection *github.Protection + BranchRules *github.BranchRules + Violations []ComplianceViolation + UserHasBypass bool + UserBypassReason string +} + +// HasViolations returns true if any compliance violations were detected. +func (r *PRComplianceResult) HasViolations() bool { + return len(r.Violations) > 0 +} + +// WasBypassed returns true if violations exist and user had bypass +// permission. +func (r *PRComplianceResult) WasBypassed() bool { + return r.HasViolations() && r.UserHasBypass +} + +// TeamSyncResult contains the results of syncing team membership. +type TeamSyncResult struct { + TeamName string + MembersAdded []string + MembersRemoved []string + MembersSkippedExternal []string + Errors []string +} + +// GroupInfo contains Okta group details and member list. +type GroupInfo struct { + ID string + Name string + Members []string + SkippedNoGitHubUsername []string +} + +// GroupMembersResult contains the results of fetching group members. +type GroupMembersResult struct { + Members []string + SkippedNoGitHubUsername []string +} + +// SyncReport contains the results of syncing a single Okta group to GitHub +// team. +type SyncReport struct { + Rule string + OktaGroup string + GitHubTeam string + MembersAdded []string + MembersRemoved []string + MembersSkippedExternal []string + MembersSkippedNoGHUsername []string + Errors []string +} + +// HasErrors returns true if any errors occurred during sync. +func (r *SyncReport) HasErrors() bool { + return len(r.Errors) > 0 +} + +// HasChanges returns true if members were added or removed. +func (r *SyncReport) HasChanges() bool { + return len(r.MembersAdded) > 0 || len(r.MembersRemoved) > 0 +} + +// OrphanedUsersReport contains users who are org members but not in any +// synced teams. +type OrphanedUsersReport struct { + OrphanedUsers []string +} + +// SyncResult contains all sync reports and orphaned users report. +type SyncResult struct { + Reports []*SyncReport + OrphanedUsers *OrphanedUsersReport +} diff --git a/internal/github/client/client.go b/internal/github/client/client.go index fd3e7ce..93ec3a3 100644 --- a/internal/github/client/client.go +++ b/internal/github/client/client.go @@ -21,6 +21,7 @@ import ( // Client wraps the GitHub API client with App authentication. // automatically refreshes installation tokens before expiry. +// implements domain.GitHubClient. type Client struct { client *github.Client org string @@ -144,14 +145,21 @@ func (c *Client) refreshToken(ctx context.Context) error { } // ensureValidToken refreshes the installation token if it expires within 5 -// minutes. +// minutes. uses double-check pattern to avoid redundant refreshes under +// concurrent access. func (c *Client) ensureValidToken(ctx context.Context) error { c.tokenMu.RLock() needsRefresh := time.Now().Add(5 * time.Minute).After(c.tokenExpAt) c.tokenMu.RUnlock() if needsRefresh { - return c.refreshToken(ctx) + c.tokenMu.Lock() + // double-check after acquiring write lock + if time.Now().Add(5 * time.Minute).After(c.tokenExpAt) { + c.tokenMu.Unlock() + return c.refreshToken(ctx) + } + c.tokenMu.Unlock() } return nil diff --git a/internal/github/client/pr.go b/internal/github/client/pr.go index a0d745f..065c851 100644 --- a/internal/github/client/pr.go +++ b/internal/github/client/pr.go @@ -5,32 +5,14 @@ import ( "fmt" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/google/go-github/v79/github" ) -// ComplianceViolation represents a single branch protection rule violation. -type ComplianceViolation struct { - Type string - Description string -} - -// PRComplianceResult contains PR compliance check results including -// violations and user bypass permissions. -type PRComplianceResult struct { - PR *github.PullRequest - BaseBranch string - Protection *github.Protection - BranchRules *github.BranchRules - Violations []ComplianceViolation - UserHasBypass bool - UserBypassReason string -} - // CheckPRCompliance verifies if a merged PR met branch protection // requirements. checks review requirements, status checks, and user bypass // permissions. -func (c *Client) CheckPRCompliance(ctx context.Context, owner, repo string, prNumber int) (*PRComplianceResult, error) { +func (c *Client) CheckPRCompliance(ctx context.Context, owner, repo string, prNumber int) (*domain.PRComplianceResult, error) { if err := c.ensureValidToken(ctx); err != nil { return nil, err } @@ -41,19 +23,19 @@ func (c *Client) CheckPRCompliance(ctx context.Context, owner, repo string, prNu } if pr == nil { - return nil, errors.Wrapf(internalerrors.ErrMissingPRData, "pr #%d returned nil", prNumber) + return nil, errors.Wrapf(domain.ErrMissingPRData, "pr #%d returned nil", prNumber) } if pr.Base == nil || pr.Base.Ref == nil { - return nil, errors.Wrapf(internalerrors.ErrMissingPRData, "pr #%d missing base branch", prNumber) + return nil, errors.Wrapf(domain.ErrMissingPRData, "pr #%d missing base branch", prNumber) } baseBranch := *pr.Base.Ref - result := &PRComplianceResult{ + result := &domain.PRComplianceResult{ PR: pr, BaseBranch: baseBranch, - Violations: []ComplianceViolation{}, + Violations: []domain.ComplianceViolation{}, } // fetch legacy branch protection rules @@ -68,8 +50,14 @@ func (c *Client) CheckPRCompliance(ctx context.Context, owner, repo string, prNu result.BranchRules = branchRules } - c.checkReviewRequirements(ctx, owner, repo, pr, result) - c.checkStatusRequirements(ctx, owner, repo, pr, result) + if err := c.checkReviewRequirements(ctx, owner, repo, pr, result); err != nil { + return nil, errors.Wrapf(err, "failed to check review requirements for pr #%d", prNumber) + } + + if err := c.checkStatusRequirements(ctx, owner, repo, pr, result); err != nil { + return nil, errors.Wrapf(err, "failed to check status requirements for pr #%d", prNumber) + } + c.checkUserBypassPermission(ctx, owner, repo, pr, result) return result, nil @@ -77,7 +65,7 @@ func (c *Client) CheckPRCompliance(ctx context.Context, owner, repo string, prNu // checkReviewRequirements validates that PR had required approving reviews. // checks both legacy branch protection and repository rulesets. -func (c *Client) checkReviewRequirements(ctx context.Context, owner, repo string, pr *github.PullRequest, result *PRComplianceResult) { +func (c *Client) checkReviewRequirements(ctx context.Context, owner, repo string, pr *github.PullRequest, result *domain.PRComplianceResult) error { requiredApprovals := 0 // check legacy branch protection @@ -95,34 +83,45 @@ func (c *Client) checkReviewRequirements(ctx context.Context, owner, repo string } if requiredApprovals == 0 { - return + return nil } - reviews, _, err := c.client.PullRequests.ListReviews(ctx, owner, repo, *pr.Number, nil) + reviews, _, err := c.client.PullRequests.ListReviews(ctx, owner, repo, *pr.Number, &github.ListOptions{PerPage: 100}) if err != nil { - return + return errors.Wrapf(err, "failed to list reviews for pr #%d in %s/%s", *pr.Number, owner, repo) } - approvedCount := 0 + // deduplicate reviews per user, keeping only the latest state + latestReviewByUser := make(map[string]string) for _, review := range reviews { - if review.State != nil && *review.State == "APPROVED" { + if review.User == nil || review.User.Login == nil || review.State == nil { + continue + } + latestReviewByUser[*review.User.Login] = *review.State + } + + approvedCount := 0 + for _, state := range latestReviewByUser { + if state == "APPROVED" { approvedCount++ } } if approvedCount < requiredApprovals { - result.Violations = append(result.Violations, ComplianceViolation{ + result.Violations = append(result.Violations, domain.ComplianceViolation{ Type: "insufficient_reviews", Description: fmt.Sprintf("required %d approving reviews, had %d", requiredApprovals, approvedCount), }) } + + return nil } // checkStatusRequirements validates that required status checks passed. // checks both legacy branch protection and repository rulesets. -func (c *Client) checkStatusRequirements(ctx context.Context, owner, repo string, pr *github.PullRequest, result *PRComplianceResult) { +func (c *Client) checkStatusRequirements(ctx context.Context, owner, repo string, pr *github.PullRequest, result *domain.PRComplianceResult) error { if pr.Head == nil || pr.Head.SHA == nil { - return + return nil } // collect required checks from both sources @@ -147,12 +146,12 @@ func (c *Client) checkStatusRequirements(ctx context.Context, owner, repo string } if len(requiredChecks) == 0 { - return + return nil } combinedStatus, _, err := c.client.Repositories.GetCombinedStatus(ctx, owner, repo, *pr.Head.SHA, nil) if err != nil { - return + return errors.Wrapf(err, "failed to get combined status for sha '%s' in %s/%s", *pr.Head.SHA, owner, repo) } passedChecks := make(map[string]bool) @@ -162,19 +161,34 @@ func (c *Client) checkStatusRequirements(ctx context.Context, owner, repo string } } + // also check GitHub Actions check runs (modern repos use these instead + // of commit statuses) + checkRuns, _, err := c.client.Checks.ListCheckRunsForRef(ctx, owner, repo, *pr.Head.SHA, &github.ListCheckRunsOptions{ + ListOptions: github.ListOptions{PerPage: 100}, + }) + if err == nil && checkRuns != nil { + for _, run := range checkRuns.CheckRuns { + if run.Name != nil && run.Conclusion != nil && *run.Conclusion == "success" { + passedChecks[*run.Name] = true + } + } + } + for required := range requiredChecks { if !passedChecks[required] { - result.Violations = append(result.Violations, ComplianceViolation{ + result.Violations = append(result.Violations, domain.ComplianceViolation{ Type: "missing_status_check", Description: fmt.Sprintf("required check '%s' did not pass", required), }) } } + + return nil } // checkUserBypassPermission checks if the user who merged the PR has admin or // maintainer permissions allowing bypass. -func (c *Client) checkUserBypassPermission(ctx context.Context, owner, repo string, pr *github.PullRequest, result *PRComplianceResult) { +func (c *Client) checkUserBypassPermission(ctx context.Context, owner, repo string, pr *github.PullRequest, result *domain.PRComplianceResult) { if pr.MergedBy == nil || pr.MergedBy.Login == nil { return } @@ -197,14 +211,3 @@ func (c *Client) checkUserBypassPermission(ctx context.Context, owner, repo stri } } } - -// HasViolations returns true if any compliance violations were detected. -func (r *PRComplianceResult) HasViolations() bool { - return len(r.Violations) > 0 -} - -// WasBypassed returns true if violations exist and user had bypass -// permission. -func (r *PRComplianceResult) WasBypassed() bool { - return r.HasViolations() && r.UserHasBypass -} diff --git a/internal/github/client/teams.go b/internal/github/client/teams.go index b2a8657..35a9971 100644 --- a/internal/github/client/teams.go +++ b/internal/github/client/teams.go @@ -3,20 +3,15 @@ package client import ( "context" "fmt" + "strings" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/google/go-github/v79/github" ) -// TeamSyncResult contains the results of syncing team membership. -type TeamSyncResult struct { - TeamName string - MembersAdded []string - MembersRemoved []string - MembersSkippedExternal []string - Errors []string -} +// compile-time assertion +var _ domain.GitHubClient = (*Client)(nil) // GetOrCreateTeam fetches an existing team by slug or creates it if missing. func (c *Client) GetOrCreateTeam(ctx context.Context, teamName, privacy string) (*github.Team, error) { @@ -29,52 +24,64 @@ func (c *Client) GetOrCreateTeam(ctx context.Context, teamName, privacy string) return team, nil } - if resp != nil && resp.StatusCode == 404 { - newTeam := &github.NewTeam{ - Name: teamName, - Privacy: &privacy, - } - team, _, err = c.client.Teams.CreateTeam(ctx, c.org, *newTeam) - if err != nil { - return nil, errors.Wrapf(err, "failed to create team '%s' in org '%s'", teamName, c.org) - } - return team, nil + if resp == nil || resp.StatusCode != 404 { + return nil, errors.Wrapf(err, "failed to fetch team '%s' from org '%s'", teamName, c.org) } - return nil, errors.Wrapf(internalerrors.ErrTeamNotFound, "failed to fetch team '%s' from org '%s'", teamName, c.org) + newTeam := &github.NewTeam{ + Name: teamName, + Privacy: &privacy, + } + team, _, err = c.client.Teams.CreateTeam(ctx, c.org, *newTeam) + if err != nil { + return nil, errors.Wrapf(err, "failed to create team '%s' in org '%s'", teamName, c.org) + } + return team, nil } // GetTeamMembers returns GitHub usernames of all team members. +// paginates through all results to handle large teams. func (c *Client) GetTeamMembers(ctx context.Context, teamSlug string) ([]string, error) { if err := c.ensureValidToken(ctx); err != nil { return nil, err } - members, _, err := c.client.Teams.ListTeamMembersBySlug(ctx, c.org, teamSlug, nil) - if err != nil { - return nil, errors.Wrapf(err, "failed to list members for team '%s'", teamSlug) + opts := &github.TeamListTeamMembersOptions{ + ListOptions: github.ListOptions{PerPage: 100}, } - logins := make([]string, 0, len(members)) - for _, member := range members { - if member.Login != nil { - logins = append(logins, *member.Login) + var allMembers []string + for { + members, resp, err := c.client.Teams.ListTeamMembersBySlug(ctx, c.org, teamSlug, opts) + if err != nil { + return nil, errors.Wrapf(err, "failed to list members for team '%s'", teamSlug) + } + + for _, member := range members { + if member.Login != nil { + allMembers = append(allMembers, *member.Login) + } + } + + if resp.NextPage == 0 { + break } + opts.Page = resp.NextPage } - return logins, nil + return allMembers, nil } // SyncTeamMembers adds and removes members to match desired state. // collects errors for individual operations but continues processing. skips // removal of external collaborators (outside org members). applies safety // threshold to prevent mass removal during outages. -func (c *Client) SyncTeamMembers(ctx context.Context, teamSlug string, desiredMembers []string, safetyThreshold float64) (*TeamSyncResult, error) { +func (c *Client) SyncTeamMembers(ctx context.Context, teamSlug string, desiredMembers []string, safetyThreshold float64) (*domain.TeamSyncResult, error) { if err := c.ensureValidToken(ctx); err != nil { return nil, err } - result := &TeamSyncResult{ + result := &domain.TeamSyncResult{ TeamName: teamSlug, MembersAdded: []string{}, MembersRemoved: []string{}, @@ -89,16 +96,16 @@ func (c *Client) SyncTeamMembers(ctx context.Context, teamSlug string, desiredMe currentSet := make(map[string]bool) for _, member := range currentMembers { - currentSet[member] = true + currentSet[strings.ToLower(member)] = true } desiredSet := make(map[string]bool) for _, member := range desiredMembers { - desiredSet[member] = true + desiredSet[strings.ToLower(member)] = true } for _, desired := range desiredMembers { - if !currentSet[desired] { + if !currentSet[strings.ToLower(desired)] { _, _, err := c.client.Teams.AddTeamMembershipBySlug(ctx, c.org, teamSlug, desired, nil) if err != nil { errMsg := fmt.Sprintf("failed to add '%s' to team '%s': %v", desired, teamSlug, err) @@ -111,7 +118,7 @@ func (c *Client) SyncTeamMembers(ctx context.Context, teamSlug string, desiredMe var toRemove []string for _, current := range currentMembers { - if !desiredSet[current] { + if !desiredSet[strings.ToLower(current)] { toRemove = append(toRemove, current) } } diff --git a/internal/github/webhooks/webhooks.go b/internal/github/webhooks/webhooks.go index 5d5d6c5..979c36a 100644 --- a/internal/github/webhooks/webhooks.go +++ b/internal/github/webhooks/webhooks.go @@ -10,7 +10,7 @@ import ( "strings" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/google/go-github/v79/github" ) @@ -64,17 +64,17 @@ type MembershipEvent struct { func ValidateWebhookSignature(payload []byte, signature string, secret string) error { if secret == "" { if signature != "" { - return internalerrors.ErrUnexpectedSignature + return domain.ErrUnexpectedSignature } return nil } if signature == "" { - return internalerrors.ErrMissingSignature + return domain.ErrMissingSignature } if !strings.HasPrefix(signature, "sha256=") { - return errors.Wrap(internalerrors.ErrInvalidSignature, "must start with 'sha256='") + return errors.Wrap(domain.ErrInvalidSignature, "must start with 'sha256='") } mac := hmac.New(sha256.New, []byte(secret)) @@ -83,7 +83,7 @@ func ValidateWebhookSignature(payload []byte, signature string, secret string) e expectedSignature := "sha256=" + expectedMAC if !hmac.Equal([]byte(signature), []byte(expectedSignature)) { - return errors.Wrap(internalerrors.ErrInvalidSignature, "computed signature does not match") + return errors.Wrap(domain.ErrInvalidSignature, "computed signature does not match") } return nil @@ -97,16 +97,16 @@ func ParsePullRequestEvent(payload []byte) (*PullRequestEvent, error) { return nil, errors.Wrap(err, "failed to unmarshal pull request event") } if event.PullRequest == nil { - return nil, errors.Wrap(internalerrors.ErrMissingPRData, "missing pull_request field") + return nil, errors.Wrap(domain.ErrMissingPRData, "missing pull_request field") } if event.PullRequest.Number == nil { - return nil, errors.Wrap(internalerrors.ErrMissingPRData, "missing pr number") + return nil, errors.Wrap(domain.ErrMissingPRData, "missing pr number") } if event.PullRequest.Base == nil || event.PullRequest.Base.Ref == nil { - return nil, errors.Wrap(internalerrors.ErrMissingPRData, "missing base branch") + return nil, errors.Wrap(domain.ErrMissingPRData, "missing base branch") } if event.Repository == nil { - return nil, errors.Wrap(internalerrors.ErrMissingPRData, "missing repository") + return nil, errors.Wrap(domain.ErrMissingPRData, "missing repository") } return &event, nil } diff --git a/internal/github/webhooks/webhooks_test.go b/internal/github/webhooks/webhooks_test.go new file mode 100644 index 0000000..ed9169b --- /dev/null +++ b/internal/github/webhooks/webhooks_test.go @@ -0,0 +1,355 @@ +package webhooks + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "testing" + + "github.com/cockroachdb/errors" + "github.com/cruxstack/github-ops-app/internal/domain" + "github.com/google/go-github/v79/github" +) + +func TestValidateWebhookSignature(t *testing.T) { + tests := []struct { + name string + payload []byte + signature string + secret string + wantErr error + }{ + { + name: "no secret, no signature", + payload: []byte(`{"test": true}`), + signature: "", + secret: "", + wantErr: nil, + }, + { + name: "no secret, unexpected signature", + payload: []byte(`{"test": true}`), + signature: "sha256=abc", + secret: "", + wantErr: domain.ErrUnexpectedSignature, + }, + { + name: "secret configured, missing signature", + payload: []byte(`{"test": true}`), + signature: "", + secret: "mysecret", + wantErr: domain.ErrMissingSignature, + }, + { + name: "wrong prefix", + payload: []byte(`{"test": true}`), + signature: "sha1=abc123", + secret: "mysecret", + wantErr: domain.ErrInvalidSignature, + }, + { + name: "wrong signature value", + payload: []byte(`{"test": true}`), + signature: "sha256=0000000000000000000000000000000000000000000000000000000000000000", + secret: "mysecret", + wantErr: domain.ErrInvalidSignature, + }, + { + name: "valid signature", + payload: []byte(`{}`), + signature: computeSignature([]byte(`{}`), "test-secret"), + secret: "test-secret", + wantErr: nil, + }, + { + name: "empty payload with valid signature", + payload: []byte{}, + signature: computeSignature([]byte{}, "key"), + secret: "key", + wantErr: nil, + }, + { + name: "truncated signature", + payload: []byte(`test`), + signature: "sha256=abc", + secret: "mysecret", + wantErr: domain.ErrInvalidSignature, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWebhookSignature(tt.payload, tt.signature, tt.secret) + if tt.wantErr == nil { + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + return + } + if err == nil { + t.Errorf("expected error %v, got nil", tt.wantErr) + return + } + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected error %v, got: %v", tt.wantErr, err) + } + }) + } +} + +// computeSignature generates a valid HMAC-SHA256 signature for testing. +func computeSignature(payload []byte, secret string) string { + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write(payload) + return "sha256=" + hex.EncodeToString(mac.Sum(nil)) +} + +func TestParsePullRequestEvent_Valid(t *testing.T) { + prNumber := 42 + baseBranch := "main" + payload := map[string]any{ + "action": "closed", + "number": prNumber, + "pull_request": map[string]any{ + "number": prNumber, + "merged": true, + "base": map[string]any{"ref": baseBranch}, + }, + "repository": map[string]any{ + "name": "test-repo", + "full_name": "owner/test-repo", + "owner": map[string]any{"login": "owner"}, + }, + } + + data, _ := json.Marshal(payload) + event, err := ParsePullRequestEvent(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if event.Number != prNumber { + t.Errorf("expected number %d, got %d", prNumber, event.Number) + } + if !event.IsMerged() { + t.Error("expected IsMerged() to return true") + } + if event.GetBaseBranch() != baseBranch { + t.Errorf("expected base branch %q, got %q", baseBranch, event.GetBaseBranch()) + } +} + +func TestParsePullRequestEvent_MissingFields(t *testing.T) { + tests := []struct { + name string + payload map[string]any + }{ + { + name: "missing pull_request", + payload: map[string]any{"action": "closed"}, + }, + { + name: "missing pr number", + payload: map[string]any{ + "action": "closed", + "pull_request": map[string]any{ + "base": map[string]any{"ref": "main"}, + }, + "repository": map[string]any{}, + }, + }, + { + name: "missing base branch", + payload: map[string]any{ + "action": "closed", + "pull_request": map[string]any{ + "number": 1, + }, + "repository": map[string]any{}, + }, + }, + { + name: "missing repository", + payload: map[string]any{ + "action": "closed", + "pull_request": map[string]any{ + "number": 1, + "base": map[string]any{"ref": "main"}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, _ := json.Marshal(tt.payload) + _, err := ParsePullRequestEvent(data) + if err == nil { + t.Error("expected error for missing fields") + } + if !errors.Is(err, domain.ErrMissingPRData) { + t.Errorf("expected ErrMissingPRData, got: %v", err) + } + }) + } +} + +func TestParsePullRequestEvent_InvalidJSON(t *testing.T) { + _, err := ParsePullRequestEvent([]byte(`{invalid`)) + if err == nil { + t.Error("expected error for invalid JSON") + } +} + +func TestParseTeamEvent_Valid(t *testing.T) { + payload := map[string]any{ + "action": "edited", + "team": map[string]any{"slug": "engineering"}, + "sender": map[string]any{"login": "user1", "type": "User"}, + } + + data, _ := json.Marshal(payload) + event, err := ParseTeamEvent(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if event.Action != "edited" { + t.Errorf("expected action 'edited', got %q", event.Action) + } + if event.GetTeamSlug() != "engineering" { + t.Errorf("expected team slug 'engineering', got %q", event.GetTeamSlug()) + } + if event.GetSenderType() != "User" { + t.Errorf("expected sender type 'User', got %q", event.GetSenderType()) + } +} + +func TestParseTeamEvent_MissingFields(t *testing.T) { + tests := []struct { + name string + payload map[string]any + }{ + { + name: "missing team", + payload: map[string]any{"action": "edited", "sender": map[string]any{"login": "u1"}}, + }, + { + name: "missing sender", + payload: map[string]any{"action": "edited", "team": map[string]any{"slug": "t1"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, _ := json.Marshal(tt.payload) + _, err := ParseTeamEvent(data) + if err == nil { + t.Error("expected error for missing fields") + } + }) + } +} + +func TestParseMembershipEvent_Valid(t *testing.T) { + payload := map[string]any{ + "action": "added", + "scope": "team", + "member": map[string]any{"login": "newuser"}, + "team": map[string]any{"slug": "engineering"}, + "sender": map[string]any{"login": "admin", "type": "User"}, + } + + data, _ := json.Marshal(payload) + event, err := ParseMembershipEvent(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !event.IsTeamScope() { + t.Error("expected IsTeamScope() to return true") + } + if event.GetSenderLogin() != "admin" { + t.Errorf("expected sender login 'admin', got %q", event.GetSenderLogin()) + } +} + +func TestParseMembershipEvent_MissingFields(t *testing.T) { + tests := []struct { + name string + payload map[string]any + }{ + { + name: "missing team", + payload: map[string]any{"action": "added", "member": map[string]any{"login": "u"}, "sender": map[string]any{"login": "s"}}, + }, + { + name: "missing member", + payload: map[string]any{"action": "added", "team": map[string]any{"slug": "t"}, "sender": map[string]any{"login": "s"}}, + }, + { + name: "missing sender", + payload: map[string]any{"action": "added", "team": map[string]any{"slug": "t"}, "member": map[string]any{"login": "u"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, _ := json.Marshal(tt.payload) + _, err := ParseMembershipEvent(data) + if err == nil { + t.Error("expected error for missing fields") + } + }) + } +} + +func TestPullRequestEvent_IsMerged(t *testing.T) { + merged := true + notMerged := false + + tests := []struct { + name string + action string + merged *bool + want bool + }{ + {"closed and merged", "closed", &merged, true}, + {"closed not merged", "closed", ¬Merged, false}, + {"opened", "opened", nil, false}, + {"closed with nil merged", "closed", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event := &PullRequestEvent{ + Action: tt.action, + PullRequest: &github.PullRequest{ + Merged: tt.merged, + }, + } + if got := event.IsMerged(); got != tt.want { + t.Errorf("IsMerged() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMembershipEvent_IsTeamScope(t *testing.T) { + tests := []struct { + scope string + want bool + }{ + {"team", true}, + {"organization", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.scope, func(t *testing.T) { + event := &MembershipEvent{Scope: tt.scope} + if got := event.IsTeamScope(); got != tt.want { + t.Errorf("IsTeamScope() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/notifiers/slack_messages.go b/internal/notifiers/slack_messages.go index 256754d..2500268 100644 --- a/internal/notifiers/slack_messages.go +++ b/internal/notifiers/slack_messages.go @@ -5,17 +5,15 @@ import ( "fmt" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" - "github.com/cruxstack/github-ops-app/internal/github/client" - "github.com/cruxstack/github-ops-app/internal/okta" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/slack-go/slack" ) // NotifyPRBypass sends a Slack notification when branch protection is // bypassed. -func (s *SlackNotifier) NotifyPRBypass(ctx context.Context, result *client.PRComplianceResult, repoFullName string) error { +func (s *SlackNotifier) NotifyPRBypass(ctx context.Context, result *domain.PRComplianceResult, repoFullName string) error { if result.PR == nil { - return errors.Wrap(internalerrors.ErrMissingPRData, "pr result missing") + return errors.Wrap(domain.ErrMissingPRData, "pr result missing") } prURL := "" @@ -90,14 +88,14 @@ func (s *SlackNotifier) NotifyPRBypass(ctx context.Context, result *client.PRCom } // NotifyOktaSync sends a Slack notification with Okta sync results. -func (s *SlackNotifier) NotifyOktaSync(ctx context.Context, reports []*okta.SyncReport, githubOrg string) error { +func (s *SlackNotifier) NotifyOktaSync(ctx context.Context, reports []*domain.SyncReport, githubOrg string) error { if len(reports) == 0 { return nil } // aggregate stats var totalAdded, totalRemoved int - var rulesWithChanges, rulesWithoutChanges []*okta.SyncReport + var rulesWithChanges, rulesWithoutChanges []*domain.SyncReport var allErrors []string var allSkippedExternal, allSkippedNoGHUsername []string @@ -239,7 +237,7 @@ func (s *SlackNotifier) NotifyOktaSync(ctx context.Context, reports []*okta.Sync // NotifyOrphanedUsers sends a Slack notification about organization members // not in any synced teams. -func (s *SlackNotifier) NotifyOrphanedUsers(ctx context.Context, report *okta.OrphanedUsersReport) error { +func (s *SlackNotifier) NotifyOrphanedUsers(ctx context.Context, report *domain.OrphanedUsersReport) error { if report == nil || len(report.OrphanedUsers) == 0 { return nil } diff --git a/internal/notifiers/slack_messages_test.go b/internal/notifiers/slack_messages_test.go new file mode 100644 index 0000000..18fc39d --- /dev/null +++ b/internal/notifiers/slack_messages_test.go @@ -0,0 +1,235 @@ +package notifiers + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/cruxstack/github-ops-app/internal/domain" + "github.com/google/go-github/v79/github" +) + +// slackTestServer creates a mock Slack API server that captures posted +// messages and returns them via the messages channel. +func slackTestServer(t *testing.T) (*httptest.Server, chan map[string]any) { + t.Helper() + messages := make(chan map[string]any, 10) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + r.Body.Close() + + // parse form-encoded body from slack client + var msg map[string]any + if strings.Contains(r.Header.Get("Content-Type"), "json") { + json.Unmarshal(body, &msg) + } else { + msg = map[string]any{"raw_body": string(body), "path": r.URL.Path} + } + + messages <- msg + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + w.Write([]byte(`{"ok":true,"channel":"C123","ts":"1234.5678"}`)) + })) + + return srv, messages +} + +func TestNotifyPRBypass_Success(t *testing.T) { + srv, messages := slackTestServer(t) + defer srv.Close() + + notifier := NewSlackNotifierWithAPIURL( + "xoxb-test", + SlackChannels{Default: "C_DEFAULT"}, + SlackMessages{PRBypassFooterNote: "review this"}, + srv.URL+"/", + ) + + prNumber := 42 + prTitle := "fix: patch auth" + prURL := "https://github.com/org/repo/pull/42" + mergedBy := "admin-user" + result := &domain.PRComplianceResult{ + PR: &github.PullRequest{ + Number: &prNumber, + Title: &prTitle, + HTMLURL: &prURL, + MergedBy: &github.User{Login: &mergedBy}, + }, + UserHasBypass: true, + UserBypassReason: "repository admin", + Violations: []domain.ComplianceViolation{ + {Type: "insufficient_reviews", Description: "required 2, had 0"}, + }, + } + + err := notifier.NotifyPRBypass(context.Background(), result, "org/repo") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msg := <-messages + if msg == nil { + t.Fatal("expected a message to be posted") + } +} + +func TestNotifyPRBypass_NilPR(t *testing.T) { + notifier := NewSlackNotifier("xoxb-test", SlackChannels{Default: "C"}, SlackMessages{}) + + result := &domain.PRComplianceResult{PR: nil} + err := notifier.NotifyPRBypass(context.Background(), result, "org/repo") + if err == nil { + t.Fatal("expected error for nil PR") + } +} + +func TestNotifyOktaSync_EmptyReports(t *testing.T) { + notifier := NewSlackNotifier("xoxb-test", SlackChannels{Default: "C"}, SlackMessages{}) + + err := notifier.NotifyOktaSync(context.Background(), nil, "org") + if err != nil { + t.Fatalf("expected nil for empty reports, got: %v", err) + } + + err = notifier.NotifyOktaSync(context.Background(), []*domain.SyncReport{}, "org") + if err != nil { + t.Fatalf("expected nil for zero reports, got: %v", err) + } +} + +func TestNotifyOktaSync_Success(t *testing.T) { + srv, messages := slackTestServer(t) + defer srv.Close() + + notifier := NewSlackNotifierWithAPIURL( + "xoxb-test", + SlackChannels{Default: "C_DEFAULT", OktaSync: "C_OKTA"}, + SlackMessages{}, + srv.URL+"/", + ) + + reports := []*domain.SyncReport{ + { + Rule: "eng", OktaGroup: "Engineering", GitHubTeam: "engineering", + MembersAdded: []string{"alice"}, MembersRemoved: []string{"bob"}, + }, + { + Rule: "platform", OktaGroup: "Platform", GitHubTeam: "platform", + }, + } + + err := notifier.NotifyOktaSync(context.Background(), reports, "my-org") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msg := <-messages + if msg == nil { + t.Fatal("expected a message to be posted") + } +} + +func TestNotifyOrphanedUsers_NilReport(t *testing.T) { + notifier := NewSlackNotifier("xoxb-test", SlackChannels{Default: "C"}, SlackMessages{}) + + err := notifier.NotifyOrphanedUsers(context.Background(), nil) + if err != nil { + t.Fatalf("expected nil for nil report, got: %v", err) + } +} + +func TestNotifyOrphanedUsers_EmptyUsers(t *testing.T) { + notifier := NewSlackNotifier("xoxb-test", SlackChannels{Default: "C"}, SlackMessages{}) + + err := notifier.NotifyOrphanedUsers(context.Background(), &domain.OrphanedUsersReport{}) + if err != nil { + t.Fatalf("expected nil for empty users, got: %v", err) + } +} + +func TestNotifyOrphanedUsers_Success(t *testing.T) { + srv, messages := slackTestServer(t) + defer srv.Close() + + notifier := NewSlackNotifierWithAPIURL( + "xoxb-test", + SlackChannels{Default: "C_DEFAULT"}, + SlackMessages{}, + srv.URL+"/", + ) + + report := &domain.OrphanedUsersReport{ + OrphanedUsers: []string{"orphan-1", "orphan-2"}, + } + + err := notifier.NotifyOrphanedUsers(context.Background(), report) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msg := <-messages + if msg == nil { + t.Fatal("expected a message to be posted") + } +} + +func TestNotifyOktaSync_WithErrors(t *testing.T) { + srv, messages := slackTestServer(t) + defer srv.Close() + + notifier := NewSlackNotifierWithAPIURL( + "xoxb-test", + SlackChannels{Default: "C"}, + SlackMessages{}, + srv.URL+"/", + ) + + reports := []*domain.SyncReport{ + { + Rule: "eng", OktaGroup: "Engineering", GitHubTeam: "engineering", + MembersAdded: []string{"alice"}, + Errors: []string{"rate limited"}, + MembersSkippedExternal: []string{"external-1"}, + MembersSkippedNoGHUsername: []string{"newhire@co.com"}, + }, + } + + err := notifier.NotifyOktaSync(context.Background(), reports, "org") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msg := <-messages + if msg == nil { + t.Fatal("expected a message") + } +} + +func TestChannelFor_CustomChannels(t *testing.T) { + n := &SlackNotifier{ + channels: SlackChannels{ + Default: "C_DEFAULT", + PRBypass: "C_PR", + OktaSync: "C_OKTA", + OrphanedUsers: "C_ORPHAN", + }, + } + + if got := n.channelFor(n.channels.PRBypass); got != "C_PR" { + t.Errorf("expected C_PR, got %s", got) + } + if got := n.channelFor(n.channels.OktaSync); got != "C_OKTA" { + t.Errorf("expected C_OKTA, got %s", got) + } + if got := n.channelFor(n.channels.OrphanedUsers); got != "C_ORPHAN" { + t.Errorf("expected C_ORPHAN, got %s", got) + } +} diff --git a/internal/okta/client.go b/internal/okta/client.go index d1f1adf..d7f085f 100644 --- a/internal/okta/client.go +++ b/internal/okta/client.go @@ -1,5 +1,5 @@ -// Package okta provides Okta API client and group synchronization to GitHub -// teams. Uses OAuth 2.0 with private key authentication. +// Package okta provides Okta API client for group and user management. +// Uses OAuth 2.0 with private key authentication. package okta import ( @@ -9,11 +9,12 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "log/slog" "net/http" "net/url" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/okta/okta-sdk-golang/v6/okta" ) @@ -21,6 +22,15 @@ import ( // these scopes are necessary for group sync functionality. var DefaultScopes = []string{"okta.groups.read", "okta.users.read"} +// certPoolKey is a typed context key for TLS certificate pool injection. +type certPoolKey struct{} + +// WithCertPool returns a new context with the given TLS certificate pool. +// used by integration tests to inject self-signed certs. +func WithCertPool(ctx context.Context, pool *x509.CertPool) context.Context { + return context.WithValue(ctx, certPoolKey{}, pool) +} + // convertToPKCS1 converts a PEM-encoded private key to PKCS#1 format if needed. // the Okta SDK requires PKCS#1 format (BEGIN RSA PRIVATE KEY), but Okta's // console generates PKCS#8 keys (BEGIN PRIVATE KEY). this function detects the @@ -61,12 +71,16 @@ func convertToPKCS1(keyPEM []byte) ([]byte, error) { } // Client wraps the Okta SDK client with custom configuration. +// implements domain.OktaClient. type Client struct { apiClient *okta.APIClient - ctx context.Context githubUserField string + logger *slog.Logger } +// compile-time assertion +var _ domain.OktaClient = (*Client)(nil) + // ClientConfig contains Okta client configuration. type ClientConfig struct { Domain string @@ -76,6 +90,7 @@ type ClientConfig struct { Scopes []string GitHubUserField string BaseURL string + Logger *slog.Logger } // NewClient creates an Okta client with background context. @@ -88,7 +103,7 @@ func NewClient(cfg *ClientConfig) (*Client, error) { // testing. func NewClientWithContext(ctx context.Context, cfg *ClientConfig) (*Client, error) { if cfg.ClientID == "" || len(cfg.PrivateKey) == 0 { - return nil, internalerrors.ErrMissingOAuthCreds + return nil, domain.ErrMissingOAuthCreds } orgURL := cfg.BaseURL @@ -119,7 +134,7 @@ func NewClientWithContext(ctx context.Context, cfg *ClientConfig) (*Client, erro opts = append(opts, okta.WithPrivateKeyId(cfg.PrivateKeyID)) } - if certPool, ok := ctx.Value("okta_tls_cert_pool").(*x509.CertPool); ok && certPool != nil { + if certPool, ok := ctx.Value(certPoolKey{}).(*x509.CertPool); ok && certPool != nil { httpClient := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ @@ -161,10 +176,15 @@ func NewClientWithContext(ctx context.Context, cfg *ClientConfig) (*Client, erro apiClient := okta.NewAPIClient(oktaCfg) + logger := cfg.Logger + if logger == nil { + logger = slog.Default() + } + return &Client{ apiClient: apiClient, - ctx: ctx, githubUserField: cfg.GitHubUserField, + logger: logger, }, nil } @@ -173,75 +193,96 @@ func (c *Client) GetAPIClient() *okta.APIClient { return c.apiClient } -// GetContext returns the context used for API requests. -func (c *Client) GetContext() context.Context { - return c.ctx -} - -// ListGroups fetches all Okta groups. -func (c *Client) ListGroups() ([]okta.Group, error) { - groups, _, err := c.apiClient.GroupAPI.ListGroups(c.ctx).Execute() +// ListGroups fetches all Okta groups with pagination. +func (c *Client) ListGroups(ctx context.Context) ([]okta.Group, error) { + groups, resp, err := c.apiClient.GroupAPI.ListGroups(ctx).Limit(200).Execute() if err != nil { return nil, errors.Wrap(err, "failed to list groups") } - return groups, nil + + allGroups := make([]okta.Group, 0, len(groups)) + allGroups = append(allGroups, groups...) + + for resp != nil && resp.HasNextPage() { + var nextGroups []okta.Group + resp, err = resp.Next(&nextGroups) + if err != nil { + return nil, errors.Wrap(err, "failed to fetch next page of groups") + } + allGroups = append(allGroups, nextGroups...) + } + + return allGroups, nil } // GetGroupByName searches for an Okta group by exact name match. -func (c *Client) GetGroupByName(name string) (*okta.Group, error) { - groups, _, err := c.apiClient.GroupAPI.ListGroups(c.ctx).Q(name).Execute() +// paginates through results in case the group is not on the first page. +func (c *Client) GetGroupByName(ctx context.Context, name string) (*okta.Group, error) { + groups, resp, err := c.apiClient.GroupAPI.ListGroups(ctx).Q(name).Limit(200).Execute() if err != nil { return nil, errors.Wrapf(err, "failed to search for group '%s'", name) } - for i := range groups { - group := &groups[i] - // check if profile is nil - if group.Profile == nil { - continue - } - - // try OktaUserGroupProfile first - if group.Profile.OktaUserGroupProfile != nil { - groupName := group.Profile.OktaUserGroupProfile.GetName() + for { + for i := range groups { + group := &groups[i] + groupName := extractGroupName(group) if groupName == name { return group, nil } } - // try OktaActiveDirectoryGroupProfile as fallback - if group.Profile.OktaActiveDirectoryGroupProfile != nil { - groupName := group.Profile.OktaActiveDirectoryGroupProfile.GetName() - if groupName == name { - return group, nil - } + if resp == nil || !resp.HasNextPage() { + break } + + var nextGroups []okta.Group + resp, err = resp.Next(&nextGroups) + if err != nil { + return nil, errors.Wrapf(err, "failed to fetch next page while searching for group '%s'", name) + } + groups = nextGroups } return nil, errors.Newf("group '%s' not found", name) } -// GroupMembersResult contains the results of fetching group members. -type GroupMembersResult struct { - Members []string - SkippedNoGitHubUsername []string -} - // GetGroupMembers fetches GitHub usernames for all active members of an Okta -// group. only includes users with status "ACTIVE" to exclude -// suspended/deprovisioned users. skips users without a GitHub username in -// their profile and tracks them separately. -func (c *Client) GetGroupMembers(groupID string) (*GroupMembersResult, error) { - users, _, err := c.apiClient.GroupAPI.ListGroupUsers(c.ctx, groupID).Execute() +// group. paginates through all members. only includes users with status +// "ACTIVE" to exclude suspended/deprovisioned users. skips users without a +// GitHub username in their profile and tracks them separately. +func (c *Client) GetGroupMembers(ctx context.Context, groupID string) (*domain.GroupMembersResult, error) { + users, resp, err := c.apiClient.GroupAPI.ListGroupUsers(ctx, groupID).Limit(200).Execute() if err != nil { return nil, errors.Wrapf(err, "failed to list members for group '%s'", groupID) } - result := &GroupMembersResult{ + result := &domain.GroupMembersResult{ Members: make([]string, 0, len(users)), SkippedNoGitHubUsername: []string{}, } + for { + c.processGroupUsers(users, result) + + if resp == nil || !resp.HasNextPage() { + break + } + + var nextUsers []okta.User + resp, err = resp.Next(&nextUsers) + if err != nil { + return nil, errors.Wrapf(err, "failed to fetch next page of members for group '%s'", groupID) + } + users = nextUsers + } + + return result, nil +} + +// processGroupUsers extracts GitHub usernames from a batch of Okta users +// and appends them to the result. +func (c *Client) processGroupUsers(users []okta.User, result *domain.GroupMembersResult) { for _, user := range users { if user.GetStatus() != "ACTIVE" { continue @@ -270,6 +311,4 @@ func (c *Client) GetGroupMembers(groupID string) (*GroupMembersResult, error) { result.SkippedNoGitHubUsername, profile.GetEmail()) } } - - return result, nil } diff --git a/internal/okta/groups.go b/internal/okta/groups.go index cd3e84a..a0a5acb 100644 --- a/internal/okta/groups.go +++ b/internal/okta/groups.go @@ -1,62 +1,62 @@ package okta import ( + "context" + "log/slog" "regexp" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" - "github.com/okta/okta-sdk-golang/v6/okta" + "github.com/cruxstack/github-ops-app/internal/domain" + oktasdk "github.com/okta/okta-sdk-golang/v6/okta" ) -// GroupInfo contains Okta group details and member list. -type GroupInfo struct { - ID string - Name string - Members []string - SkippedNoGitHubUsername []string +// extractGroupName returns the group name from either profile type. +func extractGroupName(group *oktasdk.Group) string { + if group == nil || group.Profile == nil { + return "" + } + if group.Profile.OktaUserGroupProfile != nil { + return group.Profile.OktaUserGroupProfile.GetName() + } + if group.Profile.OktaActiveDirectoryGroupProfile != nil { + return group.Profile.OktaActiveDirectoryGroupProfile.GetName() + } + return "" } // GetGroupsByPattern fetches all Okta groups matching a regex pattern. -func (c *Client) GetGroupsByPattern(pattern string) ([]*GroupInfo, error) { +func (c *Client) GetGroupsByPattern(ctx context.Context, pattern string) ([]*domain.GroupInfo, error) { if pattern == "" { - return nil, internalerrors.ErrEmptyPattern + return nil, domain.ErrEmptyPattern } re, err := regexp.Compile(pattern) if err != nil { - return nil, errors.Wrapf(internalerrors.ErrInvalidPattern, "'%s'", pattern) + return nil, errors.Wrapf(domain.ErrInvalidPattern, "'%s'", pattern) } - allGroups, err := c.ListGroups() + allGroups, err := c.ListGroups(ctx) if err != nil { return nil, err } - var matched []*GroupInfo + var matched []*domain.GroupInfo for _, group := range allGroups { - if group.Profile == nil { - continue - } - - // extract group name from either profile type - var groupName string - if group.Profile.OktaUserGroupProfile != nil { - groupName = group.Profile.OktaUserGroupProfile.GetName() - } else if group.Profile.OktaActiveDirectoryGroupProfile != nil { - groupName = group.Profile.OktaActiveDirectoryGroupProfile.GetName() - } - + groupName := extractGroupName(&group) if groupName == "" { continue } if re.MatchString(groupName) { - result, err := c.GetGroupMembers(group.GetId()) + result, err := c.GetGroupMembers(ctx, group.GetId()) if err != nil { + c.logger.Warn("failed to get group members, skipping", + slog.String("group", groupName), + slog.String("error", err.Error())) continue } - matched = append(matched, &GroupInfo{ + matched = append(matched, &domain.GroupInfo{ ID: group.GetId(), Name: groupName, Members: result.Members, @@ -69,28 +69,23 @@ func (c *Client) GetGroupsByPattern(pattern string) ([]*GroupInfo, error) { } // GetGroupInfo fetches details for a single Okta group by name. -func (c *Client) GetGroupInfo(groupName string) (*GroupInfo, error) { - group, err := c.GetGroupByName(groupName) +func (c *Client) GetGroupInfo(ctx context.Context, groupName string) (*domain.GroupInfo, error) { + group, err := c.GetGroupByName(ctx, groupName) if err != nil { return nil, err } - result, err := c.GetGroupMembers(group.GetId()) + result, err := c.GetGroupMembers(ctx, group.GetId()) if err != nil { return nil, err } - // extract group name from either profile type - var name string - if group.Profile != nil { - if group.Profile.OktaUserGroupProfile != nil { - name = group.Profile.OktaUserGroupProfile.GetName() - } else if group.Profile.OktaActiveDirectoryGroupProfile != nil { - name = group.Profile.OktaActiveDirectoryGroupProfile.GetName() - } + name := extractGroupName(group) + if name == "" { + name = groupName } - return &GroupInfo{ + return &domain.GroupInfo{ ID: group.GetId(), Name: name, Members: result.Members, @@ -100,7 +95,7 @@ func (c *Client) GetGroupInfo(groupName string) (*GroupInfo, error) { // FilterEnabledGroups filters Okta groups to only those in the enabled list. // returns all groups if enabled list is empty. -func FilterEnabledGroups(groups []okta.Group, enabledNames []string) []okta.Group { +func FilterEnabledGroups(groups []oktasdk.Group, enabledNames []string) []oktasdk.Group { if len(enabledNames) == 0 { return groups } @@ -110,19 +105,11 @@ func FilterEnabledGroups(groups []okta.Group, enabledNames []string) []okta.Grou enabledMap[name] = true } - var filtered []okta.Group + var filtered []oktasdk.Group for _, group := range groups { - if group.Profile != nil { - var groupName string - if group.Profile.OktaUserGroupProfile != nil { - groupName = group.Profile.OktaUserGroupProfile.GetName() - } else if group.Profile.OktaActiveDirectoryGroupProfile != nil { - groupName = group.Profile.OktaActiveDirectoryGroupProfile.GetName() - } - - if groupName != "" && enabledMap[groupName] { - filtered = append(filtered, group) - } + groupName := extractGroupName(&group) + if groupName != "" && enabledMap[groupName] { + filtered = append(filtered, group) } } diff --git a/internal/okta/sync.go b/internal/sync/sync.go similarity index 68% rename from internal/okta/sync.go rename to internal/sync/sync.go index a2f4794..d2ea4c5 100644 --- a/internal/okta/sync.go +++ b/internal/sync/sync.go @@ -1,4 +1,6 @@ -package okta +// Package sync coordinates synchronization of Okta groups to GitHub teams. +// depends only on domain interfaces, not concrete client implementations. +package sync import ( "context" @@ -8,53 +10,24 @@ import ( "strings" "github.com/cockroachdb/errors" - "github.com/cruxstack/github-ops-app/internal/github/client" - "github.com/cruxstack/github-ops-app/internal/types" + "github.com/cruxstack/github-ops-app/internal/domain" ) -// SyncRule is an alias to types.SyncRule for convenience. -type SyncRule = types.SyncRule - -// SyncReport contains the results of syncing a single Okta group to GitHub -// team. -type SyncReport struct { - Rule string - OktaGroup string - GitHubTeam string - MembersAdded []string - MembersRemoved []string - MembersSkippedExternal []string - MembersSkippedNoGHUsername []string - Errors []string -} - -// OrphanedUsersReport contains users who are org members but not in any synced -// teams. -type OrphanedUsersReport struct { - OrphanedUsers []string -} - -// HasErrors returns true if any errors occurred during sync. -func (r *SyncReport) HasErrors() bool { - return len(r.Errors) > 0 -} - -// HasChanges returns true if members were added or removed. -func (r *SyncReport) HasChanges() bool { - return len(r.MembersAdded) > 0 || len(r.MembersRemoved) > 0 -} +// teamNameNormalizer replaces non-alphanumeric characters (except hyphens) +// in team names. compiled once at package init. +var teamNameNormalizer = regexp.MustCompile(`[^a-z0-9-]+`) // Syncer coordinates synchronization of Okta groups to GitHub teams. type Syncer struct { - oktaClient *Client - githubClient *client.Client - rules []SyncRule + oktaClient domain.OktaClient + githubClient domain.GitHubClient + rules []domain.SyncRule safetyThreshold float64 logger *slog.Logger } // NewSyncer creates a new Okta to GitHub syncer. -func NewSyncer(oktaClient *Client, githubClient *client.Client, rules []SyncRule, safetyThreshold float64, logger *slog.Logger) *Syncer { +func NewSyncer(oktaClient domain.OktaClient, githubClient domain.GitHubClient, rules []domain.SyncRule, safetyThreshold float64, logger *slog.Logger) *Syncer { return &Syncer{ oktaClient: oktaClient, githubClient: githubClient, @@ -64,16 +37,11 @@ func NewSyncer(oktaClient *Client, githubClient *client.Client, rules []SyncRule } } -// SyncResult contains all sync reports and orphaned users report. -type SyncResult struct { - Reports []*SyncReport - OrphanedUsers *OrphanedUsersReport -} - // Sync executes all enabled sync rules and returns reports. // continues processing remaining rules even if some fail. -func (s *Syncer) Sync(ctx context.Context) (*SyncResult, error) { - var reports []*SyncReport +func (s *Syncer) Sync(ctx context.Context) (*domain.SyncResult, error) { + var reports []*domain.SyncReport + var enabledRuleCount int var failedRuleCount int for _, rule := range s.rules { @@ -81,6 +49,8 @@ func (s *Syncer) Sync(ctx context.Context) (*SyncResult, error) { continue } + enabledRuleCount++ + ruleReports, err := s.syncRule(ctx, rule) if err != nil { failedRuleCount++ @@ -88,8 +58,7 @@ func (s *Syncer) Sync(ctx context.Context) (*SyncResult, error) { slog.String("rule", rule.GetName()), slog.String("error", err.Error())) - // create a report for the failed rule so error is visible - reports = append(reports, &SyncReport{ + reports = append(reports, &domain.SyncReport{ Rule: rule.GetName(), OktaGroup: rule.OktaGroupName, GitHubTeam: rule.GitHubTeamName, @@ -101,11 +70,11 @@ func (s *Syncer) Sync(ctx context.Context) (*SyncResult, error) { reports = append(reports, ruleReports...) } - if failedRuleCount > 0 && failedRuleCount == len(reports) { + if enabledRuleCount > 0 && failedRuleCount == enabledRuleCount { return nil, errors.Newf("all sync rules failed: %d errors", failedRuleCount) } - return &SyncResult{ + return &domain.SyncResult{ Reports: reports, OrphanedUsers: nil, }, nil @@ -113,7 +82,7 @@ func (s *Syncer) Sync(ctx context.Context) (*SyncResult, error) { // DetectOrphanedUsers finds organization members not in any synced teams. // excludes external collaborators. -func (s *Syncer) DetectOrphanedUsers(ctx context.Context, syncedTeams []string) (*OrphanedUsersReport, error) { +func (s *Syncer) DetectOrphanedUsers(ctx context.Context, syncedTeams []string) (*domain.OrphanedUsersReport, error) { orgMembers, err := s.githubClient.ListOrgMembers(ctx) if err != nil { return nil, errors.Wrap(err, "failed to list organization members") @@ -150,34 +119,34 @@ func (s *Syncer) DetectOrphanedUsers(ctx context.Context, syncedTeams []string) } } - return &OrphanedUsersReport{ + return &domain.OrphanedUsersReport{ OrphanedUsers: orphanedUsers, }, nil } // syncRule executes a single sync rule. // supports both pattern matching and exact group name matching. -func (s *Syncer) syncRule(ctx context.Context, rule SyncRule) ([]*SyncReport, error) { - var reports []*SyncReport +func (s *Syncer) syncRule(ctx context.Context, rule domain.SyncRule) ([]*domain.SyncReport, error) { + var reports []*domain.SyncReport if rule.OktaGroupPattern != "" { - groups, err := s.oktaClient.GetGroupsByPattern(rule.OktaGroupPattern) + groups, err := s.oktaClient.GetGroupsByPattern(ctx, rule.OktaGroupPattern) if err != nil { return nil, errors.Wrapf(err, "failed to match groups with pattern '%s'", rule.OktaGroupPattern) } for _, group := range groups { - teamName := s.computeTeamName(group.Name, rule) + teamName := computeTeamName(group.Name, rule) report := s.syncGroupToTeam(ctx, rule, group, teamName) reports = append(reports, report) } } else if rule.OktaGroupName != "" { - group, err := s.oktaClient.GetGroupInfo(rule.OktaGroupName) + group, err := s.oktaClient.GetGroupInfo(ctx, rule.OktaGroupName) if err != nil { return nil, errors.Wrapf(err, "failed to fetch group '%s'", rule.OktaGroupName) } - teamName := s.computeTeamName(group.Name, rule) + teamName := computeTeamName(group.Name, rule) report := s.syncGroupToTeam(ctx, rule, group, teamName) reports = append(reports, report) } @@ -187,7 +156,7 @@ func (s *Syncer) syncRule(ctx context.Context, rule SyncRule) ([]*SyncReport, er // computeTeamName generates GitHub team name from Okta group name. // applies prefix stripping, prefix addition, and normalization. -func (s *Syncer) computeTeamName(oktaGroupName string, rule SyncRule) string { +func computeTeamName(oktaGroupName string, rule domain.SyncRule) string { if rule.GitHubTeamName != "" { return rule.GitHubTeamName } @@ -203,15 +172,16 @@ func (s *Syncer) computeTeamName(oktaGroupName string, rule SyncRule) string { } teamName = strings.ToLower(teamName) - teamName = regexp.MustCompile(`[^a-z0-9-]`).ReplaceAllString(teamName, "-") + teamName = teamNameNormalizer.ReplaceAllString(teamName, "-") + teamName = strings.Trim(teamName, "-") return teamName } // syncGroupToTeam synchronizes a single Okta group to a GitHub team. // creates team if missing and syncs members if enabled. -func (s *Syncer) syncGroupToTeam(ctx context.Context, rule SyncRule, group *GroupInfo, teamName string) *SyncReport { - report := &SyncReport{ +func (s *Syncer) syncGroupToTeam(ctx context.Context, rule domain.SyncRule, group *domain.GroupInfo, teamName string) *domain.SyncReport { + report := &domain.SyncReport{ Rule: rule.GetName(), OktaGroup: group.Name, GitHubTeam: teamName, diff --git a/internal/sync/sync_test.go b/internal/sync/sync_test.go new file mode 100644 index 0000000..59c2576 --- /dev/null +++ b/internal/sync/sync_test.go @@ -0,0 +1,390 @@ +package sync + +import ( + "context" + "log/slog" + "os" + "testing" + + "github.com/cockroachdb/errors" + "github.com/cruxstack/github-ops-app/internal/domain" + "github.com/google/go-github/v79/github" +) + +// mockGitHubClient implements domain.GitHubClient with overridable function +// fields for testing. +type mockGitHubClient struct { + getOrCreateTeamFn func(ctx context.Context, teamName, privacy string) (*github.Team, error) + syncTeamMembersFn func(ctx context.Context, teamSlug string, desiredMembers []string, threshold float64) (*domain.TeamSyncResult, error) + getTeamMembersFn func(ctx context.Context, teamSlug string) ([]string, error) + listOrgMembersFn func(ctx context.Context) ([]string, error) + isExternalCollaboratorFn func(ctx context.Context, username string) (bool, error) + getAppSlugFn func(ctx context.Context) (string, error) + checkPRComplianceFn func(ctx context.Context, owner, repo string, prNumber int) (*domain.PRComplianceResult, error) +} + +func (m *mockGitHubClient) CheckPRCompliance(ctx context.Context, owner, repo string, prNumber int) (*domain.PRComplianceResult, error) { + if m.checkPRComplianceFn != nil { + return m.checkPRComplianceFn(ctx, owner, repo, prNumber) + } + return &domain.PRComplianceResult{}, nil +} + +func (m *mockGitHubClient) GetOrCreateTeam(ctx context.Context, teamName, privacy string) (*github.Team, error) { + if m.getOrCreateTeamFn != nil { + return m.getOrCreateTeamFn(ctx, teamName, privacy) + } + slug := teamName + return &github.Team{Slug: &slug}, nil +} + +func (m *mockGitHubClient) SyncTeamMembers(ctx context.Context, teamSlug string, desiredMembers []string, threshold float64) (*domain.TeamSyncResult, error) { + if m.syncTeamMembersFn != nil { + return m.syncTeamMembersFn(ctx, teamSlug, desiredMembers, threshold) + } + return &domain.TeamSyncResult{ + TeamName: teamSlug, + MembersAdded: desiredMembers, + }, nil +} + +func (m *mockGitHubClient) GetTeamMembers(ctx context.Context, teamSlug string) ([]string, error) { + if m.getTeamMembersFn != nil { + return m.getTeamMembersFn(ctx, teamSlug) + } + return []string{}, nil +} + +func (m *mockGitHubClient) ListOrgMembers(ctx context.Context) ([]string, error) { + if m.listOrgMembersFn != nil { + return m.listOrgMembersFn(ctx) + } + return []string{}, nil +} + +func (m *mockGitHubClient) IsExternalCollaborator(ctx context.Context, username string) (bool, error) { + if m.isExternalCollaboratorFn != nil { + return m.isExternalCollaboratorFn(ctx, username) + } + return false, nil +} + +func (m *mockGitHubClient) GetAppSlug(ctx context.Context) (string, error) { + if m.getAppSlugFn != nil { + return m.getAppSlugFn(ctx) + } + return "test-app", nil +} + +func (m *mockGitHubClient) GetOrg() string { + return "test-org" +} + +// mockOktaClient implements domain.OktaClient with overridable function +// fields for testing. +type mockOktaClient struct { + getGroupsByPatternFn func(ctx context.Context, pattern string) ([]*domain.GroupInfo, error) + getGroupInfoFn func(ctx context.Context, groupName string) (*domain.GroupInfo, error) +} + +func (m *mockOktaClient) GetGroupsByPattern(ctx context.Context, pattern string) ([]*domain.GroupInfo, error) { + if m.getGroupsByPatternFn != nil { + return m.getGroupsByPatternFn(ctx, pattern) + } + return []*domain.GroupInfo{}, nil +} + +func (m *mockOktaClient) GetGroupInfo(ctx context.Context, groupName string) (*domain.GroupInfo, error) { + if m.getGroupInfoFn != nil { + return m.getGroupInfoFn(ctx, groupName) + } + return &domain.GroupInfo{ + ID: "group-1", + Name: groupName, + Members: []string{}, + }, nil +} + +func testLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) +} + +func TestSync_SingleRule(t *testing.T) { + oktaClient := &mockOktaClient{ + getGroupInfoFn: func(_ context.Context, name string) (*domain.GroupInfo, error) { + return &domain.GroupInfo{ + ID: "g1", + Name: "Engineering", + Members: []string{"alice", "bob"}, + }, nil + }, + } + + ghClient := &mockGitHubClient{ + syncTeamMembersFn: func(_ context.Context, teamSlug string, members []string, _ float64) (*domain.TeamSyncResult, error) { + return &domain.TeamSyncResult{ + TeamName: teamSlug, + MembersAdded: members, + }, nil + }, + } + + rules := []domain.SyncRule{ + { + OktaGroupName: "Engineering", + GitHubTeamName: "engineering", + }, + } + + syncer := NewSyncer(oktaClient, ghClient, rules, 0.5, testLogger()) + result, err := syncer.Sync(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Reports) != 1 { + t.Fatalf("expected 1 report, got %d", len(result.Reports)) + } + + report := result.Reports[0] + if report.GitHubTeam != "engineering" { + t.Errorf("expected team 'engineering', got %q", report.GitHubTeam) + } + if len(report.MembersAdded) != 2 { + t.Errorf("expected 2 members added, got %d", len(report.MembersAdded)) + } +} + +func TestSync_DisabledRule(t *testing.T) { + disabled := false + rules := []domain.SyncRule{ + { + Enabled: &disabled, + OktaGroupName: "Engineering", + GitHubTeamName: "engineering", + }, + } + + syncer := NewSyncer(&mockOktaClient{}, &mockGitHubClient{}, rules, 0.5, testLogger()) + result, err := syncer.Sync(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Reports) != 0 { + t.Errorf("expected 0 reports for disabled rule, got %d", len(result.Reports)) + } +} + +func TestSync_OktaError_PartialSuccess(t *testing.T) { + callCount := 0 + oktaClient := &mockOktaClient{ + getGroupInfoFn: func(_ context.Context, name string) (*domain.GroupInfo, error) { + callCount++ + if callCount == 1 { + return nil, errors.New("okta api error") + } + return &domain.GroupInfo{ID: "g2", Name: name, Members: []string{"charlie"}}, nil + }, + } + + rules := []domain.SyncRule{ + {OktaGroupName: "BadGroup", GitHubTeamName: "bad-team"}, + {OktaGroupName: "GoodGroup", GitHubTeamName: "good-team"}, + } + + syncer := NewSyncer(oktaClient, &mockGitHubClient{}, rules, 0.5, testLogger()) + result, err := syncer.Sync(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Reports) != 2 { + t.Fatalf("expected 2 reports, got %d", len(result.Reports)) + } + + // first report should have errors + if !result.Reports[0].HasErrors() { + t.Error("expected first report to have errors") + } + // second report should succeed + if result.Reports[1].HasErrors() { + t.Errorf("expected second report to have no errors, got: %v", result.Reports[1].Errors) + } +} + +func TestSync_AllRulesFail(t *testing.T) { + oktaClient := &mockOktaClient{ + getGroupInfoFn: func(_ context.Context, name string) (*domain.GroupInfo, error) { + return nil, errors.New("api failure") + }, + } + + rules := []domain.SyncRule{ + {OktaGroupName: "Group1", GitHubTeamName: "team1"}, + } + + syncer := NewSyncer(oktaClient, &mockGitHubClient{}, rules, 0.5, testLogger()) + _, err := syncer.Sync(context.Background()) + if err == nil { + t.Fatal("expected error when all rules fail") + } +} + +func TestSync_PatternMatching(t *testing.T) { + oktaClient := &mockOktaClient{ + getGroupsByPatternFn: func(_ context.Context, pattern string) ([]*domain.GroupInfo, error) { + return []*domain.GroupInfo{ + {ID: "g1", Name: "eng-frontend", Members: []string{"alice"}}, + {ID: "g2", Name: "eng-backend", Members: []string{"bob"}}, + }, nil + }, + } + + rules := []domain.SyncRule{ + {OktaGroupPattern: "^eng-.*$"}, + } + + syncer := NewSyncer(oktaClient, &mockGitHubClient{}, rules, 0.5, testLogger()) + result, err := syncer.Sync(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Reports) != 2 { + t.Fatalf("expected 2 reports for pattern match, got %d", len(result.Reports)) + } +} + +func TestSync_SkippedUsersTracked(t *testing.T) { + oktaClient := &mockOktaClient{ + getGroupInfoFn: func(_ context.Context, name string) (*domain.GroupInfo, error) { + return &domain.GroupInfo{ + ID: "g1", + Name: name, + Members: []string{"alice"}, + SkippedNoGitHubUsername: []string{"new-hire@example.com"}, + }, nil + }, + } + + rules := []domain.SyncRule{ + {OktaGroupName: "Engineering", GitHubTeamName: "engineering"}, + } + + syncer := NewSyncer(oktaClient, &mockGitHubClient{}, rules, 0.5, testLogger()) + result, err := syncer.Sync(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Reports[0].MembersSkippedNoGHUsername) != 1 { + t.Errorf("expected 1 skipped user, got %d", len(result.Reports[0].MembersSkippedNoGHUsername)) + } +} + +func TestDetectOrphanedUsers(t *testing.T) { + ghClient := &mockGitHubClient{ + listOrgMembersFn: func(_ context.Context) ([]string, error) { + return []string{"alice", "bob", "charlie"}, nil + }, + getTeamMembersFn: func(_ context.Context, teamSlug string) ([]string, error) { + return []string{"alice", "bob"}, nil + }, + isExternalCollaboratorFn: func(_ context.Context, username string) (bool, error) { + return false, nil + }, + } + + syncer := NewSyncer(&mockOktaClient{}, ghClient, nil, 0.5, testLogger()) + report, err := syncer.DetectOrphanedUsers(context.Background(), []string{"engineering"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(report.OrphanedUsers) != 1 { + t.Fatalf("expected 1 orphaned user, got %d", len(report.OrphanedUsers)) + } + if report.OrphanedUsers[0] != "charlie" { + t.Errorf("expected orphaned user 'charlie', got %q", report.OrphanedUsers[0]) + } +} + +func TestDetectOrphanedUsers_ExternalCollaboratorExcluded(t *testing.T) { + ghClient := &mockGitHubClient{ + listOrgMembersFn: func(_ context.Context) ([]string, error) { + return []string{"alice", "external-user"}, nil + }, + getTeamMembersFn: func(_ context.Context, _ string) ([]string, error) { + return []string{"alice"}, nil + }, + isExternalCollaboratorFn: func(_ context.Context, username string) (bool, error) { + return username == "external-user", nil + }, + } + + syncer := NewSyncer(&mockOktaClient{}, ghClient, nil, 0.5, testLogger()) + report, err := syncer.DetectOrphanedUsers(context.Background(), []string{"engineering"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(report.OrphanedUsers) != 0 { + t.Errorf("expected 0 orphaned users (external excluded), got %d", len(report.OrphanedUsers)) + } +} + +func TestComputeTeamName(t *testing.T) { + tests := []struct { + name string + groupName string + rule domain.SyncRule + want string + }{ + { + name: "explicit team name", + groupName: "Engineering", + rule: domain.SyncRule{GitHubTeamName: "custom-team"}, + want: "custom-team", + }, + { + name: "lowercase and normalize", + groupName: "My Team Name", + rule: domain.SyncRule{}, + want: "my-team-name", + }, + { + name: "strip prefix", + groupName: "DEPT-Engineering", + rule: domain.SyncRule{StripPrefix: "DEPT-"}, + want: "engineering", + }, + { + name: "add prefix", + groupName: "frontend", + rule: domain.SyncRule{GitHubTeamPrefix: "eng-"}, + want: "eng-frontend", + }, + { + name: "strip and add prefix", + groupName: "OKTA-backend", + rule: domain.SyncRule{StripPrefix: "OKTA-", GitHubTeamPrefix: "gh-"}, + want: "gh-backend", + }, + { + name: "special characters replaced and dashes collapsed", + groupName: "Team (US) & EU", + rule: domain.SyncRule{}, + want: "team-us-eu", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := computeTeamName(tt.groupName, tt.rule) + if got != tt.want { + t.Errorf("computeTeamName(%q) = %q, want %q", tt.groupName, got, tt.want) + } + }) + } +} diff --git a/internal/types/sync.go b/internal/types/sync.go deleted file mode 100644 index 303f1c6..0000000 --- a/internal/types/sync.go +++ /dev/null @@ -1,38 +0,0 @@ -// Package types provides shared type definitions used across packages. -package types - -// SyncRule defines how to sync Okta groups to GitHub teams. -type SyncRule struct { - Name string `json:"name"` - Enabled *bool `json:"enabled,omitempty"` - OktaGroupPattern string `json:"okta_group_pattern,omitempty"` - OktaGroupName string `json:"okta_group_name,omitempty"` - GitHubTeamPrefix string `json:"github_team_prefix,omitempty"` - GitHubTeamName string `json:"github_team_name,omitempty"` - StripPrefix string `json:"strip_prefix,omitempty"` - SyncMembers *bool `json:"sync_members,omitempty"` - CreateTeamIfMissing bool `json:"create_team_if_missing"` - TeamPrivacy string `json:"team_privacy,omitempty"` -} - -// IsEnabled returns true if the rule is enabled (defaults to true). -func (r SyncRule) IsEnabled() bool { - return r.Enabled == nil || *r.Enabled -} - -// ShouldSyncMembers returns true if members should be synced (defaults to -// true). -func (r SyncRule) ShouldSyncMembers() bool { - return r.SyncMembers == nil || *r.SyncMembers -} - -// GetName returns the rule name, defaulting to GitHubTeamName if not set. -func (r SyncRule) GetName() string { - if r.Name != "" { - return r.Name - } - if r.GitHubTeamName != "" { - return r.GitHubTeamName - } - return r.OktaGroupName -}