diff --git a/.gitignore b/.gitignore index 8489036..e5585d4 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,6 @@ Thumbs.db *.test *.out coverage.html + +# Local planning docs +PLAN.md diff --git a/main.go b/main.go index 5696262..af77dc2 100644 --- a/main.go +++ b/main.go @@ -18,14 +18,15 @@ import ( // FileConfig represents the YAML configuration file structure type FileConfig struct { - Host string `yaml:"host"` - Port int `yaml:"port"` - DataDir string `yaml:"data_dir"` - TLS TLSConfig `yaml:"tls"` - Users map[string]string `yaml:"users"` - RateLimit RateLimitFileConfig `yaml:"rate_limit"` - Extensions []string `yaml:"extensions"` - DuckLake DuckLakeFileConfig `yaml:"ducklake"` + Host string `yaml:"host"` + Port int `yaml:"port"` + DataDir string `yaml:"data_dir"` + TLS TLSConfig `yaml:"tls"` + Users map[string]string `yaml:"users"` + RateLimit RateLimitFileConfig `yaml:"rate_limit"` + Extensions []string `yaml:"extensions"` + DuckLake DuckLakeFileConfig `yaml:"ducklake"` + QueryTimeout string `yaml:"query_timeout"` // e.g., "30s", "5m" } type TLSConfig struct { @@ -109,12 +110,13 @@ func main() { fmt.Fprintf(os.Stderr, "Options:\n") flag.PrintDefaults() fmt.Fprintf(os.Stderr, "\nEnvironment variables:\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_CONFIG Path to YAML config file\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_HOST Host to bind to (default: 0.0.0.0)\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_PORT Port to listen on (default: 5432)\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_DATA_DIR Directory for DuckDB files (default: ./data)\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_CERT TLS certificate file (default: ./certs/server.crt)\n") - fmt.Fprintf(os.Stderr, " DUCKGRES_KEY TLS private key file (default: ./certs/server.key)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_CONFIG Path to YAML config file\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_HOST Host to bind to (default: 0.0.0.0)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_PORT Port to listen on (default: 5432)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_DATA_DIR Directory for DuckDB files (default: ./data)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_CERT TLS certificate file (default: ./certs/server.crt)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_KEY TLS private key file (default: ./certs/server.key)\n") + fmt.Fprintf(os.Stderr, " DUCKGRES_QUERY_TIMEOUT Maximum query execution time (e.g., 30s, 5m)\n") fmt.Fprintf(os.Stderr, "\nPrecedence: CLI flags > environment variables > config file > defaults\n") } @@ -194,6 +196,15 @@ func main() { cfg.Extensions = fileCfg.Extensions } + // Apply query timeout config + if fileCfg.QueryTimeout != "" { + if d, err := time.ParseDuration(fileCfg.QueryTimeout); err == nil { + cfg.QueryTimeout = d + } else { + slog.Warn("Invalid query_timeout duration: " + err.Error()) + } + } + // Apply DuckLake config if fileCfg.DuckLake.MetadataStore != "" { cfg.DuckLake.MetadataStore = fileCfg.DuckLake.MetadataStore @@ -282,6 +293,13 @@ func main() { if v := os.Getenv("DUCKGRES_DUCKLAKE_S3_PROFILE"); v != "" { cfg.DuckLake.S3Profile = v } + if v := os.Getenv("DUCKGRES_QUERY_TIMEOUT"); v != "" { + if d, err := time.ParseDuration(v); err == nil { + cfg.QueryTimeout = d + } else { + slog.Warn("Invalid DUCKGRES_QUERY_TIMEOUT: " + err.Error()) + } + } // Apply CLI flags (highest priority) if *host != "" { diff --git a/server/conn.go b/server/conn.go index 32f3b2d..08fb56d 100644 --- a/server/conn.go +++ b/server/conn.go @@ -3,6 +3,7 @@ package server import ( "bufio" "bytes" + "context" "crypto/tls" "database/sql" "encoding/binary" @@ -115,6 +116,21 @@ func (c *clientConn) newTranspiler(convertPlaceholders bool) *transpiler.Transpi }) } +// queryContext returns a context with the configured query timeout. +// If no timeout is configured, returns a background context. +// The caller should defer cancel() after receiving the context. +func (c *clientConn) queryContext() (context.Context, context.CancelFunc) { + if c.server.cfg.QueryTimeout > 0 { + return context.WithTimeout(context.Background(), c.server.cfg.QueryTimeout) + } + return context.Background(), func() {} +} + +// isQueryTimeout checks if an error is a context deadline exceeded error (query timeout) +func isQueryTimeout(err error) bool { + return err == context.DeadlineExceeded || (err != nil && strings.Contains(err.Error(), "context deadline exceeded")) +} + // stripSQLComments removes SQL comments from a query string. // It handles both line comments (-- ...) and block comments (/* ... */). func stripSQLComments(query string) string { @@ -296,7 +312,9 @@ func (c *clientConn) validateWithDuckDB(query string) error { // Use EXPLAIN to validate the query without executing it // DuckDB's EXPLAIN will fail if the query is syntactically invalid - _, err := c.db.Exec("EXPLAIN " + query) + ctx, cancel := c.queryContext() + defer cancel() + _, err := c.db.ExecContext(ctx, "EXPLAIN "+query) if err != nil { // Strip "EXPLAIN " from error messages to avoid confusing users, // but only if the original query didn't start with EXPLAIN @@ -736,12 +754,24 @@ func (c *clientConn) handleQuery(body []byte) error { return nil } - result, err := c.db.Exec(query) + ctx, cancel := c.queryContext() + defer cancel() + result, err := c.db.ExecContext(ctx, query) if err != nil { + // Check for query timeout + if isQueryTimeout(err) { + queryTimeoutsCounter.Inc() + slog.Warn("Query timeout.", "user", c.username, "query", query, "timeout", c.server.cfg.QueryTimeout) + c.sendError("ERROR", "57014", "canceling statement due to statement timeout") + c.setTxError() + _ = writeReadyForQuery(c.writer, c.txStatus) + _ = c.writer.Flush() + return nil + } // Retry ALTER TABLE as ALTER VIEW if target is a view if isAlterTableNotTableError(err) { if alteredQuery, ok := transpiler.ConvertAlterTableToAlterView(query); ok { - result, err = c.db.Exec(alteredQuery) + result, err = c.db.ExecContext(ctx, alteredQuery) } } if err != nil { @@ -763,8 +793,19 @@ func (c *clientConn) handleQuery(body []byte) error { } // Execute SELECT query - rows, err := c.db.Query(query) + ctx, cancel := c.queryContext() + defer cancel() + rows, err := c.db.QueryContext(ctx, query) if err != nil { + if isQueryTimeout(err) { + queryTimeoutsCounter.Inc() + slog.Warn("Query timeout.", "user", c.username, "query", query, "timeout", c.server.cfg.QueryTimeout) + c.sendError("ERROR", "57014", "canceling statement due to statement timeout") + c.setTxError() + _ = writeReadyForQuery(c.writer, c.txStatus) + _ = c.writer.Flush() + return nil + } slog.Error("Query execution failed.", "user", c.username, "query", query, "error", err) c.sendError("ERROR", "42000", err.Error()) c.setTxError() @@ -857,11 +898,23 @@ func (c *clientConn) executeMultiStatement(statements []string, cleanup []string } // Execute setup statements (all but last) + ctx, cancel := c.queryContext() + defer cancel() for i := 0; i < len(statements)-1; i++ { stmt := statements[i] slog.Debug("Multi-stmt setup.", "user", c.username, "step", i+1, "total", len(statements)-1, "stmt", stmt) - _, err := c.db.Exec(stmt) + _, err := c.db.ExecContext(ctx, stmt) if err != nil { + if isQueryTimeout(err) { + queryTimeoutsCounter.Inc() + slog.Warn("Query timeout.", "user", c.username, "query", stmt, "timeout", c.server.cfg.QueryTimeout) + c.setTxError() + c.executeCleanup(cleanup) + c.sendError("ERROR", "57014", "canceling statement due to statement timeout") + _ = writeReadyForQuery(c.writer, c.txStatus) + _ = c.writer.Flush() + return nil + } slog.Error("Multi-stmt setup error.", "user", c.username, "query", stmt, "error", err) c.setTxError() // On error, still try to cleanup (best effort) @@ -881,8 +934,18 @@ func (c *clientConn) executeMultiStatement(statements []string, cleanup []string if cmdType == "SELECT" || strings.HasPrefix(upperFinal, "WITH") || strings.HasPrefix(upperFinal, "TABLE") { // SELECT: obtain cursor FIRST, cleanup SECOND, stream THIRD - rows, err := c.db.Query(finalStmt) + rows, err := c.db.QueryContext(ctx, finalStmt) if err != nil { + if isQueryTimeout(err) { + queryTimeoutsCounter.Inc() + slog.Warn("Query timeout.", "user", c.username, "query", finalStmt, "timeout", c.server.cfg.QueryTimeout) + c.setTxError() + c.executeCleanup(cleanup) + c.sendError("ERROR", "57014", "canceling statement due to statement timeout") + _ = writeReadyForQuery(c.writer, c.txStatus) + _ = c.writer.Flush() + return nil + } slog.Error("Multi-stmt final query error.", "user", c.username, "query", finalStmt, "error", err) c.setTxError() c.executeCleanup(cleanup) @@ -902,8 +965,18 @@ func (c *clientConn) executeMultiStatement(statements []string, cleanup []string } else { // DML (INSERT/UPDATE/DELETE): execute then cleanup - result, err := c.db.Exec(finalStmt) + result, err := c.db.ExecContext(ctx, finalStmt) if err != nil { + if isQueryTimeout(err) { + queryTimeoutsCounter.Inc() + slog.Warn("Query timeout.", "user", c.username, "query", finalStmt, "timeout", c.server.cfg.QueryTimeout) + c.setTxError() + c.executeCleanup(cleanup) + c.sendError("ERROR", "57014", "canceling statement due to statement timeout") + _ = writeReadyForQuery(c.writer, c.txStatus) + _ = c.writer.Flush() + return nil + } slog.Error("Multi-stmt final exec error.", "user", c.username, "query", finalStmt, "error", err) c.setTxError() c.executeCleanup(cleanup) @@ -959,11 +1032,21 @@ func (c *clientConn) executeMultiStatementExtended(statements []string, cleanup } // Execute setup statements (all but last) + ctx, cancel := c.queryContext() + defer cancel() for i := 0; i < len(statements)-1; i++ { stmt := statements[i] slog.Debug("Multi-stmt-ext setup.", "user", c.username, "step", i+1, "total", len(statements)-1, "stmt", stmt) - _, err := c.db.Exec(stmt, args...) + _, err := c.db.ExecContext(ctx, stmt, args...) if err != nil { + if isQueryTimeout(err) { + queryTimeoutsCounter.Inc() + slog.Warn("Query timeout.", "user", c.username, "query", stmt, "timeout", c.server.cfg.QueryTimeout) + c.setTxError() + c.executeCleanup(cleanup) + c.sendError("ERROR", "57014", "canceling statement due to statement timeout") + return + } slog.Error("Multi-stmt-ext setup error.", "user", c.username, "query", stmt, "error", err) c.setTxError() // On error, still try to cleanup (best effort) @@ -981,8 +1064,16 @@ func (c *clientConn) executeMultiStatementExtended(statements []string, cleanup if cmdType == "SELECT" || strings.HasPrefix(upperFinal, "WITH") || strings.HasPrefix(upperFinal, "TABLE") { // SELECT: obtain cursor FIRST, cleanup SECOND, stream THIRD - rows, err := c.db.Query(finalStmt, args...) + rows, err := c.db.QueryContext(ctx, finalStmt, args...) if err != nil { + if isQueryTimeout(err) { + queryTimeoutsCounter.Inc() + slog.Warn("Query timeout.", "user", c.username, "query", finalStmt, "timeout", c.server.cfg.QueryTimeout) + c.setTxError() + c.executeCleanup(cleanup) + c.sendError("ERROR", "57014", "canceling statement due to statement timeout") + return + } slog.Error("Multi-stmt-ext final query error.", "user", c.username, "query", finalStmt, "error", err) c.setTxError() c.executeCleanup(cleanup) @@ -999,8 +1090,16 @@ func (c *clientConn) executeMultiStatementExtended(statements []string, cleanup } else { // DML (INSERT/UPDATE/DELETE): execute then cleanup - result, err := c.db.Exec(finalStmt, args...) + result, err := c.db.ExecContext(ctx, finalStmt, args...) if err != nil { + if isQueryTimeout(err) { + queryTimeoutsCounter.Inc() + slog.Warn("Query timeout.", "user", c.username, "query", finalStmt, "timeout", c.server.cfg.QueryTimeout) + c.setTxError() + c.executeCleanup(cleanup) + c.sendError("ERROR", "57014", "canceling statement due to statement timeout") + return + } slog.Error("Multi-stmt-ext final exec error.", "user", c.username, "query", finalStmt, "error", err) c.setTxError() c.executeCleanup(cleanup) @@ -1500,8 +1599,19 @@ func (c *clientConn) handleCopyOut(query, upperQuery string) error { } // Execute the query - rows, err := c.db.Query(selectQuery) + ctx, cancel := c.queryContext() + defer cancel() + rows, err := c.db.QueryContext(ctx, selectQuery) if err != nil { + if isQueryTimeout(err) { + queryTimeoutsCounter.Inc() + slog.Warn("Query timeout.", "user", c.username, "query", selectQuery, "timeout", c.server.cfg.QueryTimeout) + c.sendError("ERROR", "57014", "canceling statement due to statement timeout") + c.setTxError() + _ = writeReadyForQuery(c.writer, c.txStatus) + _ = c.writer.Flush() + return nil + } slog.Error("COPY TO query failed.", "user", c.username, "query", selectQuery, "error", err) c.sendError("ERROR", "42000", err.Error()) c.setTxError() @@ -1592,8 +1702,10 @@ func (c *clientConn) handleCopyIn(query, upperQuery string) error { slog.Debug("COPY FROM STDIN parsed.", "user", c.username, "table", tableName, "columns", columnList) // Get column count for the table + ctx, cancel := c.queryContext() + defer cancel() colQuery := fmt.Sprintf("SELECT * FROM %s LIMIT 0", tableName) - testRows, err := c.db.Query(colQuery) + testRows, err := c.db.QueryContext(ctx, colQuery) if err != nil { slog.Error("COPY FROM table check failed.", "user", c.username, "table", tableName, "error", err) c.sendError("ERROR", "42P01", fmt.Sprintf("relation \"%s\" does not exist", tableName)) @@ -1670,8 +1782,20 @@ func (c *clientConn) handleCopyIn(query, upperQuery string) error { slog.Debug("COPY FROM STDIN executing native DuckDB COPY.", "user", c.username, "sql", copySQL) loadStart := time.Now() - result, err := c.db.Exec(copySQL) + // Use a fresh context for the actual COPY operation (data already received) + copyCtx, copyCancel := c.queryContext() + defer copyCancel() + result, err := c.db.ExecContext(copyCtx, copySQL) if err != nil { + if isQueryTimeout(err) { + queryTimeoutsCounter.Inc() + slog.Warn("Query timeout.", "user", c.username, "query", copySQL, "timeout", c.server.cfg.QueryTimeout) + c.sendError("ERROR", "57014", "canceling statement due to statement timeout") + c.setTxError() + _ = writeReadyForQuery(c.writer, c.txStatus) + _ = c.writer.Flush() + return nil + } slog.Error("COPY FROM STDIN DuckDB COPY failed.", "user", c.username, "error", err) c.sendError("ERROR", "22P02", fmt.Sprintf("COPY failed: %v", err)) c.setTxError() @@ -2212,7 +2336,9 @@ func (c *clientConn) handleDescribe(body []byte) { args[i] = nil } - rows, err := c.db.Query(describeQuery, args...) + ctx, cancel := c.queryContext() + defer cancel() + rows, err := c.db.QueryContext(ctx, describeQuery, args...) if err != nil { // Can't describe - send NoData slog.Debug("Describe failed to get columns.", "user", c.username, "error", err) @@ -2257,7 +2383,9 @@ func (c *clientConn) handleDescribe(body []byte) { } // Try to get column info - rows, err := c.db.Query(p.stmt.convertedQuery, args...) + ctx, cancel := c.queryContext() + defer cancel() + rows, err := c.db.QueryContext(ctx, p.stmt.convertedQuery, args...) if err != nil { // Can't describe - send NoData _ = writeNoData(c.writer) @@ -2372,6 +2500,9 @@ func (c *clientConn) handleExecute(body []byte) { return } + ctx, cancel := c.queryContext() + defer cancel() + if !returnsResults { // Handle nested BEGIN: PostgreSQL issues a warning but continues, // while DuckDB throws an error. Match PostgreSQL behavior. @@ -2382,12 +2513,20 @@ func (c *clientConn) handleExecute(body []byte) { } // Non-result-returning query: use Exec with converted query - result, err := c.db.Exec(p.stmt.convertedQuery, args...) + result, err := c.db.ExecContext(ctx, p.stmt.convertedQuery, args...) if err != nil { + // Check for query timeout + if isQueryTimeout(err) { + queryTimeoutsCounter.Inc() + slog.Warn("Query timeout.", "user", c.username, "query", p.stmt.convertedQuery, "timeout", c.server.cfg.QueryTimeout) + c.sendError("ERROR", "57014", "canceling statement due to statement timeout") + c.setTxError() + return + } // Retry ALTER TABLE as ALTER VIEW if target is a view if isAlterTableNotTableError(err) { if alteredQuery, ok := transpiler.ConvertAlterTableToAlterView(p.stmt.convertedQuery); ok { - result, err = c.db.Exec(alteredQuery, args...) + result, err = c.db.ExecContext(ctx, alteredQuery, args...) } } if err != nil { @@ -2404,8 +2543,15 @@ func (c *clientConn) handleExecute(body []byte) { } // Result-returning query: use Query with converted query - rows, err := c.db.Query(p.stmt.convertedQuery, args...) + rows, err := c.db.QueryContext(ctx, p.stmt.convertedQuery, args...) if err != nil { + if isQueryTimeout(err) { + queryTimeoutsCounter.Inc() + slog.Warn("Query timeout.", "user", c.username, "query", p.stmt.convertedQuery, "timeout", c.server.cfg.QueryTimeout) + c.sendError("ERROR", "57014", "canceling statement due to statement timeout") + c.setTxError() + return + } slog.Error("Query execution failed.", "user", c.username, "query", p.stmt.convertedQuery, "original_query", p.stmt.query, "error", err) c.sendError("ERROR", "42000", err.Error()) c.setTxError() diff --git a/server/server.go b/server/server.go index acda3f8..b9bb12d 100644 --- a/server/server.go +++ b/server/server.go @@ -51,6 +51,11 @@ var rateLimitedIPsGauge = promauto.NewGauge(prometheus.GaugeOpts{ Help: "Number of currently rate-limited IP addresses", }) +var queryTimeoutsCounter = promauto.NewCounter(prometheus.CounterOpts{ + Name: "duckgres_query_timeouts_total", + Help: "Total number of queries that timed out", +}) + func redactConnectionString(connStr string) string { return passwordPattern.ReplaceAllString(connStr, "${1}[REDACTED]") } @@ -81,6 +86,11 @@ type Config struct { // This prevents accumulation of zombie connections from clients that disconnect // uncleanly. Default: 10 minutes. Set to 0 to disable. IdleTimeout time.Duration + + // QueryTimeout is the maximum time a query can run before being cancelled. + // This prevents runaway queries from blocking connections forever. + // Default: 0 (no timeout). Common values: 30s, 1m, 5m. + QueryTimeout time.Duration } // DuckLakeConfig configures DuckLake catalog attachment