From 4fa386242506adb678461ab8e1b7b5352032c3c5 Mon Sep 17 00:00:00 2001 From: Brian Ojeda <9335829+sgtoj@users.noreply.github.com> Date: Mon, 16 Feb 2026 20:31:52 -0500 Subject: [PATCH 1/3] refactor: restructure packages, fix bugs, and improve test coverage Reorganize internal packages into domain/sync layers, fix multiple correctness bugs, harden the integration test infrastructure, and add comprehensive unit tests across all major packages. Architecture: - Extract shared App factory into internal/app/factory.go, removing duplicated newApp() from all four cmd/ entry points - Move domain types/interfaces/errors into internal/domain/ - Move sync logic into internal/sync/ with full unit tests - Add context.Context to OktaClient interface, remove stored-context anti-pattern - Add compile-time interface assertions for GitHubClient and OktaClient - Remove unused GitHubClientFactory and cross-installation code path Bug fixes: - Fix mapErrorResponse: errors.HasType() never matched errors.Mark() markers; switch to errors.Is() so domain errors return correct HTTP status codes (401/400/503/502) instead of always 500 - Fix GetOrCreateTeam: non-404 errors (403, 500) were misclassified as ErrTeamNotFound - Fix sync all-rules-failed detection: compared against len(reports) instead of tracking enabled rule count separately - Fix PR review deduplication: count only latest review per user - Add GitHub Check Runs support alongside legacy commit statuses - Fix token refresh TOCTOU race with double-check pattern - Add Okta API pagination to ListGroups, GetGroupByName, GetGroupMembers - Normalize username case in team membership comparisons Test infrastructure: - Replace hardcoded ports (9001-9003) with dynamic TLS listeners - Add IP SANs to self-signed test certificates - Add env var save/restore between integration test scenarios - Add unexpected-destructive-call validation - Fix test log handler Enabled() to respect verbose flag - Add request body size limit (10MB) to HTTP server Test coverage (before -> after): - internal/app: 24.5% -> 55.8% - internal/config: 14.8% -> 74.1% - internal/notifiers: 2.7% -> 97.3% - internal/sync: 82.8% -> 83.3% Cleanup: - Remove all fmt.Errorf usage, consistently use cockroachdb/errors - Remove dead ErrOAuthTokenExpired sentinel - Extract extractGroupName() helper in okta package - Pre-compile regex in sync.go computeTeamName - Trim leading/trailing dashes from generated team names --- .github/.dependabot.yaml | 5 + .github/workflows/ci.yaml | 4 +- .../github/{manafiest.json => manifest.json} | 0 cmd/lambda/main.go | 10 +- cmd/sample/main.go | 8 +- cmd/server/main.go | 6 +- cmd/verify/.env.test | 10 +- cmd/verify/logger.go | 2 +- cmd/verify/scenario.go | 164 +++-- cmd/verify/tls.go | 17 +- internal/app/app.go | 71 +-- internal/app/app_test.go | 16 +- internal/app/factory.go | 69 +++ internal/app/handlers.go | 31 +- internal/app/handlers_test.go | 583 ++++++++++++++++++ internal/app/request.go | 26 +- internal/app/testdata.go | 17 +- internal/config/config.go | 26 +- internal/config/config_helpers_test.go | 298 +++++++++ internal/{errors => domain}/errors.go | 8 +- internal/domain/interfaces.go | 64 ++ internal/domain/types.go | 126 ++++ internal/github/client/client.go | 12 +- internal/github/client/pr.go | 103 ++-- internal/github/client/teams.go | 75 ++- internal/github/webhooks/webhooks.go | 18 +- internal/github/webhooks/webhooks_test.go | 355 +++++++++++ internal/notifiers/slack_messages.go | 14 +- internal/notifiers/slack_messages_test.go | 235 +++++++ internal/okta/client.go | 135 ++-- internal/okta/groups.go | 91 ++- internal/{okta => sync}/sync.go | 94 +-- internal/sync/sync_test.go | 390 ++++++++++++ internal/types/sync.go | 38 -- 34 files changed, 2606 insertions(+), 515 deletions(-) rename assets/github/{manafiest.json => manifest.json} (100%) create mode 100644 internal/app/factory.go create mode 100644 internal/app/handlers_test.go create mode 100644 internal/config/config_helpers_test.go rename internal/{errors => domain}/errors.go (88%) create mode 100644 internal/domain/interfaces.go create mode 100644 internal/domain/types.go create mode 100644 internal/github/webhooks/webhooks_test.go create mode 100644 internal/notifiers/slack_messages_test.go rename internal/{okta => sync}/sync.go (68%) create mode 100644 internal/sync/sync_test.go delete mode 100644 internal/types/sync.go diff --git a/.github/.dependabot.yaml b/.github/.dependabot.yaml index 1230149..f3f3cd3 100644 --- a/.github/.dependabot.yaml +++ b/.github/.dependabot.yaml @@ -4,3 +4,8 @@ updates: directory: "/" schedule: interval: "daily" + + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 9ba4e80..d166b19 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -46,10 +46,10 @@ jobs: with: go-version: '1.24' - - name: Run tests + - name: Run unit tests run: make test - - name: Run tests + - name: Run integration tests run: make test-verify-verbose build: diff --git a/assets/github/manafiest.json b/assets/github/manifest.json similarity index 100% rename from assets/github/manafiest.json rename to assets/github/manifest.json diff --git a/cmd/lambda/main.go b/cmd/lambda/main.go index 72e7f6e..bac0b91 100644 --- a/cmd/lambda/main.go +++ b/cmd/lambda/main.go @@ -3,13 +3,13 @@ package main import ( "context" "encoding/json" - "fmt" "log/slog" "strings" "sync" awsevents "github.com/aws/aws-lambda-go/events" "github.com/aws/aws-lambda-go/lambda" + "github.com/cockroachdb/errors" "github.com/cruxstack/github-ops-app/internal/app" "github.com/cruxstack/github-ops-app/internal/config" ) @@ -27,10 +27,10 @@ func initApp() { cfg, err := config.NewConfig() if err != nil { - initErr = fmt.Errorf("config init failed: %w", err) + initErr = errors.Wrap(err, "config init failed") return } - appInst, initErr = app.New(context.Background(), cfg) + appInst, initErr = app.NewApp(context.Background(), cfg, logger) }) } @@ -99,7 +99,7 @@ func EventBridgeHandler(ctx context.Context, evt awsevents.CloudWatchEvent) erro resp := appInst.HandleRequest(ctx, req) if resp.StatusCode >= 400 { - return fmt.Errorf("scheduled event failed: %s", string(resp.Body)) + return errors.Newf("scheduled event failed: %s", string(resp.Body)) } return nil @@ -122,7 +122,7 @@ func UniversalHandler(ctx context.Context, event json.RawMessage) (any, error) { return nil, EventBridgeHandler(ctx, eventBridgeEvent) } - return nil, fmt.Errorf("unknown lambda event type") + return nil, errors.New("unknown lambda event type") } func main() { diff --git a/cmd/sample/main.go b/cmd/sample/main.go index e7fa906..3cf5a2d 100644 --- a/cmd/sample/main.go +++ b/cmd/sample/main.go @@ -32,7 +32,7 @@ func main() { os.Exit(1) } - a, err := app.New(ctx, cfg) + a, err := app.NewApp(ctx, cfg, logger) if err != nil { logger.Error("failed to initialize app", slog.String("error", err.Error())) os.Exit(1) @@ -52,7 +52,11 @@ func main() { } for i, sample := range samples { - eventType := sample["event_type"].(string) + eventType, ok := sample["event_type"].(string) + if !ok { + logger.Error("missing or invalid event_type", slog.Int("sample", i)) + os.Exit(1) + } switch eventType { case "okta_sync": diff --git a/cmd/server/main.go b/cmd/server/main.go index cc4b84c..cd28c79 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -30,7 +30,7 @@ func main() { os.Exit(1) } - appInst, err = app.New(ctx, cfg) + appInst, err = app.NewApp(ctx, cfg, logger) if err != nil { logger.Error("app init failed", slog.String("error", err.Error())) os.Exit(1) @@ -83,12 +83,12 @@ func main() { // httpHandler converts http.Request to app.Request and handles the response. func httpHandler(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) + defer r.Body.Close() + body, err := io.ReadAll(io.LimitReader(r.Body, 10<<20)) // 10MB limit if err != nil { http.Error(w, "failed to read request body", http.StatusBadRequest) return } - defer r.Body.Close() headers := make(map[string]string) for key, values := range r.Header { diff --git a/cmd/verify/.env.test b/cmd/verify/.env.test index d070bb9..9a375a7 100644 --- a/cmd/verify/.env.test +++ b/cmd/verify/.env.test @@ -1,8 +1,9 @@ # config used during offline verification testing # -# - all api calls are mocked all config is fake -# - private keys is dynamically generated during test setup -# - endpoints are https with tls configured during setup +# - all api calls are mocked and all config is fake +# - private keys are dynamically generated during test setup +# - endpoints use dynamic ports assigned at runtime +# - base urls are set per-scenario by the test runner APP_DEBUG_ENABLED=false @@ -11,7 +12,6 @@ APP_GITHUB_APP_ID=123456 APP_GITHUB_INSTALLATION_ID=987654 APP_GITHUB_ORG=acme-ghorg APP_GITHUB_WEBHOOK_SECRET=test_webhook_secret -APP_GITHUB_BASE_URL=https://localhost:9001/ # github pr compliance configuration APP_PR_COMPLIANCE_ENABLED=true @@ -19,7 +19,6 @@ APP_PR_MONITORED_BRANCHES=main,master # okta configuration (oauth2 with private key) APP_OKTA_DOMAIN=dev-12345.okta.com -APP_OKTA_CLIENT_ID=test-client-id # okta sync rules APP_OKTA_GITHUB_USER_FIELD=githubUsername @@ -28,4 +27,3 @@ APP_OKTA_SYNC_RULES=[{"enabled":true,"okta_group_name":"Engineering","github_tea # slack configuration APP_SLACK_TOKEN=xoxb-test-token APP_SLACK_CHANNEL=C01234TEST -APP_SLACK_API_URL=https://localhost:9003/ diff --git a/cmd/verify/logger.go b/cmd/verify/logger.go index 1780d7d..d7b9f68 100644 --- a/cmd/verify/logger.go +++ b/cmd/verify/logger.go @@ -18,7 +18,7 @@ type testHandler struct { // Enabled returns true for all log levels when verbose mode is enabled. func (h *testHandler) Enabled(_ context.Context, _ slog.Level) bool { - return true + return h.verbose } // Handle formats and writes log records to output with test-appropriate diff --git a/cmd/verify/scenario.go b/cmd/verify/scenario.go index 4b00464..03c7ee5 100644 --- a/cmd/verify/scenario.go +++ b/cmd/verify/scenario.go @@ -10,8 +10,10 @@ import ( "os" "time" + "github.com/cockroachdb/errors" "github.com/cruxstack/github-ops-app/internal/app" "github.com/cruxstack/github-ops-app/internal/config" + oktaclient "github.com/cruxstack/github-ops-app/internal/okta" ) // TestScenario defines a test case with input events and expected outcomes. @@ -65,74 +67,58 @@ func runScenario(ctx context.Context, scenario TestScenario, verbose bool, logge tlsCert, certPool, err := generateSelfSignedCert() if err != nil { - return fmt.Errorf("generate cert: %w", err) + return errors.Wrap(err, "failed to generate cert") } githubAppKey, err := generateOAuthPrivateKey() if err != nil { - return fmt.Errorf("generate github app key: %w", err) + return errors.Wrap(err, "failed to generate github app key") } - os.Setenv("APP_GITHUB_APP_PRIVATE_KEY", string(githubAppKey)) oauthKey, err := generateOAuthPrivateKey() if err != nil { - return fmt.Errorf("generate oauth key: %w", err) + return errors.Wrap(err, "failed to generate oauth key") } - os.Setenv("APP_OKTA_CLIENT_ID", "test-client-id") - os.Setenv("APP_OKTA_PRIVATE_KEY", string(oauthKey)) - githubServer := &http.Server{ - Addr: "localhost:9001", - Handler: githubMock, - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - }, + // use dynamic ports: bind to :0 and extract the assigned port + tlsConfig := &tls.Config{Certificates: []tls.Certificate{tlsCert}} + + githubListener, err := tls.Listen("tcp", "localhost:0", tlsConfig) + if err != nil { + return errors.Wrap(err, "failed to listen github") } - oktaServer := &http.Server{ - Addr: "localhost:9002", - Handler: oktaMock, - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - }, + oktaListener, err := tls.Listen("tcp", "localhost:0", tlsConfig) + if err != nil { + githubListener.Close() + return errors.Wrap(err, "failed to listen okta") } - slackServer := &http.Server{ - Addr: "localhost:9003", - Handler: slackMock, - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - }, + slackListener, err := tls.Listen("tcp", "localhost:0", tlsConfig) + if err != nil { + githubListener.Close() + oktaListener.Close() + return errors.Wrap(err, "failed to listen slack") } - githubReady := make(chan bool) - oktaReady := make(chan bool) - slackReady := make(chan bool) + githubServer := &http.Server{Handler: githubMock} + oktaServer := &http.Server{Handler: oktaMock} + slackServer := &http.Server{Handler: slackMock} go func() { - githubReady <- true - if err := githubServer.ListenAndServeTLS("", ""); err != http.ErrServerClosed { + if err := githubServer.Serve(githubListener); err != http.ErrServerClosed { logger.Error("github mock server error", slog.String("error", err.Error())) } }() - go func() { - oktaReady <- true - if err := oktaServer.ListenAndServeTLS("", ""); err != http.ErrServerClosed { + if err := oktaServer.Serve(oktaListener); err != http.ErrServerClosed { logger.Error("okta mock server error", slog.String("error", err.Error())) } }() - go func() { - slackReady <- true - if err := slackServer.ListenAndServeTLS("", ""); err != http.ErrServerClosed { + if err := slackServer.Serve(slackListener); err != http.ErrServerClosed { logger.Error("slack mock server error", slog.String("error", err.Error())) } }() - <-githubReady - <-oktaReady - <-slackReady - time.Sleep(100 * time.Millisecond) - defer func() { shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() @@ -141,17 +127,51 @@ func runScenario(ctx context.Context, scenario TestScenario, verbose bool, logge slackServer.Shutdown(shutdownCtx) }() + githubAddr := fmt.Sprintf("https://%s/", githubListener.Addr().String()) + oktaAddr := fmt.Sprintf("https://%s", oktaListener.Addr().String()) + slackAddr := fmt.Sprintf("https://%s/", slackListener.Addr().String()) + + // save and restore environment variables for isolation between scenarios + envKeys := []string{ + "APP_GITHUB_APP_PRIVATE_KEY", "APP_OKTA_CLIENT_ID", "APP_OKTA_PRIVATE_KEY", + "APP_GITHUB_BASE_URL", "APP_SLACK_API_URL", "APP_OKTA_BASE_URL", + "APP_OKTA_ORPHANED_USER_NOTIFICATIONS", + } + for key := range scenario.ConfigOverrides { + envKeys = append(envKeys, key) + } + savedEnv := make(map[string]string, len(envKeys)) + for _, key := range envKeys { + savedEnv[key] = os.Getenv(key) + } + defer func() { + for key, value := range savedEnv { + if value == "" { + os.Unsetenv(key) + } else { + os.Setenv(key, value) + } + } + }() + + os.Setenv("APP_GITHUB_APP_PRIVATE_KEY", string(githubAppKey)) + os.Setenv("APP_OKTA_CLIENT_ID", "test-client-id") + os.Setenv("APP_OKTA_PRIVATE_KEY", string(oauthKey)) + os.Setenv("APP_GITHUB_BASE_URL", githubAddr) + os.Setenv("APP_SLACK_API_URL", slackAddr) + os.Setenv("APP_OKTA_BASE_URL", oktaAddr) + + // save and restore http.DefaultTransport + savedTransport := http.DefaultTransport + defer func() { http.DefaultTransport = savedTransport }() + http.DefaultTransport = &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: certPool, }, } - os.Setenv("APP_GITHUB_BASE_URL", "https://localhost:9001/") - os.Setenv("APP_SLACK_API_URL", "https://localhost:9003/") - os.Setenv("APP_OKTA_BASE_URL", "https://localhost:9002") - - ctx = context.WithValue(ctx, "okta_tls_cert_pool", certPool) + ctx = oktaclient.WithCertPool(ctx, certPool) if os.Getenv("APP_OKTA_ORPHANED_USER_NOTIFICATIONS") == "" { os.Setenv("APP_OKTA_ORPHANED_USER_NOTIFICATIONS", "false") @@ -163,27 +183,26 @@ func runScenario(ctx context.Context, scenario TestScenario, verbose bool, logge cfg, err := config.NewConfig() if err != nil { - return fmt.Errorf("config creation failed: %w", err) + return errors.Wrap(err, "config creation failed") } - a, err := app.New(ctx, cfg) + appLogger := slog.New(&testHandler{prefix: " ", verbose: verbose, w: os.Stdout}) + + a, err := app.NewApp(ctx, cfg, appLogger) if err != nil { - return fmt.Errorf("app creation failed: %w", err) + return errors.Wrap(err, "app creation failed") } if verbose { fmt.Printf("\n Application Output:\n") } - appLogger := slog.New(&testHandler{prefix: " ", verbose: verbose, w: os.Stdout}) - a.Logger = appLogger - var req app.Request switch scenario.EventType { case "scheduled_event": var evt app.ScheduledEvent if err := json.Unmarshal(scenario.EventPayload, &evt); err != nil { - return fmt.Errorf("unmarshal event payload failed: %w", err) + return errors.Wrap(err, "failed to unmarshal event payload") } req = app.Request{ Type: app.RequestTypeScheduled, @@ -204,26 +223,26 @@ func runScenario(ctx context.Context, scenario TestScenario, verbose bool, logge } default: - return fmt.Errorf("unknown event type: %s", scenario.EventType) + return errors.Newf("unknown event type: %s", scenario.EventType) } resp := a.HandleRequest(ctx, req) var processErr error if resp.StatusCode >= 400 { - processErr = fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(resp.Body)) + processErr = errors.Newf("request failed with status %d: %s", resp.StatusCode, string(resp.Body)) } if scenario.ExpectError { if processErr == nil { - return fmt.Errorf("expected error but processing succeeded") + return errors.New("expected error but processing succeeded") } if verbose { fmt.Printf(" ✓ Expected error occurred: %v\n", processErr) } } else { if processErr != nil { - return fmt.Errorf("process event failed: %w", processErr) + return errors.Wrap(processErr, "process event failed") } } @@ -244,6 +263,12 @@ func runScenario(ctx context.Context, scenario TestScenario, verbose bool, logge fmt.Printf("\n") } + if err := validateNoUnexpectedCalls(scenario.ExpectedCalls, allReqs); err != nil { + fmt.Printf("\n Validation:\n") + fmt.Printf(" ✗ FAILED: %v\n", err) + return err + } + if err := validateExpectedCalls(scenario.ExpectedCalls, allReqs); err != nil { fmt.Printf("\n Validation:\n") fmt.Printf(" ✗ FAILED: %v\n", err) @@ -294,7 +319,32 @@ func validateExpectedCalls(expected []ExpectedCall, allReqs map[string][]Request } } if !found { - return fmt.Errorf("expected call not found: %s %s %s", exp.Service, exp.Method, exp.Path) + return errors.Newf("expected call not found: %s %s %s", exp.Service, exp.Method, exp.Path) + } + } + return nil +} + +// validateNoUnexpectedCalls checks that no unexpected destructive API calls +// were made. only flags DELETE calls to catch unintended member removal or +// resource deletion. +func validateNoUnexpectedCalls(expected []ExpectedCall, allReqs map[string][]RequestRecord) error { + for service, reqs := range allReqs { + for _, req := range reqs { + // only flag unexpected destructive calls + if req.Method != "DELETE" { + continue + } + matched := false + for _, exp := range expected { + if exp.Service == service && exp.Method == req.Method && matchPath(req.Path, exp.Path) { + matched = true + break + } + } + if !matched { + return errors.Newf("unexpected destructive call: %s %s %s", service, req.Method, req.Path) + } } } return nil diff --git a/cmd/verify/tls.go b/cmd/verify/tls.go index 0cab7d2..acbb4a2 100644 --- a/cmd/verify/tls.go +++ b/cmd/verify/tls.go @@ -7,9 +7,11 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" - "fmt" "math/big" + "net" "time" + + "github.com/cockroachdb/errors" ) // generateOAuthPrivateKey creates an RSA private key for OAuth testing. @@ -17,7 +19,7 @@ import ( func generateOAuthPrivateKey() ([]byte, error) { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return nil, fmt.Errorf("generate oauth key: %w", err) + return nil, errors.Wrap(err, "failed to generate oauth key") } keyPEM := pem.EncodeToMemory(&pem.Block{ @@ -33,7 +35,7 @@ func generateOAuthPrivateKey() ([]byte, error) { func generateSelfSignedCert() (tls.Certificate, *x509.CertPool, error) { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { - return tls.Certificate{}, nil, fmt.Errorf("generate key: %w", err) + return tls.Certificate{}, nil, errors.Wrap(err, "failed to generate key") } notBefore := time.Now() @@ -41,7 +43,7 @@ func generateSelfSignedCert() (tls.Certificate, *x509.CertPool, error) { serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) if err != nil { - return tls.Certificate{}, nil, fmt.Errorf("generate serial: %w", err) + return tls.Certificate{}, nil, errors.Wrap(err, "failed to generate serial") } template := x509.Certificate{ @@ -56,11 +58,12 @@ func generateSelfSignedCert() (tls.Certificate, *x509.CertPool, error) { ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, DNSNames: []string{"localhost"}, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1), net.IPv6loopback}, } certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { - return tls.Certificate{}, nil, fmt.Errorf("create cert: %w", err) + return tls.Certificate{}, nil, errors.Wrap(err, "failed to create cert") } certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) @@ -68,12 +71,12 @@ func generateSelfSignedCert() (tls.Certificate, *x509.CertPool, error) { tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { - return tls.Certificate{}, nil, fmt.Errorf("create keypair: %w", err) + return tls.Certificate{}, nil, errors.Wrap(err, "failed to create keypair") } cert, err := x509.ParseCertificate(certDER) if err != nil { - return tls.Certificate{}, nil, fmt.Errorf("parse cert: %w", err) + return tls.Certificate{}, nil, errors.Wrap(err, "failed to parse cert") } certPool := x509.NewCertPool() diff --git a/internal/app/app.go b/internal/app/app.go index 8d6516b..7e12234 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -9,76 +9,17 @@ import ( "github.com/cockroachdb/errors" "github.com/cruxstack/github-ops-app/internal/config" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" - "github.com/cruxstack/github-ops-app/internal/github/client" - "github.com/cruxstack/github-ops-app/internal/notifiers" - "github.com/cruxstack/github-ops-app/internal/okta" + "github.com/cruxstack/github-ops-app/internal/domain" ) // App is the main application instance containing all clients and -// configuration. +// configuration. depends on domain interfaces, not concrete implementations. type App struct { Config *config.Config Logger *slog.Logger - GitHubClient *client.Client - OktaClient *okta.Client - Notifier *notifiers.SlackNotifier -} - -// New creates a new App instance with configured clients. -// Initializes GitHub, Okta, and Slack clients based on config. -func New(ctx context.Context, cfg *config.Config) (*App, error) { - logger := config.NewLogger() - - app := &App{ - Config: cfg, - Logger: logger, - } - - if cfg.IsGitHubConfigured() { - ghClient, err := client.NewAppClientWithBaseURL( - cfg.GitHubAppID, - cfg.GitHubInstallationID, - cfg.GitHubAppPrivateKey, - cfg.GitHubOrg, - cfg.GitHubBaseURL, - ) - if err != nil { - return nil, errors.Wrap(err, "failed to create github app client") - } - app.GitHubClient = ghClient - } - - if cfg.IsOktaSyncEnabled() { - oktaClient, err := okta.NewClientWithContext(ctx, &okta.ClientConfig{ - Domain: cfg.OktaDomain, - ClientID: cfg.OktaClientID, - PrivateKey: cfg.OktaPrivateKey, - PrivateKeyID: cfg.OktaPrivateKeyID, - Scopes: cfg.OktaScopes, - GitHubUserField: cfg.OktaGitHubUserField, - BaseURL: cfg.OktaBaseURL, - }) - if err != nil { - return nil, errors.Wrap(err, "failed to create okta client") - } - app.OktaClient = oktaClient - } - - if cfg.SlackEnabled { - channels := notifiers.SlackChannels{ - Default: cfg.SlackChannel, - PRBypass: cfg.SlackChannelPRBypass, - OktaSync: cfg.SlackChannelOktaSync, - OrphanedUsers: cfg.SlackChannelOrphanedUsers, - } - messages := notifiers.SlackMessages{ - PRBypassFooterNote: cfg.SlackPRBypassFooterNote, - } - app.Notifier = notifiers.NewSlackNotifierWithAPIURL(cfg.SlackToken, channels, messages, cfg.SlackAPIURL) - } - - return app, nil + GitHubClient domain.GitHubClient + OktaClient domain.OktaClient + Notifier domain.Notifier } // ScheduledEvent represents a generic scheduled event. @@ -120,7 +61,7 @@ func (a *App) ProcessWebhook(ctx context.Context, payload []byte, eventType stri case "membership": return a.handleMembershipWebhook(ctx, payload) default: - return errors.Wrapf(internalerrors.ErrInvalidEventType, "%s", eventType) + return errors.Wrapf(domain.ErrInvalidEventType, "%s", eventType) } } diff --git a/internal/app/app_test.go b/internal/app/app_test.go index ee8bf0c..7a12eec 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -7,8 +7,7 @@ import ( "testing" "github.com/cruxstack/github-ops-app/internal/config" - "github.com/cruxstack/github-ops-app/internal/github/client" - "github.com/cruxstack/github-ops-app/internal/okta" + "github.com/cruxstack/github-ops-app/internal/domain" ) func TestHandleSlackTest_NotConfigured(t *testing.T) { @@ -158,16 +157,11 @@ func TestProcessScheduledEvent_UnknownAction(t *testing.T) { } } -// verify fake data types match expected interfaces +// verify fake data types match expected domain types func TestFakeDataTypes(t *testing.T) { - // ensure fake PR result is compatible with notifier - var _ *client.PRComplianceResult = fakePRComplianceResult() - - // ensure fake sync reports are compatible with notifier - var _ []*okta.SyncReport = fakeOktaSyncReports() - - // ensure fake orphaned users report is compatible with notifier - var _ *okta.OrphanedUsersReport = fakeOrphanedUsersReport() + var _ *domain.PRComplianceResult = fakePRComplianceResult() + var _ []*domain.SyncReport = fakeOktaSyncReports() + var _ *domain.OrphanedUsersReport = fakeOrphanedUsersReport() } func TestCheckAdminAuth(t *testing.T) { diff --git a/internal/app/factory.go b/internal/app/factory.go new file mode 100644 index 0000000..ac2beb4 --- /dev/null +++ b/internal/app/factory.go @@ -0,0 +1,69 @@ +package app + +import ( + "context" + "log/slog" + + "github.com/cruxstack/github-ops-app/internal/config" + ghclient "github.com/cruxstack/github-ops-app/internal/github/client" + "github.com/cruxstack/github-ops-app/internal/notifiers" + "github.com/cruxstack/github-ops-app/internal/okta" +) + +// NewApp creates a new App instance with configured clients (composition +// root). concrete implementations are instantiated here and wired as domain +// interfaces. this is the single shared factory used by all entry points. +func NewApp(ctx context.Context, cfg *config.Config, logger *slog.Logger) (*App, error) { + a := &App{ + Config: cfg, + Logger: logger, + } + + if cfg.IsGitHubConfigured() { + ghClient, err := ghclient.NewAppClientWithBaseURL( + cfg.GitHubAppID, + cfg.GitHubInstallationID, + cfg.GitHubAppPrivateKey, + cfg.GitHubOrg, + cfg.GitHubBaseURL, + ) + if err != nil { + return nil, err + } + a.GitHubClient = ghClient + } + + if cfg.IsOktaSyncEnabled() { + oktaClient, err := okta.NewClientWithContext(ctx, &okta.ClientConfig{ + Domain: cfg.OktaDomain, + ClientID: cfg.OktaClientID, + PrivateKey: cfg.OktaPrivateKey, + PrivateKeyID: cfg.OktaPrivateKeyID, + Scopes: cfg.OktaScopes, + GitHubUserField: cfg.OktaGitHubUserField, + BaseURL: cfg.OktaBaseURL, + Logger: logger, + }) + if err != nil { + return nil, err + } + a.OktaClient = oktaClient + } + + if cfg.SlackEnabled { + channels := notifiers.SlackChannels{ + Default: cfg.SlackChannel, + PRBypass: cfg.SlackChannelPRBypass, + OktaSync: cfg.SlackChannelOktaSync, + OrphanedUsers: cfg.SlackChannelOrphanedUsers, + } + messages := notifiers.SlackMessages{ + PRBypassFooterNote: cfg.SlackPRBypassFooterNote, + } + a.Notifier = notifiers.NewSlackNotifierWithAPIURL( + cfg.SlackToken, channels, messages, cfg.SlackAPIURL, + ) + } + + return a, nil +} diff --git a/internal/app/handlers.go b/internal/app/handlers.go index 6fa6c01..82b5a5e 100644 --- a/internal/app/handlers.go +++ b/internal/app/handlers.go @@ -5,10 +5,9 @@ import ( "log/slog" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" - "github.com/cruxstack/github-ops-app/internal/github/client" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/cruxstack/github-ops-app/internal/github/webhooks" - "github.com/cruxstack/github-ops-app/internal/okta" + "github.com/cruxstack/github-ops-app/internal/sync" ) // handleOktaSync executes Okta group synchronization to GitHub teams. @@ -20,10 +19,10 @@ func (a *App) handleOktaSync(ctx context.Context) error { } if a.OktaClient == nil || a.GitHubClient == nil { - return errors.Wrap(internalerrors.ErrClientNotInit, "okta or github client") + return errors.Wrap(domain.ErrClientNotInit, "okta or github client") } - syncer := okta.NewSyncer(a.OktaClient, a.GitHubClient, a.Config.OktaSyncRules, a.Config.OktaSyncSafetyThreshold, a.Logger) + syncer := sync.NewSyncer(a.OktaClient, a.GitHubClient, a.Config.OktaSyncRules, a.Config.OktaSyncSafetyThreshold, a.Logger) syncResult, err := syncer.Sync(ctx) if err != nil { return errors.Wrap(err, "okta sync failed") @@ -83,30 +82,14 @@ func (a *App) handlePullRequestWebhook(ctx context.Context, payload []byte) erro return nil } - ghClient := a.GitHubClient - - if prEvent.GetInstallationID() != 0 && prEvent.GetInstallationID() != a.Config.GitHubInstallationID { - installClient, err := client.NewAppClientWithBaseURL( - a.Config.GitHubAppID, - prEvent.GetInstallationID(), - a.Config.GitHubAppPrivateKey, - a.Config.GitHubOrg, - a.Config.GitHubBaseURL, - ) - if err != nil { - return errors.Wrapf(err, "failed to create client for installation %d", prEvent.GetInstallationID()) - } - ghClient = installClient - } - - if ghClient == nil { - return errors.Wrap(internalerrors.ErrClientNotInit, "github client") + if a.GitHubClient == nil { + return errors.Wrap(domain.ErrClientNotInit, "github client") } owner := prEvent.GetRepoOwner() repo := prEvent.GetRepoName() - result, err := ghClient.CheckPRCompliance(ctx, owner, repo, prEvent.Number) + result, err := a.GitHubClient.CheckPRCompliance(ctx, owner, repo, prEvent.Number) if err != nil { return errors.Wrapf(err, "failed to check pr #%d compliance", prEvent.Number) } diff --git a/internal/app/handlers_test.go b/internal/app/handlers_test.go new file mode 100644 index 0000000..07e9d14 --- /dev/null +++ b/internal/app/handlers_test.go @@ -0,0 +1,583 @@ +package app + +import ( + "context" + "encoding/json" + "io" + "log/slog" + "testing" + + "github.com/cockroachdb/errors" + "github.com/cruxstack/github-ops-app/internal/config" + "github.com/cruxstack/github-ops-app/internal/domain" + "github.com/google/go-github/v79/github" +) + +// --- mock implementations --- + +type mockGitHubClient struct { + checkPRComplianceFn func(ctx context.Context, owner, repo string, prNumber int) (*domain.PRComplianceResult, error) + getOrCreateTeamFn func(ctx context.Context, teamName, privacy string) (*github.Team, error) + syncTeamMembersFn func(ctx context.Context, teamSlug string, desiredMembers []string, threshold float64) (*domain.TeamSyncResult, error) + getTeamMembersFn func(ctx context.Context, teamSlug string) ([]string, error) + listOrgMembersFn func(ctx context.Context) ([]string, error) + isExternalCollaboratorFn func(ctx context.Context, username string) (bool, error) + getAppSlugFn func(ctx context.Context) (string, error) +} + +func (m *mockGitHubClient) CheckPRCompliance(ctx context.Context, owner, repo string, prNumber int) (*domain.PRComplianceResult, error) { + if m.checkPRComplianceFn != nil { + return m.checkPRComplianceFn(ctx, owner, repo, prNumber) + } + return &domain.PRComplianceResult{}, nil +} +func (m *mockGitHubClient) GetOrCreateTeam(ctx context.Context, teamName, privacy string) (*github.Team, error) { + if m.getOrCreateTeamFn != nil { + return m.getOrCreateTeamFn(ctx, teamName, privacy) + } + slug := teamName + return &github.Team{Slug: &slug}, nil +} +func (m *mockGitHubClient) SyncTeamMembers(ctx context.Context, teamSlug string, desiredMembers []string, threshold float64) (*domain.TeamSyncResult, error) { + if m.syncTeamMembersFn != nil { + return m.syncTeamMembersFn(ctx, teamSlug, desiredMembers, threshold) + } + return &domain.TeamSyncResult{TeamName: teamSlug, MembersAdded: desiredMembers}, nil +} +func (m *mockGitHubClient) GetTeamMembers(ctx context.Context, teamSlug string) ([]string, error) { + if m.getTeamMembersFn != nil { + return m.getTeamMembersFn(ctx, teamSlug) + } + return []string{}, nil +} +func (m *mockGitHubClient) ListOrgMembers(ctx context.Context) ([]string, error) { + if m.listOrgMembersFn != nil { + return m.listOrgMembersFn(ctx) + } + return []string{}, nil +} +func (m *mockGitHubClient) IsExternalCollaborator(ctx context.Context, username string) (bool, error) { + if m.isExternalCollaboratorFn != nil { + return m.isExternalCollaboratorFn(ctx, username) + } + return false, nil +} +func (m *mockGitHubClient) GetAppSlug(ctx context.Context) (string, error) { + if m.getAppSlugFn != nil { + return m.getAppSlugFn(ctx) + } + return "test-app", nil +} +func (m *mockGitHubClient) GetOrg() string { return "test-org" } + +type mockOktaClient struct { + getGroupsByPatternFn func(ctx context.Context, pattern string) ([]*domain.GroupInfo, error) + getGroupInfoFn func(ctx context.Context, groupName string) (*domain.GroupInfo, error) +} + +func (m *mockOktaClient) GetGroupsByPattern(ctx context.Context, pattern string) ([]*domain.GroupInfo, error) { + if m.getGroupsByPatternFn != nil { + return m.getGroupsByPatternFn(ctx, pattern) + } + return []*domain.GroupInfo{}, nil +} +func (m *mockOktaClient) GetGroupInfo(ctx context.Context, groupName string) (*domain.GroupInfo, error) { + if m.getGroupInfoFn != nil { + return m.getGroupInfoFn(ctx, groupName) + } + return &domain.GroupInfo{ID: "g1", Name: groupName, Members: []string{}}, nil +} + +type mockNotifier struct { + notifyPRBypassFn func(ctx context.Context, result *domain.PRComplianceResult, repoFullName string) error + notifyOktaSyncFn func(ctx context.Context, reports []*domain.SyncReport, githubOrg string) error + notifyOrphanedUsersFn func(ctx context.Context, report *domain.OrphanedUsersReport) error +} + +func (m *mockNotifier) NotifyPRBypass(ctx context.Context, result *domain.PRComplianceResult, repoFullName string) error { + if m.notifyPRBypassFn != nil { + return m.notifyPRBypassFn(ctx, result, repoFullName) + } + return nil +} +func (m *mockNotifier) NotifyOktaSync(ctx context.Context, reports []*domain.SyncReport, githubOrg string) error { + if m.notifyOktaSyncFn != nil { + return m.notifyOktaSyncFn(ctx, reports, githubOrg) + } + return nil +} +func (m *mockNotifier) NotifyOrphanedUsers(ctx context.Context, report *domain.OrphanedUsersReport) error { + if m.notifyOrphanedUsersFn != nil { + return m.notifyOrphanedUsersFn(ctx, report) + } + return nil +} + +func discardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +// --- handler tests --- + +func TestHandleOktaSync_Disabled(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + } + + err := a.handleOktaSync(context.Background()) + if err != nil { + t.Fatalf("expected nil error when sync disabled, got: %v", err) + } +} + +func TestHandleOktaSync_ClientsNil(t *testing.T) { + a := &App{ + Config: &config.Config{ + OktaDomain: "test.okta.com", + OktaClientID: "cid", + OktaPrivateKey: []byte("key"), + OktaSyncRules: []domain.SyncRule{{OktaGroupName: "g1", GitHubTeamName: "t1"}}, + }, + Logger: discardLogger(), + } + + err := a.handleOktaSync(context.Background()) + if err == nil { + t.Fatal("expected error when clients are nil") + } + if !errors.Is(err, domain.ErrClientNotInit) { + t.Errorf("expected ErrClientNotInit, got: %v", err) + } +} + +func TestHandleOktaSync_Success(t *testing.T) { + notified := false + a := &App{ + Config: &config.Config{ + OktaDomain: "test.okta.com", + OktaClientID: "cid", + OktaPrivateKey: []byte("key"), + OktaSyncRules: []domain.SyncRule{{OktaGroupName: "Engineering", GitHubTeamName: "engineering"}}, + GitHubOrg: "test-org", + }, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{}, + OktaClient: &mockOktaClient{ + getGroupInfoFn: func(_ context.Context, name string) (*domain.GroupInfo, error) { + return &domain.GroupInfo{ID: "g1", Name: name, Members: []string{"alice"}}, nil + }, + }, + Notifier: &mockNotifier{ + notifyOktaSyncFn: func(_ context.Context, reports []*domain.SyncReport, org string) error { + notified = true + if len(reports) != 1 { + t.Errorf("expected 1 report, got %d", len(reports)) + } + return nil + }, + }, + } + + err := a.handleOktaSync(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !notified { + t.Error("expected slack notification to be sent") + } +} + +func TestHandleOktaSync_NotifierFailureDoesNotFail(t *testing.T) { + a := &App{ + Config: &config.Config{ + OktaDomain: "test.okta.com", + OktaClientID: "cid", + OktaPrivateKey: []byte("key"), + OktaSyncRules: []domain.SyncRule{{OktaGroupName: "Eng", GitHubTeamName: "eng"}}, + GitHubOrg: "org", + }, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{}, + OktaClient: &mockOktaClient{}, + Notifier: &mockNotifier{ + notifyOktaSyncFn: func(_ context.Context, _ []*domain.SyncReport, _ string) error { + return errors.New("slack api error") + }, + }, + } + + err := a.handleOktaSync(context.Background()) + if err != nil { + t.Fatalf("notifier failure should not propagate, got: %v", err) + } +} + +func TestHandlePullRequestWebhook_NotMerged(t *testing.T) { + prNumber := 10 + baseBranch := "main" + payload := map[string]any{ + "action": "opened", + "number": prNumber, + "pull_request": map[string]any{ + "number": prNumber, + "base": map[string]any{"ref": baseBranch}, + }, + "repository": map[string]any{ + "name": "repo", + "full_name": "org/repo", + "owner": map[string]any{"login": "org"}, + }, + } + data, _ := json.Marshal(payload) + + a := &App{ + Config: &config.Config{PRComplianceEnabled: true, PRMonitoredBranches: []string{"main"}, + GitHubOrg: "org", GitHubAppID: 1, GitHubAppPrivateKey: []byte("k"), GitHubInstallationID: 1}, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{}, + } + + err := a.handlePullRequestWebhook(context.Background(), data) + if err != nil { + t.Fatalf("non-merged PR should not error, got: %v", err) + } +} + +func TestHandlePullRequestWebhook_UnmonitoredBranch(t *testing.T) { + merged := true + prNumber := 10 + payload := map[string]any{ + "action": "closed", + "number": prNumber, + "pull_request": map[string]any{ + "number": prNumber, + "merged": merged, + "base": map[string]any{"ref": "develop"}, + }, + "repository": map[string]any{ + "name": "repo", + "full_name": "org/repo", + "owner": map[string]any{"login": "org"}, + }, + } + data, _ := json.Marshal(payload) + + a := &App{ + Config: &config.Config{PRComplianceEnabled: true, PRMonitoredBranches: []string{"main"}, + GitHubOrg: "org", GitHubAppID: 1, GitHubAppPrivateKey: []byte("k"), GitHubInstallationID: 1}, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{}, + } + + err := a.handlePullRequestWebhook(context.Background(), data) + if err != nil { + t.Fatalf("unmonitored branch should not error, got: %v", err) + } +} + +func TestHandlePullRequestWebhook_ComplianceCheck(t *testing.T) { + merged := true + prNumber := 42 + checked := false + payload := map[string]any{ + "action": "closed", + "number": prNumber, + "pull_request": map[string]any{ + "number": prNumber, + "merged": merged, + "base": map[string]any{"ref": "main"}, + }, + "repository": map[string]any{ + "name": "repo", + "full_name": "org/repo", + "owner": map[string]any{"login": "org"}, + }, + } + data, _ := json.Marshal(payload) + + a := &App{ + Config: &config.Config{ + PRComplianceEnabled: true, PRMonitoredBranches: []string{"main"}, + GitHubOrg: "org", GitHubAppID: 1, GitHubAppPrivateKey: []byte("k"), GitHubInstallationID: 1, + }, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{ + checkPRComplianceFn: func(_ context.Context, owner, repo string, num int) (*domain.PRComplianceResult, error) { + checked = true + if num != prNumber { + t.Errorf("expected pr %d, got %d", prNumber, num) + } + return &domain.PRComplianceResult{ + Violations: []domain.ComplianceViolation{}, + }, nil + }, + }, + } + + err := a.handlePullRequestWebhook(context.Background(), data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !checked { + t.Error("expected compliance check to be called") + } +} + +func TestHandlePullRequestWebhook_BypassNotification(t *testing.T) { + merged := true + prNumber := 42 + notified := false + prURL := "https://github.com/org/repo/pull/42" + payload := map[string]any{ + "action": "closed", + "number": prNumber, + "pull_request": map[string]any{ + "number": prNumber, + "merged": merged, + "base": map[string]any{"ref": "main"}, + "html_url": prURL, + "merged_by": map[string]any{"login": "admin-user"}, + }, + "repository": map[string]any{ + "name": "repo", + "full_name": "org/repo", + "owner": map[string]any{"login": "org"}, + }, + } + data, _ := json.Marshal(payload) + + pr := &github.PullRequest{} + a := &App{ + Config: &config.Config{ + PRComplianceEnabled: true, PRMonitoredBranches: []string{"main"}, + GitHubOrg: "org", GitHubAppID: 1, GitHubAppPrivateKey: []byte("k"), GitHubInstallationID: 1, + }, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{ + checkPRComplianceFn: func(_ context.Context, _, _ string, _ int) (*domain.PRComplianceResult, error) { + return &domain.PRComplianceResult{ + PR: pr, + UserHasBypass: true, + Violations: []domain.ComplianceViolation{{Type: "test", Description: "test"}}, + }, nil + }, + }, + Notifier: &mockNotifier{ + notifyPRBypassFn: func(_ context.Context, result *domain.PRComplianceResult, repoName string) error { + notified = true + if repoName != "org/repo" { + t.Errorf("expected repo org/repo, got %s", repoName) + } + return nil + }, + }, + } + + err := a.handlePullRequestWebhook(context.Background(), data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !notified { + t.Error("expected bypass notification to be sent") + } +} + +func TestHandleTeamWebhook_SyncDisabled(t *testing.T) { + payload := map[string]any{ + "action": "edited", + "team": map[string]any{"slug": "engineering"}, + "sender": map[string]any{"login": "user1", "type": "User"}, + } + data, _ := json.Marshal(payload) + + a := &App{ + Config: &config.Config{}, // no okta sync configured + Logger: discardLogger(), + } + + err := a.handleTeamWebhook(context.Background(), data) + if err != nil { + t.Fatalf("expected nil when sync disabled, got: %v", err) + } +} + +func TestHandleTeamWebhook_IgnoresBotSender(t *testing.T) { + payload := map[string]any{ + "action": "edited", + "team": map[string]any{"slug": "engineering"}, + "sender": map[string]any{"login": "dependabot", "type": "Bot"}, + } + data, _ := json.Marshal(payload) + + a := &App{ + Config: &config.Config{ + OktaDomain: "test.okta.com", OktaClientID: "cid", + OktaPrivateKey: []byte("key"), + OktaSyncRules: []domain.SyncRule{{OktaGroupName: "g", GitHubTeamName: "t"}}, + }, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{}, + OktaClient: &mockOktaClient{}, + } + + err := a.handleTeamWebhook(context.Background(), data) + if err != nil { + t.Fatalf("bot sender should be ignored, got: %v", err) + } +} + +func TestHandleMembershipWebhook_NonTeamScope(t *testing.T) { + payload := map[string]any{ + "action": "added", + "scope": "organization", + "member": map[string]any{"login": "user1"}, + "team": map[string]any{"slug": "eng"}, + "sender": map[string]any{"login": "admin", "type": "User"}, + } + data, _ := json.Marshal(payload) + + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + } + + err := a.handleMembershipWebhook(context.Background(), data) + if err != nil { + t.Fatalf("non-team scope should be skipped, got: %v", err) + } +} + +func TestShouldIgnoreWebhookChange_BotSender(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + } + + event := &stubSender{senderType: "Bot", senderLogin: "some-bot"} + if !a.shouldIgnoreWebhookChange(context.Background(), event) { + t.Error("expected Bot sender to be ignored") + } +} + +func TestShouldIgnoreWebhookChange_AppSlugMatch(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{ + getAppSlugFn: func(_ context.Context) (string, error) { + return "my-app", nil + }, + }, + } + + event := &stubSender{senderType: "User", senderLogin: "my-app[bot]"} + if !a.shouldIgnoreWebhookChange(context.Background(), event) { + t.Error("expected app slug match to be ignored") + } +} + +func TestShouldIgnoreWebhookChange_HumanUser(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + GitHubClient: &mockGitHubClient{ + getAppSlugFn: func(_ context.Context) (string, error) { + return "my-app", nil + }, + }, + } + + event := &stubSender{senderType: "User", senderLogin: "human-user"} + if a.shouldIgnoreWebhookChange(context.Background(), event) { + t.Error("expected human user NOT to be ignored") + } +} + +func TestProcessWebhook_UnknownEventType(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + } + + err := a.ProcessWebhook(context.Background(), []byte(`{}`), "deployment") + if err == nil { + t.Fatal("expected error for unknown event type") + } + if !errors.Is(err, domain.ErrInvalidEventType) { + t.Errorf("expected ErrInvalidEventType, got: %v", err) + } +} + +func TestMapErrorResponse(t *testing.T) { + tests := []struct { + name string + err error + wantStatus int + }{ + {"auth error direct", domain.ErrInvalidSignature, 401}, + {"auth error wrapped", errors.Wrap(domain.ErrInvalidSignature, "wrapped"), 401}, + {"validation error", domain.ErrMissingPRData, 400}, + {"validation error wrapped", errors.Wrap(domain.ErrInvalidEventType, "wrapped"), 400}, + {"config error", domain.ErrClientNotInit, 503}, + {"api error", errors.Wrap(domain.ErrTeamNotFound, "wrapped"), 502}, + {"generic error", errors.New("something"), 500}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resp := mapErrorResponse(tt.err, "fallback") + if resp.StatusCode != tt.wantStatus { + t.Errorf("expected status %d, got %d", tt.wantStatus, resp.StatusCode) + } + }) + } +} + +func TestHandleHTTPRequest_NotFound(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + } + + req := Request{Type: RequestTypeHTTP, Method: "GET", Path: "/nonexistent"} + resp := a.HandleRequest(context.Background(), req) + if resp.StatusCode != 404 { + t.Errorf("expected 404, got %d", resp.StatusCode) + } +} + +func TestHandleHTTPRequest_MethodNotAllowed(t *testing.T) { + a := &App{ + Config: &config.Config{}, + Logger: discardLogger(), + } + + req := Request{Type: RequestTypeHTTP, Method: "DELETE", Path: "/webhooks"} + resp := a.HandleRequest(context.Background(), req) + if resp.StatusCode != 405 { + t.Errorf("expected 405, got %d", resp.StatusCode) + } +} + +func TestHandleHTTPRequest_BasePathStripping(t *testing.T) { + a := &App{ + Config: &config.Config{BasePath: "/api/v1"}, + Logger: discardLogger(), + } + + req := Request{Type: RequestTypeHTTP, Method: "GET", Path: "/api/v1/server/status"} + resp := a.HandleRequest(context.Background(), req) + if resp.StatusCode != 200 { + t.Errorf("expected 200 for base-path-stripped status, got %d", resp.StatusCode) + } +} + +// stubSender is a simple webhookSender for testing shouldIgnoreWebhookChange +type stubSender struct { + senderType string + senderLogin string +} + +func (s *stubSender) GetSenderType() string { return s.senderType } +func (s *stubSender) GetSenderLogin() string { return s.senderLogin } diff --git a/internal/app/request.go b/internal/app/request.go index d3a17d7..6b831de 100644 --- a/internal/app/request.go +++ b/internal/app/request.go @@ -6,6 +6,8 @@ import ( "log/slog" "strings" + "github.com/cockroachdb/errors" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/cruxstack/github-ops-app/internal/github/webhooks" ) @@ -72,7 +74,7 @@ func (a *App) handleScheduledRequest(ctx context.Context, req Request) Response a.Logger.Error("scheduled event processing failed", slog.String("action", evt.Action), slog.String("error", err.Error())) - return errorResponse(500, "scheduled event processing failed") + return mapErrorResponse(err, "scheduled event processing failed") } return jsonResponse(200, map[string]string{ @@ -145,14 +147,14 @@ func (a *App) handleWebhookRequest(ctx context.Context, req Request) Response { ); err != nil { a.Logger.Warn("webhook signature validation failed", slog.String("error", err.Error())) - return errorResponse(401, "unauthorized") + return mapErrorResponse(err, "unauthorized") } if err := a.ProcessWebhook(ctx, req.Body, eventType); err != nil { a.Logger.Error("webhook processing failed", slog.String("event_type", eventType), slog.String("error", err.Error())) - return errorResponse(500, "webhook processing failed") + return mapErrorResponse(err, "webhook processing failed") } return Response{ @@ -186,6 +188,24 @@ func (a *App) handleScheduledHTTPRequest(ctx context.Context, req Request, path return a.handleScheduledRequest(ctx, scheduledReq) } +// mapErrorResponse translates domain error types to appropriate HTTP status +// codes. centralizes error-to-HTTP mapping in one place. uses errors.Is() +// with domain marker instances since errors.Mark() sets identity markers. +func mapErrorResponse(err error, fallbackMsg string) Response { + switch { + case errors.Is(err, domain.AuthError): + return errorResponse(401, "unauthorized") + case errors.Is(err, domain.ValidationError): + return errorResponse(400, fallbackMsg) + case errors.Is(err, domain.ConfigError): + return errorResponse(503, "service not configured") + case errors.Is(err, domain.APIError): + return errorResponse(502, fallbackMsg) + default: + return errorResponse(500, fallbackMsg) + } +} + // jsonResponse creates a JSON response with the given status and data. func jsonResponse(status int, data any) Response { body, err := json.Marshal(data) diff --git a/internal/app/testdata.go b/internal/app/testdata.go index 6c264ef..bae49a2 100644 --- a/internal/app/testdata.go +++ b/internal/app/testdata.go @@ -1,19 +1,18 @@ package app import ( - "github.com/cruxstack/github-ops-app/internal/github/client" - "github.com/cruxstack/github-ops-app/internal/okta" + "github.com/cruxstack/github-ops-app/internal/domain" gh "github.com/google/go-github/v79/github" ) // fakePRComplianceResult returns sample PR compliance data for testing. -func fakePRComplianceResult() *client.PRComplianceResult { +func fakePRComplianceResult() *domain.PRComplianceResult { prNumber := 42 prTitle := "Add new authentication feature" prURL := "https://github.com/acme-corp/demo-repo/pull/42" mergedByLogin := "test-user" - return &client.PRComplianceResult{ + return &domain.PRComplianceResult{ PR: &gh.PullRequest{ Number: &prNumber, Title: &prTitle, @@ -25,7 +24,7 @@ func fakePRComplianceResult() *client.PRComplianceResult { BaseBranch: "main", UserHasBypass: true, UserBypassReason: "repository admin", - Violations: []client.ComplianceViolation{ + Violations: []domain.ComplianceViolation{ {Type: "insufficient_reviews", Description: "required 2 approving reviews, had 0"}, {Type: "missing_status_check", Description: "required check 'ci/build' did not pass"}, }, @@ -33,8 +32,8 @@ func fakePRComplianceResult() *client.PRComplianceResult { } // fakeOktaSyncReports returns sample Okta sync reports for testing. -func fakeOktaSyncReports() []*okta.SyncReport { - return []*okta.SyncReport{ +func fakeOktaSyncReports() []*domain.SyncReport { + return []*domain.SyncReport{ { Rule: "engineering-team", OktaGroup: "Engineering", @@ -61,8 +60,8 @@ func fakeOktaSyncReports() []*okta.SyncReport { } // fakeOrphanedUsersReport returns sample orphaned users data for testing. -func fakeOrphanedUsersReport() *okta.OrphanedUsersReport { - return &okta.OrphanedUsersReport{ +func fakeOrphanedUsersReport() *domain.OrphanedUsersReport { + return &domain.OrphanedUsersReport{ OrphanedUsers: []string{"orphan-user-1", "orphan-user-2", "legacy-bot"}, } } diff --git a/internal/config/config.go b/internal/config/config.go index f43dde0..2895a98 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -15,7 +15,7 @@ import ( "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/cockroachdb/errors" - "github.com/cruxstack/github-ops-app/internal/types" + "github.com/cruxstack/github-ops-app/internal/domain" ) // Config holds all application configuration loaded from environment @@ -46,7 +46,7 @@ type Config struct { OktaScopes []string OktaBaseURL string OktaGitHubUserField string - OktaSyncRules []types.SyncRule + OktaSyncRules []domain.SyncRule OktaSyncSafetyThreshold float64 OktaOrphanedUserNotifications bool @@ -259,7 +259,7 @@ func NewConfigWithContext(ctx context.Context) (*Config, error) { syncRulesJSON := os.Getenv("APP_OKTA_SYNC_RULES") if syncRulesJSON != "" { - var rules []types.SyncRule + var rules []domain.SyncRule if err := json.Unmarshal([]byte(syncRulesJSON), &rules); err != nil { return nil, errors.Wrap(err, "failed to parse APP_OKTA_SYNC_RULES") } @@ -363,16 +363,16 @@ type RedactedConfig struct { PRMonitoredBranches []string `json:"pr_monitored_branches"` // Okta - OktaDomain string `json:"okta_domain"` - OktaClientID string `json:"okta_client_id"` - OktaPrivateKey string `json:"okta_private_key"` - OktaPrivateKeyID string `json:"okta_private_key_id"` - OktaScopes []string `json:"okta_scopes"` - OktaBaseURL string `json:"okta_base_url"` - OktaGitHubUserField string `json:"okta_github_user_field"` - OktaSyncRules []types.SyncRule `json:"okta_sync_rules"` - OktaSyncSafetyThreshold float64 `json:"okta_sync_safety_threshold"` - OktaOrphanedUserNotifications bool `json:"okta_orphaned_user_notifications"` + OktaDomain string `json:"okta_domain"` + OktaClientID string `json:"okta_client_id"` + OktaPrivateKey string `json:"okta_private_key"` + OktaPrivateKeyID string `json:"okta_private_key_id"` + OktaScopes []string `json:"okta_scopes"` + OktaBaseURL string `json:"okta_base_url"` + OktaGitHubUserField string `json:"okta_github_user_field"` + OktaSyncRules []domain.SyncRule `json:"okta_sync_rules"` + OktaSyncSafetyThreshold float64 `json:"okta_sync_safety_threshold"` + OktaOrphanedUserNotifications bool `json:"okta_orphaned_user_notifications"` // Slack SlackEnabled bool `json:"slack_enabled"` diff --git a/internal/config/config_helpers_test.go b/internal/config/config_helpers_test.go new file mode 100644 index 0000000..ed234a2 --- /dev/null +++ b/internal/config/config_helpers_test.go @@ -0,0 +1,298 @@ +package config + +import ( + "testing" +) + +func TestIsGitHubConfigured(t *testing.T) { + tests := []struct { + name string + cfg Config + want bool + }{ + { + name: "fully configured", + cfg: Config{ + GitHubOrg: "org", GitHubAppID: 1, + GitHubAppPrivateKey: []byte("key"), GitHubInstallationID: 1, + }, + want: true, + }, + {name: "empty", cfg: Config{}, want: false}, + { + name: "missing org", + cfg: Config{GitHubAppID: 1, GitHubAppPrivateKey: []byte("k"), GitHubInstallationID: 1}, + want: false, + }, + { + name: "missing private key", + cfg: Config{GitHubOrg: "org", GitHubAppID: 1, GitHubInstallationID: 1}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.IsGitHubConfigured(); got != tt.want { + t.Errorf("IsGitHubConfigured() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsOktaSyncEnabled(t *testing.T) { + tests := []struct { + name string + cfg Config + want bool + }{ + {name: "empty", cfg: Config{}, want: false}, + { + name: "missing rules", + cfg: Config{OktaDomain: "d", OktaClientID: "c", OktaPrivateKey: []byte("k")}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.IsOktaSyncEnabled(); got != tt.want { + t.Errorf("IsOktaSyncEnabled() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIsPRComplianceEnabled(t *testing.T) { + tests := []struct { + name string + cfg Config + want bool + }{ + {name: "disabled flag", cfg: Config{PRComplianceEnabled: false}, want: false}, + { + name: "enabled but no github", + cfg: Config{PRComplianceEnabled: true}, + want: false, + }, + { + name: "enabled with github", + cfg: Config{ + PRComplianceEnabled: true, + GitHubOrg: "org", GitHubAppID: 1, + GitHubAppPrivateKey: []byte("k"), GitHubInstallationID: 1, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.cfg.IsPRComplianceEnabled(); got != tt.want { + t.Errorf("IsPRComplianceEnabled() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestShouldMonitorBranch(t *testing.T) { + cfg := Config{ + PRComplianceEnabled: true, + PRMonitoredBranches: []string{"main", "master"}, + GitHubOrg: "org", + GitHubAppID: 1, + GitHubAppPrivateKey: []byte("k"), + GitHubInstallationID: 1, + } + + tests := []struct { + branch string + want bool + }{ + {"main", true}, + {"master", true}, + {"develop", false}, + {"refs/heads/main", true}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.branch, func(t *testing.T) { + if got := cfg.ShouldMonitorBranch(tt.branch); got != tt.want { + t.Errorf("ShouldMonitorBranch(%q) = %v, want %v", tt.branch, got, tt.want) + } + }) + } +} + +func TestRedacted(t *testing.T) { + cfg := Config{ + GitHubOrg: "my-org", + GitHubAppPrivateKey: []byte("secret-key"), + GitHubWebhookSecret: "webhook-secret", + AdminToken: "admin-secret", + SlackToken: "xoxb-token", + OktaClientID: "client-id", + OktaPrivateKey: []byte("okta-key"), + DebugEnabled: true, + } + + redacted := cfg.Redacted() + + // non-secrets should be preserved + if redacted.GitHubOrg != "my-org" { + t.Errorf("expected org to be preserved, got %q", redacted.GitHubOrg) + } + if !redacted.DebugEnabled { + t.Error("expected DebugEnabled to be preserved") + } + + // secrets should be redacted + if redacted.GitHubAppPrivateKey != "***REDACTED***" { + t.Errorf("expected private key to be redacted, got %q", redacted.GitHubAppPrivateKey) + } + if redacted.GitHubWebhookSecret != "***REDACTED***" { + t.Errorf("expected webhook secret to be redacted, got %q", redacted.GitHubWebhookSecret) + } + if redacted.AdminToken != "***REDACTED***" { + t.Errorf("expected admin token to be redacted, got %q", redacted.AdminToken) + } + if redacted.SlackToken != "***REDACTED***" { + t.Errorf("expected slack token to be redacted, got %q", redacted.SlackToken) + } + if redacted.OktaClientID != "***REDACTED***" { + t.Errorf("expected okta client id to be redacted, got %q", redacted.OktaClientID) + } +} + +func TestRedacted_EmptyValues(t *testing.T) { + cfg := Config{} + redacted := cfg.Redacted() + + if redacted.GitHubAppPrivateKey != "" { + t.Error("expected empty private key to remain empty") + } + if redacted.SlackToken != "" { + t.Error("expected empty slack token to remain empty") + } +} + +func TestNewConfigWithContext_Defaults(t *testing.T) { + // clear all env vars that NewConfig reads + envKeys := []string{ + "APP_DEBUG_ENABLED", "APP_GITHUB_ORG", "APP_GITHUB_APP_ID", + "APP_GITHUB_APP_PRIVATE_KEY", "APP_GITHUB_APP_PRIVATE_KEY_PATH", + "APP_GITHUB_INSTALLATION_ID", "APP_GITHUB_WEBHOOK_SECRET", + "APP_GITHUB_BASE_URL", "APP_PR_COMPLIANCE_ENABLED", + "APP_PR_MONITORED_BRANCHES", "APP_OKTA_DOMAIN", "APP_OKTA_CLIENT_ID", + "APP_OKTA_PRIVATE_KEY", "APP_OKTA_PRIVATE_KEY_PATH", + "APP_OKTA_PRIVATE_KEY_ID", "APP_OKTA_SCOPES", + "APP_OKTA_BASE_URL", "APP_OKTA_GITHUB_USER_FIELD", + "APP_OKTA_SYNC_RULES", "APP_OKTA_SYNC_SAFETY_THRESHOLD", + "APP_OKTA_ORPHANED_USER_NOTIFICATIONS", "APP_SLACK_TOKEN", + "APP_SLACK_CHANNEL", "APP_SLACK_API_URL", "APP_BASE_PATH", + "APP_ADMIN_TOKEN", + } + for _, key := range envKeys { + t.Setenv(key, "") + } + + cfg, err := NewConfig() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.OktaGitHubUserField != "githubUsername" { + t.Errorf("expected default github user field, got %q", cfg.OktaGitHubUserField) + } + if cfg.OktaSyncSafetyThreshold != 0.5 { + t.Errorf("expected default safety threshold 0.5, got %f", cfg.OktaSyncSafetyThreshold) + } + if len(cfg.PRMonitoredBranches) != 2 || cfg.PRMonitoredBranches[0] != "main" { + t.Errorf("expected default monitored branches [main, master], got %v", cfg.PRMonitoredBranches) + } + if len(cfg.OktaScopes) != 2 { + t.Errorf("expected default okta scopes, got %v", cfg.OktaScopes) + } +} + +func TestNewConfigWithContext_ParsesEnvVars(t *testing.T) { + t.Setenv("APP_DEBUG_ENABLED", "true") + t.Setenv("APP_GITHUB_ORG", "test-org") + t.Setenv("APP_GITHUB_APP_ID", "12345") + t.Setenv("APP_GITHUB_APP_PRIVATE_KEY", "test-key-data") + t.Setenv("APP_GITHUB_INSTALLATION_ID", "67890") + t.Setenv("APP_GITHUB_WEBHOOK_SECRET", "whsec") + t.Setenv("APP_PR_COMPLIANCE_ENABLED", "true") + t.Setenv("APP_PR_MONITORED_BRANCHES", "main,release") + t.Setenv("APP_OKTA_SYNC_SAFETY_THRESHOLD", "0.3") + t.Setenv("APP_BASE_PATH", "/api/v1") + t.Setenv("APP_OKTA_SYNC_RULES", `[{"okta_group_name":"Eng","github_team_name":"eng"}]`) + // clear keys that could interfere + t.Setenv("APP_GITHUB_APP_PRIVATE_KEY_PATH", "") + t.Setenv("APP_OKTA_PRIVATE_KEY_PATH", "") + t.Setenv("APP_OKTA_PRIVATE_KEY", "") + t.Setenv("APP_SLACK_TOKEN", "") + t.Setenv("APP_SLACK_CHANNEL", "") + t.Setenv("APP_ADMIN_TOKEN", "") + t.Setenv("APP_OKTA_DOMAIN", "") + t.Setenv("APP_OKTA_CLIENT_ID", "") + + cfg, err := NewConfig() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if !cfg.DebugEnabled { + t.Error("expected debug enabled") + } + if cfg.GitHubOrg != "test-org" { + t.Errorf("expected org test-org, got %q", cfg.GitHubOrg) + } + if cfg.GitHubAppID != 12345 { + t.Errorf("expected app id 12345, got %d", cfg.GitHubAppID) + } + if string(cfg.GitHubAppPrivateKey) != "test-key-data" { + t.Errorf("expected key data, got %q", string(cfg.GitHubAppPrivateKey)) + } + if cfg.GitHubInstallationID != 67890 { + t.Errorf("expected installation 67890, got %d", cfg.GitHubInstallationID) + } + if cfg.OktaSyncSafetyThreshold != 0.3 { + t.Errorf("expected threshold 0.3, got %f", cfg.OktaSyncSafetyThreshold) + } + if cfg.BasePath != "/api/v1" { + t.Errorf("expected base path /api/v1, got %q", cfg.BasePath) + } + if len(cfg.PRMonitoredBranches) != 2 || cfg.PRMonitoredBranches[1] != "release" { + t.Errorf("expected [main, release], got %v", cfg.PRMonitoredBranches) + } + if len(cfg.OktaSyncRules) != 1 { + t.Errorf("expected 1 sync rule, got %d", len(cfg.OktaSyncRules)) + } +} + +func TestNewConfigWithContext_InvalidAppID(t *testing.T) { + t.Setenv("APP_GITHUB_APP_ID", "not-a-number") + t.Setenv("APP_GITHUB_WEBHOOK_SECRET", "") + t.Setenv("APP_SLACK_TOKEN", "") + t.Setenv("APP_ADMIN_TOKEN", "") + + _, err := NewConfig() + if err == nil { + t.Fatal("expected error for invalid app id") + } +} + +func TestNewConfigWithContext_InvalidSyncRulesJSON(t *testing.T) { + t.Setenv("APP_OKTA_SYNC_RULES", "not-json") + t.Setenv("APP_GITHUB_APP_ID", "") + t.Setenv("APP_GITHUB_WEBHOOK_SECRET", "") + t.Setenv("APP_SLACK_TOKEN", "") + t.Setenv("APP_ADMIN_TOKEN", "") + + _, err := NewConfig() + if err == nil { + t.Fatal("expected error for invalid sync rules JSON") + } +} diff --git a/internal/errors/errors.go b/internal/domain/errors.go similarity index 88% rename from internal/errors/errors.go rename to internal/domain/errors.go index 3a05f20..18249f0 100644 --- a/internal/errors/errors.go +++ b/internal/domain/errors.go @@ -1,6 +1,7 @@ -// Package errors defines sentinel errors and domain types for the -// application. uses cockroachdb/errors for automatic stack trace capture. -package errors +// Package domain defines shared business types, errors, and interfaces. +// this package has zero internal imports and serves as the dependency +// inversion layer for the application. +package domain import "github.com/cockroachdb/errors" @@ -39,5 +40,4 @@ var ( ErrClientNotInit = errors.Mark(errors.New("client not initialized"), ConfigError) ErrInvalidEventType = errors.Mark(errors.New("unknown event type"), ValidationError) ErrMissingOAuthCreds = errors.Mark(errors.New("must provide either api token or oauth credentials"), ConfigError) - ErrOAuthTokenExpired = errors.Mark(errors.New("oauth token expired"), AuthError) ) diff --git a/internal/domain/interfaces.go b/internal/domain/interfaces.go new file mode 100644 index 0000000..c40da02 --- /dev/null +++ b/internal/domain/interfaces.go @@ -0,0 +1,64 @@ +package domain + +import ( + "context" + + "github.com/google/go-github/v79/github" +) + +// GitHubClient defines the interface for GitHub API operations. +// implemented by internal/github/client.Client. +type GitHubClient interface { + // CheckPRCompliance verifies if a merged PR met branch protection + // requirements. + CheckPRCompliance(ctx context.Context, owner, repo string, prNumber int) (*PRComplianceResult, error) + + // GetOrCreateTeam fetches an existing team by slug or creates it if + // missing. + GetOrCreateTeam(ctx context.Context, teamName, privacy string) (*github.Team, error) + + // SyncTeamMembers adds and removes members to match desired state. + SyncTeamMembers(ctx context.Context, teamSlug string, desiredMembers []string, safetyThreshold float64) (*TeamSyncResult, error) + + // GetTeamMembers returns GitHub usernames of all team members. + GetTeamMembers(ctx context.Context, teamSlug string) ([]string, error) + + // ListOrgMembers returns all organization members excluding external + // collaborators. + ListOrgMembers(ctx context.Context) ([]string, error) + + // IsExternalCollaborator checks if a user is an outside collaborator + // rather than an organization member. + IsExternalCollaborator(ctx context.Context, username string) (bool, error) + + // GetAppSlug fetches the GitHub App slug identifier. + GetAppSlug(ctx context.Context) (string, error) + + // GetOrg returns the GitHub organization name. + GetOrg() string +} + +// OktaClient defines the interface for Okta API operations. +// implemented by internal/okta.Client. +type OktaClient interface { + // GetGroupsByPattern fetches all Okta groups matching a regex pattern. + GetGroupsByPattern(ctx context.Context, pattern string) ([]*GroupInfo, error) + + // GetGroupInfo fetches details for a single Okta group by name. + GetGroupInfo(ctx context.Context, groupName string) (*GroupInfo, error) +} + +// Notifier defines the interface for sending notifications. +// implemented by internal/notifiers.SlackNotifier. +type Notifier interface { + // NotifyPRBypass sends a notification when branch protection is + // bypassed. + NotifyPRBypass(ctx context.Context, result *PRComplianceResult, repoFullName string) error + + // NotifyOktaSync sends a notification with Okta sync results. + NotifyOktaSync(ctx context.Context, reports []*SyncReport, githubOrg string) error + + // NotifyOrphanedUsers sends a notification about organization members + // not in any synced teams. + NotifyOrphanedUsers(ctx context.Context, report *OrphanedUsersReport) error +} diff --git a/internal/domain/types.go b/internal/domain/types.go new file mode 100644 index 0000000..f08b05f --- /dev/null +++ b/internal/domain/types.go @@ -0,0 +1,126 @@ +package domain + +import "github.com/google/go-github/v79/github" + +// SyncRule defines how to sync Okta groups to GitHub teams. +type SyncRule struct { + Name string `json:"name"` + Enabled *bool `json:"enabled,omitempty"` + OktaGroupPattern string `json:"okta_group_pattern,omitempty"` + OktaGroupName string `json:"okta_group_name,omitempty"` + GitHubTeamPrefix string `json:"github_team_prefix,omitempty"` + GitHubTeamName string `json:"github_team_name,omitempty"` + StripPrefix string `json:"strip_prefix,omitempty"` + SyncMembers *bool `json:"sync_members,omitempty"` + CreateTeamIfMissing bool `json:"create_team_if_missing"` + TeamPrivacy string `json:"team_privacy,omitempty"` +} + +// IsEnabled returns true if the rule is enabled (defaults to true). +func (r SyncRule) IsEnabled() bool { + return r.Enabled == nil || *r.Enabled +} + +// ShouldSyncMembers returns true if members should be synced (defaults to +// true). +func (r SyncRule) ShouldSyncMembers() bool { + return r.SyncMembers == nil || *r.SyncMembers +} + +// GetName returns the rule name, defaulting to GitHubTeamName if not set. +func (r SyncRule) GetName() string { + if r.Name != "" { + return r.Name + } + if r.GitHubTeamName != "" { + return r.GitHubTeamName + } + return r.OktaGroupName +} + +// ComplianceViolation represents a single branch protection rule violation. +type ComplianceViolation struct { + Type string + Description string +} + +// PRComplianceResult contains PR compliance check results including +// violations and user bypass permissions. +type PRComplianceResult struct { + PR *github.PullRequest + BaseBranch string + Protection *github.Protection + BranchRules *github.BranchRules + Violations []ComplianceViolation + UserHasBypass bool + UserBypassReason string +} + +// HasViolations returns true if any compliance violations were detected. +func (r *PRComplianceResult) HasViolations() bool { + return len(r.Violations) > 0 +} + +// WasBypassed returns true if violations exist and user had bypass +// permission. +func (r *PRComplianceResult) WasBypassed() bool { + return r.HasViolations() && r.UserHasBypass +} + +// TeamSyncResult contains the results of syncing team membership. +type TeamSyncResult struct { + TeamName string + MembersAdded []string + MembersRemoved []string + MembersSkippedExternal []string + Errors []string +} + +// GroupInfo contains Okta group details and member list. +type GroupInfo struct { + ID string + Name string + Members []string + SkippedNoGitHubUsername []string +} + +// GroupMembersResult contains the results of fetching group members. +type GroupMembersResult struct { + Members []string + SkippedNoGitHubUsername []string +} + +// SyncReport contains the results of syncing a single Okta group to GitHub +// team. +type SyncReport struct { + Rule string + OktaGroup string + GitHubTeam string + MembersAdded []string + MembersRemoved []string + MembersSkippedExternal []string + MembersSkippedNoGHUsername []string + Errors []string +} + +// HasErrors returns true if any errors occurred during sync. +func (r *SyncReport) HasErrors() bool { + return len(r.Errors) > 0 +} + +// HasChanges returns true if members were added or removed. +func (r *SyncReport) HasChanges() bool { + return len(r.MembersAdded) > 0 || len(r.MembersRemoved) > 0 +} + +// OrphanedUsersReport contains users who are org members but not in any +// synced teams. +type OrphanedUsersReport struct { + OrphanedUsers []string +} + +// SyncResult contains all sync reports and orphaned users report. +type SyncResult struct { + Reports []*SyncReport + OrphanedUsers *OrphanedUsersReport +} diff --git a/internal/github/client/client.go b/internal/github/client/client.go index fd3e7ce..93ec3a3 100644 --- a/internal/github/client/client.go +++ b/internal/github/client/client.go @@ -21,6 +21,7 @@ import ( // Client wraps the GitHub API client with App authentication. // automatically refreshes installation tokens before expiry. +// implements domain.GitHubClient. type Client struct { client *github.Client org string @@ -144,14 +145,21 @@ func (c *Client) refreshToken(ctx context.Context) error { } // ensureValidToken refreshes the installation token if it expires within 5 -// minutes. +// minutes. uses double-check pattern to avoid redundant refreshes under +// concurrent access. func (c *Client) ensureValidToken(ctx context.Context) error { c.tokenMu.RLock() needsRefresh := time.Now().Add(5 * time.Minute).After(c.tokenExpAt) c.tokenMu.RUnlock() if needsRefresh { - return c.refreshToken(ctx) + c.tokenMu.Lock() + // double-check after acquiring write lock + if time.Now().Add(5 * time.Minute).After(c.tokenExpAt) { + c.tokenMu.Unlock() + return c.refreshToken(ctx) + } + c.tokenMu.Unlock() } return nil diff --git a/internal/github/client/pr.go b/internal/github/client/pr.go index a0d745f..065c851 100644 --- a/internal/github/client/pr.go +++ b/internal/github/client/pr.go @@ -5,32 +5,14 @@ import ( "fmt" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/google/go-github/v79/github" ) -// ComplianceViolation represents a single branch protection rule violation. -type ComplianceViolation struct { - Type string - Description string -} - -// PRComplianceResult contains PR compliance check results including -// violations and user bypass permissions. -type PRComplianceResult struct { - PR *github.PullRequest - BaseBranch string - Protection *github.Protection - BranchRules *github.BranchRules - Violations []ComplianceViolation - UserHasBypass bool - UserBypassReason string -} - // CheckPRCompliance verifies if a merged PR met branch protection // requirements. checks review requirements, status checks, and user bypass // permissions. -func (c *Client) CheckPRCompliance(ctx context.Context, owner, repo string, prNumber int) (*PRComplianceResult, error) { +func (c *Client) CheckPRCompliance(ctx context.Context, owner, repo string, prNumber int) (*domain.PRComplianceResult, error) { if err := c.ensureValidToken(ctx); err != nil { return nil, err } @@ -41,19 +23,19 @@ func (c *Client) CheckPRCompliance(ctx context.Context, owner, repo string, prNu } if pr == nil { - return nil, errors.Wrapf(internalerrors.ErrMissingPRData, "pr #%d returned nil", prNumber) + return nil, errors.Wrapf(domain.ErrMissingPRData, "pr #%d returned nil", prNumber) } if pr.Base == nil || pr.Base.Ref == nil { - return nil, errors.Wrapf(internalerrors.ErrMissingPRData, "pr #%d missing base branch", prNumber) + return nil, errors.Wrapf(domain.ErrMissingPRData, "pr #%d missing base branch", prNumber) } baseBranch := *pr.Base.Ref - result := &PRComplianceResult{ + result := &domain.PRComplianceResult{ PR: pr, BaseBranch: baseBranch, - Violations: []ComplianceViolation{}, + Violations: []domain.ComplianceViolation{}, } // fetch legacy branch protection rules @@ -68,8 +50,14 @@ func (c *Client) CheckPRCompliance(ctx context.Context, owner, repo string, prNu result.BranchRules = branchRules } - c.checkReviewRequirements(ctx, owner, repo, pr, result) - c.checkStatusRequirements(ctx, owner, repo, pr, result) + if err := c.checkReviewRequirements(ctx, owner, repo, pr, result); err != nil { + return nil, errors.Wrapf(err, "failed to check review requirements for pr #%d", prNumber) + } + + if err := c.checkStatusRequirements(ctx, owner, repo, pr, result); err != nil { + return nil, errors.Wrapf(err, "failed to check status requirements for pr #%d", prNumber) + } + c.checkUserBypassPermission(ctx, owner, repo, pr, result) return result, nil @@ -77,7 +65,7 @@ func (c *Client) CheckPRCompliance(ctx context.Context, owner, repo string, prNu // checkReviewRequirements validates that PR had required approving reviews. // checks both legacy branch protection and repository rulesets. -func (c *Client) checkReviewRequirements(ctx context.Context, owner, repo string, pr *github.PullRequest, result *PRComplianceResult) { +func (c *Client) checkReviewRequirements(ctx context.Context, owner, repo string, pr *github.PullRequest, result *domain.PRComplianceResult) error { requiredApprovals := 0 // check legacy branch protection @@ -95,34 +83,45 @@ func (c *Client) checkReviewRequirements(ctx context.Context, owner, repo string } if requiredApprovals == 0 { - return + return nil } - reviews, _, err := c.client.PullRequests.ListReviews(ctx, owner, repo, *pr.Number, nil) + reviews, _, err := c.client.PullRequests.ListReviews(ctx, owner, repo, *pr.Number, &github.ListOptions{PerPage: 100}) if err != nil { - return + return errors.Wrapf(err, "failed to list reviews for pr #%d in %s/%s", *pr.Number, owner, repo) } - approvedCount := 0 + // deduplicate reviews per user, keeping only the latest state + latestReviewByUser := make(map[string]string) for _, review := range reviews { - if review.State != nil && *review.State == "APPROVED" { + if review.User == nil || review.User.Login == nil || review.State == nil { + continue + } + latestReviewByUser[*review.User.Login] = *review.State + } + + approvedCount := 0 + for _, state := range latestReviewByUser { + if state == "APPROVED" { approvedCount++ } } if approvedCount < requiredApprovals { - result.Violations = append(result.Violations, ComplianceViolation{ + result.Violations = append(result.Violations, domain.ComplianceViolation{ Type: "insufficient_reviews", Description: fmt.Sprintf("required %d approving reviews, had %d", requiredApprovals, approvedCount), }) } + + return nil } // checkStatusRequirements validates that required status checks passed. // checks both legacy branch protection and repository rulesets. -func (c *Client) checkStatusRequirements(ctx context.Context, owner, repo string, pr *github.PullRequest, result *PRComplianceResult) { +func (c *Client) checkStatusRequirements(ctx context.Context, owner, repo string, pr *github.PullRequest, result *domain.PRComplianceResult) error { if pr.Head == nil || pr.Head.SHA == nil { - return + return nil } // collect required checks from both sources @@ -147,12 +146,12 @@ func (c *Client) checkStatusRequirements(ctx context.Context, owner, repo string } if len(requiredChecks) == 0 { - return + return nil } combinedStatus, _, err := c.client.Repositories.GetCombinedStatus(ctx, owner, repo, *pr.Head.SHA, nil) if err != nil { - return + return errors.Wrapf(err, "failed to get combined status for sha '%s' in %s/%s", *pr.Head.SHA, owner, repo) } passedChecks := make(map[string]bool) @@ -162,19 +161,34 @@ func (c *Client) checkStatusRequirements(ctx context.Context, owner, repo string } } + // also check GitHub Actions check runs (modern repos use these instead + // of commit statuses) + checkRuns, _, err := c.client.Checks.ListCheckRunsForRef(ctx, owner, repo, *pr.Head.SHA, &github.ListCheckRunsOptions{ + ListOptions: github.ListOptions{PerPage: 100}, + }) + if err == nil && checkRuns != nil { + for _, run := range checkRuns.CheckRuns { + if run.Name != nil && run.Conclusion != nil && *run.Conclusion == "success" { + passedChecks[*run.Name] = true + } + } + } + for required := range requiredChecks { if !passedChecks[required] { - result.Violations = append(result.Violations, ComplianceViolation{ + result.Violations = append(result.Violations, domain.ComplianceViolation{ Type: "missing_status_check", Description: fmt.Sprintf("required check '%s' did not pass", required), }) } } + + return nil } // checkUserBypassPermission checks if the user who merged the PR has admin or // maintainer permissions allowing bypass. -func (c *Client) checkUserBypassPermission(ctx context.Context, owner, repo string, pr *github.PullRequest, result *PRComplianceResult) { +func (c *Client) checkUserBypassPermission(ctx context.Context, owner, repo string, pr *github.PullRequest, result *domain.PRComplianceResult) { if pr.MergedBy == nil || pr.MergedBy.Login == nil { return } @@ -197,14 +211,3 @@ func (c *Client) checkUserBypassPermission(ctx context.Context, owner, repo stri } } } - -// HasViolations returns true if any compliance violations were detected. -func (r *PRComplianceResult) HasViolations() bool { - return len(r.Violations) > 0 -} - -// WasBypassed returns true if violations exist and user had bypass -// permission. -func (r *PRComplianceResult) WasBypassed() bool { - return r.HasViolations() && r.UserHasBypass -} diff --git a/internal/github/client/teams.go b/internal/github/client/teams.go index b2a8657..35a9971 100644 --- a/internal/github/client/teams.go +++ b/internal/github/client/teams.go @@ -3,20 +3,15 @@ package client import ( "context" "fmt" + "strings" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/google/go-github/v79/github" ) -// TeamSyncResult contains the results of syncing team membership. -type TeamSyncResult struct { - TeamName string - MembersAdded []string - MembersRemoved []string - MembersSkippedExternal []string - Errors []string -} +// compile-time assertion +var _ domain.GitHubClient = (*Client)(nil) // GetOrCreateTeam fetches an existing team by slug or creates it if missing. func (c *Client) GetOrCreateTeam(ctx context.Context, teamName, privacy string) (*github.Team, error) { @@ -29,52 +24,64 @@ func (c *Client) GetOrCreateTeam(ctx context.Context, teamName, privacy string) return team, nil } - if resp != nil && resp.StatusCode == 404 { - newTeam := &github.NewTeam{ - Name: teamName, - Privacy: &privacy, - } - team, _, err = c.client.Teams.CreateTeam(ctx, c.org, *newTeam) - if err != nil { - return nil, errors.Wrapf(err, "failed to create team '%s' in org '%s'", teamName, c.org) - } - return team, nil + if resp == nil || resp.StatusCode != 404 { + return nil, errors.Wrapf(err, "failed to fetch team '%s' from org '%s'", teamName, c.org) } - return nil, errors.Wrapf(internalerrors.ErrTeamNotFound, "failed to fetch team '%s' from org '%s'", teamName, c.org) + newTeam := &github.NewTeam{ + Name: teamName, + Privacy: &privacy, + } + team, _, err = c.client.Teams.CreateTeam(ctx, c.org, *newTeam) + if err != nil { + return nil, errors.Wrapf(err, "failed to create team '%s' in org '%s'", teamName, c.org) + } + return team, nil } // GetTeamMembers returns GitHub usernames of all team members. +// paginates through all results to handle large teams. func (c *Client) GetTeamMembers(ctx context.Context, teamSlug string) ([]string, error) { if err := c.ensureValidToken(ctx); err != nil { return nil, err } - members, _, err := c.client.Teams.ListTeamMembersBySlug(ctx, c.org, teamSlug, nil) - if err != nil { - return nil, errors.Wrapf(err, "failed to list members for team '%s'", teamSlug) + opts := &github.TeamListTeamMembersOptions{ + ListOptions: github.ListOptions{PerPage: 100}, } - logins := make([]string, 0, len(members)) - for _, member := range members { - if member.Login != nil { - logins = append(logins, *member.Login) + var allMembers []string + for { + members, resp, err := c.client.Teams.ListTeamMembersBySlug(ctx, c.org, teamSlug, opts) + if err != nil { + return nil, errors.Wrapf(err, "failed to list members for team '%s'", teamSlug) + } + + for _, member := range members { + if member.Login != nil { + allMembers = append(allMembers, *member.Login) + } + } + + if resp.NextPage == 0 { + break } + opts.Page = resp.NextPage } - return logins, nil + return allMembers, nil } // SyncTeamMembers adds and removes members to match desired state. // collects errors for individual operations but continues processing. skips // removal of external collaborators (outside org members). applies safety // threshold to prevent mass removal during outages. -func (c *Client) SyncTeamMembers(ctx context.Context, teamSlug string, desiredMembers []string, safetyThreshold float64) (*TeamSyncResult, error) { +func (c *Client) SyncTeamMembers(ctx context.Context, teamSlug string, desiredMembers []string, safetyThreshold float64) (*domain.TeamSyncResult, error) { if err := c.ensureValidToken(ctx); err != nil { return nil, err } - result := &TeamSyncResult{ + result := &domain.TeamSyncResult{ TeamName: teamSlug, MembersAdded: []string{}, MembersRemoved: []string{}, @@ -89,16 +96,16 @@ func (c *Client) SyncTeamMembers(ctx context.Context, teamSlug string, desiredMe currentSet := make(map[string]bool) for _, member := range currentMembers { - currentSet[member] = true + currentSet[strings.ToLower(member)] = true } desiredSet := make(map[string]bool) for _, member := range desiredMembers { - desiredSet[member] = true + desiredSet[strings.ToLower(member)] = true } for _, desired := range desiredMembers { - if !currentSet[desired] { + if !currentSet[strings.ToLower(desired)] { _, _, err := c.client.Teams.AddTeamMembershipBySlug(ctx, c.org, teamSlug, desired, nil) if err != nil { errMsg := fmt.Sprintf("failed to add '%s' to team '%s': %v", desired, teamSlug, err) @@ -111,7 +118,7 @@ func (c *Client) SyncTeamMembers(ctx context.Context, teamSlug string, desiredMe var toRemove []string for _, current := range currentMembers { - if !desiredSet[current] { + if !desiredSet[strings.ToLower(current)] { toRemove = append(toRemove, current) } } diff --git a/internal/github/webhooks/webhooks.go b/internal/github/webhooks/webhooks.go index 5d5d6c5..979c36a 100644 --- a/internal/github/webhooks/webhooks.go +++ b/internal/github/webhooks/webhooks.go @@ -10,7 +10,7 @@ import ( "strings" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/google/go-github/v79/github" ) @@ -64,17 +64,17 @@ type MembershipEvent struct { func ValidateWebhookSignature(payload []byte, signature string, secret string) error { if secret == "" { if signature != "" { - return internalerrors.ErrUnexpectedSignature + return domain.ErrUnexpectedSignature } return nil } if signature == "" { - return internalerrors.ErrMissingSignature + return domain.ErrMissingSignature } if !strings.HasPrefix(signature, "sha256=") { - return errors.Wrap(internalerrors.ErrInvalidSignature, "must start with 'sha256='") + return errors.Wrap(domain.ErrInvalidSignature, "must start with 'sha256='") } mac := hmac.New(sha256.New, []byte(secret)) @@ -83,7 +83,7 @@ func ValidateWebhookSignature(payload []byte, signature string, secret string) e expectedSignature := "sha256=" + expectedMAC if !hmac.Equal([]byte(signature), []byte(expectedSignature)) { - return errors.Wrap(internalerrors.ErrInvalidSignature, "computed signature does not match") + return errors.Wrap(domain.ErrInvalidSignature, "computed signature does not match") } return nil @@ -97,16 +97,16 @@ func ParsePullRequestEvent(payload []byte) (*PullRequestEvent, error) { return nil, errors.Wrap(err, "failed to unmarshal pull request event") } if event.PullRequest == nil { - return nil, errors.Wrap(internalerrors.ErrMissingPRData, "missing pull_request field") + return nil, errors.Wrap(domain.ErrMissingPRData, "missing pull_request field") } if event.PullRequest.Number == nil { - return nil, errors.Wrap(internalerrors.ErrMissingPRData, "missing pr number") + return nil, errors.Wrap(domain.ErrMissingPRData, "missing pr number") } if event.PullRequest.Base == nil || event.PullRequest.Base.Ref == nil { - return nil, errors.Wrap(internalerrors.ErrMissingPRData, "missing base branch") + return nil, errors.Wrap(domain.ErrMissingPRData, "missing base branch") } if event.Repository == nil { - return nil, errors.Wrap(internalerrors.ErrMissingPRData, "missing repository") + return nil, errors.Wrap(domain.ErrMissingPRData, "missing repository") } return &event, nil } diff --git a/internal/github/webhooks/webhooks_test.go b/internal/github/webhooks/webhooks_test.go new file mode 100644 index 0000000..ed9169b --- /dev/null +++ b/internal/github/webhooks/webhooks_test.go @@ -0,0 +1,355 @@ +package webhooks + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "testing" + + "github.com/cockroachdb/errors" + "github.com/cruxstack/github-ops-app/internal/domain" + "github.com/google/go-github/v79/github" +) + +func TestValidateWebhookSignature(t *testing.T) { + tests := []struct { + name string + payload []byte + signature string + secret string + wantErr error + }{ + { + name: "no secret, no signature", + payload: []byte(`{"test": true}`), + signature: "", + secret: "", + wantErr: nil, + }, + { + name: "no secret, unexpected signature", + payload: []byte(`{"test": true}`), + signature: "sha256=abc", + secret: "", + wantErr: domain.ErrUnexpectedSignature, + }, + { + name: "secret configured, missing signature", + payload: []byte(`{"test": true}`), + signature: "", + secret: "mysecret", + wantErr: domain.ErrMissingSignature, + }, + { + name: "wrong prefix", + payload: []byte(`{"test": true}`), + signature: "sha1=abc123", + secret: "mysecret", + wantErr: domain.ErrInvalidSignature, + }, + { + name: "wrong signature value", + payload: []byte(`{"test": true}`), + signature: "sha256=0000000000000000000000000000000000000000000000000000000000000000", + secret: "mysecret", + wantErr: domain.ErrInvalidSignature, + }, + { + name: "valid signature", + payload: []byte(`{}`), + signature: computeSignature([]byte(`{}`), "test-secret"), + secret: "test-secret", + wantErr: nil, + }, + { + name: "empty payload with valid signature", + payload: []byte{}, + signature: computeSignature([]byte{}, "key"), + secret: "key", + wantErr: nil, + }, + { + name: "truncated signature", + payload: []byte(`test`), + signature: "sha256=abc", + secret: "mysecret", + wantErr: domain.ErrInvalidSignature, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateWebhookSignature(tt.payload, tt.signature, tt.secret) + if tt.wantErr == nil { + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + return + } + if err == nil { + t.Errorf("expected error %v, got nil", tt.wantErr) + return + } + if !errors.Is(err, tt.wantErr) { + t.Errorf("expected error %v, got: %v", tt.wantErr, err) + } + }) + } +} + +// computeSignature generates a valid HMAC-SHA256 signature for testing. +func computeSignature(payload []byte, secret string) string { + mac := hmac.New(sha256.New, []byte(secret)) + mac.Write(payload) + return "sha256=" + hex.EncodeToString(mac.Sum(nil)) +} + +func TestParsePullRequestEvent_Valid(t *testing.T) { + prNumber := 42 + baseBranch := "main" + payload := map[string]any{ + "action": "closed", + "number": prNumber, + "pull_request": map[string]any{ + "number": prNumber, + "merged": true, + "base": map[string]any{"ref": baseBranch}, + }, + "repository": map[string]any{ + "name": "test-repo", + "full_name": "owner/test-repo", + "owner": map[string]any{"login": "owner"}, + }, + } + + data, _ := json.Marshal(payload) + event, err := ParsePullRequestEvent(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if event.Number != prNumber { + t.Errorf("expected number %d, got %d", prNumber, event.Number) + } + if !event.IsMerged() { + t.Error("expected IsMerged() to return true") + } + if event.GetBaseBranch() != baseBranch { + t.Errorf("expected base branch %q, got %q", baseBranch, event.GetBaseBranch()) + } +} + +func TestParsePullRequestEvent_MissingFields(t *testing.T) { + tests := []struct { + name string + payload map[string]any + }{ + { + name: "missing pull_request", + payload: map[string]any{"action": "closed"}, + }, + { + name: "missing pr number", + payload: map[string]any{ + "action": "closed", + "pull_request": map[string]any{ + "base": map[string]any{"ref": "main"}, + }, + "repository": map[string]any{}, + }, + }, + { + name: "missing base branch", + payload: map[string]any{ + "action": "closed", + "pull_request": map[string]any{ + "number": 1, + }, + "repository": map[string]any{}, + }, + }, + { + name: "missing repository", + payload: map[string]any{ + "action": "closed", + "pull_request": map[string]any{ + "number": 1, + "base": map[string]any{"ref": "main"}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, _ := json.Marshal(tt.payload) + _, err := ParsePullRequestEvent(data) + if err == nil { + t.Error("expected error for missing fields") + } + if !errors.Is(err, domain.ErrMissingPRData) { + t.Errorf("expected ErrMissingPRData, got: %v", err) + } + }) + } +} + +func TestParsePullRequestEvent_InvalidJSON(t *testing.T) { + _, err := ParsePullRequestEvent([]byte(`{invalid`)) + if err == nil { + t.Error("expected error for invalid JSON") + } +} + +func TestParseTeamEvent_Valid(t *testing.T) { + payload := map[string]any{ + "action": "edited", + "team": map[string]any{"slug": "engineering"}, + "sender": map[string]any{"login": "user1", "type": "User"}, + } + + data, _ := json.Marshal(payload) + event, err := ParseTeamEvent(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if event.Action != "edited" { + t.Errorf("expected action 'edited', got %q", event.Action) + } + if event.GetTeamSlug() != "engineering" { + t.Errorf("expected team slug 'engineering', got %q", event.GetTeamSlug()) + } + if event.GetSenderType() != "User" { + t.Errorf("expected sender type 'User', got %q", event.GetSenderType()) + } +} + +func TestParseTeamEvent_MissingFields(t *testing.T) { + tests := []struct { + name string + payload map[string]any + }{ + { + name: "missing team", + payload: map[string]any{"action": "edited", "sender": map[string]any{"login": "u1"}}, + }, + { + name: "missing sender", + payload: map[string]any{"action": "edited", "team": map[string]any{"slug": "t1"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, _ := json.Marshal(tt.payload) + _, err := ParseTeamEvent(data) + if err == nil { + t.Error("expected error for missing fields") + } + }) + } +} + +func TestParseMembershipEvent_Valid(t *testing.T) { + payload := map[string]any{ + "action": "added", + "scope": "team", + "member": map[string]any{"login": "newuser"}, + "team": map[string]any{"slug": "engineering"}, + "sender": map[string]any{"login": "admin", "type": "User"}, + } + + data, _ := json.Marshal(payload) + event, err := ParseMembershipEvent(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !event.IsTeamScope() { + t.Error("expected IsTeamScope() to return true") + } + if event.GetSenderLogin() != "admin" { + t.Errorf("expected sender login 'admin', got %q", event.GetSenderLogin()) + } +} + +func TestParseMembershipEvent_MissingFields(t *testing.T) { + tests := []struct { + name string + payload map[string]any + }{ + { + name: "missing team", + payload: map[string]any{"action": "added", "member": map[string]any{"login": "u"}, "sender": map[string]any{"login": "s"}}, + }, + { + name: "missing member", + payload: map[string]any{"action": "added", "team": map[string]any{"slug": "t"}, "sender": map[string]any{"login": "s"}}, + }, + { + name: "missing sender", + payload: map[string]any{"action": "added", "team": map[string]any{"slug": "t"}, "member": map[string]any{"login": "u"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, _ := json.Marshal(tt.payload) + _, err := ParseMembershipEvent(data) + if err == nil { + t.Error("expected error for missing fields") + } + }) + } +} + +func TestPullRequestEvent_IsMerged(t *testing.T) { + merged := true + notMerged := false + + tests := []struct { + name string + action string + merged *bool + want bool + }{ + {"closed and merged", "closed", &merged, true}, + {"closed not merged", "closed", ¬Merged, false}, + {"opened", "opened", nil, false}, + {"closed with nil merged", "closed", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + event := &PullRequestEvent{ + Action: tt.action, + PullRequest: &github.PullRequest{ + Merged: tt.merged, + }, + } + if got := event.IsMerged(); got != tt.want { + t.Errorf("IsMerged() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMembershipEvent_IsTeamScope(t *testing.T) { + tests := []struct { + scope string + want bool + }{ + {"team", true}, + {"organization", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.scope, func(t *testing.T) { + event := &MembershipEvent{Scope: tt.scope} + if got := event.IsTeamScope(); got != tt.want { + t.Errorf("IsTeamScope() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/notifiers/slack_messages.go b/internal/notifiers/slack_messages.go index 256754d..2500268 100644 --- a/internal/notifiers/slack_messages.go +++ b/internal/notifiers/slack_messages.go @@ -5,17 +5,15 @@ import ( "fmt" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" - "github.com/cruxstack/github-ops-app/internal/github/client" - "github.com/cruxstack/github-ops-app/internal/okta" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/slack-go/slack" ) // NotifyPRBypass sends a Slack notification when branch protection is // bypassed. -func (s *SlackNotifier) NotifyPRBypass(ctx context.Context, result *client.PRComplianceResult, repoFullName string) error { +func (s *SlackNotifier) NotifyPRBypass(ctx context.Context, result *domain.PRComplianceResult, repoFullName string) error { if result.PR == nil { - return errors.Wrap(internalerrors.ErrMissingPRData, "pr result missing") + return errors.Wrap(domain.ErrMissingPRData, "pr result missing") } prURL := "" @@ -90,14 +88,14 @@ func (s *SlackNotifier) NotifyPRBypass(ctx context.Context, result *client.PRCom } // NotifyOktaSync sends a Slack notification with Okta sync results. -func (s *SlackNotifier) NotifyOktaSync(ctx context.Context, reports []*okta.SyncReport, githubOrg string) error { +func (s *SlackNotifier) NotifyOktaSync(ctx context.Context, reports []*domain.SyncReport, githubOrg string) error { if len(reports) == 0 { return nil } // aggregate stats var totalAdded, totalRemoved int - var rulesWithChanges, rulesWithoutChanges []*okta.SyncReport + var rulesWithChanges, rulesWithoutChanges []*domain.SyncReport var allErrors []string var allSkippedExternal, allSkippedNoGHUsername []string @@ -239,7 +237,7 @@ func (s *SlackNotifier) NotifyOktaSync(ctx context.Context, reports []*okta.Sync // NotifyOrphanedUsers sends a Slack notification about organization members // not in any synced teams. -func (s *SlackNotifier) NotifyOrphanedUsers(ctx context.Context, report *okta.OrphanedUsersReport) error { +func (s *SlackNotifier) NotifyOrphanedUsers(ctx context.Context, report *domain.OrphanedUsersReport) error { if report == nil || len(report.OrphanedUsers) == 0 { return nil } diff --git a/internal/notifiers/slack_messages_test.go b/internal/notifiers/slack_messages_test.go new file mode 100644 index 0000000..18fc39d --- /dev/null +++ b/internal/notifiers/slack_messages_test.go @@ -0,0 +1,235 @@ +package notifiers + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/cruxstack/github-ops-app/internal/domain" + "github.com/google/go-github/v79/github" +) + +// slackTestServer creates a mock Slack API server that captures posted +// messages and returns them via the messages channel. +func slackTestServer(t *testing.T) (*httptest.Server, chan map[string]any) { + t.Helper() + messages := make(chan map[string]any, 10) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + r.Body.Close() + + // parse form-encoded body from slack client + var msg map[string]any + if strings.Contains(r.Header.Get("Content-Type"), "json") { + json.Unmarshal(body, &msg) + } else { + msg = map[string]any{"raw_body": string(body), "path": r.URL.Path} + } + + messages <- msg + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(200) + w.Write([]byte(`{"ok":true,"channel":"C123","ts":"1234.5678"}`)) + })) + + return srv, messages +} + +func TestNotifyPRBypass_Success(t *testing.T) { + srv, messages := slackTestServer(t) + defer srv.Close() + + notifier := NewSlackNotifierWithAPIURL( + "xoxb-test", + SlackChannels{Default: "C_DEFAULT"}, + SlackMessages{PRBypassFooterNote: "review this"}, + srv.URL+"/", + ) + + prNumber := 42 + prTitle := "fix: patch auth" + prURL := "https://github.com/org/repo/pull/42" + mergedBy := "admin-user" + result := &domain.PRComplianceResult{ + PR: &github.PullRequest{ + Number: &prNumber, + Title: &prTitle, + HTMLURL: &prURL, + MergedBy: &github.User{Login: &mergedBy}, + }, + UserHasBypass: true, + UserBypassReason: "repository admin", + Violations: []domain.ComplianceViolation{ + {Type: "insufficient_reviews", Description: "required 2, had 0"}, + }, + } + + err := notifier.NotifyPRBypass(context.Background(), result, "org/repo") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msg := <-messages + if msg == nil { + t.Fatal("expected a message to be posted") + } +} + +func TestNotifyPRBypass_NilPR(t *testing.T) { + notifier := NewSlackNotifier("xoxb-test", SlackChannels{Default: "C"}, SlackMessages{}) + + result := &domain.PRComplianceResult{PR: nil} + err := notifier.NotifyPRBypass(context.Background(), result, "org/repo") + if err == nil { + t.Fatal("expected error for nil PR") + } +} + +func TestNotifyOktaSync_EmptyReports(t *testing.T) { + notifier := NewSlackNotifier("xoxb-test", SlackChannels{Default: "C"}, SlackMessages{}) + + err := notifier.NotifyOktaSync(context.Background(), nil, "org") + if err != nil { + t.Fatalf("expected nil for empty reports, got: %v", err) + } + + err = notifier.NotifyOktaSync(context.Background(), []*domain.SyncReport{}, "org") + if err != nil { + t.Fatalf("expected nil for zero reports, got: %v", err) + } +} + +func TestNotifyOktaSync_Success(t *testing.T) { + srv, messages := slackTestServer(t) + defer srv.Close() + + notifier := NewSlackNotifierWithAPIURL( + "xoxb-test", + SlackChannels{Default: "C_DEFAULT", OktaSync: "C_OKTA"}, + SlackMessages{}, + srv.URL+"/", + ) + + reports := []*domain.SyncReport{ + { + Rule: "eng", OktaGroup: "Engineering", GitHubTeam: "engineering", + MembersAdded: []string{"alice"}, MembersRemoved: []string{"bob"}, + }, + { + Rule: "platform", OktaGroup: "Platform", GitHubTeam: "platform", + }, + } + + err := notifier.NotifyOktaSync(context.Background(), reports, "my-org") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msg := <-messages + if msg == nil { + t.Fatal("expected a message to be posted") + } +} + +func TestNotifyOrphanedUsers_NilReport(t *testing.T) { + notifier := NewSlackNotifier("xoxb-test", SlackChannels{Default: "C"}, SlackMessages{}) + + err := notifier.NotifyOrphanedUsers(context.Background(), nil) + if err != nil { + t.Fatalf("expected nil for nil report, got: %v", err) + } +} + +func TestNotifyOrphanedUsers_EmptyUsers(t *testing.T) { + notifier := NewSlackNotifier("xoxb-test", SlackChannels{Default: "C"}, SlackMessages{}) + + err := notifier.NotifyOrphanedUsers(context.Background(), &domain.OrphanedUsersReport{}) + if err != nil { + t.Fatalf("expected nil for empty users, got: %v", err) + } +} + +func TestNotifyOrphanedUsers_Success(t *testing.T) { + srv, messages := slackTestServer(t) + defer srv.Close() + + notifier := NewSlackNotifierWithAPIURL( + "xoxb-test", + SlackChannels{Default: "C_DEFAULT"}, + SlackMessages{}, + srv.URL+"/", + ) + + report := &domain.OrphanedUsersReport{ + OrphanedUsers: []string{"orphan-1", "orphan-2"}, + } + + err := notifier.NotifyOrphanedUsers(context.Background(), report) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msg := <-messages + if msg == nil { + t.Fatal("expected a message to be posted") + } +} + +func TestNotifyOktaSync_WithErrors(t *testing.T) { + srv, messages := slackTestServer(t) + defer srv.Close() + + notifier := NewSlackNotifierWithAPIURL( + "xoxb-test", + SlackChannels{Default: "C"}, + SlackMessages{}, + srv.URL+"/", + ) + + reports := []*domain.SyncReport{ + { + Rule: "eng", OktaGroup: "Engineering", GitHubTeam: "engineering", + MembersAdded: []string{"alice"}, + Errors: []string{"rate limited"}, + MembersSkippedExternal: []string{"external-1"}, + MembersSkippedNoGHUsername: []string{"newhire@co.com"}, + }, + } + + err := notifier.NotifyOktaSync(context.Background(), reports, "org") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + msg := <-messages + if msg == nil { + t.Fatal("expected a message") + } +} + +func TestChannelFor_CustomChannels(t *testing.T) { + n := &SlackNotifier{ + channels: SlackChannels{ + Default: "C_DEFAULT", + PRBypass: "C_PR", + OktaSync: "C_OKTA", + OrphanedUsers: "C_ORPHAN", + }, + } + + if got := n.channelFor(n.channels.PRBypass); got != "C_PR" { + t.Errorf("expected C_PR, got %s", got) + } + if got := n.channelFor(n.channels.OktaSync); got != "C_OKTA" { + t.Errorf("expected C_OKTA, got %s", got) + } + if got := n.channelFor(n.channels.OrphanedUsers); got != "C_ORPHAN" { + t.Errorf("expected C_ORPHAN, got %s", got) + } +} diff --git a/internal/okta/client.go b/internal/okta/client.go index d1f1adf..d7f085f 100644 --- a/internal/okta/client.go +++ b/internal/okta/client.go @@ -1,5 +1,5 @@ -// Package okta provides Okta API client and group synchronization to GitHub -// teams. Uses OAuth 2.0 with private key authentication. +// Package okta provides Okta API client for group and user management. +// Uses OAuth 2.0 with private key authentication. package okta import ( @@ -9,11 +9,12 @@ import ( "crypto/x509" "encoding/pem" "fmt" + "log/slog" "net/http" "net/url" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" + "github.com/cruxstack/github-ops-app/internal/domain" "github.com/okta/okta-sdk-golang/v6/okta" ) @@ -21,6 +22,15 @@ import ( // these scopes are necessary for group sync functionality. var DefaultScopes = []string{"okta.groups.read", "okta.users.read"} +// certPoolKey is a typed context key for TLS certificate pool injection. +type certPoolKey struct{} + +// WithCertPool returns a new context with the given TLS certificate pool. +// used by integration tests to inject self-signed certs. +func WithCertPool(ctx context.Context, pool *x509.CertPool) context.Context { + return context.WithValue(ctx, certPoolKey{}, pool) +} + // convertToPKCS1 converts a PEM-encoded private key to PKCS#1 format if needed. // the Okta SDK requires PKCS#1 format (BEGIN RSA PRIVATE KEY), but Okta's // console generates PKCS#8 keys (BEGIN PRIVATE KEY). this function detects the @@ -61,12 +71,16 @@ func convertToPKCS1(keyPEM []byte) ([]byte, error) { } // Client wraps the Okta SDK client with custom configuration. +// implements domain.OktaClient. type Client struct { apiClient *okta.APIClient - ctx context.Context githubUserField string + logger *slog.Logger } +// compile-time assertion +var _ domain.OktaClient = (*Client)(nil) + // ClientConfig contains Okta client configuration. type ClientConfig struct { Domain string @@ -76,6 +90,7 @@ type ClientConfig struct { Scopes []string GitHubUserField string BaseURL string + Logger *slog.Logger } // NewClient creates an Okta client with background context. @@ -88,7 +103,7 @@ func NewClient(cfg *ClientConfig) (*Client, error) { // testing. func NewClientWithContext(ctx context.Context, cfg *ClientConfig) (*Client, error) { if cfg.ClientID == "" || len(cfg.PrivateKey) == 0 { - return nil, internalerrors.ErrMissingOAuthCreds + return nil, domain.ErrMissingOAuthCreds } orgURL := cfg.BaseURL @@ -119,7 +134,7 @@ func NewClientWithContext(ctx context.Context, cfg *ClientConfig) (*Client, erro opts = append(opts, okta.WithPrivateKeyId(cfg.PrivateKeyID)) } - if certPool, ok := ctx.Value("okta_tls_cert_pool").(*x509.CertPool); ok && certPool != nil { + if certPool, ok := ctx.Value(certPoolKey{}).(*x509.CertPool); ok && certPool != nil { httpClient := &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ @@ -161,10 +176,15 @@ func NewClientWithContext(ctx context.Context, cfg *ClientConfig) (*Client, erro apiClient := okta.NewAPIClient(oktaCfg) + logger := cfg.Logger + if logger == nil { + logger = slog.Default() + } + return &Client{ apiClient: apiClient, - ctx: ctx, githubUserField: cfg.GitHubUserField, + logger: logger, }, nil } @@ -173,75 +193,96 @@ func (c *Client) GetAPIClient() *okta.APIClient { return c.apiClient } -// GetContext returns the context used for API requests. -func (c *Client) GetContext() context.Context { - return c.ctx -} - -// ListGroups fetches all Okta groups. -func (c *Client) ListGroups() ([]okta.Group, error) { - groups, _, err := c.apiClient.GroupAPI.ListGroups(c.ctx).Execute() +// ListGroups fetches all Okta groups with pagination. +func (c *Client) ListGroups(ctx context.Context) ([]okta.Group, error) { + groups, resp, err := c.apiClient.GroupAPI.ListGroups(ctx).Limit(200).Execute() if err != nil { return nil, errors.Wrap(err, "failed to list groups") } - return groups, nil + + allGroups := make([]okta.Group, 0, len(groups)) + allGroups = append(allGroups, groups...) + + for resp != nil && resp.HasNextPage() { + var nextGroups []okta.Group + resp, err = resp.Next(&nextGroups) + if err != nil { + return nil, errors.Wrap(err, "failed to fetch next page of groups") + } + allGroups = append(allGroups, nextGroups...) + } + + return allGroups, nil } // GetGroupByName searches for an Okta group by exact name match. -func (c *Client) GetGroupByName(name string) (*okta.Group, error) { - groups, _, err := c.apiClient.GroupAPI.ListGroups(c.ctx).Q(name).Execute() +// paginates through results in case the group is not on the first page. +func (c *Client) GetGroupByName(ctx context.Context, name string) (*okta.Group, error) { + groups, resp, err := c.apiClient.GroupAPI.ListGroups(ctx).Q(name).Limit(200).Execute() if err != nil { return nil, errors.Wrapf(err, "failed to search for group '%s'", name) } - for i := range groups { - group := &groups[i] - // check if profile is nil - if group.Profile == nil { - continue - } - - // try OktaUserGroupProfile first - if group.Profile.OktaUserGroupProfile != nil { - groupName := group.Profile.OktaUserGroupProfile.GetName() + for { + for i := range groups { + group := &groups[i] + groupName := extractGroupName(group) if groupName == name { return group, nil } } - // try OktaActiveDirectoryGroupProfile as fallback - if group.Profile.OktaActiveDirectoryGroupProfile != nil { - groupName := group.Profile.OktaActiveDirectoryGroupProfile.GetName() - if groupName == name { - return group, nil - } + if resp == nil || !resp.HasNextPage() { + break } + + var nextGroups []okta.Group + resp, err = resp.Next(&nextGroups) + if err != nil { + return nil, errors.Wrapf(err, "failed to fetch next page while searching for group '%s'", name) + } + groups = nextGroups } return nil, errors.Newf("group '%s' not found", name) } -// GroupMembersResult contains the results of fetching group members. -type GroupMembersResult struct { - Members []string - SkippedNoGitHubUsername []string -} - // GetGroupMembers fetches GitHub usernames for all active members of an Okta -// group. only includes users with status "ACTIVE" to exclude -// suspended/deprovisioned users. skips users without a GitHub username in -// their profile and tracks them separately. -func (c *Client) GetGroupMembers(groupID string) (*GroupMembersResult, error) { - users, _, err := c.apiClient.GroupAPI.ListGroupUsers(c.ctx, groupID).Execute() +// group. paginates through all members. only includes users with status +// "ACTIVE" to exclude suspended/deprovisioned users. skips users without a +// GitHub username in their profile and tracks them separately. +func (c *Client) GetGroupMembers(ctx context.Context, groupID string) (*domain.GroupMembersResult, error) { + users, resp, err := c.apiClient.GroupAPI.ListGroupUsers(ctx, groupID).Limit(200).Execute() if err != nil { return nil, errors.Wrapf(err, "failed to list members for group '%s'", groupID) } - result := &GroupMembersResult{ + result := &domain.GroupMembersResult{ Members: make([]string, 0, len(users)), SkippedNoGitHubUsername: []string{}, } + for { + c.processGroupUsers(users, result) + + if resp == nil || !resp.HasNextPage() { + break + } + + var nextUsers []okta.User + resp, err = resp.Next(&nextUsers) + if err != nil { + return nil, errors.Wrapf(err, "failed to fetch next page of members for group '%s'", groupID) + } + users = nextUsers + } + + return result, nil +} + +// processGroupUsers extracts GitHub usernames from a batch of Okta users +// and appends them to the result. +func (c *Client) processGroupUsers(users []okta.User, result *domain.GroupMembersResult) { for _, user := range users { if user.GetStatus() != "ACTIVE" { continue @@ -270,6 +311,4 @@ func (c *Client) GetGroupMembers(groupID string) (*GroupMembersResult, error) { result.SkippedNoGitHubUsername, profile.GetEmail()) } } - - return result, nil } diff --git a/internal/okta/groups.go b/internal/okta/groups.go index cd3e84a..a0a5acb 100644 --- a/internal/okta/groups.go +++ b/internal/okta/groups.go @@ -1,62 +1,62 @@ package okta import ( + "context" + "log/slog" "regexp" "github.com/cockroachdb/errors" - internalerrors "github.com/cruxstack/github-ops-app/internal/errors" - "github.com/okta/okta-sdk-golang/v6/okta" + "github.com/cruxstack/github-ops-app/internal/domain" + oktasdk "github.com/okta/okta-sdk-golang/v6/okta" ) -// GroupInfo contains Okta group details and member list. -type GroupInfo struct { - ID string - Name string - Members []string - SkippedNoGitHubUsername []string +// extractGroupName returns the group name from either profile type. +func extractGroupName(group *oktasdk.Group) string { + if group == nil || group.Profile == nil { + return "" + } + if group.Profile.OktaUserGroupProfile != nil { + return group.Profile.OktaUserGroupProfile.GetName() + } + if group.Profile.OktaActiveDirectoryGroupProfile != nil { + return group.Profile.OktaActiveDirectoryGroupProfile.GetName() + } + return "" } // GetGroupsByPattern fetches all Okta groups matching a regex pattern. -func (c *Client) GetGroupsByPattern(pattern string) ([]*GroupInfo, error) { +func (c *Client) GetGroupsByPattern(ctx context.Context, pattern string) ([]*domain.GroupInfo, error) { if pattern == "" { - return nil, internalerrors.ErrEmptyPattern + return nil, domain.ErrEmptyPattern } re, err := regexp.Compile(pattern) if err != nil { - return nil, errors.Wrapf(internalerrors.ErrInvalidPattern, "'%s'", pattern) + return nil, errors.Wrapf(domain.ErrInvalidPattern, "'%s'", pattern) } - allGroups, err := c.ListGroups() + allGroups, err := c.ListGroups(ctx) if err != nil { return nil, err } - var matched []*GroupInfo + var matched []*domain.GroupInfo for _, group := range allGroups { - if group.Profile == nil { - continue - } - - // extract group name from either profile type - var groupName string - if group.Profile.OktaUserGroupProfile != nil { - groupName = group.Profile.OktaUserGroupProfile.GetName() - } else if group.Profile.OktaActiveDirectoryGroupProfile != nil { - groupName = group.Profile.OktaActiveDirectoryGroupProfile.GetName() - } - + groupName := extractGroupName(&group) if groupName == "" { continue } if re.MatchString(groupName) { - result, err := c.GetGroupMembers(group.GetId()) + result, err := c.GetGroupMembers(ctx, group.GetId()) if err != nil { + c.logger.Warn("failed to get group members, skipping", + slog.String("group", groupName), + slog.String("error", err.Error())) continue } - matched = append(matched, &GroupInfo{ + matched = append(matched, &domain.GroupInfo{ ID: group.GetId(), Name: groupName, Members: result.Members, @@ -69,28 +69,23 @@ func (c *Client) GetGroupsByPattern(pattern string) ([]*GroupInfo, error) { } // GetGroupInfo fetches details for a single Okta group by name. -func (c *Client) GetGroupInfo(groupName string) (*GroupInfo, error) { - group, err := c.GetGroupByName(groupName) +func (c *Client) GetGroupInfo(ctx context.Context, groupName string) (*domain.GroupInfo, error) { + group, err := c.GetGroupByName(ctx, groupName) if err != nil { return nil, err } - result, err := c.GetGroupMembers(group.GetId()) + result, err := c.GetGroupMembers(ctx, group.GetId()) if err != nil { return nil, err } - // extract group name from either profile type - var name string - if group.Profile != nil { - if group.Profile.OktaUserGroupProfile != nil { - name = group.Profile.OktaUserGroupProfile.GetName() - } else if group.Profile.OktaActiveDirectoryGroupProfile != nil { - name = group.Profile.OktaActiveDirectoryGroupProfile.GetName() - } + name := extractGroupName(group) + if name == "" { + name = groupName } - return &GroupInfo{ + return &domain.GroupInfo{ ID: group.GetId(), Name: name, Members: result.Members, @@ -100,7 +95,7 @@ func (c *Client) GetGroupInfo(groupName string) (*GroupInfo, error) { // FilterEnabledGroups filters Okta groups to only those in the enabled list. // returns all groups if enabled list is empty. -func FilterEnabledGroups(groups []okta.Group, enabledNames []string) []okta.Group { +func FilterEnabledGroups(groups []oktasdk.Group, enabledNames []string) []oktasdk.Group { if len(enabledNames) == 0 { return groups } @@ -110,19 +105,11 @@ func FilterEnabledGroups(groups []okta.Group, enabledNames []string) []okta.Grou enabledMap[name] = true } - var filtered []okta.Group + var filtered []oktasdk.Group for _, group := range groups { - if group.Profile != nil { - var groupName string - if group.Profile.OktaUserGroupProfile != nil { - groupName = group.Profile.OktaUserGroupProfile.GetName() - } else if group.Profile.OktaActiveDirectoryGroupProfile != nil { - groupName = group.Profile.OktaActiveDirectoryGroupProfile.GetName() - } - - if groupName != "" && enabledMap[groupName] { - filtered = append(filtered, group) - } + groupName := extractGroupName(&group) + if groupName != "" && enabledMap[groupName] { + filtered = append(filtered, group) } } diff --git a/internal/okta/sync.go b/internal/sync/sync.go similarity index 68% rename from internal/okta/sync.go rename to internal/sync/sync.go index a2f4794..d2ea4c5 100644 --- a/internal/okta/sync.go +++ b/internal/sync/sync.go @@ -1,4 +1,6 @@ -package okta +// Package sync coordinates synchronization of Okta groups to GitHub teams. +// depends only on domain interfaces, not concrete client implementations. +package sync import ( "context" @@ -8,53 +10,24 @@ import ( "strings" "github.com/cockroachdb/errors" - "github.com/cruxstack/github-ops-app/internal/github/client" - "github.com/cruxstack/github-ops-app/internal/types" + "github.com/cruxstack/github-ops-app/internal/domain" ) -// SyncRule is an alias to types.SyncRule for convenience. -type SyncRule = types.SyncRule - -// SyncReport contains the results of syncing a single Okta group to GitHub -// team. -type SyncReport struct { - Rule string - OktaGroup string - GitHubTeam string - MembersAdded []string - MembersRemoved []string - MembersSkippedExternal []string - MembersSkippedNoGHUsername []string - Errors []string -} - -// OrphanedUsersReport contains users who are org members but not in any synced -// teams. -type OrphanedUsersReport struct { - OrphanedUsers []string -} - -// HasErrors returns true if any errors occurred during sync. -func (r *SyncReport) HasErrors() bool { - return len(r.Errors) > 0 -} - -// HasChanges returns true if members were added or removed. -func (r *SyncReport) HasChanges() bool { - return len(r.MembersAdded) > 0 || len(r.MembersRemoved) > 0 -} +// teamNameNormalizer replaces non-alphanumeric characters (except hyphens) +// in team names. compiled once at package init. +var teamNameNormalizer = regexp.MustCompile(`[^a-z0-9-]+`) // Syncer coordinates synchronization of Okta groups to GitHub teams. type Syncer struct { - oktaClient *Client - githubClient *client.Client - rules []SyncRule + oktaClient domain.OktaClient + githubClient domain.GitHubClient + rules []domain.SyncRule safetyThreshold float64 logger *slog.Logger } // NewSyncer creates a new Okta to GitHub syncer. -func NewSyncer(oktaClient *Client, githubClient *client.Client, rules []SyncRule, safetyThreshold float64, logger *slog.Logger) *Syncer { +func NewSyncer(oktaClient domain.OktaClient, githubClient domain.GitHubClient, rules []domain.SyncRule, safetyThreshold float64, logger *slog.Logger) *Syncer { return &Syncer{ oktaClient: oktaClient, githubClient: githubClient, @@ -64,16 +37,11 @@ func NewSyncer(oktaClient *Client, githubClient *client.Client, rules []SyncRule } } -// SyncResult contains all sync reports and orphaned users report. -type SyncResult struct { - Reports []*SyncReport - OrphanedUsers *OrphanedUsersReport -} - // Sync executes all enabled sync rules and returns reports. // continues processing remaining rules even if some fail. -func (s *Syncer) Sync(ctx context.Context) (*SyncResult, error) { - var reports []*SyncReport +func (s *Syncer) Sync(ctx context.Context) (*domain.SyncResult, error) { + var reports []*domain.SyncReport + var enabledRuleCount int var failedRuleCount int for _, rule := range s.rules { @@ -81,6 +49,8 @@ func (s *Syncer) Sync(ctx context.Context) (*SyncResult, error) { continue } + enabledRuleCount++ + ruleReports, err := s.syncRule(ctx, rule) if err != nil { failedRuleCount++ @@ -88,8 +58,7 @@ func (s *Syncer) Sync(ctx context.Context) (*SyncResult, error) { slog.String("rule", rule.GetName()), slog.String("error", err.Error())) - // create a report for the failed rule so error is visible - reports = append(reports, &SyncReport{ + reports = append(reports, &domain.SyncReport{ Rule: rule.GetName(), OktaGroup: rule.OktaGroupName, GitHubTeam: rule.GitHubTeamName, @@ -101,11 +70,11 @@ func (s *Syncer) Sync(ctx context.Context) (*SyncResult, error) { reports = append(reports, ruleReports...) } - if failedRuleCount > 0 && failedRuleCount == len(reports) { + if enabledRuleCount > 0 && failedRuleCount == enabledRuleCount { return nil, errors.Newf("all sync rules failed: %d errors", failedRuleCount) } - return &SyncResult{ + return &domain.SyncResult{ Reports: reports, OrphanedUsers: nil, }, nil @@ -113,7 +82,7 @@ func (s *Syncer) Sync(ctx context.Context) (*SyncResult, error) { // DetectOrphanedUsers finds organization members not in any synced teams. // excludes external collaborators. -func (s *Syncer) DetectOrphanedUsers(ctx context.Context, syncedTeams []string) (*OrphanedUsersReport, error) { +func (s *Syncer) DetectOrphanedUsers(ctx context.Context, syncedTeams []string) (*domain.OrphanedUsersReport, error) { orgMembers, err := s.githubClient.ListOrgMembers(ctx) if err != nil { return nil, errors.Wrap(err, "failed to list organization members") @@ -150,34 +119,34 @@ func (s *Syncer) DetectOrphanedUsers(ctx context.Context, syncedTeams []string) } } - return &OrphanedUsersReport{ + return &domain.OrphanedUsersReport{ OrphanedUsers: orphanedUsers, }, nil } // syncRule executes a single sync rule. // supports both pattern matching and exact group name matching. -func (s *Syncer) syncRule(ctx context.Context, rule SyncRule) ([]*SyncReport, error) { - var reports []*SyncReport +func (s *Syncer) syncRule(ctx context.Context, rule domain.SyncRule) ([]*domain.SyncReport, error) { + var reports []*domain.SyncReport if rule.OktaGroupPattern != "" { - groups, err := s.oktaClient.GetGroupsByPattern(rule.OktaGroupPattern) + groups, err := s.oktaClient.GetGroupsByPattern(ctx, rule.OktaGroupPattern) if err != nil { return nil, errors.Wrapf(err, "failed to match groups with pattern '%s'", rule.OktaGroupPattern) } for _, group := range groups { - teamName := s.computeTeamName(group.Name, rule) + teamName := computeTeamName(group.Name, rule) report := s.syncGroupToTeam(ctx, rule, group, teamName) reports = append(reports, report) } } else if rule.OktaGroupName != "" { - group, err := s.oktaClient.GetGroupInfo(rule.OktaGroupName) + group, err := s.oktaClient.GetGroupInfo(ctx, rule.OktaGroupName) if err != nil { return nil, errors.Wrapf(err, "failed to fetch group '%s'", rule.OktaGroupName) } - teamName := s.computeTeamName(group.Name, rule) + teamName := computeTeamName(group.Name, rule) report := s.syncGroupToTeam(ctx, rule, group, teamName) reports = append(reports, report) } @@ -187,7 +156,7 @@ func (s *Syncer) syncRule(ctx context.Context, rule SyncRule) ([]*SyncReport, er // computeTeamName generates GitHub team name from Okta group name. // applies prefix stripping, prefix addition, and normalization. -func (s *Syncer) computeTeamName(oktaGroupName string, rule SyncRule) string { +func computeTeamName(oktaGroupName string, rule domain.SyncRule) string { if rule.GitHubTeamName != "" { return rule.GitHubTeamName } @@ -203,15 +172,16 @@ func (s *Syncer) computeTeamName(oktaGroupName string, rule SyncRule) string { } teamName = strings.ToLower(teamName) - teamName = regexp.MustCompile(`[^a-z0-9-]`).ReplaceAllString(teamName, "-") + teamName = teamNameNormalizer.ReplaceAllString(teamName, "-") + teamName = strings.Trim(teamName, "-") return teamName } // syncGroupToTeam synchronizes a single Okta group to a GitHub team. // creates team if missing and syncs members if enabled. -func (s *Syncer) syncGroupToTeam(ctx context.Context, rule SyncRule, group *GroupInfo, teamName string) *SyncReport { - report := &SyncReport{ +func (s *Syncer) syncGroupToTeam(ctx context.Context, rule domain.SyncRule, group *domain.GroupInfo, teamName string) *domain.SyncReport { + report := &domain.SyncReport{ Rule: rule.GetName(), OktaGroup: group.Name, GitHubTeam: teamName, diff --git a/internal/sync/sync_test.go b/internal/sync/sync_test.go new file mode 100644 index 0000000..59c2576 --- /dev/null +++ b/internal/sync/sync_test.go @@ -0,0 +1,390 @@ +package sync + +import ( + "context" + "log/slog" + "os" + "testing" + + "github.com/cockroachdb/errors" + "github.com/cruxstack/github-ops-app/internal/domain" + "github.com/google/go-github/v79/github" +) + +// mockGitHubClient implements domain.GitHubClient with overridable function +// fields for testing. +type mockGitHubClient struct { + getOrCreateTeamFn func(ctx context.Context, teamName, privacy string) (*github.Team, error) + syncTeamMembersFn func(ctx context.Context, teamSlug string, desiredMembers []string, threshold float64) (*domain.TeamSyncResult, error) + getTeamMembersFn func(ctx context.Context, teamSlug string) ([]string, error) + listOrgMembersFn func(ctx context.Context) ([]string, error) + isExternalCollaboratorFn func(ctx context.Context, username string) (bool, error) + getAppSlugFn func(ctx context.Context) (string, error) + checkPRComplianceFn func(ctx context.Context, owner, repo string, prNumber int) (*domain.PRComplianceResult, error) +} + +func (m *mockGitHubClient) CheckPRCompliance(ctx context.Context, owner, repo string, prNumber int) (*domain.PRComplianceResult, error) { + if m.checkPRComplianceFn != nil { + return m.checkPRComplianceFn(ctx, owner, repo, prNumber) + } + return &domain.PRComplianceResult{}, nil +} + +func (m *mockGitHubClient) GetOrCreateTeam(ctx context.Context, teamName, privacy string) (*github.Team, error) { + if m.getOrCreateTeamFn != nil { + return m.getOrCreateTeamFn(ctx, teamName, privacy) + } + slug := teamName + return &github.Team{Slug: &slug}, nil +} + +func (m *mockGitHubClient) SyncTeamMembers(ctx context.Context, teamSlug string, desiredMembers []string, threshold float64) (*domain.TeamSyncResult, error) { + if m.syncTeamMembersFn != nil { + return m.syncTeamMembersFn(ctx, teamSlug, desiredMembers, threshold) + } + return &domain.TeamSyncResult{ + TeamName: teamSlug, + MembersAdded: desiredMembers, + }, nil +} + +func (m *mockGitHubClient) GetTeamMembers(ctx context.Context, teamSlug string) ([]string, error) { + if m.getTeamMembersFn != nil { + return m.getTeamMembersFn(ctx, teamSlug) + } + return []string{}, nil +} + +func (m *mockGitHubClient) ListOrgMembers(ctx context.Context) ([]string, error) { + if m.listOrgMembersFn != nil { + return m.listOrgMembersFn(ctx) + } + return []string{}, nil +} + +func (m *mockGitHubClient) IsExternalCollaborator(ctx context.Context, username string) (bool, error) { + if m.isExternalCollaboratorFn != nil { + return m.isExternalCollaboratorFn(ctx, username) + } + return false, nil +} + +func (m *mockGitHubClient) GetAppSlug(ctx context.Context) (string, error) { + if m.getAppSlugFn != nil { + return m.getAppSlugFn(ctx) + } + return "test-app", nil +} + +func (m *mockGitHubClient) GetOrg() string { + return "test-org" +} + +// mockOktaClient implements domain.OktaClient with overridable function +// fields for testing. +type mockOktaClient struct { + getGroupsByPatternFn func(ctx context.Context, pattern string) ([]*domain.GroupInfo, error) + getGroupInfoFn func(ctx context.Context, groupName string) (*domain.GroupInfo, error) +} + +func (m *mockOktaClient) GetGroupsByPattern(ctx context.Context, pattern string) ([]*domain.GroupInfo, error) { + if m.getGroupsByPatternFn != nil { + return m.getGroupsByPatternFn(ctx, pattern) + } + return []*domain.GroupInfo{}, nil +} + +func (m *mockOktaClient) GetGroupInfo(ctx context.Context, groupName string) (*domain.GroupInfo, error) { + if m.getGroupInfoFn != nil { + return m.getGroupInfoFn(ctx, groupName) + } + return &domain.GroupInfo{ + ID: "group-1", + Name: groupName, + Members: []string{}, + }, nil +} + +func testLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError})) +} + +func TestSync_SingleRule(t *testing.T) { + oktaClient := &mockOktaClient{ + getGroupInfoFn: func(_ context.Context, name string) (*domain.GroupInfo, error) { + return &domain.GroupInfo{ + ID: "g1", + Name: "Engineering", + Members: []string{"alice", "bob"}, + }, nil + }, + } + + ghClient := &mockGitHubClient{ + syncTeamMembersFn: func(_ context.Context, teamSlug string, members []string, _ float64) (*domain.TeamSyncResult, error) { + return &domain.TeamSyncResult{ + TeamName: teamSlug, + MembersAdded: members, + }, nil + }, + } + + rules := []domain.SyncRule{ + { + OktaGroupName: "Engineering", + GitHubTeamName: "engineering", + }, + } + + syncer := NewSyncer(oktaClient, ghClient, rules, 0.5, testLogger()) + result, err := syncer.Sync(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Reports) != 1 { + t.Fatalf("expected 1 report, got %d", len(result.Reports)) + } + + report := result.Reports[0] + if report.GitHubTeam != "engineering" { + t.Errorf("expected team 'engineering', got %q", report.GitHubTeam) + } + if len(report.MembersAdded) != 2 { + t.Errorf("expected 2 members added, got %d", len(report.MembersAdded)) + } +} + +func TestSync_DisabledRule(t *testing.T) { + disabled := false + rules := []domain.SyncRule{ + { + Enabled: &disabled, + OktaGroupName: "Engineering", + GitHubTeamName: "engineering", + }, + } + + syncer := NewSyncer(&mockOktaClient{}, &mockGitHubClient{}, rules, 0.5, testLogger()) + result, err := syncer.Sync(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Reports) != 0 { + t.Errorf("expected 0 reports for disabled rule, got %d", len(result.Reports)) + } +} + +func TestSync_OktaError_PartialSuccess(t *testing.T) { + callCount := 0 + oktaClient := &mockOktaClient{ + getGroupInfoFn: func(_ context.Context, name string) (*domain.GroupInfo, error) { + callCount++ + if callCount == 1 { + return nil, errors.New("okta api error") + } + return &domain.GroupInfo{ID: "g2", Name: name, Members: []string{"charlie"}}, nil + }, + } + + rules := []domain.SyncRule{ + {OktaGroupName: "BadGroup", GitHubTeamName: "bad-team"}, + {OktaGroupName: "GoodGroup", GitHubTeamName: "good-team"}, + } + + syncer := NewSyncer(oktaClient, &mockGitHubClient{}, rules, 0.5, testLogger()) + result, err := syncer.Sync(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Reports) != 2 { + t.Fatalf("expected 2 reports, got %d", len(result.Reports)) + } + + // first report should have errors + if !result.Reports[0].HasErrors() { + t.Error("expected first report to have errors") + } + // second report should succeed + if result.Reports[1].HasErrors() { + t.Errorf("expected second report to have no errors, got: %v", result.Reports[1].Errors) + } +} + +func TestSync_AllRulesFail(t *testing.T) { + oktaClient := &mockOktaClient{ + getGroupInfoFn: func(_ context.Context, name string) (*domain.GroupInfo, error) { + return nil, errors.New("api failure") + }, + } + + rules := []domain.SyncRule{ + {OktaGroupName: "Group1", GitHubTeamName: "team1"}, + } + + syncer := NewSyncer(oktaClient, &mockGitHubClient{}, rules, 0.5, testLogger()) + _, err := syncer.Sync(context.Background()) + if err == nil { + t.Fatal("expected error when all rules fail") + } +} + +func TestSync_PatternMatching(t *testing.T) { + oktaClient := &mockOktaClient{ + getGroupsByPatternFn: func(_ context.Context, pattern string) ([]*domain.GroupInfo, error) { + return []*domain.GroupInfo{ + {ID: "g1", Name: "eng-frontend", Members: []string{"alice"}}, + {ID: "g2", Name: "eng-backend", Members: []string{"bob"}}, + }, nil + }, + } + + rules := []domain.SyncRule{ + {OktaGroupPattern: "^eng-.*$"}, + } + + syncer := NewSyncer(oktaClient, &mockGitHubClient{}, rules, 0.5, testLogger()) + result, err := syncer.Sync(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Reports) != 2 { + t.Fatalf("expected 2 reports for pattern match, got %d", len(result.Reports)) + } +} + +func TestSync_SkippedUsersTracked(t *testing.T) { + oktaClient := &mockOktaClient{ + getGroupInfoFn: func(_ context.Context, name string) (*domain.GroupInfo, error) { + return &domain.GroupInfo{ + ID: "g1", + Name: name, + Members: []string{"alice"}, + SkippedNoGitHubUsername: []string{"new-hire@example.com"}, + }, nil + }, + } + + rules := []domain.SyncRule{ + {OktaGroupName: "Engineering", GitHubTeamName: "engineering"}, + } + + syncer := NewSyncer(oktaClient, &mockGitHubClient{}, rules, 0.5, testLogger()) + result, err := syncer.Sync(context.Background()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(result.Reports[0].MembersSkippedNoGHUsername) != 1 { + t.Errorf("expected 1 skipped user, got %d", len(result.Reports[0].MembersSkippedNoGHUsername)) + } +} + +func TestDetectOrphanedUsers(t *testing.T) { + ghClient := &mockGitHubClient{ + listOrgMembersFn: func(_ context.Context) ([]string, error) { + return []string{"alice", "bob", "charlie"}, nil + }, + getTeamMembersFn: func(_ context.Context, teamSlug string) ([]string, error) { + return []string{"alice", "bob"}, nil + }, + isExternalCollaboratorFn: func(_ context.Context, username string) (bool, error) { + return false, nil + }, + } + + syncer := NewSyncer(&mockOktaClient{}, ghClient, nil, 0.5, testLogger()) + report, err := syncer.DetectOrphanedUsers(context.Background(), []string{"engineering"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(report.OrphanedUsers) != 1 { + t.Fatalf("expected 1 orphaned user, got %d", len(report.OrphanedUsers)) + } + if report.OrphanedUsers[0] != "charlie" { + t.Errorf("expected orphaned user 'charlie', got %q", report.OrphanedUsers[0]) + } +} + +func TestDetectOrphanedUsers_ExternalCollaboratorExcluded(t *testing.T) { + ghClient := &mockGitHubClient{ + listOrgMembersFn: func(_ context.Context) ([]string, error) { + return []string{"alice", "external-user"}, nil + }, + getTeamMembersFn: func(_ context.Context, _ string) ([]string, error) { + return []string{"alice"}, nil + }, + isExternalCollaboratorFn: func(_ context.Context, username string) (bool, error) { + return username == "external-user", nil + }, + } + + syncer := NewSyncer(&mockOktaClient{}, ghClient, nil, 0.5, testLogger()) + report, err := syncer.DetectOrphanedUsers(context.Background(), []string{"engineering"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(report.OrphanedUsers) != 0 { + t.Errorf("expected 0 orphaned users (external excluded), got %d", len(report.OrphanedUsers)) + } +} + +func TestComputeTeamName(t *testing.T) { + tests := []struct { + name string + groupName string + rule domain.SyncRule + want string + }{ + { + name: "explicit team name", + groupName: "Engineering", + rule: domain.SyncRule{GitHubTeamName: "custom-team"}, + want: "custom-team", + }, + { + name: "lowercase and normalize", + groupName: "My Team Name", + rule: domain.SyncRule{}, + want: "my-team-name", + }, + { + name: "strip prefix", + groupName: "DEPT-Engineering", + rule: domain.SyncRule{StripPrefix: "DEPT-"}, + want: "engineering", + }, + { + name: "add prefix", + groupName: "frontend", + rule: domain.SyncRule{GitHubTeamPrefix: "eng-"}, + want: "eng-frontend", + }, + { + name: "strip and add prefix", + groupName: "OKTA-backend", + rule: domain.SyncRule{StripPrefix: "OKTA-", GitHubTeamPrefix: "gh-"}, + want: "gh-backend", + }, + { + name: "special characters replaced and dashes collapsed", + groupName: "Team (US) & EU", + rule: domain.SyncRule{}, + want: "team-us-eu", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := computeTeamName(tt.groupName, tt.rule) + if got != tt.want { + t.Errorf("computeTeamName(%q) = %q, want %q", tt.groupName, got, tt.want) + } + }) + } +} diff --git a/internal/types/sync.go b/internal/types/sync.go deleted file mode 100644 index 303f1c6..0000000 --- a/internal/types/sync.go +++ /dev/null @@ -1,38 +0,0 @@ -// Package types provides shared type definitions used across packages. -package types - -// SyncRule defines how to sync Okta groups to GitHub teams. -type SyncRule struct { - Name string `json:"name"` - Enabled *bool `json:"enabled,omitempty"` - OktaGroupPattern string `json:"okta_group_pattern,omitempty"` - OktaGroupName string `json:"okta_group_name,omitempty"` - GitHubTeamPrefix string `json:"github_team_prefix,omitempty"` - GitHubTeamName string `json:"github_team_name,omitempty"` - StripPrefix string `json:"strip_prefix,omitempty"` - SyncMembers *bool `json:"sync_members,omitempty"` - CreateTeamIfMissing bool `json:"create_team_if_missing"` - TeamPrivacy string `json:"team_privacy,omitempty"` -} - -// IsEnabled returns true if the rule is enabled (defaults to true). -func (r SyncRule) IsEnabled() bool { - return r.Enabled == nil || *r.Enabled -} - -// ShouldSyncMembers returns true if members should be synced (defaults to -// true). -func (r SyncRule) ShouldSyncMembers() bool { - return r.SyncMembers == nil || *r.SyncMembers -} - -// GetName returns the rule name, defaulting to GitHubTeamName if not set. -func (r SyncRule) GetName() string { - if r.Name != "" { - return r.Name - } - if r.GitHubTeamName != "" { - return r.GitHubTeamName - } - return r.OktaGroupName -} From 2dc32b5febcd81e88752646d3f9e9d674e7eb511 Mon Sep 17 00:00:00 2001 From: Brian Ojeda <9335829+sgtoj@users.noreply.github.com> Date: Tue, 17 Feb 2026 15:48:04 -0500 Subject: [PATCH 2/3] dev: improve dx --- .dockerignore | 5 ++ Dockerfile | 11 ++++ Makefile | 12 +++++ compose.yaml | 10 ++++ docs/github-app-setup.md | 107 +++++++++++++++++++++++++-------------- docs/okta-setup.md | 84 +++++++++++++++++------------- docs/slack-setup.md | 37 ++++++++------ 7 files changed, 176 insertions(+), 90 deletions(-) create mode 100644 .dockerignore create mode 100644 Dockerfile create mode 100644 compose.yaml diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..509f8f4 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,5 @@ +.local/ +.env +dist/ +tmp/ +.git/ diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..da8effe --- /dev/null +++ b/Dockerfile @@ -0,0 +1,11 @@ +FROM golang:1.24 AS builder + +WORKDIR /src +COPY go.mod go.sum ./ +RUN go mod download +COPY . . +RUN CGO_ENABLED=0 go build -trimpath -ldflags "-s -w" -o /server ./cmd/server + +FROM gcr.io/distroless/static-debian12 +COPY --from=builder /server /server +ENTRYPOINT ["/server"] diff --git a/Makefile b/Makefile index 121fa72..cdbd6db 100644 --- a/Makefile +++ b/Makefile @@ -43,3 +43,15 @@ test-verify: test-verify-verbose: go run ./cmd/verify -verbose +.PHONY: server-up +server-up: + docker compose up -d --build + +.PHONY: server-logs +server-logs: + docker compose logs -f server + +.PHONY: server-stop +server-stop: + docker compose down --rmi local --remove-orphans + diff --git a/compose.yaml b/compose.yaml new file mode 100644 index 0000000..28ef6a2 --- /dev/null +++ b/compose.yaml @@ -0,0 +1,10 @@ +services: + server: + build: . + container_name: github-ops-server + env_file: .env + ports: + - "${APP_PORT:-8080}:${APP_PORT:-8080}" + volumes: + - ./.local:/.local:ro + restart: unless-stopped diff --git a/docs/github-app-setup.md b/docs/github-app-setup.md index 360bab7..feac482 100644 --- a/docs/github-app-setup.md +++ b/docs/github-app-setup.md @@ -1,6 +1,7 @@ # GitHub App Setup -This guide walks through creating and installing a GitHub App for github-ops-app. +This guide walks through creating and installing a GitHub App for +github-ops-app. ## Prerequisites @@ -9,39 +10,50 @@ This guide walks through creating and installing a GitHub App for github-ops-app ## Step 1: Create the GitHub App 1. Navigate to your organization's settings: + - Go to `https://github.com/organizations/YOUR_ORG/settings/apps` - - Or: **Organization** → **Settings** → **Developer settings** → **GitHub Apps** + - Or: **Organization** → **Settings** → **Developer settings** → **GitHub + Apps** 2. Click **New GitHub App** 3. Fill in the basic information: - | Field | Value | - |-----------------------|-------------------------------------------------| - | GitHub App name | `github-ops-app` (must be unique across GitHub) | - | Homepage URL | Your organization's URL or repo URL | - | Webhook > Webhook URL | Leave blank for now | - | Webhook > Secret | Generate a strong secret (save this for later) | - | Webhook > Active | **Uncheck** to disable webhooks initially | - - > **Note**: Disable webhooks during creation since you may not know your - > endpoint URL until after deployment. You'll configure webhooks and - > subscribe to events in [Step 7](#step-7-configure-webhook-and-events). - -4. Under **Permissions**, set the following: - - Repository Permissions - - Contents: Read - - Read branch protection rules - - Pull requests: Read - - Access PR details for compliance - - Organization Permissions - - Administration: Read - - Read organization settings - - Members: Read/Write - - Manage team membership - -4. Under Set installation scope: - - Where can this GitHub App be installed?: Only on this account + | Field | Value | + | --------------------- | ----------------------------------------------- | + | GitHub App name | `github-ops-app` (must be unique across GitHub) | + | Homepage URL | Your organization's URL or repo URL | + | Webhook > Webhook URL | Leave blank for now | + | Webhook > Secret | Generate a strong secret (save this for later) | + | Webhook > Active | **Uncheck** to disable webhooks initially | + + > **Note**: Disable webhooks during creation since you may not know your + > endpoint URL until after deployment. You'll configure webhooks and + > subscribe to events in [Step 7](#step-7-configure-webhook-and-events). + +### Configure Permissions + +Under **Permissions**, set the following: + +### Repository Permissions + +| Permission | Access | Purpose | +| ------------- | ------ | -------------------------------- | +| Contents | Read | Read branch protection rules | +| Pull requests | Read | Access PR details for compliance | + +#### Organization Permissions + +| Permission | Access | Purpose | +| -------------- | ---------- | -------------------------- | +| Members | Read/Write | Manage team membership | +| Administration | Read | Read organization settings | + +4. Set installation scope: + + | Setting | Value | + | --------------------------------------- | -------------------- | + | Where can this GitHub App be installed? | Only on this account | 5. Click **Create GitHub App** @@ -73,6 +85,7 @@ On the app's settings page, find and save: ## Step 5: Get Installation ID After installation, you'll be redirected to a URL like: + ``` https://github.com/organizations/YOUR_ORG/settings/installations/12345678 ``` @@ -80,6 +93,7 @@ https://github.com/organizations/YOUR_ORG/settings/installations/12345678 The number at the end (`12345678`) is your **Installation ID**. Alternatively, use the GitHub API: + ```bash # List installations (requires app JWT authentication) curl -H "Authorization: Bearer YOUR_JWT" \ @@ -115,25 +129,40 @@ After deploying your server, configure and enable webhooks: 1. Go to your GitHub App settings: `https://github.com/organizations/YOUR_ORG/settings/apps/YOUR_APP` -2. On the **General** tab, under **Webhook**: - - Set **Webhook URL** to your endpoint: - - Lambda: `https://xxx.execute-api.region.amazonaws.com/webhooks` - - Server: `https://your-domain.com/webhooks` - - Check **Active** to enable webhooks - - Click **Save changes** -3. Go to the **Permissions & events** tab -4. Scroll to **Subscribe to events** and check: + +2. Set **Webhook URL** to your endpoint: + + - Lambda: `https://xxx.execute-api.region.amazonaws.com/webhooks` + - Server: `https://your-domain.com/webhooks` + +3. Check **Active** to enable webhooks + +4. Click **Save changes** + +5. Under **Subscribe to events**, check: + - [x] **Pull request** - PR open, close, merge events - [x] **Team** - Team creation, deletion, changes - [x] **Membership** - Team membership changes -5. Click **Save changes** + +6. Click **Save changes** + +## Using the App Manifest (Alternative) + +For automated setup, use the manifest at `assets/github/manifest.json`: + +1. Go to `https://github.com/settings/apps/new` +2. Append `?manifest=` with URL-encoded manifest JSON +3. Or use the manifest creation API + +The manifest pre-configures all required permissions and events. ## Verification Test your setup: -1. **Webhook delivery**: Check **Settings** → **Developer settings** → - **GitHub Apps** → your app → **Advanced** → **Recent Deliveries** +1. **Webhook delivery**: Check **Settings** → **Developer settings** → **GitHub + Apps** → your app → **Advanced** → **Recent Deliveries** 2. **Create a test PR**: Open and merge a PR to a monitored branch to verify webhook reception diff --git a/docs/okta-setup.md b/docs/okta-setup.md index b8955a9..ba660b9 100644 --- a/docs/okta-setup.md +++ b/docs/okta-setup.md @@ -11,14 +11,18 @@ github-ops-app to sync Okta groups with GitHub teams. ## Step 1: Create API Services Application 1. Log in to your **Okta Admin Console** + 2. Navigate to **Applications** → **Applications** + 3. Click **Create App Integration** + 4. Select **API Services** and click **Next** > API Services apps use OAuth 2.0 client credentials flow with no user > context, ideal for server-to-server integrations. 5. Enter application name: `github-ops-app` (or similar) + 6. Click **Save** ## Step 2: Configure Client Authentication @@ -64,12 +68,13 @@ On the **General** tab, find and save: ## Step 5: Grant API Scopes 1. Go to the **Okta API Scopes** tab + 2. Grant the following scopes: - | Scope | Purpose | - |--------------------|-------------------------------| - | `okta.groups.read` | Read group names and members | - | `okta.users.read` | Read user profiles | + | Scope | Purpose | + | ------------------ | ---------------------------- | + | `okta.groups.read` | Read group names and members | + | `okta.users.read` | Read user profiles | 3. Click **Grant** for each scope @@ -82,22 +87,26 @@ API Services applications require an admin role to access Okta APIs. Without this, API calls will fail with permission errors even if scopes are granted. 1. Go to the **Admin roles** tab for your application + 2. Click **Edit assignments** + 3. Select one of the following roles: - | Role | Access Level | - |---------------------|-----------------------------------------------| - | **Read Only Admin** | Read access to all resources (recommended) | - | **Group Admin** | Full access to groups only | + | Role | Access Level | + | ------------------- | ------------------------------------------ | + | **Read Only Admin** | Read access to all resources (recommended) | + | **Group Admin** | Full access to groups only | 4. If using **Group Admin**, optionally restrict to specific groups: - - Under **Edit constraints for Group Administrator**, select specific - groups or group types the app can access + + - Under **Edit constraints for Group Administrator**, select specific groups + or group types the app can access + 5. Click **Save changes** -> **Note**: Read Only Admin is recommended for sync operations since it -> provides sufficient access without write permissions. Group Admin is an -> alternative if you need to limit the app's scope to group resources only. +> **Note**: Read Only Admin is recommended for sync operations since it provides +> sufficient access without write permissions. Group Admin is an alternative if +> you need to limit the app's scope to group resources only. ## Step 7: Identify Your Okta Domain @@ -113,12 +122,12 @@ Use the domain without `https://` prefix for `APP_OKTA_DOMAIN`. The app needs to map Okta users to GitHub usernames. Determine which Okta user profile field contains GitHub usernames: -| Common Fields | Description | -|--------------------|------------------------------------------| -| `login` | Okta username (often email) | -| `email` | User's email address | -| `githubUsername` | Custom field (recommended) | -| `nickName` | Sometimes used for GitHub username | +| Common Fields | Description | +| ---------------- | ---------------------------------- | +| `login` | Okta username (often email) | +| `email` | User's email address | +| `githubUsername` | Custom field (recommended) | +| `nickName` | Sometimes used for GitHub username | ### Adding a Custom GitHub Username Field (Recommended) @@ -141,13 +150,14 @@ rules: **Example naming conventions:** -| Pattern | Example Groups | -|----------------------|----------------------------------------------| -| `github-{team}` | `github-engineering`, `github-platform` | -| `gh-eng-{team}` | `gh-eng-frontend`, `gh-eng-backend` | -| `Team - {name}` | `Team - Platform`, `Team - Security` | +| Pattern | Example Groups | +| --------------- | --------------------------------------- | +| `github-{team}` | `github-engineering`, `github-platform` | +| `gh-eng-{team}` | `gh-eng-frontend`, `gh-eng-backend` | +| `Team - {name}` | `Team - Platform`, `Team - Security` | Groups can be: + - Okta groups (manually managed) - Groups synced from Active Directory - Groups from other identity providers @@ -192,18 +202,18 @@ APP_OKTA_SYNC_RULES='[ ### Rule Fields -| Field | Description | -|-------------------------|------------------------------------------------------| -| `name` | Rule identifier (for logging) | -| `enabled` | Enable/disable rule (default: `true`) | -| `okta_group_pattern` | Regex to match Okta groups | -| `okta_group_name` | Exact Okta group name (alternative to pattern) | -| `github_team_prefix` | Prefix for generated GitHub team names | -| `github_team_name` | Exact GitHub team name (overrides pattern) | -| `strip_prefix` | Remove this prefix from Okta group name | -| `sync_members` | Sync members between Okta and GitHub (default: `true`)| -| `create_team_if_missing`| Auto-create GitHub teams if they don't exist | -| `team_privacy` | GitHub team visibility: `secret` or `closed` | +| Field | Description | +| ------------------------ | ------------------------------------------------------ | +| `name` | Rule identifier (for logging) | +| `enabled` | Enable/disable rule (default: `true`) | +| `okta_group_pattern` | Regex to match Okta groups | +| `okta_group_name` | Exact Okta group name (alternative to pattern) | +| `github_team_prefix` | Prefix for generated GitHub team names | +| `github_team_name` | Exact GitHub team name (overrides pattern) | +| `strip_prefix` | Remove this prefix from Okta group name | +| `sync_members` | Sync members between Okta and GitHub (default: `true`) | +| `create_team_if_missing` | Auto-create GitHub teams if they don't exist | +| `team_privacy` | GitHub team visibility: `secret` or `closed` | See the [main README](../README.md#okta-sync-rules) for additional examples. @@ -221,6 +231,7 @@ Test your Okta configuration: ``` Trigger a sync and verify: + 1. POST to `/scheduled/okta-sync` endpoint 2. Check logs for groups discovered and teams synced 3. Verify GitHub team memberships match Okta groups @@ -249,6 +260,7 @@ Trigger a sync and verify: ### Rate limiting Okta has API rate limits. If you hit limits: + - Reduce sync frequency - The app handles rate limit responses gracefully diff --git a/docs/slack-setup.md b/docs/slack-setup.md index c7208a6..1b4b4ee 100644 --- a/docs/slack-setup.md +++ b/docs/slack-setup.md @@ -15,8 +15,9 @@ This guide walks through creating a Slack app for github-ops-app notifications. 2. Click **Create New App** 3. Select **From an app manifest** 4. Select your workspace -5. Copy the contents of [`assets/slack/manifest.json`](../assets/slack/manifest.json) - and paste into the manifest editor +5. Copy the contents of + [`assets/slack/manifest.json`](../assets/slack/manifest.json) and paste into + the manifest editor 6. Click **Create** ### Option B: Manual Setup @@ -34,15 +35,17 @@ Then continue to configure OAuth scopes manually (Step 2). If you used the manifest, scopes are pre-configured. Otherwise: 1. Go to **OAuth & Permissions** in the sidebar + 2. Scroll to **Scopes** → **Bot Token Scopes** + 3. Add the following scopes: - | Scope | Purpose | - |---------------------|----------------------------------------------| - | `chat:write` | Post messages to channels bot is member of | - | `chat:write.public` | Post to public channels without joining | - | `channels:read` | View basic channel info | - | `channels:join` | Join public channels | + | Scope | Purpose | + | ------------------- | ------------------------------------------ | + | `chat:write` | Post messages to channels bot is member of | + | `chat:write.public` | Post to public channels without joining | + | `channels:read` | View basic channel info | + | `channels:join` | Join public channels | ## Step 3: Install to Workspace @@ -51,6 +54,7 @@ If you used the manifest, scopes are pre-configured. Otherwise: 3. Review permissions and click **Allow** If your workspace requires admin approval: + - Submit the app for approval - Wait for workspace admin to approve - Return to install after approval @@ -129,7 +133,8 @@ Make notifications more recognizable: 2. Under **Display Information**: - **App name**: `GitHub Ops Bot` (or your preference) - **Short description**: Brief description of the bot - - **App icon**: Upload a custom icon (use `assets/slack/icon.png` or your own) + - **App icon**: Upload a custom icon (use `assets/slack/icon.png` or your + own) - **Background color**: `#10203B` or your brand color ## Verification @@ -148,6 +153,7 @@ curl -X POST https://slack.com/api/chat.postMessage \ ``` Expected response: + ```json { "ok": true, @@ -161,12 +167,12 @@ Expected response: The bot sends these notification types: -| Event | Description | -|-----------------------|------------------------------------------------| -| PR Compliance Alert | PR merged bypassing branch protection | -| Okta Sync Report | Summary of team membership changes | -| Orphaned Users Alert | Org members not in any synced teams | -| Sync Error | Errors during Okta sync process | +| Event | Description | +| -------------------- | ------------------------------------- | +| PR Compliance Alert | PR merged bypassing branch protection | +| Okta Sync Report | Summary of team membership changes | +| Orphaned Users Alert | Org members not in any synced teams | +| Sync Error | Errors during Okta sync process | ## Troubleshooting @@ -204,6 +210,7 @@ The bot sends these notification types: Slack has rate limits (typically 1 message per second per channel). The app handles rate limits gracefully, but if you see delays: + - Notifications are queued and retried - Consider consolidating notifications for high-volume events From 7a620fa5e32305b548ff14439c8163a4ec0f4015 Mon Sep 17 00:00:00 2001 From: Brian Ojeda <9335829+sgtoj@users.noreply.github.com> Date: Wed, 18 Feb 2026 12:33:13 -0500 Subject: [PATCH 3/3] refactor: replace custom request/response types with chi router Eliminate the hand-rolled Request/Response abstraction layer in favor of a standard net/http + chi router. All entry points (server, lambda, verify, sample) now construct stdlib *http.Request objects and route them through a shared chi handler returned by App.Handler(). This removes the custom HandleRequest dispatcher, consolidates routing and middleware (admin auth, panic recovery), and makes tests more realistic by exercising the actual router via httptest. --- cmd/lambda/main.go | 85 ++++++--- cmd/sample/main.go | 39 +++- cmd/server/main.go | 55 +----- cmd/verify/scenario.go | 47 +++-- go.mod | 1 + go.sum | 2 + internal/app/app.go | 16 +- internal/app/app_test.go | 69 ++++---- internal/app/handlers_test.go | 53 +++--- internal/app/request.go | 322 ++++++++++++++-------------------- 10 files changed, 339 insertions(+), 350 deletions(-) diff --git a/cmd/lambda/main.go b/cmd/lambda/main.go index bac0b91..8d41b21 100644 --- a/cmd/lambda/main.go +++ b/cmd/lambda/main.go @@ -1,9 +1,13 @@ package main import ( + "bytes" "context" "encoding/json" + "fmt" "log/slog" + "net/http" + "net/http/httptest" "strings" "sync" @@ -17,6 +21,7 @@ import ( var ( initOnce sync.Once appInst *app.App + router http.Handler logger *slog.Logger initErr error ) @@ -31,10 +36,15 @@ func initApp() { return } appInst, initErr = app.NewApp(context.Background(), cfg, logger) + if initErr != nil { + return + } + router = appInst.Handler() }) } -// APIGatewayHandler converts API Gateway requests to unified app.Request. +// APIGatewayHandler converts API Gateway requests to stdlib *http.Request +// and routes them through the chi router. func APIGatewayHandler(ctx context.Context, req awsevents.APIGatewayV2HTTPRequest) (awsevents.APIGatewayV2HTTPResponse, error) { initApp() if initErr != nil { @@ -50,29 +60,35 @@ func APIGatewayHandler(ctx context.Context, req awsevents.APIGatewayV2HTTPReques logger.Debug("received api gateway request", slog.String("request", string(j))) } - headers := make(map[string]string) - for key, value := range req.Headers { - headers[strings.ToLower(key)] = value + httpReq, err := http.NewRequestWithContext( + ctx, + req.RequestContext.HTTP.Method, + req.RawPath, + strings.NewReader(req.Body), + ) + if err != nil { + return awsevents.APIGatewayV2HTTPResponse{ + StatusCode: 500, + Body: "failed to construct http request", + }, nil } - appReq := app.Request{ - Type: app.RequestTypeHTTP, - Method: req.RequestContext.HTTP.Method, - Path: req.RawPath, - Headers: headers, - Body: []byte(req.Body), + for key, value := range req.Headers { + httpReq.Header.Set(key, value) } - resp := appInst.HandleRequest(ctx, appReq) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, httpReq) return awsevents.APIGatewayV2HTTPResponse{ - StatusCode: resp.StatusCode, - Headers: resp.Headers, - Body: string(resp.Body), + StatusCode: rec.Code, + Headers: flattenHeaders(rec.Header()), + Body: rec.Body.String(), }, nil } -// EventBridgeHandler converts EventBridge events to unified app.Request. +// EventBridgeHandler converts EventBridge events to POST /scheduled/{action} +// requests and routes them through the chi router. func EventBridgeHandler(ctx context.Context, evt awsevents.CloudWatchEvent) error { initApp() if initErr != nil { @@ -90,16 +106,30 @@ func EventBridgeHandler(ctx context.Context, evt awsevents.CloudWatchEvent) erro return err } - req := app.Request{ - Type: app.RequestTypeScheduled, - ScheduledAction: detail.Action, - ScheduledData: detail.Data, + path := fmt.Sprintf("%s/scheduled/%s", appInst.Config.BasePath, detail.Action) + + var body []byte + if detail.Data != nil { + body = detail.Data } - resp := appInst.HandleRequest(ctx, req) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, path, bytes.NewReader(body)) + if err != nil { + return errors.Wrap(err, "failed to construct http request") + } - if resp.StatusCode >= 400 { - return errors.Newf("scheduled event failed: %s", string(resp.Body)) + if appInst.Config.AdminToken != "" { + httpReq.Header.Set("Authorization", "Bearer "+appInst.Config.AdminToken) + } + if len(body) > 0 { + httpReq.Header.Set("Content-Type", "application/json") + } + + rec := httptest.NewRecorder() + router.ServeHTTP(rec, httpReq) + + if rec.Code >= 400 { + return errors.Newf("scheduled event failed: %s", rec.Body.String()) } return nil @@ -125,6 +155,17 @@ func UniversalHandler(ctx context.Context, event json.RawMessage) (any, error) { return nil, errors.New("unknown lambda event type") } +// flattenHeaders converts multi-value http.Header to single-value map. +func flattenHeaders(h http.Header) map[string]string { + flat := make(map[string]string, len(h)) + for key, values := range h { + if len(values) > 0 { + flat[key] = values[0] + } + } + return flat +} + func main() { lambda.Start(UniversalHandler) } diff --git a/cmd/sample/main.go b/cmd/sample/main.go index 3cf5a2d..40888c4 100644 --- a/cmd/sample/main.go +++ b/cmd/sample/main.go @@ -4,9 +4,13 @@ package main import ( + "bytes" "context" "encoding/json" + "fmt" "log/slog" + "net/http" + "net/http/httptest" "os" "path/filepath" @@ -38,6 +42,8 @@ func main() { os.Exit(1) } + router := a.Handler() + path := filepath.Join("fixtures", "samples.json") raw, err := os.ReadFile(path) if err != nil { @@ -60,24 +66,45 @@ func main() { switch eventType { case "okta_sync": - evt := app.ScheduledEvent{ - Action: "okta-sync", + reqPath := fmt.Sprintf("%s/scheduled/okta-sync", cfg.BasePath) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, reqPath, nil) + if err != nil { + logger.Error("failed to construct http request", + slog.Int("sample", i), + slog.String("error", err.Error())) + os.Exit(1) + } + if cfg.AdminToken != "" { + httpReq.Header.Set("Authorization", "Bearer "+cfg.AdminToken) } - if err := a.ProcessScheduledEvent(ctx, evt); err != nil { + rec := httptest.NewRecorder() + router.ServeHTTP(rec, httpReq) + if rec.Code >= 400 { logger.Error("failed to process okta_sync sample", slog.Int("sample", i), - slog.String("error", err.Error())) + slog.String("response", rec.Body.String())) os.Exit(1) } case "pr_webhook": payload, _ := json.Marshal(sample["payload"]) - if err := a.ProcessWebhook(ctx, payload, "pull_request"); err != nil { - logger.Error("failed to process pr_webhook sample", + reqPath := fmt.Sprintf("%s/webhooks", cfg.BasePath) + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, reqPath, bytes.NewReader(payload)) + if err != nil { + logger.Error("failed to construct http request", slog.Int("sample", i), slog.String("error", err.Error())) os.Exit(1) } + httpReq.Header.Set("X-GitHub-Event", "pull_request") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, httpReq) + if rec.Code >= 400 { + logger.Error("failed to process pr_webhook sample", + slog.Int("sample", i), + slog.String("response", rec.Body.String())) + os.Exit(1) + } default: logger.Info("skipping unknown event type", slog.String("event_type", eventType)) diff --git a/cmd/server/main.go b/cmd/server/main.go index cd28c79..23ab07c 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,12 +2,10 @@ package main import ( "context" - "io" "log/slog" "net/http" "os" "os/signal" - "strings" "syscall" "time" @@ -15,13 +13,8 @@ import ( "github.com/cruxstack/github-ops-app/internal/config" ) -var ( - appInst *app.App - logger *slog.Logger -) - func main() { - logger = config.NewLogger() + logger := config.NewLogger() ctx := context.Background() cfg, err := config.NewConfig() @@ -30,15 +23,12 @@ func main() { os.Exit(1) } - appInst, err = app.NewApp(ctx, cfg, logger) + appInst, err := app.NewApp(ctx, cfg, logger) if err != nil { logger.Error("app init failed", slog.String("error", err.Error())) os.Exit(1) } - mux := http.NewServeMux() - mux.HandleFunc("/", httpHandler) - port := os.Getenv("APP_PORT") if port == "" { port = "8080" @@ -46,7 +36,7 @@ func main() { srv := &http.Server{ Addr: ":" + port, - Handler: mux, + Handler: appInst.Handler(), ReadTimeout: 15 * time.Second, WriteTimeout: 15 * time.Second, IdleTimeout: 60 * time.Second, @@ -80,42 +70,3 @@ func main() { <-done logger.Info("server stopped") } - -// httpHandler converts http.Request to app.Request and handles the response. -func httpHandler(w http.ResponseWriter, r *http.Request) { - defer r.Body.Close() - body, err := io.ReadAll(io.LimitReader(r.Body, 10<<20)) // 10MB limit - if err != nil { - http.Error(w, "failed to read request body", http.StatusBadRequest) - return - } - - headers := make(map[string]string) - for key, values := range r.Header { - if len(values) > 0 { - headers[strings.ToLower(key)] = values[0] - } - } - - req := app.Request{ - Type: app.RequestTypeHTTP, - Method: r.Method, - Path: r.URL.Path, - Headers: headers, - Body: body, - } - - resp := appInst.HandleRequest(r.Context(), req) - - for key, value := range resp.Headers { - w.Header().Set(key, value) - } - if resp.ContentType != "" && w.Header().Get("Content-Type") == "" { - w.Header().Set("Content-Type", resp.ContentType) - } - - w.WriteHeader(resp.StatusCode) - if len(resp.Body) > 0 { - w.Write(resp.Body) - } -} diff --git a/cmd/verify/scenario.go b/cmd/verify/scenario.go index 03c7ee5..2ce8e8b 100644 --- a/cmd/verify/scenario.go +++ b/cmd/verify/scenario.go @@ -1,12 +1,14 @@ package main import ( + "bytes" "context" "crypto/tls" "encoding/json" "fmt" "log/slog" "net/http" + "net/http/httptest" "os" "time" @@ -197,40 +199,51 @@ func runScenario(ctx context.Context, scenario TestScenario, verbose bool, logge fmt.Printf("\n Application Output:\n") } - var req app.Request + router := a.Handler() + + var httpReq *http.Request switch scenario.EventType { case "scheduled_event": var evt app.ScheduledEvent if err := json.Unmarshal(scenario.EventPayload, &evt); err != nil { return errors.Wrap(err, "failed to unmarshal event payload") } - req = app.Request{ - Type: app.RequestTypeScheduled, - ScheduledAction: evt.Action, - ScheduledData: evt.Data, + path := fmt.Sprintf("%s/scheduled/%s", cfg.BasePath, evt.Action) + var body []byte + if evt.Data != nil { + body = evt.Data + } + var err error + httpReq, err = http.NewRequestWithContext(ctx, http.MethodPost, path, bytes.NewReader(body)) + if err != nil { + return errors.Wrap(err, "failed to construct scheduled http request") + } + if cfg.AdminToken != "" { + httpReq.Header.Set("Authorization", "Bearer "+cfg.AdminToken) + } + if len(body) > 0 { + httpReq.Header.Set("Content-Type", "application/json") } case "webhook": - req = app.Request{ - Type: app.RequestTypeHTTP, - Method: "POST", - Path: "/webhooks", - Headers: map[string]string{ - "x-github-event": scenario.WebhookType, - "x-hub-signature-256": "", // signature validated separately in tests - }, - Body: scenario.WebhookPayload, + var err error + httpReq, err = http.NewRequestWithContext(ctx, http.MethodPost, cfg.BasePath+"/webhooks", bytes.NewReader(scenario.WebhookPayload)) + if err != nil { + return errors.Wrap(err, "failed to construct webhook http request") } + httpReq.Header.Set("X-GitHub-Event", scenario.WebhookType) + httpReq.Header.Set("X-Hub-Signature-256", "") // signature validated separately in tests default: return errors.Newf("unknown event type: %s", scenario.EventType) } - resp := a.HandleRequest(ctx, req) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, httpReq) var processErr error - if resp.StatusCode >= 400 { - processErr = errors.Newf("request failed with status %d: %s", resp.StatusCode, string(resp.Body)) + if rec.Code >= 400 { + processErr = errors.Newf("request failed with status %d: %s", rec.Code, rec.Body.String()) } if scenario.ExpectError { diff --git a/go.mod b/go.mod index b63beae..7f069fb 100644 --- a/go.mod +++ b/go.mod @@ -34,6 +34,7 @@ require ( github.com/cockroachdb/redact v1.1.5 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect github.com/getsentry/sentry-go v0.27.0 // indirect + github.com/go-chi/chi/v5 v5.2.5 // indirect github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/gogo/protobuf v1.3.2 // indirect diff --git a/go.sum b/go.sum index a2ee53d..ffd662f 100644 --- a/go.sum +++ b/go.sum @@ -46,6 +46,8 @@ github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvw github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= github.com/getsentry/sentry-go v0.27.0 h1:Pv98CIbtB3LkMWmXi4Joa5OOcwbmnX88sF5qbK3r3Ps= github.com/getsentry/sentry-go v0.27.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= +github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= +github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= diff --git a/internal/app/app.go b/internal/app/app.go index 7e12234..d71ca38 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -6,6 +6,8 @@ import ( "context" "encoding/json" "log/slog" + "net/http" + "sync" "github.com/cockroachdb/errors" "github.com/cruxstack/github-ops-app/internal/config" @@ -20,6 +22,8 @@ type App struct { GitHubClient domain.GitHubClient OktaClient domain.OktaClient Notifier domain.Notifier + router http.Handler + routerOnce sync.Once } // ScheduledEvent represents a generic scheduled event. @@ -28,9 +32,9 @@ type ScheduledEvent struct { Data json.RawMessage `json:"data,omitempty"` } -// ProcessScheduledEvent handles scheduled events (e.g., cron jobs). -// Routes to appropriate handlers based on event action. -func (a *App) ProcessScheduledEvent(ctx context.Context, evt ScheduledEvent) error { +// processScheduledEvent handles scheduled events (e.g., cron jobs). +// routes to appropriate handlers based on event action. +func (a *App) processScheduledEvent(ctx context.Context, evt ScheduledEvent) error { if a.Config.DebugEnabled { j, _ := json.Marshal(evt) a.Logger.Debug("received scheduled event", slog.String("event", string(j))) @@ -46,9 +50,9 @@ func (a *App) ProcessScheduledEvent(ctx context.Context, evt ScheduledEvent) err } } -// ProcessWebhook handles incoming GitHub webhook events. -// Supports pull_request, team, and membership events. -func (a *App) ProcessWebhook(ctx context.Context, payload []byte, eventType string) error { +// processWebhook handles incoming GitHub webhook events. +// supports pull_request, team, and membership events. +func (a *App) processWebhook(ctx context.Context, payload []byte, eventType string) error { if a.Config.DebugEnabled { a.Logger.Debug("received webhook", slog.String("event_type", eventType)) } diff --git a/internal/app/app_test.go b/internal/app/app_test.go index 7a12eec..3ababf0 100644 --- a/internal/app/app_test.go +++ b/internal/app/app_test.go @@ -3,6 +3,8 @@ package app import ( "context" "log/slog" + "net/http" + "net/http/httptest" "os" "testing" @@ -135,7 +137,7 @@ func TestProcessScheduledEvent_SlackTest(t *testing.T) { } evt := ScheduledEvent{Action: "slack-test"} - err := app.ProcessScheduledEvent(context.Background(), evt) + err := app.processScheduledEvent(context.Background(), evt) // should fail because slack is not configured if err == nil { @@ -150,7 +152,7 @@ func TestProcessScheduledEvent_UnknownAction(t *testing.T) { } evt := ScheduledEvent{Action: "unknown-action"} - err := app.ProcessScheduledEvent(context.Background(), evt) + err := app.processScheduledEvent(context.Background(), evt) if err == nil { t.Error("expected error for unknown action") @@ -164,7 +166,7 @@ func TestFakeDataTypes(t *testing.T) { var _ *domain.OrphanedUsersReport = fakeOrphanedUsersReport() } -func TestCheckAdminAuth(t *testing.T) { +func TestAdminAuthMiddleware(t *testing.T) { tests := []struct { name string adminToken string @@ -211,33 +213,33 @@ func TestCheckAdminAuth(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - app := &App{ + a := &App{ Config: &config.Config{AdminToken: tt.adminToken}, Logger: slog.New(slog.NewTextHandler(os.Stderr, nil)), } - headers := map[string]string{} + // test via /server/status which uses admin auth middleware + router := a.Handler() + + req := httptest.NewRequest(http.MethodGet, "/server/status", nil) if tt.authHeader != "" { - headers["authorization"] = tt.authHeader + req.Header.Set("Authorization", tt.authHeader) } - req := Request{Headers: headers} - resp := app.checkAdminAuth(req) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) - if tt.expectError && resp == nil { - t.Error("expected error response, got nil") - } - if !tt.expectError && resp != nil { - t.Errorf("expected no error, got status %d", resp.StatusCode) + if tt.expectError && rec.Code != http.StatusUnauthorized { + t.Errorf("expected 401, got %d", rec.Code) } - if tt.expectError && resp != nil && resp.StatusCode != 401 { - t.Errorf("expected status 401, got %d", resp.StatusCode) + if !tt.expectError && rec.Code == http.StatusUnauthorized { + t.Errorf("expected success, got 401") } }) } } -func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { +func TestRouter_AdminAuthOnProtectedEndpoints(t *testing.T) { tests := []struct { name string path string @@ -249,7 +251,7 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { { name: "status endpoint, no token configured", path: "/server/status", - method: "GET", + method: http.MethodGet, adminToken: "", authHeader: "", expectedStatus: 200, @@ -257,7 +259,7 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { { name: "status endpoint, token required, missing", path: "/server/status", - method: "GET", + method: http.MethodGet, adminToken: "secret", authHeader: "", expectedStatus: 401, @@ -265,7 +267,7 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { { name: "status endpoint, token required, correct", path: "/server/status", - method: "GET", + method: http.MethodGet, adminToken: "secret", authHeader: "Bearer secret", expectedStatus: 200, @@ -273,7 +275,7 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { { name: "config endpoint, token required, missing", path: "/server/config", - method: "GET", + method: http.MethodGet, adminToken: "secret", authHeader: "", expectedStatus: 401, @@ -281,7 +283,7 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { { name: "config endpoint, token required, correct", path: "/server/config", - method: "GET", + method: http.MethodGet, adminToken: "secret", authHeader: "Bearer secret", expectedStatus: 200, @@ -289,7 +291,7 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { { name: "scheduled endpoint, token required, missing", path: "/scheduled/slack-test", - method: "POST", + method: http.MethodPost, adminToken: "secret", authHeader: "", expectedStatus: 401, @@ -298,27 +300,24 @@ func TestHandleRequest_AdminAuthOnProtectedEndpoints(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - app := &App{ + a := &App{ Config: &config.Config{AdminToken: tt.adminToken}, Logger: slog.New(slog.NewTextHandler(os.Stderr, nil)), } - headers := map[string]string{} - if tt.authHeader != "" { - headers["authorization"] = tt.authHeader - } + router := a.Handler() - req := Request{ - Type: RequestTypeHTTP, - Method: tt.method, - Path: tt.path, - Headers: headers, + req := httptest.NewRequest(tt.method, tt.path, nil) + if tt.authHeader != "" { + req.Header.Set("Authorization", tt.authHeader) } - resp := app.HandleRequest(context.Background(), req) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) - if resp.StatusCode != tt.expectedStatus { - t.Errorf("expected status %d, got %d", tt.expectedStatus, resp.StatusCode) + if rec.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d (body: %s)", + tt.expectedStatus, rec.Code, rec.Body.String()) } }) } diff --git a/internal/app/handlers_test.go b/internal/app/handlers_test.go index 07e9d14..3a58fb3 100644 --- a/internal/app/handlers_test.go +++ b/internal/app/handlers_test.go @@ -5,6 +5,8 @@ import ( "encoding/json" "io" "log/slog" + "net/http" + "net/http/httptest" "testing" "github.com/cockroachdb/errors" @@ -500,7 +502,7 @@ func TestProcessWebhook_UnknownEventType(t *testing.T) { Logger: discardLogger(), } - err := a.ProcessWebhook(context.Background(), []byte(`{}`), "deployment") + err := a.processWebhook(context.Background(), []byte(`{}`), "deployment") if err == nil { t.Fatal("expected error for unknown event type") } @@ -509,7 +511,7 @@ func TestProcessWebhook_UnknownEventType(t *testing.T) { } } -func TestMapErrorResponse(t *testing.T) { +func TestWriteErrorFromDomain(t *testing.T) { tests := []struct { name string err error @@ -526,50 +528,61 @@ func TestMapErrorResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resp := mapErrorResponse(tt.err, "fallback") - if resp.StatusCode != tt.wantStatus { - t.Errorf("expected status %d, got %d", tt.wantStatus, resp.StatusCode) + rec := httptest.NewRecorder() + writeErrorFromDomain(rec, tt.err, "fallback") + if rec.Code != tt.wantStatus { + t.Errorf("expected status %d, got %d", tt.wantStatus, rec.Code) } }) } } -func TestHandleHTTPRequest_NotFound(t *testing.T) { +func TestRouter_NotFound(t *testing.T) { a := &App{ Config: &config.Config{}, Logger: discardLogger(), } - req := Request{Type: RequestTypeHTTP, Method: "GET", Path: "/nonexistent"} - resp := a.HandleRequest(context.Background(), req) - if resp.StatusCode != 404 { - t.Errorf("expected 404, got %d", resp.StatusCode) + router := a.Handler() + req := httptest.NewRequest(http.MethodGet, "/nonexistent", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Errorf("expected 404, got %d", rec.Code) } } -func TestHandleHTTPRequest_MethodNotAllowed(t *testing.T) { +func TestRouter_MethodNotAllowed(t *testing.T) { a := &App{ Config: &config.Config{}, Logger: discardLogger(), } - req := Request{Type: RequestTypeHTTP, Method: "DELETE", Path: "/webhooks"} - resp := a.HandleRequest(context.Background(), req) - if resp.StatusCode != 405 { - t.Errorf("expected 405, got %d", resp.StatusCode) + router := a.Handler() + req := httptest.NewRequest(http.MethodDelete, "/webhooks", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusMethodNotAllowed { + t.Errorf("expected 405, got %d", rec.Code) } } -func TestHandleHTTPRequest_BasePathStripping(t *testing.T) { +func TestRouter_BasePathStripping(t *testing.T) { a := &App{ Config: &config.Config{BasePath: "/api/v1"}, Logger: discardLogger(), } - req := Request{Type: RequestTypeHTTP, Method: "GET", Path: "/api/v1/server/status"} - resp := a.HandleRequest(context.Background(), req) - if resp.StatusCode != 200 { - t.Errorf("expected 200 for base-path-stripped status, got %d", resp.StatusCode) + router := a.Handler() + req := httptest.NewRequest(http.MethodGet, "/api/v1/server/status", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("expected 200 for base-path-stripped status, got %d (body: %s)", + rec.Code, rec.Body.String()) } } diff --git a/internal/app/request.go b/internal/app/request.go index 6b831de..351b24f 100644 --- a/internal/app/request.go +++ b/internal/app/request.go @@ -1,257 +1,195 @@ package app import ( - "context" "encoding/json" + "io" "log/slog" + "net/http" "strings" "github.com/cockroachdb/errors" "github.com/cruxstack/github-ops-app/internal/domain" "github.com/cruxstack/github-ops-app/internal/github/webhooks" + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" ) -// RequestType identifies the category of incoming request. -type RequestType string - -const ( - // RequestTypeHTTP represents HTTP requests (webhooks, status, config). - RequestTypeHTTP RequestType = "http" - // RequestTypeScheduled represents scheduled/cron events. - RequestTypeScheduled RequestType = "scheduled" -) - -// Request is a unified request type that abstracts HTTP and scheduled events. -// Runtimes (server, lambda) convert their native formats to this type. -type Request struct { - Type RequestType `json:"type"` - Method string `json:"method,omitempty"` - Path string `json:"path,omitempty"` - Headers map[string]string `json:"headers,omitempty"` - Body []byte `json:"body,omitempty"` - - // ScheduledAction is used for scheduled events (e.g., "okta-sync"). - ScheduledAction string `json:"scheduled_action,omitempty"` - // ScheduledData contains optional payload for scheduled events. - ScheduledData json.RawMessage `json:"scheduled_data,omitempty"` -} - -// Response is a unified response type returned by HandleRequest. -// Runtimes convert this to their native response format. -type Response struct { - StatusCode int `json:"status_code"` - Headers map[string]string `json:"headers,omitempty"` - Body []byte `json:"body,omitempty"` - ContentType string `json:"content_type,omitempty"` +// Handler returns the HTTP handler for the application. +// this is the single entry point for all HTTP request processing. +// both the server and lambda entry points feed requests into this handler. +func (a *App) Handler() http.Handler { + a.routerOnce.Do(func() { + a.router = a.buildRouter() + }) + return a.router } -// HandleRequest routes incoming requests to the appropriate handler. -// This is the single entry point for all request processing. -func (a *App) HandleRequest(ctx context.Context, req Request) Response { - if a.Config.DebugEnabled { - j, _ := json.Marshal(req) - a.Logger.Debug("handling request", slog.String("request", string(j))) - } - - switch req.Type { - case RequestTypeScheduled: - return a.handleScheduledRequest(ctx, req) - case RequestTypeHTTP: - return a.handleHTTPRequest(ctx, req) - default: - return errorResponse(400, "unknown request type") - } -} +// buildRouter constructs the chi router with all routes and middleware. +func (a *App) buildRouter() http.Handler { + r := chi.NewRouter() + r.Use(middleware.Recoverer) -// handleScheduledRequest processes scheduled/cron events. -func (a *App) handleScheduledRequest(ctx context.Context, req Request) Response { - evt := ScheduledEvent{ - Action: req.ScheduledAction, - Data: req.ScheduledData, + basePath := a.Config.BasePath + if basePath == "" { + basePath = "/" } - if err := a.ProcessScheduledEvent(ctx, evt); err != nil { - a.Logger.Error("scheduled event processing failed", - slog.String("action", evt.Action), - slog.String("error", err.Error())) - return mapErrorResponse(err, "scheduled event processing failed") - } + r.Route(basePath, func(r chi.Router) { + r.Post("/webhooks", a.handleWebhookHTTP) - return jsonResponse(200, map[string]string{ - "status": "success", - "message": evt.Action + " completed", + r.Group(func(r chi.Router) { + r.Use(a.adminAuthMiddleware) + r.Get("/server/status", a.handleStatusHTTP) + r.Get("/server/config", a.handleConfigHTTP) + r.Post("/scheduled/{action}", a.handleScheduledHTTP) + }) }) + + return r } -// handleHTTPRequest routes HTTP requests based on path. -// strips BasePath prefix if configured (e.g., "/api/v1" -> "/"). -func (a *App) handleHTTPRequest(ctx context.Context, req Request) Response { - path := req.Path - if a.Config.BasePath != "" { - path = strings.TrimPrefix(path, a.Config.BasePath) - if path == "" { - path = "/" +// adminAuthMiddleware validates the admin bearer token on protected routes. +// if no admin token is configured, all requests pass through. +func (a *App) adminAuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if a.Config.AdminToken == "" { + next.ServeHTTP(w, r) + return } - } - switch path { - case "/server/status": - return a.handleStatusRequest(req) - case "/server/config": - return a.handleConfigRequest(req) - case "/webhooks", "/": - return a.handleWebhookRequest(ctx, req) - default: - if strings.HasPrefix(path, "/scheduled/") { - return a.handleScheduledHTTPRequest(ctx, req, path) + authHeader := r.Header.Get("Authorization") + if authHeader == "" { + writeError(w, http.StatusUnauthorized, "unauthorized") + return } - return errorResponse(404, "not found") - } + + token := strings.TrimPrefix(authHeader, "Bearer ") + if token == authHeader { + token = strings.TrimPrefix(authHeader, "bearer ") + } + + if token != a.Config.AdminToken { + writeError(w, http.StatusUnauthorized, "unauthorized") + return + } + + next.ServeHTTP(w, r) + }) } -// handleStatusRequest returns application status. -func (a *App) handleStatusRequest(req Request) Response { - if req.Method != "GET" { - return errorResponse(405, "method not allowed") - } - if resp := a.checkAdminAuth(req); resp != nil { - return *resp - } - return jsonResponse(200, a.GetStatus()) +// handleStatusHTTP returns application status. +func (a *App) handleStatusHTTP(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, a.GetStatus()) } -// handleConfigRequest returns redacted configuration. -func (a *App) handleConfigRequest(req Request) Response { - if req.Method != "GET" { - return errorResponse(405, "method not allowed") - } - if resp := a.checkAdminAuth(req); resp != nil { - return *resp - } - return jsonResponse(200, a.Config.Redacted()) +// handleConfigHTTP returns redacted configuration. +func (a *App) handleConfigHTTP(w http.ResponseWriter, r *http.Request) { + writeJSON(w, http.StatusOK, a.Config.Redacted()) } -// handleWebhookRequest processes GitHub webhook POST requests. -func (a *App) handleWebhookRequest(ctx context.Context, req Request) Response { - if req.Method != "POST" { - return errorResponse(405, "method not allowed") +// handleWebhookHTTP processes GitHub webhook POST requests. +func (a *App) handleWebhookHTTP(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(io.LimitReader(r.Body, 10<<20)) // 10MB limit + if err != nil { + writeError(w, http.StatusBadRequest, "failed to read request body") + return } - eventType := req.Headers["x-github-event"] - signature := req.Headers["x-hub-signature-256"] + eventType := r.Header.Get("X-GitHub-Event") + signature := r.Header.Get("X-Hub-Signature-256") if err := webhooks.ValidateWebhookSignature( - req.Body, + body, signature, a.Config.GitHubWebhookSecret, ); err != nil { a.Logger.Warn("webhook signature validation failed", slog.String("error", err.Error())) - return mapErrorResponse(err, "unauthorized") + writeErrorFromDomain(w, err, "unauthorized") + return } - if err := a.ProcessWebhook(ctx, req.Body, eventType); err != nil { + if err := a.processWebhook(r.Context(), body, eventType); err != nil { a.Logger.Error("webhook processing failed", slog.String("event_type", eventType), slog.String("error", err.Error())) - return mapErrorResponse(err, "webhook processing failed") + writeErrorFromDomain(w, err, "webhook processing failed") + return } - return Response{ - StatusCode: 200, - ContentType: "text/plain", - Body: []byte("ok"), - } + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte("ok")) } -// handleScheduledHTTPRequest processes scheduled events via HTTP POST. -// path is the normalized path with BasePath already stripped. -func (a *App) handleScheduledHTTPRequest(ctx context.Context, req Request, path string) Response { - if req.Method != "POST" { - return errorResponse(405, "method not allowed") - } - if resp := a.checkAdminAuth(req); resp != nil { - return *resp - } - - // extract action from path (e.g., "/scheduled/okta-sync" -> "okta-sync") - action := strings.TrimPrefix(path, "/scheduled/") +// handleScheduledHTTP processes scheduled events via HTTP POST. +// the action is extracted from the chi URL parameter. +func (a *App) handleScheduledHTTP(w http.ResponseWriter, r *http.Request) { + action := chi.URLParam(r, "action") if action == "" { - return errorResponse(400, "missing scheduled action") + writeError(w, http.StatusBadRequest, "missing scheduled action") + return } - scheduledReq := Request{ - Type: RequestTypeScheduled, - ScheduledAction: action, - } + evt := ScheduledEvent{Action: action} - return a.handleScheduledRequest(ctx, scheduledReq) -} + if r.Body != nil { + body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) + if err != nil { + writeError(w, http.StatusBadRequest, "failed to read request body") + return + } + if len(body) > 0 { + evt.Data = json.RawMessage(body) + } + } -// mapErrorResponse translates domain error types to appropriate HTTP status -// codes. centralizes error-to-HTTP mapping in one place. uses errors.Is() -// with domain marker instances since errors.Mark() sets identity markers. -func mapErrorResponse(err error, fallbackMsg string) Response { - switch { - case errors.Is(err, domain.AuthError): - return errorResponse(401, "unauthorized") - case errors.Is(err, domain.ValidationError): - return errorResponse(400, fallbackMsg) - case errors.Is(err, domain.ConfigError): - return errorResponse(503, "service not configured") - case errors.Is(err, domain.APIError): - return errorResponse(502, fallbackMsg) - default: - return errorResponse(500, fallbackMsg) + if err := a.processScheduledEvent(r.Context(), evt); err != nil { + a.Logger.Error("scheduled event processing failed", + slog.String("action", evt.Action), + slog.String("error", err.Error())) + writeErrorFromDomain(w, err, "scheduled event processing failed") + return } + + writeJSON(w, http.StatusOK, map[string]string{ + "status": "success", + "message": evt.Action + " completed", + }) } -// jsonResponse creates a JSON response with the given status and data. -func jsonResponse(status int, data any) Response { +// writeJSON writes a JSON response with the given status code. +func writeJSON(w http.ResponseWriter, status int, data any) { body, err := json.Marshal(data) if err != nil { - return errorResponse(500, "failed to marshal response") - } - return Response{ - StatusCode: status, - ContentType: "application/json", - Headers: map[string]string{"Content-Type": "application/json"}, - Body: body, + writeError(w, http.StatusInternalServerError, "failed to marshal response") + return } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + w.Write(body) } -// errorResponse creates an error response with the given status and message. -func errorResponse(status int, message string) Response { - return Response{ - StatusCode: status, - ContentType: "text/plain", - Body: []byte(message), - } +// writeError writes a plain text error response. +func writeError(w http.ResponseWriter, status int, message string) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(status) + w.Write([]byte(message)) } -// checkAdminAuth validates the admin token from the request. -// returns nil if auth is disabled (no token configured) or if token is valid. -// returns an error response if token is required but missing or invalid. -func (a *App) checkAdminAuth(req Request) *Response { - if a.Config.AdminToken == "" { - return nil - } - - authHeader := req.Headers["authorization"] - if authHeader == "" { - resp := errorResponse(401, "unauthorized") - return &resp - } - - token := strings.TrimPrefix(authHeader, "Bearer ") - if token == authHeader { - token = strings.TrimPrefix(authHeader, "bearer ") - } - - if token != a.Config.AdminToken { - resp := errorResponse(401, "unauthorized") - return &resp +// writeErrorFromDomain translates domain error types to HTTP status codes +// and writes the response. centralizes error-to-HTTP mapping. +func writeErrorFromDomain(w http.ResponseWriter, err error, fallbackMsg string) { + switch { + case errors.Is(err, domain.AuthError): + writeError(w, http.StatusUnauthorized, "unauthorized") + case errors.Is(err, domain.ValidationError): + writeError(w, http.StatusBadRequest, fallbackMsg) + case errors.Is(err, domain.ConfigError): + writeError(w, http.StatusServiceUnavailable, "service not configured") + case errors.Is(err, domain.APIError): + writeError(w, http.StatusBadGateway, fallbackMsg) + default: + writeError(w, http.StatusInternalServerError, fallbackMsg) } - - return nil }