Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions src/handlers/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@ type messageInput struct {
// authenticated user using PostgreSQL LISTEN/NOTIFY on new_msg_to.
func (h *MessageHandler) Wait(c *gin.Context) {
identity := middleware.GetIdentity(c)
if identity == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}

sinceID := int64(0)
if v := c.Query("since_id"); v != "" {
Expand Down Expand Up @@ -180,10 +176,6 @@ func (h *MessageHandler) latestMessageIDForRecipient(ctx context.Context, identi
// List handles GET /api/v1/messages — lists messages where the authenticated user is a recipient.
func (h *MessageHandler) List(c *gin.Context) {
identity := middleware.GetIdentity(c)
if identity == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}

// Parse limit query parameter (default 20, max 100).
limit := 20
Expand Down Expand Up @@ -314,10 +306,6 @@ func (h *MessageHandler) List(c *gin.Context) {
// Create handles POST /api/v1/messages — creates a draft message.
func (h *MessageHandler) Create(c *gin.Context) {
identity := middleware.GetIdentity(c)
if identity == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}

var msg messageInput
if err := c.ShouldBindJSON(&msg); err != nil {
Expand Down
17 changes: 13 additions & 4 deletions src/middleware/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,24 +73,33 @@ func SetupJWT(key []byte, idURL string) (*jwt.GinJWTMiddleware, error) {
code, accepting, err := checkFmsgID(idURL, addr)
if err != nil {
log.Printf("fmsgid check error for %s: %v", addr, err)
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{"error": "identity service unavailable"})
c.Set("auth_error_code", http.StatusServiceUnavailable)
c.Set("auth_error_msg", "identity service unavailable")
return false
}
if code == http.StatusNotFound {
log.Printf("auth rejected: ip=%s addr=%s reason=not_found", c.ClientIP(), addr)
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("User %s not found", addr)})
c.Set("auth_error_code", http.StatusBadRequest)
c.Set("auth_error_msg", fmt.Sprintf("User %s not found", addr))
return false
}
if code == http.StatusOK && !accepting {
log.Printf("auth rejected: ip=%s addr=%s reason=not_accepting", c.ClientIP(), addr)
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("User %s not authorised to send new messages", addr)})
c.Set("auth_error_code", http.StatusForbidden)
c.Set("auth_error_msg", fmt.Sprintf("User %s not authorised to send new messages", addr))
return false
}
return true
},

// Unauthorized responds with 401 when JWT validation fails.
// Unauthorized responds when JWT validation or authorization fails.
Unauthorized: func(c *gin.Context, code int, message string) {
if errCode, exists := c.Get("auth_error_code"); exists {
code = errCode.(int)
}
if errMsg, exists := c.Get("auth_error_msg"); exists {
message = errMsg.(string)
}
log.Printf("auth failure: ip=%s code=%d message=%s", c.ClientIP(), code, message)
c.JSON(code, gin.H{"error": message})
},
Expand Down
Loading