Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 95 additions & 124 deletions internal/api/oauthserver/authorize.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,61 +190,102 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ
}

authorizationID := chi.URLParam(r, "authorization_id")
authorization, err := s.validateAndFindAuthorization(r, db, authorizationID)
if err != nil {
return err
if authorizationID == "" {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization_id is required")
}

// Set user_id if not already set
if authorization.UserID == nil {
// Use transaction to atomically set user and check for auto-approve
var shouldAutoApprove bool
var existingConsent *models.OAuthServerConsent
var (
authorization *models.OAuthServerAuthorization
shouldAutoApprove bool
)

err := db.Transaction(func(tx *storage.Connection) error {
if err := authorization.SetUser(tx, user.ID); err != nil {
return err
// Lookup, user association, consent check, and optional auto-approve
// run under a FOR UPDATE SKIP LOCKED row lock so two concurrent callers
// cannot both claim the same pending authorization and each receive a
// valid authorization code.
err := db.Transaction(func(tx *storage.Connection) error {
auth, terr := models.FindOAuthServerAuthorizationByIDForUpdate(tx, authorizationID)
if terr != nil {
if models.IsNotFoundError(terr) {
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
}
return apierrors.NewInternalServerError("error finding authorization").WithInternalError(terr)
}

// Check for existing consent and auto-approve if available
var err error
existingConsent, err = models.FindActiveOAuthServerConsentByUserAndClient(tx, user.ID, authorization.ClientID)
if err != nil {
return err
if auth.IsExpired() {
if merr := auth.MarkExpired(tx); merr != nil {
observability.GetLogEntry(r).Entry.WithError(merr).Warn("failed to mark authorization as expired")
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
}
// Commit the MarkExpired update but still return not-found.
return storage.NewCommitWithError(apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found"))
}

// Check if consent covers requested scopes
if existingConsent != nil && s.consentCoversScopes(existingConsent, authorization.Scope) {
shouldAutoApprove = true
if auth.Status != models.OAuthServerAuthorizationPending {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization request cannot be processed")
}

if auth.UserID == nil {
if err := auth.SetUser(tx, user.ID); err != nil {
return err
}

return nil
})
existingConsent, cerr := models.FindActiveOAuthServerConsentByUserAndClient(tx, user.ID, auth.ClientID)
if cerr != nil {
return cerr
}

if err != nil {
return apierrors.NewInternalServerError("error setting user and checking consent").WithInternalError(err)
if existingConsent != nil && s.consentCoversScopes(existingConsent, auth.Scope) {
shouldAutoApprove = true
}
} else if *auth.UserID != user.ID {
observability.GetLogEntry(r).Entry.
WithField("request_user_id", user.ID).
WithField("authorization_id", auth.AuthorizationID).
Warn("authorization belongs to different user")
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
}

// If we should auto-approve, do it now
if shouldAutoApprove {
return s.autoApproveAndRedirect(w, r, authorization)
if err := auth.Approve(tx); err != nil {
return apierrors.NewInternalServerError("Error auto-approving authorization").WithInternalError(err)
}
}
} else {
// Authorization already has user_id set, validate ownership
if err := s.validateAuthorizationOwnership(r, authorization, user); err != nil {
return err

authorization = auth
return nil
})

if err != nil {
return err
}

observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID)
observability.LogEntrySetField(r, "client_id", authorization.ClientID.String())

if shouldAutoApprove {
observability.LogEntrySetField(r, "auto_approved", true)
return shared.SendJSON(w, http.StatusOK, ConsentResponse{
RedirectURL: s.buildSuccessRedirectURL(authorization),
})
}

client, err := models.FindOAuthServerClientByID(db, authorization.ClientID)
if err != nil {
if models.IsNotFoundError(err) {
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
}
return apierrors.NewInternalServerError("error finding client").WithInternalError(err)
}

// Build response with client and user details
response := AuthorizationDetailsResponse{
AuthorizationID: authorization.AuthorizationID,
RedirectURI: authorization.RedirectURI,
Client: ClientDetailsResponse{
ID: authorization.Client.ID.String(),
Name: utilities.StringValue(authorization.Client.ClientName),
URI: utilities.StringValue(authorization.Client.ClientURI),
LogoURI: utilities.StringValue(authorization.Client.LogoURI),
ID: client.ID.String(),
Name: utilities.StringValue(client.ClientName),
URI: utilities.StringValue(client.ClientURI),
LogoURI: utilities.StringValue(client.LogoURI),
},
User: UserDetailsResponse{
ID: user.ID.String(),
Expand All @@ -253,9 +294,6 @@ func (s *Server) OAuthServerGetAuthorization(w http.ResponseWriter, r *http.Requ
Scope: authorization.Scope,
}

observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID)
observability.LogEntrySetField(r, "client_id", authorization.Client.ID.String())

return shared.SendJSON(w, http.StatusOK, response)
}

Expand Down Expand Up @@ -284,43 +322,46 @@ func (s *Server) OAuthServerConsent(w http.ResponseWriter, r *http.Request) erro
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "action must be 'approve' or 'deny'")
}

// Validate and find authorization outside transaction first
authorizationID := chi.URLParam(r, "authorization_id")
observability.LogEntrySetField(r, "authorization_id", authorizationID)
authorization, err := s.validateAndFindAuthorization(r, db, authorizationID)
if err != nil {
return err
}

// Ensure authorization belongs to authenticated user
if err := s.validateAuthorizationOwnership(r, authorization, user); err != nil {
return err
if authorizationID == "" {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization_id is required")
}

// Process consent in transaction
// Row is locked FOR UPDATE SKIP LOCKED so concurrent approve/deny
// requests for the same authorization are serialised and the ownership
// check below can't race a SetUser from another request.
var redirectURL string
err = db.Transaction(func(tx *storage.Connection) error {
// Re-fetch in transaction to ensure consistency
authorization, err := models.FindOAuthServerAuthorizationByID(tx, authorizationID)
err := db.Transaction(func(tx *storage.Connection) error {
authorization, err := models.FindOAuthServerAuthorizationByIDForUpdate(tx, authorizationID)
if err != nil {
if models.IsNotFoundError(err) {
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
}
return apierrors.NewInternalServerError("error finding authorization").WithInternalError(err)
}

// Re-check expiration and status in transaction (state could have changed)
if authorization.IsExpired() {
if err := authorization.MarkExpired(tx); err != nil {
observability.GetLogEntry(r).Entry.WithError(err).Warn("failed to mark authorization as expired")
if merr := authorization.MarkExpired(tx); merr != nil {
observability.GetLogEntry(r).Entry.WithError(merr).Warn("failed to mark authorization as expired")
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
}
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
// Commit the MarkExpired update but still return not-found.
return storage.NewCommitWithError(apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found"))
}

if authorization.Status != models.OAuthServerAuthorizationPending {
return apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization request is no longer pending")
}

if authorization.UserID == nil || *authorization.UserID != user.ID {
observability.GetLogEntry(r).Entry.
WithField("request_user_id", user.ID).
WithField("authorization_id", authorization.AuthorizationID).
Warn("authorization belongs to different user")
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
}

if body.Action == OAuthServerConsentActionApprove {
// Approve authorization
if err := authorization.Approve(tx); err != nil {
Expand Down Expand Up @@ -390,51 +431,6 @@ func (s *Server) validateRequestOrigin(r *http.Request) error {
return nil
}

// validateAndFindAuthorization validates the authorization_id parameter and finds the authorization,
// performing all necessary checks (existence, expiration, status)
func (s *Server) validateAndFindAuthorization(r *http.Request, db *storage.Connection, authorizationID string) (*models.OAuthServerAuthorization, error) {
if authorizationID == "" {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization_id is required")
}

authorization, err := models.FindOAuthServerAuthorizationByID(db, authorizationID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
}
return nil, apierrors.NewInternalServerError("error finding authorization").WithInternalError(err)
}

// Check if expired first - no point processing expired authorizations
if authorization.IsExpired() {
// Mark as expired in database
if err := authorization.MarkExpired(db); err != nil {
observability.GetLogEntry(r).Entry.WithError(err).Warn("failed to mark authorization as expired")
}
// returning not found to avoid leaking information about the existence of the authorization
return nil, apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
}

// Check if still pending
if authorization.Status != models.OAuthServerAuthorizationPending {
return nil, apierrors.NewBadRequestError(apierrors.ErrorCodeValidationFailed, "authorization request cannot be processed")
}

return authorization, nil
}

// validateAuthorizationOwnership checks if the authorization belongs to the authenticated user
func (s *Server) validateAuthorizationOwnership(r *http.Request, authorization *models.OAuthServerAuthorization, user *models.User) error {
if authorization.UserID == nil || *authorization.UserID != user.ID {
observability.GetLogEntry(r).Entry.
WithField("request_user_id", user.ID).
WithField("authorization_id", authorization.AuthorizationID).
Warn("authorization belongs to different user")
return apierrors.NewNotFoundError(apierrors.ErrorCodeOAuthAuthorizationNotFound, "authorization not found")
}
return nil
}

// validateBasicAuthorizeParams validates only client_id and redirect_uri (needed before we can redirect errors)
func (s *Server) validateBasicAuthorizeParams(params *AuthorizeParams) (*AuthorizeParams, error) {
if params.ClientID == "" {
Expand Down Expand Up @@ -571,31 +567,6 @@ func (s *Server) consentCoversScopes(consent *models.OAuthServerConsent, request
return consent.HasAllScopes(requestedScopes)
}

func (s *Server) autoApproveAndRedirect(w http.ResponseWriter, r *http.Request, authorization *models.OAuthServerAuthorization) error {
ctx := r.Context()
db := s.db.WithContext(ctx)

// Approve the authorization in a transaction
err := db.Transaction(func(tx *storage.Connection) error {
return authorization.Approve(tx)
})

if err != nil {
return apierrors.NewInternalServerError("Error auto-approving authorization").WithInternalError(err)
}

observability.LogEntrySetField(r, "authorization_id", authorization.AuthorizationID)
observability.LogEntrySetField(r, "auto_approved", true)

// Return JSON with redirect URL (same format as consent endpoint)
redirectURL := s.buildSuccessRedirectURL(authorization)
response := ConsentResponse{
RedirectURL: redirectURL,
}

return shared.SendJSON(w, http.StatusOK, response)
}

func (s *Server) buildSuccessRedirectURL(authorization *models.OAuthServerAuthorization) string {
u, _ := url.Parse(authorization.RedirectURI)
q := u.Query()
Expand Down
Loading
Loading