From 5e09ef87599b5925142531b2e7181073ab042214 Mon Sep 17 00:00:00 2001 From: Mark Mennell Date: Sat, 18 Apr 2026 11:55:16 +0800 Subject: [PATCH 1/4] rate-limiting, timeouts and log auth failures --- README.md | 13 +++++ src/go.mod | 1 + src/go.sum | 2 + src/main.go | 30 +++++++++- src/middleware/jwt.go | 4 ++ src/middleware/ratelimit.go | 71 +++++++++++++++++++++++ src/middleware/ratelimit_test.go | 99 ++++++++++++++++++++++++++++++++ 7 files changed, 219 insertions(+), 1 deletion(-) create mode 100644 src/middleware/ratelimit.go create mode 100644 src/middleware/ratelimit_test.go diff --git a/README.md b/README.md index 4e2b89a..10aa201 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,8 @@ fine-grained authorisation rules based on the user identity they contain. | `FMSG_API_JWT_SECRET` | *(required)* | HMAC secret used to validate JWT tokens. Prefix with `base64:` to supply a base64-encoded key (e.g. `base64:c2VjcmV0`); otherwise the raw string is used. | | `FMSG_API_PORT` | `8000` | TCP port the HTTP server listens on | | `FMSG_ID_URL` | `http://127.0.0.1:8080` | Base URL of the fmsgid identity service | +| `FMSG_API_RATE_LIMIT`| `10` | Max sustained requests per second per IP | +| `FMSG_API_RATE_BURST`| `20` | Max burst size for the per-IP rate limiter | Standard PostgreSQL environment variables (`PGHOST`, `PGPORT`, `PGUSER`, `PGPASSWORD`, `PGDATABASE`) are used for database connectivity. @@ -54,10 +56,21 @@ go run . The server starts on port `8000` by default. Override with `FMSG_API_PORT`. +The HTTP server is configured with `ReadTimeout: 10s`, `WriteTimeout: 65s`, +and `IdleTimeout: 120s`. The write timeout exceeds the `/wait` endpoint's +maximum long-poll duration (60 s) so connections are not dropped prematurely. + ## API Routes All routes are prefixed with `/fmsg` and require a valid `Authorization: Bearer ` header. +All routes are subject to per-IP rate limiting. When the limit is exceeded, the +server responds with `429 Too Many Requests`: + +```json +{"error": "rate limit exceeded"} +``` + | Method | Path | Description | | -------- | ------------------------------------------- | ------------------------ | | `GET` | `/fmsg` | List messages for user | diff --git a/src/go.mod b/src/go.mod index c3cbb6b..7386e77 100644 --- a/src/go.mod +++ b/src/go.mod @@ -44,5 +44,6 @@ require ( golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect + golang.org/x/time v0.15.0 // indirect google.golang.org/protobuf v1.36.10 // indirect ) diff --git a/src/go.sum b/src/go.sum index 476e95c..edc7d4d 100644 --- a/src/go.sum +++ b/src/go.sum @@ -108,6 +108,8 @@ golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/src/main.go b/src/main.go index 8142bb2..67de4cd 100644 --- a/src/main.go +++ b/src/main.go @@ -4,8 +4,11 @@ import ( "context" "encoding/base64" "log" + "net/http" "os" + "strconv" "strings" + "time" "github.com/gin-gonic/gin" "github.com/joho/godotenv" @@ -27,6 +30,8 @@ func main() { // Optional configuration with defaults. port := envOrDefault("FMSG_API_PORT", "8000") idURL := envOrDefault("FMSG_ID_URL", "http://127.0.0.1:8080") + rateLimit := envOrDefaultInt("FMSG_API_RATE_LIMIT", 10) + rateBurst := envOrDefaultInt("FMSG_API_RATE_BURST", 20) // Connect to PostgreSQL (uses standard PG* environment variables). ctx := context.Background() @@ -46,6 +51,9 @@ func main() { // Create Gin router. router := gin.Default() + // Global rate limiter. + router.Use(middleware.NewRateLimiter(float64(rateLimit), rateBurst)) + // Instantiate handlers. msgHandler := handlers.NewMessageHandler(database, dataDir) attHandler := handlers.NewAttachmentHandler(database, dataDir) @@ -70,7 +78,15 @@ func main() { } log.Printf("fmsg-webapi starting on :%s", port) - if err = router.Run(":" + port); err != nil { + srv := &http.Server{ + Addr: ":" + port, + Handler: router, + ReadTimeout: 10 * time.Second, + WriteTimeout: 65 * time.Second, // must exceed /wait max timeout (60s) + IdleTimeout: 120 * time.Second, + MaxHeaderBytes: 1 << 20, // 1 MB + } + if err = srv.ListenAndServe(); err != nil { log.Fatalf("server error: %v", err) } } @@ -92,6 +108,18 @@ func envOrDefault(key, defaultValue string) string { return defaultValue } +// envOrDefaultInt returns the environment variable as an int or defaultValue when unset/invalid. +func envOrDefaultInt(key string, defaultValue int) int { + if v := os.Getenv(key); v != "" { + n, err := strconv.Atoi(v) + if err != nil { + log.Fatalf("environment variable %s must be an integer: %v", key, err) + } + return n + } + return defaultValue +} + // parseSecret returns the HMAC key bytes for the given secret string. // If s begins with "base64:" the remainder is base64-decoded; otherwise the // raw string bytes are used. diff --git a/src/middleware/jwt.go b/src/middleware/jwt.go index 1e904b0..d367650 100644 --- a/src/middleware/jwt.go +++ b/src/middleware/jwt.go @@ -63,6 +63,7 @@ func SetupJWT(key []byte, idURL string) (*jwt.GinJWTMiddleware, error) { } addr := v.Addr if !isValidAddr(addr) { + log.Printf("auth rejected: ip=%s reason=invalid_addr", c.ClientIP()) return false } // Store the validated identity in context for downstream handlers. @@ -76,10 +77,12 @@ func SetupJWT(key []byte, idURL string) (*jwt.GinJWTMiddleware, error) { return false } if code == http.StatusNotFound { + log.Printf("auth rejected: ip=%s addr=%s reason=not_found", c.ClientIP(), addr) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("User %s not found", addr)}) return false } if code == http.StatusOK && !accepting { + log.Printf("auth rejected: ip=%s addr=%s reason=not_accepting", c.ClientIP(), addr) c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("User %s not authorised to send new messages", addr)}) return false } @@ -88,6 +91,7 @@ func SetupJWT(key []byte, idURL string) (*jwt.GinJWTMiddleware, error) { // Unauthorized responds with 401 when JWT validation fails. Unauthorized: func(c *gin.Context, code int, message string) { + log.Printf("auth failure: ip=%s code=%d message=%s", c.ClientIP(), code, message) c.JSON(code, gin.H{"error": message}) }, diff --git a/src/middleware/ratelimit.go b/src/middleware/ratelimit.go new file mode 100644 index 0000000..6617cfd --- /dev/null +++ b/src/middleware/ratelimit.go @@ -0,0 +1,71 @@ +package middleware + +import ( + "log" + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" +) + +type visitor struct { + limiter *rate.Limiter + lastSeen time.Time +} + +type rateLimiter struct { + visitors sync.Map + rps rate.Limit + burst int +} + +// NewRateLimiter returns Gin middleware that enforces a per-IP token-bucket +// rate limit. rps is the sustained requests-per-second rate and burst is the +// maximum burst size allowed. +func NewRateLimiter(rps float64, burst int) gin.HandlerFunc { + rl := &rateLimiter{ + rps: rate.Limit(rps), + burst: burst, + } + go rl.cleanup() + return rl.handler +} + +func (rl *rateLimiter) getVisitor(ip string) *rate.Limiter { + val, ok := rl.visitors.Load(ip) + if ok { + v := val.(*visitor) + v.lastSeen = time.Now() + return v.limiter + } + limiter := rate.NewLimiter(rl.rps, rl.burst) + rl.visitors.Store(ip, &visitor{limiter: limiter, lastSeen: time.Now()}) + return limiter +} + +func (rl *rateLimiter) handler(c *gin.Context) { + ip := c.ClientIP() + limiter := rl.getVisitor(ip) + if !limiter.Allow() { + log.Printf("rate limit exceeded: ip=%s", ip) + c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{"error": "rate limit exceeded"}) + return + } + c.Next() +} + +// cleanup removes visitors that have not been seen for 5 minutes. +func (rl *rateLimiter) cleanup() { + for { + time.Sleep(1 * time.Minute) + rl.visitors.Range(func(key, value any) bool { + v := value.(*visitor) + if time.Since(v.lastSeen) > 5*time.Minute { + rl.visitors.Delete(key) + } + return true + }) + } +} diff --git a/src/middleware/ratelimit_test.go b/src/middleware/ratelimit_test.go new file mode 100644 index 0000000..5806e6a --- /dev/null +++ b/src/middleware/ratelimit_test.go @@ -0,0 +1,99 @@ +package middleware + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +func setupRateLimitRouter(rps float64, burst int) *gin.Engine { + r := gin.New() + r.Use(NewRateLimiter(rps, burst)) + r.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + return r +} + +func TestRateLimiterAllowsUnderLimit(t *testing.T) { + router := setupRateLimitRouter(10, 5) + + for i := 0; i < 5; i++ { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "1.2.3.4:1234" + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i, w.Code) + } + } +} + +func TestRateLimiterBlocksExcessBurst(t *testing.T) { + router := setupRateLimitRouter(1, 3) // 1 rps, burst of 3 + + // First 3 requests should succeed (burst). + for i := 0; i < 3; i++ { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "1.2.3.4:1234" + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("request %d: expected 200, got %d", i, w.Code) + } + } + + // Next request should be rate-limited. + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "1.2.3.4:1234" + router.ServeHTTP(w, req) + if w.Code != http.StatusTooManyRequests { + t.Fatalf("expected 429, got %d", w.Code) + } + + var body map[string]string + if err := json.NewDecoder(w.Body).Decode(&body); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if body["error"] != "rate limit exceeded" { + t.Fatalf("unexpected error message: %s", body["error"]) + } +} + +func TestRateLimiterTracksIPsIndependently(t *testing.T) { + router := setupRateLimitRouter(1, 1) // 1 rps, burst of 1 + + // Exhaust IP A. + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.0.0.1:1000" + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("IP A first request: expected 200, got %d", w.Code) + } + + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.0.0.1:1000" + router.ServeHTTP(w, req) + if w.Code != http.StatusTooManyRequests { + t.Fatalf("IP A second request: expected 429, got %d", w.Code) + } + + // IP B should still be allowed. + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/test", nil) + req.RemoteAddr = "10.0.0.2:2000" + router.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("IP B first request: expected 200, got %d", w.Code) + } +} From 330e73fc464b13e4326f6e3f19a45175742b19c5 Mon Sep 17 00:00:00 2001 From: Mark Mennell Date: Sat, 18 Apr 2026 11:57:21 +0800 Subject: [PATCH 2/4] example FMSG_DATA_DIR to use /var/lib/fmsgd/ --- .env.example | 2 +- README.md | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.env.example b/.env.example index 1dfdf52..3ad6d95 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,4 @@ -FMSG_DATA_DIR=/opt/fmsg/data +FMSG_DATA_DIR=/var/lib/fmsgd/ FMSG_API_JWT_SECRET=fmsg-dev-secret-do-not-use-in-production PGHOST=localhost PGUSER=fmsg diff --git a/README.md b/README.md index 10aa201..5598827 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ fine-grained authorisation rules based on the user identity they contain. | Variable | Default | Description | | ------------------- | ------------------------ | ------------------------------------------------------- | -| `FMSG_DATA_DIR` | *(required)* | Path where message data files are stored, e.g. `/opt/fmsg/data` | +| `FMSG_DATA_DIR` | *(required)* | Path where message data files are stored, e.g. `/var/lib/fmsgd/` | | `FMSG_API_JWT_SECRET` | *(required)* | HMAC secret used to validate JWT tokens. Prefix with `base64:` to supply a base64-encoded key (e.g. `base64:c2VjcmV0`); otherwise the raw string is used. | | `FMSG_API_PORT` | `8000` | TCP port the HTTP server listens on | | `FMSG_ID_URL` | `http://127.0.0.1:8080` | Base URL of the fmsgid identity service | @@ -43,7 +43,7 @@ go test ./... ## Running ```bash -export FMSG_DATA_DIR=/opt/fmsg/data +export FMSG_DATA_DIR=/var/lib/fmsgd/ export FMSG_API_JWT_SECRET=changeme export PGHOST=localhost export PGUSER=fmsg From 72806f01301c7a5ee0cdefc3fc3b2c88c01d5eeb Mon Sep 17 00:00:00 2001 From: Mark Mennell Date: Sat, 18 Apr 2026 12:28:53 +0800 Subject: [PATCH 3/4] improve rate limiting, max message/attachment sizes --- README.md | 5 +++- src/go.mod | 4 +-- src/handlers/attachments.go | 38 ++++++++++++++++++++---- src/handlers/messages.go | 35 +++++++++++++++++++--- src/main.go | 24 ++++++++------- src/middleware/ratelimit.go | 51 +++++++++++++++++++------------- src/middleware/ratelimit_test.go | 3 +- 7 files changed, 115 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index 5598827..2891a34 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,9 @@ fine-grained authorisation rules based on the user identity they contain. | `FMSG_ID_URL` | `http://127.0.0.1:8080` | Base URL of the fmsgid identity service | | `FMSG_API_RATE_LIMIT`| `10` | Max sustained requests per second per IP | | `FMSG_API_RATE_BURST`| `20` | Max burst size for the per-IP rate limiter | +| `FMSG_API_MAX_DATA_SIZE`| `10` | Maximum message data size in megabytes | +| `FMSG_API_MAX_ATTACH_SIZE`| `10` | Maximum attachment file size in megabytes | +| `FMSG_API_MAX_MSG_SIZE`| `20` | Maximum total message size (data + attachments) in megabytes | Standard PostgreSQL environment variables (`PGHOST`, `PGPORT`, `PGUSER`, `PGPASSWORD`, `PGDATABASE`) are used for database connectivity. @@ -56,7 +59,7 @@ go run . The server starts on port `8000` by default. Override with `FMSG_API_PORT`. -The HTTP server is configured with `ReadTimeout: 10s`, `WriteTimeout: 65s`, +The HTTP server is configured with `ReadHeaderTimeout: 10s`, `WriteTimeout: 65s`, and `IdleTimeout: 120s`. The write timeout exceeds the `/wait` endpoint's maximum long-poll duration (60 s) so connections are not dropped prematurely. diff --git a/src/go.mod b/src/go.mod index 7386e77..69e0407 100644 --- a/src/go.mod +++ b/src/go.mod @@ -5,8 +5,10 @@ go 1.25.0 require ( github.com/appleboy/gin-jwt/v2 v2.10.3 github.com/gin-gonic/gin v1.12.0 + github.com/golang-jwt/jwt/v4 v4.5.2 github.com/jackc/pgx/v5 v5.8.0 github.com/joho/godotenv v1.5.1 + golang.org/x/time v0.15.0 ) require ( @@ -21,7 +23,6 @@ require ( github.com/go-playground/validator/v10 v10.30.1 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-yaml v1.19.2 // indirect - github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect @@ -44,6 +45,5 @@ require ( golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect - golang.org/x/time v0.15.0 // indirect google.golang.org/protobuf v1.36.10 // indirect ) diff --git a/src/handlers/attachments.go b/src/handlers/attachments.go index a6b25b8..93b1874 100644 --- a/src/handlers/attachments.go +++ b/src/handlers/attachments.go @@ -19,13 +19,15 @@ import ( // AttachmentHandler holds dependencies for attachment routes. type AttachmentHandler struct { - DB *db.DB - DataDir string + DB *db.DB + DataDir string + MaxAttachSize int64 + MaxMsgSize int64 } // NewAttachmentHandler creates an AttachmentHandler. -func NewAttachmentHandler(database *db.DB, dataDir string) *AttachmentHandler { - return &AttachmentHandler{DB: database, DataDir: dataDir} +func NewAttachmentHandler(database *db.DB, dataDir string, maxAttachSize, maxMsgSize int64) *AttachmentHandler { + return &AttachmentHandler{DB: database, DataDir: dataDir, MaxAttachSize: maxAttachSize, MaxMsgSize: maxMsgSize} } // Upload handles POST /api/v1/messages/:id/attachments. @@ -89,14 +91,14 @@ func (h *AttachmentHandler) Upload(c *gin.Context) { // Resolve collision-safe filepath. finalPath := resolveFilePath(dir, intendedFilename) - // Write file to disk. + // Write file to disk (limit read to MaxAttachSize + 1 to detect oversized uploads). dst, err := os.OpenFile(finalPath, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0640) if err != nil { log.Printf("upload attachment: open file: %v", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save attachment"}) return } - written, err := io.Copy(dst, file) + written, err := io.Copy(dst, io.LimitReader(file, h.MaxAttachSize+1)) closeErr := dst.Close() if err != nil || closeErr != nil { _ = os.Remove(finalPath) @@ -105,6 +107,30 @@ func (h *AttachmentHandler) Upload(c *gin.Context) { return } + if written > h.MaxAttachSize { + _ = os.Remove(finalPath) + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "attachment exceeds maximum size"}) + return + } + + // Check total message size (data + all attachments including this one). + var currentTotal int64 + if err = h.DB.Pool.QueryRow(ctx, + `SELECT m.size + COALESCE((SELECT SUM(filesize) FROM msg_attachment WHERE msg_id = m.id), 0) + FROM msg m WHERE m.id = $1`, + msgID, + ).Scan(¤tTotal); err != nil { + _ = os.Remove(finalPath) + log.Printf("upload attachment: total size check: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to check message size"}) + return + } + if currentTotal+written > h.MaxMsgSize { + _ = os.Remove(finalPath) + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "total message size exceeds limit"}) + return + } + // Persist to DB. _, err = h.DB.Pool.Exec(ctx, `INSERT INTO msg_attachment (msg_id, filename, filesize, filepath) diff --git a/src/handlers/messages.go b/src/handlers/messages.go index 9dae2e9..0df2a40 100644 --- a/src/handlers/messages.go +++ b/src/handlers/messages.go @@ -24,13 +24,15 @@ import ( // MessageHandler holds dependencies for message routes. type MessageHandler struct { - DB *db.DB - DataDir string + DB *db.DB + DataDir string + MaxDataSize int64 + MaxMsgSize int64 } // NewMessageHandler creates a MessageHandler. -func NewMessageHandler(database *db.DB, dataDir string) *MessageHandler { - return &MessageHandler{DB: database, DataDir: dataDir} +func NewMessageHandler(database *db.DB, dataDir string, maxDataSize, maxMsgSize int64) *MessageHandler { + return &MessageHandler{DB: database, DataDir: dataDir, MaxDataSize: maxDataSize, MaxMsgSize: maxMsgSize} } // messageListItem is the JSON shape for each message in the list response. @@ -334,6 +336,11 @@ func (h *MessageHandler) Create(c *gin.Context) { return } + if int64(len(msg.Data)) > h.MaxDataSize { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "message data exceeds maximum size"}) + return + } + ctx := c.Request.Context() // Validate PID references an existing message. @@ -525,6 +532,26 @@ func (h *MessageHandler) Update(c *gin.Context) { return } + if int64(len(msg.Data)) > h.MaxDataSize { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "message data exceeds maximum size"}) + return + } + + // Check total message size (data + existing attachments). + var attachTotal int64 + if err := h.DB.Pool.QueryRow(ctx, + "SELECT COALESCE(SUM(filesize), 0) FROM msg_attachment WHERE msg_id = $1", + msgID, + ).Scan(&attachTotal); err != nil { + log.Printf("update message %d: total size check: %v", msgID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to check message size"}) + return + } + if int64(len(msg.Data))+attachTotal > h.MaxMsgSize { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "total message size exceeds limit"}) + return + } + msg.Deflate = isZip([]byte(msg.Data)) ext := mimeToExt(msg.Type) diff --git a/src/main.go b/src/main.go index 67de4cd..e527f8a 100644 --- a/src/main.go +++ b/src/main.go @@ -32,6 +32,9 @@ func main() { idURL := envOrDefault("FMSG_ID_URL", "http://127.0.0.1:8080") rateLimit := envOrDefaultInt("FMSG_API_RATE_LIMIT", 10) rateBurst := envOrDefaultInt("FMSG_API_RATE_BURST", 20) + maxDataSize := int64(envOrDefaultInt("FMSG_API_MAX_DATA_SIZE", 10)) * 1024 * 1024 + maxAttachSize := int64(envOrDefaultInt("FMSG_API_MAX_ATTACH_SIZE", 10)) * 1024 * 1024 + maxMsgSize := int64(envOrDefaultInt("FMSG_API_MAX_MSG_SIZE", 20)) * 1024 * 1024 // Connect to PostgreSQL (uses standard PG* environment variables). ctx := context.Background() @@ -52,11 +55,11 @@ func main() { router := gin.Default() // Global rate limiter. - router.Use(middleware.NewRateLimiter(float64(rateLimit), rateBurst)) + router.Use(middleware.NewRateLimiter(ctx, float64(rateLimit), rateBurst)) // Instantiate handlers. - msgHandler := handlers.NewMessageHandler(database, dataDir) - attHandler := handlers.NewAttachmentHandler(database, dataDir) + msgHandler := handlers.NewMessageHandler(database, dataDir, maxDataSize, maxMsgSize) + attHandler := handlers.NewAttachmentHandler(database, dataDir, maxAttachSize, maxMsgSize) // Register routes under /fmsg, all protected by JWT. fmsg := router.Group("/fmsg") @@ -79,12 +82,12 @@ func main() { log.Printf("fmsg-webapi starting on :%s", port) srv := &http.Server{ - Addr: ":" + port, - Handler: router, - ReadTimeout: 10 * time.Second, - WriteTimeout: 65 * time.Second, // must exceed /wait max timeout (60s) - IdleTimeout: 120 * time.Second, - MaxHeaderBytes: 1 << 20, // 1 MB + Addr: ":" + port, + Handler: router, + ReadHeaderTimeout: 10 * time.Second, + WriteTimeout: 65 * time.Second, // must exceed /wait max timeout (60s) + IdleTimeout: 120 * time.Second, + MaxHeaderBytes: 1 << 20, // 1 MB } if err = srv.ListenAndServe(); err != nil { log.Fatalf("server error: %v", err) @@ -108,7 +111,8 @@ func envOrDefault(key, defaultValue string) string { return defaultValue } -// envOrDefaultInt returns the environment variable as an int or defaultValue when unset/invalid. +// envOrDefaultInt returns the environment variable as an int or defaultValue when unset. +// Fatally exits if the value is set but not a valid integer. func envOrDefaultInt(key string, defaultValue int) int { if v := os.Getenv(key); v != "" { n, err := strconv.Atoi(v) diff --git a/src/middleware/ratelimit.go b/src/middleware/ratelimit.go index 6617cfd..43af179 100644 --- a/src/middleware/ratelimit.go +++ b/src/middleware/ratelimit.go @@ -1,9 +1,11 @@ package middleware import ( + "context" "log" "net/http" "sync" + "sync/atomic" "time" "github.com/gin-gonic/gin" @@ -12,7 +14,7 @@ import ( type visitor struct { limiter *rate.Limiter - lastSeen time.Time + lastSeen atomic.Int64 // UnixNano } type rateLimiter struct { @@ -23,26 +25,26 @@ type rateLimiter struct { // NewRateLimiter returns Gin middleware that enforces a per-IP token-bucket // rate limit. rps is the sustained requests-per-second rate and burst is the -// maximum burst size allowed. -func NewRateLimiter(rps float64, burst int) gin.HandlerFunc { +// maximum burst size allowed. The cleanup goroutine runs until ctx is cancelled. +func NewRateLimiter(ctx context.Context, rps float64, burst int) gin.HandlerFunc { rl := &rateLimiter{ rps: rate.Limit(rps), burst: burst, } - go rl.cleanup() + go rl.cleanup(ctx) return rl.handler } func (rl *rateLimiter) getVisitor(ip string) *rate.Limiter { - val, ok := rl.visitors.Load(ip) - if ok { - v := val.(*visitor) - v.lastSeen = time.Now() - return v.limiter + now := time.Now().UnixNano() + v := &visitor{limiter: rate.NewLimiter(rl.rps, rl.burst)} + v.lastSeen.Store(now) + + if actual, loaded := rl.visitors.LoadOrStore(ip, v); loaded { + v = actual.(*visitor) + v.lastSeen.Store(now) } - limiter := rate.NewLimiter(rl.rps, rl.burst) - rl.visitors.Store(ip, &visitor{limiter: limiter, lastSeen: time.Now()}) - return limiter + return v.limiter } func (rl *rateLimiter) handler(c *gin.Context) { @@ -57,15 +59,22 @@ func (rl *rateLimiter) handler(c *gin.Context) { } // cleanup removes visitors that have not been seen for 5 minutes. -func (rl *rateLimiter) cleanup() { +func (rl *rateLimiter) cleanup(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() for { - time.Sleep(1 * time.Minute) - rl.visitors.Range(func(key, value any) bool { - v := value.(*visitor) - if time.Since(v.lastSeen) > 5*time.Minute { - rl.visitors.Delete(key) - } - return true - }) + select { + case <-ctx.Done(): + return + case <-ticker.C: + now := time.Now().UnixNano() + rl.visitors.Range(func(key, value any) bool { + v := value.(*visitor) + if now-v.lastSeen.Load() > int64(5*time.Minute) { + rl.visitors.Delete(key) + } + return true + }) + } } } diff --git a/src/middleware/ratelimit_test.go b/src/middleware/ratelimit_test.go index 5806e6a..f793001 100644 --- a/src/middleware/ratelimit_test.go +++ b/src/middleware/ratelimit_test.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -15,7 +16,7 @@ func init() { func setupRateLimitRouter(rps float64, burst int) *gin.Engine { r := gin.New() - r.Use(NewRateLimiter(rps, burst)) + r.Use(NewRateLimiter(context.Background(), rps, burst)) r.GET("/test", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) }) From f9d764168cad65262352416eefeb75f745a4e293 Mon Sep 17 00:00:00 2001 From: Mark Mennell Date: Sat, 18 Apr 2026 14:10:17 +0800 Subject: [PATCH 4/4] pr review --- src/handlers/attachments.go | 28 +++++++++++++++++++++++----- src/handlers/messages.go | 5 +++-- src/main.go | 3 ++- src/middleware/ratelimit.go | 6 +++++- 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/src/handlers/attachments.go b/src/handlers/attachments.go index 93b1874..f662215 100644 --- a/src/handlers/attachments.go +++ b/src/handlers/attachments.go @@ -113,11 +113,22 @@ func (h *AttachmentHandler) Upload(c *gin.Context) { return } - // Check total message size (data + all attachments including this one). + // Check total message size and persist attachment in a transaction to + // prevent concurrent uploads from exceeding MaxMsgSize. + tx, err := h.DB.Pool.Begin(ctx) + if err != nil { + _ = os.Remove(finalPath) + log.Printf("upload attachment: begin tx: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to record attachment"}) + return + } + defer tx.Rollback(ctx) + + // Lock the message row and compute current total size. var currentTotal int64 - if err = h.DB.Pool.QueryRow(ctx, + if err = tx.QueryRow(ctx, `SELECT m.size + COALESCE((SELECT SUM(filesize) FROM msg_attachment WHERE msg_id = m.id), 0) - FROM msg m WHERE m.id = $1`, + FROM msg m WHERE m.id = $1 FOR UPDATE`, msgID, ).Scan(¤tTotal); err != nil { _ = os.Remove(finalPath) @@ -131,8 +142,8 @@ func (h *AttachmentHandler) Upload(c *gin.Context) { return } - // Persist to DB. - _, err = h.DB.Pool.Exec(ctx, + // Persist attachment to DB. + _, err = tx.Exec(ctx, `INSERT INTO msg_attachment (msg_id, filename, filesize, filepath) VALUES ($1, $2, $3, $4) ON CONFLICT (msg_id, filename) DO UPDATE SET filesize=$3, filepath=$4`, @@ -145,6 +156,13 @@ func (h *AttachmentHandler) Upload(c *gin.Context) { return } + if err = tx.Commit(ctx); err != nil { + _ = os.Remove(finalPath) + log.Printf("upload attachment: commit: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to record attachment"}) + return + } + c.JSON(http.StatusCreated, gin.H{"filename": intendedFilename, "size": written}) } diff --git a/src/handlers/messages.go b/src/handlers/messages.go index 0df2a40..932977d 100644 --- a/src/handlers/messages.go +++ b/src/handlers/messages.go @@ -365,12 +365,13 @@ func (h *MessageHandler) Create(c *gin.Context) { ext := mimeToExt(msg.Type) // Insert message row with empty filepath; update after we know the ID. + dataSize := len(msg.Data) var msgID int64 err := h.DB.Pool.QueryRow(ctx, `INSERT INTO msg (version, pid, no_reply, is_important, is_deflate, from_addr, topic, type, size, filepath, time_sent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, '', NULL) RETURNING id`, - msg.Version, msg.PID, msg.NoReply, msg.Important, msg.Deflate, msg.From, msg.Topic, msg.Type, msg.Size, + msg.Version, msg.PID, msg.NoReply, msg.Important, msg.Deflate, msg.From, msg.Topic, msg.Type, dataSize, ).Scan(&msgID) if err != nil { log.Printf("create message: insert: %v", err) @@ -564,7 +565,7 @@ func (h *MessageHandler) Update(c *gin.Context) { _, err = h.DB.Pool.Exec(ctx, `UPDATE msg SET version=$1, pid=$2, no_reply=$3, is_important=$4, is_deflate=$5, topic=$6, type=$7, size=$8, filepath=$9 WHERE id=$10`, - msg.Version, msg.PID, msg.NoReply, msg.Important, msg.Deflate, msg.Topic, msg.Type, msg.Size, dataPath, msgID, + msg.Version, msg.PID, msg.NoReply, msg.Important, msg.Deflate, msg.Topic, msg.Type, len(msg.Data), dataPath, msgID, ) if err != nil { log.Printf("update message %d: %v", msgID, err) diff --git a/src/main.go b/src/main.go index e527f8a..03dcc5c 100644 --- a/src/main.go +++ b/src/main.go @@ -3,6 +3,7 @@ package main import ( "context" "encoding/base64" + "errors" "log" "net/http" "os" @@ -89,7 +90,7 @@ func main() { IdleTimeout: 120 * time.Second, MaxHeaderBytes: 1 << 20, // 1 MB } - if err = srv.ListenAndServe(); err != nil { + if err = srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("server error: %v", err) } } diff --git a/src/middleware/ratelimit.go b/src/middleware/ratelimit.go index 43af179..905649f 100644 --- a/src/middleware/ratelimit.go +++ b/src/middleware/ratelimit.go @@ -37,9 +37,13 @@ func NewRateLimiter(ctx context.Context, rps float64, burst int) gin.HandlerFunc func (rl *rateLimiter) getVisitor(ip string) *rate.Limiter { now := time.Now().UnixNano() + if val, ok := rl.visitors.Load(ip); ok { + v := val.(*visitor) + v.lastSeen.Store(now) + return v.limiter + } v := &visitor{limiter: rate.NewLimiter(rl.rps, rl.burst)} v.lastSeen.Store(now) - if actual, loaded := rl.visitors.LoadOrStore(ip, v); loaded { v = actual.(*visitor) v.lastSeen.Store(now)