diff --git a/.env.example b/.env.example index 1dfdf52..3ad6d95 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,4 @@ -FMSG_DATA_DIR=/opt/fmsg/data +FMSG_DATA_DIR=/var/lib/fmsgd/ FMSG_API_JWT_SECRET=fmsg-dev-secret-do-not-use-in-production PGHOST=localhost PGUSER=fmsg diff --git a/README.md b/README.md index 1361067..4a1427a 100644 --- a/README.md +++ b/README.md @@ -8,12 +8,17 @@ HTTP API providing user/client message handling for an fmsg host. Exposes CRUD o | Variable | Default | Description | | ------------------- | ------------------------ | ------------------------------------------------------- | -| `FMSG_DATA_DIR` | *(required)* | Path where message data files are stored, e.g. `/opt/fmsg/data` | +| `FMSG_DATA_DIR` | *(required)* | Path where message data files are stored, e.g. `/var/lib/fmsgd/` | | `FMSG_API_JWT_SECRET` | *(required)* | HMAC secret used to validate JWT tokens. Prefix with `base64:` to supply a base64-encoded key (e.g. `base64:c2VjcmV0`); otherwise the raw string is used. | | `FMSG_TLS_CERT` | *(optional)* | Path to the TLS certificate file (e.g. `/etc/letsencrypt/live/example.com/fullchain.pem`). When set with `FMSG_TLS_KEY`, enables HTTPS on port 443. | | `FMSG_TLS_KEY` | *(optional)* | Path to the TLS private key file (e.g. `/etc/letsencrypt/live/example.com/privkey.pem`). Must be set together with `FMSG_TLS_CERT`. | | `FMSG_API_PORT` | `8000` | TCP port for plain HTTP mode (ignored when TLS is enabled) | | `FMSG_ID_URL` | `http://127.0.0.1:8080` | Base URL of the fmsgid identity service | +| `FMSG_API_RATE_LIMIT`| `10` | Max sustained requests per second per IP | +| `FMSG_API_RATE_BURST`| `20` | Max burst size for the per-IP rate limiter | +| `FMSG_API_MAX_DATA_SIZE`| `10` | Maximum message data size in megabytes | +| `FMSG_API_MAX_ATTACH_SIZE`| `10` | Maximum attachment file size in megabytes | +| `FMSG_API_MAX_MSG_SIZE`| `20` | Maximum total message size (data + attachments) in megabytes | Standard PostgreSQL environment variables (`PGHOST`, `PGPORT`, `PGUSER`, `PGPASSWORD`, `PGDATABASE`) are used for database connectivity. @@ -63,7 +68,7 @@ Omit the TLS variables to run a plain HTTP server. Override the port with `FMSG_API_PORT` (default `8000`). ```bash -export FMSG_DATA_DIR=/opt/fmsg/data +export FMSG_DATA_DIR=/var/lib/fmsgd/ export FMSG_API_JWT_SECRET=changeme export PGHOST=localhost export PGUSER=fmsg @@ -74,10 +79,23 @@ cd src go run . ``` +The server starts on port `8000` by default. Override with `FMSG_API_PORT`. + +The HTTP server is configured with `ReadHeaderTimeout: 10s`, `WriteTimeout: 65s`, +and `IdleTimeout: 120s`. The write timeout exceeds the `/wait` endpoint's +maximum long-poll duration (60 s) so connections are not dropped prematurely. + ## API Routes All routes are prefixed with `/fmsg` and require a valid `Authorization: Bearer ` header. +All routes are subject to per-IP rate limiting. When the limit is exceeded, the +server responds with `429 Too Many Requests`: + +```json +{"error": "rate limit exceeded"} +``` + | Method | Path | Description | | -------- | ------------------------------------------- | ------------------------ | | `GET` | `/fmsg` | List messages for user | diff --git a/src/handlers/attachments.go b/src/handlers/attachments.go index a6b25b8..f662215 100644 --- a/src/handlers/attachments.go +++ b/src/handlers/attachments.go @@ -19,13 +19,15 @@ import ( // AttachmentHandler holds dependencies for attachment routes. type AttachmentHandler struct { - DB *db.DB - DataDir string + DB *db.DB + DataDir string + MaxAttachSize int64 + MaxMsgSize int64 } // NewAttachmentHandler creates an AttachmentHandler. -func NewAttachmentHandler(database *db.DB, dataDir string) *AttachmentHandler { - return &AttachmentHandler{DB: database, DataDir: dataDir} +func NewAttachmentHandler(database *db.DB, dataDir string, maxAttachSize, maxMsgSize int64) *AttachmentHandler { + return &AttachmentHandler{DB: database, DataDir: dataDir, MaxAttachSize: maxAttachSize, MaxMsgSize: maxMsgSize} } // Upload handles POST /api/v1/messages/:id/attachments. @@ -89,14 +91,14 @@ func (h *AttachmentHandler) Upload(c *gin.Context) { // Resolve collision-safe filepath. finalPath := resolveFilePath(dir, intendedFilename) - // Write file to disk. + // Write file to disk (limit read to MaxAttachSize + 1 to detect oversized uploads). dst, err := os.OpenFile(finalPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0640) if err != nil { log.Printf("upload attachment: open file: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save attachment"}) return } - written, err := io.Copy(dst, file) + written, err := io.Copy(dst, io.LimitReader(file, h.MaxAttachSize+1)) closeErr := dst.Close() if err != nil || closeErr != nil { _ = os.Remove(finalPath) @@ -105,8 +107,43 @@ func (h *AttachmentHandler) Upload(c *gin.Context) { return } - // Persist to DB. - _, err = h.DB.Pool.Exec(ctx, + if written > h.MaxAttachSize { + _ = os.Remove(finalPath) + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "attachment exceeds maximum size"}) + return + } + + // Check total message size and persist attachment in a transaction to + // prevent concurrent uploads from exceeding MaxMsgSize. + tx, err := h.DB.Pool.Begin(ctx) + if err != nil { + _ = os.Remove(finalPath) + log.Printf("upload attachment: begin tx: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to record attachment"}) + return + } + defer tx.Rollback(ctx) + + // Lock the message row and compute current total size. + var currentTotal int64 + if err = tx.QueryRow(ctx, + `SELECT m.size + COALESCE((SELECT SUM(filesize) FROM msg_attachment WHERE msg_id = m.id), 0) + FROM msg m WHERE m.id = $1 FOR UPDATE`, + msgID, + ).Scan(¤tTotal); err != nil { + _ = os.Remove(finalPath) + log.Printf("upload attachment: total size check: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to check message size"}) + return + } + if currentTotal+written > h.MaxMsgSize { + _ = os.Remove(finalPath) + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "total message size exceeds limit"}) + return + } + + // Persist attachment to DB. + _, err = tx.Exec(ctx, `INSERT INTO msg_attachment (msg_id, filename, filesize, filepath) VALUES ($1, $2, $3, $4) ON CONFLICT (msg_id, filename) DO UPDATE SET filesize=$3, filepath=$4`, @@ -119,6 +156,13 @@ func (h *AttachmentHandler) Upload(c *gin.Context) { return } + if err = tx.Commit(ctx); err != nil { + _ = os.Remove(finalPath) + log.Printf("upload attachment: commit: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to record attachment"}) + return + } + c.JSON(http.StatusCreated, gin.H{"filename": intendedFilename, "size": written}) } diff --git a/src/handlers/messages.go b/src/handlers/messages.go index 9dae2e9..932977d 100644 --- a/src/handlers/messages.go +++ b/src/handlers/messages.go @@ -24,13 +24,15 @@ import ( // MessageHandler holds dependencies for message routes. type MessageHandler struct { - DB *db.DB - DataDir string + DB *db.DB + DataDir string + MaxDataSize int64 + MaxMsgSize int64 } // NewMessageHandler creates a MessageHandler. -func NewMessageHandler(database *db.DB, dataDir string) *MessageHandler { - return &MessageHandler{DB: database, DataDir: dataDir} +func NewMessageHandler(database *db.DB, dataDir string, maxDataSize, maxMsgSize int64) *MessageHandler { + return &MessageHandler{DB: database, DataDir: dataDir, MaxDataSize: maxDataSize, MaxMsgSize: maxMsgSize} } // messageListItem is the JSON shape for each message in the list response. @@ -334,6 +336,11 @@ func (h *MessageHandler) Create(c *gin.Context) { return } + if int64(len(msg.Data)) > h.MaxDataSize { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "message data exceeds maximum size"}) + return + } + ctx := c.Request.Context() // Validate PID references an existing message. @@ -358,12 +365,13 @@ func (h *MessageHandler) Create(c *gin.Context) { ext := mimeToExt(msg.Type) // Insert message row with empty filepath; update after we know the ID. + dataSize := len(msg.Data) var msgID int64 err := h.DB.Pool.QueryRow(ctx, `INSERT INTO msg (version, pid, no_reply, is_important, is_deflate, from_addr, topic, type, size, filepath, time_sent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, '', NULL) RETURNING id`, - msg.Version, msg.PID, msg.NoReply, msg.Important, msg.Deflate, msg.From, msg.Topic, msg.Type, msg.Size, + msg.Version, msg.PID, msg.NoReply, msg.Important, msg.Deflate, msg.From, msg.Topic, msg.Type, dataSize, ).Scan(&msgID) if err != nil { log.Printf("create message: insert: %v", err) @@ -525,6 +533,26 @@ func (h *MessageHandler) Update(c *gin.Context) { return } + if int64(len(msg.Data)) > h.MaxDataSize { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "message data exceeds maximum size"}) + return + } + + // Check total message size (data + existing attachments). + var attachTotal int64 + if err := h.DB.Pool.QueryRow(ctx, + "SELECT COALESCE(SUM(filesize), 0) FROM msg_attachment WHERE msg_id = $1", + msgID, + ).Scan(&attachTotal); err != nil { + log.Printf("update message %d: total size check: %v", msgID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to check message size"}) + return + } + if int64(len(msg.Data))+attachTotal > h.MaxMsgSize { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "total message size exceeds limit"}) + return + } + msg.Deflate = isZip([]byte(msg.Data)) ext := mimeToExt(msg.Type) @@ -537,7 +565,7 @@ func (h *MessageHandler) Update(c *gin.Context) { _, err = h.DB.Pool.Exec(ctx, `UPDATE msg SET version=$1, pid=$2, no_reply=$3, is_important=$4, is_deflate=$5, topic=$6, type=$7, size=$8, filepath=$9 WHERE id=$10`, - msg.Version, msg.PID, msg.NoReply, msg.Important, msg.Deflate, msg.Topic, msg.Type, msg.Size, dataPath, msgID, + msg.Version, msg.PID, msg.NoReply, msg.Important, msg.Deflate, msg.Topic, msg.Type, len(msg.Data), dataPath, msgID, ) if err != nil { log.Printf("update message %d: %v", msgID, err) diff --git a/src/main.go b/src/main.go index e36a02d..d683d8a 100644 --- a/src/main.go +++ b/src/main.go @@ -2,10 +2,15 @@ package main import ( "context" + "crypto/tls" "encoding/base64" + "errors" "log" + "net/http" "os" + "strconv" "strings" + "time" "github.com/gin-gonic/gin" "github.com/joho/godotenv" @@ -34,6 +39,11 @@ func main() { // Optional configuration with defaults. idURL := envOrDefault("FMSG_ID_URL", "http://127.0.0.1:8080") + rateLimit := envOrDefaultInt("FMSG_API_RATE_LIMIT", 10) + rateBurst := envOrDefaultInt("FMSG_API_RATE_BURST", 20) + maxDataSize := int64(envOrDefaultInt("FMSG_API_MAX_DATA_SIZE", 10)) * 1024 * 1024 + maxAttachSize := int64(envOrDefaultInt("FMSG_API_MAX_ATTACH_SIZE", 10)) * 1024 * 1024 + maxMsgSize := int64(envOrDefaultInt("FMSG_API_MAX_MSG_SIZE", 20)) * 1024 * 1024 // Connect to PostgreSQL (uses standard PG* environment variables). ctx := context.Background() @@ -53,9 +63,12 @@ func main() { // Create Gin router. router := gin.Default() + // Global rate limiter. + router.Use(middleware.NewRateLimiter(ctx, float64(rateLimit), rateBurst)) + // Instantiate handlers. - msgHandler := handlers.NewMessageHandler(database, dataDir) - attHandler := handlers.NewAttachmentHandler(database, dataDir) + msgHandler := handlers.NewMessageHandler(database, dataDir, maxDataSize, maxMsgSize) + attHandler := handlers.NewAttachmentHandler(database, dataDir, maxAttachSize, maxMsgSize) // Register routes under /fmsg, all protected by JWT. fmsg := router.Group("/fmsg") @@ -76,15 +89,26 @@ func main() { fmsg.DELETE("/:id/attach/:filename", attHandler.DeleteAttachment) } + srv := &http.Server{ + Handler: router, + ReadHeaderTimeout: 10 * time.Second, + WriteTimeout: 65 * time.Second, // must exceed /wait max timeout (60s) + IdleTimeout: 120 * time.Second, + MaxHeaderBytes: 1 << 20, // 1 MB + } + if tlsEnabled { + srv.Addr = ":443" log.Println("fmsg-webapi starting on :443") - if err = router.RunTLS(":443", tlsCert, tlsKey); err != nil { + srv.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12} + if err = srv.ListenAndServeTLS(tlsCert, tlsKey); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("server error: %v", err) } } else { port := envOrDefault("FMSG_API_PORT", "8000") + srv.Addr = ":" + port log.Printf("fmsg-webapi starting on :%s (plain HTTP)", port) - if err = router.Run(":" + port); err != nil { + if err = srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("server error: %v", err) } } @@ -107,6 +131,19 @@ func envOrDefault(key, defaultValue string) string { return defaultValue } +// envOrDefaultInt returns the environment variable as an int or defaultValue when unset. +// Fatally exits if the value is set but not a valid integer. +func envOrDefaultInt(key string, defaultValue int) int { + if v := os.Getenv(key); v != "" { + n, err := strconv.Atoi(v) + if err != nil { + log.Fatalf("environment variable %s must be an integer: %v", key, err) + } + return n + } + return defaultValue +} + // parseSecret returns the HMAC key bytes for the given secret string. // If s begins with "base64:" the remainder is base64-decoded; otherwise the // raw string bytes are used. diff --git a/src/middleware/jwt.go b/src/middleware/jwt.go index 1e904b0..d367650 100644 --- a/src/middleware/jwt.go +++ b/src/middleware/jwt.go @@ -63,6 +63,7 @@ func SetupJWT(key []byte, idURL string) (*jwt.GinJWTMiddleware, error) { } addr := v.Addr if !isValidAddr(addr) { + log.Printf("auth rejected: ip=%s reason=invalid_addr", c.ClientIP()) return false } // Store the validated identity in context for downstream handlers. @@ -76,10 +77,12 @@ func SetupJWT(key []byte, idURL string) (*jwt.GinJWTMiddleware, error) { return false } if code == http.StatusNotFound { + log.Printf("auth rejected: ip=%s addr=%s reason=not_found", c.ClientIP(), addr) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("User %s not found", addr)}) return false } if code == http.StatusOK && !accepting { + log.Printf("auth rejected: ip=%s addr=%s reason=not_accepting", c.ClientIP(), addr) c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("User %s not authorised to send new messages", addr)}) return false } @@ -88,6 +91,7 @@ func SetupJWT(key []byte, idURL string) (*jwt.GinJWTMiddleware, error) { // Unauthorized responds with 401 when JWT validation fails. Unauthorized: func(c *gin.Context, code int, message string) { + log.Printf("auth failure: ip=%s code=%d message=%s", c.ClientIP(), code, message) c.JSON(code, gin.H{"error": message}) }, diff --git a/src/middleware/ratelimit_test.go b/src/middleware/ratelimit_test.go new file mode 100644 index 0000000..f793001 --- /dev/null +++ b/src/middleware/ratelimit_test.go @@ -0,0 +1,100 @@ +package middleware + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func setupRateLimitRouter(rps float64, burst int) *gin.Engine { + r := gin.New() + r.Use(NewRateLimiter(context.Background(), rps, burst)) + r.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + return r +} + +func TestRateLimiterAllowsUnderLimit(t *testing.T) { + router := setupRateLimitRouter(10, 5) + + for i := 0; i < 5; i++ { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "1.2.3.4:1234" + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i, w.Code) + } + } +} + +func TestRateLimiterBlocksExcessBurst(t *testing.T) { + router := setupRateLimitRouter(1, 3) // 1 rps, burst of 3 + + // First 3 requests should succeed (burst). + for i := 0; i < 3; i++ { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "1.2.3.4:1234" + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i, w.Code) + } + } + + // Next request should be rate-limited. + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "1.2.3.4:1234" + router.ServeHTTP(w, req) + if w.Code != http.StatusTooManyRequests { + t.Fatalf("expected 429, got %d", w.Code) + } + + var body map[string]string + if err := json.NewDecoder(w.Body).Decode(&body); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if body["error"] != "rate limit exceeded" { + t.Fatalf("unexpected error message: %s", body["error"]) + } +} + +func TestRateLimiterTracksIPsIndependently(t *testing.T) { + router := setupRateLimitRouter(1, 1) // 1 rps, burst of 1 + + // Exhaust IP A. + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.0.0.1:1000" + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("IP A first request: expected 200, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.0.0.1:1000" + router.ServeHTTP(w, req) + if w.Code != http.StatusTooManyRequests { + t.Fatalf("IP A second request: expected 429, got %d", w.Code) + } + + // IP B should still be allowed. + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.0.0.2:2000" + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("IP B first request: expected 200, got %d", w.Code) + } +}