-
Notifications
You must be signed in to change notification settings - Fork 0
rate-limiting, timeouts and log auth failures #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5e09ef8
330e73f
72806f0
f9d7641
68ec68c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,13 +24,15 @@ import ( | |
|
|
||
| // MessageHandler holds dependencies for message routes. | ||
| type MessageHandler struct { | ||
| DB *db.DB | ||
| DataDir string | ||
| DB *db.DB | ||
| DataDir string | ||
| MaxDataSize int64 | ||
| MaxMsgSize int64 | ||
| } | ||
|
|
||
| // NewMessageHandler creates a MessageHandler. | ||
| func NewMessageHandler(database *db.DB, dataDir string) *MessageHandler { | ||
| return &MessageHandler{DB: database, DataDir: dataDir} | ||
| func NewMessageHandler(database *db.DB, dataDir string, maxDataSize, maxMsgSize int64) *MessageHandler { | ||
| return &MessageHandler{DB: database, DataDir: dataDir, MaxDataSize: maxDataSize, MaxMsgSize: maxMsgSize} | ||
| } | ||
|
|
||
| // messageListItem is the JSON shape for each message in the list response. | ||
|
|
@@ -334,6 +336,11 @@ func (h *MessageHandler) Create(c *gin.Context) { | |
| return | ||
| } | ||
|
|
||
| if int64(len(msg.Data)) > h.MaxDataSize { | ||
| c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "message data exceeds maximum size"}) | ||
| return | ||
| } | ||
|
|
||
| ctx := c.Request.Context() | ||
|
|
||
| // Validate PID references an existing message. | ||
|
|
@@ -358,12 +365,13 @@ func (h *MessageHandler) Create(c *gin.Context) { | |
| ext := mimeToExt(msg.Type) | ||
|
|
||
| // Insert message row with empty filepath; update after we know the ID. | ||
| dataSize := len(msg.Data) | ||
| var msgID int64 | ||
| err := h.DB.Pool.QueryRow(ctx, | ||
| `INSERT INTO msg (version, pid, no_reply, is_important, is_deflate, from_addr, topic, type, size, filepath, time_sent) | ||
| VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, '', NULL) | ||
| RETURNING id`, | ||
| msg.Version, msg.PID, msg.NoReply, msg.Important, msg.Deflate, msg.From, msg.Topic, msg.Type, msg.Size, | ||
| msg.Version, msg.PID, msg.NoReply, msg.Important, msg.Deflate, msg.From, msg.Topic, msg.Type, dataSize, | ||
| ).Scan(&msgID) | ||
| if err != nil { | ||
| log.Printf("create message: insert: %v", err) | ||
|
|
@@ -525,6 +533,26 @@ func (h *MessageHandler) Update(c *gin.Context) { | |
| return | ||
| } | ||
|
|
||
| if int64(len(msg.Data)) > h.MaxDataSize { | ||
| c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "message data exceeds maximum size"}) | ||
| return | ||
| } | ||
|
|
||
| // Check total message size (data + existing attachments). | ||
| var attachTotal int64 | ||
| if err := h.DB.Pool.QueryRow(ctx, | ||
| "SELECT COALESCE(SUM(filesize), 0) FROM msg_attachment WHERE msg_id = $1", | ||
| msgID, | ||
| ).Scan(&attachTotal); err != nil { | ||
| log.Printf("update message %d: total size check: %v", msgID, err) | ||
| c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to check message size"}) | ||
| return | ||
| } | ||
| if int64(len(msg.Data))+attachTotal > h.MaxMsgSize { | ||
| c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "total message size exceeds limit"}) | ||
| return | ||
| } | ||
|
Comment on lines
+536
to
+554
|
||
|
|
||
| msg.Deflate = isZip([]byte(msg.Data)) | ||
| ext := mimeToExt(msg.Type) | ||
|
|
||
|
|
@@ -537,7 +565,7 @@ func (h *MessageHandler) Update(c *gin.Context) { | |
|
|
||
| _, err = h.DB.Pool.Exec(ctx, | ||
| `UPDATE msg SET version=$1, pid=$2, no_reply=$3, is_important=$4, is_deflate=$5, topic=$6, type=$7, size=$8, filepath=$9 WHERE id=$10`, | ||
| msg.Version, msg.PID, msg.NoReply, msg.Important, msg.Deflate, msg.Topic, msg.Type, msg.Size, dataPath, msgID, | ||
| msg.Version, msg.PID, msg.NoReply, msg.Important, msg.Deflate, msg.Topic, msg.Type, len(msg.Data), dataPath, msgID, | ||
| ) | ||
| if err != nil { | ||
| log.Printf("update message %d: %v", msgID, err) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -2,10 +2,15 @@ package main | |||||
|
|
||||||
| import ( | ||||||
| "context" | ||||||
| "crypto/tls" | ||||||
| "encoding/base64" | ||||||
| "errors" | ||||||
| "log" | ||||||
| "net/http" | ||||||
| "os" | ||||||
| "strconv" | ||||||
| "strings" | ||||||
| "time" | ||||||
|
|
||||||
| "github.com/gin-gonic/gin" | ||||||
| "github.com/joho/godotenv" | ||||||
|
|
@@ -34,6 +39,11 @@ func main() { | |||||
|
|
||||||
| // Optional configuration with defaults. | ||||||
| idURL := envOrDefault("FMSG_ID_URL", "http://127.0.0.1:8080") | ||||||
| rateLimit := envOrDefaultInt("FMSG_API_RATE_LIMIT", 10) | ||||||
| rateBurst := envOrDefaultInt("FMSG_API_RATE_BURST", 20) | ||||||
| maxDataSize := int64(envOrDefaultInt("FMSG_API_MAX_DATA_SIZE", 10)) * 1024 * 1024 | ||||||
| maxAttachSize := int64(envOrDefaultInt("FMSG_API_MAX_ATTACH_SIZE", 10)) * 1024 * 1024 | ||||||
| maxMsgSize := int64(envOrDefaultInt("FMSG_API_MAX_MSG_SIZE", 20)) * 1024 * 1024 | ||||||
|
|
||||||
| // Connect to PostgreSQL (uses standard PG* environment variables). | ||||||
| ctx := context.Background() | ||||||
|
|
@@ -53,9 +63,12 @@ func main() { | |||||
| // Create Gin router. | ||||||
| router := gin.Default() | ||||||
|
|
||||||
| // Global rate limiter. | ||||||
| router.Use(middleware.NewRateLimiter(ctx, float64(rateLimit), rateBurst)) | ||||||
|
|
||||||
| // Instantiate handlers. | ||||||
| msgHandler := handlers.NewMessageHandler(database, dataDir) | ||||||
| attHandler := handlers.NewAttachmentHandler(database, dataDir) | ||||||
| msgHandler := handlers.NewMessageHandler(database, dataDir, maxDataSize, maxMsgSize) | ||||||
| attHandler := handlers.NewAttachmentHandler(database, dataDir, maxAttachSize, maxMsgSize) | ||||||
|
|
||||||
| // Register routes under /fmsg, all protected by JWT. | ||||||
| fmsg := router.Group("/fmsg") | ||||||
|
|
@@ -76,15 +89,26 @@ func main() { | |||||
| fmsg.DELETE("/:id/attach/:filename", attHandler.DeleteAttachment) | ||||||
| } | ||||||
|
|
||||||
| srv := &http.Server{ | ||||||
| Handler: router, | ||||||
| ReadHeaderTimeout: 10 * time.Second, | ||||||
| WriteTimeout: 65 * time.Second, // must exceed /wait max timeout (60s) | ||||||
|
||||||
| WriteTimeout: 65 * time.Second, // must exceed /wait max timeout (60s) | |
| WriteTimeout: 0, // disabled: a server-wide write timeout can prematurely terminate slow long-poll/download responses |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -63,6 +63,7 @@ func SetupJWT(key []byte, idURL string) (*jwt.GinJWTMiddleware, error) { | |
| } | ||
| addr := v.Addr | ||
| if !isValidAddr(addr) { | ||
| log.Printf("auth rejected: ip=%s reason=invalid_addr", c.ClientIP()) | ||
| return false | ||
| } | ||
| // Store the validated identity in context for downstream handlers. | ||
|
|
@@ -76,10 +77,12 @@ func SetupJWT(key []byte, idURL string) (*jwt.GinJWTMiddleware, error) { | |
| return false | ||
| } | ||
| if code == http.StatusNotFound { | ||
| log.Printf("auth rejected: ip=%s addr=%s reason=not_found", c.ClientIP(), addr) | ||
| c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("User %s not found", addr)}) | ||
| return false | ||
| } | ||
| if code == http.StatusOK && !accepting { | ||
| log.Printf("auth rejected: ip=%s addr=%s reason=not_accepting", c.ClientIP(), addr) | ||
| c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("User %s not authorised to send new messages", addr)}) | ||
| return false | ||
| } | ||
|
|
@@ -88,6 +91,7 @@ func SetupJWT(key []byte, idURL string) (*jwt.GinJWTMiddleware, error) { | |
|
|
||
| // Unauthorized responds with 401 when JWT validation fails. | ||
| Unauthorized: func(c *gin.Context, code int, message string) { | ||
| log.Printf("auth failure: ip=%s code=%d message=%s", c.ClientIP(), code, message) | ||
| c.JSON(code, gin.H{"error": message}) | ||
|
Comment on lines
93
to
95
|
||
| }, | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create enforces
MaxDataSizebut notMaxMsgSize. IfMaxMsgSizeis configured lower thanMaxDataSize, this allows creating messages whose data alone exceeds the total-message limit. Add aMaxMsgSizecheck here as well (and ideally ensureMaxDataSize <= MaxMsgSizeat startup).