diff --git a/internal/api/oauthserver/authorize.go b/internal/api/oauthserver/authorize.go index 42e14e335b..2addacfdaa 100644 --- a/internal/api/oauthserver/authorize.go +++ b/internal/api/oauthserver/authorize.go @@ -190,61 +190,102 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ } authorizationID := chi.URLParam(r, "authorization_id") - authorization, err := s.validateAndFindAuthorization(r, db, authorizationID) - if err != nil { - return err + if authorizationID == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization_id is required") } - // Set user_id if not already set - if authorization.UserID == nil { - // Use transaction to atomically set user and check for auto-approve - var shouldAutoApprove bool - var existingConsent *models.OAuthServerConsent + var ( + authorization *models.OAuthServerAuthorization + shouldAutoApprove bool + ) - err := db.Transaction(func(tx *storage.Connection) error { - if err := authorization.SetUser(tx, user.ID); err != nil { - return err + // Lookup, user association, consent check, and optional auto-approve + // run under a FOR UPDATE SKIP LOCKED row lock so two concurrent callers + // cannot both claim the same pending authorization and each receive a + // valid authorization code. + err := db.Transaction(func(tx *storage.Connection) error { + auth, terr := models.FindOAuthServerAuthorizationByIDForUpdate(tx, authorizationID) + if terr != nil { + if models.IsNotFoundError(terr) { + return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") } + return apierrors.NewInternalServerError("error finding authorization").WithInternalError(terr) + } - // Check for existing consent and auto-approve if available - var err error - existingConsent, err = models.FindActiveOAuthServerConsentByUserAndClient(tx, user.ID, authorization.ClientID) - if err != nil { - return err + if auth.IsExpired() { + if merr := auth.MarkExpired(tx); merr != nil { + observability.GetLogEntry(r).Entry.WithError(merr).Warn("failed to mark authorization as expired") + return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") } + // Commit the MarkExpired update but still return not-found. + return storage.NewCommitWithError(apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")) + } - // Check if consent covers requested scopes - if existingConsent != nil && s.consentCoversScopes(existingConsent, authorization.Scope) { - shouldAutoApprove = true + if auth.Status != models.OAuthServerAuthorizationPending { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization request cannot be processed") + } + + if auth.UserID == nil { + if err := auth.SetUser(tx, user.ID); err != nil { + return err } - return nil - }) + existingConsent, cerr := models.FindActiveOAuthServerConsentByUserAndClient(tx, user.ID, auth.ClientID) + if cerr != nil { + return cerr + } - if err != nil { - return apierrors.NewInternalServerError("error setting user and checking consent").WithInternalError(err) + if existingConsent != nil && s.consentCoversScopes(existingConsent, auth.Scope) { + shouldAutoApprove = true + } + } else if *auth.UserID != user.ID { + observability.GetLogEntry(r).Entry. + WithField("request_user_id", user.ID). + WithField("authorization_id", auth.AuthorizationID). + Warn("authorization belongs to different user") + return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") } - // If we should auto-approve, do it now if shouldAutoApprove { - return s.autoApproveAndRedirect(w, r, authorization) + if err := auth.Approve(tx); err != nil { + return apierrors.NewInternalServerError("Error auto-approving authorization").WithInternalError(err) + } } - } else { - // Authorization already has user_id set, validate ownership - if err := s.validateAuthorizationOwnership(r, authorization, user); err != nil { - return err + + authorization = auth + return nil + }) + + if err != nil { + return err + } + + observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) + observability.LogEntrySetField(r, "client_id", authorization.ClientID.String()) + + if shouldAutoApprove { + observability.LogEntrySetField(r, "auto_approved", true) + return shared.SendJSON(w, http.StatusOK, ConsentResponse{ + RedirectURL: s.buildSuccessRedirectURL(authorization), + }) + } + + client, err := models.FindOAuthServerClientByID(db, authorization.ClientID) + if err != nil { + if models.IsNotFoundError(err) { + return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") } + return apierrors.NewInternalServerError("error finding client").WithInternalError(err) } - // Build response with client and user details response := AuthorizationDetailsResponse{ AuthorizationID: authorization.AuthorizationID, RedirectURI: authorization.RedirectURI, Client: ClientDetailsResponse{ - ID: authorization.Client.ID.String(), - Name: utilities.StringValue(authorization.Client.ClientName), - URI: utilities.StringValue(authorization.Client.ClientURI), - LogoURI: utilities.StringValue(authorization.Client.LogoURI), + ID: client.ID.String(), + Name: utilities.StringValue(client.ClientName), + URI: utilities.StringValue(client.ClientURI), + LogoURI: utilities.StringValue(client.LogoURI), }, User: UserDetailsResponse{ ID: user.ID.String(), @@ -253,9 +294,6 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ Scope: authorization.Scope, } - observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) - observability.LogEntrySetField(r, "client_id", authorization.Client.ID.String()) - return shared.SendJSON(w, http.StatusOK, response) } @@ -284,24 +322,18 @@ func (s *Server) OAuthServerConsent(w http.ResponseWriter, r *http.Request) erro return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "action must be 'approve' or 'deny'") } - // Validate and find authorization outside transaction first authorizationID := chi.URLParam(r, "authorization_id") observability.LogEntrySetField(r, "authorization_id", authorizationID) - authorization, err := s.validateAndFindAuthorization(r, db, authorizationID) - if err != nil { - return err - } - - // Ensure authorization belongs to authenticated user - if err := s.validateAuthorizationOwnership(r, authorization, user); err != nil { - return err + if authorizationID == "" { + return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization_id is required") } - // Process consent in transaction + // Row is locked FOR UPDATE SKIP LOCKED so concurrent approve/deny + // requests for the same authorization are serialised and the ownership + // check below can't race a SetUser from another request. var redirectURL string - err = db.Transaction(func(tx *storage.Connection) error { - // Re-fetch in transaction to ensure consistency - authorization, err := models.FindOAuthServerAuthorizationByID(tx, authorizationID) + err := db.Transaction(func(tx *storage.Connection) error { + authorization, err := models.FindOAuthServerAuthorizationByIDForUpdate(tx, authorizationID) if err != nil { if models.IsNotFoundError(err) { return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") @@ -309,18 +341,27 @@ func (s *Server) OAuthServerConsent(w http.ResponseWriter, r *http.Request) erro return apierrors.NewInternalServerError("error finding authorization").WithInternalError(err) } - // Re-check expiration and status in transaction (state could have changed) if authorization.IsExpired() { - if err := authorization.MarkExpired(tx); err != nil { - observability.GetLogEntry(r).Entry.WithError(err).Warn("failed to mark authorization as expired") + if merr := authorization.MarkExpired(tx); merr != nil { + observability.GetLogEntry(r).Entry.WithError(merr).Warn("failed to mark authorization as expired") + return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") } - return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") + // Commit the MarkExpired update but still return not-found. + return storage.NewCommitWithError(apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")) } if authorization.Status != models.OAuthServerAuthorizationPending { return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization request is no longer pending") } + if authorization.UserID == nil || *authorization.UserID != user.ID { + observability.GetLogEntry(r).Entry. + WithField("request_user_id", user.ID). + WithField("authorization_id", authorization.AuthorizationID). + Warn("authorization belongs to different user") + return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") + } + if body.Action == OAuthServerConsentActionApprove { // Approve authorization if err := authorization.Approve(tx); err != nil { @@ -390,51 +431,6 @@ func (s *Server) validateRequestOrigin(r *http.Request) error { return nil } -// validateAndFindAuthorization validates the authorization_id parameter and finds the authorization, -// performing all necessary checks (existence, expiration, status) -func (s *Server) validateAndFindAuthorization(r *http.Request, db *storage.Connection, authorizationID string) (*models.OAuthServerAuthorization, error) { - if authorizationID == "" { - return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization_id is required") - } - - authorization, err := models.FindOAuthServerAuthorizationByID(db, authorizationID) - if err != nil { - if models.IsNotFoundError(err) { - return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") - } - return nil, apierrors.NewInternalServerError("error finding authorization").WithInternalError(err) - } - - // Check if expired first - no point processing expired authorizations - if authorization.IsExpired() { - // Mark as expired in database - if err := authorization.MarkExpired(db); err != nil { - observability.GetLogEntry(r).Entry.WithError(err).Warn("failed to mark authorization as expired") - } - // returning not found to avoid leaking information about the existence of the authorization - return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") - } - - // Check if still pending - if authorization.Status != models.OAuthServerAuthorizationPending { - return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization request cannot be processed") - } - - return authorization, nil -} - -// validateAuthorizationOwnership checks if the authorization belongs to the authenticated user -func (s *Server) validateAuthorizationOwnership(r *http.Request, authorization *models.OAuthServerAuthorization, user *models.User) error { - if authorization.UserID == nil || *authorization.UserID != user.ID { - observability.GetLogEntry(r).Entry. - WithField("request_user_id", user.ID). - WithField("authorization_id", authorization.AuthorizationID). - Warn("authorization belongs to different user") - return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found") - } - return nil -} - // validateBasicAuthorizeParams validates only client_id and redirect_uri (needed before we can redirect errors) func (s *Server) validateBasicAuthorizeParams(params *AuthorizeParams) (*AuthorizeParams, error) { if params.ClientID == "" { @@ -571,31 +567,6 @@ func (s *Server) consentCoversScopes(consent *models.OAuthServerConsent, request return consent.HasAllScopes(requestedScopes) } -func (s *Server) autoApproveAndRedirect(w http.ResponseWriter, r *http.Request, authorization *models.OAuthServerAuthorization) error { - ctx := r.Context() - db := s.db.WithContext(ctx) - - // Approve the authorization in a transaction - err := db.Transaction(func(tx *storage.Connection) error { - return authorization.Approve(tx) - }) - - if err != nil { - return apierrors.NewInternalServerError("Error auto-approving authorization").WithInternalError(err) - } - - observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID) - observability.LogEntrySetField(r, "auto_approved", true) - - // Return JSON with redirect URL (same format as consent endpoint) - redirectURL := s.buildSuccessRedirectURL(authorization) - response := ConsentResponse{ - RedirectURL: redirectURL, - } - - return shared.SendJSON(w, http.StatusOK, response) -} - func (s *Server) buildSuccessRedirectURL(authorization *models.OAuthServerAuthorization) string { u, _ := url.Parse(authorization.RedirectURI) q := u.Query() diff --git a/internal/api/oauthserver/authorize_test.go b/internal/api/oauthserver/authorize_test.go index 3cfa2aeaca..13d82c803b 100644 --- a/internal/api/oauthserver/authorize_test.go +++ b/internal/api/oauthserver/authorize_test.go @@ -1,15 +1,27 @@ package oauthserver import ( + "bytes" + "context" + "encoding/json" "net/http" "net/http/httptest" + "net/url" "testing" + "time" + "github.com/go-chi/chi/v5" "github.com/gobwas/glob" + "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/shared" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/hooks/v0hooks" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" "github.com/supabase/auth/internal/storage/test" "github.com/supabase/auth/internal/tokens" ) @@ -185,3 +197,469 @@ func TestValidateRequestOriginEdgeCases(t *testing.T) { assert.Error(t, err) }) } + +type OAuthAuthorizeTestSuite struct { + suite.Suite + Server *Server + Config *conf.GlobalConfiguration + DB *storage.Connection +} + +func TestOAuthAuthorize(t *testing.T) { + globalConfig, err := conf.LoadGlobal(oauthServerTestConfig) + require.NoError(t, err) + + conn, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + + globalConfig.OAuthServer.Enabled = true + globalConfig.OAuthServer.AllowDynamicRegistration = true + if globalConfig.OAuthServer.AuthorizationTTL == 0 { + globalConfig.OAuthServer.AuthorizationTTL = 10 * time.Minute + } + // OAuthServerAuthorize bails to a server_error redirect if this is empty. + if globalConfig.OAuthServer.AuthorizationPath == "" { + globalConfig.OAuthServer.AuthorizationPath = "/oauth/authorize-frontend" + } + + hooksMgr := &v0hooks.Manager{} + tokenService := tokens.NewService(globalConfig, hooksMgr) + server := NewServer(globalConfig, conn, tokenService) + + ts := &OAuthAuthorizeTestSuite{ + Server: server, + Config: globalConfig, + DB: conn, + } + defer ts.DB.Close() + + suite.Run(t, ts) +} + +func (ts *OAuthAuthorizeTestSuite) SetupTest() { + require.NoError(ts.T(), models.TruncateAll(ts.DB)) + ts.Config.OAuthServer.Enabled = true + ts.Config.OAuthServer.AllowDynamicRegistration = true +} + +// ---------- helpers ---------- + +func (ts *OAuthAuthorizeTestSuite) createUser(email string) *models.User { + u, err := models.NewUser("", email, "password123", "authenticated", nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.DB.Create(u)) + return u +} + +func (ts *OAuthAuthorizeTestSuite) createClient() *models.OAuthServerClient { + params := &OAuthServerClientRegisterParams{ + ClientName: "Test Authorize Client", + RedirectURIs: []string{"https://example.com/callback"}, + RegistrationType: "dynamic", + } + client, _, err := ts.Server.registerOAuthServerClient(context.Background(), params) + require.NoError(ts.T(), err) + return client +} + +// createAuthorization creates an authorization using the OAuthServerAuthorize handler +func (ts *OAuthAuthorizeTestSuite) createAuthorization(clientID uuid.UUID, scope string) *models.OAuthServerAuthorization { + q := url.Values{ + "client_id": []string{clientID.String()}, + "redirect_uri": []string{"https://example.com/callback"}, + "scope": []string{scope}, + "code_challenge": []string{"abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG"}, + "code_challenge_method": []string{"S256"}, + } + req := httptest.NewRequest(http.MethodGet, "/oauth/authorize?"+q.Encode(), nil) + w := httptest.NewRecorder() + require.NoError(ts.T(), ts.Server.OAuthServerAuthorize(w, req)) + require.Equal(ts.T(), http.StatusFound, w.Code, "authorize: %s", w.Body.String()) + + loc, err := url.Parse(w.Header().Get("Location")) + require.NoError(ts.T(), err) + authID := loc.Query().Get("authorization_id") + require.NotEmpty(ts.T(), authID, "authorize redirect missing authorization_id (got %q, error=%q)", + w.Header().Get("Location"), loc.Query().Get("error")) + return ts.reload(authID) +} + +// expireAuthorization pushes expires_at into the past so IsExpired() returns true +func (ts *OAuthAuthorizeTestSuite) expireAuthorization(authorizationID string) { + require.NoError(ts.T(), ts.DB.RawQuery( + "UPDATE oauth_authorizations SET created_at = now() - interval '2 hours', expires_at = now() - interval '1 hour' WHERE authorization_id = ?", + authorizationID, + ).Exec()) +} + +// newRequest builds a request with the authorization_id and the user attached. +func (ts *OAuthAuthorizeTestSuite) newRequest(method, authorizationID string, user *models.User, body []byte) *http.Request { + req := httptest.NewRequest(method, "/oauth/authorizations/"+authorizationID, bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + + rctx := chi.NewRouteContext() + rctx.URLParams.Add("authorization_id", authorizationID) + ctx := context.WithValue(req.Context(), chi.RouteCtxKey, rctx) + if user != nil { + ctx = shared.WithUser(ctx, user) + } + return req.WithContext(ctx) +} + +func (ts *OAuthAuthorizeTestSuite) assertHTTPError(err error, status int, code apierrors.ErrorCode) { + ts.T().Helper() + httpErr, ok := err.(*apierrors.HTTPError) + require.True(ts.T(), ok, "expected *apierrors.HTTPError, got %T (%v)", err, err) + assert.Equal(ts.T(), status, httpErr.HTTPStatus) + assert.Equal(ts.T(), string(code), httpErr.ErrorCode) +} + +func (ts *OAuthAuthorizeTestSuite) reload(authorizationID string) *models.OAuthServerAuthorization { + a, err := models.FindOAuthServerAuthorizationByID(ts.DB, authorizationID) + require.NoError(ts.T(), err) + return a +} + +// ---------- OAuthServerGetAuthorization ---------- + +func (ts *OAuthAuthorizeTestSuite) TestGetAuthorization_SetsUserAndReturnsDetails() { + user := ts.createUser("get-details@example.com") + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid profile") + + req := ts.newRequest(http.MethodGet, auth.AuthorizationID, user, nil) + w := httptest.NewRecorder() + + require.NoError(ts.T(), ts.Server.OAuthServerGetAuthorization(w, req)) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var resp AuthorizationDetailsResponse + require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &resp)) + assert.Equal(ts.T(), auth.AuthorizationID, resp.AuthorizationID) + assert.Equal(ts.T(), client.ID.String(), resp.Client.ID) + assert.Equal(ts.T(), "Test Authorize Client", resp.Client.Name) + assert.Equal(ts.T(), user.ID.String(), resp.User.ID) + assert.Equal(ts.T(), "openid profile", resp.Scope) + + var maybeConsent ConsentResponse + _ = json.Unmarshal(w.Body.Bytes(), &maybeConsent) + assert.Empty(ts.T(), maybeConsent.RedirectURL, "should not have auto-approved without consent") + + reloaded := ts.reload(auth.AuthorizationID) + require.NotNil(ts.T(), reloaded.UserID) + assert.Equal(ts.T(), user.ID, *reloaded.UserID) + assert.Equal(ts.T(), models.OAuthServerAuthorizationPending, reloaded.Status) + assert.Nil(ts.T(), reloaded.AuthorizationCode) +} + +func (ts *OAuthAuthorizeTestSuite) TestGetAuthorization_AutoApprovesWhenConsentCoversScopes() { + user := ts.createUser("auto-approve@example.com") + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid profile") + + consent := models.NewOAuthServerConsent(user.ID, client.ID, []string{"openid", "profile"}) + require.NoError(ts.T(), models.UpsertOAuthServerConsent(ts.DB, consent)) + + req := ts.newRequest(http.MethodGet, auth.AuthorizationID, user, nil) + w := httptest.NewRecorder() + + require.NoError(ts.T(), ts.Server.OAuthServerGetAuthorization(w, req)) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var resp ConsentResponse + require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &resp)) + require.NotEmpty(ts.T(), resp.RedirectURL) + parsed, err := url.Parse(resp.RedirectURL) + require.NoError(ts.T(), err) + assert.NotEmpty(ts.T(), parsed.Query().Get("code"), "redirect_url must carry an authorization code") + + reloaded := ts.reload(auth.AuthorizationID) + assert.Equal(ts.T(), models.OAuthServerAuthorizationApproved, reloaded.Status) + require.NotNil(ts.T(), reloaded.AuthorizationCode) + assert.NotEmpty(ts.T(), *reloaded.AuthorizationCode) + assert.NotNil(ts.T(), reloaded.ApprovedAt) + require.NotNil(ts.T(), reloaded.UserID) + assert.Equal(ts.T(), user.ID, *reloaded.UserID) +} + +func (ts *OAuthAuthorizeTestSuite) TestGetAuthorization_ConsentDoesNotCoverScopes() { + user := ts.createUser("partial-consent@example.com") + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid profile") + + // Consent only covers "openid" — missing "profile". + consent := models.NewOAuthServerConsent(user.ID, client.ID, []string{"openid"}) + require.NoError(ts.T(), models.UpsertOAuthServerConsent(ts.DB, consent)) + + req := ts.newRequest(http.MethodGet, auth.AuthorizationID, user, nil) + w := httptest.NewRecorder() + + require.NoError(ts.T(), ts.Server.OAuthServerGetAuthorization(w, req)) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var resp AuthorizationDetailsResponse + require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &resp)) + assert.Equal(ts.T(), client.ID.String(), resp.Client.ID) + + reloaded := ts.reload(auth.AuthorizationID) + assert.Equal(ts.T(), models.OAuthServerAuthorizationPending, reloaded.Status) + assert.Nil(ts.T(), reloaded.AuthorizationCode) +} + +func (ts *OAuthAuthorizeTestSuite) TestGetAuthorization_SameUserRepeatCall() { + user := ts.createUser("repeat@example.com") + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid profile") + + for i := range 2 { + req := ts.newRequest(http.MethodGet, auth.AuthorizationID, user, nil) + w := httptest.NewRecorder() + require.NoError(ts.T(), ts.Server.OAuthServerGetAuthorization(w, req), "call %d", i) + assert.Equal(ts.T(), http.StatusOK, w.Code, "call %d", i) + + var resp AuthorizationDetailsResponse + require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &resp), "call %d", i) + assert.Equal(ts.T(), auth.AuthorizationID, resp.AuthorizationID, "call %d", i) + } + + reloaded := ts.reload(auth.AuthorizationID) + assert.Equal(ts.T(), models.OAuthServerAuthorizationPending, reloaded.Status) + require.NotNil(ts.T(), reloaded.UserID) + assert.Equal(ts.T(), user.ID, *reloaded.UserID) +} + +func (ts *OAuthAuthorizeTestSuite) TestGetAuthorization_DifferentUserReturnsNotFound() { + owner := ts.createUser("owner@example.com") + otherUser := ts.createUser("other-user@example.com") + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid") + require.NoError(ts.T(), auth.SetUser(ts.DB, owner.ID)) + + req := ts.newRequest(http.MethodGet, auth.AuthorizationID, otherUser, nil) + w := httptest.NewRecorder() + + err := ts.Server.OAuthServerGetAuthorization(w, req) + ts.assertHTTPError(err, http.StatusNotFound, apierrors.ErrorCodeOAuthAuthorizationNotFound) + + reloaded := ts.reload(auth.AuthorizationID) + require.NotNil(ts.T(), reloaded.UserID) + assert.Equal(ts.T(), owner.ID, *reloaded.UserID, "row ownership should be unchanged") + assert.Equal(ts.T(), models.OAuthServerAuthorizationPending, reloaded.Status) +} + +func (ts *OAuthAuthorizeTestSuite) TestGetAuthorization_ExpiredCommitsMarkExpired() { + user := ts.createUser("expired-get@example.com") + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid") + ts.expireAuthorization(auth.AuthorizationID) + + req := ts.newRequest(http.MethodGet, auth.AuthorizationID, user, nil) + w := httptest.NewRecorder() + + // Expired path returns a CommitWithError so MarkExpired persists. + err := ts.Server.OAuthServerGetAuthorization(w, req) + cwe, ok := err.(*storage.CommitWithError) + require.True(ts.T(), ok, "expected *storage.CommitWithError, got %T (%v)", err, err) + ts.assertHTTPError(cwe.Err, http.StatusNotFound, apierrors.ErrorCodeOAuthAuthorizationNotFound) + + // The MarkExpired update must be committed even though the handler + // returns an error — otherwise an expired row would remain pending. + reloaded := ts.reload(auth.AuthorizationID) + assert.Equal(ts.T(), models.OAuthServerAuthorizationExpired, reloaded.Status) +} + +func (ts *OAuthAuthorizeTestSuite) TestGetAuthorization_NonPendingStatusRejected() { + user := ts.createUser("non-pending-get@example.com") + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid") + require.NoError(ts.T(), auth.SetUser(ts.DB, user.ID)) + require.NoError(ts.T(), auth.Deny(ts.DB)) + + req := ts.newRequest(http.MethodGet, auth.AuthorizationID, user, nil) + w := httptest.NewRecorder() + + err := ts.Server.OAuthServerGetAuthorization(w, req) + ts.assertHTTPError(err, http.StatusBadRequest, apierrors.ErrorCodeValidationFailed) +} + +func (ts *OAuthAuthorizeTestSuite) TestGetAuthorization_UnknownAuthorizationID() { + user := ts.createUser("unknown-get@example.com") + + req := ts.newRequest(http.MethodGet, "nonexistent-authorization-id", user, nil) + w := httptest.NewRecorder() + + err := ts.Server.OAuthServerGetAuthorization(w, req) + ts.assertHTTPError(err, http.StatusNotFound, apierrors.ErrorCodeOAuthAuthorizationNotFound) +} + +func (ts *OAuthAuthorizeTestSuite) TestGetAuthorization_MissingAuthorizationID() { + user := ts.createUser("missing-get@example.com") + + req := ts.newRequest(http.MethodGet, "", user, nil) + w := httptest.NewRecorder() + + err := ts.Server.OAuthServerGetAuthorization(w, req) + ts.assertHTTPError(err, http.StatusBadRequest, apierrors.ErrorCodeValidationFailed) +} + +func (ts *OAuthAuthorizeTestSuite) TestGetAuthorization_NoUserInContext() { + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid") + + req := ts.newRequest(http.MethodGet, auth.AuthorizationID, nil, nil) + w := httptest.NewRecorder() + + err := ts.Server.OAuthServerGetAuthorization(w, req) + ts.assertHTTPError(err, http.StatusForbidden, apierrors.ErrorCodeBadJWT) +} + +// ---------- OAuthServerConsent ---------- + +func consentBody(action OAuthServerConsentAction) []byte { + b, _ := json.Marshal(ConsentRequest{Action: action}) + return b +} + +func (ts *OAuthAuthorizeTestSuite) TestConsent_ApproveIssuesCodeAndStoresConsent() { + user := ts.createUser("consent-approve@example.com") + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid profile") + require.NoError(ts.T(), auth.SetUser(ts.DB, user.ID)) + + req := ts.newRequest(http.MethodPost, auth.AuthorizationID, user, consentBody(OAuthServerConsentActionApprove)) + w := httptest.NewRecorder() + + require.NoError(ts.T(), ts.Server.OAuthServerConsent(w, req)) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var resp ConsentResponse + require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &resp)) + require.NotEmpty(ts.T(), resp.RedirectURL) + parsed, err := url.Parse(resp.RedirectURL) + require.NoError(ts.T(), err) + assert.NotEmpty(ts.T(), parsed.Query().Get("code"), "approve should return redirect with code") + assert.Empty(ts.T(), parsed.Query().Get("error")) + + reloaded := ts.reload(auth.AuthorizationID) + assert.Equal(ts.T(), models.OAuthServerAuthorizationApproved, reloaded.Status) + require.NotNil(ts.T(), reloaded.AuthorizationCode) + assert.NotEmpty(ts.T(), *reloaded.AuthorizationCode) + + stored, err := models.FindActiveOAuthServerConsentByUserAndClient(ts.DB, user.ID, client.ID) + require.NoError(ts.T(), err) + require.NotNil(ts.T(), stored) + assert.True(ts.T(), stored.HasAllScopes([]string{"openid", "profile"})) +} + +func (ts *OAuthAuthorizeTestSuite) TestConsent_DenyReturnsAccessDenied() { + user := ts.createUser("consent-deny@example.com") + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid") + require.NoError(ts.T(), auth.SetUser(ts.DB, user.ID)) + + req := ts.newRequest(http.MethodPost, auth.AuthorizationID, user, consentBody(OAuthServerConsentActionDeny)) + w := httptest.NewRecorder() + + require.NoError(ts.T(), ts.Server.OAuthServerConsent(w, req)) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + var resp ConsentResponse + require.NoError(ts.T(), json.Unmarshal(w.Body.Bytes(), &resp)) + require.NotEmpty(ts.T(), resp.RedirectURL) + parsed, err := url.Parse(resp.RedirectURL) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), oAuth2ErrorAccessDenied, parsed.Query().Get("error")) + assert.Empty(ts.T(), parsed.Query().Get("code")) + + reloaded := ts.reload(auth.AuthorizationID) + assert.Equal(ts.T(), models.OAuthServerAuthorizationDenied, reloaded.Status) + + // Deny must NOT upsert a consent record. + stored, err := models.FindActiveOAuthServerConsentByUserAndClient(ts.DB, user.ID, client.ID) + require.NoError(ts.T(), err) + assert.Nil(ts.T(), stored) +} + +func (ts *OAuthAuthorizeTestSuite) TestConsent_UserIDMismatchReturnsNotFound() { + owner := ts.createUser("consent-owner@example.com") + otherUser := ts.createUser("consent-other-user@example.com") + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid") + require.NoError(ts.T(), auth.SetUser(ts.DB, owner.ID)) + + req := ts.newRequest(http.MethodPost, auth.AuthorizationID, otherUser, consentBody(OAuthServerConsentActionApprove)) + w := httptest.NewRecorder() + + err := ts.Server.OAuthServerConsent(w, req) + ts.assertHTTPError(err, http.StatusNotFound, apierrors.ErrorCodeOAuthAuthorizationNotFound) + + reloaded := ts.reload(auth.AuthorizationID) + assert.Equal(ts.T(), models.OAuthServerAuthorizationPending, reloaded.Status, "row must be untouched") + assert.Nil(ts.T(), reloaded.AuthorizationCode) +} + +func (ts *OAuthAuthorizeTestSuite) TestConsent_UserIDNilReturnsNotFound() { + user := ts.createUser("consent-nil@example.com") + client := ts.createClient() + // No UserID set — user never went through GetAuthorization first. + auth := ts.createAuthorization(client.ID, "openid") + + req := ts.newRequest(http.MethodPost, auth.AuthorizationID, user, consentBody(OAuthServerConsentActionApprove)) + w := httptest.NewRecorder() + + err := ts.Server.OAuthServerConsent(w, req) + ts.assertHTTPError(err, http.StatusNotFound, apierrors.ErrorCodeOAuthAuthorizationNotFound) + + reloaded := ts.reload(auth.AuthorizationID) + assert.Nil(ts.T(), reloaded.UserID) + assert.Equal(ts.T(), models.OAuthServerAuthorizationPending, reloaded.Status) +} + +func (ts *OAuthAuthorizeTestSuite) TestConsent_ExpiredCommitsMarkExpired() { + user := ts.createUser("consent-expired@example.com") + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid") + require.NoError(ts.T(), auth.SetUser(ts.DB, user.ID)) + ts.expireAuthorization(auth.AuthorizationID) + + req := ts.newRequest(http.MethodPost, auth.AuthorizationID, user, consentBody(OAuthServerConsentActionApprove)) + w := httptest.NewRecorder() + + // Expired path returns a CommitWithError so MarkExpired persists. + err := ts.Server.OAuthServerConsent(w, req) + cwe, ok := err.(*storage.CommitWithError) + require.True(ts.T(), ok, "expected *storage.CommitWithError, got %T (%v)", err, err) + ts.assertHTTPError(cwe.Err, http.StatusNotFound, apierrors.ErrorCodeOAuthAuthorizationNotFound) + + reloaded := ts.reload(auth.AuthorizationID) + assert.Equal(ts.T(), models.OAuthServerAuthorizationExpired, reloaded.Status, + "MarkExpired update must be committed via NewCommitWithError") +} + +func (ts *OAuthAuthorizeTestSuite) TestConsent_NonPendingStatusRejected() { + user := ts.createUser("consent-nonpending@example.com") + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid") + require.NoError(ts.T(), auth.SetUser(ts.DB, user.ID)) + require.NoError(ts.T(), auth.Approve(ts.DB)) + + req := ts.newRequest(http.MethodPost, auth.AuthorizationID, user, consentBody(OAuthServerConsentActionApprove)) + w := httptest.NewRecorder() + + err := ts.Server.OAuthServerConsent(w, req) + ts.assertHTTPError(err, http.StatusBadRequest, apierrors.ErrorCodeValidationFailed) +} + +func (ts *OAuthAuthorizeTestSuite) TestConsent_InvalidActionRejected() { + user := ts.createUser("consent-invalid-action@example.com") + client := ts.createClient() + auth := ts.createAuthorization(client.ID, "openid") + require.NoError(ts.T(), auth.SetUser(ts.DB, user.ID)) + + body, _ := json.Marshal(map[string]string{"action": "maybe"}) + req := ts.newRequest(http.MethodPost, auth.AuthorizationID, user, body) + w := httptest.NewRecorder() + + err := ts.Server.OAuthServerConsent(w, req) + ts.assertHTTPError(err, http.StatusBadRequest, apierrors.ErrorCodeValidationFailed) +} diff --git a/internal/models/oauth_authorization.go b/internal/models/oauth_authorization.go index d731719120..2b7fab2579 100644 --- a/internal/models/oauth_authorization.go +++ b/internal/models/oauth_authorization.go @@ -252,6 +252,23 @@ func FindOAuthServerAuthorizationByID(tx *storage.Connection, authorizationID st return auth, nil } +// FindOAuthServerAuthorizationByIDForUpdate finds an OAuth authorization by +// authorization_id and locks the row with FOR UPDATE SKIP LOCKED. +// Must be called inside a transaction. +func FindOAuthServerAuthorizationByIDForUpdate(tx *storage.Connection, authorizationID string) (*OAuthServerAuthorization, error) { + auth := &OAuthServerAuthorization{} + if err := tx.RawQuery( + fmt.Sprintf("SELECT * FROM %q WHERE authorization_id = ? LIMIT 1 FOR UPDATE SKIP LOCKED", auth.TableName()), + authorizationID, + ).First(auth); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, OAuthServerAuthorizationNotFoundError{} + } + return nil, errors.Wrap(err, "error finding OAuth authorization") + } + return auth, nil +} + // FindOAuthServerAuthorizationByCode finds an OAuth authorization by authorization code func FindOAuthServerAuthorizationByCode(tx *storage.Connection, code string) (*OAuthServerAuthorization, error) { auth := &OAuthServerAuthorization{} diff --git a/internal/models/oauth_authorization_test.go b/internal/models/oauth_authorization_test.go index 02ed5d6682..25f7e01355 100644 --- a/internal/models/oauth_authorization_test.go +++ b/internal/models/oauth_authorization_test.go @@ -7,6 +7,10 @@ import ( "github.com/gofrs/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" + "github.com/supabase/auth/internal/storage/test" ) func TestNewOAuthServerAuthorization(t *testing.T) { @@ -327,6 +331,57 @@ func TestOAuthServerAuthorization_MarkExpiredLogic(t *testing.T) { } } +// FindOAuthServerAuthorizationByIDForUpdate locks the row with FOR UPDATE +// SKIP LOCKED so concurrent callers can't both claim the same pending +// authorization. Verify a second caller in a fresh transaction sees the +// locked row as "not found". +func TestFindOAuthServerAuthorizationByIDForUpdate_SkipLocked(t *testing.T) { + globalConfig, err := conf.LoadGlobal(modelsTestConfig) + require.NoError(t, err) + + db, err := test.SetupDBConnection(globalConfig) + require.NoError(t, err) + defer db.Close() + require.NoError(t, TruncateAll(db)) + + clientName := "Skip Locked Test Client" + secretHash, err := testHashClientSecret("test_secret") + require.NoError(t, err) + client := &OAuthServerClient{ + ID: uuid.Must(uuid.NewV4()), + ClientName: &clientName, + GrantTypes: "authorization_code,refresh_token", + RegistrationType: "dynamic", + ClientType: OAuthServerClientTypeConfidential, + ClientSecretHash: secretHash, + RedirectURIs: "https://example.com/callback", + } + require.NoError(t, CreateOAuthServerClient(db, client)) + + auth := NewOAuthServerAuthorization(NewOAuthServerAuthorizationParams{ + ClientID: client.ID, + RedirectURI: "https://example.com/callback", + Scope: "openid", + CodeChallenge: "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFG", + CodeChallengeMethod: "S256", + TTL: 10 * time.Minute, + }) + require.NoError(t, CreateOAuthServerAuthorization(db, auth)) + + holdTx, err := db.Connection.NewTransaction() + require.NoError(t, err) + defer func() { _ = holdTx.TX.Rollback() }() + held := &storage.Connection{Connection: holdTx} + _, err = FindOAuthServerAuthorizationByIDForUpdate(held, auth.AuthorizationID) + require.NoError(t, err) + + require.NoError(t, db.Transaction(func(tx *storage.Connection) error { + _, ferr := FindOAuthServerAuthorizationByIDForUpdate(tx, auth.AuthorizationID) + require.True(t, IsNotFoundError(ferr)) + return nil + })) +} + func TestOAuthServerAuthorization_Validate(t *testing.T) { userID := uuid.Must(uuid.NewV4()) clientID := uuid.Must(uuid.NewV4())