diff --git a/Makefile b/Makefile index fcfdf1738..56056f0fb 100644 --- a/Makefile +++ b/Makefile @@ -94,6 +94,14 @@ ifneq ($(TEST_PACKAGES),) GO_TEST_PACKAGES := $(addprefix ./, $(addsuffix /..., $(subst :, ,$(TEST_PACKAGES)))) endif +COVERAGE_DIR := coverage +COVER ?= +ifeq ($(COVER),true) + GO_TEST_FLAGS += -coverprofile=$(COVERAGE_DIR)/coverage.out -covermode=atomic -coverpkg=./... + COVER_DEPS := $(COVERAGE_DIR) + COVER_REPORT := coverage-report +endif + ROLLUPS_CONTRACTS_ABI_BASEDIR:= rollups-contracts/ ROLLUPS_PRT_CONTRACTS_ABI_BASEDIR:= rollups-prt-contracts/ @@ -190,6 +198,7 @@ clean-go: ## Clean Go artifacts @echo "Cleaning Go artifacts" @go clean -i -r -cache @rm -f $(GO_ARTIFACTS) + @rm -rf $(COVERAGE_DIR) clean-contracts: ## Clean contract artifacts @echo "Cleaning contract artifacts" @@ -221,10 +230,19 @@ clean-test-dependencies: ## Clean the test dependencies # ============================================================================= test: unit-test ## Execute all tests -unit-test: ## Execute go unit tests +$(COVERAGE_DIR): + @mkdir -p $@ + +coverage-report: + @go tool cover -func=$(COVERAGE_DIR)/coverage.out + @go tool cover -html=$(COVERAGE_DIR)/coverage.out -o $(COVERAGE_DIR)/coverage.html + @echo "Coverage report: $(COVERAGE_DIR)/coverage.html" + +unit-test: $(COVER_DEPS) ## Execute go unit tests @echo "Running go unit tests" @go clean -testcache @go test -p 1 $(GO_BUILD_PARAMS) $(GO_TEST_FLAGS) $(GO_TEST_PACKAGES) + @$(if $(COVER_REPORT),$(MAKE) $(COVER_REPORT)) integration-test: ## Execute e2e tests @echo "Running end-to-end tests" @@ -417,4 +435,4 @@ build-debian-package: install sed 's|ARG_VERSION|$(ROLLUPS_NODE_VERSION)|g;s|ARG_ARCH|$(DEB_ARCH)|g' control.template > $(DESTDIR)/DEBIAN/control dpkg-deb -Zxz --root-owner-group --build $(DESTDIR) $(DEB_FILENAME) -.PHONY: build build-go clean clean-go test unit-test-go e2e-test lint fmt vet escape md-lint devnet image run-with-compose shutdown-compose help docs $(GO_ARTIFACTS) +.PHONY: build build-go clean clean-go test unit-test-go e2e-test lint fmt vet escape md-lint devnet image run-with-compose shutdown-compose help docs coverage-report $(GO_ARTIFACTS) diff --git a/internal/claimer/claimer.go b/internal/claimer/claimer.go index 5c5f4b982..f852d448f 100644 --- a/internal/claimer/claimer.go +++ b/internal/claimer/claimer.go @@ -75,14 +75,14 @@ type iclaimerRepository interface { UpdateEpochWithSubmittedClaim( ctx context.Context, - application_id int64, + applicationID int64, index uint64, - transaction_hash common.Hash, + transactionHash common.Hash, ) error UpdateEpochWithAcceptedClaim( ctx context.Context, - application_id int64, + applicationID int64, index uint64, ) error diff --git a/internal/node/node.go b/internal/node/node.go index d10924663..6a2c6f15f 100644 --- a/internal/node/node.go +++ b/internal/node/node.go @@ -6,7 +6,6 @@ package node import ( "context" "fmt" - "os" "github.com/cartesi/rollups-node/pkg/service" @@ -22,6 +21,13 @@ import ( "github.com/ethereum/go-ethereum/ethclient" ) +// serviceResult carries either a successfully created service or an error +// back from the goroutines in createServices. +type serviceResult struct { + service service.IService + err error +} + type CreateInfo struct { service.CreateInfo @@ -66,56 +72,59 @@ func Create(ctx context.Context, c *CreateInfo) (*Service, error) { return s, nil } +type serviceCreator func(context.Context, *CreateInfo, *Service) (service.IService, error) + func createServices(ctx context.Context, c *CreateInfo, s *Service) error { - // Count services first - numChildren := 5 // evm-reader, advancer, validator, claimer, prt + creators := []serviceCreator{ + newEVMReader, + newAdvancer, + newValidator, + newClaimer, + newPrt, + } if c.Config.FeatureJsonrpcApiEnabled { - numChildren++ // jsonrpc + creators = append(creators, newJsonrpc) } - // Create buffered channel with correct size - ch := make(chan service.IService, numChildren) - - go func() { - ch <- newEVMReader(ctx, c, s) - }() - - go func() { - ch <- newAdvancer(ctx, c, s) - }() - - go func() { - ch <- newValidator(ctx, c, s) - }() - - go func() { - ch <- newClaimer(ctx, c, s) - }() - - go func() { - ch <- newPrt(ctx, c, s) - }() - - if c.Config.FeatureJsonrpcApiEnabled { + ch := make(chan serviceResult, len(creators)) + for _, create := range creators { go func() { - ch <- newJsonrpc(ctx, c, s) + svc, err := create(ctx, c, s) + ch <- serviceResult{service: svc, err: err} }() } - for range numChildren { + for range len(creators) { select { - case child := <-ch: - s.Children = append(s.Children, child) + case result := <-ch: + if result.err != nil { + stopAndDrain(s.Children, ch, len(creators)-len(s.Children)-1) + return fmt.Errorf("failed to create service: %w", result.err) + } + s.Children = append(s.Children, result.service) case <-ctx.Done(): - err := ctx.Err() - s.Logger.Error("Failed to create services. Time limit exceeded", - "err", err) - return fmt.Errorf("failed to create services. Time limit exceeded") + stopAndDrain(s.Children, ch, len(creators)-len(s.Children)) + return fmt.Errorf("failed to create services: %w", ctx.Err()) } } return nil } +// stopAndDrain stops already-created children and drains remaining results +// from the channel, stopping any successful services to prevent resource leaks. +func stopAndDrain(children []service.IService, ch <-chan serviceResult, remaining int) { + for _, child := range children { + child.Stop(true) + } + go func() { + for range remaining { + if r := <-ch; r.err == nil && r.service != nil { + r.service.Stop(true) + } + } + }() +} + func (me *Service) Alive() bool { allAlive := true for _, s := range me.Children { @@ -151,7 +160,7 @@ func (me *Service) Serve() error { // services creation -func newEVMReader(ctx context.Context, c *CreateInfo, s *Service) service.IService { +func newEVMReader(ctx context.Context, c *CreateInfo, s *Service) (service.IService, error) { readerArgs := evmreader.CreateInfo{ CreateInfo: service.CreateInfo{ Name: "evm-reader", @@ -171,13 +180,12 @@ func newEVMReader(ctx context.Context, c *CreateInfo, s *Service) service.IServi readerService, err := evmreader.Create(ctx, &readerArgs) if err != nil { - s.Logger.Error("Fatal", "error", err) - os.Exit(1) + return nil, fmt.Errorf("create evm-reader: %w", err) } - return readerService + return readerService, nil } -func newAdvancer(ctx context.Context, c *CreateInfo, s *Service) service.IService { +func newAdvancer(ctx context.Context, c *CreateInfo, s *Service) (service.IService, error) { advancerArgs := advancer.CreateInfo{ CreateInfo: service.CreateInfo{ Name: "advancer", @@ -196,13 +204,12 @@ func newAdvancer(ctx context.Context, c *CreateInfo, s *Service) service.IServic advancerService, err := advancer.Create(ctx, &advancerArgs) if err != nil { - s.Logger.Error("Fatal", "error", err) - os.Exit(1) + return nil, fmt.Errorf("create advancer: %w", err) } - return advancerService + return advancerService, nil } -func newValidator(ctx context.Context, c *CreateInfo, s *Service) service.IService { +func newValidator(ctx context.Context, c *CreateInfo, s *Service) (service.IService, error) { validatorArgs := validator.CreateInfo{ CreateInfo: service.CreateInfo{ Name: "validator", @@ -221,13 +228,12 @@ func newValidator(ctx context.Context, c *CreateInfo, s *Service) service.IServi validatorService, err := validator.Create(ctx, &validatorArgs) if err != nil { - s.Logger.Error("Fatal", "error", err) - os.Exit(1) + return nil, fmt.Errorf("create validator: %w", err) } - return validatorService + return validatorService, nil } -func newClaimer(ctx context.Context, c *CreateInfo, s *Service) service.IService { +func newClaimer(ctx context.Context, c *CreateInfo, s *Service) (service.IService, error) { claimerArgs := claimer.CreateInfo{ CreateInfo: service.CreateInfo{ Name: "claimer", @@ -247,13 +253,12 @@ func newClaimer(ctx context.Context, c *CreateInfo, s *Service) service.IService claimerService, err := claimer.Create(ctx, &claimerArgs) if err != nil { - s.Logger.Error("Fatal", "error", err) - os.Exit(1) + return nil, fmt.Errorf("create claimer: %w", err) } - return claimerService + return claimerService, nil } -func newJsonrpc(ctx context.Context, c *CreateInfo, s *Service) service.IService { +func newJsonrpc(ctx context.Context, c *CreateInfo, s *Service) (service.IService, error) { jsonrpcArgs := jsonrpc.CreateInfo{ CreateInfo: service.CreateInfo{ Name: "jsonrpc", @@ -271,13 +276,12 @@ func newJsonrpc(ctx context.Context, c *CreateInfo, s *Service) service.IService jsonrpcService, err := jsonrpc.Create(ctx, &jsonrpcArgs) if err != nil { - s.Logger.Error("Fatal", "error", err) - os.Exit(1) + return nil, fmt.Errorf("create jsonrpc: %w", err) } - return jsonrpcService + return jsonrpcService, nil } -func newPrt(ctx context.Context, c *CreateInfo, s *Service) service.IService { +func newPrt(ctx context.Context, c *CreateInfo, s *Service) (service.IService, error) { prtArgs := prt.CreateInfo{ CreateInfo: service.CreateInfo{ Name: "prt", @@ -297,8 +301,7 @@ func newPrt(ctx context.Context, c *CreateInfo, s *Service) service.IService { prtService, err := prt.Create(ctx, &prtArgs) if err != nil { - s.Logger.Error("Fatal", "error", err) - os.Exit(1) + return nil, fmt.Errorf("create prt: %w", err) } - return prtService + return prtService, nil } diff --git a/internal/repository/factory/factory.go b/internal/repository/factory/factory.go index 4a2313f76..0506a4501 100644 --- a/internal/repository/factory/factory.go +++ b/internal/repository/factory/factory.go @@ -28,7 +28,9 @@ func NewRepositoryFromConnectionString(ctx context.Context, conn string) (Reposi // case strings.HasPrefix(lowerConn, "sqlite://"): // return newSQLiteRepository(ctx, conn) default: - return nil, fmt.Errorf("unrecognized connection string format: %s", conn) + return nil, fmt.Errorf( + "unrecognized connection string scheme (expected postgres:// or postgresql://)", + ) } } diff --git a/internal/repository/postgres/application.go b/internal/repository/postgres/application.go index b40aedd3d..b18efd6d8 100644 --- a/internal/repository/postgres/application.go +++ b/internal/repository/postgres/application.go @@ -67,12 +67,13 @@ func (r *PostgresRepository) CreateApplication( if err != nil { return 0, err } + defer tx.Rollback(ctx) //nolint:errcheck sqlStr, args := insertStmt.Sql() var newID int64 err = tx.QueryRow(ctx, sqlStr, args...).Scan(&newID) if err != nil { - return 0, errors.Join(fmt.Errorf("unable to create database application: %w", err), tx.Rollback(ctx)) + return 0, fmt.Errorf("unable to create database application: %w", err) } if !withExecutionParameters { @@ -121,12 +122,12 @@ func (r *PostgresRepository) CreateApplication( _, err = tx.Exec(ctx, sqlStr, args...) if err != nil { - return 0, errors.Join(err, tx.Rollback(ctx)) + return 0, err } err = tx.Commit(ctx) if err != nil { - return 0, errors.Join(err, tx.Rollback(ctx)) + return 0, err } return newID, nil } @@ -137,10 +138,7 @@ func (r *PostgresRepository) GetApplication( nameOrAddress string, ) (*model.Application, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) stmt := table.Application. SELECT( @@ -193,7 +191,7 @@ func (r *PostgresRepository) GetApplication( row := r.db.QueryRow(ctx, sqlStr, args...) var app model.Application - err = row.Scan( + err := row.Scan( &app.ID, &app.Name, &app.IApplicationAddress, @@ -247,10 +245,7 @@ func (r *PostgresRepository) GetProcessedInputCount( nameOrAddress string, ) (uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) stmt := table.Application. SELECT(table.Application.ProcessedInputs). @@ -260,7 +255,7 @@ func (r *PostgresRepository) GetProcessedInputCount( row := r.db.QueryRow(ctx, sqlStr, args...) var processedInputs uint64 - err = row.Scan(&processedInputs) + err := row.Scan(&processedInputs) if errors.Is(err, sql.ErrNoRows) { return 0, repository.ErrNotFound } @@ -419,7 +414,7 @@ func (r *PostgresRepository) UpdateEventLastCheckBlock( column, ). SET( - postgres.RawFloat(fmt.Sprintf("%d", blockNumber)), + uint64Expr(blockNumber), ). WHERE(table.Application.ID.IN(ids...)) @@ -430,10 +425,7 @@ func (r *PostgresRepository) UpdateEventLastCheckBlock( // GetLastSnapshot retrieves the most recent input with a snapshot for the given application func (r *PostgresRepository) GetLastSnapshot(ctx context.Context, nameOrAddress string) (*model.Input, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Input. SELECT( @@ -468,7 +460,7 @@ func (r *PostgresRepository) GetLastSnapshot(ctx context.Context, nameOrAddress row := r.db.QueryRow(ctx, sqlStr, args...) var inp model.Input - err = row.Scan( + err := row.Scan( &inp.EpochApplicationID, &inp.EpochIndex, &inp.Index, @@ -652,6 +644,9 @@ func (r *PostgresRepository) ListApplications( } apps = append(apps, &app) } + if err := rows.Err(); err != nil { + return nil, 0, err + } return apps, total, nil } @@ -757,7 +752,7 @@ func (r *PostgresRepository) UpdateExecutionParameters( return err } if cmd.RowsAffected() == 0 { - return sql.ErrNoRows + return repository.ErrNotFound } return nil } diff --git a/internal/repository/postgres/bulk.go b/internal/repository/postgres/bulk.go index 281acd79c..6bd47e2f0 100644 --- a/internal/repository/postgres/bulk.go +++ b/internal/repository/postgres/bulk.go @@ -5,8 +5,6 @@ package postgres import ( "context" - "database/sql" - "errors" "fmt" "unsafe" @@ -15,6 +13,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" "github.com/cartesi/rollups-node/internal/repository/postgres/db/rollupsdb/public/table" ) @@ -73,7 +72,8 @@ func getReportNextIndex( postgres.Float(0), ), ).FROM( - table.Report.INNER_JOIN(table.Input, table.Input.EpochApplicationID.EQ(table.Report.InputEpochApplicationID)), + table.Report.INNER_JOIN(table.Input, table.Input.EpochApplicationID.EQ(table.Report.InputEpochApplicationID). + AND(table.Input.Index.EQ(table.Report.InputIndex))), ).WHERE( table.Report.InputEpochApplicationID.EQ(postgres.Int64(appID)). AND(table.Input.Status.EQ(postgres.NewEnumValue(model.InputCompletionStatus_Accepted.String()))), @@ -102,7 +102,7 @@ func getStateHashNextIndex( ), ).WHERE( table.StateHashes.InputEpochApplicationID.EQ(postgres.Int64(appID)). - AND(table.StateHashes.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", epochIndex)))), + AND(table.StateHashes.EpochIndex.EQ(uint64Expr(epochIndex))), ) queryStr, args := query.Sql() @@ -268,7 +268,7 @@ func updateInput( ). WHERE( table.Input.EpochApplicationID.EQ(postgres.Int64(appID)). - AND(table.Input.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", inputIndex)))), + AND(table.Input.Index.EQ(uint64Expr(inputIndex))), ) sqlStr, args := updStmt.Sql() @@ -277,7 +277,7 @@ func updateInput( return err } if cmd.RowsAffected() == 0 { - return sql.ErrNoRows + return repository.ErrNotFound } return nil } @@ -305,7 +305,7 @@ func updateEpochOutputsMerkleProof( ). WHERE( table.Epoch.ApplicationID.EQ(postgres.Int64(appID)). - AND(table.Epoch.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", epochIndex)))), + AND(table.Epoch.Index.EQ(uint64Expr(epochIndex))), ) sqlStr, args := updStmt.Sql() @@ -314,7 +314,7 @@ func updateEpochOutputsMerkleProof( return err } if cmd.RowsAffected() == 0 { - return sql.ErrNoRows + return repository.ErrNotFound } return nil } @@ -331,7 +331,7 @@ func updateApp( table.Application.ProcessedInputs, ). SET( - postgres.RawFloat(fmt.Sprintf("%d", inputIndex+1)), + uint64Expr(inputIndex + 1), ). WHERE( table.Application.ID.EQ(postgres.Int64(appID)), @@ -343,7 +343,7 @@ func updateApp( return err } if cmd.RowsAffected() == 0 { - return sql.ErrNoRows + return repository.ErrNotFound } return nil } @@ -357,48 +357,44 @@ func (r *PostgresRepository) StoreAdvanceResult( if err != nil { return err } + defer tx.Rollback(ctx) //nolint:errcheck if res.Status == model.InputCompletionStatus_Accepted { err = insertOutputs(ctx, tx, appID, res.InputIndex, res.Outputs) if err != nil { - return errors.Join(err, tx.Rollback(ctx)) + return err } err = insertReports(ctx, tx, appID, res.InputIndex, res.Reports) if err != nil { - return errors.Join(err, tx.Rollback(ctx)) + return err } } if res.IsDaveConsensus { err = insertStateHashes(ctx, tx, appID, res.EpochIndex, res.InputIndex, res.Hashes, res.MachineHash, res.RemainingMetaCycles) if err != nil { - return errors.Join(err, tx.Rollback(ctx)) + return err } } err = updateInput(ctx, tx, appID, res.InputIndex, res.Status, res.OutputsHash, res.MachineHash) if err != nil { - return errors.Join(err, tx.Rollback(ctx)) + return err } err = updateEpochOutputsMerkleProof(ctx, tx, appID, res.EpochIndex, res.OutputsHash, byteSliceToHashSlice(res.OutputsHashProof), res.MachineHash) if err != nil { - return errors.Join(err, tx.Rollback(ctx)) + return err } err = updateApp(ctx, tx, appID, res.InputIndex) if err != nil { - return errors.Join(err, tx.Rollback(ctx)) - } - - err = tx.Commit(ctx) - if err != nil { - return errors.Join(err, tx.Rollback(ctx)) + return err } - return nil + return tx.Commit(ctx) } func updateEpochClaim( @@ -420,22 +416,16 @@ func updateEpochClaim( ). WHERE( table.Epoch.ApplicationID.EQ(postgres.Int64(e.ApplicationID)). - AND(table.Epoch.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", e.Index)))), + AND(table.Epoch.Index.EQ(uint64Expr(e.Index))), ) sqlStr, args := updStmt.Sql() cmd, err := tx.Exec(ctx, sqlStr, args...) if err != nil { - return errors.Join( - fmt.Errorf("SetEpochClaimAndInsertProofsTransaction failed: %w", err), - tx.Rollback(ctx), - ) + return fmt.Errorf("SetEpochClaimAndInsertProofsTransaction failed: %w", err) } if cmd.RowsAffected() != 1 { - return errors.Join( - fmt.Errorf("failed to update application %d epoch %d: no rows affected", e.ApplicationID, e.Index), - tx.Rollback(ctx), - ) + return fmt.Errorf("failed to update application %d epoch %d: no rows affected", e.ApplicationID, e.Index) } return nil } @@ -445,6 +435,11 @@ func updateOutputs( tx pgx.Tx, outputs []*model.Output, ) error { + if len(outputs) == 0 { + return nil + } + + batch := &pgx.Batch{} for _, output := range outputs { updStmt := table.Output. UPDATE( @@ -457,28 +452,26 @@ func updateOutputs( ). WHERE( table.Output.InputEpochApplicationID.EQ(postgres.Int64(output.InputEpochApplicationID)). - AND(table.Output.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", output.Index)))), + AND(table.Output.Index.EQ(uint64Expr(output.Index))), ) - sqlStr, args := updStmt.Sql() - cmd, err := tx.Exec(ctx, sqlStr, args...) + batch.Queue(sqlStr, args...) + } + + br := tx.SendBatch(ctx, batch) + + for _, output := range outputs { + cmd, err := br.Exec() if err != nil { - return errors.Join( - fmt.Errorf("failed to insert proof for output '%d'. %w", output.Index, err), - tx.Rollback(ctx), - ) + br.Close() + return fmt.Errorf("failed to insert proof for output '%d': %w", output.Index, err) } if cmd.RowsAffected() == 0 { - return errors.Join( - fmt.Errorf( - "failed to insert proof for output '%d'. No rows affected", - output.Index, - ), - tx.Rollback(ctx), - ) + br.Close() + return fmt.Errorf("failed to insert proof for output '%d'. No rows affected", output.Index) } } - return nil + return br.Close() } func (r *PostgresRepository) StoreClaimAndProofs(ctx context.Context, epoch *model.Epoch, outputs []*model.Output) error { @@ -487,6 +480,7 @@ func (r *PostgresRepository) StoreClaimAndProofs(ctx context.Context, epoch *mod if err != nil { return fmt.Errorf("SetEpochClaimAndInsertProofsTransaction failed: %w", err) } + defer tx.Rollback(ctx) //nolint:errcheck err = updateEpochClaim(ctx, tx, epoch) if err != nil { @@ -498,14 +492,7 @@ func (r *PostgresRepository) StoreClaimAndProofs(ctx context.Context, epoch *mod return err } - err = tx.Commit(ctx) - if err != nil { - return errors.Join( - fmt.Errorf("SetEpochClaimAndInsertProofsTransaction failed: %w", err), - tx.Rollback(ctx), - ) - } - return nil + return tx.Commit(ctx) } func insertCommitments(ctx context.Context, tx pgx.Tx, appID int64, commitments []*model.Commitment) error { @@ -538,10 +525,7 @@ func insertCommitments(ctx context.Context, tx pgx.Tx, appID int64, commitments sqlStr, args := stmt.Sql() _, err := tx.Exec(ctx, sqlStr, args...) - if err != nil { - return errors.Join(err, tx.Rollback(ctx)) - } - return nil + return err } func insertMatches(ctx context.Context, tx pgx.Tx, appID int64, matches []*model.Match) error { @@ -584,10 +568,7 @@ func insertMatches(ctx context.Context, tx pgx.Tx, appID int64, matches []*model sqlStr, args := stmt.Sql() _, err := tx.Exec(ctx, sqlStr, args...) - if err != nil { - return errors.Join(err, tx.Rollback(ctx)) - } - return nil + return err } func insertMatchAdvanced(ctx context.Context, tx pgx.Tx, appID int64, matchAdvanced []*model.MatchAdvanced) error { @@ -620,13 +601,15 @@ func insertMatchAdvanced(ctx context.Context, tx pgx.Tx, appID int64, matchAdvan sqlStr, args := stmt.Sql() _, err := tx.Exec(ctx, sqlStr, args...) - if err != nil { - return errors.Join(err, tx.Rollback(ctx)) - } - return nil + return err } func updateMatches(ctx context.Context, tx pgx.Tx, appID int64, matches []*model.Match) error { + if len(matches) == 0 { + return nil + } + + batch := &pgx.Batch{} for _, m := range matches { updStmt := table.Matches.UPDATE( table.Matches.Winner, @@ -640,28 +623,33 @@ func updateMatches(ctx context.Context, tx pgx.Tx, appID int64, matches []*model m.DeletionTxHash, ).WHERE( table.Matches.ApplicationID.EQ(postgres.Int64(appID)). - AND(table.Matches.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", m.EpochIndex)))). + AND(table.Matches.EpochIndex.EQ(uint64Expr(m.EpochIndex))). AND(table.Matches.TournamentAddress.EQ(postgres.Bytea(m.TournamentAddress.Bytes()))). AND(table.Matches.IDHash.EQ(postgres.Bytea(m.IDHash.Bytes()))), ) - sqlStr, args := updStmt.Sql() - cmd, err := tx.Exec(ctx, sqlStr, args...) + batch.Queue(sqlStr, args...) + } + + br := tx.SendBatch(ctx, batch) + + for _, m := range matches { + cmd, err := br.Exec() if err != nil { - return errors.Join(err, tx.Rollback(ctx)) + br.Close() + return err } if cmd.RowsAffected() == 0 { - return errors.Join( - fmt.Errorf("no match found for update: app %d, epoch %d, tournament %s, idHash %s", m.ApplicationID, m.EpochIndex, m.TournamentAddress.Hex(), m.IDHash.Hex()), - tx.Rollback(ctx), - ) + br.Close() + return fmt.Errorf("no match found for update: app %d, epoch %d, tournament %s, idHash %s", + m.ApplicationID, m.EpochIndex, m.TournamentAddress.Hex(), m.IDHash.Hex()) } } - return nil + return br.Close() } func updateLastProcessedBlock(ctx context.Context, tx pgx.Tx, appID int64, lastProcessedBlock uint64) error { - lastBlock := postgres.RawFloat(fmt.Sprintf("%d", lastProcessedBlock)) + lastBlock := uint64Expr(lastProcessedBlock) appUpdateStmt := table.Application. UPDATE( table.Application.LastTournamentCheckBlock, @@ -676,10 +664,7 @@ func updateLastProcessedBlock(ctx context.Context, tx pgx.Tx, appID int64, lastP sqlStr, args := appUpdateStmt.Sql() _, err := tx.Exec(ctx, sqlStr, args...) - if err != nil { - return errors.Join(err, tx.Rollback(ctx)) - } - return nil + return err } func (r *PostgresRepository) StoreTournamentEvents( @@ -695,6 +680,7 @@ func (r *PostgresRepository) StoreTournamentEvents( if err != nil { return err } + defer tx.Rollback(ctx) //nolint:errcheck err = insertCommitments(ctx, tx, appID, commitments) if err != nil { @@ -721,10 +707,5 @@ func (r *PostgresRepository) StoreTournamentEvents( return err } - err = tx.Commit(ctx) - if err != nil { - return errors.Join(err, tx.Rollback(ctx)) - } - - return nil + return tx.Commit(ctx) } diff --git a/internal/repository/postgres/claimer.go b/internal/repository/postgres/claimer.go index 746986938..cd13526ee 100644 --- a/internal/repository/postgres/claimer.go +++ b/internal/repository/postgres/claimer.go @@ -12,18 +12,16 @@ import ( "github.com/jackc/pgx/v5" "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" "github.com/cartesi/rollups-node/internal/repository/postgres/db/rollupsdb/public/enum" "github.com/cartesi/rollups-node/internal/repository/postgres/db/rollupsdb/public/table" ) -var ( - ErrNoUpdate = fmt.Errorf("update did not take effect") -) - // Retrieve the claim of each application with the smallest index. // The query may return either 0 or 1 entries per application. func (r *PostgresRepository) selectOldestClaimPerApp( ctx context.Context, + tx pgx.Tx, epochStatus model.EpochStatus, ) ( map[int64]*model.Epoch, @@ -31,7 +29,7 @@ func (r *PostgresRepository) selectOldestClaimPerApp( error, ) { if (epochStatus != model.EpochStatus_ClaimSubmitted) && (epochStatus != model.EpochStatus_ClaimComputed) { - return nil, nil, fmt.Errorf("Invalid epoch status: %v", epochStatus) + return nil, nil, fmt.Errorf("invalid epoch status: %v", epochStatus) } // NOTE(mpolitzer): DISTINCT ON is a postgres extension. To implement @@ -86,7 +84,7 @@ func (r *PostgresRepository) selectOldestClaimPerApp( ) sqlStr, args := stmt.Sql() - rows, err := r.db.Query(ctx, sqlStr, args...) + rows, err := tx.Query(ctx, sqlStr, args...) if err != nil { return nil, nil, err } @@ -133,12 +131,16 @@ func (r *PostgresRepository) selectOldestClaimPerApp( epochs[application.ID] = &epoch applications[application.ID] = &application } + if err := rows.Err(); err != nil { + return nil, nil, err + } return epochs, applications, nil } // Retrieve the newest accepted claim of each application func (r *PostgresRepository) selectNewestAcceptedClaimPerApp( ctx context.Context, + tx pgx.Tx, includeSubmitted bool, ) ( map[int64]*model.Epoch, @@ -182,7 +184,7 @@ func (r *PostgresRepository) selectNewestAcceptedClaimPerApp( ) sqlStr, args := stmt.Sql() - rows, err := r.db.Query(ctx, sqlStr, args...) + rows, err := tx.Query(ctx, sqlStr, args...) if err != nil { return nil, err } @@ -208,6 +210,9 @@ func (r *PostgresRepository) selectNewestAcceptedClaimPerApp( } epochs[epoch.ApplicationID] = &epoch } + if err := rows.Err(); err != nil { + return nil, err + } return epochs, nil } @@ -224,14 +229,15 @@ func (r *PostgresRepository) SelectSubmittedClaimPairsPerApp(ctx context.Context if err != nil { return nil, nil, nil, err } - defer tx.Commit(ctx) + // Read-only tx: rollback releases the snapshot, equivalent to commit. + defer tx.Rollback(ctx) //nolint:errcheck - computed, applications, err := r.selectOldestClaimPerApp(ctx, model.EpochStatus_ClaimComputed) + computed, applications, err := r.selectOldestClaimPerApp(ctx, tx, model.EpochStatus_ClaimComputed) if err != nil { return nil, nil, nil, err } - acceptedOrSubmitted, err := r.selectNewestAcceptedClaimPerApp(ctx, true) + acceptedOrSubmitted, err := r.selectNewestAcceptedClaimPerApp(ctx, tx, true) if err != nil { return nil, nil, nil, err } @@ -252,14 +258,15 @@ func (r *PostgresRepository) SelectAcceptedClaimPairsPerApp(ctx context.Context) if err != nil { return nil, nil, nil, err } - defer tx.Commit(ctx) + // Read-only tx: rollback releases the snapshot, equivalent to commit. + defer tx.Rollback(ctx) //nolint:errcheck - submitted, applications, err := r.selectOldestClaimPerApp(ctx, model.EpochStatus_ClaimSubmitted) + submitted, applications, err := r.selectOldestClaimPerApp(ctx, tx, model.EpochStatus_ClaimSubmitted) if err != nil { return nil, nil, nil, err } - accepted, err := r.selectNewestAcceptedClaimPerApp(ctx, false) + accepted, err := r.selectNewestAcceptedClaimPerApp(ctx, tx, false) if err != nil { return nil, nil, nil, err } @@ -269,9 +276,9 @@ func (r *PostgresRepository) SelectAcceptedClaimPairsPerApp(ctx context.Context) func (r *PostgresRepository) UpdateEpochWithSubmittedClaim( ctx context.Context, - application_id int64, + applicationID int64, index uint64, - transaction_hash common.Hash, + transactionHash common.Hash, ) error { updStmt := table.Epoch. UPDATE( @@ -279,15 +286,15 @@ func (r *PostgresRepository) UpdateEpochWithSubmittedClaim( table.Epoch.Status, ). SET( - transaction_hash, + transactionHash, postgres.NewEnumValue(model.EpochStatus_ClaimSubmitted.String()), ). FROM( table.Application, ). WHERE( - table.Epoch.ApplicationID.EQ(postgres.Int64(application_id)). - AND(table.Epoch.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", index)))). + table.Epoch.ApplicationID.EQ(postgres.Int64(applicationID)). + AND(table.Epoch.Index.EQ(uint64Expr(index))). AND(table.Epoch.Status.EQ(postgres.NewEnumValue(model.EpochStatus_ClaimComputed.String()))), ) @@ -297,14 +304,14 @@ func (r *PostgresRepository) UpdateEpochWithSubmittedClaim( return err } if cmd.RowsAffected() == 0 { - return ErrNoUpdate + return repository.ErrNoUpdate } return nil } func (r *PostgresRepository) UpdateEpochWithAcceptedClaim( ctx context.Context, - application_id int64, + applicationID int64, index uint64, ) error { updStmt := table.Epoch. @@ -318,8 +325,8 @@ func (r *PostgresRepository) UpdateEpochWithAcceptedClaim( table.Application, ). WHERE( - table.Epoch.ApplicationID.EQ(postgres.Int64(application_id)). - AND(table.Epoch.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", index)))). + table.Epoch.ApplicationID.EQ(postgres.Int64(applicationID)). + AND(table.Epoch.Index.EQ(uint64Expr(index))). AND(table.Epoch.Status.EQ(postgres.NewEnumValue(model.EpochStatus_ClaimSubmitted.String()))), ) @@ -329,7 +336,7 @@ func (r *PostgresRepository) UpdateEpochWithAcceptedClaim( return err } if cmd.RowsAffected() == 0 { - return ErrNoUpdate + return repository.ErrNoUpdate } return nil } diff --git a/internal/repository/postgres/commitment.go b/internal/repository/postgres/commitment.go index d3bf00349..3293fefd8 100644 --- a/internal/repository/postgres/commitment.go +++ b/internal/repository/postgres/commitment.go @@ -7,7 +7,6 @@ import ( "context" "database/sql" "errors" - "fmt" "github.com/ethereum/go-ethereum/common" "github.com/go-jet/jet/v2/postgres" @@ -25,19 +24,16 @@ func (r *PostgresRepository) CreateCommitment( c *model.Commitment, ) error { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) selectQuery := table.Application.SELECT( table.Application.ID, - postgres.RawFloat(fmt.Sprintf("%d", c.EpochIndex)), + uint64Expr(c.EpochIndex), postgres.Bytea(c.TournamentAddress.Bytes()), postgres.Bytea(c.Commitment.Bytes()), postgres.Bytea(c.FinalStateHash.Bytes()), postgres.Bytea(c.SubmitterAddress.Bytes()), - postgres.RawFloat(fmt.Sprintf("%d", c.BlockNumber)), + uint64Expr(c.BlockNumber), postgres.Bytea(c.TxHash.Bytes()), ).WHERE( whereClause, @@ -57,7 +53,7 @@ func (r *PostgresRepository) CreateCommitment( ) sqlStr, args := insertStmt.Sql() - _, err = r.db.Exec(ctx, sqlStr, args...) + _, err := r.db.Exec(ctx, sqlStr, args...) return err } @@ -70,10 +66,7 @@ func (r *PostgresRepository) GetCommitment( commitmentHex string, ) (*model.Commitment, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) tournamentAddr := common.HexToAddress(tournamentAddress) commitment := common.HexToHash(commitmentHex) @@ -99,7 +92,7 @@ func (r *PostgresRepository) GetCommitment( ). WHERE( whereClause. - AND(table.Commitments.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", epochIndex)))). + AND(table.Commitments.EpochIndex.EQ(uint64Expr(epochIndex))). AND(table.Commitments.TournamentAddress.EQ(postgres.Bytea(tournamentAddr.Bytes()))). AND(table.Commitments.Commitment.EQ(postgres.Bytea(commitment.Bytes()))), ) @@ -108,7 +101,7 @@ func (r *PostgresRepository) GetCommitment( row := r.db.QueryRow(ctx, sqlStr, args...) var c model.Commitment - err = row.Scan( + err := row.Scan( &c.ApplicationID, &c.EpochIndex, &c.TournamentAddress, @@ -137,10 +130,7 @@ func (r *PostgresRepository) ListCommitments( descending bool, ) ([]*model.Commitment, uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Commitments. SELECT( @@ -165,7 +155,7 @@ func (r *PostgresRepository) ListCommitments( conditions := []postgres.BoolExpression{whereClause} if f.EpochIndex != nil { - conditions = append(conditions, table.Commitments.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", *f.EpochIndex)))) + conditions = append(conditions, table.Commitments.EpochIndex.EQ(uint64Expr(*f.EpochIndex))) } if f.TournamentAddress != nil { tournamentAddr := common.HexToAddress(*f.TournamentAddress) @@ -225,6 +215,9 @@ func (r *PostgresRepository) ListCommitments( } commitments = append(commitments, &c) } + if err := rows.Err(); err != nil { + return nil, 0, err + } return commitments, total, nil } diff --git a/internal/repository/postgres/epoch.go b/internal/repository/postgres/epoch.go index 6052f9fa5..abd7d688e 100644 --- a/internal/repository/postgres/epoch.go +++ b/internal/repository/postgres/epoch.go @@ -24,10 +24,7 @@ func getEpochNextVirtualIndex( nameOrAddress string, ) (uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) query := table.Epoch.SELECT( postgres.COALESCE( @@ -42,10 +39,9 @@ func getEpochNextVirtualIndex( queryStr, args := query.Sql() var currentIndex uint64 - err = tx.QueryRow(ctx, queryStr, args...).Scan(¤tIndex) + err := tx.QueryRow(ctx, queryStr, args...).Scan(¤tIndex) if err != nil { - err = fmt.Errorf("failed to get the next epoch virtual index: %w", err) - return 0, errors.Join(err, tx.Rollback(ctx)) + return 0, fmt.Errorf("failed to get the next epoch virtual index: %w", err) } return currentIndex, nil } @@ -70,10 +66,7 @@ func (r *PostgresRepository) CreateEpochsAndInputs( blockNumber uint64, ) error { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) epochInsertStmt := table.Epoch.INSERT( table.Epoch.ApplicationID, @@ -102,6 +95,7 @@ func (r *PostgresRepository) CreateEpochsAndInputs( if err != nil { return err } + defer tx.Rollback(ctx) //nolint:errcheck epochs := orderEpochs(epochInputsMap) for _, epoch := range epochs { @@ -120,14 +114,14 @@ func (r *PostgresRepository) CreateEpochsAndInputs( } epochSelectQuery := table.Application.SELECT( table.Application.ID, - postgres.RawFloat(fmt.Sprintf("%d", epoch.Index)), - postgres.RawFloat(fmt.Sprintf("%d", epoch.FirstBlock)), - postgres.RawFloat(fmt.Sprintf("%d", epoch.LastBlock)), - postgres.RawFloat(fmt.Sprintf("%d", epoch.InputIndexLowerBound)), - postgres.RawFloat(fmt.Sprintf("%d", epoch.InputIndexUpperBound)), + uint64Expr(epoch.Index), + uint64Expr(epoch.FirstBlock), + uint64Expr(epoch.LastBlock), + uint64Expr(epoch.InputIndexLowerBound), + uint64Expr(epoch.InputIndexUpperBound), tournamentAddress, postgres.NewEnumValue(epoch.Status.String()), - postgres.RawFloat(fmt.Sprintf("%d", nextVirtualIndex)), + uint64Expr(nextVirtualIndex), ).WHERE( whereClause, ) @@ -136,33 +130,45 @@ func (r *PostgresRepository) CreateEpochsAndInputs( ON_CONFLICT(table.Epoch.ApplicationID, table.Epoch.Index). DO_UPDATE(postgres.SET( table.Epoch.Status.SET(postgres.NewEnumValue(epoch.Status.String())), - table.Epoch.LastBlock.SET(postgres.RawFloat(fmt.Sprintf("%d", epoch.LastBlock))), - table.Epoch.InputIndexUpperBound.SET(postgres.RawFloat(fmt.Sprintf("%d", epoch.InputIndexUpperBound))), + table.Epoch.LastBlock.SET(uint64Expr(epoch.LastBlock)), + table.Epoch.InputIndexUpperBound.SET(uint64Expr(epoch.InputIndexUpperBound)), table.Epoch.TournamentAddress.SET(tournamentAddress), )).Sql() // FIXME on conflict _, err = tx.Exec(ctx, sqlStr, args...) if err != nil { - return errors.Join(err, tx.Rollback(ctx)) + return err } - for _, input := range inputs { - inputSelectQuery := table.Application.SELECT( - table.Application.ID, - postgres.RawFloat(fmt.Sprintf("%d", epoch.Index)), - postgres.RawFloat(fmt.Sprintf("%d", input.Index)), - postgres.RawFloat(fmt.Sprintf("%d", input.BlockNumber)), - postgres.Bytea(input.RawData), - postgres.NewEnumValue(input.Status.String()), - postgres.Bytea(input.TransactionReference.Bytes()), - ).WHERE( - whereClause, - ) - - sqlStr, args := inputInsertStmt.QUERY(inputSelectQuery).Sql() - _, err := tx.Exec(ctx, sqlStr, args...) - if err != nil { - return errors.Join(err, tx.Rollback(ctx)) + if len(inputs) > 0 { + batch := &pgx.Batch{} + for _, input := range inputs { + inputSelectQuery := table.Application.SELECT( + table.Application.ID, + uint64Expr(epoch.Index), + uint64Expr(input.Index), + uint64Expr(input.BlockNumber), + postgres.Bytea(input.RawData), + postgres.NewEnumValue(input.Status.String()), + postgres.Bytea(input.TransactionReference.Bytes()), + ).WHERE( + whereClause, + ) + + sqlStr, args := inputInsertStmt.QUERY(inputSelectQuery).Sql() + batch.Queue(sqlStr, args...) + } + + br := tx.SendBatch(ctx, batch) + for range inputs { + _, err := br.Exec() + if err != nil { + br.Close() + return err + } + } + if err := br.Close(); err != nil { + return err } } } @@ -173,23 +179,17 @@ func (r *PostgresRepository) CreateEpochsAndInputs( table.Application.LastInputCheckBlock, ). SET( - postgres.RawFloat(fmt.Sprintf("%d", blockNumber)), + uint64Expr(blockNumber), ). WHERE(whereClause) sqlStr, args := appUpdateStmt.Sql() _, err = tx.Exec(ctx, sqlStr, args...) if err != nil { - return errors.Join(err, tx.Rollback(ctx)) - } - - // Commit transaction - err = tx.Commit(ctx) - if err != nil { - return errors.Join(err, tx.Rollback(ctx)) + return err } - return nil + return tx.Commit(ctx) } func (r *PostgresRepository) GetEpoch( @@ -198,10 +198,7 @@ func (r *PostgresRepository) GetEpoch( index uint64, ) (*model.Epoch, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) stmt := table.Epoch. SELECT( @@ -231,14 +228,14 @@ func (r *PostgresRepository) GetEpoch( ). WHERE( whereClause. - AND(table.Epoch.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", index)))), + AND(table.Epoch.Index.EQ(uint64Expr(index))), ) sqlStr, args := stmt.Sql() row := r.db.QueryRow(ctx, sqlStr, args...) var ep model.Epoch - err = row.Scan( + err := row.Scan( &ep.ApplicationID, &ep.Index, &ep.FirstBlock, @@ -271,10 +268,7 @@ func (r *PostgresRepository) GetLastAcceptedEpochIndex( nameOrAddress string, ) (uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) stmt := table.Epoch. SELECT( @@ -297,7 +291,7 @@ func (r *PostgresRepository) GetLastAcceptedEpochIndex( row := r.db.QueryRow(ctx, sqlStr, args...) var index uint64 - err = row.Scan( + err := row.Scan( &index, ) if errors.Is(err, sql.ErrNoRows) { @@ -314,10 +308,7 @@ func (r *PostgresRepository) GetLastNonOpenEpoch( nameOrAddress string, ) (*model.Epoch, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) stmt := table.Epoch. SELECT( @@ -356,7 +347,7 @@ func (r *PostgresRepository) GetLastNonOpenEpoch( row := r.db.QueryRow(ctx, sqlStr, args...) var ep model.Epoch - err = row.Scan( + err := row.Scan( &ep.ApplicationID, &ep.Index, &ep.FirstBlock, @@ -390,10 +381,7 @@ func (r *PostgresRepository) GetEpochByVirtualIndex( index uint64, ) (*model.Epoch, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) stmt := table.Epoch. SELECT( @@ -423,14 +411,14 @@ func (r *PostgresRepository) GetEpochByVirtualIndex( ). WHERE( whereClause. - AND(table.Epoch.VirtualIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", index)))), + AND(table.Epoch.VirtualIndex.EQ(uint64Expr(index))), ) sqlStr, args := stmt.Sql() row := r.db.QueryRow(ctx, sqlStr, args...) var ep model.Epoch - err = row.Scan( + err := row.Scan( &ep.ApplicationID, &ep.Index, &ep.FirstBlock, @@ -464,10 +452,7 @@ func (r *PostgresRepository) UpdateEpochClaimTransactionHash( e *model.Epoch, ) error { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) updStmt := table.Epoch. UPDATE( @@ -482,7 +467,7 @@ func (r *PostgresRepository) UpdateEpochClaimTransactionHash( WHERE( whereClause. AND(table.Epoch.ApplicationID.EQ(table.Application.ID)). - AND(table.Epoch.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", e.Index)))), + AND(table.Epoch.Index.EQ(uint64Expr(e.Index))), ) sqlStr, args := updStmt.Sql() @@ -491,7 +476,7 @@ func (r *PostgresRepository) UpdateEpochClaimTransactionHash( return err } if cmd.RowsAffected() == 0 { - return sql.ErrNoRows + return repository.ErrNotFound } return nil } @@ -501,18 +486,15 @@ func (r *PostgresRepository) UpdateEpochOutputsProof(ctx context.Context, appID if err != nil { return err } + defer tx.Rollback(ctx) //nolint:errcheck err = updateEpochOutputsMerkleProof(ctx, tx, appID, epochIndex, proof.OutputsHash, byteSliceToHashSlice(proof.OutputsHashProof), proof.MachineHash) if err != nil { - return errors.Join(err, tx.Rollback(ctx)) + return err } - err = tx.Commit(ctx) - if err != nil { - return errors.Join(err, tx.Rollback(ctx)) - } - return nil + return tx.Commit(ctx) } func (r *PostgresRepository) UpdateEpochStatus( @@ -521,10 +503,7 @@ func (r *PostgresRepository) UpdateEpochStatus( e *model.Epoch, ) error { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) updStmt := table.Epoch. UPDATE( @@ -539,7 +518,7 @@ func (r *PostgresRepository) UpdateEpochStatus( WHERE( whereClause. AND(table.Epoch.ApplicationID.EQ(table.Application.ID)). - AND(table.Epoch.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", e.Index)))), + AND(table.Epoch.Index.EQ(uint64Expr(e.Index))), ) sqlStr, args := updStmt.Sql() @@ -548,7 +527,7 @@ func (r *PostgresRepository) UpdateEpochStatus( return err } if cmd.RowsAffected() == 0 { - return sql.ErrNoRows + return repository.ErrNotFound } return nil } @@ -559,10 +538,7 @@ func (r *PostgresRepository) UpdateEpochInputsProcessed( epochIndex uint64, ) error { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) // Subquery to check if the previous epoch is not open or closed prevTable := table.Epoch.AS("prev") @@ -608,7 +584,7 @@ func (r *PostgresRepository) UpdateEpochInputsProcessed( WHERE(postgres.AND( table.Epoch.Status.EQ(postgres.NewEnumValue(model.EpochStatus_Closed.String())), table.Epoch.ApplicationID.EQ(table.Application.ID), - table.Epoch.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", epochIndex))), + table.Epoch.Index.EQ(uint64Expr(epochIndex)), whereClause, prevCondition, inputsCondition, @@ -619,9 +595,9 @@ func (r *PostgresRepository) UpdateEpochInputsProcessed( sqlStr, args := updateStmt.Sql() var index uint64 - err = r.db.QueryRow(ctx, sqlStr, args...).Scan(&index) + err := r.db.QueryRow(ctx, sqlStr, args...).Scan(&index) if err != nil { - if err == pgx.ErrNoRows { + if errors.Is(err, pgx.ErrNoRows) { return nil } return err @@ -641,10 +617,7 @@ func (r *PostgresRepository) ListEpochs( descending bool, ) ([]*model.Epoch, uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Epoch. SELECT( @@ -682,7 +655,7 @@ func (r *PostgresRepository) ListEpochs( } if f.BeforeBlock != nil { - conditions = append(conditions, table.Epoch.LastBlock.LT(postgres.RawFloat(fmt.Sprintf("%d", *f.BeforeBlock)))) + conditions = append(conditions, table.Epoch.LastBlock.LT(uint64Expr(*f.BeforeBlock))) } sel = sel.WHERE(postgres.AND(conditions...)) @@ -735,6 +708,9 @@ func (r *PostgresRepository) ListEpochs( } epochs = append(epochs, &ep) } + if err := rows.Err(); err != nil { + return nil, 0, err + } return epochs, total, nil } @@ -763,9 +739,9 @@ func (r *PostgresRepository) RepeatPreviousEpochOutputsProof( FROM(e2). WHERE(postgres.AND( e1.ApplicationID.EQ(postgres.Int64(appID)), - e1.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", epochIndex))), + e1.Index.EQ(uint64Expr(epochIndex)), e2.ApplicationID.EQ(postgres.Int64(appID)), - e2.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", epochIndex-1))), + e2.Index.EQ(uint64Expr(epochIndex-1)), )) sqlStr, args := updStmt.Sql() @@ -775,7 +751,7 @@ func (r *PostgresRepository) RepeatPreviousEpochOutputsProof( return err } if cmd.RowsAffected() == 0 { - return sql.ErrNoRows + return repository.ErrNotFound } return nil } diff --git a/internal/repository/postgres/input.go b/internal/repository/postgres/input.go index 8e4d76740..649bb23f3 100644 --- a/internal/repository/postgres/input.go +++ b/internal/repository/postgres/input.go @@ -22,10 +22,7 @@ func (r *PostgresRepository) GetInput( inputIndex uint64, ) (*model.Input, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Input. SELECT( @@ -50,14 +47,14 @@ func (r *PostgresRepository) GetInput( ). WHERE( whereClause. - AND(table.Input.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", inputIndex)))), + AND(table.Input.Index.EQ(uint64Expr(inputIndex))), ) sqlStr, args := sel.Sql() row := r.db.QueryRow(ctx, sqlStr, args...) var inp model.Input - err = row.Scan( + err := row.Scan( &inp.EpochApplicationID, &inp.EpochIndex, &inp.Index, @@ -90,10 +87,7 @@ func (r *PostgresRepository) GetInputByTxReference( return nil, fmt.Errorf("tx reference is nil") } - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Input. SELECT( @@ -125,7 +119,7 @@ func (r *PostgresRepository) GetInputByTxReference( row := r.db.QueryRow(ctx, sqlStr, args...) var inp model.Input - err = row.Scan( + err := row.Scan( &inp.EpochApplicationID, &inp.EpochIndex, &inp.Index, @@ -154,10 +148,7 @@ func (r *PostgresRepository) GetLastInput( epochIndex uint64, ) (*model.Input, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Input. SELECT( @@ -182,7 +173,7 @@ func (r *PostgresRepository) GetLastInput( ). WHERE( whereClause. - AND(table.Input.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", epochIndex)))), + AND(table.Input.EpochIndex.EQ(uint64Expr(epochIndex))), ). ORDER_BY(table.Input.Index.DESC()). LIMIT(1) @@ -191,7 +182,7 @@ func (r *PostgresRepository) GetLastInput( row := r.db.QueryRow(ctx, sqlStr, args...) var inp model.Input - err = row.Scan( + err := row.Scan( &inp.EpochApplicationID, &inp.EpochIndex, &inp.Index, @@ -219,10 +210,7 @@ func (r *PostgresRepository) GetLastProcessedInput( nameOrAddress string, ) (*model.Input, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Input. SELECT( @@ -256,7 +244,7 @@ func (r *PostgresRepository) GetLastProcessedInput( row := r.db.QueryRow(ctx, sqlStr, args...) var inp model.Input - err = row.Scan( + err := row.Scan( &inp.EpochApplicationID, &inp.EpochIndex, &inp.Index, @@ -287,10 +275,7 @@ func (r *PostgresRepository) ListInputs( descending bool, ) ([]*model.Input, uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Input. SELECT( @@ -317,7 +302,7 @@ func (r *PostgresRepository) ListInputs( conditions := []postgres.BoolExpression{whereClause} if f.EpochIndex != nil { - conditions = append(conditions, table.Input.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", *f.EpochIndex)))) + conditions = append(conditions, table.Input.EpochIndex.EQ(uint64Expr(*f.EpochIndex))) } if f.Status != nil { @@ -380,6 +365,9 @@ func (r *PostgresRepository) ListInputs( } inputs = append(inputs, &in) } + if err := rows.Err(); err != nil { + return nil, 0, err + } return inputs, total, nil } @@ -388,10 +376,7 @@ func (r *PostgresRepository) GetNumberOfInputs( nameOrAddress string, ) (uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Input. SELECT(postgres.COUNT(postgres.STAR)). @@ -407,7 +392,7 @@ func (r *PostgresRepository) GetNumberOfInputs( row := r.db.QueryRow(ctx, sqlStr, args...) var count uint64 - err = row.Scan(&count) + err := row.Scan(&count) if err != nil { return 0, err } @@ -424,7 +409,7 @@ func (r *PostgresRepository) UpdateInputSnapshotURI(ctx context.Context, appId i ). WHERE( table.Input.EpochApplicationID.EQ(postgres.Int64(appId)). - AND(table.Input.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", inputIndex)))), + AND(table.Input.Index.EQ(uint64Expr(inputIndex))), ) sqlStr, args := updStmt.Sql() diff --git a/internal/repository/postgres/match.go b/internal/repository/postgres/match.go index 371f94dd9..10da18520 100644 --- a/internal/repository/postgres/match.go +++ b/internal/repository/postgres/match.go @@ -7,7 +7,6 @@ import ( "context" "database/sql" "errors" - "fmt" "github.com/ethereum/go-ethereum/common" "github.com/go-jet/jet/v2/postgres" @@ -25,24 +24,21 @@ func (r *PostgresRepository) CreateMatch( m *model.Match, ) error { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) selectQuery := table.Application.SELECT( table.Application.ID, - postgres.RawFloat(fmt.Sprintf("%d", m.EpochIndex)), + uint64Expr(m.EpochIndex), postgres.Bytea(m.TournamentAddress.Bytes()), postgres.Bytea(m.IDHash.Bytes()), postgres.Bytea(m.CommitmentOne.Bytes()), postgres.Bytea(m.CommitmentTwo.Bytes()), postgres.Bytea(m.LeftOfTwo.Bytes()), - postgres.RawFloat(fmt.Sprintf("%d", m.BlockNumber)), + uint64Expr(m.BlockNumber), postgres.Bytea(m.TxHash.Bytes()), postgres.NewEnumValue(m.Winner.String()), postgres.NewEnumValue(m.DeletionReason.String()), - postgres.RawFloat(fmt.Sprintf("%d", m.DeletionBlockNumber)), + uint64Expr(m.DeletionBlockNumber), postgres.Bytea(m.DeletionTxHash.Bytes()), ).WHERE( whereClause, @@ -67,7 +63,7 @@ func (r *PostgresRepository) CreateMatch( ) sqlStr, args := insertStmt.Sql() - _, err = r.db.Exec(ctx, sqlStr, args...) + _, err := r.db.Exec(ctx, sqlStr, args...) return err } @@ -78,10 +74,7 @@ func (r *PostgresRepository) UpdateMatch( m *model.Match, ) error { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) updateStmt := table.Matches. UPDATE( @@ -102,7 +95,7 @@ func (r *PostgresRepository) UpdateMatch( WHERE( whereClause. AND(table.Matches.ApplicationID.EQ(postgres.Int(m.ApplicationID))). - AND(table.Matches.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", m.EpochIndex)))). + AND(table.Matches.EpochIndex.EQ(uint64Expr(m.EpochIndex))). AND(table.Matches.TournamentAddress.EQ(postgres.Bytea(m.TournamentAddress.Bytes()))). AND(table.Matches.IDHash.EQ(postgres.Bytea(m.IDHash.Bytes()))), ) @@ -113,7 +106,7 @@ func (r *PostgresRepository) UpdateMatch( return err } if cmd.RowsAffected() == 0 { - return sql.ErrNoRows + return repository.ErrNotFound } return nil } @@ -126,10 +119,7 @@ func (r *PostgresRepository) GetMatch( idHashHex string, ) (*model.Match, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) tournamentAddr := common.HexToAddress(tournamentAddress) idHash := common.HexToHash(idHashHex) @@ -160,7 +150,7 @@ func (r *PostgresRepository) GetMatch( ). WHERE( whereClause. - AND(table.Matches.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", epochIndex)))). + AND(table.Matches.EpochIndex.EQ(uint64Expr(epochIndex))). AND(table.Matches.TournamentAddress.EQ(postgres.Bytea(tournamentAddr.Bytes()))). AND(table.Matches.IDHash.EQ(postgres.Bytea(idHash.Bytes()))), ) @@ -169,7 +159,7 @@ func (r *PostgresRepository) GetMatch( row := r.db.QueryRow(ctx, sqlStr, args...) var m model.Match - err = row.Scan( + err := row.Scan( &m.ApplicationID, &m.EpochIndex, &m.TournamentAddress, @@ -203,10 +193,7 @@ func (r *PostgresRepository) ListMatches( descending bool, ) ([]*model.Match, uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Matches. SELECT( @@ -236,7 +223,7 @@ func (r *PostgresRepository) ListMatches( conditions := []postgres.BoolExpression{whereClause} if f.EpochIndex != nil { - conditions = append(conditions, table.Matches.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", *f.EpochIndex)))) + conditions = append(conditions, table.Matches.EpochIndex.EQ(uint64Expr(*f.EpochIndex))) } if f.TournamentAddress != nil { tournamentAddr := common.HexToAddress(*f.TournamentAddress) @@ -293,6 +280,9 @@ func (r *PostgresRepository) ListMatches( } matches = append(matches, &m) } + if err := rows.Err(); err != nil { + return nil, 0, err + } return matches, total, nil } diff --git a/internal/repository/postgres/match_advanced.go b/internal/repository/postgres/match_advanced.go index 635fde749..2309bf4b3 100644 --- a/internal/repository/postgres/match_advanced.go +++ b/internal/repository/postgres/match_advanced.go @@ -7,7 +7,6 @@ import ( "context" "database/sql" "errors" - "fmt" "github.com/ethereum/go-ethereum/common" "github.com/go-jet/jet/v2/postgres" @@ -25,19 +24,16 @@ func (r *PostgresRepository) CreateMatchAdvanced( m *model.MatchAdvanced, ) error { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) selectQuery := table.Application.SELECT( table.Application.ID, - postgres.RawFloat(fmt.Sprintf("%d", m.EpochIndex)), + uint64Expr(m.EpochIndex), postgres.Bytea(m.TournamentAddress.Bytes()), postgres.Bytea(m.IDHash.Bytes()), postgres.Bytea(m.OtherParent.Bytes()), postgres.Bytea(m.LeftNode.Bytes()), - postgres.RawFloat(fmt.Sprintf("%d", m.BlockNumber)), + uint64Expr(m.BlockNumber), postgres.Bytea(m.TxHash.Bytes()), ).WHERE( whereClause, @@ -57,7 +53,7 @@ func (r *PostgresRepository) CreateMatchAdvanced( ) sqlStr, args := insertStmt.Sql() - _, err = r.db.Exec(ctx, sqlStr, args...) + _, err := r.db.Exec(ctx, sqlStr, args...) return err } @@ -71,10 +67,7 @@ func (r *PostgresRepository) GetMatchAdvanced( parentHex string, ) (*model.MatchAdvanced, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) tournamentAddr := common.HexToAddress(tournamentAddress) idHash := common.HexToHash(idHashHex) @@ -101,7 +94,7 @@ func (r *PostgresRepository) GetMatchAdvanced( ). WHERE( whereClause. - AND(table.MatchAdvances.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", epochIndex)))). + AND(table.MatchAdvances.EpochIndex.EQ(uint64Expr(epochIndex))). AND(table.MatchAdvances.TournamentAddress.EQ(postgres.Bytea(tournamentAddr.Bytes()))). AND(table.MatchAdvances.IDHash.EQ(postgres.Bytea(idHash.Bytes()))). AND(table.MatchAdvances.OtherParent.EQ(postgres.Bytea(parent.Bytes()))), @@ -111,7 +104,7 @@ func (r *PostgresRepository) GetMatchAdvanced( row := r.db.QueryRow(ctx, sqlStr, args...) var m model.MatchAdvanced - err = row.Scan( + err := row.Scan( &m.ApplicationID, &m.EpochIndex, &m.TournamentAddress, @@ -142,10 +135,7 @@ func (r *PostgresRepository) ListMatchAdvances( descending bool, ) ([]*model.MatchAdvanced, uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.MatchAdvances. SELECT( @@ -169,7 +159,7 @@ func (r *PostgresRepository) ListMatchAdvances( ) conditions := []postgres.BoolExpression{whereClause} - conditions = append(conditions, table.MatchAdvances.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", epochIndex)))) + conditions = append(conditions, table.MatchAdvances.EpochIndex.EQ(uint64Expr(epochIndex))) tAddr := common.HexToAddress(tournamentAddress) conditions = append(conditions, table.MatchAdvances.TournamentAddress.EQ(postgres.Bytea(tAddr.Bytes()))) @@ -232,6 +222,9 @@ func (r *PostgresRepository) ListMatchAdvances( } matchAdvances = append(matchAdvances, &m) } + if err := rows.Err(); err != nil { + return nil, 0, err + } return matchAdvances, total, nil } diff --git a/internal/repository/postgres/output.go b/internal/repository/postgres/output.go index 31425b6c1..98ad458fe 100644 --- a/internal/repository/postgres/output.go +++ b/internal/repository/postgres/output.go @@ -22,10 +22,7 @@ func (r *PostgresRepository) GetOutput( outputIndex uint64, ) (*model.Output, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Output. SELECT( @@ -52,14 +49,14 @@ func (r *PostgresRepository) GetOutput( ). WHERE( whereClause. - AND(table.Output.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", outputIndex)))), + AND(table.Output.Index.EQ(uint64Expr(outputIndex))), ) sqlStr, args := sel.Sql() row := r.db.QueryRow(ctx, sqlStr, args...) var o model.Output - err = row.Scan( + err := row.Scan( &o.InputEpochApplicationID, &o.InputIndex, &o.Index, @@ -87,21 +84,19 @@ func (r *PostgresRepository) UpdateOutputsExecution( lastOutputCheckBlock uint64, ) error { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) tx, err := r.db.Begin(ctx) if err != nil { return err } + defer tx.Rollback(ctx) //nolint:errcheck for _, o := range outputs { if o.ExecutionTransactionHash == nil { - return errors.Join( - fmt.Errorf("output ExecutionTransactionHash must be not nil when updating app %s output %d", nameOrAddress, o.Index), - tx.Rollback(ctx), + return fmt.Errorf( + "output ExecutionTransactionHash must be not nil when updating app %s output %d", + nameOrAddress, o.Index, ) } updStmt := table.Output. @@ -117,19 +112,16 @@ func (r *PostgresRepository) UpdateOutputsExecution( WHERE( whereClause. AND(table.Output.InputEpochApplicationID.EQ(table.Application.ID)). - AND(table.Output.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", o.Index)))), + AND(table.Output.Index.EQ(uint64Expr(o.Index))), ) sqlStr, args := updStmt.Sql() cmd, err := tx.Exec(ctx, sqlStr, args...) if err != nil { - return errors.Join(err, tx.Rollback(ctx)) + return err } if cmd.RowsAffected() != 1 { - return errors.Join( - fmt.Errorf("no row affected when updating app %s epoch %d", nameOrAddress, o.Index), - tx.Rollback(ctx), - ) + return fmt.Errorf("no row affected when updating app %s output %d", nameOrAddress, o.Index) } } @@ -139,23 +131,17 @@ func (r *PostgresRepository) UpdateOutputsExecution( table.Application.LastOutputCheckBlock, ). SET( - postgres.RawFloat(fmt.Sprintf("%d", lastOutputCheckBlock)), + uint64Expr(lastOutputCheckBlock), ). WHERE(whereClause) sqlStr, args := appUpdateStmt.Sql() _, err = tx.Exec(ctx, sqlStr, args...) if err != nil { - return errors.Join(err, tx.Rollback(ctx)) - } - - // Commit transaction - err = tx.Commit(ctx) - if err != nil { - return errors.Join(err, tx.Rollback(ctx)) + return err } - return nil + return tx.Commit(ctx) } func (r *PostgresRepository) ListOutputs( @@ -166,10 +152,7 @@ func (r *PostgresRepository) ListOutputs( descending bool, ) ([]*model.Output, uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Output. SELECT( @@ -199,19 +182,19 @@ func (r *PostgresRepository) ListOutputs( conditions := []postgres.BoolExpression{whereClause} if f.BlockRange != nil { conditions = append(conditions, table.Input.BlockNumber.BETWEEN( - postgres.RawFloat(fmt.Sprintf("%d", f.BlockRange.Start)), - postgres.RawFloat(fmt.Sprintf("%d", f.BlockRange.End)), + uint64Expr(f.BlockRange.Start), + uint64Expr(f.BlockRange.End), )) conditions = append(conditions, table.Input.Status.EQ(postgres.NewEnumValue(model.InputCompletionStatus_Accepted.String()))) } if f.EpochIndex != nil { - conditions = append(conditions, table.Input.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", *f.EpochIndex)))) + conditions = append(conditions, table.Input.EpochIndex.EQ(uint64Expr(*f.EpochIndex))) conditions = append(conditions, table.Input.Status.EQ(postgres.NewEnumValue(model.InputCompletionStatus_Accepted.String()))) } if f.InputIndex != nil { - conditions = append(conditions, table.Output.InputIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", *f.InputIndex)))) + conditions = append(conditions, table.Output.InputIndex.EQ(uint64Expr(*f.InputIndex))) } if f.OutputType != nil { @@ -270,6 +253,9 @@ func (r *PostgresRepository) ListOutputs( } outputs = append(outputs, &out) } + if err := rows.Err(); err != nil { + return nil, 0, err + } return outputs, total, nil } @@ -279,10 +265,7 @@ func (r *PostgresRepository) GetLastOutputBeforeBlock( block uint64, ) (*model.Output, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Output. SELECT( @@ -311,7 +294,7 @@ func (r *PostgresRepository) GetLastOutputBeforeBlock( WHERE( postgres.AND( whereClause, - table.Input.BlockNumber.LT(postgres.RawFloat(fmt.Sprintf("%d", block))), + table.Input.BlockNumber.LT(uint64Expr(block)), table.Input.Status.EQ(postgres.NewEnumValue(model.InputCompletionStatus_Accepted.String())), ), ). @@ -322,7 +305,7 @@ func (r *PostgresRepository) GetLastOutputBeforeBlock( row := r.db.QueryRow(ctx, sqlStr, args...) var out model.Output - err = row.Scan( + err := row.Scan( &out.InputEpochApplicationID, &out.InputIndex, &out.Index, @@ -348,10 +331,7 @@ func (r *PostgresRepository) GetNumberOfExecutedOutputs( nameOrAddress string, ) (uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Output. SELECT(postgres.COUNT(postgres.STAR)). @@ -367,7 +347,7 @@ func (r *PostgresRepository) GetNumberOfExecutedOutputs( row := r.db.QueryRow(ctx, sqlStr, args...) var count uint64 - err = row.Scan(&count) + err := row.Scan(&count) if err != nil { return 0, err } diff --git a/internal/repository/postgres/postgres_repo_test.go b/internal/repository/postgres/postgres_repo_test.go new file mode 100644 index 000000000..96e3a8e0d --- /dev/null +++ b/internal/repository/postgres/postgres_repo_test.go @@ -0,0 +1,33 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package postgres_test + +import ( + "context" + "testing" + + "github.com/cartesi/rollups-node/internal/repository" + "github.com/cartesi/rollups-node/internal/repository/factory" + "github.com/cartesi/rollups-node/internal/repository/repotest" + "github.com/cartesi/rollups-node/test/tooling/db" + "github.com/stretchr/testify/require" +) + +func TestPostgresRepository(t *testing.T) { + endpoint, err := db.GetTestDatabaseEndpoint() + if err != nil { + t.Skipf("Skipping: %v", err) + } + + repotest.RunAllSuites(t, func(ctx context.Context, t *testing.T) (repository.Repository, func()) { + t.Helper() + err := db.SetupTestPostgres(endpoint) + require.NoError(t, err) + + repo, err := factory.NewRepositoryFromConnectionString(ctx, endpoint) + require.NoError(t, err) + + return repo, func() { repo.Close() } + }) +} diff --git a/internal/repository/postgres/report.go b/internal/repository/postgres/report.go index 9de91fa63..bc7069a61 100644 --- a/internal/repository/postgres/report.go +++ b/internal/repository/postgres/report.go @@ -7,7 +7,6 @@ import ( "context" "database/sql" "errors" - "fmt" "github.com/go-jet/jet/v2/postgres" @@ -22,10 +21,7 @@ func (r *PostgresRepository) GetReport( reportIndex uint64, ) (*model.Report, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Report. SELECT( @@ -49,14 +45,14 @@ func (r *PostgresRepository) GetReport( ). WHERE( whereClause. - AND(table.Report.Index.EQ(postgres.RawFloat(fmt.Sprintf("%d", reportIndex)))), + AND(table.Report.Index.EQ(uint64Expr(reportIndex))), ) sqlStr, args := sel.Sql() row := r.db.QueryRow(ctx, sqlStr, args...) var rp model.Report - err = row.Scan( + err := row.Scan( &rp.InputEpochApplicationID, &rp.InputIndex, &rp.Index, @@ -82,10 +78,7 @@ func (r *PostgresRepository) ListReports( descending bool, ) ([]*model.Report, uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Report. SELECT( @@ -111,11 +104,11 @@ func (r *PostgresRepository) ListReports( conditions := []postgres.BoolExpression{whereClause} if f.InputIndex != nil { - conditions = append(conditions, table.Report.InputIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", *f.InputIndex)))) + conditions = append(conditions, table.Report.InputIndex.EQ(uint64Expr(*f.InputIndex))) } if f.EpochIndex != nil { - conditions = append(conditions, table.Input.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", *f.EpochIndex)))) + conditions = append(conditions, table.Input.EpochIndex.EQ(uint64Expr(*f.EpochIndex))) conditions = append(conditions, table.Input.Status.EQ(postgres.NewEnumValue(model.InputCompletionStatus_Accepted.String()))) } @@ -160,5 +153,8 @@ func (r *PostgresRepository) ListReports( } reports = append(reports, &rp) } + if err := rows.Err(); err != nil { + return nil, 0, err + } return reports, total, nil } diff --git a/internal/repository/postgres/repository.go b/internal/repository/postgres/repository.go index 1c71e23b8..61e0ba534 100644 --- a/internal/repository/postgres/repository.go +++ b/internal/repository/postgres/repository.go @@ -57,7 +57,12 @@ func NewPostgresRepository(ctx context.Context, conn string, maxRetries int, del pool.Close() return nil, fmt.Errorf("failed to ping Postgres after %d retries", maxRetries) } - time.Sleep(delay) + select { + case <-time.After(delay): + case <-ctx.Done(): + pool.Close() + return nil, ctx.Err() + } } // Wait for schema validation (migrations) to complete. Workaround to facilitate container startup order. @@ -70,7 +75,12 @@ func NewPostgresRepository(ctx context.Context, conn string, maxRetries int, del pool.Close() return nil, fmt.Errorf("failed to validate Postgres schema version: %w", err) } - time.Sleep(delay) + select { + case <-time.After(delay): + case <-ctx.Done(): + pool.Close() + return nil, ctx.Err() + } } // This should never be reached due to the returns in the loops above diff --git a/internal/repository/postgres/repository_error_test.go b/internal/repository/postgres/repository_error_test.go new file mode 100644 index 000000000..4394aab87 --- /dev/null +++ b/internal/repository/postgres/repository_error_test.go @@ -0,0 +1,47 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package postgres_test + +import ( + "context" + "testing" + "time" + + "github.com/cartesi/rollups-node/internal/repository/postgres" + "github.com/stretchr/testify/require" +) + +func TestNewPostgresRepository_InvalidConnectionString(t *testing.T) { + ctx := context.Background() + _, err := postgres.NewPostgresRepository( + ctx, "not-a-valid-connection-string", 1, time.Millisecond) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to parse Postgres connection string") +} + +func TestNewPostgresRepository_UnreachableHostRetriesExhausted(t *testing.T) { + ctx := context.Background() + // Port 1 on localhost is almost certainly not running PostgreSQL. + // With maxRetries=2 and minimal delay the retries exhaust quickly. + _, err := postgres.NewPostgresRepository( + ctx, "postgres://user:pass@localhost:1/testdb?connect_timeout=1", 2, time.Millisecond) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to ping Postgres after 2 retries") +} + +func TestNewPostgresRepository_ContextCancelledDuringRetry(t *testing.T) { + // Use a short-lived context so that it expires while the function is + // waiting between retry attempts (delay is deliberately long). + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + _, err := postgres.NewPostgresRepository( + ctx, + "postgres://user:pass@localhost:1/testdb?connect_timeout=1", + 100, // many retries — we won't exhaust them + 10*time.Second, // long delay — context expires before this elapses + ) + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) +} diff --git a/internal/repository/postgres/state_hash.go b/internal/repository/postgres/state_hash.go index e41a9dbf2..9a38c917d 100644 --- a/internal/repository/postgres/state_hash.go +++ b/internal/repository/postgres/state_hash.go @@ -5,7 +5,6 @@ package postgres import ( "context" - "fmt" "github.com/go-jet/jet/v2/postgres" @@ -22,10 +21,7 @@ func (r *PostgresRepository) ListStateHashes( descending bool, ) ([]*model.StateHash, uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.StateHashes. SELECT( @@ -48,7 +44,7 @@ func (r *PostgresRepository) ListStateHashes( conditions := []postgres.BoolExpression{whereClause} if f.EpochIndex != nil { - conditions = append(conditions, table.StateHashes.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", *f.EpochIndex)))) + conditions = append(conditions, table.StateHashes.EpochIndex.EQ(uint64Expr(*f.EpochIndex))) } sel = sel.WHERE(postgres.AND(conditions...)) @@ -93,5 +89,8 @@ func (r *PostgresRepository) ListStateHashes( } stateHashes = append(stateHashes, &sh) } + if err := rows.Err(); err != nil { + return nil, 0, err + } return stateHashes, total, nil } diff --git a/internal/repository/postgres/tournament.go b/internal/repository/postgres/tournament.go index 0d70b610d..7ded0dc3f 100644 --- a/internal/repository/postgres/tournament.go +++ b/internal/repository/postgres/tournament.go @@ -25,10 +25,7 @@ func (r *PostgresRepository) CreateTournament( t *model.Tournament, ) error { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) insertStmt := table.Tournaments. INSERT( @@ -65,23 +62,23 @@ func (r *PostgresRepository) CreateTournament( selectQuery := table.Application.SELECT( table.Application.ID, - postgres.RawFloat(fmt.Sprintf("%d", t.EpochIndex)), + uint64Expr(t.EpochIndex), postgres.Bytea(t.Address.Bytes()), parentAddress, parentMatch, - postgres.RawFloat(fmt.Sprintf("%d", t.MaxLevel)), - postgres.RawFloat(fmt.Sprintf("%d", t.Level)), - postgres.RawFloat(fmt.Sprintf("%d", t.Log2Step)), - postgres.RawFloat(fmt.Sprintf("%d", t.Height)), + uint64Expr(t.MaxLevel), + uint64Expr(t.Level), + uint64Expr(t.Log2Step), + uint64Expr(t.Height), winnerCommitment, finalState, - postgres.RawFloat(fmt.Sprintf("%d", t.FinishedAtBlock)), + uint64Expr(t.FinishedAtBlock), ).WHERE( whereClause, ) sqlStr, args := insertStmt.QUERY(selectQuery).Sql() - _, err = r.db.Exec(ctx, sqlStr, args...) + _, err := r.db.Exec(ctx, sqlStr, args...) return err } @@ -92,10 +89,7 @@ func (r *PostgresRepository) UpdateTournament( t *model.Tournament, ) error { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) winnerCommitment := postgres.NULL if t.WinnerCommitment != nil { @@ -123,7 +117,7 @@ func (r *PostgresRepository) UpdateTournament( WHERE(postgres.AND( whereClause, table.Tournaments.ApplicationID.EQ(postgres.Int(t.ApplicationID)), - table.Tournaments.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", t.EpochIndex))), + table.Tournaments.EpochIndex.EQ(uint64Expr(t.EpochIndex)), table.Tournaments.Address.EQ(postgres.Bytea(t.Address.Bytes())), )) @@ -133,7 +127,7 @@ func (r *PostgresRepository) UpdateTournament( return err } if cmd.RowsAffected() == 0 { - return sql.ErrNoRows + return repository.ErrNotFound } return nil } @@ -144,10 +138,7 @@ func (r *PostgresRepository) GetTournament( address string, ) (*model.Tournament, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) tournamentAddress := common.HexToAddress(address) sel := table.Tournaments. @@ -182,7 +173,7 @@ func (r *PostgresRepository) GetTournament( row := r.db.QueryRow(ctx, sqlStr, args...) var t model.Tournament - err = row.Scan( + err := row.Scan( &t.ApplicationID, &t.EpochIndex, &t.Address, @@ -215,10 +206,7 @@ func (r *PostgresRepository) ListTournaments( descending bool, ) ([]*model.Tournament, uint64, error) { - whereClause, err := getWhereClauseFromNameOrAddress(nameOrAddress) - if err != nil { - return nil, 0, err - } + whereClause := getWhereClauseFromNameOrAddress(nameOrAddress) sel := table.Tournaments. SELECT( @@ -247,7 +235,7 @@ func (r *PostgresRepository) ListTournaments( conditions := []postgres.BoolExpression{whereClause} if f.EpochIndex != nil { - conditions = append(conditions, table.Tournaments.EpochIndex.EQ(postgres.RawFloat(fmt.Sprintf("%d", *f.EpochIndex)))) + conditions = append(conditions, table.Tournaments.EpochIndex.EQ(uint64Expr(*f.EpochIndex))) } if f.Level != nil { conditions = append(conditions, table.Tournaments.Level.EQ(postgres.RawInt(fmt.Sprintf("%d", *f.Level)))) @@ -308,6 +296,9 @@ func (r *PostgresRepository) ListTournaments( } tournaments = append(tournaments, &t) } + if err := rows.Err(); err != nil { + return nil, 0, err + } return tournaments, total, nil } diff --git a/internal/repository/postgres/util.go b/internal/repository/postgres/util.go index 763f86179..6dacc4140 100644 --- a/internal/repository/postgres/util.go +++ b/internal/repository/postgres/util.go @@ -20,17 +20,18 @@ func isHexAddress(s string) bool { return hexAddressRegex.MatchString(s) } -func getWhereClauseFromNameOrAddress(nameOrAddress string) (postgres.BoolExpression, error) { - - var whereClause postgres.BoolExpression +func getWhereClauseFromNameOrAddress(nameOrAddress string) postgres.BoolExpression { if isHexAddress(nameOrAddress) { address := common.HexToAddress(nameOrAddress) - whereClause = table.Application.IapplicationAddress.EQ(postgres.Bytea(address.Bytes())) - } else { - // treat as name - whereClause = table.Application.Name.EQ(postgres.String(nameOrAddress)) + return table.Application.IapplicationAddress.EQ(postgres.Bytea(address.Bytes())) } - return whereClause, nil + return table.Application.Name.EQ(postgres.String(nameOrAddress)) +} + +// uint64Expr converts a uint64 to a go-jet FloatExpression for use with +// PostgreSQL NUMERIC(20,0) "uint64" domain columns. +func uint64Expr(v uint64) postgres.FloatExpression { + return postgres.RawFloat(fmt.Sprintf("%d", v)) } func hashToBytes(h *common.Hash) any { diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 99308b9bb..380c96a8a 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -6,6 +6,7 @@ package repository import ( "context" "encoding/json" + "errors" "fmt" "time" @@ -14,7 +15,8 @@ import ( ) var ( - ErrNotFound = fmt.Errorf("not found") + ErrNotFound = errors.New("not found") + ErrNoUpdate = errors.New("update did not take effect") ) type Pagination struct { @@ -197,13 +199,13 @@ type ClaimerRepository interface { ) UpdateEpochWithSubmittedClaim( ctx context.Context, - application_id int64, + applicationID int64, index uint64, - transaction_hash common.Hash, + transactionHash common.Hash, ) error UpdateEpochWithAcceptedClaim( ctx context.Context, - application_id int64, + applicationID int64, index uint64, ) error } diff --git a/internal/repository/repotest/application_test_cases.go b/internal/repository/repotest/application_test_cases.go new file mode 100644 index 000000000..5206225fb --- /dev/null +++ b/internal/repository/repotest/application_test_cases.go @@ -0,0 +1,529 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + "time" + + . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" + "github.com/ethereum/go-ethereum/crypto" +) + +type ApplicationSuite struct { + BaseSuite +} + +func NewApplicationSuite(factory RepositoryFactory) *ApplicationSuite { + return &ApplicationSuite{BaseSuite: BaseSuite{factory: factory}} +} + +func (s *ApplicationSuite) TestCreateApplication() { + s.Run("ReturnsGeneratedID", func() { + app := NewApplicationBuilder().Build() + id, err := s.Repo.CreateApplication(s.Ctx, app, false) + s.Require().NoError(err) + s.Greater(id, int64(0)) + }) + + s.Run("WithExecutionParameters", func() { + ep := ExecutionParameters{ + SnapshotPolicy: SnapshotPolicy_EveryEpoch, + AdvanceIncCycles: 1000, + AdvanceMaxCycles: 5000, + InspectIncCycles: 1000, + InspectMaxCycles: 5000, + AdvanceIncDeadline: 10 * time.Second, + AdvanceMaxDeadline: 60 * time.Second, + InspectIncDeadline: 10 * time.Second, + InspectMaxDeadline: 60 * time.Second, + LoadDeadline: 30 * time.Second, + StoreDeadline: 30 * time.Second, + FastDeadline: 5 * time.Second, + MaxConcurrentInspects: 10, + } + app := NewApplicationBuilder(). + WithExecutionParameters(ep). + Create(s.Ctx, s.T(), s.Repo) + s.Greater(app.ID, int64(0)) + + got, err := s.Repo.GetExecutionParameters(s.Ctx, app.ID) + s.Require().NoError(err) + s.Equal(ep.SnapshotPolicy, got.SnapshotPolicy) + s.Equal(ep.AdvanceIncCycles, got.AdvanceIncCycles) + s.Equal(ep.AdvanceMaxCycles, got.AdvanceMaxCycles) + s.Equal(ep.MaxConcurrentInspects, got.MaxConcurrentInspects) + }) +} + +func (s *ApplicationSuite) TestGetApplication() { + s.Run("ByName", func() { + app := NewApplicationBuilder(). + WithName("my-unique-app"). + Create(s.Ctx, s.T(), s.Repo) + + got, err := s.Repo.GetApplication(s.Ctx, "my-unique-app") + s.Require().NoError(err) + s.Equal(app.ID, got.ID) + s.Equal("my-unique-app", got.Name) + s.Equal(app.IApplicationAddress, got.IApplicationAddress) + s.Equal(app.IConsensusAddress, got.IConsensusAddress) + s.Equal(app.IInputBoxAddress, got.IInputBoxAddress) + s.Equal(app.TemplateHash, got.TemplateHash) + s.Equal(app.EpochLength, got.EpochLength) + s.Equal(app.ConsensusType, got.ConsensusType) + s.Equal(app.State, got.State) + s.Equal(app.DataAvailability, got.DataAvailability) + s.False(got.CreatedAt.IsZero(), "CreatedAt should be set") + s.False(got.UpdatedAt.IsZero(), "UpdatedAt should be set") + }) + + s.Run("ByAddress", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + got, err := s.Repo.GetApplication(s.Ctx, app.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(app.ID, got.ID) + s.Equal(app.Name, got.Name) + }) + + s.Run("NotFound", func() { + got, err := s.Repo.GetApplication(s.Ctx, "nonexistent") + s.Require().NoError(err) + s.Nil(got) + }) +} + +func (s *ApplicationSuite) TestListApplications() { + s.Run("EmptyResult", func() { + apps, total, err := s.Repo.ListApplications( + s.Ctx, repository.ApplicationFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(apps) + s.Equal(uint64(0), total) + }) + + s.Run("ReturnsAllApps", func() { + for range 3 { + NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + } + apps, total, err := s.Repo.ListApplications( + s.Ctx, repository.ApplicationFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(apps, 3) + s.Equal(uint64(3), total) + }) + + s.Run("FilterByState", func() { + NewApplicationBuilder().WithState(ApplicationState_Enabled).Create(s.Ctx, s.T(), s.Repo) + NewApplicationBuilder().WithState(ApplicationState_Disabled).Create(s.Ctx, s.T(), s.Repo) + + state := ApplicationState_Enabled + apps, total, err := s.Repo.ListApplications( + s.Ctx, + repository.ApplicationFilter{State: &state}, + repository.Pagination{Limit: 10}, + false, + ) + s.Require().NoError(err) + s.Len(apps, 1) + s.Equal(uint64(1), total) + s.Equal(ApplicationState_Enabled, apps[0].State) + }) + + s.Run("FilterByConsensus", func() { + NewApplicationBuilder().WithConsensus(Consensus_Authority).Create(s.Ctx, s.T(), s.Repo) + NewApplicationBuilder().WithConsensus(Consensus_PRT).Create(s.Ctx, s.T(), s.Repo) + + consensus := Consensus_PRT + apps, total, err := s.Repo.ListApplications( + s.Ctx, + repository.ApplicationFilter{ConsensusType: &consensus}, + repository.Pagination{Limit: 10}, + false, + ) + s.Require().NoError(err) + s.Len(apps, 1) + s.Equal(uint64(1), total) + s.Equal(Consensus_PRT, apps[0].ConsensusType) + }) + + s.Run("FilterByDataAvailability", func() { + NewApplicationBuilder(). + WithDataAvailability(DataAvailability_InputBox[:]). + Create(s.Ctx, s.T(), s.Repo) + // Create another app with a different DA selector + otherDA := DataAvailabilitySelector{0xaa, 0xbb, 0xcc, 0xdd} + NewApplicationBuilder(). + WithDataAvailability(otherDA[:]). + Create(s.Ctx, s.T(), s.Repo) + + da := DataAvailability_InputBox + apps, total, err := s.Repo.ListApplications( + s.Ctx, + repository.ApplicationFilter{DataAvailability: &da}, + repository.Pagination{Limit: 10}, + false, + ) + s.Require().NoError(err) + s.Len(apps, 1) + s.Equal(uint64(1), total) + s.Equal(DataAvailability_InputBox[:], apps[0].DataAvailability[:4]) + }) + + s.Run("Pagination", func() { + for range 5 { + NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + } + + apps, total, err := s.Repo.ListApplications( + s.Ctx, + repository.ApplicationFilter{}, + repository.Pagination{Limit: 2, Offset: 0}, + false, + ) + s.Require().NoError(err) + s.Len(apps, 2) + s.Equal(uint64(5), total) + + apps2, total2, err := s.Repo.ListApplications( + s.Ctx, + repository.ApplicationFilter{}, + repository.Pagination{Limit: 2, Offset: 2}, + false, + ) + s.Require().NoError(err) + s.Len(apps2, 2) + s.Equal(uint64(5), total2) + // Pages should be different + s.NotEqual(apps[0].ID, apps2[0].ID) + }) + + s.Run("Descending", func() { + a1 := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + a2 := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + apps, _, err := s.Repo.ListApplications( + s.Ctx, repository.ApplicationFilter{}, repository.Pagination{Limit: 10}, true) + s.Require().NoError(err) + s.Require().Len(apps, 2) + // Descending: second created should be first + s.Equal(a2.ID, apps[0].ID) + s.Equal(a1.ID, apps[1].ID) + }) + + s.Run("CombinedFilters", func() { + // Create apps with different combinations of state, consensus, and DA + NewApplicationBuilder(). + WithState(ApplicationState_Enabled). + WithConsensus(Consensus_Authority). + WithDataAvailability(DataAvailability_InputBox[:]). + Create(s.Ctx, s.T(), s.Repo) + NewApplicationBuilder(). + WithState(ApplicationState_Enabled). + WithConsensus(Consensus_PRT). + WithDataAvailability(DataAvailability_InputBox[:]). + Create(s.Ctx, s.T(), s.Repo) + NewApplicationBuilder(). + WithState(ApplicationState_Disabled). + WithConsensus(Consensus_Authority). + WithDataAvailability(DataAvailability_InputBox[:]). + Create(s.Ctx, s.T(), s.Repo) + + state := ApplicationState_Enabled + consensus := Consensus_Authority + apps, total, err := s.Repo.ListApplications( + s.Ctx, + repository.ApplicationFilter{ + State: &state, + ConsensusType: &consensus, + }, + repository.Pagination{Limit: 10}, + false, + ) + s.Require().NoError(err) + s.Len(apps, 1) + s.Equal(uint64(1), total) + s.Equal(ApplicationState_Enabled, apps[0].State) + s.Equal(Consensus_Authority, apps[0].ConsensusType) + }) + + s.Run("CombinedStateAndDataAvailability", func() { + NewApplicationBuilder(). + WithState(ApplicationState_Enabled). + WithDataAvailability(DataAvailability_InputBox[:]). + Create(s.Ctx, s.T(), s.Repo) + + otherDA := DataAvailabilitySelector{0xaa, 0xbb, 0xcc, 0xdd} + NewApplicationBuilder(). + WithState(ApplicationState_Enabled). + WithDataAvailability(otherDA[:]). + Create(s.Ctx, s.T(), s.Repo) + + state := ApplicationState_Enabled + da := DataAvailability_InputBox + apps, total, err := s.Repo.ListApplications( + s.Ctx, + repository.ApplicationFilter{ + State: &state, + DataAvailability: &da, + }, + repository.Pagination{Limit: 10}, + false, + ) + s.Require().NoError(err) + s.Len(apps, 1) + s.Equal(uint64(1), total) + s.Equal(DataAvailability_InputBox[:], apps[0].DataAvailability[:4]) + }) +} + +func (s *ApplicationSuite) TestUpdateApplicationState() { + s.Run("UpdatesState", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + s.Equal(ApplicationState_Enabled, app.State) + + reason := "maintenance" + err := s.Repo.UpdateApplicationState(s.Ctx, app.ID, ApplicationState_Disabled, &reason) + s.Require().NoError(err) + + got, err := s.Repo.GetApplication(s.Ctx, app.Name) + s.Require().NoError(err) + s.Equal(ApplicationState_Disabled, got.State) + s.Require().NotNil(got.Reason) + s.Equal("maintenance", *got.Reason) + }) +} + +func (s *ApplicationSuite) TestDeleteApplication() { + s.Run("DeletesExistingApp", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + err := s.Repo.DeleteApplication(s.Ctx, app.ID) + s.Require().NoError(err) + + got, err := s.Repo.GetApplication(s.Ctx, app.Name) + s.Require().NoError(err) + s.Nil(got) + }) +} + +func (s *ApplicationSuite) TestGetExecutionParameters() { + s.Run("DefaultParameters", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + ep, err := s.Repo.GetExecutionParameters(s.Ctx, app.ID) + s.Require().NoError(err) + s.NotNil(ep) + }) +} + +func (s *ApplicationSuite) TestUpdateExecutionParameters() { + s.Run("UpdatesValues", func() { + ep := ExecutionParameters{ + SnapshotPolicy: SnapshotPolicy_EveryInput, + AdvanceIncCycles: 2000, + AdvanceMaxCycles: 10000, + InspectIncCycles: 2000, + InspectMaxCycles: 10000, + AdvanceIncDeadline: 20 * time.Second, + AdvanceMaxDeadline: 120 * time.Second, + InspectIncDeadline: 20 * time.Second, + InspectMaxDeadline: 120 * time.Second, + LoadDeadline: 60 * time.Second, + StoreDeadline: 60 * time.Second, + FastDeadline: 10 * time.Second, + MaxConcurrentInspects: 5, + } + app := NewApplicationBuilder(). + WithExecutionParameters(ep). + Create(s.Ctx, s.T(), s.Repo) + + ep.ApplicationID = app.ID + ep.AdvanceMaxCycles = 99999 + err := s.Repo.UpdateExecutionParameters(s.Ctx, &ep) + s.Require().NoError(err) + + got, err := s.Repo.GetExecutionParameters(s.Ctx, app.ID) + s.Require().NoError(err) + s.Equal(uint64(99999), got.AdvanceMaxCycles) + }) +} + +func (s *ApplicationSuite) TestEventLastCheckBlock() { + s.Run("DefaultIsZero", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + block, err := s.Repo.GetEventLastCheckBlock(s.Ctx, app.ID, MonitoredEvent_InputAdded) + s.Require().NoError(err) + s.Equal(uint64(0), block) + }) + + s.Run("UpdateAndGet", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + err := s.Repo.UpdateEventLastCheckBlock( + s.Ctx, []int64{app.ID}, MonitoredEvent_InputAdded, 42) + s.Require().NoError(err) + + block, err := s.Repo.GetEventLastCheckBlock(s.Ctx, app.ID, MonitoredEvent_InputAdded) + s.Require().NoError(err) + s.Equal(uint64(42), block) + }) + + s.Run("AllMonitoredEventTypes", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + // Events that map to the epoch check block column + err := s.Repo.UpdateEventLastCheckBlock( + s.Ctx, []int64{app.ID}, MonitoredEvent_EpochSealed, 10) + s.Require().NoError(err) + block, err := s.Repo.GetEventLastCheckBlock( + s.Ctx, app.ID, MonitoredEvent_EpochSealed) + s.Require().NoError(err) + s.Equal(uint64(10), block) + + // Events that map to the output check block column + err = s.Repo.UpdateEventLastCheckBlock( + s.Ctx, []int64{app.ID}, MonitoredEvent_OutputExecuted, 20) + s.Require().NoError(err) + block, err = s.Repo.GetEventLastCheckBlock( + s.Ctx, app.ID, MonitoredEvent_OutputExecuted) + s.Require().NoError(err) + s.Equal(uint64(20), block) + + // Tournament events all map to the tournament check block column + tournamentEvents := []MonitoredEvent{ + MonitoredEvent_CommitmentJoined, + MonitoredEvent_MatchAdvanced, + MonitoredEvent_MatchCreated, + MonitoredEvent_MatchDeleted, + MonitoredEvent_NewInnerTournament, + } + for _, event := range tournamentEvents { + err = s.Repo.UpdateEventLastCheckBlock( + s.Ctx, []int64{app.ID}, event, 30) + s.Require().NoError(err) + block, err = s.Repo.GetEventLastCheckBlock(s.Ctx, app.ID, event) + s.Require().NoError(err) + s.Equal(uint64(30), block) + } + + // ClaimSubmitted and ClaimAccepted should return errors + _, err = s.Repo.GetEventLastCheckBlock( + s.Ctx, app.ID, MonitoredEvent_ClaimSubmitted) + s.Require().Error(err) + + _, err = s.Repo.GetEventLastCheckBlock( + s.Ctx, app.ID, MonitoredEvent_ClaimAccepted) + s.Require().Error(err) + }) + + s.Run("UpdateMultipleAppIDs", func() { + app1 := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + app2 := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + err := s.Repo.UpdateEventLastCheckBlock( + s.Ctx, []int64{app1.ID, app2.ID}, MonitoredEvent_InputAdded, 55) + s.Require().NoError(err) + + block1, err := s.Repo.GetEventLastCheckBlock( + s.Ctx, app1.ID, MonitoredEvent_InputAdded) + s.Require().NoError(err) + s.Equal(uint64(55), block1) + + block2, err := s.Repo.GetEventLastCheckBlock( + s.Ctx, app2.ID, MonitoredEvent_InputAdded) + s.Require().NoError(err) + s.Equal(uint64(55), block2) + }) + + s.Run("EmptyAppIDsIsNoOp", func() { + err := s.Repo.UpdateEventLastCheckBlock( + s.Ctx, []int64{}, MonitoredEvent_InputAdded, 42) + s.Require().NoError(err) + }) +} + +func (s *ApplicationSuite) TestGetProcessedInputCount() { + s.Run("ReturnsZeroInitially", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + count, err := s.Repo.GetProcessedInputCount(s.Ctx, app.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(uint64(0), count) + }) + + s.Run("ReturnsCountAfterProcessing", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + // StoreAdvanceResult increments ProcessedInputs + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + OutputsProof: OutputsProof{ + MachineHash: crypto.Keccak256Hash([]byte("machine")), + }, + } + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().NoError(err) + + count, err := s.Repo.GetProcessedInputCount( + s.Ctx, seed.App.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(uint64(1), count) + }) +} + +func (s *ApplicationSuite) TestUpdateApplication() { + s.Run("UpdatesFields", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + app.EpochLength = 20 + err := s.Repo.UpdateApplication(s.Ctx, app) + s.Require().NoError(err) + + got, err := s.Repo.GetApplication(s.Ctx, app.Name) + s.Require().NoError(err) + s.Equal(uint64(20), got.EpochLength) + }) +} + +func (s *ApplicationSuite) TestGetLastSnapshot() { + s.Run("NotFound", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + got, err := s.Repo.GetLastSnapshot(s.Ctx, app.IApplicationAddress.String()) + s.Require().NoError(err) + s.Nil(got) + }) + + s.Run("ReturnsInputWithSnapshot", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + // GetLastSnapshot requires Status == Accepted AND SnapshotURI IS NOT NULL. + // Use StoreAdvanceResult to set the input to Accepted first. + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + OutputsProof: OutputsProof{ + MachineHash: crypto.Keccak256Hash([]byte("machine")), + }, + } + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().NoError(err) + + uri := "/snapshots/epoch-0-input-0" + err = s.Repo.UpdateInputSnapshotURI(s.Ctx, seed.App.ID, 0, uri) + s.Require().NoError(err) + + got, err := s.Repo.GetLastSnapshot( + s.Ctx, seed.App.IApplicationAddress.String()) + s.Require().NoError(err) + s.Require().NotNil(got) + s.Require().NotNil(got.SnapshotURI) + s.Equal(uri, *got.SnapshotURI) + s.Equal(uint64(0), got.Index) + }) +} diff --git a/internal/repository/repotest/builders.go b/internal/repository/repotest/builders.go new file mode 100644 index 000000000..dcf0d4161 --- /dev/null +++ b/internal/repository/repotest/builders.go @@ -0,0 +1,587 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + "context" + "fmt" + "math/big" + "sync/atomic" + "testing" + + . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" + "github.com/ethereum/go-ethereum/common" + "github.com/stretchr/testify/require" +) + +// counter provides unique values across all builders to avoid collisions. +var counter atomic.Uint64 + +func nextID() uint64 { + return counter.Add(1) +} + +// UniqueAddress returns a unique Ethereum address for test isolation. +func UniqueAddress() common.Address { + return common.BigToAddress(big.NewInt(int64(nextID()))) +} + +// UniqueHash returns a unique hash for test isolation. +func UniqueHash() common.Hash { + return common.BigToHash(big.NewInt(int64(nextID()))) +} + +// --------------------------------------------------------------------------- +// ApplicationBuilder +// --------------------------------------------------------------------------- + +type ApplicationBuilder struct { + app *Application + withExecutionParameters bool +} + +func NewApplicationBuilder() *ApplicationBuilder { + id := nextID() + return &ApplicationBuilder{ + app: &Application{ + Name: fmt.Sprintf("test-app-%d", id), + IApplicationAddress: UniqueAddress(), + IConsensusAddress: UniqueAddress(), + IInputBoxAddress: UniqueAddress(), + TemplateHash: UniqueHash(), + TemplateURI: fmt.Sprintf("/template/%d", id), + EpochLength: 10, + DataAvailability: DataAvailability_InputBox[:], + ConsensusType: Consensus_Authority, + State: ApplicationState_Enabled, + }, + } +} + +func (b *ApplicationBuilder) WithName(name string) *ApplicationBuilder { + b.app.Name = name + return b +} + +func (b *ApplicationBuilder) WithAddress(addr common.Address) *ApplicationBuilder { + b.app.IApplicationAddress = addr + return b +} + +func (b *ApplicationBuilder) WithConsensus(c Consensus) *ApplicationBuilder { + b.app.ConsensusType = c + return b +} + +func (b *ApplicationBuilder) WithState(s ApplicationState) *ApplicationBuilder { + b.app.State = s + return b +} + +func (b *ApplicationBuilder) WithEpochLength(l uint64) *ApplicationBuilder { + b.app.EpochLength = l + return b +} + +func (b *ApplicationBuilder) WithDataAvailability(da []byte) *ApplicationBuilder { + b.app.DataAvailability = da + return b +} + +func (b *ApplicationBuilder) WithExecutionParameters(ep ExecutionParameters) *ApplicationBuilder { + b.app.ExecutionParameters = ep + b.withExecutionParameters = true + return b +} + +// Build returns a copy of the Application model without persisting it. +func (b *ApplicationBuilder) Build() *Application { + a := *b.app + return &a +} + +// Create persists the Application via the repository and returns a copy with the generated ID. +func (b *ApplicationBuilder) Create( + ctx context.Context, t *testing.T, repo repository.Repository, +) *Application { + t.Helper() + a := *b.app + id, err := repo.CreateApplication(ctx, &a, b.withExecutionParameters) + require.NoError(t, err) + a.ID = id + return &a +} + +// --------------------------------------------------------------------------- +// EpochBuilder +// --------------------------------------------------------------------------- + +type EpochBuilder struct { + epoch *Epoch +} + +func NewEpochBuilder(appID int64) *EpochBuilder { + return &EpochBuilder{ + epoch: &Epoch{ + ApplicationID: appID, + Index: 0, + VirtualIndex: 0, + FirstBlock: 0, + LastBlock: 9, + Status: EpochStatus_Open, + }, + } +} + +func (b *EpochBuilder) WithIndex(i uint64) *EpochBuilder { + b.epoch.Index = i + b.epoch.VirtualIndex = i + return b +} + +func (b *EpochBuilder) WithVirtualIndex(i uint64) *EpochBuilder { + b.epoch.VirtualIndex = i + return b +} + +func (b *EpochBuilder) WithStatus(s EpochStatus) *EpochBuilder { + b.epoch.Status = s + return b +} + +func (b *EpochBuilder) WithBlocks(first, last uint64) *EpochBuilder { + b.epoch.FirstBlock = first + b.epoch.LastBlock = last + return b +} + +func (b *EpochBuilder) WithInputBounds(lower, upper uint64) *EpochBuilder { + b.epoch.InputIndexLowerBound = lower + b.epoch.InputIndexUpperBound = upper + return b +} + +func (b *EpochBuilder) WithClaimHash(h common.Hash) *EpochBuilder { + b.epoch.OutputsMerkleRoot = &h + return b +} + +func (b *EpochBuilder) WithClaimTransactionHash(h common.Hash) *EpochBuilder { + b.epoch.ClaimTransactionHash = &h + return b +} + +func (b *EpochBuilder) WithMachineHash(h common.Hash) *EpochBuilder { + b.epoch.MachineHash = &h + return b +} + +// Build returns a copy of the Epoch model without persisting it. +func (b *EpochBuilder) Build() *Epoch { + e := *b.epoch + return &e +} + +// --------------------------------------------------------------------------- +// InputBuilder +// --------------------------------------------------------------------------- + +type InputBuilder struct { + input *Input +} + +func NewInputBuilder() *InputBuilder { + return &InputBuilder{ + input: &Input{ + Index: 0, + BlockNumber: 1, + RawData: []byte("input-data"), + Status: InputCompletionStatus_None, + TransactionReference: UniqueHash(), + }, + } +} + +func (b *InputBuilder) WithIndex(i uint64) *InputBuilder { + b.input.Index = i + return b +} + +func (b *InputBuilder) WithEpochIndex(i uint64) *InputBuilder { + b.input.EpochIndex = i + return b +} + +func (b *InputBuilder) WithBlockNumber(n uint64) *InputBuilder { + b.input.BlockNumber = n + return b +} + +func (b *InputBuilder) WithStatus(s InputCompletionStatus) *InputBuilder { + b.input.Status = s + return b +} + +func (b *InputBuilder) WithRawData(data []byte) *InputBuilder { + b.input.RawData = data + return b +} + +func (b *InputBuilder) WithTransactionReference(h common.Hash) *InputBuilder { + b.input.TransactionReference = h + return b +} + +// Build returns a copy of the Input model without persisting it. +func (b *InputBuilder) Build() *Input { + i := *b.input + return &i +} + +// --------------------------------------------------------------------------- +// OutputBuilder +// --------------------------------------------------------------------------- + +type OutputBuilder struct { + output *Output +} + +func NewOutputBuilder(appID int64) *OutputBuilder { + return &OutputBuilder{ + output: &Output{ + InputEpochApplicationID: appID, + EpochIndex: 0, + InputIndex: 0, + Index: 0, + RawData: []byte("output-data"), + }, + } +} + +func (b *OutputBuilder) WithEpochIndex(i uint64) *OutputBuilder { + b.output.EpochIndex = i + return b +} + +func (b *OutputBuilder) WithInputIndex(i uint64) *OutputBuilder { + b.output.InputIndex = i + return b +} + +func (b *OutputBuilder) WithIndex(i uint64) *OutputBuilder { + b.output.Index = i + return b +} + +func (b *OutputBuilder) WithRawData(data []byte) *OutputBuilder { + b.output.RawData = data + return b +} + +func (b *OutputBuilder) WithHash(h common.Hash) *OutputBuilder { + b.output.Hash = &h + return b +} + +// Build returns a copy of the Output model without persisting it. +func (b *OutputBuilder) Build() *Output { + o := *b.output + return &o +} + +// --------------------------------------------------------------------------- +// ReportBuilder +// --------------------------------------------------------------------------- + +type ReportBuilder struct { + report *Report +} + +func NewReportBuilder(appID int64) *ReportBuilder { + return &ReportBuilder{ + report: &Report{ + InputEpochApplicationID: appID, + EpochIndex: 0, + InputIndex: 0, + Index: 0, + RawData: []byte("report-data"), + }, + } +} + +func (b *ReportBuilder) WithEpochIndex(i uint64) *ReportBuilder { + b.report.EpochIndex = i + return b +} + +func (b *ReportBuilder) WithInputIndex(i uint64) *ReportBuilder { + b.report.InputIndex = i + return b +} + +func (b *ReportBuilder) WithIndex(i uint64) *ReportBuilder { + b.report.Index = i + return b +} + +func (b *ReportBuilder) WithRawData(data []byte) *ReportBuilder { + b.report.RawData = data + return b +} + +// Build returns a copy of the Report model without persisting it. +func (b *ReportBuilder) Build() *Report { + r := *b.report + return &r +} + +// --------------------------------------------------------------------------- +// TournamentBuilder +// --------------------------------------------------------------------------- + +type TournamentBuilder struct { + tournament *Tournament +} + +func NewTournamentBuilder(appID int64) *TournamentBuilder { + return &TournamentBuilder{ + tournament: &Tournament{ + ApplicationID: appID, + EpochIndex: 0, + Address: UniqueAddress(), + MaxLevel: 3, + Level: 0, + Log2Step: 20, + Height: 2, + }, + } +} + +func (b *TournamentBuilder) WithEpochIndex(i uint64) *TournamentBuilder { + b.tournament.EpochIndex = i + return b +} + +func (b *TournamentBuilder) WithAddress(addr common.Address) *TournamentBuilder { + b.tournament.Address = addr + return b +} + +func (b *TournamentBuilder) WithLevel(l uint64) *TournamentBuilder { + b.tournament.Level = l + return b +} + +func (b *TournamentBuilder) WithParent(addr common.Address, matchIDHash common.Hash) *TournamentBuilder { + b.tournament.ParentTournamentAddress = &addr + b.tournament.ParentMatchIDHash = &matchIDHash + return b +} + +// Build returns a copy of the Tournament model without persisting it. +func (b *TournamentBuilder) Build() *Tournament { + t := *b.tournament + return &t +} + +// --------------------------------------------------------------------------- +// CommitmentBuilder +// --------------------------------------------------------------------------- + +type CommitmentBuilder struct { + commitment *Commitment +} + +func NewCommitmentBuilder(appID int64) *CommitmentBuilder { + return &CommitmentBuilder{ + commitment: &Commitment{ + ApplicationID: appID, + EpochIndex: 0, + TournamentAddress: UniqueAddress(), + Commitment: UniqueHash(), + FinalStateHash: UniqueHash(), + SubmitterAddress: UniqueAddress(), + BlockNumber: 100, + TxHash: UniqueHash(), + }, + } +} + +func (b *CommitmentBuilder) WithEpochIndex(i uint64) *CommitmentBuilder { + b.commitment.EpochIndex = i + return b +} + +func (b *CommitmentBuilder) WithTournamentAddress(addr common.Address) *CommitmentBuilder { + b.commitment.TournamentAddress = addr + return b +} + +func (b *CommitmentBuilder) WithCommitmentHash(h common.Hash) *CommitmentBuilder { + b.commitment.Commitment = h + return b +} + +// Build returns a copy of the Commitment model without persisting it. +func (b *CommitmentBuilder) Build() *Commitment { + c := *b.commitment + return &c +} + +// --------------------------------------------------------------------------- +// MatchBuilder +// --------------------------------------------------------------------------- + +type MatchBuilder struct { + match *Match +} + +func NewMatchBuilder(appID int64) *MatchBuilder { + return &MatchBuilder{ + match: &Match{ + ApplicationID: appID, + EpochIndex: 0, + TournamentAddress: UniqueAddress(), + IDHash: UniqueHash(), + CommitmentOne: UniqueHash(), + CommitmentTwo: UniqueHash(), + LeftOfTwo: UniqueHash(), + BlockNumber: 100, + TxHash: UniqueHash(), + Winner: WinnerCommitment_NONE, + DeletionReason: MatchDeletionReason_NOT_DELETED, + }, + } +} + +func (b *MatchBuilder) WithEpochIndex(i uint64) *MatchBuilder { + b.match.EpochIndex = i + return b +} + +func (b *MatchBuilder) WithTournamentAddress(addr common.Address) *MatchBuilder { + b.match.TournamentAddress = addr + return b +} + +func (b *MatchBuilder) WithIDHash(h common.Hash) *MatchBuilder { + b.match.IDHash = h + return b +} + +func (b *MatchBuilder) WithWinner(w WinnerCommitment) *MatchBuilder { + b.match.Winner = w + return b +} + +func (b *MatchBuilder) WithCommitmentOne(h common.Hash) *MatchBuilder { + b.match.CommitmentOne = h + return b +} + +func (b *MatchBuilder) WithCommitmentTwo(h common.Hash) *MatchBuilder { + b.match.CommitmentTwo = h + return b +} + +func (b *MatchBuilder) WithDeletionReason(r MatchDeletionReason) *MatchBuilder { + b.match.DeletionReason = r + return b +} + +// Build returns a copy of the Match model without persisting it. +func (b *MatchBuilder) Build() *Match { + m := *b.match + return &m +} + +// --------------------------------------------------------------------------- +// MatchAdvancedBuilder +// --------------------------------------------------------------------------- + +type MatchAdvancedBuilder struct { + ma *MatchAdvanced +} + +func NewMatchAdvancedBuilder(appID int64) *MatchAdvancedBuilder { + return &MatchAdvancedBuilder{ + ma: &MatchAdvanced{ + ApplicationID: appID, + EpochIndex: 0, + TournamentAddress: UniqueAddress(), + IDHash: UniqueHash(), + OtherParent: UniqueHash(), + LeftNode: UniqueHash(), + BlockNumber: 100, + TxHash: UniqueHash(), + }, + } +} + +func (b *MatchAdvancedBuilder) WithEpochIndex(i uint64) *MatchAdvancedBuilder { + b.ma.EpochIndex = i + return b +} + +func (b *MatchAdvancedBuilder) WithTournamentAddress(addr common.Address) *MatchAdvancedBuilder { + b.ma.TournamentAddress = addr + return b +} + +func (b *MatchAdvancedBuilder) WithIDHash(h common.Hash) *MatchAdvancedBuilder { + b.ma.IDHash = h + return b +} + +func (b *MatchAdvancedBuilder) WithOtherParent(h common.Hash) *MatchAdvancedBuilder { + b.ma.OtherParent = h + return b +} + +// Build returns a copy of the MatchAdvanced model without persisting it. +func (b *MatchAdvancedBuilder) Build() *MatchAdvanced { + m := *b.ma + return &m +} + +// --------------------------------------------------------------------------- +// Seed — convenience for creating a minimum viable Application + Epoch + Input +// --------------------------------------------------------------------------- + +// SeedResult holds the entities created by Seed. +type SeedResult struct { + App *Application + Epoch *Epoch + Input *Input +} + +// Seed creates and persists a minimal Application with one Epoch and one Input. +func Seed(ctx context.Context, t *testing.T, repo repository.Repository) *SeedResult { + t.Helper() + + app := NewApplicationBuilder().Create(ctx, t, repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0). + WithStatus(EpochStatus_Closed). + WithBlocks(0, 9). + WithInputBounds(0, 0). + Build() + + input := NewInputBuilder(). + WithIndex(0). + WithBlockNumber(5). + Build() + + epochInputMap := map[*Epoch][]*Input{epoch: {input}} + err := repo.CreateEpochsAndInputs(ctx, app.IApplicationAddress.String(), epochInputMap, 10) + require.NoError(t, err) + + return &SeedResult{ + App: app, + Epoch: epoch, + Input: input, + } +} diff --git a/internal/repository/repotest/bulk_test_cases.go b/internal/repository/repotest/bulk_test_cases.go new file mode 100644 index 000000000..779fb28a6 --- /dev/null +++ b/internal/repository/repotest/bulk_test_cases.go @@ -0,0 +1,979 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + "context" + "encoding/hex" + "fmt" + "sync" + + . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/crypto" +) + +type BulkOperationsSuite struct { + BaseSuite +} + +func NewBulkOperationsSuite(factory RepositoryFactory) *BulkOperationsSuite { + return &BulkOperationsSuite{BaseSuite: BaseSuite{factory: factory}} +} + +func (s *BulkOperationsSuite) TestStoreAdvanceResult() { + s.Run("AcceptedInput", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + machineHash := crypto.Keccak256Hash([]byte("machine")) + outputsHash := crypto.Keccak256Hash([]byte("outputs")) + + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte("output1"), []byte("output2")}, + Reports: [][]byte{[]byte("report1")}, + OutputsProof: OutputsProof{ + OutputsHash: outputsHash, + MachineHash: machineHash, + }, + } + + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().NoError(err) + + // Verify the input was updated + input, err := s.Repo.GetInput(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(InputCompletionStatus_Accepted, input.Status) + s.Require().NotNil(input.MachineHash) + s.Equal(machineHash, *input.MachineHash) + + // Verify outputs were created + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(outputs, 2) + s.Equal(uint64(2), total) + + // Verify reports were created + reports, total, err := s.Repo.ListReports( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.ReportFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(reports, 1) + s.Equal(uint64(1), total) + }) + + s.Run("RejectedInput", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + machineHash := crypto.Keccak256Hash([]byte("machine-rejected")) + + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Rejected, + OutputsProof: OutputsProof{ + MachineHash: machineHash, + }, + } + + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().NoError(err) + + input, err := s.Repo.GetInput(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(InputCompletionStatus_Rejected, input.Status) + }) + + s.Run("WithNoOutputsOrReports", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + machineHash := crypto.Keccak256Hash([]byte("machine-empty")) + + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + OutputsProof: OutputsProof{ + MachineHash: machineHash, + }, + } + + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().NoError(err) + + input, err := s.Repo.GetInput(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(InputCompletionStatus_Accepted, input.Status) + }) + + // Verify that a failure mid-transaction rolls back all prior changes. + // We trigger failure by providing a non-existent epoch index, causing the + // epoch outputs proof update to fail after outputs and input are written. + s.Run("RollbackOnPartialFailure", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + result := &AdvanceResult{ + EpochIndex: 99, // non-existent epoch + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte("should-be-rolled-back")}, + Reports: [][]byte{[]byte("should-be-rolled-back")}, + OutputsProof: OutputsProof{ + OutputsHash: UniqueHash(), + MachineHash: UniqueHash(), + }, + } + + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().Error(err) + + // Input status should remain unchanged (NONE) + input, err := s.Repo.GetInput( + s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(InputCompletionStatus_None, input.Status) + + // No outputs should have been persisted + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(outputs) + s.Equal(uint64(0), total) + + // No reports should have been persisted + reports, total, err := s.Repo.ListReports( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.ReportFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(reports) + s.Equal(uint64(0), total) + + // ProcessedInputs should remain at 0 + count, err := s.Repo.GetProcessedInputCount( + s.Ctx, seed.App.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(uint64(0), count) + }) + + s.Run("DaveConsensusWithStateHashes", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + machineHash := crypto.Keccak256Hash([]byte("dave-machine")) + outputsHash := crypto.Keccak256Hash([]byte("dave-outputs")) + + hash1 := [32]byte(crypto.Keccak256Hash([]byte("state-1"))) + hash2 := [32]byte(crypto.Keccak256Hash([]byte("state-2"))) + hash3 := [32]byte(crypto.Keccak256Hash([]byte("state-3"))) + + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte("dave-output")}, + Hashes: [][32]byte{hash1, hash2, hash3}, + RemainingMetaCycles: 42, + IsDaveConsensus: true, + OutputsProof: OutputsProof{ + OutputsHash: outputsHash, + MachineHash: machineHash, + }, + } + + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().NoError(err) + + // Verify state hashes were created (3 intermediate + 1 final = 4) + epochIdx := uint64(0) + stateHashes, total, err := s.Repo.ListStateHashes( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.StateHashFilter{EpochIndex: &epochIdx}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(stateHashes, 4) + s.Equal(uint64(4), total) + + // Verify intermediate hashes have Repetitions=1 + s.Equal(common.Hash(hash1), stateHashes[0].MachineHash) + s.Equal(uint64(1), stateHashes[0].Repetitions) + s.Equal(common.Hash(hash2), stateHashes[1].MachineHash) + s.Equal(uint64(1), stateHashes[1].Repetitions) + s.Equal(common.Hash(hash3), stateHashes[2].MachineHash) + s.Equal(uint64(1), stateHashes[2].Repetitions) + + // Verify final hash has RemainingMetaCycles as Repetitions + s.Equal(machineHash, stateHashes[3].MachineHash) + s.Equal(uint64(42), stateHashes[3].Repetitions) + + // Verify outputs were also created + outputs, _, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(outputs, 1) + + // Verify input was updated + input, err := s.Repo.GetInput(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(InputCompletionStatus_Accepted, input.Status) + }) +} + +func (s *BulkOperationsSuite) TestStoreAdvanceResultRollback() { + // Trigger rollback by referencing a non-existent input index (the input + // doesn't exist in the DB, so updateInput will fail with sql.ErrNoRows). + // This tests that outputs inserted earlier in the same transaction are + // rolled back when a subsequent step fails. + s.Run("RollbackOnInputUpdateFailure", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 999, // non-existent input index + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte("should-be-rolled-back")}, + Reports: [][]byte{[]byte("should-be-rolled-back")}, + OutputsProof: OutputsProof{ + OutputsHash: UniqueHash(), + MachineHash: UniqueHash(), + }, + } + + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().Error(err) + + // Verify no outputs were persisted (rolled back) + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(outputs) + s.Equal(uint64(0), total) + + // Verify no reports were persisted (rolled back) + reports, total, err := s.Repo.ListReports( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.ReportFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(reports) + s.Equal(uint64(0), total) + + // Verify the original input is untouched + input, err := s.Repo.GetInput( + s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(InputCompletionStatus_None, input.Status) + + // Verify ProcessedInputs remains 0 + count, err := s.Repo.GetProcessedInputCount( + s.Ctx, seed.App.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(uint64(0), count) + }) + + // Trigger rollback by providing a valid input index but a non-existent + // app ID, so updateApp fails. This verifies that outputs, reports, and + // the input status update are all rolled back. + s.Run("RollbackOnAppUpdateFailure", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte("should-be-rolled-back")}, + Reports: [][]byte{[]byte("should-be-rolled-back")}, + OutputsProof: OutputsProof{ + OutputsHash: UniqueHash(), + MachineHash: UniqueHash(), + }, + } + + // Use a non-existent app ID -- updateApp will fail + // because no application row matches. + err := s.Repo.StoreAdvanceResult(s.Ctx, 999999, result) + s.Require().Error(err) + + // Verify the original input is untouched + input, err := s.Repo.GetInput( + s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(InputCompletionStatus_None, input.Status) + + // Verify no outputs were persisted + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(outputs) + s.Equal(uint64(0), total) + }) + + // Verify that when Dave consensus state hash insertion fails (due to + // a bad epoch index), all prior work (outputs, reports) is rolled back. + s.Run("DaveConsensusRollbackOnStateHashFailure", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + result := &AdvanceResult{ + EpochIndex: 99, // non-existent epoch + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte("should-be-rolled-back")}, + Hashes: [][32]byte{{1}, {2}}, + RemainingMetaCycles: 10, + IsDaveConsensus: true, + OutputsProof: OutputsProof{ + OutputsHash: UniqueHash(), + MachineHash: UniqueHash(), + }, + } + + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().Error(err) + + // Verify no outputs were persisted (rolled back) + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(outputs) + s.Equal(uint64(0), total) + + // Verify the input remains unprocessed + input, err := s.Repo.GetInput( + s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(InputCompletionStatus_None, input.Status) + }) +} + +func (s *BulkOperationsSuite) TestStoreClaimAndProofs() { + s.Run("StoresClaimAndOutputProofs", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + // First store an advance result to create outputs + machineHash := crypto.Keccak256Hash([]byte("machine")) + outputData := []byte("output-for-claim") + outputsHash := crypto.Keccak256Hash([]byte("outputs-merkle")) + + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{outputData}, + OutputsProof: OutputsProof{ + OutputsHash: outputsHash, + MachineHash: machineHash, + }, + } + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().NoError(err) + + // Now store claim and proofs using Commitment/CommitmentProof fields + commitmentHash := crypto.Keccak256Hash([]byte("commitment")) + seed.Epoch.Commitment = &commitmentHash + seed.Epoch.CommitmentProof = []common.Hash{UniqueHash(), UniqueHash()} + + outputHash := crypto.Keccak256Hash(outputData) + proof := []common.Hash{UniqueHash(), UniqueHash()} + out := &Output{ + InputEpochApplicationID: seed.App.ID, + InputIndex: 0, + Index: 0, + RawData: outputData, + Hash: &outputHash, + OutputHashesSiblings: proof, + } + + err = s.Repo.StoreClaimAndProofs(s.Ctx, seed.Epoch, []*Output{out}) + s.Require().NoError(err) + + gotEpoch, err := s.Repo.GetEpoch(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(EpochStatus_ClaimComputed, gotEpoch.Status) + s.Require().NotNil(gotEpoch.Commitment) + s.Equal(commitmentHash, *gotEpoch.Commitment) + }) +} + +func (s *BulkOperationsSuite) TestStoreTournamentEvents() { + // setupTournamentWithMatch creates a PRT app with a tournament, two + // commitments, and one match, all stored via StoreTournamentEvents. + type tournamentSetup struct { + app *Application + tournAddr common.Address + match *Match + } + setupTournamentWithMatch := func() *tournamentSetup { + s.T().Helper() + app := NewApplicationBuilder(). + WithConsensus(Consensus_PRT). + Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed).WithBlocks(0, 9).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + tournAddr := UniqueAddress() + tournament := NewTournamentBuilder(app.ID). + WithEpochIndex(0).WithAddress(tournAddr).Build() + err = s.Repo.CreateTournament( + s.Ctx, app.IApplicationAddress.String(), tournament) + s.Require().NoError(err) + + commitment1 := NewCommitmentBuilder(app.ID). + WithEpochIndex(0).WithTournamentAddress(tournAddr).Build() + commitment2 := NewCommitmentBuilder(app.ID). + WithEpochIndex(0).WithTournamentAddress(tournAddr).Build() + + matchIDHash := UniqueHash() + match := NewMatchBuilder(app.ID). + WithEpochIndex(0). + WithTournamentAddress(tournAddr). + WithIDHash(matchIDHash). + WithCommitmentOne(commitment1.Commitment). + WithCommitmentTwo(commitment2.Commitment). + Build() + + err = s.Repo.StoreTournamentEvents( + s.Ctx, app.ID, + []*Commitment{commitment1, commitment2}, + []*Match{match}, + nil, nil, 100) + s.Require().NoError(err) + + return &tournamentSetup{app: app, tournAddr: tournAddr, match: match} + } + + s.Run("StoresCommitmentsAndMatches", func() { + ts := setupTournamentWithMatch() + + // Verify commitment was stored + gotCommitment, err := s.Repo.GetCommitment( + s.Ctx, ts.app.IApplicationAddress.String(), + 0, ts.tournAddr.String(), ts.match.CommitmentOne.Hex()) + s.Require().NoError(err) + s.Equal(ts.match.CommitmentOne, gotCommitment.Commitment) + + // Verify match was stored + gotMatch, err := s.Repo.GetMatch( + s.Ctx, ts.app.IApplicationAddress.String(), + 0, ts.tournAddr.String(), ts.match.IDHash.Hex()) + s.Require().NoError(err) + s.Equal(ts.match.IDHash, gotMatch.IDHash) + }) + + s.Run("StoresMatchAdvanced", func() { + ts := setupTournamentWithMatch() + + // Now store a match advanced event for the existing match + ma := NewMatchAdvancedBuilder(ts.app.ID). + WithEpochIndex(0). + WithTournamentAddress(ts.tournAddr). + WithIDHash(ts.match.IDHash). + Build() + + err := s.Repo.StoreTournamentEvents( + s.Ctx, ts.app.ID, + nil, nil, + []*MatchAdvanced{ma}, nil, 200) + s.Require().NoError(err) + + // Verify match advanced was stored + gotMA, err := s.Repo.GetMatchAdvanced( + s.Ctx, ts.app.IApplicationAddress.String(), + 0, ts.tournAddr.String(), ts.match.IDHash.Hex(), + hex.EncodeToString(ma.OtherParent[:])) + s.Require().NoError(err) + s.Require().NotNil(gotMA) + s.Equal(ma.OtherParent, gotMA.OtherParent) + }) + + s.Run("UpdatesDeletedMatches", func() { + ts := setupTournamentWithMatch() + + // Mark the match as deleted (winner decided) + deletedMatch := &Match{ + EpochIndex: 0, + TournamentAddress: ts.tournAddr, + IDHash: ts.match.IDHash, + Winner: WinnerCommitment_ONE, + DeletionReason: MatchDeletionReason_TIMEOUT, + DeletionBlockNumber: 200, + DeletionTxHash: UniqueHash(), + } + + err := s.Repo.StoreTournamentEvents( + s.Ctx, ts.app.ID, + nil, nil, nil, + []*Match{deletedMatch}, 300) + s.Require().NoError(err) + + // Verify the match was updated + gotMatch, err := s.Repo.GetMatch( + s.Ctx, ts.app.IApplicationAddress.String(), + 0, ts.tournAddr.String(), ts.match.IDHash.Hex()) + s.Require().NoError(err) + s.Equal(WinnerCommitment_ONE, gotMatch.Winner) + s.Equal(MatchDeletionReason_TIMEOUT, gotMatch.DeletionReason) + }) +} + +func (s *BulkOperationsSuite) TestConcurrentStoreAdvanceResult() { + // Verify that concurrent StoreAdvanceResult calls for different + // applications succeed independently without corrupting data. + s.Run("DifferentApplications", func() { + seed1 := Seed(s.Ctx, s.T(), s.Repo) + seed2 := Seed(s.Ctx, s.T(), s.Repo) + + var wg sync.WaitGroup + errs := make([]error, 2) + + seeds := []*SeedResult{seed1, seed2} + for i, seed := range seeds { + wg.Add(1) + go func() { + defer wg.Done() + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte(fmt.Sprintf("output-%d", i))}, + Reports: [][]byte{[]byte(fmt.Sprintf("report-%d", i))}, + OutputsProof: OutputsProof{ + OutputsHash: crypto.Keccak256Hash( + []byte(fmt.Sprintf("outputs-%d", i))), + MachineHash: crypto.Keccak256Hash( + []byte(fmt.Sprintf("machine-%d", i))), + }, + } + errs[i] = s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + }() + } + wg.Wait() + + s.Require().NoError(errs[0], "first app store should succeed") + s.Require().NoError(errs[1], "second app store should succeed") + + // Verify first application + input1, err := s.Repo.GetInput( + s.Ctx, seed1.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(InputCompletionStatus_Accepted, input1.Status) + + count1, err := s.Repo.GetProcessedInputCount( + s.Ctx, seed1.App.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(uint64(1), count1) + + outputs1, total1, err := s.Repo.ListOutputs( + s.Ctx, seed1.App.IApplicationAddress.String(), + repository.OutputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(outputs1, 1) + s.Equal(uint64(1), total1) + + // Verify second application + input2, err := s.Repo.GetInput( + s.Ctx, seed2.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(InputCompletionStatus_Accepted, input2.Status) + + count2, err := s.Repo.GetProcessedInputCount( + s.Ctx, seed2.App.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(uint64(1), count2) + + outputs2, total2, err := s.Repo.ListOutputs( + s.Ctx, seed2.App.IApplicationAddress.String(), + repository.OutputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(outputs2, 1) + s.Equal(uint64(1), total2) + }) + + // Verify that concurrent StoreAdvanceResult calls for the same input + // do not corrupt data. At least one goroutine must succeed and the + // final state must be consistent. + s.Run("SameInputAtomicity", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + const numGoroutines = 5 + var wg sync.WaitGroup + errs := make([]error, numGoroutines) + + for i := range numGoroutines { + wg.Add(1) + go func() { + defer wg.Done() + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte(fmt.Sprintf("output-%d", i))}, + OutputsProof: OutputsProof{ + OutputsHash: crypto.Keccak256Hash( + []byte(fmt.Sprintf("outputs-%d", i))), + MachineHash: crypto.Keccak256Hash( + []byte(fmt.Sprintf("machine-%d", i))), + }, + } + errs[i] = s.Repo.StoreAdvanceResult( + s.Ctx, seed.App.ID, result) + }() + } + wg.Wait() + + successCount := 0 + for _, err := range errs { + if err == nil { + successCount++ + } + } + s.GreaterOrEqual(successCount, 1, + "at least one concurrent store should succeed") + + // Verify data integrity: the input must be in Accepted state + input, err := s.Repo.GetInput( + s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(InputCompletionStatus_Accepted, input.Status) + + // ProcessedInputs must reflect exactly one successful processing + count, err := s.Repo.GetProcessedInputCount( + s.Ctx, seed.App.IApplicationAddress.String()) + s.Require().NoError(err) + s.GreaterOrEqual(count, uint64(1)) + }) +} + +func (s *BulkOperationsSuite) TestStoreClaimAndProofsRollback() { + // When the epoch doesn't exist, updateEpochClaim should fail with + // RowsAffected == 0 and the whole transaction should be rolled back. + s.Run("RollbackOnEpochUpdateFailure", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + // Store advance result first so the epoch has outputs + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte("output")}, + OutputsProof: OutputsProof{ + OutputsHash: crypto.Keccak256Hash([]byte("outputs")), + MachineHash: crypto.Keccak256Hash([]byte("machine")), + }, + } + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().NoError(err) + + // Build an epoch that doesn't exist in the DB + commitmentHash := crypto.Keccak256Hash([]byte("commitment")) + badEpoch := &Epoch{ + ApplicationID: seed.App.ID, + Index: 99, // doesn't exist + Commitment: &commitmentHash, + CommitmentProof: []common.Hash{UniqueHash()}, + } + + err = s.Repo.StoreClaimAndProofs(s.Ctx, badEpoch, nil) + s.Require().Error(err) + + // Verify the real epoch is untouched + gotEpoch, err := s.Repo.GetEpoch( + s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(EpochStatus_Closed, gotEpoch.Status) + s.Nil(gotEpoch.Commitment) + }) + + // When updateOutputs fails (output doesn't exist), the epoch status + // change from updateEpochClaim must also be rolled back. + s.Run("RollbackOnOutputProofUpdateFailure", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + // Store advance result to create one output (index 0) + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte("real-output")}, + OutputsProof: OutputsProof{ + OutputsHash: crypto.Keccak256Hash([]byte("outputs")), + MachineHash: crypto.Keccak256Hash([]byte("machine")), + }, + } + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().NoError(err) + + // Prepare a valid epoch claim (this part would succeed) + commitmentHash := crypto.Keccak256Hash([]byte("commitment")) + seed.Epoch.Commitment = &commitmentHash + seed.Epoch.CommitmentProof = []common.Hash{UniqueHash()} + + // Prepare an output with a non-existent index to trigger failure + badHash := UniqueHash() + nonExistentOutput := &Output{ + InputEpochApplicationID: seed.App.ID, + InputIndex: 0, + Index: 999, // doesn't exist + Hash: &badHash, + OutputHashesSiblings: []common.Hash{UniqueHash()}, + } + + err = s.Repo.StoreClaimAndProofs( + s.Ctx, seed.Epoch, []*Output{nonExistentOutput}) + s.Require().Error(err) + + // Verify the epoch status was rolled back — still Closed + gotEpoch, err := s.Repo.GetEpoch( + s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(EpochStatus_Closed, gotEpoch.Status, + "epoch status should be rolled back to Closed") + s.Nil(gotEpoch.Commitment, + "commitment should not have been persisted") + + // Verify the real output's hash was NOT changed + realOutput, err := s.Repo.GetOutput( + s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Nil(realOutput.Hash, + "existing output hash should remain nil after rollback") + }) +} + +func (s *BulkOperationsSuite) TestStoreTournamentEventsRollback() { + // Helper: create a PRT application with one closed epoch and a tournament. + setupPRTApp := func() (app *Application, tournAddr common.Address) { + s.T().Helper() + app = NewApplicationBuilder(). + WithConsensus(Consensus_PRT). + Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed).WithBlocks(0, 9).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + tournAddr = UniqueAddress() + tournament := NewTournamentBuilder(app.ID). + WithEpochIndex(0).WithAddress(tournAddr).Build() + err = s.Repo.CreateTournament( + s.Ctx, app.IApplicationAddress.String(), tournament) + s.Require().NoError(err) + return app, tournAddr + } + + // Insert valid commitments + a match that references a non-existent + // tournament address, causing an FK violation. The commitments + // inserted in the same transaction must be rolled back. + s.Run("RollbackOnMatchInsertFailure", func() { + app, tournAddr := setupPRTApp() + + // Valid commitment targeting the real tournament + commitment := NewCommitmentBuilder(app.ID). + WithEpochIndex(0).WithTournamentAddress(tournAddr).Build() + + // Match targeting a non-existent tournament → FK violation + bogusAddr := UniqueAddress() + match := NewMatchBuilder(app.ID). + WithEpochIndex(0). + WithTournamentAddress(bogusAddr). + WithCommitmentOne(commitment.Commitment). + WithCommitmentTwo(UniqueHash()). + Build() + + err := s.Repo.StoreTournamentEvents( + s.Ctx, app.ID, + []*Commitment{commitment}, + []*Match{match}, + nil, nil, 100) + s.Require().Error(err) + + // Verify the commitment was rolled back + got, err := s.Repo.GetCommitment( + s.Ctx, app.IApplicationAddress.String(), + 0, tournAddr.String(), commitment.Commitment.Hex()) + s.Require().NoError(err) + s.Nil(got, "commitment should have been rolled back") + }) + + // Insert valid commitments + try to delete (update) a non-existent + // match. updateMatches returns an error when RowsAffected == 0. + // The commitments inserted earlier in the same tx must be rolled back. + s.Run("RollbackOnMatchDeleteFailure", func() { + app, tournAddr := setupPRTApp() + + // Valid new commitment + newCommitment := NewCommitmentBuilder(app.ID). + WithEpochIndex(0).WithTournamentAddress(tournAddr).Build() + + // Try to delete a match that doesn't exist + deletedMatch := &Match{ + EpochIndex: 0, + TournamentAddress: tournAddr, + IDHash: UniqueHash(), // doesn't exist + Winner: WinnerCommitment_ONE, + DeletionReason: MatchDeletionReason_TIMEOUT, + } + + err := s.Repo.StoreTournamentEvents( + s.Ctx, app.ID, + []*Commitment{newCommitment}, + nil, nil, + []*Match{deletedMatch}, 100) + s.Require().Error(err) + + // Verify the new commitment was rolled back + got, err := s.Repo.GetCommitment( + s.Ctx, app.IApplicationAddress.String(), + 0, tournAddr.String(), newCommitment.Commitment.Hex()) + s.Require().NoError(err) + s.Nil(got, "commitment should have been rolled back") + }) + + // Insert a match advanced event for a match that doesn't exist, + // causing an FK violation. Commitments in the same tx should roll back. + s.Run("RollbackOnMatchAdvancedInsertFailure", func() { + app, tournAddr := setupPRTApp() + + // Valid new commitment + newCommitment := NewCommitmentBuilder(app.ID). + WithEpochIndex(0).WithTournamentAddress(tournAddr).Build() + + // Match advanced for a non-existent match → FK violation + bogusMA := NewMatchAdvancedBuilder(app.ID). + WithEpochIndex(0). + WithTournamentAddress(tournAddr). + WithIDHash(UniqueHash()). // match doesn't exist + Build() + + err := s.Repo.StoreTournamentEvents( + s.Ctx, app.ID, + []*Commitment{newCommitment}, + nil, + []*MatchAdvanced{bogusMA}, + nil, 100) + s.Require().Error(err) + + // Verify the commitment was rolled back + got, err := s.Repo.GetCommitment( + s.Ctx, app.IApplicationAddress.String(), + 0, tournAddr.String(), newCommitment.Commitment.Hex()) + s.Require().NoError(err) + s.Nil(got, "commitment should have been rolled back") + }) +} + +func (s *BulkOperationsSuite) TestContextCancellation() { + // Verify that StoreAdvanceResult respects a cancelled context + // and does not persist any data. + s.Run("StoreAdvanceResultWithCancelledContext", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + cancelledCtx, cancel := context.WithCancel(s.Ctx) + cancel() // cancel immediately + + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte("should-not-persist")}, + Reports: [][]byte{[]byte("should-not-persist")}, + OutputsProof: OutputsProof{ + OutputsHash: UniqueHash(), + MachineHash: UniqueHash(), + }, + } + + err := s.Repo.StoreAdvanceResult(cancelledCtx, seed.App.ID, result) + s.Require().Error(err) + s.Require().ErrorIs(err, context.Canceled) + + // Data should not have been persisted + input, err := s.Repo.GetInput( + s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(InputCompletionStatus_None, input.Status, + "input should remain unprocessed after context cancellation") + + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(outputs) + s.Equal(uint64(0), total) + }) + + // Verify that CreateEpochsAndInputs respects a cancelled context. + s.Run("CreateEpochsAndInputsWithCancelledContext", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + cancelledCtx, cancel := context.WithCancel(s.Ctx) + cancel() // cancel immediately + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed).WithBlocks(0, 9).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + cancelledCtx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().Error(err) + + // The epoch should not have been persisted + got, err := s.Repo.GetEpoch( + s.Ctx, app.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Nil(got, "epoch should not exist after context cancellation") + }) + + // Verify that StoreClaimAndProofs respects a cancelled context. + s.Run("StoreClaimAndProofsWithCancelledContext", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + // Store an advance result so the epoch has outputs + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte("output")}, + OutputsProof: OutputsProof{ + OutputsHash: crypto.Keccak256Hash([]byte("outputs")), + MachineHash: crypto.Keccak256Hash([]byte("machine")), + }, + } + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().NoError(err) + + cancelledCtx, cancel := context.WithCancel(s.Ctx) + cancel() + + commitmentHash := crypto.Keccak256Hash([]byte("commitment")) + seed.Epoch.Commitment = &commitmentHash + seed.Epoch.CommitmentProof = []common.Hash{UniqueHash()} + + err = s.Repo.StoreClaimAndProofs(cancelledCtx, seed.Epoch, nil) + s.Require().Error(err) + + // Epoch should remain in Closed state + gotEpoch, err := s.Repo.GetEpoch( + s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(EpochStatus_Closed, gotEpoch.Status, + "epoch status should not change after context cancellation") + }) +} diff --git a/internal/repository/repotest/claimer_test_cases.go b/internal/repository/repotest/claimer_test_cases.go new file mode 100644 index 000000000..6f616425e --- /dev/null +++ b/internal/repository/repotest/claimer_test_cases.go @@ -0,0 +1,485 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + "context" + + . "github.com/cartesi/rollups-node/internal/model" + "github.com/ethereum/go-ethereum/common" +) + +type ClaimerSuite struct { + BaseSuite +} + +func NewClaimerSuite(factory RepositoryFactory) *ClaimerSuite { + return &ClaimerSuite{BaseSuite: BaseSuite{factory: factory}} +} + +// createAppWithClaimComputedEpoch creates an app with one epoch at ClaimComputed status. +func (s *ClaimerSuite) createAppWithClaimComputedEpoch() *Application { + s.T().Helper() + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + epoch.Status = EpochStatus_ClaimComputed + err = s.Repo.UpdateEpochStatus(s.Ctx, app.IApplicationAddress.String(), epoch) + s.Require().NoError(err) + + return app +} + +func (s *ClaimerSuite) TestSelectSubmittedClaimPairsPerApp() { + s.Run("EmptyWhenNoClaimComputed", func() { + NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + // Returns: (acceptedOrSubmitted, computed, applications, error) + _, computed, apps, err := s.Repo.SelectSubmittedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + s.Empty(computed) + s.Empty(apps) + }) + + s.Run("ReturnsPairWhenClaimComputed", func() { + app := s.createAppWithClaimComputedEpoch() + + // Returns: (acceptedOrSubmitted, computed, applications, error) + _, computed, apps, err := s.Repo.SelectSubmittedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + s.NotEmpty(computed) + s.NotEmpty(apps) + s.Contains(computed, app.ID) + s.Contains(apps, app.ID) + }) + + s.Run("MultipleAppsReturnsSeparateEntries", func() { + app1 := s.createAppWithClaimComputedEpoch() + app2 := s.createAppWithClaimComputedEpoch() + + _, computed, apps, err := s.Repo.SelectSubmittedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + s.Len(computed, 2) + s.Len(apps, 2) + s.Contains(computed, app1.ID) + s.Contains(computed, app2.ID) + s.Contains(apps, app1.ID) + s.Contains(apps, app2.ID) + }) + + s.Run("IncludesAcceptedOrSubmittedForMultipleApps", func() { + // Create two apps, each with a submitted epoch. + // SelectSubmittedClaimPairsPerApp returns acceptedOrSubmitted + // via selectNewestAcceptedClaimPerApp(includeSubmitted=true). + app1 := s.createAppWithClaimComputedEpoch() + app2 := s.createAppWithClaimComputedEpoch() + + // Move both to ClaimSubmitted + txHash1 := UniqueHash() + err := s.Repo.UpdateEpochWithSubmittedClaim(s.Ctx, app1.ID, 0, txHash1) + s.Require().NoError(err) + + txHash2 := UniqueHash() + err = s.Repo.UpdateEpochWithSubmittedClaim(s.Ctx, app2.ID, 0, txHash2) + s.Require().NoError(err) + + acceptedOrSubmitted, _, _, err := s.Repo.SelectSubmittedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + s.Len(acceptedOrSubmitted, 2) + s.Contains(acceptedOrSubmitted, app1.ID) + s.Contains(acceptedOrSubmitted, app2.ID) + }) + + // Regression guard: verify map keys are actual application IDs + // and that each epoch is stored under the correct key. + s.Run("MultiAppMapKeysMatchEpochApplicationIDs", func() { + app1 := s.createAppWithClaimComputedEpoch() + app2 := s.createAppWithClaimComputedEpoch() + + _, computed, apps, err := s.Repo.SelectSubmittedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + + for appID, epoch := range computed { + s.NotEqual(int64(0), appID, "map key must not be zero") + s.Equal(appID, epoch.ApplicationID, + "epoch stored under wrong key") + } + for appID, app := range apps { + s.NotEqual(int64(0), appID, "map key must not be zero") + s.Equal(appID, app.ID, + "application stored under wrong key") + } + + // Verify specific app data integrity + s.Equal(app1.IApplicationAddress, apps[app1.ID].IApplicationAddress) + s.Equal(app2.IApplicationAddress, apps[app2.ID].IApplicationAddress) + }) + + s.Run("ExcludesPRTApps", func() { + app := NewApplicationBuilder(). + WithConsensus(Consensus_PRT). + Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + epoch.Status = EpochStatus_ClaimComputed + err = s.Repo.UpdateEpochStatus(s.Ctx, app.IApplicationAddress.String(), epoch) + s.Require().NoError(err) + + _, computed, apps, err := s.Repo.SelectSubmittedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + s.Empty(computed) + s.Empty(apps) + }) + + s.Run("ExcludesDisabledApps", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + epoch.Status = EpochStatus_ClaimComputed + err = s.Repo.UpdateEpochStatus(s.Ctx, app.IApplicationAddress.String(), epoch) + s.Require().NoError(err) + + reason := "test disabled" + err = s.Repo.UpdateApplicationState( + s.Ctx, app.ID, ApplicationState_Disabled, &reason) + s.Require().NoError(err) + + _, computed, apps, err := s.Repo.SelectSubmittedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + s.Empty(computed) + s.Empty(apps) + }) + + s.Run("ContextCancellation", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, _, _, err := s.Repo.SelectSubmittedClaimPairsPerApp(ctx) + s.Require().Error(err) + }) +} + +func (s *ClaimerSuite) TestSelectAcceptedClaimPairsPerApp() { + s.Run("EmptyWhenNoClaimSubmitted", func() { + NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + accepted, submitted, apps, err := s.Repo.SelectAcceptedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + s.Empty(accepted) + s.Empty(submitted) + s.Empty(apps) + }) + + s.Run("ReturnsPairWhenClaimAccepted", func() { + app := s.createAppWithClaimComputedEpoch() + + // Move to ClaimSubmitted + txHash := UniqueHash() + err := s.Repo.UpdateEpochWithSubmittedClaim(s.Ctx, app.ID, 0, txHash) + s.Require().NoError(err) + + // Move to ClaimAccepted + err = s.Repo.UpdateEpochWithAcceptedClaim(s.Ctx, app.ID, 0) + s.Require().NoError(err) + + accepted, _, _, err := s.Repo.SelectAcceptedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + s.Len(accepted, 1) + s.Contains(accepted, app.ID) + }) + + s.Run("MultipleAppsReturnsSeparateEntries", func() { + app1 := s.createAppWithClaimComputedEpoch() + app2 := s.createAppWithClaimComputedEpoch() + + // Move both through submitted -> accepted + for _, app := range []*Application{app1, app2} { + txHash := UniqueHash() + err := s.Repo.UpdateEpochWithSubmittedClaim(s.Ctx, app.ID, 0, txHash) + s.Require().NoError(err) + err = s.Repo.UpdateEpochWithAcceptedClaim(s.Ctx, app.ID, 0) + s.Require().NoError(err) + } + + accepted, _, _, err := s.Repo.SelectAcceptedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + s.Len(accepted, 2) + s.Contains(accepted, app1.ID) + s.Contains(accepted, app2.ID) + }) + + // Regression guard: verify accepted map keys match the actual + // epoch.ApplicationID, not a zero-value from an unscanned field. + s.Run("MultiAppMapKeysMatchEpochApplicationIDs", func() { + app1 := s.createAppWithClaimComputedEpoch() + app2 := s.createAppWithClaimComputedEpoch() + + for _, app := range []*Application{app1, app2} { + txHash := UniqueHash() + err := s.Repo.UpdateEpochWithSubmittedClaim(s.Ctx, app.ID, 0, txHash) + s.Require().NoError(err) + err = s.Repo.UpdateEpochWithAcceptedClaim(s.Ctx, app.ID, 0) + s.Require().NoError(err) + } + + accepted, submitted, apps, err := s.Repo.SelectAcceptedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + + for appID, epoch := range accepted { + s.NotEqual(int64(0), appID, "accepted map key must not be zero") + s.Equal(appID, epoch.ApplicationID, + "accepted epoch stored under wrong key") + } + for appID, epoch := range submitted { + s.NotEqual(int64(0), appID, "submitted map key must not be zero") + s.Equal(appID, epoch.ApplicationID, + "submitted epoch stored under wrong key") + } + for appID, app := range apps { + s.NotEqual(int64(0), appID, "apps map key must not be zero") + s.Equal(appID, app.ID, + "application stored under wrong key") + } + + s.Contains(accepted, app1.ID) + s.Contains(accepted, app2.ID) + }) + + s.Run("ExcludesPRTApps", func() { + app := NewApplicationBuilder(). + WithConsensus(Consensus_PRT). + Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + epoch.Status = EpochStatus_ClaimComputed + err = s.Repo.UpdateEpochStatus(s.Ctx, app.IApplicationAddress.String(), epoch) + s.Require().NoError(err) + + txHash := UniqueHash() + err = s.Repo.UpdateEpochWithSubmittedClaim(s.Ctx, app.ID, 0, txHash) + s.Require().NoError(err) + + err = s.Repo.UpdateEpochWithAcceptedClaim(s.Ctx, app.ID, 0) + s.Require().NoError(err) + + accepted, submitted, apps, err := s.Repo.SelectAcceptedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + s.Empty(accepted) + s.Empty(submitted) + s.Empty(apps) + }) + + s.Run("ExcludesDisabledApps", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + epoch.Status = EpochStatus_ClaimComputed + err = s.Repo.UpdateEpochStatus(s.Ctx, app.IApplicationAddress.String(), epoch) + s.Require().NoError(err) + + txHash := UniqueHash() + err = s.Repo.UpdateEpochWithSubmittedClaim(s.Ctx, app.ID, 0, txHash) + s.Require().NoError(err) + + err = s.Repo.UpdateEpochWithAcceptedClaim(s.Ctx, app.ID, 0) + s.Require().NoError(err) + + reason := "test disabled" + err = s.Repo.UpdateApplicationState( + s.Ctx, app.ID, ApplicationState_Disabled, &reason) + s.Require().NoError(err) + + accepted, submitted, apps, err := s.Repo.SelectAcceptedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + s.Empty(accepted) + s.Empty(submitted) + s.Empty(apps) + }) + + s.Run("ReturnsSubmittedMapWithBothEpochStates", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + // Create two epochs with inputs + epoch0 := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + epoch1 := NewEpochBuilder(app.ID). + WithIndex(1).WithStatus(EpochStatus_Closed). + WithBlocks(10, 19).WithInputBounds(1, 1).Build() + input1 := NewInputBuilder().WithIndex(1).WithBlockNumber(15).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch0: {input0}, epoch1: {input1}}, 20) + s.Require().NoError(err) + + // Move epoch 0 to ClaimAccepted + epoch0.Status = EpochStatus_ClaimComputed + err = s.Repo.UpdateEpochStatus( + s.Ctx, app.IApplicationAddress.String(), epoch0) + s.Require().NoError(err) + + txHash0 := UniqueHash() + err = s.Repo.UpdateEpochWithSubmittedClaim(s.Ctx, app.ID, 0, txHash0) + s.Require().NoError(err) + + err = s.Repo.UpdateEpochWithAcceptedClaim(s.Ctx, app.ID, 0) + s.Require().NoError(err) + + // Move epoch 1 to ClaimSubmitted + epoch1.Status = EpochStatus_ClaimComputed + err = s.Repo.UpdateEpochStatus( + s.Ctx, app.IApplicationAddress.String(), epoch1) + s.Require().NoError(err) + + txHash1 := UniqueHash() + err = s.Repo.UpdateEpochWithSubmittedClaim(s.Ctx, app.ID, 1, txHash1) + s.Require().NoError(err) + + accepted, submitted, apps, err := s.Repo.SelectAcceptedClaimPairsPerApp(s.Ctx) + s.Require().NoError(err) + + // accepted contains epoch 0 (newest accepted) + s.Len(accepted, 1) + s.Contains(accepted, app.ID) + s.Equal(uint64(0), accepted[app.ID].Index) + s.Equal(EpochStatus_ClaimAccepted, accepted[app.ID].Status) + + // submitted contains epoch 1 (oldest submitted) + s.Len(submitted, 1) + s.Contains(submitted, app.ID) + s.Equal(uint64(1), submitted[app.ID].Index) + s.Equal(EpochStatus_ClaimSubmitted, submitted[app.ID].Status) + + // apps contains the application + s.Len(apps, 1) + s.Contains(apps, app.ID) + s.Equal(app.IApplicationAddress, apps[app.ID].IApplicationAddress) + }) + + s.Run("ContextCancellation", func() { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, _, _, err := s.Repo.SelectAcceptedClaimPairsPerApp(ctx) + s.Require().Error(err) + }) +} + +func (s *ClaimerSuite) TestUpdateEpochWithSubmittedClaim() { + s.Run("SetsClaimSubmitted", func() { + app := s.createAppWithClaimComputedEpoch() + + txHash := common.HexToHash("0xdeadbeef") + err := s.Repo.UpdateEpochWithSubmittedClaim(s.Ctx, app.ID, 0, txHash) + s.Require().NoError(err) + + got, err := s.Repo.GetEpoch(s.Ctx, app.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(EpochStatus_ClaimSubmitted, got.Status) + s.Require().NotNil(got.ClaimTransactionHash) + s.Equal(txHash, *got.ClaimTransactionHash) + }) + + s.Run("ErrorWhenEpochNotClaimComputed", func() { + // Create an app with an epoch still in Closed status (not ClaimComputed) + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + txHash := UniqueHash() + err = s.Repo.UpdateEpochWithSubmittedClaim(s.Ctx, app.ID, 0, txHash) + s.Require().Error(err) + }) +} + +func (s *ClaimerSuite) TestUpdateEpochWithAcceptedClaim() { + s.Run("SetsClaimAccepted", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + epoch.Status = EpochStatus_ClaimSubmitted + err = s.Repo.UpdateEpochStatus(s.Ctx, app.IApplicationAddress.String(), epoch) + s.Require().NoError(err) + + err = s.Repo.UpdateEpochWithAcceptedClaim(s.Ctx, app.ID, 0) + s.Require().NoError(err) + + got, err := s.Repo.GetEpoch(s.Ctx, app.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(EpochStatus_ClaimAccepted, got.Status) + }) + + s.Run("ErrorWhenEpochNotClaimSubmitted", func() { + // Create an app with an epoch in ClaimComputed status (not ClaimSubmitted) + app := s.createAppWithClaimComputedEpoch() + + err := s.Repo.UpdateEpochWithAcceptedClaim(s.Ctx, app.ID, 0) + s.Require().Error(err) + }) +} diff --git a/internal/repository/repotest/commitment_test_cases.go b/internal/repository/repotest/commitment_test_cases.go new file mode 100644 index 000000000..44316b627 --- /dev/null +++ b/internal/repository/repotest/commitment_test_cases.go @@ -0,0 +1,148 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + "github.com/cartesi/rollups-node/internal/repository" + "github.com/ethereum/go-ethereum/common" +) + +type CommitmentSuite struct { + BaseSuite +} + +func NewCommitmentSuite(factory RepositoryFactory) *CommitmentSuite { + return &CommitmentSuite{BaseSuite: BaseSuite{factory: factory}} +} + +func (s *CommitmentSuite) createTournament() (*SeedResult, common.Address) { + seed := Seed(s.Ctx, s.T(), s.Repo) + tournAddr := UniqueAddress() + t := NewTournamentBuilder(seed.App.ID). + WithEpochIndex(0).WithAddress(tournAddr).Build() + err := s.Repo.CreateTournament(s.Ctx, seed.App.IApplicationAddress.String(), t) + s.Require().NoError(err) + return seed, tournAddr +} + +func (s *CommitmentSuite) TestCreateCommitment() { + s.Run("CreatesSuccessfully", func() { + seed, tournAddr := s.createTournament() + + c := NewCommitmentBuilder(seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(tournAddr). + Build() + + err := s.Repo.CreateCommitment( + s.Ctx, seed.App.IApplicationAddress.String(), c) + s.Require().NoError(err) + }) +} + +func (s *CommitmentSuite) TestGetCommitment() { + s.Run("ExistingCommitment", func() { + seed, tournAddr := s.createTournament() + + c := NewCommitmentBuilder(seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(tournAddr). + Build() + + err := s.Repo.CreateCommitment( + s.Ctx, seed.App.IApplicationAddress.String(), c) + s.Require().NoError(err) + + got, err := s.Repo.GetCommitment( + s.Ctx, seed.App.IApplicationAddress.String(), + 0, tournAddr.String(), c.Commitment.Hex()) + s.Require().NoError(err) + s.Equal(c.Commitment, got.Commitment) + s.Equal(c.FinalStateHash, got.FinalStateHash) + }) + + s.Run("NotFound", func() { + seed, tournAddr := s.createTournament() + + got, err := s.Repo.GetCommitment( + s.Ctx, seed.App.IApplicationAddress.String(), + 0, tournAddr.String(), UniqueHash().Hex()) + s.Require().NoError(err) + s.Nil(got) + }) +} + +func (s *CommitmentSuite) TestListCommitments() { + s.Run("EmptyResult", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + commitments, total, err := s.Repo.ListCommitments( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.CommitmentFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(commitments) + s.Equal(uint64(0), total) + }) + + s.Run("ReturnsAll", func() { + seed, tournAddr := s.createTournament() + for range 3 { + c := NewCommitmentBuilder(seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(tournAddr). + Build() + err := s.Repo.CreateCommitment( + s.Ctx, seed.App.IApplicationAddress.String(), c) + s.Require().NoError(err) + } + + commitments, total, err := s.Repo.ListCommitments( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.CommitmentFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(commitments, 3) + s.Equal(uint64(3), total) + }) + + s.Run("FilterByEpochIndex", func() { + seed, tournAddr := s.createTournament() + + c := NewCommitmentBuilder(seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(tournAddr). + Build() + err := s.Repo.CreateCommitment( + s.Ctx, seed.App.IApplicationAddress.String(), c) + s.Require().NoError(err) + + epochIdx := uint64(0) + commitments, total, err := s.Repo.ListCommitments( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.CommitmentFilter{EpochIndex: &epochIdx}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(commitments, 1) + s.Equal(uint64(1), total) + }) + + s.Run("FilterByTournamentAddress", func() { + seed, tournAddr := s.createTournament() + + c := NewCommitmentBuilder(seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(tournAddr). + Build() + err := s.Repo.CreateCommitment( + s.Ctx, seed.App.IApplicationAddress.String(), c) + s.Require().NoError(err) + + addrStr := tournAddr.String() + commitments, total, err := s.Repo.ListCommitments( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.CommitmentFilter{TournamentAddress: &addrStr}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(commitments, 1) + s.Equal(uint64(1), total) + }) +} diff --git a/internal/repository/repotest/epoch_test_cases.go b/internal/repository/repotest/epoch_test_cases.go new file mode 100644 index 000000000..e3068512e --- /dev/null +++ b/internal/repository/repotest/epoch_test_cases.go @@ -0,0 +1,725 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + "errors" + + . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" + "github.com/ethereum/go-ethereum/common" +) + +type EpochSuite struct { + BaseSuite +} + +func NewEpochSuite(factory RepositoryFactory) *EpochSuite { + return &EpochSuite{BaseSuite: BaseSuite{factory: factory}} +} + +func (s *EpochSuite) TestCreateEpochsAndInputs() { + s.Run("SingleEpochSingleInput", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epoch := NewEpochBuilder(app.ID).WithIndex(0).WithStatus(EpochStatus_Closed).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + got, err := s.Repo.GetEpoch(s.Ctx, app.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(uint64(0), got.Index) + s.Equal(EpochStatus_Closed, got.Status) + }) + + s.Run("MultipleEpochsMultipleInputs", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch0 := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed).WithBlocks(0, 9).Build() + epoch1 := NewEpochBuilder(app.ID). + WithIndex(1).WithStatus(EpochStatus_Open).WithBlocks(10, 19).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + input1 := NewInputBuilder().WithIndex(1).WithEpochIndex(1).WithBlockNumber(15).Build() + + epochInputMap := map[*Epoch][]*Input{ + epoch0: {input0}, + epoch1: {input1}, + } + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), epochInputMap, 20) + s.Require().NoError(err) + + got0, err := s.Repo.GetEpoch(s.Ctx, app.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(EpochStatus_Closed, got0.Status) + + got1, err := s.Repo.GetEpoch(s.Ctx, app.IApplicationAddress.String(), 1) + s.Require().NoError(err) + s.Equal(EpochStatus_Open, got1.Status) + }) + + s.Run("EpochWithNoInputs", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed).WithBlocks(0, 9).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {}}, 10) + s.Require().NoError(err) + + // Epoch should exist + got, err := s.Repo.GetEpoch(s.Ctx, app.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Require().NotNil(got) + s.Equal(EpochStatus_Closed, got.Status) + + // No inputs should exist + inputs, total, err := s.Repo.ListInputs( + s.Ctx, app.IApplicationAddress.String(), + repository.InputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(inputs) + s.Equal(uint64(0), total) + }) + + s.Run("UpsertExistingEpoch", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Open).WithBlocks(0, 9).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + // Upsert the same epoch with updated status + epoch2 := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed).WithBlocks(0, 9).Build() + input2 := NewInputBuilder().WithIndex(1).WithBlockNumber(8).Build() + + err = s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch2: {input2}}, 10) + s.Require().NoError(err) + + got, err := s.Repo.GetEpoch(s.Ctx, app.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(EpochStatus_Closed, got.Status) + }) +} + +func (s *EpochSuite) TestGetEpoch() { + s.Run("ExistingEpoch", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + got, err := s.Repo.GetEpoch(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(seed.App.ID, got.ApplicationID) + s.Equal(uint64(0), got.Index) + s.Equal(seed.Epoch.FirstBlock, got.FirstBlock) + s.Equal(seed.Epoch.LastBlock, got.LastBlock) + s.Equal(seed.Epoch.InputIndexLowerBound, got.InputIndexLowerBound) + s.Equal(seed.Epoch.InputIndexUpperBound, got.InputIndexUpperBound) + s.Equal(EpochStatus_Closed, got.Status) + s.Equal(uint64(0), got.VirtualIndex) + s.Nil(got.MachineHash) + s.Nil(got.OutputsMerkleRoot) + s.Nil(got.ClaimTransactionHash) + s.Nil(got.Commitment) + s.False(got.CreatedAt.IsZero(), "CreatedAt should be set") + s.False(got.UpdatedAt.IsZero(), "UpdatedAt should be set") + }) + + s.Run("NotFound", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + got, err := s.Repo.GetEpoch(s.Ctx, app.IApplicationAddress.String(), 99) + s.Require().NoError(err) + s.Nil(got) + }) +} + +func (s *EpochSuite) TestGetEpochByVirtualIndex() { + s.Run("ExistingVirtualIndex", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + got, err := s.Repo.GetEpochByVirtualIndex( + s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(uint64(0), got.VirtualIndex) + }) + + s.Run("NotFound", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + got, err := s.Repo.GetEpochByVirtualIndex( + s.Ctx, app.IApplicationAddress.String(), 99) + s.Require().NoError(err) + s.Nil(got) + }) + + s.Run("DivergentVirtualAndPhysicalIndex", func() { + // CreateEpochsAndInputs auto-assigns VirtualIndex as MAX(VirtualIndex)+1. + // By creating an epoch with Index=5, VirtualIndex will be auto-assigned as 0. + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(5). + WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + // Lookup by virtual index 0 (auto-assigned) should find the epoch with Index=5 + got, err := s.Repo.GetEpochByVirtualIndex( + s.Ctx, app.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Require().NotNil(got) + s.Equal(uint64(5), got.Index) + s.Equal(uint64(0), got.VirtualIndex) + + // Lookup by virtual index 5 should NOT find it + got, err = s.Repo.GetEpochByVirtualIndex( + s.Ctx, app.IApplicationAddress.String(), 5) + s.Require().NoError(err) + s.Nil(got) + }) +} + +func (s *EpochSuite) TestGetLastAcceptedEpochIndex() { + s.Run("WithAcceptedEpoch", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch0 := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_ClaimAccepted).WithBlocks(0, 9).Build() + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + epoch1 := NewEpochBuilder(app.ID). + WithIndex(1).WithStatus(EpochStatus_Closed).WithBlocks(10, 19).Build() + input1 := NewInputBuilder().WithIndex(1).WithEpochIndex(1).WithBlockNumber(15).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch0: {input0}, epoch1: {input1}}, 20) + s.Require().NoError(err) + + idx, err := s.Repo.GetLastAcceptedEpochIndex(s.Ctx, app.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(uint64(0), idx) + }) + + s.Run("ErrorWhenNoAcceptedEpochs", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + _, err = s.Repo.GetLastAcceptedEpochIndex( + s.Ctx, app.IApplicationAddress.String()) + s.Require().Error(err) + s.True(errors.Is(err, repository.ErrNotFound)) + }) +} + +func (s *EpochSuite) TestGetLastNonOpenEpoch() { + s.Run("ReturnsClosedEpoch", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed).WithBlocks(0, 9).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + got, err := s.Repo.GetLastNonOpenEpoch(s.Ctx, app.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(EpochStatus_Closed, got.Status) + }) + + s.Run("NilWhenAllEpochsAreOpen", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Open). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + input := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input}}, 10) + s.Require().NoError(err) + + got, err := s.Repo.GetLastNonOpenEpoch( + s.Ctx, app.IApplicationAddress.String()) + s.Require().NoError(err) + s.Nil(got) + }) +} + +func (s *EpochSuite) TestListEpochs() { + s.Run("EmptyResult", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epochs, total, err := s.Repo.ListEpochs( + s.Ctx, app.IApplicationAddress.String(), + repository.EpochFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(epochs) + s.Equal(uint64(0), total) + }) + + s.Run("FilterByStatus", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch0 := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed).WithBlocks(0, 9).Build() + epoch1 := NewEpochBuilder(app.ID). + WithIndex(1).WithStatus(EpochStatus_Open).WithBlocks(10, 19).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + input1 := NewInputBuilder().WithIndex(1).WithEpochIndex(1).WithBlockNumber(15).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch0: {input0}, epoch1: {input1}}, 20) + s.Require().NoError(err) + + epochs, total, err := s.Repo.ListEpochs( + s.Ctx, app.IApplicationAddress.String(), + repository.EpochFilter{Status: []EpochStatus{EpochStatus_Closed}}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(epochs, 1) + s.Equal(uint64(1), total) + s.Equal(EpochStatus_Closed, epochs[0].Status) + }) + + s.Run("FilterByBeforeBlock", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch0 := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + epoch1 := NewEpochBuilder(app.ID). + WithIndex(1).WithStatus(EpochStatus_Closed). + WithBlocks(10, 19).WithInputBounds(1, 1).Build() + epoch2 := NewEpochBuilder(app.ID). + WithIndex(2).WithStatus(EpochStatus_Open). + WithBlocks(20, 29).WithInputBounds(2, 2).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + input1 := NewInputBuilder().WithIndex(1).WithEpochIndex(1).WithBlockNumber(15).Build() + input2 := NewInputBuilder().WithIndex(2).WithEpochIndex(2).WithBlockNumber(25).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch0: {input0}, epoch1: {input1}, epoch2: {input2}}, 30) + s.Require().NoError(err) + + // BeforeBlock=15 means LastBlock < 15, so epoch0 (LastBlock=9) matches + beforeBlock := uint64(15) + epochs, total, err := s.Repo.ListEpochs( + s.Ctx, app.IApplicationAddress.String(), + repository.EpochFilter{BeforeBlock: &beforeBlock}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(epochs, 1) + s.Equal(uint64(1), total) + s.Equal(uint64(0), epochs[0].Index) + }) + + s.Run("Pagination", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epochInputMap := make(map[*Epoch][]*Input) + for i := range uint64(5) { + e := NewEpochBuilder(app.ID). + WithIndex(i).WithStatus(EpochStatus_Closed). + WithBlocks(i*10, i*10+9).WithInputBounds(i, i).Build() + inp := NewInputBuilder().WithIndex(i).WithEpochIndex(i).WithBlockNumber(i*10 + 5).Build() + epochInputMap[e] = []*Input{inp} + } + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), epochInputMap, 50) + s.Require().NoError(err) + + epochs, total, err := s.Repo.ListEpochs( + s.Ctx, app.IApplicationAddress.String(), + repository.EpochFilter{}, + repository.Pagination{Limit: 2, Offset: 0}, false) + s.Require().NoError(err) + s.Len(epochs, 2) + s.Equal(uint64(5), total) + }) + + s.Run("Descending", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epochInputMap := make(map[*Epoch][]*Input) + for i := range uint64(3) { + e := NewEpochBuilder(app.ID). + WithIndex(i).WithStatus(EpochStatus_Closed). + WithBlocks(i*10, i*10+9).WithInputBounds(i, i).Build() + inp := NewInputBuilder().WithIndex(i).WithEpochIndex(i). + WithBlockNumber(i*10 + 5).Build() + epochInputMap[e] = []*Input{inp} + } + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), epochInputMap, 30) + s.Require().NoError(err) + + epochs, _, err := s.Repo.ListEpochs( + s.Ctx, app.IApplicationAddress.String(), + repository.EpochFilter{}, + repository.Pagination{Limit: 10}, true) + s.Require().NoError(err) + s.Require().Len(epochs, 3) + // Descending: highest index first + s.Equal(uint64(2), epochs[0].Index) + s.Equal(uint64(1), epochs[1].Index) + s.Equal(uint64(0), epochs[2].Index) + }) + + s.Run("FilterByMultipleStatuses", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch0 := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Open). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + epoch1 := NewEpochBuilder(app.ID). + WithIndex(1).WithStatus(EpochStatus_Closed). + WithBlocks(10, 19).WithInputBounds(1, 1).Build() + epoch2 := NewEpochBuilder(app.ID). + WithIndex(2).WithStatus(EpochStatus_InputsProcessed). + WithBlocks(20, 29).WithInputBounds(2, 2).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + input1 := NewInputBuilder().WithIndex(1).WithEpochIndex(1).WithBlockNumber(15).Build() + input2 := NewInputBuilder().WithIndex(2).WithEpochIndex(2).WithBlockNumber(25).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch0: {input0}, epoch1: {input1}, epoch2: {input2}}, 30) + s.Require().NoError(err) + + // Filter for both Closed and InputsProcessed + epochs, total, err := s.Repo.ListEpochs( + s.Ctx, app.IApplicationAddress.String(), + repository.EpochFilter{ + Status: []EpochStatus{EpochStatus_Closed, EpochStatus_InputsProcessed}, + }, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(epochs, 2) + s.Equal(uint64(2), total) + for _, e := range epochs { + s.True( + e.Status == EpochStatus_Closed || e.Status == EpochStatus_InputsProcessed, + "unexpected status: %s", e.Status) + } + }) +} + +func (s *EpochSuite) TestUpdateEpochStatus() { + s.Run("UpdatesStatus", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + epoch := seed.Epoch + epoch.Status = EpochStatus_InputsProcessed + + err := s.Repo.UpdateEpochStatus(s.Ctx, seed.App.IApplicationAddress.String(), epoch) + s.Require().NoError(err) + + got, err := s.Repo.GetEpoch(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(EpochStatus_InputsProcessed, got.Status) + }) + + s.Run("NotFoundForNonExistentEpoch", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + nonExistentEpoch := &Epoch{Index: 99, Status: EpochStatus_InputsProcessed} + + err := s.Repo.UpdateEpochStatus( + s.Ctx, app.IApplicationAddress.String(), nonExistentEpoch) + s.Require().Error(err) + s.Require().ErrorIs(err, repository.ErrNotFound) + }) +} + +func (s *EpochSuite) TestUpdateEpochInputsProcessed() { + s.Run("MarksEpochProcessed", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + err := s.Repo.UpdateEpochInputsProcessed( + s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + + got, err := s.Repo.GetEpoch(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(EpochStatus_InputsProcessed, got.Status) + }) + + // The update should be a no-op when the previous epoch is still Open + // (i.e., not yet past Closed). The SQL condition requires that the + // previous epoch status is NOT IN (Open, Closed). + s.Run("NoOpWhenPreviousEpochStillOpen", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch0 := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Open). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + epoch1 := NewEpochBuilder(app.ID). + WithIndex(1).WithStatus(EpochStatus_Closed). + WithBlocks(10, 19).WithInputBounds(1, 1).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + input1 := NewInputBuilder().WithIndex(1).WithEpochIndex(1). + WithBlockNumber(15).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch0: {input0}, epoch1: {input1}}, 20) + s.Require().NoError(err) + + // Process input1 so the inputs-present condition is satisfied + result := &AdvanceResult{ + EpochIndex: 1, + InputIndex: 1, + Status: InputCompletionStatus_Accepted, + OutputsProof: OutputsProof{ + OutputsHash: UniqueHash(), + MachineHash: UniqueHash(), + }, + } + err = s.Repo.StoreAdvanceResult(s.Ctx, app.ID, result) + s.Require().NoError(err) + + // Try to mark epoch1 as InputsProcessed; previous epoch0 is Open + err = s.Repo.UpdateEpochInputsProcessed( + s.Ctx, app.IApplicationAddress.String(), 1) + s.Require().NoError(err) // returns nil (no-op), not an error + + got, err := s.Repo.GetEpoch(s.Ctx, app.IApplicationAddress.String(), 1) + s.Require().NoError(err) + // Epoch1 should remain Closed -- not promoted to InputsProcessed + s.Equal(EpochStatus_Closed, got.Status) + }) + + // The update should be a no-op when the epoch still has pending + // (unprocessed) inputs. The SQL requires pending_count == 0. + s.Run("NoOpWhenPendingInputsRemain", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + // Create a single-epoch setup with 2 inputs + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 2).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(3).Build() + input1 := NewInputBuilder().WithIndex(1).WithBlockNumber(7).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input0, input1}}, 10) + s.Require().NoError(err) + + // Process only 1 of the 2 inputs + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + OutputsProof: OutputsProof{ + OutputsHash: UniqueHash(), + MachineHash: UniqueHash(), + }, + } + err = s.Repo.StoreAdvanceResult(s.Ctx, app.ID, result) + s.Require().NoError(err) + + // Try to mark epoch as InputsProcessed; input1 is still pending + err = s.Repo.UpdateEpochInputsProcessed( + s.Ctx, app.IApplicationAddress.String(), 0) + s.Require().NoError(err) // returns nil (no-op) + + got, err := s.Repo.GetEpoch(s.Ctx, app.IApplicationAddress.String(), 0) + s.Require().NoError(err) + // Epoch should remain Closed since not all inputs are processed + s.Equal(EpochStatus_Closed, got.Status) + }) + + // The update should be a no-op when not all expected inputs are present. + // total_count != (upper_bound - lower_bound). + s.Run("NoOpWhenInputCountDoesNotMatchBounds", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + // Epoch expects 3 inputs (bounds 0..3) but we only provide 1. + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 3).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input0}}, 10) + s.Require().NoError(err) + + // Process the single input we do have + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + OutputsProof: OutputsProof{ + OutputsHash: UniqueHash(), + MachineHash: UniqueHash(), + }, + } + err = s.Repo.StoreAdvanceResult(s.Ctx, app.ID, result) + s.Require().NoError(err) + + // Try to mark epoch as InputsProcessed + err = s.Repo.UpdateEpochInputsProcessed( + s.Ctx, app.IApplicationAddress.String(), 0) + s.Require().NoError(err) // returns nil (no-op) + + got, err := s.Repo.GetEpoch(s.Ctx, app.IApplicationAddress.String(), 0) + s.Require().NoError(err) + // Epoch should remain Closed since not all expected inputs are present + s.Equal(EpochStatus_Closed, got.Status) + }) + + // Non-existent epoch should be a silent no-op (returns nil). + s.Run("NoOpForNonExistentEpoch", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + err := s.Repo.UpdateEpochInputsProcessed( + s.Ctx, app.IApplicationAddress.String(), 99) + s.Require().NoError(err) + }) +} + +func (s *EpochSuite) TestUpdateEpochClaimTransactionHash() { + s.Run("SetsTransactionHash", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + txHash := UniqueHash() + seed.Epoch.ClaimTransactionHash = &txHash + + err := s.Repo.UpdateEpochClaimTransactionHash( + s.Ctx, seed.App.IApplicationAddress.String(), seed.Epoch) + s.Require().NoError(err) + + got, err := s.Repo.GetEpoch(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Require().NotNil(got.ClaimTransactionHash) + s.Equal(txHash, *got.ClaimTransactionHash) + }) + + s.Run("NotFoundForNonExistentEpoch", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + txHash := UniqueHash() + nonExistentEpoch := &Epoch{Index: 99, ClaimTransactionHash: &txHash} + + err := s.Repo.UpdateEpochClaimTransactionHash( + s.Ctx, app.IApplicationAddress.String(), nonExistentEpoch) + s.Require().Error(err) + s.Require().ErrorIs(err, repository.ErrNotFound) + }) +} + +func (s *EpochSuite) TestUpdateEpochOutputsProof() { + s.Run("SetsOutputsProof", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + proof := &OutputsProof{ + OutputsHash: UniqueHash(), + MachineHash: UniqueHash(), + OutputsHashProof: [][32]byte{ + [32]byte(common.HexToHash("0xaabb")), + }, + } + + err := s.Repo.UpdateEpochOutputsProof(s.Ctx, seed.App.ID, 0, proof) + s.Require().NoError(err) + + got, err := s.Repo.GetEpoch(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Require().NotNil(got.OutputsMerkleRoot) + }) +} + +func (s *EpochSuite) TestRepeatPreviousEpochOutputsProof() { + s.Run("CopiesProofFromPreviousEpoch", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch0 := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + epoch1 := NewEpochBuilder(app.ID). + WithIndex(1).WithStatus(EpochStatus_Closed). + WithBlocks(10, 19).WithInputBounds(1, 1).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + input1 := NewInputBuilder().WithIndex(1).WithEpochIndex(1).WithBlockNumber(15).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch0: {input0}, epoch1: {input1}}, 20) + s.Require().NoError(err) + + // Set proof on epoch 0 + outputsHash := UniqueHash() + machineHash := UniqueHash() + proof := &OutputsProof{ + OutputsHash: outputsHash, + MachineHash: machineHash, + OutputsHashProof: [][32]byte{ + [32]byte(UniqueHash()), + }, + } + err = s.Repo.UpdateEpochOutputsProof(s.Ctx, app.ID, 0, proof) + s.Require().NoError(err) + + // Copy proof from epoch 0 to epoch 1 + err = s.Repo.RepeatPreviousEpochOutputsProof(s.Ctx, app.ID, 1) + s.Require().NoError(err) + + // Verify epoch 1 has epoch 0's proof + got, err := s.Repo.GetEpoch(s.Ctx, app.IApplicationAddress.String(), 1) + s.Require().NoError(err) + s.Require().NotNil(got.OutputsMerkleRoot) + s.Equal(outputsHash, *got.OutputsMerkleRoot) + s.Require().NotNil(got.MachineHash) + s.Equal(machineHash, *got.MachineHash) + }) + + s.Run("ErrorsForEpochZero", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + err := s.Repo.RepeatPreviousEpochOutputsProof(s.Ctx, seed.App.ID, 0) + s.Require().Error(err) + s.Contains(err.Error(), "epoch 0") + }) + + s.Run("ErrorsForNonExistentEpoch", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + err := s.Repo.RepeatPreviousEpochOutputsProof(s.Ctx, seed.App.ID, 99) + s.Require().Error(err) + }) +} diff --git a/internal/repository/repotest/input_test_cases.go b/internal/repository/repotest/input_test_cases.go new file mode 100644 index 000000000..8b05914af --- /dev/null +++ b/internal/repository/repotest/input_test_cases.go @@ -0,0 +1,382 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" +) + +type InputSuite struct { + BaseSuite +} + +func NewInputSuite(factory RepositoryFactory) *InputSuite { + return &InputSuite{BaseSuite: BaseSuite{factory: factory}} +} + +func (s *InputSuite) TestGetInput() { + s.Run("ExistingInput", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + got, err := s.Repo.GetInput(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(seed.App.ID, got.EpochApplicationID) + s.Equal(uint64(0), got.EpochIndex) + s.Equal(uint64(0), got.Index) + s.Equal(seed.Input.BlockNumber, got.BlockNumber) + s.Equal(seed.Input.RawData, got.RawData) + s.Equal(InputCompletionStatus_None, got.Status) + s.Equal(seed.Input.TransactionReference, got.TransactionReference) + s.Nil(got.MachineHash) + s.Nil(got.OutputsHash) + s.Nil(got.SnapshotURI) + s.False(got.CreatedAt.IsZero(), "CreatedAt should be set") + s.False(got.UpdatedAt.IsZero(), "UpdatedAt should be set") + }) + + s.Run("NotFound", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + got, err := s.Repo.GetInput(s.Ctx, app.IApplicationAddress.String(), 99) + s.Require().NoError(err) + s.Nil(got) + }) +} + +func (s *InputSuite) TestGetInputByTxReference() { + s.Run("NilRef", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + _, err := s.Repo.GetInputByTxReference(s.Ctx, app.IApplicationAddress.String(), nil) + s.Error(err) + }) + + s.Run("ExistingRef", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + ref := seed.Input.TransactionReference + + got, err := s.Repo.GetInputByTxReference( + s.Ctx, seed.App.IApplicationAddress.String(), &ref) + s.Require().NoError(err) + s.Require().NotNil(got) + s.Equal(seed.Input.Index, got.Index) + s.Equal(ref, got.TransactionReference) + }) + + s.Run("NotFound", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + nonExistentRef := UniqueHash() + + got, err := s.Repo.GetInputByTxReference( + s.Ctx, seed.App.IApplicationAddress.String(), &nonExistentRef) + s.Require().NoError(err) + s.Nil(got) + }) +} + +func (s *InputSuite) TestGetLastInput() { + s.Run("ReturnsLastInput", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 19).WithInputBounds(0, 2).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + input1 := NewInputBuilder().WithIndex(1).WithBlockNumber(10).Build() + input2 := NewInputBuilder().WithIndex(2).WithBlockNumber(15).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input0, input1, input2}}, 20) + s.Require().NoError(err) + + got, err := s.Repo.GetLastInput(s.Ctx, app.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(uint64(2), got.Index) + }) +} + +func (s *InputSuite) TestGetLastProcessedInput() { + s.Run("ReturnsLastProcessed", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 19).WithInputBounds(0, 1).Build() + + input0 := NewInputBuilder(). + WithIndex(0).WithBlockNumber(5). + WithStatus(InputCompletionStatus_Accepted).Build() + input1 := NewInputBuilder(). + WithIndex(1).WithBlockNumber(10). + WithStatus(InputCompletionStatus_None).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input0, input1}}, 20) + s.Require().NoError(err) + + got, err := s.Repo.GetLastProcessedInput(s.Ctx, app.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(uint64(0), got.Index) + s.Equal(InputCompletionStatus_Accepted, got.Status) + }) +} + +func (s *InputSuite) TestListInputs() { + s.Run("EmptyResult", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + inputs, total, err := s.Repo.ListInputs( + s.Ctx, app.IApplicationAddress.String(), + repository.InputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(inputs) + s.Equal(uint64(0), total) + }) + + s.Run("ReturnsAllInputs", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 29).WithInputBounds(0, 2).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + input1 := NewInputBuilder().WithIndex(1).WithBlockNumber(10).Build() + input2 := NewInputBuilder().WithIndex(2).WithBlockNumber(20).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input0, input1, input2}}, 30) + s.Require().NoError(err) + + inputs, total, err := s.Repo.ListInputs( + s.Ctx, app.IApplicationAddress.String(), + repository.InputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(inputs, 3) + s.Equal(uint64(3), total) + }) + + s.Run("FilterByEpochIndex", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch0 := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + epoch1 := NewEpochBuilder(app.ID). + WithIndex(1).WithStatus(EpochStatus_Open). + WithBlocks(10, 19).WithInputBounds(1, 1).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + input1 := NewInputBuilder().WithIndex(1).WithEpochIndex(1).WithBlockNumber(15).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch0: {input0}, epoch1: {input1}}, 20) + s.Require().NoError(err) + + epochIdx := uint64(1) + inputs, total, err := s.Repo.ListInputs( + s.Ctx, app.IApplicationAddress.String(), + repository.InputFilter{EpochIndex: &epochIdx}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(inputs, 1) + s.Equal(uint64(1), total) + s.Equal(uint64(1), inputs[0].Index) + }) + + s.Run("FilterByStatus", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 19).WithInputBounds(0, 1).Build() + + input0 := NewInputBuilder(). + WithIndex(0).WithBlockNumber(5). + WithStatus(InputCompletionStatus_Accepted).Build() + input1 := NewInputBuilder(). + WithIndex(1).WithBlockNumber(10). + WithStatus(InputCompletionStatus_None).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input0, input1}}, 20) + s.Require().NoError(err) + + status := InputCompletionStatus_Accepted + inputs, total, err := s.Repo.ListInputs( + s.Ctx, app.IApplicationAddress.String(), + repository.InputFilter{Status: &status}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(inputs, 1) + s.Equal(uint64(1), total) + s.Equal(InputCompletionStatus_Accepted, inputs[0].Status) + }) + + s.Run("FilterByNotStatus", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 19).WithInputBounds(0, 2).Build() + + input0 := NewInputBuilder(). + WithIndex(0).WithBlockNumber(5). + WithStatus(InputCompletionStatus_Accepted).Build() + input1 := NewInputBuilder(). + WithIndex(1).WithBlockNumber(10). + WithStatus(InputCompletionStatus_Rejected).Build() + input2 := NewInputBuilder(). + WithIndex(2).WithBlockNumber(15). + WithStatus(InputCompletionStatus_None).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input0, input1, input2}}, 20) + s.Require().NoError(err) + + notStatus := InputCompletionStatus_None + inputs, total, err := s.Repo.ListInputs( + s.Ctx, app.IApplicationAddress.String(), + repository.InputFilter{NotStatus: ¬Status}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(inputs, 2) + s.Equal(uint64(2), total) + for _, inp := range inputs { + s.NotEqual(InputCompletionStatus_None, inp.Status) + } + }) + + s.Run("FilterBySender", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 19).WithInputBounds(0, 1).Build() + + // The Sender filter uses SUBSTR(raw_data, 81, 20) to extract a + // 20-byte sender address from the ABI-encoded input payload. + senderAddr := UniqueAddress() + rawWithSender := make([]byte, 101) + copy(rawWithSender[80:100], senderAddr.Bytes()) + + otherAddr := UniqueAddress() + rawWithOther := make([]byte, 101) + copy(rawWithOther[80:100], otherAddr.Bytes()) + + input0 := NewInputBuilder(). + WithIndex(0).WithBlockNumber(5). + WithRawData(rawWithSender).Build() + input1 := NewInputBuilder(). + WithIndex(1).WithBlockNumber(10). + WithRawData(rawWithOther).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input0, input1}}, 20) + s.Require().NoError(err) + + inputs, total, err := s.Repo.ListInputs( + s.Ctx, app.IApplicationAddress.String(), + repository.InputFilter{Sender: &senderAddr}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(inputs, 1) + s.Equal(uint64(1), total) + s.Equal(uint64(0), inputs[0].Index) + }) + + s.Run("Pagination", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 49).WithInputBounds(0, 4).Build() + + inputs := make([]*Input, 5) + for i := range uint64(5) { + inputs[i] = NewInputBuilder().WithIndex(i).WithBlockNumber(i*10 + 5).Build() + } + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: inputs}, 50) + s.Require().NoError(err) + + got, total, err := s.Repo.ListInputs( + s.Ctx, app.IApplicationAddress.String(), + repository.InputFilter{}, + repository.Pagination{Limit: 2, Offset: 0}, false) + s.Require().NoError(err) + s.Len(got, 2) + s.Equal(uint64(5), total) + }) + + s.Run("Descending", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 29).WithInputBounds(0, 2).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + input1 := NewInputBuilder().WithIndex(1).WithBlockNumber(10).Build() + input2 := NewInputBuilder().WithIndex(2).WithBlockNumber(20).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input0, input1, input2}}, 30) + s.Require().NoError(err) + + inputs, _, err := s.Repo.ListInputs( + s.Ctx, app.IApplicationAddress.String(), + repository.InputFilter{}, + repository.Pagination{Limit: 10}, true) + s.Require().NoError(err) + s.Require().Len(inputs, 3) + // Descending: highest index first + s.Equal(uint64(2), inputs[0].Index) + s.Equal(uint64(1), inputs[1].Index) + s.Equal(uint64(0), inputs[2].Index) + }) +} + +func (s *InputSuite) TestGetNumberOfInputs() { + s.Run("ReturnsCount", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 19).WithInputBounds(0, 2).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + input1 := NewInputBuilder().WithIndex(1).WithBlockNumber(10).Build() + input2 := NewInputBuilder().WithIndex(2).WithBlockNumber(15).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input0, input1, input2}}, 20) + s.Require().NoError(err) + + count, err := s.Repo.GetNumberOfInputs(s.Ctx, app.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(uint64(3), count) + }) + + s.Run("ZeroWhenEmpty", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + count, err := s.Repo.GetNumberOfInputs(s.Ctx, app.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(uint64(0), count) + }) +} + +func (s *InputSuite) TestUpdateInputSnapshotURI() { + s.Run("SetsSnapshotURI", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + uri := "/snapshots/test" + + err := s.Repo.UpdateInputSnapshotURI(s.Ctx, seed.App.ID, 0, uri) + s.Require().NoError(err) + + got, err := s.Repo.GetInput(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Require().NotNil(got.SnapshotURI) + s.Equal(uri, *got.SnapshotURI) + }) +} diff --git a/internal/repository/repotest/match_advanced_test_cases.go b/internal/repository/repotest/match_advanced_test_cases.go new file mode 100644 index 000000000..09c900655 --- /dev/null +++ b/internal/repository/repotest/match_advanced_test_cases.go @@ -0,0 +1,140 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + "encoding/hex" + + "github.com/cartesi/rollups-node/internal/repository" + "github.com/ethereum/go-ethereum/common" +) + +type MatchAdvancedSuite struct { + BaseSuite +} + +func NewMatchAdvancedSuite(factory RepositoryFactory) *MatchAdvancedSuite { + return &MatchAdvancedSuite{BaseSuite: BaseSuite{factory: factory}} +} + +func (s *MatchAdvancedSuite) createTournamentAndMatch() ( + *SeedResult, common.Address, common.Hash, +) { + seed := Seed(s.Ctx, s.T(), s.Repo) + tournAddr := UniqueAddress() + t := NewTournamentBuilder(seed.App.ID). + WithEpochIndex(0).WithAddress(tournAddr).Build() + err := s.Repo.CreateTournament(s.Ctx, seed.App.IApplicationAddress.String(), t) + s.Require().NoError(err) + + // Create commitments to satisfy match FK constraints + c1 := NewCommitmentBuilder(seed.App.ID). + WithEpochIndex(0).WithTournamentAddress(tournAddr).Build() + err = s.Repo.CreateCommitment(s.Ctx, seed.App.IApplicationAddress.String(), c1) + s.Require().NoError(err) + + c2 := NewCommitmentBuilder(seed.App.ID). + WithEpochIndex(0).WithTournamentAddress(tournAddr).Build() + err = s.Repo.CreateCommitment(s.Ctx, seed.App.IApplicationAddress.String(), c2) + s.Require().NoError(err) + + matchIDHash := UniqueHash() + match := NewMatchBuilder(seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(tournAddr). + WithIDHash(matchIDHash). + WithCommitmentOne(c1.Commitment). + WithCommitmentTwo(c2.Commitment). + Build() + err = s.Repo.CreateMatch(s.Ctx, seed.App.IApplicationAddress.String(), match) + s.Require().NoError(err) + + return seed, tournAddr, matchIDHash +} + +func (s *MatchAdvancedSuite) TestCreateMatchAdvanced() { + s.Run("CreatesSuccessfully", func() { + seed, tournAddr, matchIDHash := s.createTournamentAndMatch() + + ma := NewMatchAdvancedBuilder(seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(tournAddr). + WithIDHash(matchIDHash). + Build() + + err := s.Repo.CreateMatchAdvanced( + s.Ctx, seed.App.IApplicationAddress.String(), ma) + s.Require().NoError(err) + }) +} + +func (s *MatchAdvancedSuite) TestGetMatchAdvanced() { + s.Run("ExistingMatchAdvanced", func() { + seed, tournAddr, matchIDHash := s.createTournamentAndMatch() + + ma := NewMatchAdvancedBuilder(seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(tournAddr). + WithIDHash(matchIDHash). + Build() + + err := s.Repo.CreateMatchAdvanced( + s.Ctx, seed.App.IApplicationAddress.String(), ma) + s.Require().NoError(err) + + got, err := s.Repo.GetMatchAdvanced( + s.Ctx, seed.App.IApplicationAddress.String(), + 0, tournAddr.String(), matchIDHash.Hex(), + hex.EncodeToString(ma.OtherParent[:])) + s.Require().NoError(err) + s.Equal(ma.IDHash, got.IDHash) + s.Equal(ma.OtherParent, got.OtherParent) + }) + + s.Run("NotFound", func() { + seed, tournAddr, matchIDHash := s.createTournamentAndMatch() + nonExistent := UniqueHash() + got, err := s.Repo.GetMatchAdvanced( + s.Ctx, seed.App.IApplicationAddress.String(), + 0, tournAddr.String(), matchIDHash.Hex(), + hex.EncodeToString(nonExistent[:])) + s.Require().NoError(err) + s.Nil(got) + }) +} + +func (s *MatchAdvancedSuite) TestListMatchAdvances() { + s.Run("EmptyResult", func() { + seed, tournAddr, matchIDHash := s.createTournamentAndMatch() + advances, total, err := s.Repo.ListMatchAdvances( + s.Ctx, seed.App.IApplicationAddress.String(), + 0, tournAddr.String(), matchIDHash.Hex(), + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(advances) + s.Equal(uint64(0), total) + }) + + s.Run("ReturnsAll", func() { + seed, tournAddr, matchIDHash := s.createTournamentAndMatch() + for range 3 { + ma := NewMatchAdvancedBuilder(seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(tournAddr). + WithIDHash(matchIDHash). + Build() + err := s.Repo.CreateMatchAdvanced( + s.Ctx, seed.App.IApplicationAddress.String(), ma) + s.Require().NoError(err) + } + + advances, total, err := s.Repo.ListMatchAdvances( + s.Ctx, seed.App.IApplicationAddress.String(), + 0, tournAddr.String(), matchIDHash.Hex(), + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(advances, 3) + s.Equal(uint64(3), total) + }) +} diff --git a/internal/repository/repotest/match_test_cases.go b/internal/repository/repotest/match_test_cases.go new file mode 100644 index 000000000..6f5515c46 --- /dev/null +++ b/internal/repository/repotest/match_test_cases.go @@ -0,0 +1,220 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" + "github.com/ethereum/go-ethereum/common" +) + +type MatchSuite struct { + BaseSuite +} + +func NewMatchSuite(factory RepositoryFactory) *MatchSuite { + return &MatchSuite{BaseSuite: BaseSuite{factory: factory}} +} + +// matchSetup holds a tournament and two commitments needed for match FK constraints. +type matchSetup struct { + seed *SeedResult + tournAddr common.Address + commitHash1 common.Hash + commitHash2 common.Hash +} + +func (s *MatchSuite) setupTournamentWithCommitments() *matchSetup { + seed := Seed(s.Ctx, s.T(), s.Repo) + tournAddr := UniqueAddress() + t := NewTournamentBuilder(seed.App.ID). + WithEpochIndex(0).WithAddress(tournAddr).Build() + err := s.Repo.CreateTournament(s.Ctx, seed.App.IApplicationAddress.String(), t) + s.Require().NoError(err) + + c1 := NewCommitmentBuilder(seed.App.ID). + WithEpochIndex(0).WithTournamentAddress(tournAddr).Build() + err = s.Repo.CreateCommitment(s.Ctx, seed.App.IApplicationAddress.String(), c1) + s.Require().NoError(err) + + c2 := NewCommitmentBuilder(seed.App.ID). + WithEpochIndex(0).WithTournamentAddress(tournAddr).Build() + err = s.Repo.CreateCommitment(s.Ctx, seed.App.IApplicationAddress.String(), c2) + s.Require().NoError(err) + + return &matchSetup{ + seed: seed, + tournAddr: tournAddr, + commitHash1: c1.Commitment, + commitHash2: c2.Commitment, + } +} + +func (s *MatchSuite) TestCreateMatch() { + s.Run("CreatesSuccessfully", func() { + setup := s.setupTournamentWithCommitments() + match := NewMatchBuilder(setup.seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(setup.tournAddr). + WithCommitmentOne(setup.commitHash1). + WithCommitmentTwo(setup.commitHash2). + Build() + + err := s.Repo.CreateMatch( + s.Ctx, setup.seed.App.IApplicationAddress.String(), match) + s.Require().NoError(err) + }) +} + +func (s *MatchSuite) TestGetMatch() { + s.Run("ExistingMatch", func() { + setup := s.setupTournamentWithCommitments() + match := NewMatchBuilder(setup.seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(setup.tournAddr). + WithCommitmentOne(setup.commitHash1). + WithCommitmentTwo(setup.commitHash2). + Build() + + err := s.Repo.CreateMatch( + s.Ctx, setup.seed.App.IApplicationAddress.String(), match) + s.Require().NoError(err) + + got, err := s.Repo.GetMatch( + s.Ctx, setup.seed.App.IApplicationAddress.String(), + 0, setup.tournAddr.String(), match.IDHash.Hex()) + s.Require().NoError(err) + s.Equal(match.IDHash, got.IDHash) + s.Equal(match.CommitmentOne, got.CommitmentOne) + }) + + s.Run("NotFound", func() { + setup := s.setupTournamentWithCommitments() + got, err := s.Repo.GetMatch( + s.Ctx, setup.seed.App.IApplicationAddress.String(), + 0, setup.tournAddr.String(), UniqueHash().Hex()) + s.Require().NoError(err) + s.Nil(got) + }) +} + +func (s *MatchSuite) TestUpdateMatch() { + s.Run("UpdatesWinner", func() { + setup := s.setupTournamentWithCommitments() + match := NewMatchBuilder(setup.seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(setup.tournAddr). + WithCommitmentOne(setup.commitHash1). + WithCommitmentTwo(setup.commitHash2). + Build() + + err := s.Repo.CreateMatch( + s.Ctx, setup.seed.App.IApplicationAddress.String(), match) + s.Require().NoError(err) + + match.Winner = WinnerCommitment_ONE + err = s.Repo.UpdateMatch( + s.Ctx, setup.seed.App.IApplicationAddress.String(), match) + s.Require().NoError(err) + + got, err := s.Repo.GetMatch( + s.Ctx, setup.seed.App.IApplicationAddress.String(), + 0, setup.tournAddr.String(), match.IDHash.Hex()) + s.Require().NoError(err) + s.Equal(WinnerCommitment_ONE, got.Winner) + }) +} + +func (s *MatchSuite) TestListMatches() { + s.Run("EmptyResult", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + matches, total, err := s.Repo.ListMatches( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.MatchFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(matches) + s.Equal(uint64(0), total) + }) + + s.Run("ReturnsAll", func() { + setup := s.setupTournamentWithCommitments() + + // Each match needs unique commitment pairs (matches_unique_pair_idx). + // Create additional commitments for more matches. + for range 3 { + c1 := NewCommitmentBuilder(setup.seed.App.ID). + WithEpochIndex(0).WithTournamentAddress(setup.tournAddr).Build() + err := s.Repo.CreateCommitment( + s.Ctx, setup.seed.App.IApplicationAddress.String(), c1) + s.Require().NoError(err) + + c2 := NewCommitmentBuilder(setup.seed.App.ID). + WithEpochIndex(0).WithTournamentAddress(setup.tournAddr).Build() + err = s.Repo.CreateCommitment( + s.Ctx, setup.seed.App.IApplicationAddress.String(), c2) + s.Require().NoError(err) + + m := NewMatchBuilder(setup.seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(setup.tournAddr). + WithCommitmentOne(c1.Commitment). + WithCommitmentTwo(c2.Commitment). + Build() + err = s.Repo.CreateMatch( + s.Ctx, setup.seed.App.IApplicationAddress.String(), m) + s.Require().NoError(err) + } + + matches, total, err := s.Repo.ListMatches( + s.Ctx, setup.seed.App.IApplicationAddress.String(), + repository.MatchFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(matches, 3) + s.Equal(uint64(3), total) + }) + + s.Run("FilterByEpochIndex", func() { + setup := s.setupTournamentWithCommitments() + m := NewMatchBuilder(setup.seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(setup.tournAddr). + WithCommitmentOne(setup.commitHash1). + WithCommitmentTwo(setup.commitHash2). + Build() + err := s.Repo.CreateMatch( + s.Ctx, setup.seed.App.IApplicationAddress.String(), m) + s.Require().NoError(err) + + epochIdx := uint64(0) + matches, total, err := s.Repo.ListMatches( + s.Ctx, setup.seed.App.IApplicationAddress.String(), + repository.MatchFilter{EpochIndex: &epochIdx}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(matches, 1) + s.Equal(uint64(1), total) + }) + + s.Run("FilterByTournamentAddress", func() { + setup := s.setupTournamentWithCommitments() + m := NewMatchBuilder(setup.seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(setup.tournAddr). + WithCommitmentOne(setup.commitHash1). + WithCommitmentTwo(setup.commitHash2). + Build() + err := s.Repo.CreateMatch( + s.Ctx, setup.seed.App.IApplicationAddress.String(), m) + s.Require().NoError(err) + + addrStr := setup.tournAddr.String() + matches, total, err := s.Repo.ListMatches( + s.Ctx, setup.seed.App.IApplicationAddress.String(), + repository.MatchFilter{TournamentAddress: &addrStr}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(matches, 1) + s.Equal(uint64(1), total) + }) +} diff --git a/internal/repository/repotest/node_config_test_cases.go b/internal/repository/repotest/node_config_test_cases.go new file mode 100644 index 000000000..b888e6a3e --- /dev/null +++ b/internal/repository/repotest/node_config_test_cases.go @@ -0,0 +1,92 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + "encoding/json" + + . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" +) + +type NodeConfigSuite struct { + BaseSuite +} + +func NewNodeConfigSuite(factory RepositoryFactory) *NodeConfigSuite { + return &NodeConfigSuite{BaseSuite: BaseSuite{factory: factory}} +} + +func (s *NodeConfigSuite) TestSaveAndLoadNodeConfigRaw() { + s.Run("RoundTrip", func() { + key := "test-config" + value := map[string]string{"foo": "bar"} + data, err := json.Marshal(value) + s.Require().NoError(err) + + err = s.Repo.SaveNodeConfigRaw(s.Ctx, key, data) + s.Require().NoError(err) + + got, createdAt, updatedAt, err := s.Repo.LoadNodeConfigRaw(s.Ctx, key) + s.Require().NoError(err) + + // Compare JSON semantically (PostgreSQL may reformat whitespace) + var expected, actual map[string]string + s.Require().NoError(json.Unmarshal(data, &expected)) + s.Require().NoError(json.Unmarshal(got, &actual)) + s.Equal(expected, actual) + + s.False(createdAt.IsZero()) + s.False(updatedAt.IsZero()) + }) + + s.Run("UpdateExistingKey", func() { + key := "update-test" + data1, _ := json.Marshal("value1") + data2, _ := json.Marshal("value2") + + err := s.Repo.SaveNodeConfigRaw(s.Ctx, key, data1) + s.Require().NoError(err) + + err = s.Repo.SaveNodeConfigRaw(s.Ctx, key, data2) + s.Require().NoError(err) + + got, _, _, err := s.Repo.LoadNodeConfigRaw(s.Ctx, key) + s.Require().NoError(err) + + // Compare JSON semantically + var expected, actual string + s.Require().NoError(json.Unmarshal(data2, &expected)) + s.Require().NoError(json.Unmarshal(got, &actual)) + s.Equal(expected, actual) + }) + + s.Run("NotFound", func() { + _, _, _, err := s.Repo.LoadNodeConfigRaw(s.Ctx, "nonexistent-key") + s.ErrorIs(err, repository.ErrNotFound) + }) +} + +func (s *NodeConfigSuite) TestGenericNodeConfig() { + s.Run("SaveAndLoadTyped", func() { + type TestConfig struct { + Name string `json:"name"` + Count int `json:"count"` + } + + nc := &NodeConfig[TestConfig]{ + Key: "typed-config", + Value: TestConfig{Name: "test", Count: 42}, + } + + err := repository.SaveNodeConfig(s.Ctx, s.Repo, nc) + s.Require().NoError(err) + + got, err := repository.LoadNodeConfig[TestConfig](s.Ctx, s.Repo, "typed-config") + s.Require().NoError(err) + s.Require().NotNil(got) + s.Equal("test", got.Value.Name) + s.Equal(42, got.Value.Count) + }) +} diff --git a/internal/repository/repotest/output_test_cases.go b/internal/repository/repotest/output_test_cases.go new file mode 100644 index 000000000..e554c68d9 --- /dev/null +++ b/internal/repository/repotest/output_test_cases.go @@ -0,0 +1,536 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" + "github.com/ethereum/go-ethereum/crypto" +) + +type OutputSuite struct { + BaseSuite +} + +func NewOutputSuite(factory RepositoryFactory) *OutputSuite { + return &OutputSuite{BaseSuite: BaseSuite{factory: factory}} +} + +func (s *OutputSuite) TestGetOutput() { + s.Run("ExistingOutput", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + out := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(0). + WithRawData([]byte("output-data")).Build() + err := s.Repo.CreateOutput(s.Ctx, out) + s.Require().NoError(err) + + got, err := s.Repo.GetOutput(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(uint64(0), got.Index) + s.Equal([]byte("output-data"), got.RawData) + }) + + s.Run("NotFound", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + got, err := s.Repo.GetOutput(s.Ctx, app.IApplicationAddress.String(), 99) + s.Require().NoError(err) + s.Nil(got) + }) +} + +func (s *OutputSuite) TestListOutputs() { + s.Run("EmptyResult", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, app.IApplicationAddress.String(), + repository.OutputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(outputs) + s.Equal(uint64(0), total) + }) + + s.Run("ReturnsAllOutputs", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + for i := range uint64(3) { + out := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(i).Build() + err := s.Repo.CreateOutput(s.Ctx, out) + s.Require().NoError(err) + } + + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(outputs, 3) + s.Equal(uint64(3), total) + }) + + s.Run("FilterByEpochIndex", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + // EpochIndex filter also requires input.status = ACCEPTED, + // so use StoreAdvanceResult to create the output with accepted input. + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte("epoch-output")}, + OutputsProof: OutputsProof{ + OutputsHash: UniqueHash(), + MachineHash: UniqueHash(), + }, + } + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().NoError(err) + + epochIdx := uint64(0) + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{EpochIndex: &epochIdx}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(outputs, 1) + s.Equal(uint64(1), total) + }) + + s.Run("FilterByInputIndex", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + for i := range uint64(3) { + out := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(i).Build() + err := s.Repo.CreateOutput(s.Ctx, out) + s.Require().NoError(err) + } + + inputIdx := uint64(0) + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{InputIndex: &inputIdx}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(outputs, 3) + s.Equal(uint64(3), total) + }) + + s.Run("FilterByBlockRange", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 99).WithInputBounds(0, 2).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(10).Build() + input1 := NewInputBuilder().WithIndex(1).WithBlockNumber(50).Build() + input2 := NewInputBuilder().WithIndex(2).WithBlockNumber(90).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input0, input1, input2}}, 100) + s.Require().NoError(err) + + // Store advance results to create outputs with accepted inputs + for i := range uint64(3) { + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: i, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte("output-data")}, + OutputsProof: OutputsProof{ + OutputsHash: UniqueHash(), + MachineHash: UniqueHash(), + }, + } + err = s.Repo.StoreAdvanceResult(s.Ctx, app.ID, result) + s.Require().NoError(err) + } + + // Filter for block range 40-60 (should match input1 at block 50) + blockRange := repository.Range{Start: 40, End: 60} + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, app.IApplicationAddress.String(), + repository.OutputFilter{BlockRange: &blockRange}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(outputs, 1) + s.Equal(uint64(1), total) + }) + + s.Run("FilterByOutputType", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + // OutputType filter uses SUBSTR(raw_data, 1, 4) to match the first 4 bytes + targetType := []byte{0xef, 0x01, 0xab, 0xcd} + rawWithType := make([]byte, 32) + copy(rawWithType[0:4], targetType) + + otherType := []byte{0x00, 0x00, 0x00, 0x00} + rawWithOther := make([]byte, 32) + copy(rawWithOther[0:4], otherType) + + out0 := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(0). + WithRawData(rawWithType).Build() + err := s.Repo.CreateOutput(s.Ctx, out0) + s.Require().NoError(err) + + out1 := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(1). + WithRawData(rawWithOther).Build() + err = s.Repo.CreateOutput(s.Ctx, out1) + s.Require().NoError(err) + + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{OutputType: &targetType}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(outputs, 1) + s.Equal(uint64(1), total) + s.Equal(rawWithType, outputs[0].RawData) + }) + + s.Run("FilterByVoucherAddress", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + // VoucherAddress filter uses SUBSTR(raw_data, 17, 20) + // to extract a 20-byte address at bytes 17-36 (1-indexed) + voucherAddr := UniqueAddress() + rawWithVoucher := make([]byte, 64) + copy(rawWithVoucher[16:36], voucherAddr.Bytes()) + + otherAddr := UniqueAddress() + rawWithOther := make([]byte, 64) + copy(rawWithOther[16:36], otherAddr.Bytes()) + + out0 := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(0). + WithRawData(rawWithVoucher).Build() + err := s.Repo.CreateOutput(s.Ctx, out0) + s.Require().NoError(err) + + out1 := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(1). + WithRawData(rawWithOther).Build() + err = s.Repo.CreateOutput(s.Ctx, out1) + s.Require().NoError(err) + + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{VoucherAddress: &voucherAddr}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(outputs, 1) + s.Equal(uint64(1), total) + s.Equal(rawWithVoucher, outputs[0].RawData) + }) + + s.Run("Pagination", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + for i := range uint64(5) { + out := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(i).Build() + err := s.Repo.CreateOutput(s.Ctx, out) + s.Require().NoError(err) + } + + outputs, total, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{}, + repository.Pagination{Limit: 2, Offset: 0}, false) + s.Require().NoError(err) + s.Len(outputs, 2) + s.Equal(uint64(5), total) + }) + + s.Run("Descending", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + for i := range uint64(3) { + out := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(i).Build() + err := s.Repo.CreateOutput(s.Ctx, out) + s.Require().NoError(err) + } + + outputs, _, err := s.Repo.ListOutputs( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.OutputFilter{}, + repository.Pagination{Limit: 10}, true) + s.Require().NoError(err) + s.Require().Len(outputs, 3) + // Descending: highest index first + s.Equal(uint64(2), outputs[0].Index) + s.Equal(uint64(1), outputs[1].Index) + s.Equal(uint64(0), outputs[2].Index) + }) +} + +func (s *OutputSuite) TestUpdateOutputsExecution() { + s.Run("UpdatesExecutionHash", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + out := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(0).Build() + err := s.Repo.CreateOutput(s.Ctx, out) + s.Require().NoError(err) + + txHash := UniqueHash() + out.ExecutionTransactionHash = &txHash + err = s.Repo.UpdateOutputsExecution( + s.Ctx, seed.App.IApplicationAddress.String(), []*Output{out}, 100) + s.Require().NoError(err) + + got, err := s.Repo.GetOutput(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Require().NotNil(got.ExecutionTransactionHash) + s.Equal(txHash, *got.ExecutionTransactionHash) + }) + + // Regression guard: all output updates must be transactional. + // Verify multiple outputs are updated atomically in a single call. + s.Run("MultipleOutputsUpdatedAtomically", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + // Create 3 outputs + for i := range uint64(3) { + out := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(i).Build() + err := s.Repo.CreateOutput(s.Ctx, out) + s.Require().NoError(err) + } + + txHash := UniqueHash() + outputs := make([]*Output, 3) + for i := range uint64(3) { + outputs[i] = &Output{ + InputEpochApplicationID: seed.App.ID, + Index: i, + ExecutionTransactionHash: &txHash, + } + } + + err := s.Repo.UpdateOutputsExecution( + s.Ctx, seed.App.IApplicationAddress.String(), outputs, 200) + s.Require().NoError(err) + + // Verify all 3 were updated + for i := range uint64(3) { + got, err := s.Repo.GetOutput( + s.Ctx, seed.App.IApplicationAddress.String(), i) + s.Require().NoError(err) + s.Require().NotNil(got.ExecutionTransactionHash, + "output %d should have execution hash", i) + s.Equal(txHash, *got.ExecutionTransactionHash) + } + }) + + s.Run("NilExecutionHashReturnsError", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + out := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(0).Build() + err := s.Repo.CreateOutput(s.Ctx, out) + s.Require().NoError(err) + + // ExecutionTransactionHash is nil — should fail + badOutput := &Output{ + InputEpochApplicationID: seed.App.ID, + Index: 0, + } + err = s.Repo.UpdateOutputsExecution( + s.Ctx, seed.App.IApplicationAddress.String(), + []*Output{badOutput}, 100) + s.Require().Error(err) + }) + + // Verify that a failure mid-loop rolls back all prior output updates. + // We create 3 outputs, set valid hashes on the first two, and use a + // non-existent index for the third. The third update should fail + // (RowsAffected == 0), rolling back the first two. + s.Run("RollbackOnPartialFailure", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + // Create 2 valid outputs (index 0 and 1) + for i := range uint64(2) { + out := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(i).Build() + err := s.Repo.CreateOutput(s.Ctx, out) + s.Require().NoError(err) + } + + txHash := UniqueHash() + outputs := []*Output{ + { + InputEpochApplicationID: seed.App.ID, + Index: 0, + ExecutionTransactionHash: &txHash, + }, + { + InputEpochApplicationID: seed.App.ID, + Index: 1, + ExecutionTransactionHash: &txHash, + }, + { + InputEpochApplicationID: seed.App.ID, + Index: 999, // non-existent + ExecutionTransactionHash: &txHash, + }, + } + + err := s.Repo.UpdateOutputsExecution( + s.Ctx, seed.App.IApplicationAddress.String(), outputs, 200) + s.Require().Error(err) + + // Verify that the first two outputs were NOT updated (rolled back) + for i := range uint64(2) { + got, err := s.Repo.GetOutput( + s.Ctx, seed.App.IApplicationAddress.String(), i) + s.Require().NoError(err) + s.Nil(got.ExecutionTransactionHash, + "output %d should not have execution hash after rollback", i) + } + + // Verify no executed outputs exist + count, err := s.Repo.GetNumberOfExecutedOutputs( + s.Ctx, seed.App.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(uint64(0), count) + }) + + // Verify that a nil hash on the second output rolls back the first + // output's successful update. This exercises the nil-check rollback + // path at the top of the loop. + s.Run("NilHashMidLoopRollsBackPrior", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + for i := range uint64(2) { + out := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(i).Build() + err := s.Repo.CreateOutput(s.Ctx, out) + s.Require().NoError(err) + } + + txHash := UniqueHash() + outputs := []*Output{ + { + InputEpochApplicationID: seed.App.ID, + Index: 0, + ExecutionTransactionHash: &txHash, + }, + { + InputEpochApplicationID: seed.App.ID, + Index: 1, + // nil ExecutionTransactionHash — triggers error + }, + } + + err := s.Repo.UpdateOutputsExecution( + s.Ctx, seed.App.IApplicationAddress.String(), outputs, 200) + s.Require().Error(err) + + // First output should NOT have been updated (rolled back) + got, err := s.Repo.GetOutput( + s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Nil(got.ExecutionTransactionHash, + "output 0 should not have execution hash after rollback") + }) +} + +func (s *OutputSuite) TestGetLastOutputBeforeBlock() { + s.Run("NoOutputReturnsNil", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + got, err := s.Repo.GetLastOutputBeforeBlock( + s.Ctx, app.IApplicationAddress.String(), 100) + s.Require().NoError(err) + s.Nil(got) + }) + + s.Run("ReturnsLastOutputBeforeBlock", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 99).WithInputBounds(0, 2).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(10).Build() + input1 := NewInputBuilder().WithIndex(1).WithBlockNumber(50).Build() + input2 := NewInputBuilder().WithIndex(2).WithBlockNumber(90).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch: {input0, input1, input2}}, 100) + s.Require().NoError(err) + + // Store advance results to create outputs with accepted inputs + for i := range uint64(3) { + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: i, + Status: InputCompletionStatus_Accepted, + Outputs: [][]byte{[]byte("output-data")}, + OutputsProof: OutputsProof{ + OutputsHash: UniqueHash(), + MachineHash: crypto.Keccak256Hash([]byte("machine")), + }, + } + err = s.Repo.StoreAdvanceResult(s.Ctx, app.ID, result) + s.Require().NoError(err) + } + + // Query for outputs before block 60 (should return output from input1 at block 50) + got, err := s.Repo.GetLastOutputBeforeBlock( + s.Ctx, app.IApplicationAddress.String(), 60) + s.Require().NoError(err) + s.Require().NotNil(got) + // The last output before block 60 should be from input1 (index 1) + s.Equal(uint64(1), got.InputIndex) + }) +} + +func (s *OutputSuite) TestGetNumberOfExecutedOutputs() { + s.Run("ReturnsZeroWhenNone", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + count, err := s.Repo.GetNumberOfExecutedOutputs(s.Ctx, app.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(uint64(0), count) + }) + + s.Run("ReturnsCountAfterExecution", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + // Create outputs + for i := range uint64(3) { + out := NewOutputBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(i).Build() + err := s.Repo.CreateOutput(s.Ctx, out) + s.Require().NoError(err) + } + + // Execute 2 of the 3 outputs + txHash := UniqueHash() + out0 := &Output{ + InputEpochApplicationID: seed.App.ID, + Index: 0, + ExecutionTransactionHash: &txHash, + } + out1 := &Output{ + InputEpochApplicationID: seed.App.ID, + Index: 1, + ExecutionTransactionHash: &txHash, + } + err := s.Repo.UpdateOutputsExecution( + s.Ctx, seed.App.IApplicationAddress.String(), + []*Output{out0, out1}, 200) + s.Require().NoError(err) + + count, err := s.Repo.GetNumberOfExecutedOutputs( + s.Ctx, seed.App.IApplicationAddress.String()) + s.Require().NoError(err) + s.Equal(uint64(2), count) + }) +} diff --git a/internal/repository/repotest/report_test_cases.go b/internal/repository/repotest/report_test_cases.go new file mode 100644 index 000000000..3ecaf1f29 --- /dev/null +++ b/internal/repository/repotest/report_test_cases.go @@ -0,0 +1,157 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" +) + +type ReportSuite struct { + BaseSuite +} + +func NewReportSuite(factory RepositoryFactory) *ReportSuite { + return &ReportSuite{BaseSuite: BaseSuite{factory: factory}} +} + +func (s *ReportSuite) TestGetReport() { + s.Run("ExistingReport", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + report := NewReportBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(0). + WithRawData([]byte("report-payload")).Build() + err := s.Repo.CreateReport(s.Ctx, report) + s.Require().NoError(err) + + got, err := s.Repo.GetReport(s.Ctx, seed.App.IApplicationAddress.String(), 0) + s.Require().NoError(err) + s.Equal(uint64(0), got.Index) + s.Equal([]byte("report-payload"), got.RawData) + }) + + s.Run("NotFound", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + got, err := s.Repo.GetReport(s.Ctx, app.IApplicationAddress.String(), 99) + s.Require().NoError(err) + s.Nil(got) + }) +} + +func (s *ReportSuite) TestListReports() { + s.Run("EmptyResult", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + reports, total, err := s.Repo.ListReports( + s.Ctx, app.IApplicationAddress.String(), + repository.ReportFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(reports) + s.Equal(uint64(0), total) + }) + + s.Run("ReturnsAllReports", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + for i := range uint64(3) { + r := NewReportBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(i).Build() + err := s.Repo.CreateReport(s.Ctx, r) + s.Require().NoError(err) + } + + reports, total, err := s.Repo.ListReports( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.ReportFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(reports, 3) + s.Equal(uint64(3), total) + }) + + s.Run("FilterByEpochIndex", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + // EpochIndex filter also requires input.status = ACCEPTED, + // so use StoreAdvanceResult to create the report with accepted input. + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Reports: [][]byte{[]byte("epoch-report")}, + OutputsProof: OutputsProof{ + OutputsHash: UniqueHash(), + MachineHash: UniqueHash(), + }, + } + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().NoError(err) + + epochIdx := uint64(0) + reports, total, err := s.Repo.ListReports( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.ReportFilter{EpochIndex: &epochIdx}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(reports, 1) + s.Equal(uint64(1), total) + }) + + s.Run("FilterByInputIndex", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + + for i := range uint64(3) { + r := NewReportBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(i).Build() + err := s.Repo.CreateReport(s.Ctx, r) + s.Require().NoError(err) + } + + inputIdx := uint64(0) + reports, total, err := s.Repo.ListReports( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.ReportFilter{InputIndex: &inputIdx}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(reports, 3) + s.Equal(uint64(3), total) + }) + + s.Run("Pagination", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + for i := range uint64(5) { + r := NewReportBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(i).Build() + err := s.Repo.CreateReport(s.Ctx, r) + s.Require().NoError(err) + } + + reports, total, err := s.Repo.ListReports( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.ReportFilter{}, + repository.Pagination{Limit: 2, Offset: 0}, false) + s.Require().NoError(err) + s.Len(reports, 2) + s.Equal(uint64(5), total) + }) + + s.Run("Descending", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + for i := range uint64(3) { + r := NewReportBuilder(seed.App.ID). + WithEpochIndex(0).WithInputIndex(0).WithIndex(i).Build() + err := s.Repo.CreateReport(s.Ctx, r) + s.Require().NoError(err) + } + + reports, _, err := s.Repo.ListReports( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.ReportFilter{}, + repository.Pagination{Limit: 10}, true) + s.Require().NoError(err) + s.Require().Len(reports, 3) + // Descending: highest index first + s.Equal(uint64(2), reports[0].Index) + s.Equal(uint64(1), reports[1].Index) + s.Equal(uint64(0), reports[2].Index) + }) +} diff --git a/internal/repository/repotest/repotest.go b/internal/repository/repotest/repotest.go new file mode 100644 index 000000000..612278e10 --- /dev/null +++ b/internal/repository/repotest/repotest.go @@ -0,0 +1,69 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + "context" + "testing" + "time" + + "github.com/cartesi/rollups-node/internal/repository" + "github.com/stretchr/testify/suite" +) + +const testTimeout = 300 * time.Second + +// RepositoryFactory creates a fresh repository backed by an empty schema. +// Called once per sub-test for full isolation. +type RepositoryFactory func(ctx context.Context, t *testing.T) ( + repo repository.Repository, + cleanup func(), +) + +// BaseSuite is embedded by every per-interface suite. +// SetupSubTest calls the factory to get a fresh repo per s.Run() block. +// TearDownSubTest calls cleanup. +type BaseSuite struct { + suite.Suite + factory RepositoryFactory + Repo repository.Repository + cleanup func() + Ctx context.Context + cancel context.CancelFunc +} + +func (s *BaseSuite) SetupSuite() { + s.Ctx, s.cancel = context.WithTimeout(context.Background(), testTimeout) +} + +func (s *BaseSuite) TearDownSuite() { + s.cancel() +} + +func (s *BaseSuite) SetupSubTest() { + s.Repo, s.cleanup = s.factory(s.Ctx, s.T()) +} + +func (s *BaseSuite) TearDownSubTest() { + if s.cleanup != nil { + s.cleanup() + } +} + +// RunAllSuites is the single entry point backends call. +func RunAllSuites(t *testing.T, factory RepositoryFactory) { + t.Run("Application", func(t *testing.T) { suite.Run(t, NewApplicationSuite(factory)) }) + t.Run("Epoch", func(t *testing.T) { suite.Run(t, NewEpochSuite(factory)) }) + t.Run("Input", func(t *testing.T) { suite.Run(t, NewInputSuite(factory)) }) + t.Run("Output", func(t *testing.T) { suite.Run(t, NewOutputSuite(factory)) }) + t.Run("Report", func(t *testing.T) { suite.Run(t, NewReportSuite(factory)) }) + t.Run("StateHash", func(t *testing.T) { suite.Run(t, NewStateHashSuite(factory)) }) + t.Run("BulkOperations", func(t *testing.T) { suite.Run(t, NewBulkOperationsSuite(factory)) }) + t.Run("NodeConfig", func(t *testing.T) { suite.Run(t, NewNodeConfigSuite(factory)) }) + t.Run("Claimer", func(t *testing.T) { suite.Run(t, NewClaimerSuite(factory)) }) + t.Run("Tournament", func(t *testing.T) { suite.Run(t, NewTournamentSuite(factory)) }) + t.Run("Commitment", func(t *testing.T) { suite.Run(t, NewCommitmentSuite(factory)) }) + t.Run("Match", func(t *testing.T) { suite.Run(t, NewMatchSuite(factory)) }) + t.Run("MatchAdvanced", func(t *testing.T) { suite.Run(t, NewMatchAdvancedSuite(factory)) }) +} diff --git a/internal/repository/repotest/state_hash_test_cases.go b/internal/repository/repotest/state_hash_test_cases.go new file mode 100644 index 000000000..1f9df2b1c --- /dev/null +++ b/internal/repository/repotest/state_hash_test_cases.go @@ -0,0 +1,96 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" + "github.com/ethereum/go-ethereum/crypto" +) + +type StateHashSuite struct { + BaseSuite +} + +func NewStateHashSuite(factory RepositoryFactory) *StateHashSuite { + return &StateHashSuite{BaseSuite: BaseSuite{factory: factory}} +} + +func (s *StateHashSuite) TestListStateHashes() { + s.Run("EmptyResult", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + hashes, total, err := s.Repo.ListStateHashes( + s.Ctx, app.IApplicationAddress.String(), + repository.StateHashFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(hashes) + s.Equal(uint64(0), total) + }) + + s.Run("FilterByEpochIndex", func() { + // StateHashes are created by StoreAdvanceResult, tested in BulkOperationsSuite. + // This test verifies the filter works even with no data. + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epochIdx := uint64(0) + hashes, total, err := s.Repo.ListStateHashes( + s.Ctx, app.IApplicationAddress.String(), + repository.StateHashFilter{EpochIndex: &epochIdx}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(hashes) + s.Equal(uint64(0), total) + }) + + s.Run("ReturnsStateHashesFromDaveConsensus", func() { + seed := Seed(s.Ctx, s.T(), s.Repo) + machineHash := crypto.Keccak256Hash([]byte("dave-list-machine")) + outputsHash := crypto.Keccak256Hash([]byte("dave-list-outputs")) + + hash1 := [32]byte(crypto.Keccak256Hash([]byte("list-state-1"))) + hash2 := [32]byte(crypto.Keccak256Hash([]byte("list-state-2"))) + + result := &AdvanceResult{ + EpochIndex: 0, + InputIndex: 0, + Status: InputCompletionStatus_Accepted, + Hashes: [][32]byte{hash1, hash2}, + RemainingMetaCycles: 10, + IsDaveConsensus: true, + OutputsProof: OutputsProof{ + OutputsHash: outputsHash, + MachineHash: machineHash, + }, + } + + err := s.Repo.StoreAdvanceResult(s.Ctx, seed.App.ID, result) + s.Require().NoError(err) + + // List all state hashes + hashes, total, err := s.Repo.ListStateHashes( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.StateHashFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(hashes, 3) // 2 intermediate + 1 final + s.Equal(uint64(3), total) + + // List with epoch filter + epochIdx := uint64(0) + hashes, total, err = s.Repo.ListStateHashes( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.StateHashFilter{EpochIndex: &epochIdx}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(hashes, 3) + s.Equal(uint64(3), total) + + // Verify pagination + hashes, total, err = s.Repo.ListStateHashes( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.StateHashFilter{}, + repository.Pagination{Limit: 2, Offset: 0}, false) + s.Require().NoError(err) + s.Len(hashes, 2) + s.Equal(uint64(3), total) + }) +} diff --git a/internal/repository/repotest/tournament_test_cases.go b/internal/repository/repotest/tournament_test_cases.go new file mode 100644 index 000000000..dd74f490a --- /dev/null +++ b/internal/repository/repotest/tournament_test_cases.go @@ -0,0 +1,274 @@ +// (c) Cartesi and individual authors (see AUTHORS) +// SPDX-License-Identifier: Apache-2.0 (see LICENSE) + +package repotest + +import ( + . "github.com/cartesi/rollups-node/internal/model" + "github.com/cartesi/rollups-node/internal/repository" +) + +type TournamentSuite struct { + BaseSuite +} + +func NewTournamentSuite(factory RepositoryFactory) *TournamentSuite { + return &TournamentSuite{BaseSuite: BaseSuite{factory: factory}} +} + +func (s *TournamentSuite) seedWithEpoch() *SeedResult { + return Seed(s.Ctx, s.T(), s.Repo) +} + +func (s *TournamentSuite) TestCreateTournament() { + s.Run("CreatesSuccessfully", func() { + seed := s.seedWithEpoch() + tournament := NewTournamentBuilder(seed.App.ID). + WithEpochIndex(0).Build() + + err := s.Repo.CreateTournament( + s.Ctx, seed.App.IApplicationAddress.String(), tournament) + s.Require().NoError(err) + }) +} + +func (s *TournamentSuite) TestGetTournament() { + s.Run("ExistingTournament", func() { + seed := s.seedWithEpoch() + tournament := NewTournamentBuilder(seed.App.ID). + WithEpochIndex(0).Build() + + err := s.Repo.CreateTournament( + s.Ctx, seed.App.IApplicationAddress.String(), tournament) + s.Require().NoError(err) + + got, err := s.Repo.GetTournament( + s.Ctx, seed.App.IApplicationAddress.String(), tournament.Address.String()) + s.Require().NoError(err) + s.Equal(tournament.Address, got.Address) + s.Equal(tournament.Level, got.Level) + }) + + s.Run("NotFound", func() { + seed := s.seedWithEpoch() + got, err := s.Repo.GetTournament( + s.Ctx, seed.App.IApplicationAddress.String(), UniqueAddress().String()) + s.Require().NoError(err) + s.Nil(got) + }) +} + +func (s *TournamentSuite) TestUpdateTournament() { + s.Run("UpdatesFields", func() { + seed := s.seedWithEpoch() + tournament := NewTournamentBuilder(seed.App.ID). + WithEpochIndex(0).Build() + + err := s.Repo.CreateTournament( + s.Ctx, seed.App.IApplicationAddress.String(), tournament) + s.Require().NoError(err) + + winnerHash := UniqueHash() + tournament.WinnerCommitment = &winnerHash + err = s.Repo.UpdateTournament( + s.Ctx, seed.App.IApplicationAddress.String(), tournament) + s.Require().NoError(err) + + got, err := s.Repo.GetTournament( + s.Ctx, seed.App.IApplicationAddress.String(), tournament.Address.String()) + s.Require().NoError(err) + s.Require().NotNil(got.WinnerCommitment) + s.Equal(winnerHash, *got.WinnerCommitment) + }) +} + +func (s *TournamentSuite) TestListTournaments() { + s.Run("EmptyResult", func() { + seed := s.seedWithEpoch() + tournaments, total, err := s.Repo.ListTournaments( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.TournamentFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Empty(tournaments) + s.Equal(uint64(0), total) + }) + + s.Run("ReturnsAll", func() { + // Create 3 root tournaments in different epochs + // (unique_root_per_epoch_idx allows only 1 root per epoch) + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + epochInputMap := make(map[*Epoch][]*Input) + for i := range uint64(3) { + e := NewEpochBuilder(app.ID). + WithIndex(i).WithStatus(EpochStatus_Closed). + WithBlocks(i*10, i*10+9).WithInputBounds(i, i).Build() + inp := NewInputBuilder().WithIndex(i).WithEpochIndex(i). + WithBlockNumber(i*10 + 5).Build() + epochInputMap[e] = []*Input{inp} + } + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), epochInputMap, 30) + s.Require().NoError(err) + + for i := range uint64(3) { + t := NewTournamentBuilder(app.ID).WithEpochIndex(i).Build() + err := s.Repo.CreateTournament( + s.Ctx, app.IApplicationAddress.String(), t) + s.Require().NoError(err) + } + + tournaments, total, err := s.Repo.ListTournaments( + s.Ctx, app.IApplicationAddress.String(), + repository.TournamentFilter{}, repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(tournaments, 3) + s.Equal(uint64(3), total) + }) + + s.Run("FilterByEpochIndex", func() { + app := NewApplicationBuilder().Create(s.Ctx, s.T(), s.Repo) + + epoch0 := NewEpochBuilder(app.ID). + WithIndex(0).WithStatus(EpochStatus_Closed). + WithBlocks(0, 9).WithInputBounds(0, 0).Build() + epoch1 := NewEpochBuilder(app.ID). + WithIndex(1).WithStatus(EpochStatus_Closed). + WithBlocks(10, 19).WithInputBounds(1, 1).Build() + + input0 := NewInputBuilder().WithIndex(0).WithBlockNumber(5).Build() + input1 := NewInputBuilder().WithIndex(1).WithEpochIndex(1).WithBlockNumber(15).Build() + + err := s.Repo.CreateEpochsAndInputs( + s.Ctx, app.IApplicationAddress.String(), + map[*Epoch][]*Input{epoch0: {input0}, epoch1: {input1}}, 20) + s.Require().NoError(err) + + t0 := NewTournamentBuilder(app.ID).WithEpochIndex(0).Build() + err = s.Repo.CreateTournament(s.Ctx, app.IApplicationAddress.String(), t0) + s.Require().NoError(err) + + t1 := NewTournamentBuilder(app.ID).WithEpochIndex(1).Build() + err = s.Repo.CreateTournament(s.Ctx, app.IApplicationAddress.String(), t1) + s.Require().NoError(err) + + epochIdx := uint64(0) + tournaments, total, err := s.Repo.ListTournaments( + s.Ctx, app.IApplicationAddress.String(), + repository.TournamentFilter{EpochIndex: &epochIdx}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(tournaments, 1) + s.Equal(uint64(1), total) + }) + + s.Run("FilterByLevel", func() { + seed := s.seedWithEpoch() + + // Root tournament at level 0 + t0 := NewTournamentBuilder(seed.App.ID). + WithEpochIndex(0).WithLevel(0).Build() + err := s.Repo.CreateTournament( + s.Ctx, seed.App.IApplicationAddress.String(), t0) + s.Require().NoError(err) + + // Create commitments required by match FK constraints + c1 := NewCommitmentBuilder(seed.App.ID). + WithEpochIndex(0).WithTournamentAddress(t0.Address).Build() + err = s.Repo.CreateCommitment( + s.Ctx, seed.App.IApplicationAddress.String(), c1) + s.Require().NoError(err) + + c2 := NewCommitmentBuilder(seed.App.ID). + WithEpochIndex(0).WithTournamentAddress(t0.Address).Build() + err = s.Repo.CreateCommitment( + s.Ctx, seed.App.IApplicationAddress.String(), c2) + s.Require().NoError(err) + + // Create a match in the root tournament so the FK constraint is satisfied + matchIDHash := UniqueHash() + match := NewMatchBuilder(seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(t0.Address). + WithIDHash(matchIDHash). + WithCommitmentOne(c1.Commitment). + WithCommitmentTwo(c2.Commitment). + Build() + err = s.Repo.CreateMatch( + s.Ctx, seed.App.IApplicationAddress.String(), match) + s.Require().NoError(err) + + // Sub-tournament at level 1 with parent referencing the match + t1 := NewTournamentBuilder(seed.App.ID). + WithEpochIndex(0).WithLevel(1). + WithParent(t0.Address, matchIDHash).Build() + err = s.Repo.CreateTournament( + s.Ctx, seed.App.IApplicationAddress.String(), t1) + s.Require().NoError(err) + + level := uint64(1) + tournaments, total, err := s.Repo.ListTournaments( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.TournamentFilter{Level: &level}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(tournaments, 1) + s.Equal(uint64(1), total) + s.Equal(uint64(1), tournaments[0].Level) + }) + + s.Run("FilterByParentTournamentAddress", func() { + seed := s.seedWithEpoch() + + // Root tournament + rootAddr := UniqueAddress() + t0 := NewTournamentBuilder(seed.App.ID). + WithEpochIndex(0).WithAddress(rootAddr).Build() + err := s.Repo.CreateTournament( + s.Ctx, seed.App.IApplicationAddress.String(), t0) + s.Require().NoError(err) + + // Create commitments required by match FK constraints + c1 := NewCommitmentBuilder(seed.App.ID). + WithEpochIndex(0).WithTournamentAddress(rootAddr).Build() + err = s.Repo.CreateCommitment( + s.Ctx, seed.App.IApplicationAddress.String(), c1) + s.Require().NoError(err) + + c2 := NewCommitmentBuilder(seed.App.ID). + WithEpochIndex(0).WithTournamentAddress(rootAddr).Build() + err = s.Repo.CreateCommitment( + s.Ctx, seed.App.IApplicationAddress.String(), c2) + s.Require().NoError(err) + + // Create a match in the root tournament so the FK constraint is satisfied + matchIDHash := UniqueHash() + match := NewMatchBuilder(seed.App.ID). + WithEpochIndex(0). + WithTournamentAddress(rootAddr). + WithIDHash(matchIDHash). + WithCommitmentOne(c1.Commitment). + WithCommitmentTwo(c2.Commitment). + Build() + err = s.Repo.CreateMatch( + s.Ctx, seed.App.IApplicationAddress.String(), match) + s.Require().NoError(err) + + // Sub-tournament with parent referencing the match + t1 := NewTournamentBuilder(seed.App.ID). + WithEpochIndex(0).WithLevel(1). + WithParent(rootAddr, matchIDHash).Build() + err = s.Repo.CreateTournament( + s.Ctx, seed.App.IApplicationAddress.String(), t1) + s.Require().NoError(err) + + tournaments, total, err := s.Repo.ListTournaments( + s.Ctx, seed.App.IApplicationAddress.String(), + repository.TournamentFilter{ParentTournamentAddress: &rootAddr}, + repository.Pagination{Limit: 10}, false) + s.Require().NoError(err) + s.Len(tournaments, 1) + s.Equal(uint64(1), total) + s.Require().NotNil(tournaments[0].ParentTournamentAddress) + s.Equal(rootAddr, *tournaments[0].ParentTournamentAddress) + }) +} diff --git a/pkg/service/service.go b/pkg/service/service.go index 97660c492..4d889d779 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -266,6 +266,15 @@ func (s *Service) Stop(force bool) []error { func (s *Service) Serve() error { s.Running.Store(true) + + // Check for context cancellation before the first tick. + select { + case <-s.Context.Done(): + s.Stop(true) + return nil + default: + } + s.Tick() for s.Running.Load() { select {