From db091d441616ba8de41d9f497534a13ee25383c0 Mon Sep 17 00:00:00 2001 From: James Greenhill Date: Thu, 5 Feb 2026 17:23:01 -0800 Subject: [PATCH 1/3] Add control plane / data plane architecture for zero-downtime deployments Implement a multi-process architecture that splits duckgres into a control plane (connection management, routing) and data plane (pool of long-lived DuckDB worker processes). This enables zero-downtime deployments, cross-session DuckDB cache reuse, and rolling worker updates. Key components: - gRPC-based worker management (Configure, Health, Drain, Shutdown) - Unix socket FD passing via SCM_RIGHTS for TCP connection handoff - Least-connections load balancing across worker pool - Graceful control plane handover via listener FD transfer - Rolling worker updates triggered by SIGUSR2 - Health check loop with automatic worker restart New CLI modes: --mode control-plane | worker | standalone (default) Standalone mode (existing behavior) is completely unchanged. Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 37 +- README.md | 89 ++- controlplane/control.go | 343 ++++++++++ controlplane/dbpool.go | 87 +++ controlplane/fdpass/fdpass.go | 113 ++++ controlplane/fdpass/fdpass_test.go | 133 ++++ controlplane/handover.go | 228 +++++++ controlplane/pool.go | 508 ++++++++++++++ controlplane/proto/generate.go | 3 + controlplane/proto/worker.pb.go | 964 +++++++++++++++++++++++++++ controlplane/proto/worker.proto | 112 ++++ controlplane/proto/worker_grpc.pb.go | 333 +++++++++ controlplane/worker.go | 551 +++++++++++++++ go.mod | 4 +- main.go | 31 + server/conn.go | 14 +- server/exports.go | 90 +++ server/server.go | 126 ++-- server/worker.go | 9 +- 19 files changed, 3694 insertions(+), 81 deletions(-) create mode 100644 controlplane/control.go create mode 100644 controlplane/dbpool.go create mode 100644 controlplane/fdpass/fdpass.go create mode 100644 controlplane/fdpass/fdpass_test.go create mode 100644 controlplane/handover.go create mode 100644 controlplane/pool.go create mode 100644 controlplane/proto/generate.go create mode 100644 controlplane/proto/worker.pb.go create mode 100644 controlplane/proto/worker.proto create mode 100644 controlplane/proto/worker_grpc.pb.go create mode 100644 controlplane/worker.go create mode 100644 server/exports.go diff --git a/CLAUDE.md b/CLAUDE.md index 5ae07b3..1526843 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,24 +8,39 @@ Duckgres is a PostgreSQL wire protocol server backed by DuckDB. It allows any Po ## Architecture +Duckgres supports three run modes: `standalone` (default), `control-plane`, and `worker`. + ``` -PostgreSQL Client → TLS → Duckgres Server → DuckDB (per-user database) +Standalone: PostgreSQL Client → TLS → Duckgres Server → DuckDB (per-user database) +Control Plane: PostgreSQL Client → TLS → Control Plane → (FD pass) → Worker → DuckDB ``` ### Key Components -- **main.go**: Entry point, configuration loading (CLI flags, env vars, YAML) -- **server/server.go**: Server struct, connection handling, graceful shutdown +- **main.go**: Entry point, configuration loading (CLI flags, env vars, YAML), mode routing +- **server/server.go**: Server struct, connection handling, graceful shutdown, `CreateDBConnection()` (standalone function) - **server/conn.go**: Client connection handling, query execution, COPY protocol - **server/protocol.go**: PostgreSQL wire protocol message encoding/decoding +- **server/exports.go**: Exported wrappers for protocol functions (used by control plane workers) - **server/catalog.go**: pg_catalog compatibility views and macros initialization - **server/types.go**: Type OID mapping between DuckDB and PostgreSQL - **server/ratelimit.go**: Rate limiting for brute-force protection - **server/certs.go**: Auto-generation of self-signed TLS certificates +- **server/parent.go**: Child process spawning for ProcessIsolation mode +- **server/worker.go**: Per-connection child worker (ProcessIsolation mode) - **transpiler/**: AST-based SQL transpiler (PostgreSQL → DuckDB) - `transpiler.go`: Main API, transform pipeline orchestration - `config.go`: Configuration types (DuckLakeMode, ConvertPlaceholders) - `transform/`: Individual transform implementations +- **controlplane/**: Multi-process control plane architecture + - `proto/worker.proto`: gRPC service definition (Configure, AcceptConnection, CancelQuery, Drain, Health, Shutdown) + - `proto/*.pb.go`: Generated gRPC/protobuf code + - `fdpass/fdpass.go`: Unix socket FD passing via SCM_RIGHTS + - `worker.go`: Long-lived worker process (gRPC server, FD receiver, session handler) + - `dbpool.go`: Per-session DuckDB database pool management + - `control.go`: Control plane main loop (TCP listener, rate limiting, connection routing) + - `pool.go`: Worker pool management (spawn, health check, least-connections routing, rolling update) + - `handover.go`: Graceful deployment (listener FD transfer between control planes) ## PostgreSQL Wire Protocol @@ -74,10 +89,24 @@ Supports bulk data transfer: - **COPY FROM STDIN**: Receives data from client, inserts row by row - Supports CSV format with HEADER, DELIMITER, and NULL options +## Run Modes + +- **standalone** (default): Single process, handles everything. Current behavior unchanged. +- **control-plane**: Multi-process. Accepts TCP connections, passes FDs to worker pool via Unix sockets. +- **worker**: Long-lived child process spawned by control plane. Handles TLS, auth, query execution via gRPC + FD passing. + +Key CLI flags for control plane mode: +- `--mode control-plane|worker|standalone` +- `--worker-count N` (default 4) +- `--socket-dir /path` (Unix sockets for gRPC + FD passing) +- `--handover-socket /path` (graceful deployment between control planes) +- `--grpc-socket /path` (worker, set by control plane at spawn) +- `--fd-socket /path` (worker, set by control plane at spawn) + ## Configuration Three-tier configuration (highest to lowest priority): -1. CLI flags (`--port`, `--config`, etc.) +1. CLI flags (`--port`, `--config`, `--mode`, etc.) 2. Environment variables (`DUCKGRES_PORT`, etc.) 3. YAML config file 4. Built-in defaults diff --git a/README.md b/README.md index b01ee1d..b50a697 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,8 @@ A PostgreSQL wire protocol compatible server backed by DuckDB. Connect with any - [Rate Limiting](#rate-limiting) - [Usage Examples](#usage-examples) - [Architecture](#architecture) + - [Standalone Mode](#standalone-mode) + - [Control Plane Mode](#control-plane-mode) - [Two-Tier Query Processing](#two-tier-query-processing) - [Supported Features](#supported-features) - [Limitations](#limitations) @@ -45,6 +47,7 @@ A PostgreSQL wire protocol compatible server backed by DuckDB. Connect with any - **DuckLake Integration**: Auto-attach DuckLake catalogs for lakehouse workflows - **Rate Limiting**: Built-in protection against brute-force attacks - **Graceful Shutdown**: Waits for in-flight queries before exiting +- **Control Plane Mode**: Multi-process architecture with long-lived workers, zero-downtime deployments, and rolling updates - **Flexible Configuration**: YAML config files, environment variables, and CLI flags - **Prometheus Metrics**: Built-in metrics endpoint for monitoring @@ -177,12 +180,16 @@ export POSTHOG_HOST=eu.i.posthog.com ./duckgres --help Options: - -config string Path to YAML config file - -host string Host to bind to - -port int Port to listen on - -data-dir string Directory for DuckDB files - -cert string TLS certificate file - -key string TLS private key file + -config string Path to YAML config file + -host string Host to bind to + -port int Port to listen on + -data-dir string Directory for DuckDB files + -cert string TLS certificate file + -key string TLS private key file + -mode string Run mode: standalone (default), control-plane, or worker + -worker-count int Number of worker processes (control-plane mode, default 4) + -socket-dir string Unix socket directory (control-plane mode) + -handover-socket string Handover socket for graceful deployment (control-plane mode) ``` ## DuckDB Extensions @@ -428,6 +435,12 @@ GROUP BY name; ## Architecture +Duckgres supports two run modes: **standalone** (single process, default) and **control-plane** (multi-process with worker pool). + +### Standalone Mode + +The default mode runs everything in a single process: + ``` ┌─────────────────┐ │ PostgreSQL │ @@ -449,6 +462,64 @@ GROUP BY name; └─────────────────┘ ``` +### Control Plane Mode + +For production deployments, control-plane mode splits the server into a **control plane** (connection management, routing) and a pool of long-lived **worker processes** (query execution). This enables zero-downtime deployments and cross-session DuckDB cache reuse. + +``` + CONTROL PLANE (duckgres --mode control-plane) + ┌──────────────────────────────────────────┐ + PG Client ──TLS──>│ TCP Listener │ + │ Rate Limiting │ + │ Connection Router (least-connections) │ + │ │ FD pass via Unix socket (SCM_RIGHTS) │ + │ ▼ │ + │ gRPC Client ─────────────────────────+ │ + └──────────────────────────────────────────┘ + │ + gRPC (UDS) + │ + WORKER POOL ▼ + ┌──────────────────────────────────────────┐ + │ Worker 1 (duckgres --mode worker) │ + │ gRPC Server (Configure, Health, Drain) │ + │ FD Receiver (Unix socket) │ + │ Shared DuckDB instance (long-lived) │ + │ ├── Session 1 (goroutine) │ + │ ├── Session 2 (goroutine) │ + │ └── Session N ... │ + ├──────────────────────────────────────────┤ + │ Worker 2 ... │ + └──────────────────────────────────────────┘ +``` + +Start in control-plane mode: + +```bash +# Start with 4 workers (default) +./duckgres --mode control-plane --port 5432 --worker-count 4 + +# Connect with psql (identical to standalone mode) +PGPASSWORD=postgres psql "host=localhost port=5432 user=postgres sslmode=require" +``` + +**Zero-downtime deployment** using the handover protocol: + +```bash +# Start the first control plane with a handover socket +./duckgres --mode control-plane --port 5432 --handover-socket /var/run/duckgres/handover.sock + +# Deploy a new version - it takes over the listener and workers without dropping connections +./duckgres-v2 --mode control-plane --port 5432 --handover-socket /var/run/duckgres/handover.sock +``` + +**Rolling worker updates** via signal: + +```bash +# Replace workers one at a time (drains sessions before replacing each worker) +kill -USR2 +``` + ## Two-Tier Query Processing Duckgres uses a two-tier approach to handle both PostgreSQL and DuckDB-specific SQL syntax transparently: @@ -509,9 +580,9 @@ The following DuckDB features work transparently through the fallback mechanism: ## Limitations -- **Single Process**: Each user's database is opened in the same process -- **No Replication**: Single-node only -- **Limited System Catalog**: Some `pg_*` system tables are not available +- **Single Node**: No built-in replication or clustering +- **Limited System Catalog**: Some `pg_*` system tables are stubs (return empty) +- **Type OID Mapping**: Incomplete (some types show as "unknown") ## Dependencies diff --git a/controlplane/control.go b/controlplane/control.go new file mode 100644 index 0000000..2b118c6 --- /dev/null +++ b/controlplane/control.go @@ -0,0 +1,343 @@ +package controlplane + +import ( + "context" + "encoding/binary" + "fmt" + "log/slog" + "net" + "os" + "os/signal" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/posthog/duckgres/server" +) + +// ControlPlaneConfig extends server.Config with control-plane-specific settings. +type ControlPlaneConfig struct { + server.Config + + WorkerCount int + SocketDir string + HandoverSocket string + HealthCheckInterval time.Duration + MaxConnsPerWorker int +} + +// ControlPlane manages the TCP listener and worker pool. +type ControlPlane struct { + cfg ControlPlaneConfig + pool *WorkerPool + rateLimiter *server.RateLimiter + listener net.Listener + activeConns int64 + closed bool + closeMu sync.Mutex + wg sync.WaitGroup +} + +// RunControlPlane is the entry point for the control plane process. +func RunControlPlane(cfg ControlPlaneConfig) { + // Apply defaults + if cfg.WorkerCount == 0 { + cfg.WorkerCount = 4 + } + if cfg.SocketDir == "" { + cfg.SocketDir = "/var/run/duckgres" + } + if cfg.HealthCheckInterval == 0 { + cfg.HealthCheckInterval = 5 * time.Second + } + + // Create socket directory + if err := os.MkdirAll(cfg.SocketDir, 0755); err != nil { + slog.Error("Failed to create socket directory.", "error", err) + os.Exit(1) + } + + // Use default rate limit config if not specified + if cfg.RateLimit.MaxFailedAttempts == 0 { + cfg.RateLimit = server.DefaultRateLimitConfig() + } + + cp := &ControlPlane{ + cfg: cfg, + pool: NewWorkerPool(cfg.SocketDir, cfg.Config), + rateLimiter: server.NewRateLimiter(cfg.RateLimit), + } + + // Set up signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT, syscall.SIGUSR2) + + // Try handover from existing control plane if handover socket exists + handoverDone := false + if cfg.HandoverSocket != "" { + if _, err := os.Stat(cfg.HandoverSocket); err == nil { + slog.Info("Existing handover socket found, attempting handover.", "socket", cfg.HandoverSocket) + tcpLn, existingWorkers, err := receiveHandover(cfg.HandoverSocket) + if err != nil { + slog.Warn("Handover failed, starting fresh.", "error", err) + } else { + cp.listener = tcpLn + handoverDone = true + + // Connect to existing workers instead of spawning new ones + for _, w := range existingWorkers { + if err := cp.pool.ConnectExistingWorker(w.ID, w.GRPCSocket, w.FDSocket); err != nil { + slog.Error("Failed to connect to handed-over worker.", "id", w.ID, "error", err) + } + } + slog.Info("Handover complete, took over listener and workers.", + "workers", len(existingWorkers)) + } + } + } + + if !handoverDone { + // Spawn new workers + for i := 0; i < cfg.WorkerCount; i++ { + if err := cp.pool.SpawnWorker(i); err != nil { + slog.Error("Failed to spawn worker.", "id", i, "error", err) + os.Exit(1) + } + } + + // Start TCP listener + addr := fmt.Sprintf("%s:%d", cfg.Host, cfg.Port) + ln, err := net.Listen("tcp", addr) + if err != nil { + slog.Error("Failed to listen.", "addr", addr, "error", err) + os.Exit(1) + } + cp.listener = ln + } + + // Start health check loop + go cp.pool.HealthCheckLoop(makeShutdownCtx(sigChan), cfg.HealthCheckInterval, cfg.WorkerCount) + + // Start handover listener for future deployments + cp.startHandoverListener() + + slog.Info("Control plane listening.", "addr", cp.listener.Addr().String(), "workers", cfg.WorkerCount) + + // Handle signals + go func() { + for sig := range sigChan { + switch sig { + case syscall.SIGUSR2: + slog.Info("Received SIGUSR2, starting rolling update.") + go func() { + if err := cp.pool.RollingUpdate(makeShutdownCtx(sigChan)); err != nil { + slog.Error("Rolling update failed.", "error", err) + } + }() + case syscall.SIGTERM, syscall.SIGINT: + slog.Info("Received shutdown signal.", "signal", sig) + cp.shutdown() + os.Exit(0) + } + } + }() + + // Accept loop + cp.acceptLoop() +} + +func makeShutdownCtx(_ <-chan os.Signal) context.Context { + ctx, cancel := context.WithCancel(context.Background()) + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT) + go func() { + <-sigCh + cancel() + }() + return ctx +} + +func (cp *ControlPlane) acceptLoop() { + for { + conn, err := cp.listener.Accept() + if err != nil { + cp.closeMu.Lock() + closed := cp.closed + cp.closeMu.Unlock() + if closed { + return + } + slog.Error("Accept error.", "error", err) + continue + } + + // Enable TCP keepalive + if tcpConn, ok := conn.(*net.TCPConn); ok { + _ = tcpConn.SetKeepAlive(true) + _ = tcpConn.SetKeepAlivePeriod(30 * time.Second) + } + + cp.wg.Add(1) + go func() { + defer cp.wg.Done() + cp.handleConnection(conn) + }() + } +} + +func (cp *ControlPlane) handleConnection(conn net.Conn) { + remoteAddr := conn.RemoteAddr() + + // Rate limiting + if msg := cp.rateLimiter.CheckConnection(remoteAddr); msg != "" { + slog.Warn("Connection rejected.", "remote_addr", remoteAddr, "reason", msg) + _ = conn.Close() + return + } + + if !cp.rateLimiter.RegisterConnection(remoteAddr) { + slog.Warn("Connection rejected: rate limit.", "remote_addr", remoteAddr) + _ = conn.Close() + return + } + defer cp.rateLimiter.UnregisterConnection(remoteAddr) + + // Read startup message to determine SSL vs cancel + params, err := readStartupFromRaw(conn) + if err != nil { + slog.Error("Failed to read startup.", "remote_addr", remoteAddr, "error", err) + _ = conn.Close() + return + } + + // Handle cancel request + if params.cancelRequest { + cp.pool.CancelQuery(params.cancelPid, params.cancelSecretKey) + _ = conn.Close() + return + } + + // Handle SSL request + if params.sslRequest { + // Send 'S' to indicate SSL support + if _, err := conn.Write([]byte("S")); err != nil { + slog.Error("Failed to send SSL response.", "remote_addr", remoteAddr, "error", err) + _ = conn.Close() + return + } + + // Get TCP file descriptor to pass to worker + tcpConn, ok := conn.(*net.TCPConn) + if !ok { + slog.Error("Not a TCP connection.", "remote_addr", remoteAddr) + _ = conn.Close() + return + } + + file, err := tcpConn.File() + if err != nil { + slog.Error("Failed to get FD.", "remote_addr", remoteAddr, "error", err) + _ = conn.Close() + return + } + + // Close the original connection (file has a dup'd FD) + _ = conn.Close() + + // Generate a secret key for this connection + secretKey := server.GenerateSecretKey() + + atomic.AddInt64(&cp.activeConns, 1) + defer atomic.AddInt64(&cp.activeConns, -1) + + // Route to a worker + backendPid, err := cp.pool.RouteConnection(file, remoteAddr.String(), secretKey) + _ = file.Close() + if err != nil { + slog.Error("Failed to route connection.", "remote_addr", remoteAddr, "error", err) + return + } + + slog.Debug("Connection routed.", "remote_addr", remoteAddr, "backend_pid", backendPid) + } else { + // No SSL - reject + slog.Warn("Connection rejected: SSL required.", "remote_addr", remoteAddr) + _ = conn.Close() + } +} + +// startupResult holds the parsed initial startup message. +type startupResult struct { + sslRequest bool + cancelRequest bool + cancelPid int32 + cancelSecretKey int32 +} + +// readStartupFromRaw reads the startup message from a raw (unbuffered) connection. +func readStartupFromRaw(conn net.Conn) (startupResult, error) { + // Read length (4 bytes) + var length int32 + if err := binary.Read(conn, binary.BigEndian, &length); err != nil { + return startupResult{}, fmt.Errorf("read length: %w", err) + } + + if length < 8 || length > 10000 { + return startupResult{}, fmt.Errorf("invalid startup message length: %d", length) + } + + remaining := make([]byte, length-4) + if _, err := fullRead(conn, remaining); err != nil { + return startupResult{}, fmt.Errorf("read body: %w", err) + } + + protocolVersion := binary.BigEndian.Uint32(remaining[:4]) + + // SSL request + if protocolVersion == 80877103 { + return startupResult{sslRequest: true}, nil + } + + // Cancel request + if protocolVersion == 80877102 && len(remaining) >= 12 { + pid := int32(binary.BigEndian.Uint32(remaining[4:8])) + key := int32(binary.BigEndian.Uint32(remaining[8:12])) + return startupResult{cancelRequest: true, cancelPid: pid, cancelSecretKey: key}, nil + } + + return startupResult{}, fmt.Errorf("unexpected protocol version: %d", protocolVersion) +} + +func fullRead(conn net.Conn, buf []byte) (int, error) { + total := 0 + for total < len(buf) { + n, err := conn.Read(buf[total:]) + total += n + if err != nil { + return total, err + } + } + return total, nil +} + +func (cp *ControlPlane) shutdown() { + cp.closeMu.Lock() + cp.closed = true + cp.closeMu.Unlock() + + if cp.listener != nil { + _ = cp.listener.Close() + } + + slog.Info("Draining workers...") + cp.pool.DrainAll(30 * time.Second) + + slog.Info("Shutting down workers...") + cp.pool.ShutdownAll(30 * time.Second) + + // Wait for in-flight accept loop goroutines + cp.wg.Wait() + + slog.Info("Control plane shutdown complete.") +} diff --git a/controlplane/dbpool.go b/controlplane/dbpool.go new file mode 100644 index 0000000..6b56371 --- /dev/null +++ b/controlplane/dbpool.go @@ -0,0 +1,87 @@ +package controlplane + +import ( + "database/sql" + "fmt" + "log/slog" + "sync" + "time" + + "github.com/posthog/duckgres/server" +) + +// DBPool manages a shared DuckDB database for a long-lived worker. +// The database is opened once and shared across all sessions. Each session +// gets its own *sql.DB with MaxOpenConns=1 for transaction isolation. +type DBPool struct { + cfg server.Config + duckLakeSem chan struct{} + + mu sync.Mutex + sessions map[int32]*sql.DB // keyed by session backend PID +} + +// NewDBPool creates a new database pool with the given server configuration. +func NewDBPool(cfg server.Config) *DBPool { + return &DBPool{ + cfg: cfg, + duckLakeSem: make(chan struct{}, 1), + sessions: make(map[int32]*sql.DB), + } +} + +// CreateSession creates a new DuckDB connection for a client session. +// The connection is registered by its backend PID for tracking. +func (p *DBPool) CreateSession(pid int32, username string) (*sql.DB, error) { + db, err := server.CreateDBConnection(p.cfg, p.duckLakeSem, username) + if err != nil { + return nil, fmt.Errorf("create session db: %w", err) + } + + p.mu.Lock() + p.sessions[pid] = db + p.mu.Unlock() + + slog.Debug("Created session database.", "pid", pid, "user", username) + return db, nil +} + +// CloseSession closes and unregisters a session's database connection. +func (p *DBPool) CloseSession(pid int32) { + p.mu.Lock() + db, ok := p.sessions[pid] + if ok { + delete(p.sessions, pid) + } + p.mu.Unlock() + + if ok && db != nil { + if err := db.Close(); err != nil { + slog.Warn("Failed to close session database.", "pid", pid, "error", err) + } + } +} + +// ActiveSessions returns the number of active sessions. +func (p *DBPool) ActiveSessions() int { + p.mu.Lock() + defer p.mu.Unlock() + return len(p.sessions) +} + +// CloseAll closes all session databases. Used during shutdown. +func (p *DBPool) CloseAll(timeout time.Duration) { + p.mu.Lock() + sessions := make(map[int32]*sql.DB, len(p.sessions)) + for k, v := range p.sessions { + sessions[k] = v + } + p.sessions = make(map[int32]*sql.DB) + p.mu.Unlock() + + for pid, db := range sessions { + if err := db.Close(); err != nil { + slog.Warn("Failed to close session database during shutdown.", "pid", pid, "error", err) + } + } +} diff --git a/controlplane/fdpass/fdpass.go b/controlplane/fdpass/fdpass.go new file mode 100644 index 0000000..9093fed --- /dev/null +++ b/controlplane/fdpass/fdpass.go @@ -0,0 +1,113 @@ +// Package fdpass provides Unix socket file descriptor passing via SCM_RIGHTS. +// +// This is used by the control plane to pass raw TCP file descriptors to worker +// processes at runtime, enabling workers to handle client connections that were +// accepted by the control plane. +package fdpass + +import ( + "fmt" + "net" + "os" + "syscall" +) + +// SendFD sends a file descriptor over a Unix socket using SCM_RIGHTS. +// The fd is the file descriptor to pass; conn is the Unix socket to send it over. +func SendFD(conn *net.UnixConn, fd int) error { + rights := syscall.UnixRights(fd) + // Send a single byte as the message body (required by sendmsg) + _, _, err := conn.WriteMsgUnix([]byte{0}, rights, nil) + if err != nil { + return fmt.Errorf("sendmsg: %w", err) + } + return nil +} + +// SendFile sends an *os.File's descriptor over a Unix socket using SCM_RIGHTS. +func SendFile(conn *net.UnixConn, f *os.File) error { + return SendFD(conn, int(f.Fd())) +} + +// RecvFD receives a file descriptor from a Unix socket using SCM_RIGHTS. +// Returns the received file descriptor. The caller is responsible for closing it. +func RecvFD(conn *net.UnixConn) (int, error) { + buf := make([]byte, 1) + oob := make([]byte, syscall.CmsgLen(4)) // space for one int32 fd + _, oobn, _, _, err := conn.ReadMsgUnix(buf, oob) + if err != nil { + return -1, fmt.Errorf("recvmsg: %w", err) + } + + cmsgs, err := syscall.ParseSocketControlMessage(oob[:oobn]) + if err != nil { + return -1, fmt.Errorf("parse control message: %w", err) + } + + for _, cmsg := range cmsgs { + fds, err := syscall.ParseUnixRights(&cmsg) + if err != nil { + continue + } + if len(fds) > 0 { + return fds[0], nil + } + } + + return -1, fmt.Errorf("no file descriptor received") +} + +// RecvFile receives a file descriptor from a Unix socket and wraps it as *os.File. +// The caller is responsible for closing the returned file. +func RecvFile(conn *net.UnixConn, name string) (*os.File, error) { + fd, err := RecvFD(conn) + if err != nil { + return nil, err + } + return os.NewFile(uintptr(fd), name), nil +} + +// SocketPair creates a connected pair of Unix sockets for FD passing. +// Returns (sender, receiver) connections. Both must be closed by the caller. +func SocketPair() (*net.UnixConn, *net.UnixConn, error) { + fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + if err != nil { + return nil, nil, fmt.Errorf("socketpair: %w", err) + } + + sender, err := fdToUnixConn(fds[0], "sender") + if err != nil { + syscall.Close(fds[0]) + syscall.Close(fds[1]) + return nil, nil, err + } + + receiver, err := fdToUnixConn(fds[1], "receiver") + if err != nil { + sender.Close() + syscall.Close(fds[1]) + return nil, nil, err + } + + return sender, receiver, nil +} + +func fdToUnixConn(fd int, name string) (*net.UnixConn, error) { + f := os.NewFile(uintptr(fd), name) + if f == nil { + return nil, fmt.Errorf("invalid fd %d", fd) + } + defer f.Close() + + fc, err := net.FileConn(f) + if err != nil { + return nil, fmt.Errorf("FileConn: %w", err) + } + + uc, ok := fc.(*net.UnixConn) + if !ok { + fc.Close() + return nil, fmt.Errorf("not a UnixConn") + } + return uc, nil +} diff --git a/controlplane/fdpass/fdpass_test.go b/controlplane/fdpass/fdpass_test.go new file mode 100644 index 0000000..3f575aa --- /dev/null +++ b/controlplane/fdpass/fdpass_test.go @@ -0,0 +1,133 @@ +package fdpass + +import ( + "net" + "os" + "testing" +) + +func TestSendRecvFD(t *testing.T) { + // Create a socket pair for FD passing + sender, receiver, err := SocketPair() + if err != nil { + t.Fatalf("SocketPair: %v", err) + } + defer sender.Close() + defer receiver.Close() + + // Create a temp file to pass + tmp, err := os.CreateTemp("", "fdpass-test-*") + if err != nil { + t.Fatalf("CreateTemp: %v", err) + } + defer os.Remove(tmp.Name()) + + // Write some data + msg := "hello from fd passing" + if _, err := tmp.WriteString(msg); err != nil { + t.Fatalf("WriteString: %v", err) + } + if err := tmp.Sync(); err != nil { + t.Fatalf("Sync: %v", err) + } + + // Send the FD + if err := SendFile(sender, tmp); err != nil { + t.Fatalf("SendFile: %v", err) + } + tmp.Close() + + // Receive the FD + received, err := RecvFile(receiver, "received") + if err != nil { + t.Fatalf("RecvFile: %v", err) + } + defer received.Close() + + // Read data from received FD to verify it works + if _, err := received.Seek(0, 0); err != nil { + t.Fatalf("Seek: %v", err) + } + buf := make([]byte, len(msg)) + n, err := received.Read(buf) + if err != nil { + t.Fatalf("Read: %v", err) + } + if string(buf[:n]) != msg { + t.Errorf("got %q, want %q", string(buf[:n]), msg) + } +} + +func TestSendRecvTCPConn(t *testing.T) { + // Create a socket pair for FD passing + sender, receiver, err := SocketPair() + if err != nil { + t.Fatalf("SocketPair: %v", err) + } + defer sender.Close() + defer receiver.Close() + + // Create a TCP listener + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + // Connect a TCP client + clientConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + + // Accept the server side + serverConn, err := ln.Accept() + if err != nil { + t.Fatalf("Accept: %v", err) + } + + // Get the FD from the server-side TCP connection + tcpConn := serverConn.(*net.TCPConn) + file, err := tcpConn.File() + if err != nil { + t.Fatalf("File: %v", err) + } + serverConn.Close() // Close the original; file has a dup'd FD + + // Send the TCP FD + if err := SendFile(sender, file); err != nil { + t.Fatalf("SendFile: %v", err) + } + file.Close() + + // Receive the TCP FD in the "worker" + recvFile, err := RecvFile(receiver, "tcp-conn") + if err != nil { + t.Fatalf("RecvFile: %v", err) + } + + // Reconstruct the connection from the received FD + fc, err := net.FileConn(recvFile) + if err != nil { + t.Fatalf("FileConn: %v", err) + } + recvFile.Close() + defer fc.Close() + + // Write from the reconstructed connection, read from client + msg := "hello via fd passing" + if _, err := fc.Write([]byte(msg)); err != nil { + t.Fatalf("Write: %v", err) + } + + buf := make([]byte, len(msg)) + n, err := clientConn.Read(buf) + if err != nil { + t.Fatalf("Read: %v", err) + } + if string(buf[:n]) != msg { + t.Errorf("got %q, want %q", string(buf[:n]), msg) + } + + clientConn.Close() +} diff --git a/controlplane/handover.go b/controlplane/handover.go new file mode 100644 index 0000000..d4f9909 --- /dev/null +++ b/controlplane/handover.go @@ -0,0 +1,228 @@ +package controlplane + +import ( + "encoding/json" + "fmt" + "log/slog" + "net" + "os" + "time" + + "github.com/posthog/duckgres/controlplane/fdpass" +) + +// Handover protocol messages +type handoverMsg struct { + Type string `json:"type"` + Workers []handoverWorker `json:"workers,omitempty"` +} + +type handoverWorker struct { + ID int `json:"id"` + GRPCSocket string `json:"grpc_socket"` + FDSocket string `json:"fd_socket"` +} + +// startHandoverListener starts listening for handover requests from a new control plane. +// When a new CP connects, the old CP will transfer its TCP listener FD and worker info. +func (cp *ControlPlane) startHandoverListener() { + if cp.cfg.HandoverSocket == "" { + return + } + + // Clean up old socket + os.Remove(cp.cfg.HandoverSocket) + + ln, err := net.Listen("unix", cp.cfg.HandoverSocket) + if err != nil { + slog.Error("Failed to start handover listener.", "error", err) + return + } + + slog.Info("Handover listener started.", "socket", cp.cfg.HandoverSocket) + + go func() { + defer ln.Close() + defer os.Remove(cp.cfg.HandoverSocket) + + for { + conn, err := ln.Accept() + if err != nil { + cp.closeMu.Lock() + closed := cp.closed + cp.closeMu.Unlock() + if closed { + return + } + slog.Error("Handover accept error.", "error", err) + continue + } + + // Handle handover in a goroutine (only one at a time is expected) + go cp.handleHandoverRequest(conn, ln) + return // Only handle one handover + } + }() +} + +// handleHandoverRequest processes an incoming handover request from a new control plane. +func (cp *ControlPlane) handleHandoverRequest(conn net.Conn, handoverLn net.Listener) { + defer conn.Close() + defer handoverLn.Close() + + decoder := json.NewDecoder(conn) + encoder := json.NewEncoder(conn) + + // Read handover request + var req handoverMsg + if err := decoder.Decode(&req); err != nil { + slog.Error("Failed to read handover request.", "error", err) + return + } + + if req.Type != "handover_request" { + slog.Error("Unexpected handover message type.", "type", req.Type) + return + } + + slog.Info("Received handover request, preparing transfer...") + + // Build worker list + workers := cp.pool.Workers() + handoverWorkers := make([]handoverWorker, 0, len(workers)) + for _, w := range workers { + handoverWorkers = append(handoverWorkers, handoverWorker{ + ID: w.ID, + GRPCSocket: w.GRPCSocket, + FDSocket: w.FDSocket, + }) + } + + // Send ack with worker info + if err := encoder.Encode(handoverMsg{ + Type: "handover_ack", + Workers: handoverWorkers, + }); err != nil { + slog.Error("Failed to send handover ack.", "error", err) + return + } + + // Pass TCP listener FD via SCM_RIGHTS + tcpLn, ok := cp.listener.(*net.TCPListener) + if !ok { + slog.Error("Listener is not TCP, cannot handover.") + return + } + + file, err := tcpLn.File() + if err != nil { + slog.Error("Failed to get listener FD.", "error", err) + return + } + defer file.Close() + + uc, ok := conn.(*net.UnixConn) + if !ok { + slog.Error("Handover connection is not Unix.") + return + } + + if err := fdpass.SendFile(uc, file); err != nil { + slog.Error("Failed to send listener FD.", "error", err) + return + } + + slog.Info("Listener FD sent to new control plane.") + + // Wait for handover_complete + var complete handoverMsg + if err := decoder.Decode(&complete); err != nil { + slog.Error("Failed to read handover complete.", "error", err) + return + } + + if complete.Type != "handover_complete" { + slog.Error("Unexpected handover message type.", "type", complete.Type) + return + } + + slog.Info("Handover complete. Old control plane stopping accept loop...") + + // Stop accepting new connections + cp.closeMu.Lock() + cp.closed = true + cp.closeMu.Unlock() + _ = cp.listener.Close() + + // Brief wait for in-flight FD passes to complete + time.Sleep(2 * time.Second) + + // Wait for wg to drain + cp.wg.Wait() + + slog.Info("Old control plane exiting after handover.") + os.Exit(0) +} + +// receiveHandover connects to an existing control plane's handover socket, +// receives the TCP listener FD and worker info, and takes over. +func receiveHandover(handoverSocket string) (*net.TCPListener, []handoverWorker, error) { + conn, err := net.Dial("unix", handoverSocket) + if err != nil { + return nil, nil, fmt.Errorf("connect handover socket: %w", err) + } + defer conn.Close() + + decoder := json.NewDecoder(conn) + encoder := json.NewEncoder(conn) + + // Send handover request + if err := encoder.Encode(handoverMsg{Type: "handover_request"}); err != nil { + return nil, nil, fmt.Errorf("send handover request: %w", err) + } + + // Read ack with worker info + var ack handoverMsg + if err := decoder.Decode(&ack); err != nil { + return nil, nil, fmt.Errorf("read handover ack: %w", err) + } + + if ack.Type != "handover_ack" { + return nil, nil, fmt.Errorf("unexpected handover message: %s", ack.Type) + } + + // Receive listener FD + uc, ok := conn.(*net.UnixConn) + if !ok { + return nil, nil, fmt.Errorf("handover connection is not Unix") + } + + file, err := fdpass.RecvFile(uc, "tcp-listener") + if err != nil { + return nil, nil, fmt.Errorf("receive listener FD: %w", err) + } + defer file.Close() + + // Reconstruct listener from FD + ln, err := net.FileListener(file) + if err != nil { + return nil, nil, fmt.Errorf("FileListener: %w", err) + } + + tcpLn, ok := ln.(*net.TCPListener) + if !ok { + ln.Close() + return nil, nil, fmt.Errorf("not a TCP listener") + } + + // Send handover complete + if err := encoder.Encode(handoverMsg{Type: "handover_complete"}); err != nil { + tcpLn.Close() + return nil, nil, fmt.Errorf("send handover complete: %w", err) + } + + slog.Info("Handover received: got listener and worker info.", + "workers", len(ack.Workers)) + + return tcpLn, ack.Workers, nil +} diff --git a/controlplane/pool.go b/controlplane/pool.go new file mode 100644 index 0000000..900fc44 --- /dev/null +++ b/controlplane/pool.go @@ -0,0 +1,508 @@ +package controlplane + +import ( + "context" + "fmt" + "log/slog" + "net" + "os" + "os/exec" + "path/filepath" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/posthog/duckgres/controlplane/fdpass" + pb "github.com/posthog/duckgres/controlplane/proto" + "github.com/posthog/duckgres/server" +) + +// ManagedWorker represents a spawned worker process and its connections. +type ManagedWorker struct { + ID int + Cmd *exec.Cmd + PID int + GRPCSocket string + FDSocket string + GRPCConn *grpc.ClientConn + Client pb.WorkerControlClient + StartTime time.Time + done chan struct{} // closed when process exits +} + +// WorkerPool manages a pool of long-lived worker processes. +type WorkerPool struct { + mu sync.RWMutex + workers map[int]*ManagedWorker + socketDir string + cfg server.Config + + // Round-robin counter for simple load balancing + nextWorker int +} + +// NewWorkerPool creates a new worker pool. +func NewWorkerPool(socketDir string, cfg server.Config) *WorkerPool { + return &WorkerPool{ + workers: make(map[int]*ManagedWorker), + socketDir: socketDir, + cfg: cfg, + } +} + +// SpawnWorker spawns a new worker process and establishes gRPC + FD socket connections. +func (p *WorkerPool) SpawnWorker(id int) error { + grpcSocket := filepath.Join(p.socketDir, fmt.Sprintf("worker-%d-grpc.sock", id)) + fdSocket := filepath.Join(p.socketDir, fmt.Sprintf("worker-%d-fd.sock", id)) + + // Clean up old sockets + os.Remove(grpcSocket) + os.Remove(fdSocket) + + // Spawn child process + cmd := exec.Command(os.Args[0], + "--mode", "worker", + "--grpc-socket", grpcSocket, + "--fd-socket", fdSocket, + ) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = os.Environ() + + if err := cmd.Start(); err != nil { + return fmt.Errorf("start worker %d: %w", id, err) + } + + slog.Info("Spawned worker process.", "id", id, "pid", cmd.Process.Pid, + "grpc_socket", grpcSocket, "fd_socket", fdSocket) + + // Wait for the gRPC socket to become available + if err := waitForSocket(grpcSocket, 10*time.Second); err != nil { + _ = cmd.Process.Kill() + return fmt.Errorf("worker %d gRPC socket not ready: %w", id, err) + } + + // Connect gRPC + conn, err := grpc.NewClient( + "unix://"+grpcSocket, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + _ = cmd.Process.Kill() + return fmt.Errorf("connect gRPC to worker %d: %w", id, err) + } + + client := pb.NewWorkerControlClient(conn) + + // Send Configure + configReq := buildConfigureRequest(p.cfg) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + resp, err := client.Configure(ctx, configReq) + cancel() + if err != nil { + conn.Close() + _ = cmd.Process.Kill() + return fmt.Errorf("configure worker %d: %w", id, err) + } + if !resp.Ok { + conn.Close() + _ = cmd.Process.Kill() + return fmt.Errorf("configure worker %d: %s", id, resp.Error) + } + + worker := &ManagedWorker{ + ID: id, + Cmd: cmd, + PID: cmd.Process.Pid, + GRPCSocket: grpcSocket, + FDSocket: fdSocket, + GRPCConn: conn, + Client: client, + StartTime: time.Now(), + done: make(chan struct{}), + } + + // Monitor the process in the background + go func() { + err := cmd.Wait() + if err != nil { + slog.Error("Worker process exited.", "id", id, "pid", worker.PID, "error", err) + } else { + slog.Info("Worker process exited cleanly.", "id", id, "pid", worker.PID) + } + close(worker.done) + + // Remove from pool + p.mu.Lock() + delete(p.workers, id) + p.mu.Unlock() + }() + + p.mu.Lock() + p.workers[id] = worker + p.mu.Unlock() + + slog.Info("Worker configured and ready.", "id", id, "pid", worker.PID) + return nil +} + +// ConnectExistingWorker connects to a worker that was handed over from a previous control plane. +// The worker process is already running - we just need to establish gRPC connection. +func (p *WorkerPool) ConnectExistingWorker(id int, grpcSocket, fdSocket string) error { + // Connect gRPC + conn, err := grpc.NewClient( + "unix://"+grpcSocket, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return fmt.Errorf("connect gRPC to worker %d: %w", id, err) + } + + client := pb.NewWorkerControlClient(conn) + + // Verify the worker is healthy + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + health, err := client.Health(ctx, &pb.HealthRequest{}) + cancel() + if err != nil { + conn.Close() + return fmt.Errorf("health check worker %d: %w", id, err) + } + if !health.Healthy { + conn.Close() + return fmt.Errorf("worker %d is not healthy", id) + } + + worker := &ManagedWorker{ + ID: id, + GRPCSocket: grpcSocket, + FDSocket: fdSocket, + GRPCConn: conn, + Client: client, + StartTime: time.Now(), + done: make(chan struct{}), + } + + p.mu.Lock() + p.workers[id] = worker + p.mu.Unlock() + + slog.Info("Connected to existing worker.", "id", id, + "active_connections", health.ActiveConnections) + return nil +} + +// RouteConnection sends a TCP file descriptor to a worker for handling. +// It selects the least-loaded worker based on health checks. +func (p *WorkerPool) RouteConnection(tcpFile *os.File, remoteAddr string, secretKey int32) (int32, error) { + worker, err := p.selectWorker() + if err != nil { + return 0, err + } + + // Wait for FD socket to be available + if err := waitForSocket(worker.FDSocket, 5*time.Second); err != nil { + return 0, fmt.Errorf("worker %d FD socket not ready: %w", worker.ID, err) + } + + // Connect to worker's FD socket and send the FD + fdConn, err := net.Dial("unix", worker.FDSocket) + if err != nil { + return 0, fmt.Errorf("connect FD socket for worker %d: %w", worker.ID, err) + } + + uc, ok := fdConn.(*net.UnixConn) + if !ok { + fdConn.Close() + return 0, fmt.Errorf("FD conn not UnixConn") + } + + if err := fdpass.SendFile(uc, tcpFile); err != nil { + uc.Close() + return 0, fmt.Errorf("send FD to worker %d: %w", worker.ID, err) + } + uc.Close() + + // Tell worker to accept the connection via gRPC + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + resp, err := worker.Client.AcceptConnection(ctx, &pb.AcceptConnectionRequest{ + RemoteAddr: remoteAddr, + BackendSecretKey: secretKey, + }) + cancel() + + if err != nil { + return 0, fmt.Errorf("AcceptConnection on worker %d: %w", worker.ID, err) + } + if !resp.Ok { + return 0, fmt.Errorf("AcceptConnection on worker %d: %s", worker.ID, resp.Error) + } + + return resp.BackendPid, nil +} + +// selectWorker picks the least-loaded worker using health information. +func (p *WorkerPool) selectWorker() (*ManagedWorker, error) { + p.mu.RLock() + defer p.mu.RUnlock() + + if len(p.workers) == 0 { + return nil, fmt.Errorf("no workers available") + } + + var best *ManagedWorker + bestConns := int32(1<<31 - 1) + + for _, w := range p.workers { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + health, err := w.Client.Health(ctx, &pb.HealthRequest{}) + cancel() + + if err != nil || !health.Healthy { + continue + } + + if health.ActiveConnections < bestConns { + bestConns = health.ActiveConnections + best = w + } + } + + if best == nil { + // Fallback to round-robin if health checks fail + for _, w := range p.workers { + best = w + break + } + } + + if best == nil { + return nil, fmt.Errorf("no healthy workers") + } + + return best, nil +} + +// CancelQuery forwards a cancel request to the appropriate worker. +func (p *WorkerPool) CancelQuery(backendPid, secretKey int32) bool { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, w := range p.workers { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + resp, err := w.Client.CancelQuery(ctx, &pb.CancelQueryRequest{ + BackendPid: backendPid, + SecretKey: secretKey, + }) + cancel() + + if err == nil && resp.Cancelled { + return true + } + } + return false +} + +// DrainAll sends Drain to all workers. +func (p *WorkerPool) DrainAll(timeout time.Duration) error { + p.mu.RLock() + workers := make([]*ManagedWorker, 0, len(p.workers)) + for _, w := range p.workers { + workers = append(workers, w) + } + p.mu.RUnlock() + + var wg sync.WaitGroup + for _, w := range workers { + wg.Add(1) + go func(w *ManagedWorker) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + _, err := w.Client.Drain(ctx, &pb.DrainRequest{TimeoutNs: int64(timeout)}) + cancel() + if err != nil { + slog.Warn("Failed to drain worker.", "id", w.ID, "error", err) + } + }(w) + } + + wg.Wait() + return nil +} + +// ShutdownAll sends Shutdown to all workers and waits for them to exit. +func (p *WorkerPool) ShutdownAll(timeout time.Duration) { + p.mu.RLock() + workers := make([]*ManagedWorker, 0, len(p.workers)) + for _, w := range p.workers { + workers = append(workers, w) + } + p.mu.RUnlock() + + for _, w := range workers { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + _, _ = w.Client.Shutdown(ctx, &pb.ShutdownRequest{TimeoutNs: int64(timeout)}) + cancel() + } + + // Wait for processes to exit + deadline := time.After(timeout) + for _, w := range workers { + select { + case <-w.done: + case <-deadline: + slog.Warn("Worker shutdown timeout, killing.", "id", w.ID, "pid", w.PID) + if w.Cmd.Process != nil { + _ = w.Cmd.Process.Kill() + } + } + } +} + +// HealthCheckLoop periodically checks worker health and restarts dead workers. +func (p *WorkerPool) HealthCheckLoop(ctx context.Context, interval time.Duration, desiredCount int) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + p.mu.RLock() + currentCount := len(p.workers) + p.mu.RUnlock() + + if currentCount < desiredCount { + slog.Warn("Worker count below desired, spawning replacements.", + "current", currentCount, "desired", desiredCount) + for i := 0; i < desiredCount; i++ { + p.mu.RLock() + _, exists := p.workers[i] + p.mu.RUnlock() + if !exists { + if err := p.SpawnWorker(i); err != nil { + slog.Error("Failed to respawn worker.", "id", i, "error", err) + } + } + } + } + } + } +} + +// Workers returns a snapshot of all managed workers. +func (p *WorkerPool) Workers() []*ManagedWorker { + p.mu.RLock() + defer p.mu.RUnlock() + result := make([]*ManagedWorker, 0, len(p.workers)) + for _, w := range p.workers { + result = append(result, w) + } + return result +} + +// RollingUpdate replaces workers one at a time with a new binary. +func (p *WorkerPool) RollingUpdate(ctx context.Context) error { + p.mu.RLock() + ids := make([]int, 0, len(p.workers)) + for id := range p.workers { + ids = append(ids, id) + } + p.mu.RUnlock() + + for _, id := range ids { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + slog.Info("Rolling update: replacing worker.", "id", id) + + // Drain old worker + p.mu.RLock() + old, exists := p.workers[id] + p.mu.RUnlock() + if !exists { + continue + } + + drainCtx, drainCancel := context.WithTimeout(ctx, 60*time.Second) + _, _ = old.Client.Drain(drainCtx, &pb.DrainRequest{TimeoutNs: int64(60 * time.Second)}) + drainCancel() + + // Shutdown old worker + shutCtx, shutCancel := context.WithTimeout(ctx, 30*time.Second) + _, _ = old.Client.Shutdown(shutCtx, &pb.ShutdownRequest{TimeoutNs: int64(30 * time.Second)}) + shutCancel() + + // Wait for old worker to exit + select { + case <-old.done: + case <-time.After(30 * time.Second): + if old.Cmd.Process != nil { + _ = old.Cmd.Process.Kill() + } + } + + old.GRPCConn.Close() + + // Spawn replacement + if err := p.SpawnWorker(id); err != nil { + return fmt.Errorf("failed to spawn replacement worker %d: %w", id, err) + } + + slog.Info("Rolling update: worker replaced.", "id", id) + } + + return nil +} + +func buildConfigureRequest(cfg server.Config) *pb.ConfigureRequest { + req := &pb.ConfigureRequest{ + DataDir: cfg.DataDir, + Extensions: cfg.Extensions, + IdleTimeoutNs: int64(cfg.IdleTimeout), + TlsCertFile: cfg.TLSCertFile, + TlsKeyFile: cfg.TLSKeyFile, + Users: cfg.Users, + } + + if cfg.DuckLake.MetadataStore != "" { + req.Ducklake = &pb.DuckLakeConfig{ + MetadataStore: cfg.DuckLake.MetadataStore, + ObjectStore: cfg.DuckLake.ObjectStore, + DataPath: cfg.DuckLake.DataPath, + S3Provider: cfg.DuckLake.S3Provider, + S3Endpoint: cfg.DuckLake.S3Endpoint, + S3AccessKey: cfg.DuckLake.S3AccessKey, + S3SecretKey: cfg.DuckLake.S3SecretKey, + S3Region: cfg.DuckLake.S3Region, + S3UseSsl: cfg.DuckLake.S3UseSSL, + S3UrlStyle: cfg.DuckLake.S3URLStyle, + S3Chain: cfg.DuckLake.S3Chain, + S3Profile: cfg.DuckLake.S3Profile, + } + } + + return req +} + +func waitForSocket(path string, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if _, err := os.Stat(path); err == nil { + // Try to connect + conn, err := net.DialTimeout("unix", path, time.Second) + if err == nil { + conn.Close() + return nil + } + } + time.Sleep(50 * time.Millisecond) + } + return fmt.Errorf("socket %s not ready after %v", path, timeout) +} diff --git a/controlplane/proto/generate.go b/controlplane/proto/generate.go new file mode 100644 index 0000000..0ad97c7 --- /dev/null +++ b/controlplane/proto/generate.go @@ -0,0 +1,3 @@ +package proto + +//go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative worker.proto diff --git a/controlplane/proto/worker.pb.go b/controlplane/proto/worker.pb.go new file mode 100644 index 0000000..6716aa2 --- /dev/null +++ b/controlplane/proto/worker.pb.go @@ -0,0 +1,964 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.33.4 +// source: controlplane/proto/worker.proto + +package proto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type ConfigureRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + DataDir string `protobuf:"bytes,1,opt,name=data_dir,json=dataDir,proto3" json:"data_dir,omitempty"` + Extensions []string `protobuf:"bytes,2,rep,name=extensions,proto3" json:"extensions,omitempty"` + IdleTimeoutNs int64 `protobuf:"varint,3,opt,name=idle_timeout_ns,json=idleTimeoutNs,proto3" json:"idle_timeout_ns,omitempty"` + // TLS configuration + TlsCertFile string `protobuf:"bytes,4,opt,name=tls_cert_file,json=tlsCertFile,proto3" json:"tls_cert_file,omitempty"` + TlsKeyFile string `protobuf:"bytes,5,opt,name=tls_key_file,json=tlsKeyFile,proto3" json:"tls_key_file,omitempty"` + // DuckLake configuration + Ducklake *DuckLakeConfig `protobuf:"bytes,6,opt,name=ducklake,proto3" json:"ducklake,omitempty"` + // Authentication - map of username -> password + Users map[string]string `protobuf:"bytes,7,rep,name=users,proto3" json:"users,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ConfigureRequest) Reset() { + *x = ConfigureRequest{} + mi := &file_controlplane_proto_worker_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ConfigureRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ConfigureRequest) ProtoMessage() {} + +func (x *ConfigureRequest) ProtoReflect() protoreflect.Message { + mi := &file_controlplane_proto_worker_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ConfigureRequest.ProtoReflect.Descriptor instead. +func (*ConfigureRequest) Descriptor() ([]byte, []int) { + return file_controlplane_proto_worker_proto_rawDescGZIP(), []int{0} +} + +func (x *ConfigureRequest) GetDataDir() string { + if x != nil { + return x.DataDir + } + return "" +} + +func (x *ConfigureRequest) GetExtensions() []string { + if x != nil { + return x.Extensions + } + return nil +} + +func (x *ConfigureRequest) GetIdleTimeoutNs() int64 { + if x != nil { + return x.IdleTimeoutNs + } + return 0 +} + +func (x *ConfigureRequest) GetTlsCertFile() string { + if x != nil { + return x.TlsCertFile + } + return "" +} + +func (x *ConfigureRequest) GetTlsKeyFile() string { + if x != nil { + return x.TlsKeyFile + } + return "" +} + +func (x *ConfigureRequest) GetDucklake() *DuckLakeConfig { + if x != nil { + return x.Ducklake + } + return nil +} + +func (x *ConfigureRequest) GetUsers() map[string]string { + if x != nil { + return x.Users + } + return nil +} + +type DuckLakeConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + MetadataStore string `protobuf:"bytes,1,opt,name=metadata_store,json=metadataStore,proto3" json:"metadata_store,omitempty"` + ObjectStore string `protobuf:"bytes,2,opt,name=object_store,json=objectStore,proto3" json:"object_store,omitempty"` + DataPath string `protobuf:"bytes,3,opt,name=data_path,json=dataPath,proto3" json:"data_path,omitempty"` + S3Provider string `protobuf:"bytes,4,opt,name=s3_provider,json=s3Provider,proto3" json:"s3_provider,omitempty"` + S3Endpoint string `protobuf:"bytes,5,opt,name=s3_endpoint,json=s3Endpoint,proto3" json:"s3_endpoint,omitempty"` + S3AccessKey string `protobuf:"bytes,6,opt,name=s3_access_key,json=s3AccessKey,proto3" json:"s3_access_key,omitempty"` + S3SecretKey string `protobuf:"bytes,7,opt,name=s3_secret_key,json=s3SecretKey,proto3" json:"s3_secret_key,omitempty"` + S3Region string `protobuf:"bytes,8,opt,name=s3_region,json=s3Region,proto3" json:"s3_region,omitempty"` + S3UseSsl bool `protobuf:"varint,9,opt,name=s3_use_ssl,json=s3UseSsl,proto3" json:"s3_use_ssl,omitempty"` + S3UrlStyle string `protobuf:"bytes,10,opt,name=s3_url_style,json=s3UrlStyle,proto3" json:"s3_url_style,omitempty"` + S3Chain string `protobuf:"bytes,11,opt,name=s3_chain,json=s3Chain,proto3" json:"s3_chain,omitempty"` + S3Profile string `protobuf:"bytes,12,opt,name=s3_profile,json=s3Profile,proto3" json:"s3_profile,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DuckLakeConfig) Reset() { + *x = DuckLakeConfig{} + mi := &file_controlplane_proto_worker_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DuckLakeConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DuckLakeConfig) ProtoMessage() {} + +func (x *DuckLakeConfig) ProtoReflect() protoreflect.Message { + mi := &file_controlplane_proto_worker_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DuckLakeConfig.ProtoReflect.Descriptor instead. +func (*DuckLakeConfig) Descriptor() ([]byte, []int) { + return file_controlplane_proto_worker_proto_rawDescGZIP(), []int{1} +} + +func (x *DuckLakeConfig) GetMetadataStore() string { + if x != nil { + return x.MetadataStore + } + return "" +} + +func (x *DuckLakeConfig) GetObjectStore() string { + if x != nil { + return x.ObjectStore + } + return "" +} + +func (x *DuckLakeConfig) GetDataPath() string { + if x != nil { + return x.DataPath + } + return "" +} + +func (x *DuckLakeConfig) GetS3Provider() string { + if x != nil { + return x.S3Provider + } + return "" +} + +func (x *DuckLakeConfig) GetS3Endpoint() string { + if x != nil { + return x.S3Endpoint + } + return "" +} + +func (x *DuckLakeConfig) GetS3AccessKey() string { + if x != nil { + return x.S3AccessKey + } + return "" +} + +func (x *DuckLakeConfig) GetS3SecretKey() string { + if x != nil { + return x.S3SecretKey + } + return "" +} + +func (x *DuckLakeConfig) GetS3Region() string { + if x != nil { + return x.S3Region + } + return "" +} + +func (x *DuckLakeConfig) GetS3UseSsl() bool { + if x != nil { + return x.S3UseSsl + } + return false +} + +func (x *DuckLakeConfig) GetS3UrlStyle() string { + if x != nil { + return x.S3UrlStyle + } + return "" +} + +func (x *DuckLakeConfig) GetS3Chain() string { + if x != nil { + return x.S3Chain + } + return "" +} + +func (x *DuckLakeConfig) GetS3Profile() string { + if x != nil { + return x.S3Profile + } + return "" +} + +type ConfigureResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ok bool `protobuf:"varint,1,opt,name=ok,proto3" json:"ok,omitempty"` + Error string `protobuf:"bytes,2,opt,name=error,proto3" json:"error,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ConfigureResponse) Reset() { + *x = ConfigureResponse{} + mi := &file_controlplane_proto_worker_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ConfigureResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ConfigureResponse) ProtoMessage() {} + +func (x *ConfigureResponse) ProtoReflect() protoreflect.Message { + mi := &file_controlplane_proto_worker_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ConfigureResponse.ProtoReflect.Descriptor instead. +func (*ConfigureResponse) Descriptor() ([]byte, []int) { + return file_controlplane_proto_worker_proto_rawDescGZIP(), []int{2} +} + +func (x *ConfigureResponse) GetOk() bool { + if x != nil { + return x.Ok + } + return false +} + +func (x *ConfigureResponse) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +type AcceptConnectionRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + RemoteAddr string `protobuf:"bytes,1,opt,name=remote_addr,json=remoteAddr,proto3" json:"remote_addr,omitempty"` + BackendSecretKey int32 `protobuf:"varint,2,opt,name=backend_secret_key,json=backendSecretKey,proto3" json:"backend_secret_key,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AcceptConnectionRequest) Reset() { + *x = AcceptConnectionRequest{} + mi := &file_controlplane_proto_worker_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AcceptConnectionRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AcceptConnectionRequest) ProtoMessage() {} + +func (x *AcceptConnectionRequest) ProtoReflect() protoreflect.Message { + mi := &file_controlplane_proto_worker_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AcceptConnectionRequest.ProtoReflect.Descriptor instead. +func (*AcceptConnectionRequest) Descriptor() ([]byte, []int) { + return file_controlplane_proto_worker_proto_rawDescGZIP(), []int{3} +} + +func (x *AcceptConnectionRequest) GetRemoteAddr() string { + if x != nil { + return x.RemoteAddr + } + return "" +} + +func (x *AcceptConnectionRequest) GetBackendSecretKey() int32 { + if x != nil { + return x.BackendSecretKey + } + return 0 +} + +type AcceptConnectionResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ok bool `protobuf:"varint,1,opt,name=ok,proto3" json:"ok,omitempty"` + BackendPid int32 `protobuf:"varint,2,opt,name=backend_pid,json=backendPid,proto3" json:"backend_pid,omitempty"` + Error string `protobuf:"bytes,3,opt,name=error,proto3" json:"error,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AcceptConnectionResponse) Reset() { + *x = AcceptConnectionResponse{} + mi := &file_controlplane_proto_worker_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AcceptConnectionResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AcceptConnectionResponse) ProtoMessage() {} + +func (x *AcceptConnectionResponse) ProtoReflect() protoreflect.Message { + mi := &file_controlplane_proto_worker_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AcceptConnectionResponse.ProtoReflect.Descriptor instead. +func (*AcceptConnectionResponse) Descriptor() ([]byte, []int) { + return file_controlplane_proto_worker_proto_rawDescGZIP(), []int{4} +} + +func (x *AcceptConnectionResponse) GetOk() bool { + if x != nil { + return x.Ok + } + return false +} + +func (x *AcceptConnectionResponse) GetBackendPid() int32 { + if x != nil { + return x.BackendPid + } + return 0 +} + +func (x *AcceptConnectionResponse) GetError() string { + if x != nil { + return x.Error + } + return "" +} + +type CancelQueryRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + BackendPid int32 `protobuf:"varint,1,opt,name=backend_pid,json=backendPid,proto3" json:"backend_pid,omitempty"` + SecretKey int32 `protobuf:"varint,2,opt,name=secret_key,json=secretKey,proto3" json:"secret_key,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CancelQueryRequest) Reset() { + *x = CancelQueryRequest{} + mi := &file_controlplane_proto_worker_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CancelQueryRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CancelQueryRequest) ProtoMessage() {} + +func (x *CancelQueryRequest) ProtoReflect() protoreflect.Message { + mi := &file_controlplane_proto_worker_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CancelQueryRequest.ProtoReflect.Descriptor instead. +func (*CancelQueryRequest) Descriptor() ([]byte, []int) { + return file_controlplane_proto_worker_proto_rawDescGZIP(), []int{5} +} + +func (x *CancelQueryRequest) GetBackendPid() int32 { + if x != nil { + return x.BackendPid + } + return 0 +} + +func (x *CancelQueryRequest) GetSecretKey() int32 { + if x != nil { + return x.SecretKey + } + return 0 +} + +type CancelQueryResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Cancelled bool `protobuf:"varint,1,opt,name=cancelled,proto3" json:"cancelled,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *CancelQueryResponse) Reset() { + *x = CancelQueryResponse{} + mi := &file_controlplane_proto_worker_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *CancelQueryResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CancelQueryResponse) ProtoMessage() {} + +func (x *CancelQueryResponse) ProtoReflect() protoreflect.Message { + mi := &file_controlplane_proto_worker_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CancelQueryResponse.ProtoReflect.Descriptor instead. +func (*CancelQueryResponse) Descriptor() ([]byte, []int) { + return file_controlplane_proto_worker_proto_rawDescGZIP(), []int{6} +} + +func (x *CancelQueryResponse) GetCancelled() bool { + if x != nil { + return x.Cancelled + } + return false +} + +type DrainRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + TimeoutNs int64 `protobuf:"varint,1,opt,name=timeout_ns,json=timeoutNs,proto3" json:"timeout_ns,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DrainRequest) Reset() { + *x = DrainRequest{} + mi := &file_controlplane_proto_worker_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DrainRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DrainRequest) ProtoMessage() {} + +func (x *DrainRequest) ProtoReflect() protoreflect.Message { + mi := &file_controlplane_proto_worker_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DrainRequest.ProtoReflect.Descriptor instead. +func (*DrainRequest) Descriptor() ([]byte, []int) { + return file_controlplane_proto_worker_proto_rawDescGZIP(), []int{7} +} + +func (x *DrainRequest) GetTimeoutNs() int64 { + if x != nil { + return x.TimeoutNs + } + return 0 +} + +type DrainResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ok bool `protobuf:"varint,1,opt,name=ok,proto3" json:"ok,omitempty"` + RemainingConnections int32 `protobuf:"varint,2,opt,name=remaining_connections,json=remainingConnections,proto3" json:"remaining_connections,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DrainResponse) Reset() { + *x = DrainResponse{} + mi := &file_controlplane_proto_worker_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DrainResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DrainResponse) ProtoMessage() {} + +func (x *DrainResponse) ProtoReflect() protoreflect.Message { + mi := &file_controlplane_proto_worker_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DrainResponse.ProtoReflect.Descriptor instead. +func (*DrainResponse) Descriptor() ([]byte, []int) { + return file_controlplane_proto_worker_proto_rawDescGZIP(), []int{8} +} + +func (x *DrainResponse) GetOk() bool { + if x != nil { + return x.Ok + } + return false +} + +func (x *DrainResponse) GetRemainingConnections() int32 { + if x != nil { + return x.RemainingConnections + } + return 0 +} + +type HealthRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HealthRequest) Reset() { + *x = HealthRequest{} + mi := &file_controlplane_proto_worker_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HealthRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HealthRequest) ProtoMessage() {} + +func (x *HealthRequest) ProtoReflect() protoreflect.Message { + mi := &file_controlplane_proto_worker_proto_msgTypes[9] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HealthRequest.ProtoReflect.Descriptor instead. +func (*HealthRequest) Descriptor() ([]byte, []int) { + return file_controlplane_proto_worker_proto_rawDescGZIP(), []int{9} +} + +type HealthResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Healthy bool `protobuf:"varint,1,opt,name=healthy,proto3" json:"healthy,omitempty"` + ActiveConnections int32 `protobuf:"varint,2,opt,name=active_connections,json=activeConnections,proto3" json:"active_connections,omitempty"` + UptimeNs int64 `protobuf:"varint,3,opt,name=uptime_ns,json=uptimeNs,proto3" json:"uptime_ns,omitempty"` + TotalQueries int64 `protobuf:"varint,4,opt,name=total_queries,json=totalQueries,proto3" json:"total_queries,omitempty"` + Draining bool `protobuf:"varint,5,opt,name=draining,proto3" json:"draining,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HealthResponse) Reset() { + *x = HealthResponse{} + mi := &file_controlplane_proto_worker_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HealthResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HealthResponse) ProtoMessage() {} + +func (x *HealthResponse) ProtoReflect() protoreflect.Message { + mi := &file_controlplane_proto_worker_proto_msgTypes[10] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HealthResponse.ProtoReflect.Descriptor instead. +func (*HealthResponse) Descriptor() ([]byte, []int) { + return file_controlplane_proto_worker_proto_rawDescGZIP(), []int{10} +} + +func (x *HealthResponse) GetHealthy() bool { + if x != nil { + return x.Healthy + } + return false +} + +func (x *HealthResponse) GetActiveConnections() int32 { + if x != nil { + return x.ActiveConnections + } + return 0 +} + +func (x *HealthResponse) GetUptimeNs() int64 { + if x != nil { + return x.UptimeNs + } + return 0 +} + +func (x *HealthResponse) GetTotalQueries() int64 { + if x != nil { + return x.TotalQueries + } + return 0 +} + +func (x *HealthResponse) GetDraining() bool { + if x != nil { + return x.Draining + } + return false +} + +type ShutdownRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + TimeoutNs int64 `protobuf:"varint,1,opt,name=timeout_ns,json=timeoutNs,proto3" json:"timeout_ns,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ShutdownRequest) Reset() { + *x = ShutdownRequest{} + mi := &file_controlplane_proto_worker_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ShutdownRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ShutdownRequest) ProtoMessage() {} + +func (x *ShutdownRequest) ProtoReflect() protoreflect.Message { + mi := &file_controlplane_proto_worker_proto_msgTypes[11] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ShutdownRequest.ProtoReflect.Descriptor instead. +func (*ShutdownRequest) Descriptor() ([]byte, []int) { + return file_controlplane_proto_worker_proto_rawDescGZIP(), []int{11} +} + +func (x *ShutdownRequest) GetTimeoutNs() int64 { + if x != nil { + return x.TimeoutNs + } + return 0 +} + +type ShutdownResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Ok bool `protobuf:"varint,1,opt,name=ok,proto3" json:"ok,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ShutdownResponse) Reset() { + *x = ShutdownResponse{} + mi := &file_controlplane_proto_worker_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ShutdownResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ShutdownResponse) ProtoMessage() {} + +func (x *ShutdownResponse) ProtoReflect() protoreflect.Message { + mi := &file_controlplane_proto_worker_proto_msgTypes[12] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ShutdownResponse.ProtoReflect.Descriptor instead. +func (*ShutdownResponse) Descriptor() ([]byte, []int) { + return file_controlplane_proto_worker_proto_rawDescGZIP(), []int{12} +} + +func (x *ShutdownResponse) GetOk() bool { + if x != nil { + return x.Ok + } + return false +} + +var File_controlplane_proto_worker_proto protoreflect.FileDescriptor + +const file_controlplane_proto_worker_proto_rawDesc = "" + + "\n" + + "\x1fcontrolplane/proto/worker.proto\x12\fcontrolplane\"\xf0\x02\n" + + "\x10ConfigureRequest\x12\x19\n" + + "\bdata_dir\x18\x01 \x01(\tR\adataDir\x12\x1e\n" + + "\n" + + "extensions\x18\x02 \x03(\tR\n" + + "extensions\x12&\n" + + "\x0fidle_timeout_ns\x18\x03 \x01(\x03R\ridleTimeoutNs\x12\"\n" + + "\rtls_cert_file\x18\x04 \x01(\tR\vtlsCertFile\x12 \n" + + "\ftls_key_file\x18\x05 \x01(\tR\n" + + "tlsKeyFile\x128\n" + + "\bducklake\x18\x06 \x01(\v2\x1c.controlplane.DuckLakeConfigR\bducklake\x12?\n" + + "\x05users\x18\a \x03(\v2).controlplane.ConfigureRequest.UsersEntryR\x05users\x1a8\n" + + "\n" + + "UsersEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"\x98\x03\n" + + "\x0eDuckLakeConfig\x12%\n" + + "\x0emetadata_store\x18\x01 \x01(\tR\rmetadataStore\x12!\n" + + "\fobject_store\x18\x02 \x01(\tR\vobjectStore\x12\x1b\n" + + "\tdata_path\x18\x03 \x01(\tR\bdataPath\x12\x1f\n" + + "\vs3_provider\x18\x04 \x01(\tR\n" + + "s3Provider\x12\x1f\n" + + "\vs3_endpoint\x18\x05 \x01(\tR\n" + + "s3Endpoint\x12\"\n" + + "\rs3_access_key\x18\x06 \x01(\tR\vs3AccessKey\x12\"\n" + + "\rs3_secret_key\x18\a \x01(\tR\vs3SecretKey\x12\x1b\n" + + "\ts3_region\x18\b \x01(\tR\bs3Region\x12\x1c\n" + + "\n" + + "s3_use_ssl\x18\t \x01(\bR\bs3UseSsl\x12 \n" + + "\fs3_url_style\x18\n" + + " \x01(\tR\n" + + "s3UrlStyle\x12\x19\n" + + "\bs3_chain\x18\v \x01(\tR\as3Chain\x12\x1d\n" + + "\n" + + "s3_profile\x18\f \x01(\tR\ts3Profile\"9\n" + + "\x11ConfigureResponse\x12\x0e\n" + + "\x02ok\x18\x01 \x01(\bR\x02ok\x12\x14\n" + + "\x05error\x18\x02 \x01(\tR\x05error\"h\n" + + "\x17AcceptConnectionRequest\x12\x1f\n" + + "\vremote_addr\x18\x01 \x01(\tR\n" + + "remoteAddr\x12,\n" + + "\x12backend_secret_key\x18\x02 \x01(\x05R\x10backendSecretKey\"a\n" + + "\x18AcceptConnectionResponse\x12\x0e\n" + + "\x02ok\x18\x01 \x01(\bR\x02ok\x12\x1f\n" + + "\vbackend_pid\x18\x02 \x01(\x05R\n" + + "backendPid\x12\x14\n" + + "\x05error\x18\x03 \x01(\tR\x05error\"T\n" + + "\x12CancelQueryRequest\x12\x1f\n" + + "\vbackend_pid\x18\x01 \x01(\x05R\n" + + "backendPid\x12\x1d\n" + + "\n" + + "secret_key\x18\x02 \x01(\x05R\tsecretKey\"3\n" + + "\x13CancelQueryResponse\x12\x1c\n" + + "\tcancelled\x18\x01 \x01(\bR\tcancelled\"-\n" + + "\fDrainRequest\x12\x1d\n" + + "\n" + + "timeout_ns\x18\x01 \x01(\x03R\ttimeoutNs\"T\n" + + "\rDrainResponse\x12\x0e\n" + + "\x02ok\x18\x01 \x01(\bR\x02ok\x123\n" + + "\x15remaining_connections\x18\x02 \x01(\x05R\x14remainingConnections\"\x0f\n" + + "\rHealthRequest\"\xb7\x01\n" + + "\x0eHealthResponse\x12\x18\n" + + "\ahealthy\x18\x01 \x01(\bR\ahealthy\x12-\n" + + "\x12active_connections\x18\x02 \x01(\x05R\x11activeConnections\x12\x1b\n" + + "\tuptime_ns\x18\x03 \x01(\x03R\buptimeNs\x12#\n" + + "\rtotal_queries\x18\x04 \x01(\x03R\ftotalQueries\x12\x1a\n" + + "\bdraining\x18\x05 \x01(\bR\bdraining\"0\n" + + "\x0fShutdownRequest\x12\x1d\n" + + "\n" + + "timeout_ns\x18\x01 \x01(\x03R\ttimeoutNs\"\"\n" + + "\x10ShutdownResponse\x12\x0e\n" + + "\x02ok\x18\x01 \x01(\bR\x02ok2\xe6\x03\n" + + "\rWorkerControl\x12L\n" + + "\tConfigure\x12\x1e.controlplane.ConfigureRequest\x1a\x1f.controlplane.ConfigureResponse\x12a\n" + + "\x10AcceptConnection\x12%.controlplane.AcceptConnectionRequest\x1a&.controlplane.AcceptConnectionResponse\x12R\n" + + "\vCancelQuery\x12 .controlplane.CancelQueryRequest\x1a!.controlplane.CancelQueryResponse\x12@\n" + + "\x05Drain\x12\x1a.controlplane.DrainRequest\x1a\x1b.controlplane.DrainResponse\x12C\n" + + "\x06Health\x12\x1b.controlplane.HealthRequest\x1a\x1c.controlplane.HealthResponse\x12I\n" + + "\bShutdown\x12\x1d.controlplane.ShutdownRequest\x1a\x1e.controlplane.ShutdownResponseB0Z.github.com/posthog/duckgres/controlplane/protob\x06proto3" + +var ( + file_controlplane_proto_worker_proto_rawDescOnce sync.Once + file_controlplane_proto_worker_proto_rawDescData []byte +) + +func file_controlplane_proto_worker_proto_rawDescGZIP() []byte { + file_controlplane_proto_worker_proto_rawDescOnce.Do(func() { + file_controlplane_proto_worker_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_controlplane_proto_worker_proto_rawDesc), len(file_controlplane_proto_worker_proto_rawDesc))) + }) + return file_controlplane_proto_worker_proto_rawDescData +} + +var file_controlplane_proto_worker_proto_msgTypes = make([]protoimpl.MessageInfo, 14) +var file_controlplane_proto_worker_proto_goTypes = []any{ + (*ConfigureRequest)(nil), // 0: controlplane.ConfigureRequest + (*DuckLakeConfig)(nil), // 1: controlplane.DuckLakeConfig + (*ConfigureResponse)(nil), // 2: controlplane.ConfigureResponse + (*AcceptConnectionRequest)(nil), // 3: controlplane.AcceptConnectionRequest + (*AcceptConnectionResponse)(nil), // 4: controlplane.AcceptConnectionResponse + (*CancelQueryRequest)(nil), // 5: controlplane.CancelQueryRequest + (*CancelQueryResponse)(nil), // 6: controlplane.CancelQueryResponse + (*DrainRequest)(nil), // 7: controlplane.DrainRequest + (*DrainResponse)(nil), // 8: controlplane.DrainResponse + (*HealthRequest)(nil), // 9: controlplane.HealthRequest + (*HealthResponse)(nil), // 10: controlplane.HealthResponse + (*ShutdownRequest)(nil), // 11: controlplane.ShutdownRequest + (*ShutdownResponse)(nil), // 12: controlplane.ShutdownResponse + nil, // 13: controlplane.ConfigureRequest.UsersEntry +} +var file_controlplane_proto_worker_proto_depIdxs = []int32{ + 1, // 0: controlplane.ConfigureRequest.ducklake:type_name -> controlplane.DuckLakeConfig + 13, // 1: controlplane.ConfigureRequest.users:type_name -> controlplane.ConfigureRequest.UsersEntry + 0, // 2: controlplane.WorkerControl.Configure:input_type -> controlplane.ConfigureRequest + 3, // 3: controlplane.WorkerControl.AcceptConnection:input_type -> controlplane.AcceptConnectionRequest + 5, // 4: controlplane.WorkerControl.CancelQuery:input_type -> controlplane.CancelQueryRequest + 7, // 5: controlplane.WorkerControl.Drain:input_type -> controlplane.DrainRequest + 9, // 6: controlplane.WorkerControl.Health:input_type -> controlplane.HealthRequest + 11, // 7: controlplane.WorkerControl.Shutdown:input_type -> controlplane.ShutdownRequest + 2, // 8: controlplane.WorkerControl.Configure:output_type -> controlplane.ConfigureResponse + 4, // 9: controlplane.WorkerControl.AcceptConnection:output_type -> controlplane.AcceptConnectionResponse + 6, // 10: controlplane.WorkerControl.CancelQuery:output_type -> controlplane.CancelQueryResponse + 8, // 11: controlplane.WorkerControl.Drain:output_type -> controlplane.DrainResponse + 10, // 12: controlplane.WorkerControl.Health:output_type -> controlplane.HealthResponse + 12, // 13: controlplane.WorkerControl.Shutdown:output_type -> controlplane.ShutdownResponse + 8, // [8:14] is the sub-list for method output_type + 2, // [2:8] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_controlplane_proto_worker_proto_init() } +func file_controlplane_proto_worker_proto_init() { + if File_controlplane_proto_worker_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_controlplane_proto_worker_proto_rawDesc), len(file_controlplane_proto_worker_proto_rawDesc)), + NumEnums: 0, + NumMessages: 14, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_controlplane_proto_worker_proto_goTypes, + DependencyIndexes: file_controlplane_proto_worker_proto_depIdxs, + MessageInfos: file_controlplane_proto_worker_proto_msgTypes, + }.Build() + File_controlplane_proto_worker_proto = out.File + file_controlplane_proto_worker_proto_goTypes = nil + file_controlplane_proto_worker_proto_depIdxs = nil +} diff --git a/controlplane/proto/worker.proto b/controlplane/proto/worker.proto new file mode 100644 index 0000000..bba1c11 --- /dev/null +++ b/controlplane/proto/worker.proto @@ -0,0 +1,112 @@ +syntax = "proto3"; + +package controlplane; + +option go_package = "github.com/posthog/duckgres/controlplane/proto"; + +// WorkerControl is the gRPC service exposed by each long-lived worker process. +// The control plane uses this service to manage workers. +service WorkerControl { + // Configure initializes the worker with database and TLS settings. + // Must be called once before AcceptConnection. + rpc Configure(ConfigureRequest) returns (ConfigureResponse); + + // AcceptConnection tells the worker to accept a new client connection. + // The raw TCP file descriptor is sent out-of-band via the companion Unix socket. + rpc AcceptConnection(AcceptConnectionRequest) returns (AcceptConnectionResponse); + + // CancelQuery cancels a running query by backend key. + rpc CancelQuery(CancelQueryRequest) returns (CancelQueryResponse); + + // Drain tells the worker to stop accepting new connections and finish existing ones. + rpc Drain(DrainRequest) returns (DrainResponse); + + // Health returns the worker's current health and load status. + rpc Health(HealthRequest) returns (HealthResponse); + + // Shutdown tells the worker to drain and exit. + rpc Shutdown(ShutdownRequest) returns (ShutdownResponse); +} + +message ConfigureRequest { + string data_dir = 1; + repeated string extensions = 2; + int64 idle_timeout_ns = 3; + + // TLS configuration + string tls_cert_file = 4; + string tls_key_file = 5; + + // DuckLake configuration + DuckLakeConfig ducklake = 6; + + // Authentication - map of username -> password + map users = 7; +} + +message DuckLakeConfig { + string metadata_store = 1; + string object_store = 2; + string data_path = 3; + string s3_provider = 4; + string s3_endpoint = 5; + string s3_access_key = 6; + string s3_secret_key = 7; + string s3_region = 8; + bool s3_use_ssl = 9; + string s3_url_style = 10; + string s3_chain = 11; + string s3_profile = 12; +} + +message ConfigureResponse { + bool ok = 1; + string error = 2; +} + +message AcceptConnectionRequest { + string remote_addr = 1; + int32 backend_secret_key = 2; +} + +message AcceptConnectionResponse { + bool ok = 1; + int32 backend_pid = 2; + string error = 3; +} + +message CancelQueryRequest { + int32 backend_pid = 1; + int32 secret_key = 2; +} + +message CancelQueryResponse { + bool cancelled = 1; +} + +message DrainRequest { + int64 timeout_ns = 1; +} + +message DrainResponse { + bool ok = 1; + int32 remaining_connections = 2; +} + +message HealthRequest {} + +message HealthResponse { + bool healthy = 1; + int32 active_connections = 2; + int64 uptime_ns = 3; + int64 total_queries = 4; + bool draining = 5; +} + +message ShutdownRequest { + int64 timeout_ns = 1; +} + +message ShutdownResponse { + bool ok = 1; +} diff --git a/controlplane/proto/worker_grpc.pb.go b/controlplane/proto/worker_grpc.pb.go new file mode 100644 index 0000000..629b105 --- /dev/null +++ b/controlplane/proto/worker_grpc.pb.go @@ -0,0 +1,333 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.1 +// - protoc v6.33.4 +// source: controlplane/proto/worker.proto + +package proto + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + WorkerControl_Configure_FullMethodName = "/controlplane.WorkerControl/Configure" + WorkerControl_AcceptConnection_FullMethodName = "/controlplane.WorkerControl/AcceptConnection" + WorkerControl_CancelQuery_FullMethodName = "/controlplane.WorkerControl/CancelQuery" + WorkerControl_Drain_FullMethodName = "/controlplane.WorkerControl/Drain" + WorkerControl_Health_FullMethodName = "/controlplane.WorkerControl/Health" + WorkerControl_Shutdown_FullMethodName = "/controlplane.WorkerControl/Shutdown" +) + +// WorkerControlClient is the client API for WorkerControl service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// WorkerControl is the gRPC service exposed by each long-lived worker process. +// The control plane uses this service to manage workers. +type WorkerControlClient interface { + // Configure initializes the worker with database and TLS settings. + // Must be called once before AcceptConnection. + Configure(ctx context.Context, in *ConfigureRequest, opts ...grpc.CallOption) (*ConfigureResponse, error) + // AcceptConnection tells the worker to accept a new client connection. + // The raw TCP file descriptor is sent out-of-band via the companion Unix socket. + AcceptConnection(ctx context.Context, in *AcceptConnectionRequest, opts ...grpc.CallOption) (*AcceptConnectionResponse, error) + // CancelQuery cancels a running query by backend key. + CancelQuery(ctx context.Context, in *CancelQueryRequest, opts ...grpc.CallOption) (*CancelQueryResponse, error) + // Drain tells the worker to stop accepting new connections and finish existing ones. + Drain(ctx context.Context, in *DrainRequest, opts ...grpc.CallOption) (*DrainResponse, error) + // Health returns the worker's current health and load status. + Health(ctx context.Context, in *HealthRequest, opts ...grpc.CallOption) (*HealthResponse, error) + // Shutdown tells the worker to drain and exit. + Shutdown(ctx context.Context, in *ShutdownRequest, opts ...grpc.CallOption) (*ShutdownResponse, error) +} + +type workerControlClient struct { + cc grpc.ClientConnInterface +} + +func NewWorkerControlClient(cc grpc.ClientConnInterface) WorkerControlClient { + return &workerControlClient{cc} +} + +func (c *workerControlClient) Configure(ctx context.Context, in *ConfigureRequest, opts ...grpc.CallOption) (*ConfigureResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ConfigureResponse) + err := c.cc.Invoke(ctx, WorkerControl_Configure_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *workerControlClient) AcceptConnection(ctx context.Context, in *AcceptConnectionRequest, opts ...grpc.CallOption) (*AcceptConnectionResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(AcceptConnectionResponse) + err := c.cc.Invoke(ctx, WorkerControl_AcceptConnection_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *workerControlClient) CancelQuery(ctx context.Context, in *CancelQueryRequest, opts ...grpc.CallOption) (*CancelQueryResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(CancelQueryResponse) + err := c.cc.Invoke(ctx, WorkerControl_CancelQuery_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *workerControlClient) Drain(ctx context.Context, in *DrainRequest, opts ...grpc.CallOption) (*DrainResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(DrainResponse) + err := c.cc.Invoke(ctx, WorkerControl_Drain_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *workerControlClient) Health(ctx context.Context, in *HealthRequest, opts ...grpc.CallOption) (*HealthResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(HealthResponse) + err := c.cc.Invoke(ctx, WorkerControl_Health_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *workerControlClient) Shutdown(ctx context.Context, in *ShutdownRequest, opts ...grpc.CallOption) (*ShutdownResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(ShutdownResponse) + err := c.cc.Invoke(ctx, WorkerControl_Shutdown_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// WorkerControlServer is the server API for WorkerControl service. +// All implementations must embed UnimplementedWorkerControlServer +// for forward compatibility. +// +// WorkerControl is the gRPC service exposed by each long-lived worker process. +// The control plane uses this service to manage workers. +type WorkerControlServer interface { + // Configure initializes the worker with database and TLS settings. + // Must be called once before AcceptConnection. + Configure(context.Context, *ConfigureRequest) (*ConfigureResponse, error) + // AcceptConnection tells the worker to accept a new client connection. + // The raw TCP file descriptor is sent out-of-band via the companion Unix socket. + AcceptConnection(context.Context, *AcceptConnectionRequest) (*AcceptConnectionResponse, error) + // CancelQuery cancels a running query by backend key. + CancelQuery(context.Context, *CancelQueryRequest) (*CancelQueryResponse, error) + // Drain tells the worker to stop accepting new connections and finish existing ones. + Drain(context.Context, *DrainRequest) (*DrainResponse, error) + // Health returns the worker's current health and load status. + Health(context.Context, *HealthRequest) (*HealthResponse, error) + // Shutdown tells the worker to drain and exit. + Shutdown(context.Context, *ShutdownRequest) (*ShutdownResponse, error) + mustEmbedUnimplementedWorkerControlServer() +} + +// UnimplementedWorkerControlServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedWorkerControlServer struct{} + +func (UnimplementedWorkerControlServer) Configure(context.Context, *ConfigureRequest) (*ConfigureResponse, error) { + return nil, status.Error(codes.Unimplemented, "method Configure not implemented") +} +func (UnimplementedWorkerControlServer) AcceptConnection(context.Context, *AcceptConnectionRequest) (*AcceptConnectionResponse, error) { + return nil, status.Error(codes.Unimplemented, "method AcceptConnection not implemented") +} +func (UnimplementedWorkerControlServer) CancelQuery(context.Context, *CancelQueryRequest) (*CancelQueryResponse, error) { + return nil, status.Error(codes.Unimplemented, "method CancelQuery not implemented") +} +func (UnimplementedWorkerControlServer) Drain(context.Context, *DrainRequest) (*DrainResponse, error) { + return nil, status.Error(codes.Unimplemented, "method Drain not implemented") +} +func (UnimplementedWorkerControlServer) Health(context.Context, *HealthRequest) (*HealthResponse, error) { + return nil, status.Error(codes.Unimplemented, "method Health not implemented") +} +func (UnimplementedWorkerControlServer) Shutdown(context.Context, *ShutdownRequest) (*ShutdownResponse, error) { + return nil, status.Error(codes.Unimplemented, "method Shutdown not implemented") +} +func (UnimplementedWorkerControlServer) mustEmbedUnimplementedWorkerControlServer() {} +func (UnimplementedWorkerControlServer) testEmbeddedByValue() {} + +// UnsafeWorkerControlServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to WorkerControlServer will +// result in compilation errors. +type UnsafeWorkerControlServer interface { + mustEmbedUnimplementedWorkerControlServer() +} + +func RegisterWorkerControlServer(s grpc.ServiceRegistrar, srv WorkerControlServer) { + // If the following call panics, it indicates UnimplementedWorkerControlServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&WorkerControl_ServiceDesc, srv) +} + +func _WorkerControl_Configure_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ConfigureRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerControlServer).Configure(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: WorkerControl_Configure_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerControlServer).Configure(ctx, req.(*ConfigureRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _WorkerControl_AcceptConnection_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(AcceptConnectionRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerControlServer).AcceptConnection(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: WorkerControl_AcceptConnection_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerControlServer).AcceptConnection(ctx, req.(*AcceptConnectionRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _WorkerControl_CancelQuery_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CancelQueryRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerControlServer).CancelQuery(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: WorkerControl_CancelQuery_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerControlServer).CancelQuery(ctx, req.(*CancelQueryRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _WorkerControl_Drain_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DrainRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerControlServer).Drain(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: WorkerControl_Drain_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerControlServer).Drain(ctx, req.(*DrainRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _WorkerControl_Health_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(HealthRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerControlServer).Health(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: WorkerControl_Health_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerControlServer).Health(ctx, req.(*HealthRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _WorkerControl_Shutdown_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ShutdownRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(WorkerControlServer).Shutdown(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: WorkerControl_Shutdown_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(WorkerControlServer).Shutdown(ctx, req.(*ShutdownRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// WorkerControl_ServiceDesc is the grpc.ServiceDesc for WorkerControl service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var WorkerControl_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "controlplane.WorkerControl", + HandlerType: (*WorkerControlServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Configure", + Handler: _WorkerControl_Configure_Handler, + }, + { + MethodName: "AcceptConnection", + Handler: _WorkerControl_AcceptConnection_Handler, + }, + { + MethodName: "CancelQuery", + Handler: _WorkerControl_CancelQuery_Handler, + }, + { + MethodName: "Drain", + Handler: _WorkerControl_Drain_Handler, + }, + { + MethodName: "Health", + Handler: _WorkerControl_Health_Handler, + }, + { + MethodName: "Shutdown", + Handler: _WorkerControl_Shutdown_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "controlplane/proto/worker.proto", +} diff --git a/controlplane/worker.go b/controlplane/worker.go new file mode 100644 index 0000000..d7f8170 --- /dev/null +++ b/controlplane/worker.go @@ -0,0 +1,551 @@ +package controlplane + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "fmt" + "log/slog" + "net" + "os" + "os/signal" + "sync" + "sync/atomic" + "syscall" + "time" + + "google.golang.org/grpc" + + "github.com/posthog/duckgres/controlplane/fdpass" + pb "github.com/posthog/duckgres/controlplane/proto" + "github.com/posthog/duckgres/server" +) + +// Worker is a long-lived worker process that handles many client connections. +// It exposes a gRPC server for control and a Unix socket for FD passing. +type Worker struct { + pb.UnimplementedWorkerControlServer + + grpcSocketPath string + fdSocketPath string + + mu sync.RWMutex + cfg server.Config + configured bool + draining bool + tlsConfig *tls.Config + dbPool *DBPool + startTime time.Time + totalQueries atomic.Int64 + + // Session tracking + sessions map[int32]*workerSession + sessionsMu sync.RWMutex + sessionsWg sync.WaitGroup + + // Cancellation + activeQueries map[server.BackendKey]context.CancelFunc + activeQueriesMu sync.RWMutex + + // FD passing - stores the most recently received FD + pendingFD int +} + +type workerSession struct { + pid int32 + secretKey int32 + cancel context.CancelFunc + remoteAddr string +} + +// RunWorker is the entry point for a worker process. +// It starts the gRPC server and FD receiver, then waits for signals. +func RunWorker(grpcSocket, fdSocket string) { + w := &Worker{ + grpcSocketPath: grpcSocket, + fdSocketPath: fdSocket, + startTime: time.Now(), + sessions: make(map[int32]*workerSession), + activeQueries: make(map[server.BackendKey]context.CancelFunc), + pendingFD: -1, + } + + // Set up signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + sig := <-sigChan + slog.Info("Worker received signal, shutting down.", "signal", sig) + cancel() + }() + + if err := w.run(ctx); err != nil { + slog.Error("Worker exited with error.", "error", err) + os.Exit(1) + } +} + +func (w *Worker) run(ctx context.Context) error { + // Start gRPC server + grpcLn, err := net.Listen("unix", w.grpcSocketPath) + if err != nil { + return fmt.Errorf("listen gRPC socket: %w", err) + } + defer os.Remove(w.grpcSocketPath) + + grpcServer := grpc.NewServer() + pb.RegisterWorkerControlServer(grpcServer, w) + + go func() { + if err := grpcServer.Serve(grpcLn); err != nil { + slog.Error("gRPC server error.", "error", err) + } + }() + + // Start FD receiver + fdLn, err := net.Listen("unix", w.fdSocketPath) + if err != nil { + return fmt.Errorf("listen FD socket: %w", err) + } + defer os.Remove(w.fdSocketPath) + + // Accept FD connections in a goroutine + go w.fdReceiverLoop(fdLn) + + slog.Info("Worker started.", + "pid", os.Getpid(), + "grpc_socket", w.grpcSocketPath, + "fd_socket", w.fdSocketPath, + ) + + // Wait for context cancellation (signal) + <-ctx.Done() + + slog.Info("Worker shutting down, draining sessions...") + w.mu.Lock() + w.draining = true + w.mu.Unlock() + + // Wait for active sessions with timeout + done := make(chan struct{}) + go func() { + w.sessionsWg.Wait() + close(done) + }() + + select { + case <-done: + slog.Info("All sessions drained.") + case <-time.After(30 * time.Second): + slog.Warn("Drain timeout, force closing sessions.") + } + + grpcServer.GracefulStop() + _ = fdLn.Close() + + if w.dbPool != nil { + w.dbPool.CloseAll(5 * time.Second) + } + + return nil +} + +// fdReceiverLoop accepts connections on the FD socket and reads file descriptors from them. +// Each connection from the control plane is used for a single FD pass. +func (w *Worker) fdReceiverLoop(ln net.Listener) { + for { + conn, err := ln.Accept() + if err != nil { + // Check if listener was closed + if opErr, ok := err.(*net.OpError); ok && opErr.Err.Error() == "use of closed network connection" { + return + } + slog.Error("FD socket accept error.", "error", err) + continue + } + + uc, ok := conn.(*net.UnixConn) + if !ok { + slog.Error("FD socket: not a UnixConn") + conn.Close() + continue + } + + // Receive the FD + fd, err := fdpass.RecvFD(uc) + uc.Close() + if err != nil { + slog.Error("Failed to receive FD.", "error", err) + continue + } + + // Store the FD for the next AcceptConnection RPC to pick up + w.mu.Lock() + w.pendingFD = fd + w.mu.Unlock() + } +} + +func (w *Worker) Configure(_ context.Context, req *pb.ConfigureRequest) (*pb.ConfigureResponse, error) { + w.mu.Lock() + defer w.mu.Unlock() + + if w.configured { + return &pb.ConfigureResponse{Ok: false, Error: "already configured"}, nil + } + + // Build server config from proto + w.cfg = server.Config{ + DataDir: req.DataDir, + Extensions: req.Extensions, + IdleTimeout: time.Duration(req.IdleTimeoutNs), + TLSCertFile: req.TlsCertFile, + TLSKeyFile: req.TlsKeyFile, + Users: req.Users, + } + + if req.Ducklake != nil { + w.cfg.DuckLake = server.DuckLakeConfig{ + MetadataStore: req.Ducklake.MetadataStore, + ObjectStore: req.Ducklake.ObjectStore, + DataPath: req.Ducklake.DataPath, + S3Provider: req.Ducklake.S3Provider, + S3Endpoint: req.Ducklake.S3Endpoint, + S3AccessKey: req.Ducklake.S3AccessKey, + S3SecretKey: req.Ducklake.S3SecretKey, + S3Region: req.Ducklake.S3Region, + S3UseSSL: req.Ducklake.S3UseSsl, + S3URLStyle: req.Ducklake.S3UrlStyle, + S3Chain: req.Ducklake.S3Chain, + S3Profile: req.Ducklake.S3Profile, + } + } + + // Load TLS + if w.cfg.TLSCertFile != "" && w.cfg.TLSKeyFile != "" { + cert, err := tls.LoadX509KeyPair(w.cfg.TLSCertFile, w.cfg.TLSKeyFile) + if err != nil { + return &pb.ConfigureResponse{Ok: false, Error: fmt.Sprintf("load TLS: %v", err)}, nil + } + w.tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + } + + // Create DB pool + w.dbPool = NewDBPool(w.cfg) + + w.configured = true + slog.Info("Worker configured.", "data_dir", w.cfg.DataDir, "ducklake", w.cfg.DuckLake.MetadataStore != "") + return &pb.ConfigureResponse{Ok: true}, nil +} + +func (w *Worker) AcceptConnection(_ context.Context, req *pb.AcceptConnectionRequest) (*pb.AcceptConnectionResponse, error) { + w.mu.RLock() + if !w.configured { + w.mu.RUnlock() + return &pb.AcceptConnectionResponse{Ok: false, Error: "not configured"}, nil + } + if w.draining { + w.mu.RUnlock() + return &pb.AcceptConnectionResponse{Ok: false, Error: "draining"}, nil + } + w.mu.RUnlock() + + // Get the pending FD + w.mu.Lock() + fd := w.pendingFD + w.pendingFD = -1 + w.mu.Unlock() + + if fd < 0 { + return &pb.AcceptConnectionResponse{Ok: false, Error: "no pending FD"}, nil + } + + // Create a net.Conn from the FD + file := os.NewFile(uintptr(fd), "tcp-conn") + if file == nil { + return &pb.AcceptConnectionResponse{Ok: false, Error: "invalid FD"}, nil + } + fc, err := net.FileConn(file) + file.Close() + if err != nil { + return &pb.AcceptConnectionResponse{Ok: false, Error: fmt.Sprintf("FileConn: %v", err)}, nil + } + + tcpConn, ok := fc.(*net.TCPConn) + if !ok { + fc.Close() + return &pb.AcceptConnectionResponse{Ok: false, Error: "not TCP"}, nil + } + + // Use a unique session counter within this worker + sessionPid := w.nextSessionPID() + + w.sessionsWg.Add(1) + go w.handleSession(tcpConn, req.RemoteAddr, sessionPid, req.BackendSecretKey) + + return &pb.AcceptConnectionResponse{Ok: true, BackendPid: sessionPid}, nil +} + +var sessionCounter atomic.Int32 + +func (w *Worker) nextSessionPID() int32 { + // Use worker PID * 1000 + counter to create unique pseudo-PIDs + base := int32(os.Getpid()) + counter := sessionCounter.Add(1) + return base*1000 + counter%1000 +} + +func (w *Worker) handleSession(tcpConn *net.TCPConn, remoteAddr string, pid, secretKey int32) { + defer w.sessionsWg.Done() + + slog.Info("Session starting.", "pid", pid, "remote_addr", remoteAddr) + + // TLS handshake + tlsConn := tls.Server(tcpConn, w.tlsConfig) + if err := tlsConn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil { + slog.Error("Failed to set TLS deadline.", "error", err) + tcpConn.Close() + return + } + if err := tlsConn.Handshake(); err != nil { + slog.Error("TLS handshake failed.", "error", err, "remote_addr", remoteAddr) + tcpConn.Close() + return + } + if err := tlsConn.SetDeadline(time.Time{}); err != nil { + slog.Error("Failed to clear TLS deadline.", "error", err) + tlsConn.Close() + return + } + + reader := bufio.NewReader(tlsConn) + writer := bufio.NewWriter(tlsConn) + + // Read startup message + params, err := server.ReadStartupMessage(reader) + if err != nil { + slog.Error("Failed to read startup message.", "error", err, "remote_addr", remoteAddr) + tlsConn.Close() + return + } + + username := params["user"] + database := params["database"] + + if username == "" { + _ = server.WriteErrorResponse(writer, "FATAL", "28000", "no user specified") + _ = writer.Flush() + tlsConn.Close() + return + } + + // Authenticate + expectedPassword, ok := w.cfg.Users[username] + if !ok { + _ = server.WriteErrorResponse(writer, "FATAL", "28P01", "password authentication failed") + _ = writer.Flush() + tlsConn.Close() + return + } + + if err := server.WriteAuthCleartextPassword(writer); err != nil { + slog.Error("Failed to request password.", "error", err) + tlsConn.Close() + return + } + if err := writer.Flush(); err != nil { + slog.Error("Failed to flush.", "error", err) + tlsConn.Close() + return + } + + msgType, body, err := server.ReadMessage(reader) + if err != nil { + slog.Error("Failed to read password.", "error", err) + tlsConn.Close() + return + } + if msgType != 'p' { + _ = server.WriteErrorResponse(writer, "FATAL", "28000", "expected password message") + _ = writer.Flush() + tlsConn.Close() + return + } + + password := string(bytes.TrimRight(body, "\x00")) + if password != expectedPassword { + _ = server.WriteErrorResponse(writer, "FATAL", "28P01", "password authentication failed") + _ = writer.Flush() + tlsConn.Close() + return + } + + if err := server.WriteAuthOK(writer); err != nil { + slog.Error("Failed to send auth OK.", "error", err) + tlsConn.Close() + return + } + + slog.Info("Session authenticated.", "user", username, "pid", pid, "remote_addr", remoteAddr) + + // Create per-session DuckDB connection + db, err := w.dbPool.CreateSession(pid, username) + if err != nil { + slog.Error("Failed to create session DB.", "error", err) + _ = server.WriteErrorResponse(writer, "FATAL", "28000", fmt.Sprintf("failed to open database: %v", err)) + _ = writer.Flush() + tlsConn.Close() + return + } + defer w.dbPool.CloseSession(pid) + + // Register session for cancel tracking + sessionCtx, sessionCancel := context.WithCancel(context.Background()) + defer sessionCancel() + + w.sessionsMu.Lock() + w.sessions[pid] = &workerSession{ + pid: pid, + secretKey: secretKey, + cancel: sessionCancel, + remoteAddr: remoteAddr, + } + w.sessionsMu.Unlock() + + defer func() { + w.sessionsMu.Lock() + delete(w.sessions, pid) + w.sessionsMu.Unlock() + }() + + // Create a minimal Server for the clientConn + queryCancelCh := make(chan struct{}) + minServer := &server.Server{} + server.InitMinimalServer(minServer, w.cfg, queryCancelCh) + + cc := server.NewClientConn(minServer, tlsConn, reader, writer, username, database, db, pid, secretKey) + + // Send initial params and ready for query + server.SendInitialParams(cc) + if err := server.WriteReadyForQuery(writer, 'I'); err != nil { + slog.Error("Failed to send ready for query.", "error", err) + return + } + if err := writer.Flush(); err != nil { + slog.Error("Failed to flush.", "error", err) + return + } + + // Run message loop + errChan := make(chan error, 1) + go func() { + errChan <- server.RunMessageLoop(cc) + }() + + select { + case err := <-errChan: + if err != nil { + slog.Error("Session error.", "error", err, "pid", pid) + } else { + slog.Info("Session disconnected cleanly.", "user", username, "pid", pid) + } + case <-sessionCtx.Done(): + slog.Info("Session cancelled.", "user", username, "pid", pid) + tlsConn.Close() + } +} + +func (w *Worker) CancelQuery(_ context.Context, req *pb.CancelQueryRequest) (*pb.CancelQueryResponse, error) { + w.sessionsMu.RLock() + defer w.sessionsMu.RUnlock() + + for _, s := range w.sessions { + if s.pid == req.BackendPid && s.secretKey == req.SecretKey { + s.cancel() + slog.Info("Query cancelled via gRPC.", "pid", req.BackendPid) + return &pb.CancelQueryResponse{Cancelled: true}, nil + } + } + return &pb.CancelQueryResponse{Cancelled: false}, nil +} + +func (w *Worker) Drain(_ context.Context, req *pb.DrainRequest) (*pb.DrainResponse, error) { + w.mu.Lock() + w.draining = true + w.mu.Unlock() + + timeout := time.Duration(req.TimeoutNs) + if timeout == 0 { + timeout = 30 * time.Second + } + + done := make(chan struct{}) + go func() { + w.sessionsWg.Wait() + close(done) + }() + + select { + case <-done: + return &pb.DrainResponse{Ok: true, RemainingConnections: 0}, nil + case <-time.After(timeout): + w.sessionsMu.RLock() + remaining := int32(len(w.sessions)) + w.sessionsMu.RUnlock() + return &pb.DrainResponse{Ok: false, RemainingConnections: remaining}, nil + } +} + +func (w *Worker) Health(_ context.Context, _ *pb.HealthRequest) (*pb.HealthResponse, error) { + w.mu.RLock() + draining := w.draining + w.mu.RUnlock() + + w.sessionsMu.RLock() + active := int32(len(w.sessions)) + w.sessionsMu.RUnlock() + + return &pb.HealthResponse{ + Healthy: w.configured && !draining, + ActiveConnections: active, + UptimeNs: int64(time.Since(w.startTime)), + TotalQueries: w.totalQueries.Load(), + Draining: draining, + }, nil +} + +func (w *Worker) Shutdown(_ context.Context, req *pb.ShutdownRequest) (*pb.ShutdownResponse, error) { + timeout := time.Duration(req.TimeoutNs) + if timeout == 0 { + timeout = 30 * time.Second + } + + w.mu.Lock() + w.draining = true + w.mu.Unlock() + + go func() { + done := make(chan struct{}) + go func() { + w.sessionsWg.Wait() + close(done) + }() + + select { + case <-done: + case <-time.After(timeout): + slog.Warn("Shutdown timeout, exiting.") + } + + os.Exit(0) + }() + + return &pb.ShutdownResponse{Ok: true}, nil +} diff --git a/go.mod b/go.mod index af7d058..a4771b5 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,8 @@ require ( go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.16.0 go.opentelemetry.io/otel/sdk v1.40.0 go.opentelemetry.io/otel/sdk/log v0.16.0 + google.golang.org/grpc v1.78.0 + google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 ) @@ -60,6 +62,4 @@ require ( golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260128011058-8636f8732409 // indirect - google.golang.org/grpc v1.78.0 // indirect - google.golang.org/protobuf v1.36.11 // indirect ) diff --git a/main.go b/main.go index cc28fec..63de2e4 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "syscall" "time" + "github.com/posthog/duckgres/controlplane" "github.com/posthog/duckgres/server" "github.com/prometheus/client_golang/prometheus/promhttp" "gopkg.in/yaml.v3" @@ -116,6 +117,14 @@ func main() { idleTimeout := flag.String("idle-timeout", "", "Connection idle timeout (e.g., '30m', '1h', '-1' to disable) (env: DUCKGRES_IDLE_TIMEOUT)") showHelp := flag.Bool("help", false, "Show help message") + // Control plane flags + mode := flag.String("mode", "standalone", "Run mode: standalone (default), control-plane, or worker") + workerCount := flag.Int("worker-count", 4, "Number of worker processes (control-plane mode)") + socketDir := flag.String("socket-dir", "/var/run/duckgres", "Unix socket directory (control-plane mode)") + handoverSocket := flag.String("handover-socket", "", "Handover socket for graceful deployment (control-plane mode)") + grpcSocket := flag.String("grpc-socket", "", "gRPC socket path (worker mode, set by control-plane)") + fdSocket := flag.String("fd-socket", "", "FD passing socket path (worker mode, set by control-plane)") + flag.Usage = func() { fmt.Fprintf(os.Stderr, "Duckgres - PostgreSQL wire protocol server for DuckDB\n\n") fmt.Fprintf(os.Stderr, "Usage: duckgres [options]\n\n") @@ -356,6 +365,15 @@ func main() { } } + // Handle worker mode early (before metrics, certs, etc.) + if *mode == "worker" { + if *grpcSocket == "" || *fdSocket == "" { + fatal("Worker mode requires --grpc-socket and --fd-socket flags") + } + controlplane.RunWorker(*grpcSocket, *fdSocket) + return + } + initMetrics() // Create data directory if it doesn't exist @@ -369,6 +387,19 @@ func main() { } slog.Info("Using TLS certificates", "cert_file", cfg.TLSCertFile, "key_file", cfg.TLSKeyFile) + // Handle control-plane mode + if *mode == "control-plane" { + cpCfg := controlplane.ControlPlaneConfig{ + Config: cfg, + WorkerCount: *workerCount, + SocketDir: *socketDir, + HandoverSocket: *handoverSocket, + } + controlplane.RunControlPlane(cpCfg) + return + } + + // Standalone mode (default) srv, err := server.New(cfg) if err != nil { fatal("Failed to create server: " + err.Error()) diff --git a/server/conn.go b/server/conn.go index 102853d..4de29ab 100644 --- a/server/conn.go +++ b/server/conn.go @@ -381,13 +381,15 @@ func (c *clientConn) serve() error { return fmt.Errorf("startup failed: %w", err) } - // Create a DuckDB connection for this client session - db, err := c.server.createDBConnection(c.username) - if err != nil { - c.sendError("FATAL", "28000", fmt.Sprintf("failed to open database: %v", err)) - return err + // Create a DuckDB connection for this client session (unless pre-created by caller) + if c.db == nil { + db, err := c.server.createDBConnection(c.username) + if err != nil { + c.sendError("FATAL", "28000", fmt.Sprintf("failed to open database: %v", err)) + return err + } + c.db = db } - c.db = db defer func() { if c.db != nil { c.safeCleanupDB() diff --git a/server/exports.go b/server/exports.go new file mode 100644 index 0000000..1ccfaf1 --- /dev/null +++ b/server/exports.go @@ -0,0 +1,90 @@ +package server + +import ( + "bufio" + "context" + "database/sql" + "io" + "net" +) + +// Exported wrappers for protocol functions used by the control plane worker. +// These delegate to the internal (lowercase) implementations. + +func ReadStartupMessage(r io.Reader) (map[string]string, error) { + return readStartupMessage(r) +} + +func ReadMessage(r io.Reader) (byte, []byte, error) { + return readMessage(r) +} + +func WriteAuthOK(w io.Writer) error { + return writeAuthOK(w) +} + +func WriteAuthCleartextPassword(w io.Writer) error { + return writeAuthCleartextPassword(w) +} + +func WriteReadyForQuery(w io.Writer, txStatus byte) error { + return writeReadyForQuery(w, txStatus) +} + +func WriteErrorResponse(w io.Writer, severity, code, message string) error { + return writeErrorResponse(w, severity, code, message) +} + +func WriteParameterStatus(w io.Writer, name, value string) error { + return writeParameterStatus(w, name, value) +} + +func WriteBackendKeyData(w io.Writer, pid, secretKey int32) error { + return writeBackendKeyData(w, pid, secretKey) +} + +// NewClientConn creates a clientConn with pre-initialized fields for use by +// the control plane worker. The returned value is opaque (*clientConn) but +// can be used with SendInitialParams and RunMessageLoop. +func NewClientConn(s *Server, conn net.Conn, reader *bufio.Reader, writer *bufio.Writer, + username, database string, db *sql.DB, pid, secretKey int32) *clientConn { + + return &clientConn{ + server: s, + conn: conn, + reader: reader, + writer: writer, + username: username, + database: database, + db: db, + pid: pid, + secretKey: secretKey, + stmts: make(map[string]*preparedStmt), + portals: make(map[string]*portal), + txStatus: txStatusIdle, + } +} + +// SendInitialParams sends the initial parameter status messages and backend key data. +func SendInitialParams(cc *clientConn) { + cc.sendInitialParams() +} + +// RunMessageLoop runs the main message loop for a client connection. +func RunMessageLoop(cc *clientConn) error { + return cc.messageLoop() +} + +// InitMinimalServer initializes a Server struct with minimal fields for use +// in control plane worker sessions. +func InitMinimalServer(s *Server, cfg Config, queryCancelCh <-chan struct{}) { + s.cfg = cfg + s.activeQueries = make(map[BackendKey]context.CancelFunc) + s.duckLakeSem = make(chan struct{}, 1) + s.externalCancelCh = queryCancelCh +} + +// GenerateSecretKey generates a cryptographically random secret key for cancel requests. +func GenerateSecretKey() int32 { + return generateSecretKey() +} diff --git a/server/server.go b/server/server.go index 4cceac3..ec06f46 100644 --- a/server/server.go +++ b/server/server.go @@ -398,8 +398,15 @@ func (s *Server) CancelQuery(key BackendKey) bool { } // createDBConnection creates a DuckDB connection for a client session. -// Uses in-memory database as an anchor for DuckLake attachment (actual data lives in RDS/S3). +// This is a thin wrapper around CreateDBConnection using the server's config. func (s *Server) createDBConnection(username string) (*sql.DB, error) { + return CreateDBConnection(s.cfg, s.duckLakeSem, username) +} + +// CreateDBConnection creates a DuckDB connection for a client session. +// Uses in-memory database as an anchor for DuckLake attachment (actual data lives in RDS/S3). +// This is a standalone function so it can be reused by both the server and control plane workers. +func CreateDBConnection(cfg Config, duckLakeSem chan struct{}, username string) (*sql.DB, error) { // Create new in-memory connection (DuckLake provides actual storage) db, err := sql.Open("duckdb", ":memory:") if err != nil { @@ -429,7 +436,7 @@ func (s *Server) createDBConnection(username string) (*sql.DB, error) { // Set temp directory to a subdirectory under DataDir to ensure DuckDB has a // writable location for intermediate results. This prevents "Read-only file system" // errors in containerized or restricted environments. - tempDir := filepath.Join(s.cfg.DataDir, "tmp") + tempDir := filepath.Join(cfg.DataDir, "tmp") if _, err := db.Exec(fmt.Sprintf("SET temp_directory = '%s'", tempDir)); err != nil { slog.Warn("Failed to set DuckDB temp_directory.", "temp_directory", tempDir, "error", err) } else { @@ -437,7 +444,7 @@ func (s *Server) createDBConnection(username string) (*sql.DB, error) { } // Load configured extensions - if err := s.loadExtensions(db); err != nil { + if err := LoadExtensions(db, cfg.Extensions); err != nil { slog.Warn("Failed to load some extensions.", "user", username, "error", err) } @@ -451,16 +458,16 @@ func (s *Server) createDBConnection(username string) (*sql.DB, error) { // Attach DuckLake catalog if configured (but don't set as default yet) duckLakeMode := false - if err := s.attachDuckLake(db); err != nil { + if err := AttachDuckLake(db, cfg.DuckLake, duckLakeSem); err != nil { // If DuckLake was explicitly configured, fail the connection. // Silent fallback to local DB causes schema/table mismatches. - if s.cfg.DuckLake.MetadataStore != "" { + if cfg.DuckLake.MetadataStore != "" { _ = db.Close() return nil, fmt.Errorf("DuckLake configured but attachment failed: %w", err) } // DuckLake not configured, this warning is just informational slog.Warn("Failed to attach DuckLake.", "user", username, "error", err) - } else if s.cfg.DuckLake.MetadataStore != "" { + } else if cfg.DuckLake.MetadataStore != "" { duckLakeMode = true // Recreate pg_class_full to source from DuckLake metadata instead of DuckDB's pg_catalog. @@ -497,14 +504,21 @@ func (s *Server) createDBConnection(username string) (*sql.DB, error) { return db, nil } -// loadExtensions installs and loads configured DuckDB extensions +// loadExtensions installs and loads configured DuckDB extensions. +// This is a thin wrapper around LoadExtensions using the server's config. func (s *Server) loadExtensions(db *sql.DB) error { - if len(s.cfg.Extensions) == 0 { + return LoadExtensions(db, s.cfg.Extensions) +} + +// LoadExtensions installs and loads DuckDB extensions. +// This is a standalone function so it can be reused by control plane workers. +func LoadExtensions(db *sql.DB, extensions []string) error { + if len(extensions) == 0 { return nil } var lastErr error - for _, ext := range s.cfg.Extensions { + for _, ext := range extensions { // First install the extension (downloads if needed) if _, err := db.Exec("INSTALL " + ext); err != nil { slog.Warn("Failed to install extension.", "extension", ext, "error", err) @@ -526,9 +540,16 @@ func (s *Server) loadExtensions(db *sql.DB) error { } // attachDuckLake attaches a DuckLake catalog if configured (but does NOT set it as default). -// Call setDuckLakeDefault after creating per-connection views in memory.main. +// This is a thin wrapper around AttachDuckLake using the server's config. func (s *Server) attachDuckLake(db *sql.DB) error { - if s.cfg.DuckLake.MetadataStore == "" { + return AttachDuckLake(db, s.cfg.DuckLake, s.duckLakeSem) +} + +// AttachDuckLake attaches a DuckLake catalog if configured (but does NOT set it as default). +// Call setDuckLakeDefault after creating per-connection views in memory.main. +// This is a standalone function so it can be reused by control plane workers. +func AttachDuckLake(db *sql.DB, dlCfg DuckLakeConfig, sem chan struct{}) error { + if dlCfg.MetadataStore == "" { return nil // DuckLake not configured } @@ -538,8 +559,8 @@ func (s *Server) attachDuckLake(db *sql.DB) error { // Use a 30-second timeout to prevent connections from hanging indefinitely // if attachment is slow (e.g., network latency to metadata store). select { - case s.duckLakeSem <- struct{}{}: - defer func() { <-s.duckLakeSem }() + case sem <- struct{}{}: + defer func() { <-sem }() case <-time.After(30 * time.Second): return fmt.Errorf("timeout waiting for DuckLake attachment lock") } @@ -555,15 +576,15 @@ func (s *Server) attachDuckLake(db *sql.DB) error { // Create S3 secret if using object store // - With explicit credentials (S3AccessKey set) or custom endpoint // - With credential_chain provider (for AWS S3) - if s.cfg.DuckLake.ObjectStore != "" { - needsSecret := s.cfg.DuckLake.S3Endpoint != "" || - s.cfg.DuckLake.S3AccessKey != "" || - s.cfg.DuckLake.S3Provider == "credential_chain" || - s.cfg.DuckLake.S3Chain != "" || - s.cfg.DuckLake.S3Profile != "" + if dlCfg.ObjectStore != "" { + needsSecret := dlCfg.S3Endpoint != "" || + dlCfg.S3AccessKey != "" || + dlCfg.S3Provider == "credential_chain" || + dlCfg.S3Chain != "" || + dlCfg.S3Profile != "" if needsSecret { - if err := s.createS3Secret(db); err != nil { + if err := createS3Secret(db, dlCfg); err != nil { return fmt.Errorf("failed to create S3 secret: %w", err) } } @@ -574,17 +595,17 @@ func (s *Server) attachDuckLake(db *sql.DB) error { // Format with data path: ATTACH 'ducklake:' AS ducklake (DATA_PATH '') // See: https://ducklake.select/docs/stable/duckdb/usage/connecting var attachStmt string - dataPath := s.cfg.DuckLake.ObjectStore + dataPath := dlCfg.ObjectStore if dataPath == "" { - dataPath = s.cfg.DuckLake.DataPath + dataPath = dlCfg.DataPath } if dataPath != "" { attachStmt = fmt.Sprintf("ATTACH 'ducklake:%s' AS ducklake (DATA_PATH '%s')", - s.cfg.DuckLake.MetadataStore, dataPath) - slog.Info("Attaching DuckLake catalog with data path.", "metadata", redactConnectionString(s.cfg.DuckLake.MetadataStore), "data", dataPath) + dlCfg.MetadataStore, dataPath) + slog.Info("Attaching DuckLake catalog with data path.", "metadata", redactConnectionString(dlCfg.MetadataStore), "data", dataPath) } else { - attachStmt = fmt.Sprintf("ATTACH 'ducklake:%s' AS ducklake", s.cfg.DuckLake.MetadataStore) - slog.Info("Attaching DuckLake catalog.", "metadata", redactConnectionString(s.cfg.DuckLake.MetadataStore)) + attachStmt = fmt.Sprintf("ATTACH 'ducklake:%s' AS ducklake", dlCfg.MetadataStore) + slog.Info("Attaching DuckLake catalog.", "metadata", redactConnectionString(dlCfg.MetadataStore)) } if _, err := db.Exec(attachStmt); err != nil { @@ -615,14 +636,15 @@ func setDuckLakeDefault(db *sql.DB) error { return nil } -// createS3Secret creates a DuckDB secret for S3/MinIO access +// createS3Secret creates a DuckDB secret for S3/MinIO access. +// This is a standalone function so it can be reused by control plane workers. // Supports two providers: // - "config": explicit credentials (for MinIO or when you have access keys) // - "credential_chain": AWS SDK credential chain (env vars, config files, instance metadata, etc.) // // Note: Caller must hold duckLakeSem to avoid race conditions. // See: https://duckdb.org/docs/stable/core_extensions/httpfs/s3api -func (s *Server) createS3Secret(db *sql.DB) error { +func createS3Secret(db *sql.DB, dlCfg DuckLakeConfig) error { // Check if secret already exists to avoid unnecessary creation var count int err := db.QueryRow("SELECT COUNT(*) FROM duckdb_secrets() WHERE name = 'ducklake_s3'").Scan(&count) @@ -631,9 +653,9 @@ func (s *Server) createS3Secret(db *sql.DB) error { } // Determine provider: use credential_chain if explicitly set or if no access key provided - provider := s.cfg.DuckLake.S3Provider + provider := dlCfg.S3Provider if provider == "" { - if s.cfg.DuckLake.S3AccessKey != "" { + if dlCfg.S3AccessKey != "" { provider = "config" } else { provider = "credential_chain" @@ -644,12 +666,12 @@ func (s *Server) createS3Secret(db *sql.DB) error { if provider == "credential_chain" { // Use AWS SDK credential chain - secretStmt = s.buildCredentialChainSecret() + secretStmt = buildCredentialChainSecret(dlCfg) slog.Info("Creating S3 secret with credential_chain provider.") } else { // Use explicit credentials (config provider) - secretStmt = s.buildConfigSecret() - slog.Info("Creating S3 secret with config provider.", "endpoint", s.cfg.DuckLake.S3Endpoint) + secretStmt = buildConfigSecret(dlCfg) + slog.Info("Creating S3 secret with config provider.", "endpoint", dlCfg.S3Endpoint) } if _, err := db.Exec(secretStmt); err != nil { @@ -661,19 +683,19 @@ func (s *Server) createS3Secret(db *sql.DB) error { } // buildConfigSecret builds a CREATE SECRET statement with explicit credentials -func (s *Server) buildConfigSecret() string { - region := s.cfg.DuckLake.S3Region +func buildConfigSecret(dlCfg DuckLakeConfig) string { + region := dlCfg.S3Region if region == "" { region = "us-east-1" } - urlStyle := s.cfg.DuckLake.S3URLStyle + urlStyle := dlCfg.S3URLStyle if urlStyle == "" { urlStyle = "path" // Default to path style for MinIO compatibility } useSSL := "false" - if s.cfg.DuckLake.S3UseSSL { + if dlCfg.S3UseSSL { useSSL = "true" } @@ -687,16 +709,16 @@ func (s *Server) buildConfigSecret() string { REGION '%s', URL_STYLE '%s', USE_SSL %s`, - s.cfg.DuckLake.S3AccessKey, - s.cfg.DuckLake.S3SecretKey, + dlCfg.S3AccessKey, + dlCfg.S3SecretKey, region, urlStyle, useSSL, ) // Add endpoint if specified (for MinIO or custom S3-compatible storage) - if s.cfg.DuckLake.S3Endpoint != "" { - secret += fmt.Sprintf(",\n\t\t\tENDPOINT '%s'", s.cfg.DuckLake.S3Endpoint) + if dlCfg.S3Endpoint != "" { + secret += fmt.Sprintf(",\n\t\t\tENDPOINT '%s'", dlCfg.S3Endpoint) } secret += "\n\t\t)" @@ -704,7 +726,7 @@ func (s *Server) buildConfigSecret() string { } // buildCredentialChainSecret builds a CREATE SECRET statement using AWS SDK credential chain -func (s *Server) buildCredentialChainSecret() string { +func buildCredentialChainSecret(dlCfg DuckLakeConfig) string { // Start with base credential_chain secret secret := ` CREATE OR REPLACE SECRET ducklake_s3 ( @@ -712,33 +734,33 @@ func (s *Server) buildCredentialChainSecret() string { PROVIDER credential_chain` // Add chain if specified (e.g., "env;config" to check specific sources) - if s.cfg.DuckLake.S3Chain != "" { - secret += fmt.Sprintf(",\n\t\t\tCHAIN '%s'", s.cfg.DuckLake.S3Chain) + if dlCfg.S3Chain != "" { + secret += fmt.Sprintf(",\n\t\t\tCHAIN '%s'", dlCfg.S3Chain) } // Add profile if specified (for config chain) - if s.cfg.DuckLake.S3Profile != "" { - secret += fmt.Sprintf(",\n\t\t\tPROFILE '%s'", s.cfg.DuckLake.S3Profile) + if dlCfg.S3Profile != "" { + secret += fmt.Sprintf(",\n\t\t\tPROFILE '%s'", dlCfg.S3Profile) } // Add region override if specified - if s.cfg.DuckLake.S3Region != "" { - secret += fmt.Sprintf(",\n\t\t\tREGION '%s'", s.cfg.DuckLake.S3Region) + if dlCfg.S3Region != "" { + secret += fmt.Sprintf(",\n\t\t\tREGION '%s'", dlCfg.S3Region) } // Add endpoint if specified (for custom S3-compatible storage) - if s.cfg.DuckLake.S3Endpoint != "" { - secret += fmt.Sprintf(",\n\t\t\tENDPOINT '%s'", s.cfg.DuckLake.S3Endpoint) + if dlCfg.S3Endpoint != "" { + secret += fmt.Sprintf(",\n\t\t\tENDPOINT '%s'", dlCfg.S3Endpoint) // Also set URL style and SSL for custom endpoints - urlStyle := s.cfg.DuckLake.S3URLStyle + urlStyle := dlCfg.S3URLStyle if urlStyle == "" { urlStyle = "path" } secret += fmt.Sprintf(",\n\t\t\tURL_STYLE '%s'", urlStyle) useSSL := "false" - if s.cfg.DuckLake.S3UseSSL { + if dlCfg.S3UseSSL { useSSL = "true" } secret += fmt.Sprintf(",\n\t\t\tUSE_SSL %s", useSSL) diff --git a/server/worker.go b/server/worker.go index ebbed35..19f487d 100644 --- a/server/worker.go +++ b/server/worker.go @@ -358,13 +358,6 @@ type workerServer struct { } // createDBConnection creates a DuckDB connection for the child worker. -// This is a copy of Server.createDBConnection adapted for the worker context. func (w *workerServer) createDBConnection(username string) (*sql.DB, error) { - // Import sql package - we'll use the Server's method through embedding - // For now, create a temporary Server instance to reuse the logic - tempServer := &Server{ - cfg: w.cfg, - duckLakeSem: make(chan struct{}, 1), - } - return tempServer.createDBConnection(username) + return CreateDBConnection(w.cfg, make(chan struct{}, 1), username) } From b741e5242589ccc97cd762da81b6899d3f484a82 Mon Sep 17 00:00:00 2001 From: James Greenhill Date: Fri, 6 Feb 2026 03:17:47 +0000 Subject: [PATCH 2/3] Fix done channel leak and per-query cancellation in control plane ConnectExistingWorker never closed the done channel, causing ShutdownAll and RollingUpdate to always hit their timeout for handed-over workers. Add a health-check monitoring goroutine. CancelQuery killed the entire session instead of just the running query. Use the per-session minServer.CancelQuery() to cancel only the in-flight query, matching standalone mode behavior. Co-Authored-By: Claude Opus 4.6 --- controlplane/pool.go | 22 ++++++++++++++++++++++ controlplane/worker.go | 28 ++++++++++++++++++++++------ 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/controlplane/pool.go b/controlplane/pool.go index 900fc44..088c74b 100644 --- a/controlplane/pool.go +++ b/controlplane/pool.go @@ -185,6 +185,28 @@ func (p *WorkerPool) ConnectExistingWorker(id int, grpcSocket, fdSocket string) done: make(chan struct{}), } + // Monitor the handed-over worker via gRPC health checks. + // Unlike SpawnWorker where we can cmd.Wait(), we don't own this process, + // so we detect exit by polling health. This ensures done is closed for + // ShutdownAll and RollingUpdate which wait on <-w.done. + go func() { + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + for range ticker.C { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + _, err := client.Health(ctx, &pb.HealthRequest{}) + cancel() + if err != nil { + slog.Info("Handed-over worker unreachable, marking as done.", "id", id, "error", err) + close(worker.done) + p.mu.Lock() + delete(p.workers, id) + p.mu.Unlock() + return + } + } + }() + p.mu.Lock() p.workers[id] = worker p.mu.Unlock() diff --git a/controlplane/worker.go b/controlplane/worker.go index d7f8170..bab90f8 100644 --- a/controlplane/worker.go +++ b/controlplane/worker.go @@ -53,10 +53,11 @@ type Worker struct { } type workerSession struct { - pid int32 - secretKey int32 - cancel context.CancelFunc + pid int32 + secretKey int32 + cancel context.CancelFunc remoteAddr string + minServer *server.Server // per-session server for query cancellation } // RunWorker is the entry point for a worker process. @@ -430,6 +431,14 @@ func (w *Worker) handleSession(tcpConn *net.TCPConn, remoteAddr string, pid, sec minServer := &server.Server{} server.InitMinimalServer(minServer, w.cfg, queryCancelCh) + // Store minServer in session so CancelQuery can cancel individual queries + // rather than tearing down the entire session. + w.sessionsMu.Lock() + if s, ok := w.sessions[pid]; ok { + s.minServer = minServer + } + w.sessionsMu.Unlock() + cc := server.NewClientConn(minServer, tlsConn, reader, writer, username, database, db, pid, secretKey) // Send initial params and ready for query @@ -468,9 +477,16 @@ func (w *Worker) CancelQuery(_ context.Context, req *pb.CancelQueryRequest) (*pb for _, s := range w.sessions { if s.pid == req.BackendPid && s.secretKey == req.SecretKey { - s.cancel() - slog.Info("Query cancelled via gRPC.", "pid", req.BackendPid) - return &pb.CancelQueryResponse{Cancelled: true}, nil + // Cancel the active query via the per-session server's activeQueries map, + // matching standalone mode behavior. This cancels only the in-flight query + // while keeping the connection alive (instead of tearing down the session). + key := server.BackendKey{Pid: req.BackendPid, SecretKey: req.SecretKey} + if s.minServer != nil && s.minServer.CancelQuery(key) { + slog.Info("Query cancelled via gRPC.", "pid", req.BackendPid) + return &pb.CancelQueryResponse{Cancelled: true}, nil + } + // No active query registered yet (session still initializing) + return &pb.CancelQueryResponse{Cancelled: false}, nil } } return &pb.CancelQueryResponse{Cancelled: false}, nil From 9b4c408ffe0dc21094ed1d71aae96dde8f0b9db4 Mon Sep 17 00:00:00 2001 From: James Greenhill Date: Fri, 6 Feb 2026 03:23:41 +0000 Subject: [PATCH 3/3] Fix lint errors: unchecked Close() returns and unused code - Add _ = prefix to all unchecked .Close() return values (errcheck) - Remove unused nextWorker field from WorkerPool (unused) - Remove unused activeQueriesMu field from Worker (unused) - Remove unused loadExtensions/attachDuckLake method wrappers (unused) Co-Authored-By: Claude Opus 4.6 --- controlplane/control.go | 2 +- controlplane/fdpass/fdpass.go | 12 ++++----- controlplane/fdpass/fdpass_test.go | 26 +++++++++--------- controlplane/handover.go | 20 +++++++------- controlplane/pool.go | 25 ++++++++--------- controlplane/worker.go | 43 +++++++++++++++--------------- server/server.go | 12 --------- 7 files changed, 62 insertions(+), 78 deletions(-) diff --git a/controlplane/control.go b/controlplane/control.go index 2b118c6..168fa9c 100644 --- a/controlplane/control.go +++ b/controlplane/control.go @@ -331,7 +331,7 @@ func (cp *ControlPlane) shutdown() { } slog.Info("Draining workers...") - cp.pool.DrainAll(30 * time.Second) + _ = cp.pool.DrainAll(30 * time.Second) slog.Info("Shutting down workers...") cp.pool.ShutdownAll(30 * time.Second) diff --git a/controlplane/fdpass/fdpass.go b/controlplane/fdpass/fdpass.go index 9093fed..e3ce02b 100644 --- a/controlplane/fdpass/fdpass.go +++ b/controlplane/fdpass/fdpass.go @@ -77,15 +77,15 @@ func SocketPair() (*net.UnixConn, *net.UnixConn, error) { sender, err := fdToUnixConn(fds[0], "sender") if err != nil { - syscall.Close(fds[0]) - syscall.Close(fds[1]) + _ = syscall.Close(fds[0]) + _ = syscall.Close(fds[1]) return nil, nil, err } receiver, err := fdToUnixConn(fds[1], "receiver") if err != nil { - sender.Close() - syscall.Close(fds[1]) + _ = sender.Close() + _ = syscall.Close(fds[1]) return nil, nil, err } @@ -97,7 +97,7 @@ func fdToUnixConn(fd int, name string) (*net.UnixConn, error) { if f == nil { return nil, fmt.Errorf("invalid fd %d", fd) } - defer f.Close() + defer func() { _ = f.Close() }() fc, err := net.FileConn(f) if err != nil { @@ -106,7 +106,7 @@ func fdToUnixConn(fd int, name string) (*net.UnixConn, error) { uc, ok := fc.(*net.UnixConn) if !ok { - fc.Close() + _ = fc.Close() return nil, fmt.Errorf("not a UnixConn") } return uc, nil diff --git a/controlplane/fdpass/fdpass_test.go b/controlplane/fdpass/fdpass_test.go index 3f575aa..90d8dc7 100644 --- a/controlplane/fdpass/fdpass_test.go +++ b/controlplane/fdpass/fdpass_test.go @@ -12,15 +12,15 @@ func TestSendRecvFD(t *testing.T) { if err != nil { t.Fatalf("SocketPair: %v", err) } - defer sender.Close() - defer receiver.Close() + defer func() { _ = sender.Close() }() + defer func() { _ = receiver.Close() }() // Create a temp file to pass tmp, err := os.CreateTemp("", "fdpass-test-*") if err != nil { t.Fatalf("CreateTemp: %v", err) } - defer os.Remove(tmp.Name()) + defer func() { _ = os.Remove(tmp.Name()) }() // Write some data msg := "hello from fd passing" @@ -35,14 +35,14 @@ func TestSendRecvFD(t *testing.T) { if err := SendFile(sender, tmp); err != nil { t.Fatalf("SendFile: %v", err) } - tmp.Close() + _ = tmp.Close() // Receive the FD received, err := RecvFile(receiver, "received") if err != nil { t.Fatalf("RecvFile: %v", err) } - defer received.Close() + defer func() { _ = received.Close() }() // Read data from received FD to verify it works if _, err := received.Seek(0, 0); err != nil { @@ -64,15 +64,15 @@ func TestSendRecvTCPConn(t *testing.T) { if err != nil { t.Fatalf("SocketPair: %v", err) } - defer sender.Close() - defer receiver.Close() + defer func() { _ = sender.Close() }() + defer func() { _ = receiver.Close() }() // Create a TCP listener ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("Listen: %v", err) } - defer ln.Close() + defer func() { _ = ln.Close() }() // Connect a TCP client clientConn, err := net.Dial("tcp", ln.Addr().String()) @@ -92,13 +92,13 @@ func TestSendRecvTCPConn(t *testing.T) { if err != nil { t.Fatalf("File: %v", err) } - serverConn.Close() // Close the original; file has a dup'd FD + _ = serverConn.Close() // Close the original; file has a dup'd FD // Send the TCP FD if err := SendFile(sender, file); err != nil { t.Fatalf("SendFile: %v", err) } - file.Close() + _ = file.Close() // Receive the TCP FD in the "worker" recvFile, err := RecvFile(receiver, "tcp-conn") @@ -111,8 +111,8 @@ func TestSendRecvTCPConn(t *testing.T) { if err != nil { t.Fatalf("FileConn: %v", err) } - recvFile.Close() - defer fc.Close() + _ = recvFile.Close() + defer func() { _ = fc.Close() }() // Write from the reconstructed connection, read from client msg := "hello via fd passing" @@ -129,5 +129,5 @@ func TestSendRecvTCPConn(t *testing.T) { t.Errorf("got %q, want %q", string(buf[:n]), msg) } - clientConn.Close() + _ = clientConn.Close() } diff --git a/controlplane/handover.go b/controlplane/handover.go index d4f9909..131ae2f 100644 --- a/controlplane/handover.go +++ b/controlplane/handover.go @@ -31,7 +31,7 @@ func (cp *ControlPlane) startHandoverListener() { } // Clean up old socket - os.Remove(cp.cfg.HandoverSocket) + _ = os.Remove(cp.cfg.HandoverSocket) ln, err := net.Listen("unix", cp.cfg.HandoverSocket) if err != nil { @@ -42,8 +42,8 @@ func (cp *ControlPlane) startHandoverListener() { slog.Info("Handover listener started.", "socket", cp.cfg.HandoverSocket) go func() { - defer ln.Close() - defer os.Remove(cp.cfg.HandoverSocket) + defer func() { _ = ln.Close() }() + defer func() { _ = os.Remove(cp.cfg.HandoverSocket) }() for { conn, err := ln.Accept() @@ -67,8 +67,8 @@ func (cp *ControlPlane) startHandoverListener() { // handleHandoverRequest processes an incoming handover request from a new control plane. func (cp *ControlPlane) handleHandoverRequest(conn net.Conn, handoverLn net.Listener) { - defer conn.Close() - defer handoverLn.Close() + defer func() { _ = conn.Close() }() + defer func() { _ = handoverLn.Close() }() decoder := json.NewDecoder(conn) encoder := json.NewEncoder(conn) @@ -119,7 +119,7 @@ func (cp *ControlPlane) handleHandoverRequest(conn net.Conn, handoverLn net.List slog.Error("Failed to get listener FD.", "error", err) return } - defer file.Close() + defer func() { _ = file.Close() }() uc, ok := conn.(*net.UnixConn) if !ok { @@ -171,7 +171,7 @@ func receiveHandover(handoverSocket string) (*net.TCPListener, []handoverWorker, if err != nil { return nil, nil, fmt.Errorf("connect handover socket: %w", err) } - defer conn.Close() + defer func() { _ = conn.Close() }() decoder := json.NewDecoder(conn) encoder := json.NewEncoder(conn) @@ -201,7 +201,7 @@ func receiveHandover(handoverSocket string) (*net.TCPListener, []handoverWorker, if err != nil { return nil, nil, fmt.Errorf("receive listener FD: %w", err) } - defer file.Close() + defer func() { _ = file.Close() }() // Reconstruct listener from FD ln, err := net.FileListener(file) @@ -211,13 +211,13 @@ func receiveHandover(handoverSocket string) (*net.TCPListener, []handoverWorker, tcpLn, ok := ln.(*net.TCPListener) if !ok { - ln.Close() + _ = ln.Close() return nil, nil, fmt.Errorf("not a TCP listener") } // Send handover complete if err := encoder.Encode(handoverMsg{Type: "handover_complete"}); err != nil { - tcpLn.Close() + _ = tcpLn.Close() return nil, nil, fmt.Errorf("send handover complete: %w", err) } diff --git a/controlplane/pool.go b/controlplane/pool.go index 088c74b..cface75 100644 --- a/controlplane/pool.go +++ b/controlplane/pool.go @@ -38,9 +38,6 @@ type WorkerPool struct { workers map[int]*ManagedWorker socketDir string cfg server.Config - - // Round-robin counter for simple load balancing - nextWorker int } // NewWorkerPool creates a new worker pool. @@ -58,8 +55,8 @@ func (p *WorkerPool) SpawnWorker(id int) error { fdSocket := filepath.Join(p.socketDir, fmt.Sprintf("worker-%d-fd.sock", id)) // Clean up old sockets - os.Remove(grpcSocket) - os.Remove(fdSocket) + _ = os.Remove(grpcSocket) + _ = os.Remove(fdSocket) // Spawn child process cmd := exec.Command(os.Args[0], @@ -102,12 +99,12 @@ func (p *WorkerPool) SpawnWorker(id int) error { resp, err := client.Configure(ctx, configReq) cancel() if err != nil { - conn.Close() + _ = conn.Close() _ = cmd.Process.Kill() return fmt.Errorf("configure worker %d: %w", id, err) } if !resp.Ok { - conn.Close() + _ = conn.Close() _ = cmd.Process.Kill() return fmt.Errorf("configure worker %d: %s", id, resp.Error) } @@ -167,11 +164,11 @@ func (p *WorkerPool) ConnectExistingWorker(id int, grpcSocket, fdSocket string) health, err := client.Health(ctx, &pb.HealthRequest{}) cancel() if err != nil { - conn.Close() + _ = conn.Close() return fmt.Errorf("health check worker %d: %w", id, err) } if !health.Healthy { - conn.Close() + _ = conn.Close() return fmt.Errorf("worker %d is not healthy", id) } @@ -237,15 +234,15 @@ func (p *WorkerPool) RouteConnection(tcpFile *os.File, remoteAddr string, secret uc, ok := fdConn.(*net.UnixConn) if !ok { - fdConn.Close() + _ = fdConn.Close() return 0, fmt.Errorf("FD conn not UnixConn") } if err := fdpass.SendFile(uc, tcpFile); err != nil { - uc.Close() + _ = uc.Close() return 0, fmt.Errorf("send FD to worker %d: %w", worker.ID, err) } - uc.Close() + _ = uc.Close() // Tell worker to accept the connection via gRPC ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) @@ -470,7 +467,7 @@ func (p *WorkerPool) RollingUpdate(ctx context.Context) error { } } - old.GRPCConn.Close() + _ = old.GRPCConn.Close() // Spawn replacement if err := p.SpawnWorker(id); err != nil { @@ -520,7 +517,7 @@ func waitForSocket(path string, timeout time.Duration) error { // Try to connect conn, err := net.DialTimeout("unix", path, time.Second) if err == nil { - conn.Close() + _ = conn.Close() return nil } } diff --git a/controlplane/worker.go b/controlplane/worker.go index bab90f8..47e5406 100644 --- a/controlplane/worker.go +++ b/controlplane/worker.go @@ -45,8 +45,7 @@ type Worker struct { sessionsWg sync.WaitGroup // Cancellation - activeQueries map[server.BackendKey]context.CancelFunc - activeQueriesMu sync.RWMutex + activeQueries map[server.BackendKey]context.CancelFunc // FD passing - stores the most recently received FD pendingFD int @@ -97,7 +96,7 @@ func (w *Worker) run(ctx context.Context) error { if err != nil { return fmt.Errorf("listen gRPC socket: %w", err) } - defer os.Remove(w.grpcSocketPath) + defer func() { _ = os.Remove(w.grpcSocketPath) }() grpcServer := grpc.NewServer() pb.RegisterWorkerControlServer(grpcServer, w) @@ -113,7 +112,7 @@ func (w *Worker) run(ctx context.Context) error { if err != nil { return fmt.Errorf("listen FD socket: %w", err) } - defer os.Remove(w.fdSocketPath) + defer func() { _ = os.Remove(w.fdSocketPath) }() // Accept FD connections in a goroutine go w.fdReceiverLoop(fdLn) @@ -173,13 +172,13 @@ func (w *Worker) fdReceiverLoop(ln net.Listener) { uc, ok := conn.(*net.UnixConn) if !ok { slog.Error("FD socket: not a UnixConn") - conn.Close() + _ = conn.Close() continue } // Receive the FD fd, err := fdpass.RecvFD(uc) - uc.Close() + _ = uc.Close() if err != nil { slog.Error("Failed to receive FD.", "error", err) continue @@ -274,14 +273,14 @@ func (w *Worker) AcceptConnection(_ context.Context, req *pb.AcceptConnectionReq return &pb.AcceptConnectionResponse{Ok: false, Error: "invalid FD"}, nil } fc, err := net.FileConn(file) - file.Close() + _ = file.Close() if err != nil { return &pb.AcceptConnectionResponse{Ok: false, Error: fmt.Sprintf("FileConn: %v", err)}, nil } tcpConn, ok := fc.(*net.TCPConn) if !ok { - fc.Close() + _ = fc.Close() return &pb.AcceptConnectionResponse{Ok: false, Error: "not TCP"}, nil } @@ -312,17 +311,17 @@ func (w *Worker) handleSession(tcpConn *net.TCPConn, remoteAddr string, pid, sec tlsConn := tls.Server(tcpConn, w.tlsConfig) if err := tlsConn.SetDeadline(time.Now().Add(30 * time.Second)); err != nil { slog.Error("Failed to set TLS deadline.", "error", err) - tcpConn.Close() + _ = tcpConn.Close() return } if err := tlsConn.Handshake(); err != nil { slog.Error("TLS handshake failed.", "error", err, "remote_addr", remoteAddr) - tcpConn.Close() + _ = tcpConn.Close() return } if err := tlsConn.SetDeadline(time.Time{}); err != nil { slog.Error("Failed to clear TLS deadline.", "error", err) - tlsConn.Close() + _ = tlsConn.Close() return } @@ -333,7 +332,7 @@ func (w *Worker) handleSession(tcpConn *net.TCPConn, remoteAddr string, pid, sec params, err := server.ReadStartupMessage(reader) if err != nil { slog.Error("Failed to read startup message.", "error", err, "remote_addr", remoteAddr) - tlsConn.Close() + _ = tlsConn.Close() return } @@ -343,7 +342,7 @@ func (w *Worker) handleSession(tcpConn *net.TCPConn, remoteAddr string, pid, sec if username == "" { _ = server.WriteErrorResponse(writer, "FATAL", "28000", "no user specified") _ = writer.Flush() - tlsConn.Close() + _ = tlsConn.Close() return } @@ -352,31 +351,31 @@ func (w *Worker) handleSession(tcpConn *net.TCPConn, remoteAddr string, pid, sec if !ok { _ = server.WriteErrorResponse(writer, "FATAL", "28P01", "password authentication failed") _ = writer.Flush() - tlsConn.Close() + _ = tlsConn.Close() return } if err := server.WriteAuthCleartextPassword(writer); err != nil { slog.Error("Failed to request password.", "error", err) - tlsConn.Close() + _ = tlsConn.Close() return } if err := writer.Flush(); err != nil { slog.Error("Failed to flush.", "error", err) - tlsConn.Close() + _ = tlsConn.Close() return } msgType, body, err := server.ReadMessage(reader) if err != nil { slog.Error("Failed to read password.", "error", err) - tlsConn.Close() + _ = tlsConn.Close() return } if msgType != 'p' { _ = server.WriteErrorResponse(writer, "FATAL", "28000", "expected password message") _ = writer.Flush() - tlsConn.Close() + _ = tlsConn.Close() return } @@ -384,13 +383,13 @@ func (w *Worker) handleSession(tcpConn *net.TCPConn, remoteAddr string, pid, sec if password != expectedPassword { _ = server.WriteErrorResponse(writer, "FATAL", "28P01", "password authentication failed") _ = writer.Flush() - tlsConn.Close() + _ = tlsConn.Close() return } if err := server.WriteAuthOK(writer); err != nil { slog.Error("Failed to send auth OK.", "error", err) - tlsConn.Close() + _ = tlsConn.Close() return } @@ -402,7 +401,7 @@ func (w *Worker) handleSession(tcpConn *net.TCPConn, remoteAddr string, pid, sec slog.Error("Failed to create session DB.", "error", err) _ = server.WriteErrorResponse(writer, "FATAL", "28000", fmt.Sprintf("failed to open database: %v", err)) _ = writer.Flush() - tlsConn.Close() + _ = tlsConn.Close() return } defer w.dbPool.CloseSession(pid) @@ -467,7 +466,7 @@ func (w *Worker) handleSession(tcpConn *net.TCPConn, remoteAddr string, pid, sec } case <-sessionCtx.Done(): slog.Info("Session cancelled.", "user", username, "pid", pid) - tlsConn.Close() + _ = tlsConn.Close() } } diff --git a/server/server.go b/server/server.go index ec06f46..cc2337e 100644 --- a/server/server.go +++ b/server/server.go @@ -504,12 +504,6 @@ func CreateDBConnection(cfg Config, duckLakeSem chan struct{}, username string) return db, nil } -// loadExtensions installs and loads configured DuckDB extensions. -// This is a thin wrapper around LoadExtensions using the server's config. -func (s *Server) loadExtensions(db *sql.DB) error { - return LoadExtensions(db, s.cfg.Extensions) -} - // LoadExtensions installs and loads DuckDB extensions. // This is a standalone function so it can be reused by control plane workers. func LoadExtensions(db *sql.DB, extensions []string) error { @@ -539,12 +533,6 @@ func LoadExtensions(db *sql.DB, extensions []string) error { return lastErr } -// attachDuckLake attaches a DuckLake catalog if configured (but does NOT set it as default). -// This is a thin wrapper around AttachDuckLake using the server's config. -func (s *Server) attachDuckLake(db *sql.DB) error { - return AttachDuckLake(db, s.cfg.DuckLake, s.duckLakeSem) -} - // AttachDuckLake attaches a DuckLake catalog if configured (but does NOT set it as default). // Call setDuckLakeDefault after creating per-connection views in memory.main. // This is a standalone function so it can be reused by control plane workers.