diff --git a/CLAUDE.md b/CLAUDE.md index ca9620e..0c61728 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -37,7 +37,7 @@ This is a library project without a main entry point. Integration is done by imp - Intent-based transaction system: clients submit intents, sequencer processes them - Key components: - `ocp/rpc/`: gRPC service implementations (entry points for client requests): `transaction`, `account`, `currency`, `messaging` - - `ocp/worker/`: Background workers (nonce management, swap processing, sequencer, Geyser integration, account sync, currency tasks) + - `ocp/worker/`: Background workers (nonce management, swap processing, sequencer, Geyser integration, account sync, currency tasks, guaranteed task execution) - `ocp/transaction/`: Transaction building and local nonce pool management - `ocp/data/`: Data layer with Store interfaces for all domain entities - `ocp/integration/`: Pluggable integration hooks (SubmitIntent, Swap, Geyser, Moderation, Antispam) for application-specific behavior @@ -77,6 +77,13 @@ This is a library project without a main entry point. Integration is done by imp - `solana/coinbasestableswapper/`: Solana program interface for Coinbase's Stable Swapper (used for USDC↔USDF) - `coinbase/`: HTTP client for the Coinbase Developer Platform Onramp API (JWT auth, used as a swap funding source) +**Task System** (`ocp/task/`, `ocp/worker/task/`, `ocp/data/task/`) +- Durable, app-defined background work with at-least-once execution guarantees +- Apps return tasks from the `GetTasksToSchedule` SubmitIntent hook; the scheduler (`ocp/task/scheduler.go`) enqueues them in the *same* DB transaction that commits the intent, so scheduling is atomic with the intent +- After commit, a best-effort fast path attempts immediate execution; the background worker (`ocp/worker/task/`) sweeps `StatePending` tasks whose `NextAttemptAt` has elapsed and retries with backoff, dead-lettering to `StateFailed` after a max attempt count +- Execution is delegated to the app's `TaskExecutor` integration hook. Because tasks can run concurrently and more than once (fast path + worker, multiple processes), **TaskExecutor implementations must be idempotent** — the task ID is the natural dedup key. Concurrent state advances are resolved by an optimistic version check on the task record +- Task `Type`/`Data` are opaque to the base system; the implementing app owns their namespace and serialization + **Solana Integration** - `solana/` package: Low-level Solana primitives (accounts, transactions, programs) - `solana/token/`: SPL Token program interface diff --git a/ocp/data/internal.go b/ocp/data/internal.go index 2ac7bb3..2fd5e42 100644 --- a/ocp/data/internal.go +++ b/ocp/data/internal.go @@ -31,6 +31,7 @@ import ( "github.com/code-payments/ocp-server/ocp/data/nonce" "github.com/code-payments/ocp-server/ocp/data/rendezvous" "github.com/code-payments/ocp-server/ocp/data/swap" + "github.com/code-payments/ocp-server/ocp/data/task" "github.com/code-payments/ocp-server/ocp/data/timelock" "github.com/code-payments/ocp-server/ocp/data/transaction" "github.com/code-payments/ocp-server/ocp/data/vault" @@ -49,6 +50,7 @@ import ( nonce_memory_client "github.com/code-payments/ocp-server/ocp/data/nonce/memory" rendezvous_memory_client "github.com/code-payments/ocp-server/ocp/data/rendezvous/memory" swap_memory_client "github.com/code-payments/ocp-server/ocp/data/swap/memory" + task_memory_client "github.com/code-payments/ocp-server/ocp/data/task/memory" timelock_memory_client "github.com/code-payments/ocp-server/ocp/data/timelock/memory" transaction_memory_client "github.com/code-payments/ocp-server/ocp/data/transaction/memory" vault_memory_client "github.com/code-payments/ocp-server/ocp/data/vault/memory" @@ -67,6 +69,7 @@ import ( nonce_postgres_client "github.com/code-payments/ocp-server/ocp/data/nonce/postgres" rendezvous_postgres_client "github.com/code-payments/ocp-server/ocp/data/rendezvous/postgres" swap_postgres_client "github.com/code-payments/ocp-server/ocp/data/swap/postgres" + task_postgres_client "github.com/code-payments/ocp-server/ocp/data/task/postgres" timelock_postgres_client "github.com/code-payments/ocp-server/ocp/data/timelock/postgres" transaction_postgres_client "github.com/code-payments/ocp-server/ocp/data/transaction/postgres" vault_postgres_client "github.com/code-payments/ocp-server/ocp/data/vault/postgres" @@ -228,6 +231,14 @@ type DatabaseData interface { GetAllSwapsByState(ctx context.Context, state swap.State, opts ...query.Option) ([]*swap.Record, error) GetSwapCountByState(ctx context.Context, state swap.State) (uint64, error) + // Tasks + // -------------------------------------------------------------------------------- + PutAllTasks(ctx context.Context, records ...*task.Record) error + UpdateTask(ctx context.Context, record *task.Record) error + GetTaskById(ctx context.Context, taskId string) (*task.Record, error) + GetAllReadyTasksByState(ctx context.Context, state task.State, asOf time.Time, opts ...query.Option) ([]*task.Record, error) + GetTaskCountByState(ctx context.Context, state task.State) (uint64, error) + // Timelocks // -------------------------------------------------------------------------------- SaveTimelock(ctx context.Context, record *timelock.Record) error @@ -293,6 +304,7 @@ type DatabaseProvider struct { nonces nonce.Store rendezvous rendezvous.Store swaps swap.Store + tasks task.Store timelocks timelock.Store transactions transaction.Store vault vault.Store @@ -339,6 +351,7 @@ func NewDatabaseProvider(dbConfig *pg.Config) (DatabaseData, error) { nonces: nonce_postgres_client.New(db), rendezvous: rendezvous_postgres_client.New(db), swaps: swap_postgres_client.New(db), + tasks: task_postgres_client.New(db), timelocks: timelock_postgres_client.New(db), transactions: transaction_postgres_client.New(db), vault: vault_postgres_client.New(db), @@ -366,6 +379,7 @@ func NewTestDatabaseProvider() DatabaseData { nonces: nonce_memory_client.New(), rendezvous: rendezvous_memory_client.New(), swaps: swap_memory_client.New(), + tasks: task_memory_client.New(), timelocks: timelock_memory_client.New(), transactions: transaction_memory_client.New(), vault: vault_memory_client.New(), @@ -818,6 +832,29 @@ func (dp *DatabaseProvider) GetSwapCountByState(ctx context.Context, state swap. return dp.swaps.CountByState(ctx, state) } +// Tasks +// -------------------------------------------------------------------------------- + +func (dp *DatabaseProvider) PutAllTasks(ctx context.Context, records ...*task.Record) error { + return dp.tasks.PutAll(ctx, records...) +} +func (dp *DatabaseProvider) UpdateTask(ctx context.Context, record *task.Record) error { + return dp.tasks.Update(ctx, record) +} +func (dp *DatabaseProvider) GetTaskById(ctx context.Context, taskId string) (*task.Record, error) { + return dp.tasks.GetByTaskId(ctx, taskId) +} +func (dp *DatabaseProvider) GetAllReadyTasksByState(ctx context.Context, state task.State, asOf time.Time, opts ...query.Option) ([]*task.Record, error) { + req, err := query.DefaultPaginationHandler(opts...) + if err != nil { + return nil, err + } + return dp.tasks.GetAllReadyByState(ctx, state, asOf, req.Cursor, req.Limit, req.SortBy) +} +func (dp *DatabaseProvider) GetTaskCountByState(ctx context.Context, state task.State) (uint64, error) { + return dp.tasks.CountByState(ctx, state) +} + // Timelocks // -------------------------------------------------------------------------------- func (dp *DatabaseProvider) SaveTimelock(ctx context.Context, record *timelock.Record) error { diff --git a/ocp/data/task/memory/store.go b/ocp/data/task/memory/store.go new file mode 100644 index 0000000..06ba019 --- /dev/null +++ b/ocp/data/task/memory/store.go @@ -0,0 +1,214 @@ +package memory + +import ( + "context" + "errors" + "sort" + "sync" + "time" + + "github.com/code-payments/ocp-server/database/query" + "github.com/code-payments/ocp-server/ocp/data/task" +) + +type ById []*task.Record + +func (a ById) Len() int { return len(a) } +func (a ById) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a ById) Less(i, j int) bool { return a[i].Id < a[j].Id } + +type store struct { + mu sync.RWMutex + records []*task.Record + last uint64 +} + +func New() task.Store { + return &store{} +} + +func (s *store) PutAll(ctx context.Context, records ...*task.Record) error { + if len(records) == 0 { + return errors.New("empty task set") + } + + for _, data := range records { + if data.Id > 0 { + return task.ErrExists + } + + if err := data.Validate(); err != nil { + return err + } + } + + s.mu.Lock() + defer s.mu.Unlock() + + seen := make(map[string]struct{}) + for _, data := range records { + if _, ok := seen[data.TaskId]; ok { + return task.ErrExists + } + seen[data.TaskId] = struct{}{} + + if item := s.findByTaskId(data.TaskId); item != nil { + return task.ErrExists + } + } + + for _, data := range records { + s.last++ + data.Id = s.last + if data.CreatedAt.IsZero() { + data.CreatedAt = time.Now() + } + if data.NextAttemptAt.IsZero() { + data.NextAttemptAt = data.CreatedAt + } + data.Version++ + + c := data.Clone() + s.records = append(s.records, &c) + } + + return nil +} + +func (s *store) Update(ctx context.Context, data *task.Record) error { + if err := data.Validate(); err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + + item := s.findByTaskId(data.TaskId) + if item == nil || item.Version != data.Version { + return task.ErrStaleVersion + } + + data.Version++ + + item.State = data.State + item.FailedAttempts = data.FailedAttempts + item.NextAttemptAt = data.NextAttemptAt + item.Version = data.Version + + item.CopyTo(data) + + return nil +} + +func (s *store) GetByTaskId(ctx context.Context, taskId string) (*task.Record, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + item := s.findByTaskId(taskId) + if item == nil { + return nil, task.ErrNotFound + } + + cloned := item.Clone() + return &cloned, nil +} + +func (s *store) GetAllReadyByState(ctx context.Context, state task.State, asOf time.Time, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*task.Record, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + items := s.findByState(state) + items = s.filterReady(items, asOf) + + if items = s.filter(items, cursor, limit, direction); len(items) > 0 { + return cloneRecords(items), nil + } + + return nil, task.ErrNotFound +} + +func (s *store) CountByState(ctx context.Context, state task.State) (uint64, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + items := s.findByState(state) + return uint64(len(items)), nil +} + +func (s *store) findByTaskId(taskId string) *task.Record { + for _, item := range s.records { + if item.TaskId == taskId { + return item + } + } + return nil +} + +func (s *store) findByState(state task.State) []*task.Record { + var res []*task.Record + for _, item := range s.records { + if item.State == state { + res = append(res, item) + } + } + return res +} + +func (s *store) filterReady(items []*task.Record, asOf time.Time) []*task.Record { + var res []*task.Record + for _, item := range items { + if !item.NextAttemptAt.After(asOf) { + res = append(res, item) + } + } + return res +} + +func (s *store) filter(items []*task.Record, cursor query.Cursor, limit uint64, direction query.Ordering) []*task.Record { + var start uint64 + + start = 0 + if direction == query.Descending { + start = s.last + 1 + } + if len(cursor) > 0 { + start = cursor.ToUint64() + } + + var res []*task.Record + for _, item := range items { + if item.Id > start && direction == query.Ascending { + res = append(res, item) + } + if item.Id < start && direction == query.Descending { + res = append(res, item) + } + } + + if direction == query.Descending { + sort.Sort(sort.Reverse(ById(res))) + } + + if len(res) >= int(limit) { + return res[:limit] + } + + return res +} + +func cloneRecords(items []*task.Record) []*task.Record { + var res []*task.Record + for _, item := range items { + cloned := item.Clone() + res = append(res, &cloned) + } + return res +} + +func (s *store) reset() { + s.mu.Lock() + defer s.mu.Unlock() + + s.records = nil + s.last = 0 +} diff --git a/ocp/data/task/memory/store_test.go b/ocp/data/task/memory/store_test.go new file mode 100644 index 0000000..df9e45e --- /dev/null +++ b/ocp/data/task/memory/store_test.go @@ -0,0 +1,15 @@ +package memory + +import ( + "testing" + + "github.com/code-payments/ocp-server/ocp/data/task/tests" +) + +func TestTaskMemoryStore(t *testing.T) { + testStore := New() + teardown := func() { + testStore.(*store).reset() + } + tests.RunTests(t, testStore, teardown) +} diff --git a/ocp/data/task/postgres/model.go b/ocp/data/task/postgres/model.go new file mode 100644 index 0000000..85545e4 --- /dev/null +++ b/ocp/data/task/postgres/model.go @@ -0,0 +1,201 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/jmoiron/sqlx" + + pgutil "github.com/code-payments/ocp-server/database/postgres" + q "github.com/code-payments/ocp-server/database/query" + "github.com/code-payments/ocp-server/ocp/data/task" +) + +const ( + tableName = "ocp__core_task" +) + +type model struct { + Id sql.NullInt64 `db:"id"` + TaskId string `db:"task_id"` + TaskType uint32 `db:"task_type"` + Data []byte `db:"data"` + ReferenceId sql.NullString `db:"reference_id"` + State uint8 `db:"state"` + FailedAttempts uint32 `db:"failed_attempts"` + NextAttemptAt time.Time `db:"next_attempt_at"` + Version uint64 `db:"version"` + CreatedAt time.Time `db:"created_at"` +} + +func toModel(obj *task.Record) (*model, error) { + if err := obj.Validate(); err != nil { + return nil, err + } + + if obj.CreatedAt.IsZero() { + obj.CreatedAt = time.Now().UTC() + } + + if obj.NextAttemptAt.IsZero() { + obj.NextAttemptAt = obj.CreatedAt + } + + var referenceId sql.NullString + if obj.ReferenceId != nil { + referenceId = sql.NullString{String: *obj.ReferenceId, Valid: true} + } + + return &model{ + Id: sql.NullInt64{Int64: int64(obj.Id), Valid: true}, + TaskId: obj.TaskId, + TaskType: obj.Type, + Data: obj.Data, + ReferenceId: referenceId, + State: uint8(obj.State), + FailedAttempts: obj.FailedAttempts, + NextAttemptAt: obj.NextAttemptAt, + Version: obj.Version, + CreatedAt: obj.CreatedAt, + }, nil +} + +func fromModel(m *model) *task.Record { + var referenceId *string + if m.ReferenceId.Valid { + value := m.ReferenceId.String + referenceId = &value + } + + return &task.Record{ + Id: uint64(m.Id.Int64), + TaskId: m.TaskId, + Type: m.TaskType, + Data: m.Data, + ReferenceId: referenceId, + State: task.State(m.State), + FailedAttempts: m.FailedAttempts, + NextAttemptAt: m.NextAttemptAt, + Version: m.Version, + CreatedAt: m.CreatedAt, + } +} + +func dbPutAllInTx(ctx context.Context, tx *sqlx.Tx, models []*model) ([]*model, error) { + var res []*model + + query := `INSERT INTO ` + tableName + ` (task_id, task_type, data, reference_id, state, failed_attempts, next_attempt_at, version, created_at) VALUES ` + + var parameters []interface{} + for i, model := range models { + baseIndex := len(parameters) + query += fmt.Sprintf( + `($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d + 1, $%d)`, + baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7, baseIndex+8, baseIndex+9, + ) + + if i != len(models)-1 { + query += "," + } + + parameters = append( + parameters, + model.TaskId, + model.TaskType, + model.Data, + model.ReferenceId, + model.State, + model.FailedAttempts, + model.NextAttemptAt, + model.Version, + model.CreatedAt, + ) + } + + query += ` RETURNING id, task_id, task_type, data, reference_id, state, failed_attempts, next_attempt_at, version, created_at` + + err := tx.SelectContext( + ctx, + &res, + query, + parameters..., + ) + if err != nil { + return nil, pgutil.CheckUniqueViolation(err, task.ErrExists) + } + + return res, nil +} + +func (m *model) dbUpdate(ctx context.Context, db *sqlx.DB) error { + return pgutil.ExecuteInTx(ctx, db, sql.LevelDefault, func(tx *sqlx.Tx) error { + query := `UPDATE ` + tableName + ` + SET state = $3, failed_attempts = $4, next_attempt_at = $5, version = version + 1 + WHERE task_id = $1 AND version = $2 + RETURNING id, task_id, task_type, data, reference_id, state, failed_attempts, next_attempt_at, version, created_at` + + err := tx.QueryRowxContext( + ctx, + query, + m.TaskId, + m.Version, + m.State, + m.FailedAttempts, + m.NextAttemptAt, + ).StructScan(m) + if err != nil { + return pgutil.CheckNoRows(err, task.ErrStaleVersion) + } + + return nil + }) +} + +func dbGetByTaskId(ctx context.Context, db *sqlx.DB, taskId string) (*model, error) { + res := &model{} + + query := `SELECT id, task_id, task_type, data, reference_id, state, failed_attempts, next_attempt_at, version, created_at + FROM ` + tableName + ` + WHERE task_id = $1 + LIMIT 1` + + err := db.GetContext(ctx, res, query, taskId) + if err != nil { + return nil, pgutil.CheckNoRows(err, task.ErrNotFound) + } + return res, nil +} + +func dbGetAllReadyByState(ctx context.Context, db *sqlx.DB, state task.State, asOf time.Time, cursor q.Cursor, limit uint64, direction q.Ordering) ([]*model, error) { + res := []*model{} + + query := `SELECT + id, task_id, task_type, data, reference_id, state, failed_attempts, next_attempt_at, version, created_at + FROM ` + tableName + ` + WHERE state = $1 AND next_attempt_at <= $2` + + opts := []interface{}{state, asOf} + query, opts = q.PaginateQuery(query, opts, cursor, limit, direction) + + err := db.SelectContext(ctx, &res, query, opts...) + if err != nil { + return nil, pgutil.CheckNoRows(err, task.ErrNotFound) + } + + if len(res) == 0 { + return nil, task.ErrNotFound + } + return res, nil +} + +func dbCountByState(ctx context.Context, db *sqlx.DB, state task.State) (uint64, error) { + var res uint64 + query := `SELECT COUNT(*) FROM ` + tableName + ` WHERE state = $1` + err := db.GetContext(ctx, &res, query, state) + if err != nil { + return 0, err + } + return res, nil +} diff --git a/ocp/data/task/postgres/store.go b/ocp/data/task/postgres/store.go new file mode 100644 index 0000000..343ef4d --- /dev/null +++ b/ocp/data/task/postgres/store.go @@ -0,0 +1,118 @@ +package postgres + +import ( + "context" + "database/sql" + "errors" + "time" + + "github.com/jmoiron/sqlx" + + pgutil "github.com/code-payments/ocp-server/database/postgres" + "github.com/code-payments/ocp-server/database/query" + "github.com/code-payments/ocp-server/ocp/data/task" +) + +type store struct { + db *sqlx.DB +} + +func New(db *sql.DB) task.Store { + return &store{ + db: sqlx.NewDb(db, "pgx"), + } +} + +// PutAll implements task.Store.PutAll +func (s *store) PutAll(ctx context.Context, records ...*task.Record) error { + if len(records) == 0 { + return errors.New("empty task set") + } + + models := make([]*model, len(records)) + for i, record := range records { + if record.Id > 0 { + return task.ErrExists + } + + model, err := toModel(record) + if err != nil { + return err + } + + models[i] = model + } + + return pgutil.ExecuteInTx(ctx, s.db, sql.LevelDefault, func(tx *sqlx.Tx) error { + updated, err := dbPutAllInTx(ctx, tx, models) + if err != nil { + return err + } + + if len(updated) != len(records) { + return errors.New("unexpected count of task models returned") + } + + // Don't assume postgres properly orders things + updatedByTaskId := make(map[string]*model) + for _, model := range updated { + updatedByTaskId[model.TaskId] = model + } + + for _, record := range records { + model, ok := updatedByTaskId[record.TaskId] + if !ok { + return errors.New("task model not returned") + } + + fromModel(model).CopyTo(record) + } + + return nil + }) +} + +// Update implements task.Store.Update +func (s *store) Update(ctx context.Context, record *task.Record) error { + obj, err := toModel(record) + if err != nil { + return err + } + + err = obj.dbUpdate(ctx, s.db) + if err != nil { + return err + } + + fromModel(obj).CopyTo(record) + + return nil +} + +// GetByTaskId implements task.Store.GetByTaskId +func (s *store) GetByTaskId(ctx context.Context, taskId string) (*task.Record, error) { + obj, err := dbGetByTaskId(ctx, s.db, taskId) + if err != nil { + return nil, err + } + return fromModel(obj), nil +} + +// GetAllReadyByState implements task.Store.GetAllReadyByState +func (s *store) GetAllReadyByState(ctx context.Context, state task.State, asOf time.Time, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*task.Record, error) { + models, err := dbGetAllReadyByState(ctx, s.db, state, asOf, cursor, limit, direction) + if err != nil { + return nil, err + } + + res := make([]*task.Record, len(models)) + for i, model := range models { + res[i] = fromModel(model) + } + return res, nil +} + +// CountByState implements task.Store.CountByState +func (s *store) CountByState(ctx context.Context, state task.State) (uint64, error) { + return dbCountByState(ctx, s.db, state) +} diff --git a/ocp/data/task/postgres/store_test.go b/ocp/data/task/postgres/store_test.go new file mode 100644 index 0000000..598d2cd --- /dev/null +++ b/ocp/data/task/postgres/store_test.go @@ -0,0 +1,173 @@ +package postgres + +import ( + "context" + "database/sql" + "errors" + "os" + "testing" + "time" + + "github.com/jmoiron/sqlx" + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/code-payments/ocp-server/ocp/data/task" + "github.com/code-payments/ocp-server/ocp/data/task/tests" + + pgutil "github.com/code-payments/ocp-server/database/postgres" + postgrestest "github.com/code-payments/ocp-server/database/postgres/test" + + _ "github.com/jackc/pgx/v4/stdlib" +) + +const ( + // Used for testing ONLY, the table and migrations are external to this repository + tableCreate = ` + CREATE TABLE ocp__core_task( + id SERIAL NOT NULL PRIMARY KEY, + + task_id TEXT NOT NULL UNIQUE, + + task_type INTEGER NOT NULL, + data BYTEA, + + reference_id TEXT, + + state INTEGER NOT NULL, + + failed_attempts INTEGER NOT NULL DEFAULT 0, + next_attempt_at TIMESTAMP WITH TIME ZONE NOT NULL, + + version INTEGER NOT NULL, + + created_at TIMESTAMP WITH TIME ZONE NOT NULL + ); + + CREATE INDEX ocp__core_task_ready_by_state_idx ON ocp__core_task (state, next_attempt_at, id); + ` + + // Used for testing ONLY, the table and migrations are external to this repository + tableDestroy = ` + DROP TABLE ocp__core_task; + ` +) + +var ( + testStore task.Store + testDb *sqlx.DB + teardown func() +) + +func TestMain(m *testing.M) { + log := zap.Must(zap.NewDevelopment()) + + testPool, err := dockertest.NewPool("") + if err != nil { + log.With(zap.Error(err)).Error("Error creating docker pool") + os.Exit(1) + } + + var cleanUpFunc func() + db, cleanUpFunc, err := postgrestest.StartPostgresDB(testPool) + if err != nil { + log.With(zap.Error(err)).Error("Error starting postgres image") + os.Exit(1) + } + defer db.Close() + + if err := createTestTables(log, db); err != nil { + log.With(zap.Error(err)).Error("Error creating test tables") + cleanUpFunc() + os.Exit(1) + } + + testStore = New(db) + testDb = sqlx.NewDb(db, "pgx") + teardown = func() { + if pc := recover(); pc != nil { + cleanUpFunc() + panic(pc) + } + + if err := resetTestTables(log, db); err != nil { + log.With(zap.Error(err)).Error("Error resetting test tables") + cleanUpFunc() + os.Exit(1) + } + } + + code := m.Run() + cleanUpFunc() + os.Exit(code) +} + +func TestTaskPostgresStore(t *testing.T) { + tests.RunTests(t, testStore, teardown) +} + +// Tasks must be created atomically with whatever DB transaction is being +// passed along the context (eg. the one that commits an intent). +func TestTaskPostgresStoreTxSupport(t *testing.T) { + defer teardown() + + ctx := context.Background() + + record := &task.Record{ + TaskId: "test_task_id", + Type: 1, + Data: []byte("test_data"), + State: task.StatePending, + NextAttemptAt: time.Now(), + CreatedAt: time.Now(), + } + + errRollback := errors.New("rollback") + err := pgutil.ExecuteTxWithinCtx(ctx, testDb, sql.LevelDefault, func(ctx context.Context) error { + if err := testStore.PutAll(ctx, record); err != nil { + return err + } + return errRollback + }) + assert.Equal(t, errRollback, err) + + _, err = testStore.GetByTaskId(ctx, "test_task_id") + assert.Equal(t, task.ErrNotFound, err) + + record.Id = 0 + record.Version = 0 + err = pgutil.ExecuteTxWithinCtx(ctx, testDb, sql.LevelDefault, func(ctx context.Context) error { + if err := testStore.PutAll(ctx, record); err != nil { + return err + } + + // The task is visible within the transaction via the same connection + return testStore.Update(ctx, record) + }) + require.NoError(t, err) + + actual, err := testStore.GetByTaskId(ctx, "test_task_id") + require.NoError(t, err) + assert.EqualValues(t, 2, actual.Version) +} + +func createTestTables(log *zap.Logger, db *sql.DB) error { + _, err := db.Exec(tableCreate) + if err != nil { + log.With(zap.Error(err)).Error("could not create test tables") + return err + } + return nil +} + +func resetTestTables(log *zap.Logger, db *sql.DB) error { + _, err := db.Exec(tableDestroy) + if err != nil { + log.With(zap.Error(err)).Error("could not drop test tables") + return err + } + + return createTestTables(log, db) +} diff --git a/ocp/data/task/store.go b/ocp/data/task/store.go new file mode 100644 index 0000000..0712420 --- /dev/null +++ b/ocp/data/task/store.go @@ -0,0 +1,36 @@ +package task + +import ( + "context" + "errors" + "time" + + "github.com/code-payments/ocp-server/database/query" +) + +var ( + ErrNotFound = errors.New("task not found") + ErrExists = errors.New("task already exists") + ErrStaleVersion = errors.New("task version is stale") +) + +type Store interface { + // PutAll creates all tasks in a single operation. This method supports + // being executed within an existing DB transaction passed along ctx. + PutAll(ctx context.Context, records ...*Record) error + + // Update updates an existing task with an optimistic concurrency check + // on the version. Only mutable fields (state, failed attempts, next + // attempt timestamp) are updated. + Update(ctx context.Context, record *Record) error + + // GetByTaskId gets a task by its task ID + GetByTaskId(ctx context.Context, taskId string) (*Record, error) + + // GetAllReadyByState gets all tasks in the provided state whose next + // attempt timestamp is at or before asOf + GetAllReadyByState(ctx context.Context, state State, asOf time.Time, cursor query.Cursor, limit uint64, direction query.Ordering) ([]*Record, error) + + // CountByState returns the count of tasks in the requested state + CountByState(ctx context.Context, state State) (uint64, error) +} diff --git a/ocp/data/task/task.go b/ocp/data/task/task.go new file mode 100644 index 0000000..02e3b5d --- /dev/null +++ b/ocp/data/task/task.go @@ -0,0 +1,123 @@ +package task + +import ( + "errors" + "time" +) + +type State uint8 + +const ( + StateUnknown State = iota + StatePending + StateConfirmed + StateFailed +) + +type Record struct { + Id uint64 + + TaskId string + + // Type and Data are opaque to the base system. The implementing app owns + // the namespace and serialization. + Type uint32 + Data []byte + + // ReferenceId is an optional correlation ID (eg. an intent ID) used purely + // for observability. The base system never reads it. + ReferenceId *string + + State State + + FailedAttempts uint32 + NextAttemptAt time.Time + + Version uint64 + + CreatedAt time.Time +} + +func (r *Record) Clone() Record { + var referenceIdCopy *string + if r.ReferenceId != nil { + value := *r.ReferenceId + referenceIdCopy = &value + } + + var dataCopy []byte + if r.Data != nil { + dataCopy = make([]byte, len(r.Data)) + copy(dataCopy, r.Data) + } + + return Record{ + Id: r.Id, + + TaskId: r.TaskId, + + Type: r.Type, + Data: dataCopy, + + ReferenceId: referenceIdCopy, + + State: r.State, + + FailedAttempts: r.FailedAttempts, + NextAttemptAt: r.NextAttemptAt, + + Version: r.Version, + + CreatedAt: r.CreatedAt, + } +} + +func (r *Record) CopyTo(dst *Record) { + cloned := r.Clone() + + dst.Id = cloned.Id + + dst.TaskId = cloned.TaskId + + dst.Type = cloned.Type + dst.Data = cloned.Data + + dst.ReferenceId = cloned.ReferenceId + + dst.State = cloned.State + + dst.FailedAttempts = cloned.FailedAttempts + dst.NextAttemptAt = cloned.NextAttemptAt + + dst.Version = cloned.Version + + dst.CreatedAt = cloned.CreatedAt +} + +func (r *Record) Validate() error { + if len(r.TaskId) == 0 { + return errors.New("task id is required") + } + + if r.Type == 0 { + return errors.New("type is required") + } + + if r.ReferenceId != nil && len(*r.ReferenceId) == 0 { + return errors.New("reference id is empty when set") + } + + return nil +} + +func (s State) String() string { + switch s { + case StatePending: + return "pending" + case StateConfirmed: + return "confirmed" + case StateFailed: + return "failed" + } + return "unknown" +} diff --git a/ocp/data/task/tests/tests.go b/ocp/data/task/tests/tests.go new file mode 100644 index 0000000..e308ebd --- /dev/null +++ b/ocp/data/task/tests/tests.go @@ -0,0 +1,431 @@ +package tests + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/code-payments/ocp-server/database/query" + "github.com/code-payments/ocp-server/ocp/data/task" + "github.com/code-payments/ocp-server/pointer" +) + +func RunTests(t *testing.T, s task.Store, teardown func()) { + for _, tf := range []func(t *testing.T, s task.Store){ + testRoundTrip, + testPutAllBatch, + testPutAllDuplicate, + testUpdateHappyPath, + testUpdateStaleRecord, + testGetAllReadyByState, + testCountByState, + } { + tf(t, s) + teardown() + } +} + +func testRoundTrip(t *testing.T, s task.Store) { + t.Run("testRoundTrip", func(t *testing.T) { + ctx := context.Background() + + actual, err := s.GetByTaskId(ctx, "test_task_id") + require.Error(t, err) + assert.Equal(t, task.ErrNotFound, err) + assert.Nil(t, actual) + + invalid := &task.Record{ + Type: 1, + Data: []byte("test_data"), + } + require.Error(t, s.PutAll(ctx, invalid)) + + invalid = &task.Record{ + TaskId: "test_task_id", + Data: []byte("test_data"), + } + require.Error(t, s.PutAll(ctx, invalid)) + + expected := &task.Record{ + TaskId: "test_task_id", + + Type: 1, + Data: []byte("test_data"), + + ReferenceId: pointer.String("test_reference_id"), + + State: task.StatePending, + + NextAttemptAt: time.Now(), + + CreatedAt: time.Now(), + } + cloned := expected.Clone() + err = s.PutAll(ctx, expected) + require.NoError(t, err) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 1, expected.Version) + + actual, err = s.GetByTaskId(ctx, "test_task_id") + require.NoError(t, err) + assertEquivalentRecords(t, &cloned, actual) + assert.EqualValues(t, 1, actual.Id) + assert.EqualValues(t, 1, actual.Version) + }) +} + +func testPutAllBatch(t *testing.T, s task.Store) { + t.Run("testPutAllBatch", func(t *testing.T) { + ctx := context.Background() + + require.Error(t, s.PutAll(ctx)) + + var expected []*task.Record + for i := range 10 { + expected = append(expected, &task.Record{ + TaskId: fmt.Sprintf("test_task_id_%d", i), + + Type: uint32(i + 1), + Data: []byte(fmt.Sprintf("test_data_%d", i)), + + State: task.StatePending, + + NextAttemptAt: time.Now(), + + CreatedAt: time.Now(), + }) + } + + var cloned []task.Record + for _, record := range expected { + cloned = append(cloned, record.Clone()) + } + + require.NoError(t, s.PutAll(ctx, expected...)) + + for i, record := range expected { + assert.EqualValues(t, i+1, record.Id) + assert.EqualValues(t, 1, record.Version) + + actual, err := s.GetByTaskId(ctx, record.TaskId) + require.NoError(t, err) + assertEquivalentRecords(t, &cloned[i], actual) + } + }) +} + +func testPutAllDuplicate(t *testing.T, s task.Store) { + t.Run("testPutAllDuplicate", func(t *testing.T) { + ctx := context.Background() + + record := &task.Record{ + TaskId: "test_task_id", + + Type: 1, + Data: []byte("test_data"), + + State: task.StatePending, + + NextAttemptAt: time.Now(), + + CreatedAt: time.Now(), + } + require.NoError(t, s.PutAll(ctx, record)) + + duplicate := &task.Record{ + TaskId: "test_task_id", + + Type: 2, + Data: []byte("other_data"), + + State: task.StatePending, + + NextAttemptAt: time.Now(), + + CreatedAt: time.Now(), + } + err := s.PutAll(ctx, duplicate) + assert.Equal(t, task.ErrExists, err) + + // The batch is all-or-nothing + other := &task.Record{ + TaskId: "test_other_task_id", + + Type: 1, + Data: []byte("test_data"), + + State: task.StatePending, + + NextAttemptAt: time.Now(), + + CreatedAt: time.Now(), + } + duplicate.Id = 0 + err = s.PutAll(ctx, other, duplicate) + assert.Equal(t, task.ErrExists, err) + + _, err = s.GetByTaskId(ctx, "test_other_task_id") + assert.Equal(t, task.ErrNotFound, err) + + actual, err := s.GetByTaskId(ctx, "test_task_id") + require.NoError(t, err) + assert.EqualValues(t, 1, actual.Type) + }) +} + +func testUpdateHappyPath(t *testing.T, s task.Store) { + t.Run("testUpdateHappyPath", func(t *testing.T) { + ctx := context.Background() + + expected := &task.Record{ + TaskId: "test_task_id", + + Type: 1, + Data: []byte("test_data"), + + ReferenceId: pointer.String("test_reference_id"), + + State: task.StatePending, + + NextAttemptAt: time.Now(), + + CreatedAt: time.Now(), + } + require.NoError(t, s.PutAll(ctx, expected)) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 1, expected.Version) + + expected.State = task.StatePending + expected.FailedAttempts = 3 + expected.NextAttemptAt = time.Now().Add(time.Minute) + + err := s.Update(ctx, expected) + require.NoError(t, err) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 2, expected.Version) + + actual, err := s.GetByTaskId(ctx, "test_task_id") + require.NoError(t, err) + assertEquivalentRecords(t, expected, actual) + assert.EqualValues(t, 3, actual.FailedAttempts) + + expected.State = task.StateConfirmed + mutatedData := []byte("mutated_data_should_be_ignored") + expected.Data = mutatedData + + err = s.Update(ctx, expected) + require.NoError(t, err) + assert.EqualValues(t, 1, expected.Id) + assert.EqualValues(t, 3, expected.Version) + + actual, err = s.GetByTaskId(ctx, "test_task_id") + require.NoError(t, err) + assert.Equal(t, task.StateConfirmed, actual.State) + assert.Equal(t, []byte("test_data"), actual.Data) + assert.NotEqual(t, mutatedData, actual.Data) + + // The stored value is copied back into the updated record + assert.Equal(t, []byte("test_data"), expected.Data) + }) +} + +func testUpdateStaleRecord(t *testing.T, s task.Store) { + t.Run("testUpdateStaleRecord", func(t *testing.T) { + ctx := context.Background() + + missing := &task.Record{ + TaskId: "test_missing_task_id", + + Type: 1, + Data: []byte("test_data"), + + State: task.StatePending, + + NextAttemptAt: time.Now(), + + CreatedAt: time.Now(), + } + err := s.Update(ctx, missing) + assert.Equal(t, task.ErrStaleVersion, err) + + expected := &task.Record{ + TaskId: "test_task_id", + + Type: 1, + Data: []byte("test_data"), + + State: task.StatePending, + + NextAttemptAt: time.Now(), + + CreatedAt: time.Now(), + } + require.NoError(t, s.PutAll(ctx, expected)) + assert.EqualValues(t, 1, expected.Version) + + stale := expected.Clone() + stale.Version -= 1 + stale.State = task.StateFailed + + err = s.Update(ctx, &stale) + assert.Equal(t, task.ErrStaleVersion, err) + + actual, err := s.GetByTaskId(ctx, "test_task_id") + require.NoError(t, err) + assert.Equal(t, task.StatePending, actual.State) + assert.EqualValues(t, 1, actual.Version) + }) +} + +func testGetAllReadyByState(t *testing.T, s task.Store) { + t.Run("testGetAllReadyByState", func(t *testing.T) { + ctx := context.Background() + + now := time.Now() + + _, err := s.GetAllReadyByState(ctx, task.StatePending, now, query.EmptyCursor, 10, query.Ascending) + assert.Equal(t, task.ErrNotFound, err) + + var records []*task.Record + for i := range 100 { + state := task.StatePending + if i >= 50 { + state = task.StateConfirmed + } + + // Even pending tasks are ready, odd pending tasks are scheduled + // for the future + nextAttemptAt := now + if i%2 == 1 { + nextAttemptAt = now.Add(time.Hour) + } + + record := &task.Record{ + TaskId: fmt.Sprintf("test_task_id_%d", i), + + Type: uint32(i + 1), + Data: []byte(fmt.Sprintf("test_data_%d", i)), + + State: state, + + NextAttemptAt: nextAttemptAt, + + CreatedAt: now, + } + records = append(records, record) + } + require.NoError(t, s.PutAll(ctx, records...)) + + // Only the 25 ready pending tasks are returned + allActual, err := s.GetAllReadyByState(ctx, task.StatePending, now, query.EmptyCursor, 100, query.Ascending) + require.NoError(t, err) + require.Len(t, allActual, 25) + for i, actual := range allActual { + assertEquivalentRecords(t, records[2*i], actual) + } + + // All 50 pending tasks are ready in an hour + allActual, err = s.GetAllReadyByState(ctx, task.StatePending, now.Add(time.Hour), query.EmptyCursor, 100, query.Ascending) + require.NoError(t, err) + require.Len(t, allActual, 50) + + allActual, err = s.GetAllReadyByState(ctx, task.StatePending, now, query.EmptyCursor, 10, query.Ascending) + require.NoError(t, err) + require.Len(t, allActual, 10) + for i, actual := range allActual { + assertEquivalentRecords(t, records[2*i], actual) + } + + allActual, err = s.GetAllReadyByState(ctx, task.StatePending, now, query.EmptyCursor, 10, query.Descending) + require.NoError(t, err) + require.Len(t, allActual, 10) + for i, actual := range allActual { + assertEquivalentRecords(t, records[48-2*i], actual) + } + + allActual, err = s.GetAllReadyByState(ctx, task.StatePending, now, query.ToCursor(records[24].Id), 10, query.Ascending) + require.NoError(t, err) + require.Len(t, allActual, 10) + for i, actual := range allActual { + assertEquivalentRecords(t, records[26+2*i], actual) + } + + allActual, err = s.GetAllReadyByState(ctx, task.StatePending, now, query.ToCursor(records[24].Id), 10, query.Descending) + require.NoError(t, err) + require.Len(t, allActual, 10) + for i, actual := range allActual { + assertEquivalentRecords(t, records[22-2*i], actual) + } + + _, err = s.GetAllReadyByState(ctx, task.StatePending, now, query.ToCursor(records[98].Id), 10, query.Ascending) + assert.Equal(t, task.ErrNotFound, err) + + _, err = s.GetAllReadyByState(ctx, task.StateFailed, now, query.EmptyCursor, 10, query.Ascending) + assert.Equal(t, task.ErrNotFound, err) + }) +} + +func testCountByState(t *testing.T, s task.Store) { + t.Run("testCountByState", func(t *testing.T) { + ctx := context.Background() + + count, err := s.CountByState(ctx, task.StatePending) + require.NoError(t, err) + assert.EqualValues(t, 0, count) + + var records []*task.Record + for i := range 10 { + state := task.StatePending + if i >= 6 { + state = task.StateConfirmed + } + if i >= 9 { + state = task.StateFailed + } + + records = append(records, &task.Record{ + TaskId: fmt.Sprintf("test_task_id_%d", i), + + Type: uint32(i + 1), + Data: []byte(fmt.Sprintf("test_data_%d", i)), + + State: state, + + NextAttemptAt: time.Now(), + + CreatedAt: time.Now(), + }) + } + require.NoError(t, s.PutAll(ctx, records...)) + + count, err = s.CountByState(ctx, task.StatePending) + require.NoError(t, err) + assert.EqualValues(t, 6, count) + + count, err = s.CountByState(ctx, task.StateConfirmed) + require.NoError(t, err) + assert.EqualValues(t, 3, count) + + count, err = s.CountByState(ctx, task.StateFailed) + require.NoError(t, err) + assert.EqualValues(t, 1, count) + }) +} + +func assertEquivalentRecords(t *testing.T, obj1, obj2 *task.Record) { + assert.Equal(t, obj1.TaskId, obj2.TaskId) + + assert.Equal(t, obj1.Type, obj2.Type) + assert.Equal(t, obj1.Data, obj2.Data) + + assert.Equal(t, obj1.ReferenceId, obj2.ReferenceId) + + assert.Equal(t, obj1.State, obj2.State) + + assert.Equal(t, obj1.FailedAttempts, obj2.FailedAttempts) + assert.Equal(t, obj1.NextAttemptAt.UTC().Truncate(time.Microsecond), obj2.NextAttemptAt.UTC().Truncate(time.Microsecond)) +} diff --git a/ocp/integration/submit_intent.go b/ocp/integration/submit_intent.go index b3094ad..63b305c 100644 --- a/ocp/integration/submit_intent.go +++ b/ocp/integration/submit_intent.go @@ -6,6 +6,7 @@ import ( transactionpb "github.com/code-payments/ocp-protobuf-api/generated/go/transaction/v1" "github.com/code-payments/ocp-server/ocp/data/intent" + "github.com/code-payments/ocp-server/ocp/data/task" ) // SubmitIntent is an integration that hooks into SubmitIntent @@ -14,6 +15,17 @@ type SubmitIntent interface { // with app-specific validation rules AllowCreation(ctx context.Context, intentRecord *intent.Record, metadata *transactionpb.Metadata, actions []*transactionpb.Action) error + // GetTasksToSchedule returns app-defined tasks whose execution is + // guaranteed once the intent is committed. Returned tasks are persisted in + // the same DB transaction that commits the intent, so an error fails the + // whole submission and no tasks are scheduled. + // + // Execution is delegated to the app's TaskExecutor and is at-least-once: a + // best-effort attempt is made immediately after commit, with the background + // worker retrying anything that fails. The same task may therefore execute + // more than once and concurrently, so the TaskExecutor must be idempotent. + GetTasksToSchedule(ctx context.Context, intentRecord *intent.Record) ([]*task.Record, error) + // OnSuccess is a best-effort callback when an intent has been successfully // submitted OnSuccess(ctx context.Context, intentRecord *intent.Record) error @@ -31,6 +43,10 @@ func (i *defaultSubmitIntentIntegration) AllowCreation(ctx context.Context, inte return nil } +func (i *defaultSubmitIntentIntegration) GetTasksToSchedule(ctx context.Context, intentRecord *intent.Record) ([]*task.Record, error) { + return nil, nil +} + func (i *defaultSubmitIntentIntegration) OnSuccess(ctx context.Context, intentRecord *intent.Record) error { return nil } diff --git a/ocp/integration/task.go b/ocp/integration/task.go new file mode 100644 index 0000000..dadc4b4 --- /dev/null +++ b/ocp/integration/task.go @@ -0,0 +1,34 @@ +package integration + +import ( + "context" + "errors" + + "github.com/code-payments/ocp-server/ocp/data/task" +) + +// TaskExecutor executes app-defined tasks whose execution is guaranteed by +// the base system. +// +// Implementations must be idempotent. Tasks are delivered at-least-once and +// may be executed concurrently, in any order, by any number of processes. +// The task ID is a natural deduplication key. +// +// A returned error means the task will be retried at a later time, until a +// maximum attempt count is reached. +type TaskExecutor interface { + Execute(ctx context.Context, record *task.Record) error +} + +type defaultTaskExecutor struct { +} + +// NewDefaultTaskExecutor returns a TaskExecutor that fails every task, so +// orphaned tasks are surfaced loudly instead of being silently confirmed. +func NewDefaultTaskExecutor() TaskExecutor { + return &defaultTaskExecutor{} +} + +func (e *defaultTaskExecutor) Execute(ctx context.Context, record *task.Record) error { + return errors.New("no task executor registered") +} diff --git a/ocp/rpc/transaction/intent.go b/ocp/rpc/transaction/intent.go index c9ccd3f..bf92d1e 100644 --- a/ocp/rpc/transaction/intent.go +++ b/ocp/rpc/transaction/intent.go @@ -28,6 +28,7 @@ import ( "github.com/code-payments/ocp-server/ocp/data/fulfillment" "github.com/code-payments/ocp-server/ocp/data/intent" "github.com/code-payments/ocp-server/ocp/data/nonce" + "github.com/code-payments/ocp-server/ocp/data/task" "github.com/code-payments/ocp-server/ocp/rpc" "github.com/code-payments/ocp-server/ocp/transaction" "github.com/code-payments/ocp-server/pointer" @@ -626,6 +627,7 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm // Note: This is the first use case of this new method to do this kind of // operation. Not all store implementations have real support for this, so // if anything is added, then ensure it does! + var tasksToSchedule []*task.Record err = s.data.ExecuteInTx(ctx, sql.LevelDefault, func(ctx context.Context) error { // Save any supporting records that must exist before the intent record err = intentHandler.OnPreSaveToDB(ctx) @@ -698,6 +700,19 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm } } + // Schedule app-defined tasks atomically with the intent, so their + // execution is guaranteed once the intent is committed + tasksToSchedule, err = s.submitIntentIntegration.GetTasksToSchedule(ctx, intentRecord) + if err != nil { + log.With(zap.Error(err)).Warn("failure getting tasks to schedule from integration") + return err + } + err = s.taskScheduler.Enqueue(ctx, tasksToSchedule...) + if err != nil { + log.With(zap.Error(err)).Warn("failure enqueuing tasks") + return err + } + return nil }) if err != nil { @@ -709,12 +724,24 @@ func (s *transactionServer) SubmitIntent(streamer transactionpb.Transaction_Subm } go func() { - err := s.submitIntentIntegration.OnSuccess(context.Background(), intentRecord) + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 10*time.Second) + defer cancel() + + err := s.submitIntentIntegration.OnSuccess(ctx, intentRecord) if err != nil { log.With(zap.Error(err)).Warn("failure calling integration success callback") } }() + // Fast path for scheduled tasks. Anything that fails here is picked up + // by the background worker. + go func() { + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 10*time.Second) + defer cancel() + + s.taskScheduler.TryExecuteNow(ctx, tasksToSchedule...) + }() + // // Intent is submitted, and anything beyond this point is best-effort. // We must send success back to the client. Rolling back the intent is diff --git a/ocp/rpc/transaction/server.go b/ocp/rpc/transaction/server.go index cc93526..9bbca93 100644 --- a/ocp/rpc/transaction/server.go +++ b/ocp/rpc/transaction/server.go @@ -18,6 +18,7 @@ import ( ocp_data "github.com/code-payments/ocp-server/ocp/data" "github.com/code-payments/ocp-server/ocp/data/nonce" "github.com/code-payments/ocp-server/ocp/integration" + ocp_task "github.com/code-payments/ocp-server/ocp/task" "github.com/code-payments/ocp-server/ocp/transaction" ) @@ -34,6 +35,8 @@ type transactionServer struct { submitIntentIntegration integration.SubmitIntent swapIntegration integration.Swap + taskScheduler *ocp_task.Scheduler + antispamGuard *antispam.Guard amlGuard *aml.Guard @@ -56,6 +59,7 @@ func NewTransactionServer( mintDataProvider *currency_util.MintDataProvider, submitIntentIntegration integration.SubmitIntent, swapIntegration integration.Swap, + taskScheduler *ocp_task.Scheduler, antispamGuard *antispam.Guard, amlGuard *aml.Guard, coinbaseClient *coinbase.Client, @@ -104,6 +108,8 @@ func NewTransactionServer( submitIntentIntegration: submitIntentIntegration, swapIntegration: swapIntegration, + taskScheduler: taskScheduler, + antispamGuard: antispamGuard, amlGuard: amlGuard, diff --git a/ocp/task/scheduler.go b/ocp/task/scheduler.go new file mode 100644 index 0000000..8decfb4 --- /dev/null +++ b/ocp/task/scheduler.go @@ -0,0 +1,168 @@ +package task + +import ( + "context" + "sync" + "time" + + "go.uber.org/zap" + + ocp_data "github.com/code-payments/ocp-server/ocp/data" + task_data "github.com/code-payments/ocp-server/ocp/data/task" + "github.com/code-payments/ocp-server/ocp/integration" + "github.com/code-payments/ocp-server/retry/backoff" +) + +const ( + defaultMaxFailedAttempts = 10 + defaultBackoffBase = time.Second + defaultMaxBackoffDelay = 5 * time.Minute +) + +// Scheduler is the entry point into the guaranteed task execution system. +// Tasks are durably enqueued within the caller's DB transaction and executed +// at least once via the app's TaskExecutor, either through the best-effort +// fast path or the background worker. +type Scheduler struct { + log *zap.Logger + data ocp_data.Provider + executor integration.TaskExecutor + maxFailedAttempts uint32 + backoffStrategy backoff.Strategy + maxBackoffDelay time.Duration +} + +type Option func(*Scheduler) + +// WithMaxFailedAttempts overrides the number of failed execution attempts +// before a task is marked as failed +func WithMaxFailedAttempts(maxFailedAttempts uint32) Option { + return func(s *Scheduler) { + s.maxFailedAttempts = maxFailedAttempts + } +} + +// WithBackoff overrides the retry backoff strategy and maximum delay between +// execution attempts +func WithBackoff(strategy backoff.Strategy, maxDelay time.Duration) Option { + return func(s *Scheduler) { + s.backoffStrategy = strategy + s.maxBackoffDelay = maxDelay + } +} + +func NewScheduler(log *zap.Logger, data ocp_data.Provider, executor integration.TaskExecutor, opts ...Option) *Scheduler { + s := &Scheduler{ + log: log, + data: data, + executor: executor, + maxFailedAttempts: defaultMaxFailedAttempts, + backoffStrategy: backoff.BinaryExponential(defaultBackoffBase), + maxBackoffDelay: defaultMaxBackoffDelay, + } + + for _, opt := range opts { + opt(s) + } + + return s +} + +// Enqueue durably persists tasks for guaranteed execution. When called within +// a DB transaction passed along ctx (eg. the one committing an intent), the +// tasks are persisted atomically with that transaction. +func (s *Scheduler) Enqueue(ctx context.Context, records ...*task_data.Record) error { + if len(records) == 0 { + return nil + } + + for _, record := range records { + if record.State == task_data.StateUnknown { + record.State = task_data.StatePending + } + + if record.State != task_data.StatePending { + return task_data.ErrExists + } + + if err := record.Validate(); err != nil { + return err + } + } + + return s.data.PutAllTasks(ctx, records...) +} + +// TryExecuteNow makes a best-effort immediate execution attempt for the +// provided tasks. Failures are left for the background worker to retry. +// +// This must only be called after the tasks have been committed to the DB. +func (s *Scheduler) TryExecuteNow(ctx context.Context, records ...*task_data.Record) { + var wg sync.WaitGroup + for _, record := range records { + wg.Add(1) + + go func(record *task_data.Record) { + defer wg.Done() + + err := s.ExecuteAndAdvance(ctx, record) + if err != nil { + s.log.With( + zap.Error(err), + zap.String("task_id", record.TaskId), + ).Warn("failure executing task on fast path") + } + }(record) + } + wg.Wait() +} + +// ExecuteAndAdvance executes a single pending task and advances its state +// based on the outcome. On success the task is confirmed. On failure the +// task is scheduled for a retry with backoff, or marked as failed once the +// maximum attempt count is reached. +// +// Both the fast path and the background worker funnel through this method, +// so the two cannot disagree on semantics. Concurrent executions of the same +// task are resolved by the optimistic concurrency check when the task record +// is updated. +func (s *Scheduler) ExecuteAndAdvance(ctx context.Context, record *task_data.Record) error { + if record.State != task_data.StatePending { + return nil + } + + executionErr := s.executor.Execute(ctx, record) + if executionErr == nil { + record.State = task_data.StateConfirmed + } else { + record.FailedAttempts++ + if record.FailedAttempts >= s.maxFailedAttempts { + record.State = task_data.StateFailed + } else { + delay := s.backoffStrategy(uint(record.FailedAttempts)) + if delay > s.maxBackoffDelay { + delay = s.maxBackoffDelay + } + record.NextAttemptAt = time.Now().Add(delay) + } + } + + err := s.data.UpdateTask(ctx, record) + if err == task_data.ErrStaleVersion { + // Another process has advanced the task. Idempotency of the executor + // makes this a safe no-op. + return nil + } else if err != nil { + return err + } + + if record.State == task_data.StateFailed { + s.log.With( + zap.NamedError("execution_error", executionErr), + zap.String("task_id", record.TaskId), + zap.Uint32("failed_attempts", record.FailedAttempts), + ).Error("task exhausted max failed attempts and is marked as failed") + } + + return executionErr +} diff --git a/ocp/task/scheduler_test.go b/ocp/task/scheduler_test.go new file mode 100644 index 0000000..0f4f60a --- /dev/null +++ b/ocp/task/scheduler_test.go @@ -0,0 +1,188 @@ +package task + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + ocp_data "github.com/code-payments/ocp-server/ocp/data" + task_data "github.com/code-payments/ocp-server/ocp/data/task" + "github.com/code-payments/ocp-server/retry/backoff" +) + +type mockExecutor struct { + mu sync.Mutex + executed []string + executeFn func(record *task_data.Record) error +} + +func (e *mockExecutor) Execute(ctx context.Context, record *task_data.Record) error { + e.mu.Lock() + defer e.mu.Unlock() + + e.executed = append(e.executed, record.TaskId) + + if e.executeFn != nil { + return e.executeFn(record) + } + return nil +} + +func (e *mockExecutor) executionCount() int { + e.mu.Lock() + defer e.mu.Unlock() + + return len(e.executed) +} + +func newTestScheduler(executor *mockExecutor, opts ...Option) (*Scheduler, ocp_data.Provider) { + data := ocp_data.NewTestDataProvider() + return NewScheduler(zap.Must(zap.NewDevelopment()), data, executor, opts...), data +} + +func newTestTask(taskId string) *task_data.Record { + return &task_data.Record{ + TaskId: taskId, + Type: 1, + Data: []byte("test_data"), + } +} + +func TestScheduler_EnqueueHappyPath(t *testing.T) { + ctx := context.Background() + + scheduler, data := newTestScheduler(&mockExecutor{}) + + require.NoError(t, scheduler.Enqueue(ctx)) + + record := newTestTask("test_task_id") + require.NoError(t, scheduler.Enqueue(ctx, record)) + assert.Equal(t, task_data.StatePending, record.State) + + actual, err := data.GetTaskById(ctx, "test_task_id") + require.NoError(t, err) + assert.Equal(t, task_data.StatePending, actual.State) + assert.False(t, actual.NextAttemptAt.After(time.Now())) +} + +func TestScheduler_EnqueueValidation(t *testing.T) { + ctx := context.Background() + + scheduler, data := newTestScheduler(&mockExecutor{}) + + invalid := newTestTask("test_task_id") + invalid.State = task_data.StateConfirmed + require.Error(t, scheduler.Enqueue(ctx, invalid)) + + invalid = newTestTask("test_task_id") + invalid.Type = 0 + require.Error(t, scheduler.Enqueue(ctx, invalid)) + + _, err := data.GetTaskById(ctx, "test_task_id") + assert.Equal(t, task_data.ErrNotFound, err) +} + +func TestScheduler_TryExecuteNowHappyPath(t *testing.T) { + ctx := context.Background() + + executor := &mockExecutor{} + scheduler, data := newTestScheduler(executor) + + records := []*task_data.Record{ + newTestTask("test_task_id_1"), + newTestTask("test_task_id_2"), + } + require.NoError(t, scheduler.Enqueue(ctx, records...)) + + scheduler.TryExecuteNow(ctx, records...) + assert.Equal(t, 2, executor.executionCount()) + + for _, record := range records { + actual, err := data.GetTaskById(ctx, record.TaskId) + require.NoError(t, err) + assert.Equal(t, task_data.StateConfirmed, actual.State) + assert.EqualValues(t, 0, actual.FailedAttempts) + } +} + +func TestScheduler_ExecuteAndAdvanceRetryWithBackoff(t *testing.T) { + ctx := context.Background() + + executor := &mockExecutor{ + executeFn: func(record *task_data.Record) error { + return errors.New("transient failure") + }, + } + scheduler, data := newTestScheduler(executor, WithBackoff(backoff.Constant(time.Minute), time.Hour)) + + record := newTestTask("test_task_id") + require.NoError(t, scheduler.Enqueue(ctx, record)) + + err := scheduler.ExecuteAndAdvance(ctx, record) + require.Error(t, err) + + actual, err := data.GetTaskById(ctx, "test_task_id") + require.NoError(t, err) + assert.Equal(t, task_data.StatePending, actual.State) + assert.EqualValues(t, 1, actual.FailedAttempts) + assert.True(t, actual.NextAttemptAt.After(time.Now().Add(30*time.Second))) + + // The task is no longer ready for execution until the backoff elapses + _, err = data.GetAllReadyTasksByState(ctx, task_data.StatePending, time.Now()) + assert.Equal(t, task_data.ErrNotFound, err) +} + +func TestScheduler_ExecuteAndAdvanceDeadLetter(t *testing.T) { + ctx := context.Background() + + executor := &mockExecutor{ + executeFn: func(record *task_data.Record) error { + return errors.New("permanent failure") + }, + } + scheduler, data := newTestScheduler(executor, WithMaxFailedAttempts(3), WithBackoff(backoff.Constant(0), 0)) + + record := newTestTask("test_task_id") + require.NoError(t, scheduler.Enqueue(ctx, record)) + + for range 3 { + require.Error(t, scheduler.ExecuteAndAdvance(ctx, record)) + } + + actual, err := data.GetTaskById(ctx, "test_task_id") + require.NoError(t, err) + assert.Equal(t, task_data.StateFailed, actual.State) + assert.EqualValues(t, 3, actual.FailedAttempts) + + // Failed tasks are no longer executed + require.NoError(t, scheduler.ExecuteAndAdvance(ctx, actual)) + assert.Equal(t, 3, executor.executionCount()) +} + +func TestScheduler_ExecuteAndAdvanceStaleRace(t *testing.T) { + ctx := context.Background() + + executor := &mockExecutor{} + scheduler, data := newTestScheduler(executor) + + record := newTestTask("test_task_id") + require.NoError(t, scheduler.Enqueue(ctx, record)) + + // Simulate another process executing the same task concurrently + racingCopy := record.Clone() + require.NoError(t, scheduler.ExecuteAndAdvance(ctx, &racingCopy)) + + // The loser of the race treats the stale update as a no-op success + require.NoError(t, scheduler.ExecuteAndAdvance(ctx, record)) + + actual, err := data.GetTaskById(ctx, "test_task_id") + require.NoError(t, err) + assert.Equal(t, task_data.StateConfirmed, actual.State) + assert.EqualValues(t, 2, actual.Version) +} diff --git a/ocp/worker/task/config.go b/ocp/worker/task/config.go new file mode 100644 index 0000000..12c8679 --- /dev/null +++ b/ocp/worker/task/config.go @@ -0,0 +1,29 @@ +package task + +import ( + "github.com/code-payments/ocp-server/config" + "github.com/code-payments/ocp-server/config/env" +) + +const ( + envConfigPrefix = "TASK_RUNTIME_" + + BatchSizeConfigEnvName = envConfigPrefix + "WORKER_BATCH_SIZE" + defaultBatchSize = 100 +) + +type conf struct { + batchSize config.Uint64 +} + +// ConfigProvider defines how config values are pulled +type ConfigProvider func() *conf + +// WithEnvConfigs returns configuration pulled from environment variables +func WithEnvConfigs() ConfigProvider { + return func() *conf { + return &conf{ + batchSize: env.NewUint64Config(BatchSizeConfigEnvName, defaultBatchSize), + } + } +} diff --git a/ocp/worker/task/metrics.go b/ocp/worker/task/metrics.go new file mode 100644 index 0000000..c41c3e9 --- /dev/null +++ b/ocp/worker/task/metrics.go @@ -0,0 +1,46 @@ +package task + +import ( + "context" + "time" + + "github.com/code-payments/ocp-server/metrics" + task_data "github.com/code-payments/ocp-server/ocp/data/task" +) + +const ( + taskCountEventName = "TaskCountPollingCheck" +) + +func (p *runtime) metricsGaugeWorker(ctx context.Context) error { + delay := time.Second + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delay): + start := time.Now() + + for _, state := range []task_data.State{ + task_data.StatePending, + task_data.StateFailed, + } { + count, err := p.data.GetTaskCountByState(ctx, state) + if err != nil { + continue + } + recordTaskCountEvent(ctx, state, count) + } + + delay = time.Second - time.Since(start) + } + } +} + +func recordTaskCountEvent(ctx context.Context, state task_data.State, count uint64) { + metrics.RecordEvent(ctx, taskCountEventName, map[string]interface{}{ + "count": count, + "state": state.String(), + }) +} diff --git a/ocp/worker/task/runtime.go b/ocp/worker/task/runtime.go new file mode 100644 index 0000000..0a1a865 --- /dev/null +++ b/ocp/worker/task/runtime.go @@ -0,0 +1,54 @@ +package task + +import ( + "context" + "time" + + "go.uber.org/zap" + + ocp_data "github.com/code-payments/ocp-server/ocp/data" + ocp_task "github.com/code-payments/ocp-server/ocp/task" + "github.com/code-payments/ocp-server/ocp/worker" +) + +type runtime struct { + log *zap.Logger + conf *conf + data ocp_data.Provider + scheduler *ocp_task.Scheduler +} + +func New( + log *zap.Logger, + data ocp_data.Provider, + scheduler *ocp_task.Scheduler, + configProvider ConfigProvider, +) worker.Runtime { + return &runtime{ + log: log, + conf: configProvider(), + data: data, + scheduler: scheduler, + } +} + +func (p *runtime) Start(ctx context.Context, interval time.Duration) error { + go func() { + err := p.worker(ctx, interval) + if err != nil && err != context.Canceled { + p.log.With(zap.Error(err)).Warn("task processing loop terminated unexpectedly") + } + }() + + go func() { + err := p.metricsGaugeWorker(ctx) + if err != nil && err != context.Canceled { + p.log.With(zap.Error(err)).Warn("task metrics gauge loop terminated unexpectedly") + } + }() + + select { + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/ocp/worker/task/worker.go b/ocp/worker/task/worker.go new file mode 100644 index 0000000..d771bb7 --- /dev/null +++ b/ocp/worker/task/worker.go @@ -0,0 +1,69 @@ +package task + +import ( + "context" + "sync" + "time" + + "github.com/code-payments/ocp-server/database/query" + "github.com/code-payments/ocp-server/metrics" + task_data "github.com/code-payments/ocp-server/ocp/data/task" + "github.com/code-payments/ocp-server/retry" +) + +func (p *runtime) worker(runtimeCtx context.Context, interval time.Duration) error { + var cursor query.Cursor + delay := interval + + err := retry.Loop( + func() (err error) { + time.Sleep(delay) + + provider := runtimeCtx.Value(metrics.ProviderContextKey).(metrics.Provider) + trace := provider.StartTrace("task_runtime__handle_pending") + defer trace.End() + tracedCtx := metrics.NewContext(runtimeCtx, trace) + + items, err := p.data.GetAllReadyTasksByState( + tracedCtx, + task_data.StatePending, + time.Now(), + query.WithLimit(p.conf.batchSize.Get(runtimeCtx)), + query.WithCursor(cursor), + ) + if err == task_data.ErrNotFound { + cursor = query.EmptyCursor + return nil + } else if err != nil { + cursor = query.EmptyCursor + return err + } + + var wg sync.WaitGroup + for _, item := range items { + wg.Add(1) + + go func(record *task_data.Record) { + defer wg.Done() + + err := p.scheduler.ExecuteAndAdvance(tracedCtx, record) + if err != nil { + trace.OnError(err) + } + }(item) + } + wg.Wait() + + if len(items) > 0 { + cursor = query.ToCursor(items[len(items)-1].Id) + } else { + cursor = query.EmptyCursor + } + + return nil + }, + retry.NonRetriableErrors(context.Canceled), + ) + + return err +} diff --git a/ocp/worker/task/worker_test.go b/ocp/worker/task/worker_test.go new file mode 100644 index 0000000..d3c90d1 --- /dev/null +++ b/ocp/worker/task/worker_test.go @@ -0,0 +1,85 @@ +package task + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/code-payments/ocp-server/metrics" + "github.com/code-payments/ocp-server/metrics/noop" + ocp_data "github.com/code-payments/ocp-server/ocp/data" + task_data "github.com/code-payments/ocp-server/ocp/data/task" + ocp_task "github.com/code-payments/ocp-server/ocp/task" + "github.com/code-payments/ocp-server/retry/backoff" + "github.com/code-payments/ocp-server/testutil" +) + +type mockExecutor struct { + executions int32 + executeFn func(record *task_data.Record) error +} + +func (e *mockExecutor) Execute(ctx context.Context, record *task_data.Record) error { + atomic.AddInt32(&e.executions, 1) + + if e.executeFn != nil { + return e.executeFn(record) + } + return nil +} + +func TestWorker_ProcessesPendingTasks(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = context.WithValue(ctx, metrics.ProviderContextKey, noop.NewProvider()) + + log := zap.Must(zap.NewDevelopment()) + data := ocp_data.NewTestDataProvider() + + // Fail the flaky task once, then succeed + var flakyFailures int32 + executor := &mockExecutor{ + executeFn: func(record *task_data.Record) error { + if record.TaskId == "flaky_task" && atomic.CompareAndSwapInt32(&flakyFailures, 0, 1) { + return errors.New("transient failure") + } + return nil + }, + } + + scheduler := ocp_task.NewScheduler(log, data, executor, ocp_task.WithBackoff(backoff.Constant(0), 0)) + + records := []*task_data.Record{ + {TaskId: "happy_task", Type: 1, Data: []byte("data"), State: task_data.StatePending}, + {TaskId: "flaky_task", Type: 1, Data: []byte("data"), State: task_data.StatePending}, + } + require.NoError(t, scheduler.Enqueue(ctx, records...)) + + worker := New(log, data, scheduler, WithEnvConfigs()) + go func() { + worker.Start(ctx, time.Millisecond) + }() + + require.NoError(t, testutil.WaitFor(5*time.Second, 10*time.Millisecond, func() bool { + count, err := data.GetTaskCountByState(ctx, task_data.StateConfirmed) + return err == nil && count == 2 + })) + + flaky, err := data.GetTaskById(ctx, "flaky_task") + require.NoError(t, err) + assert.Equal(t, task_data.StateConfirmed, flaky.State) + assert.EqualValues(t, 1, flaky.FailedAttempts) + + happy, err := data.GetTaskById(ctx, "happy_task") + require.NoError(t, err) + assert.Equal(t, task_data.StateConfirmed, happy.State) + assert.EqualValues(t, 0, happy.FailedAttempts) + + assert.GreaterOrEqual(t, atomic.LoadInt32(&executor.executions), int32(3)) +}