From 7e9d21462d52627b24dcb564a70716621ae1be6c Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Thu, 30 Apr 2026 18:16:50 +1200 Subject: [PATCH 1/5] refactor(db): use new store interface --- internal/assets/assets.go | 2 +- .../{ => sqlite}/000001_init_sqlite.down.sql | 0 .../{ => sqlite}/000001_init_sqlite.up.sql | 0 .../{ => sqlite}/000002_oauth_name.down.sql | 0 .../{ => sqlite}/000002_oauth_name.up.sql | 0 .../{ => sqlite}/000003_oauth_sub.down.sql | 0 .../{ => sqlite}/000003_oauth_sub.up.sql | 0 .../{ => sqlite}/000004_created_at.down.sql | 0 .../{ => sqlite}/000004_created_at.up.sql | 0 .../{ => sqlite}/000005_oidc_session.down.sql | 0 .../{ => sqlite}/000005_oidc_session.up.sql | 0 .../{ => sqlite}/000006_oidc_nonce.down.sql | 0 .../{ => sqlite}/000006_oidc_nonce.up.sql | 0 .../{ => sqlite}/000007_oidc_pkce.down.sql | 0 .../{ => sqlite}/000007_oidc_pkce.up.sql | 0 .../000008_oidc_code_reuse.down.sql | 0 .../000008_oidc_code_reuse.up.sql | 0 .../000009_oidc_userinfo_profile.down.sql | 0 .../000009_oidc_userinfo_profile.up.sql | 0 internal/bootstrap/app_bootstrap.go | 11 +-- internal/bootstrap/db_bootstrap.go | 19 ++++- internal/bootstrap/service_bootstrap.go | 2 +- internal/controller/oidc_controller_test.go | 13 +--- internal/controller/proxy_controller_test.go | 14 +--- internal/controller/user_controller_test.go | 26 +++---- .../controller/well_known_controller_test.go | 14 +--- internal/model/config.go | 2 +- internal/repository/models.go | 73 ++++--------------- internal/repository/{ => sqlite}/db.go | 4 +- internal/repository/sqlite/models.go | 64 ++++++++++++++++ .../{ => sqlite}/oidc_queries.sql.go | 4 +- .../{ => sqlite}/session_queries.sql.go | 4 +- internal/repository/store.go | 41 +++++++++++ internal/service/auth_service.go | 4 +- internal/service/oidc_service.go | 4 +- sql/{ => sqlite}/oidc_queries.sql | 0 sql/{ => sqlite}/oidc_schemas.sql | 0 sql/{ => sqlite}/session_queries.sql | 0 sql/{ => sqlite}/session_schemas.sql | 0 sqlc.yml | 8 +- 40 files changed, 171 insertions(+), 138 deletions(-) rename internal/assets/migrations/{ => sqlite}/000001_init_sqlite.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000001_init_sqlite.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000002_oauth_name.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000002_oauth_name.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000003_oauth_sub.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000003_oauth_sub.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000004_created_at.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000004_created_at.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000005_oidc_session.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000005_oidc_session.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000006_oidc_nonce.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000006_oidc_nonce.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000007_oidc_pkce.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000007_oidc_pkce.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000008_oidc_code_reuse.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000008_oidc_code_reuse.up.sql (100%) rename internal/assets/migrations/{ => sqlite}/000009_oidc_userinfo_profile.down.sql (100%) rename internal/assets/migrations/{ => sqlite}/000009_oidc_userinfo_profile.up.sql (100%) rename internal/repository/{ => sqlite}/db.go (93%) create mode 100644 internal/repository/sqlite/models.go rename internal/repository/{ => sqlite}/oidc_queries.sql.go (99%) rename internal/repository/{ => sqlite}/session_queries.sql.go (98%) create mode 100644 internal/repository/store.go rename sql/{ => sqlite}/oidc_queries.sql (100%) rename sql/{ => sqlite}/oidc_schemas.sql (100%) rename sql/{ => sqlite}/session_queries.sql (100%) rename sql/{ => sqlite}/session_schemas.sql (100%) diff --git a/internal/assets/assets.go b/internal/assets/assets.go index 412403c9..a5c3d79d 100644 --- a/internal/assets/assets.go +++ b/internal/assets/assets.go @@ -11,5 +11,5 @@ var FrontendAssets embed.FS // Migrations // -//go:embed migrations/*.sql +//go:embed migrations/sqlite/*.sql var Migrations embed.FS diff --git a/internal/assets/migrations/000001_init_sqlite.down.sql b/internal/assets/migrations/sqlite/000001_init_sqlite.down.sql similarity index 100% rename from internal/assets/migrations/000001_init_sqlite.down.sql rename to internal/assets/migrations/sqlite/000001_init_sqlite.down.sql diff --git a/internal/assets/migrations/000001_init_sqlite.up.sql b/internal/assets/migrations/sqlite/000001_init_sqlite.up.sql similarity index 100% rename from internal/assets/migrations/000001_init_sqlite.up.sql rename to internal/assets/migrations/sqlite/000001_init_sqlite.up.sql diff --git a/internal/assets/migrations/000002_oauth_name.down.sql b/internal/assets/migrations/sqlite/000002_oauth_name.down.sql similarity index 100% rename from internal/assets/migrations/000002_oauth_name.down.sql rename to internal/assets/migrations/sqlite/000002_oauth_name.down.sql diff --git a/internal/assets/migrations/000002_oauth_name.up.sql b/internal/assets/migrations/sqlite/000002_oauth_name.up.sql similarity index 100% rename from internal/assets/migrations/000002_oauth_name.up.sql rename to internal/assets/migrations/sqlite/000002_oauth_name.up.sql diff --git a/internal/assets/migrations/000003_oauth_sub.down.sql b/internal/assets/migrations/sqlite/000003_oauth_sub.down.sql similarity index 100% rename from internal/assets/migrations/000003_oauth_sub.down.sql rename to internal/assets/migrations/sqlite/000003_oauth_sub.down.sql diff --git a/internal/assets/migrations/000003_oauth_sub.up.sql b/internal/assets/migrations/sqlite/000003_oauth_sub.up.sql similarity index 100% rename from internal/assets/migrations/000003_oauth_sub.up.sql rename to internal/assets/migrations/sqlite/000003_oauth_sub.up.sql diff --git a/internal/assets/migrations/000004_created_at.down.sql b/internal/assets/migrations/sqlite/000004_created_at.down.sql similarity index 100% rename from internal/assets/migrations/000004_created_at.down.sql rename to internal/assets/migrations/sqlite/000004_created_at.down.sql diff --git a/internal/assets/migrations/000004_created_at.up.sql b/internal/assets/migrations/sqlite/000004_created_at.up.sql similarity index 100% rename from internal/assets/migrations/000004_created_at.up.sql rename to internal/assets/migrations/sqlite/000004_created_at.up.sql diff --git a/internal/assets/migrations/000005_oidc_session.down.sql b/internal/assets/migrations/sqlite/000005_oidc_session.down.sql similarity index 100% rename from internal/assets/migrations/000005_oidc_session.down.sql rename to internal/assets/migrations/sqlite/000005_oidc_session.down.sql diff --git a/internal/assets/migrations/000005_oidc_session.up.sql b/internal/assets/migrations/sqlite/000005_oidc_session.up.sql similarity index 100% rename from internal/assets/migrations/000005_oidc_session.up.sql rename to internal/assets/migrations/sqlite/000005_oidc_session.up.sql diff --git a/internal/assets/migrations/000006_oidc_nonce.down.sql b/internal/assets/migrations/sqlite/000006_oidc_nonce.down.sql similarity index 100% rename from internal/assets/migrations/000006_oidc_nonce.down.sql rename to internal/assets/migrations/sqlite/000006_oidc_nonce.down.sql diff --git a/internal/assets/migrations/000006_oidc_nonce.up.sql b/internal/assets/migrations/sqlite/000006_oidc_nonce.up.sql similarity index 100% rename from internal/assets/migrations/000006_oidc_nonce.up.sql rename to internal/assets/migrations/sqlite/000006_oidc_nonce.up.sql diff --git a/internal/assets/migrations/000007_oidc_pkce.down.sql b/internal/assets/migrations/sqlite/000007_oidc_pkce.down.sql similarity index 100% rename from internal/assets/migrations/000007_oidc_pkce.down.sql rename to internal/assets/migrations/sqlite/000007_oidc_pkce.down.sql diff --git a/internal/assets/migrations/000007_oidc_pkce.up.sql b/internal/assets/migrations/sqlite/000007_oidc_pkce.up.sql similarity index 100% rename from internal/assets/migrations/000007_oidc_pkce.up.sql rename to internal/assets/migrations/sqlite/000007_oidc_pkce.up.sql diff --git a/internal/assets/migrations/000008_oidc_code_reuse.down.sql b/internal/assets/migrations/sqlite/000008_oidc_code_reuse.down.sql similarity index 100% rename from internal/assets/migrations/000008_oidc_code_reuse.down.sql rename to internal/assets/migrations/sqlite/000008_oidc_code_reuse.down.sql diff --git a/internal/assets/migrations/000008_oidc_code_reuse.up.sql b/internal/assets/migrations/sqlite/000008_oidc_code_reuse.up.sql similarity index 100% rename from internal/assets/migrations/000008_oidc_code_reuse.up.sql rename to internal/assets/migrations/sqlite/000008_oidc_code_reuse.up.sql diff --git a/internal/assets/migrations/000009_oidc_userinfo_profile.down.sql b/internal/assets/migrations/sqlite/000009_oidc_userinfo_profile.down.sql similarity index 100% rename from internal/assets/migrations/000009_oidc_userinfo_profile.down.sql rename to internal/assets/migrations/sqlite/000009_oidc_userinfo_profile.down.sql diff --git a/internal/assets/migrations/000009_oidc_userinfo_profile.up.sql b/internal/assets/migrations/sqlite/000009_oidc_userinfo_profile.up.sql similarity index 100% rename from internal/assets/migrations/000009_oidc_userinfo_profile.up.sql rename to internal/assets/migrations/sqlite/000009_oidc_userinfo_profile.up.sql diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 5b342c48..a840ffcb 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -144,17 +144,14 @@ func (app *BootstrapApp) Setup() error { tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name") // Database - db, err := app.SetupDatabase(app.config.Database.Path) + store, err := app.SetupStore() if err != nil { return fmt.Errorf("failed to setup database: %w", err) } - // Queries - queries := repository.New(db) - // Services - services, err := app.initServices(queries) + services, err := app.initServices(store) if err != nil { return fmt.Errorf("failed to initialize services: %w", err) @@ -210,7 +207,7 @@ func (app *BootstrapApp) Setup() error { // Start db cleanup routine tlog.App.Debug().Msg("Starting database cleanup routine") - go app.dbCleanupRoutine(queries) + go app.dbCleanupRoutine(store) // If analytics are not disabled, start heartbeat if app.config.Analytics.Enabled { @@ -300,7 +297,7 @@ func (app *BootstrapApp) heartbeatRoutine() { } } -func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) { +func (app *BootstrapApp) dbCleanupRoutine(queries repository.Store) { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() ctx := context.Background() diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 3f48f793..efc21311 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -7,6 +7,8 @@ import ( "path/filepath" "github.com/tinyauthapp/tinyauth/internal/assets" + "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/sqlite3" @@ -14,7 +16,18 @@ import ( _ "modernc.org/sqlite" ) -func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { +func (app *BootstrapApp) SetupStore() (repository.Store, error) { + return app.setupSQLite(app.config.Database.Path) +} + +// NewSQLiteStore opens a SQLite database at the given path, runs migrations, and returns a Store. +// Useful for testing or when constructing a store outside of a BootstrapApp. +func NewSQLiteStore(databasePath string) (repository.Store, error) { + app := &BootstrapApp{} + return app.setupSQLite(databasePath) +} + +func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) { dir := filepath.Dir(databasePath) if err := os.MkdirAll(dir, 0750); err != nil { @@ -31,7 +44,7 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { // if the sqlite connection starts being a bottleneck db.SetMaxOpenConns(1) - migrations, err := iofs.New(assets.Migrations, "migrations") + migrations, err := iofs.New(assets.Migrations, "migrations/sqlite") if err != nil { return nil, fmt.Errorf("failed to create migrations: %w", err) @@ -53,5 +66,5 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { return nil, fmt.Errorf("failed to migrate database: %w", err) } - return db, nil + return sqlite.New(db), nil } diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 09485bd0..48ded235 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -18,7 +18,7 @@ type Services struct { oidcService *service.OIDCService } -func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) { +func (app *BootstrapApp) initServices(queries repository.Store) (Services, error) { services := Services{} ldapService := service.NewLdapService(service.LdapServiceConfig{ diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 150540fc..88690410 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -17,7 +17,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) @@ -852,13 +851,10 @@ func TestOIDCController(t *testing.T) { }, } - app := bootstrap.NewBootstrapApp(model.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) require.NoError(t, err) - queries := repository.New(db) - oidcService := service.NewOIDCService(oidcServiceCfg, queries) + oidcService := service.NewOIDCService(oidcServiceCfg, store) err = oidcService.Init() require.NoError(t, err) @@ -881,9 +877,4 @@ func TestOIDCController(t *testing.T) { test.run(t, router, recorder) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index 7b2e3202..f84d791b 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -11,7 +11,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) @@ -400,13 +399,9 @@ func TestProxyController(t *testing.T) { oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - app := bootstrap.NewBootstrapApp(model.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) require.NoError(t, err) - queries := repository.New(db) - docker := service.NewDockerService() err = docker.Init() require.NoError(t, err) @@ -419,7 +414,7 @@ func TestProxyController(t *testing.T) { err = broker.Init() require.NoError(t, err) - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) + authService := service.NewAuthService(authServiceCfg, ldap, store, broker) err = authService.Init() require.NoError(t, err) @@ -444,9 +439,4 @@ func TestProxyController(t *testing.T) { test.run(t, router, recorder) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 4863c16e..4184274d 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -111,15 +111,6 @@ func TestUserController(t *testing.T) { }) } - oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - - app := bootstrap.NewBootstrapApp(model.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) - - queries := repository.New(db) - type testCase struct { description string middlewares []gin.HandlerFunc @@ -294,7 +285,7 @@ func TestUserController(t *testing.T) { totpCtx, }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - _, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{ + _, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{ UUID: "test-totp-login-uuid", Username: "test", Email: "test@example.com", @@ -418,7 +409,7 @@ func TestUserController(t *testing.T) { totpAttrCtx, }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - _, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{ + _, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{ UUID: "test-totp-login-attributes-uuid", Username: "test", Email: "test@example.com", @@ -456,6 +447,12 @@ func TestUserController(t *testing.T) { }, } + oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) + + store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) + require.NoError(t, err) + + docker := service.NewDockerService() err = docker.Init() require.NoError(t, err) @@ -468,7 +465,7 @@ func TestUserController(t *testing.T) { err = broker.Init() require.NoError(t, err) - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) + authService := service.NewAuthService(authServiceCfg, ldap, store, broker) err = authService.Init() require.NoError(t, err) @@ -497,9 +494,4 @@ func TestUserController(t *testing.T) { test.run(t, router, recorder) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index 7dcf2bdc..2bb8cfe1 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -13,7 +13,6 @@ import ( "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) @@ -101,14 +100,10 @@ func TestWellKnownController(t *testing.T) { }, } - app := bootstrap.NewBootstrapApp(model.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) + store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) require.NoError(t, err) - queries := repository.New(db) - - oidcService := service.NewOIDCService(oidcServiceCfg, queries) + oidcService := service.NewOIDCService(oidcServiceCfg, store) err = oidcService.Init() require.NoError(t, err) @@ -125,9 +120,4 @@ func TestWellKnownController(t *testing.T) { test.run(t, router, recorder) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/model/config.go b/internal/model/config.go index 95870e3d..c2125f92 100644 --- a/internal/model/config.go +++ b/internal/model/config.go @@ -82,7 +82,7 @@ type Config struct { } type DatabaseConfig struct { - Path string `description:"The path to the database, including file name." yaml:"path"` + Path string `description:"The path to the SQLite database, including file name." yaml:"path"` } type AnalyticsConfig struct { diff --git a/internal/repository/models.go b/internal/repository/models.go index bc2e2c66..0c33e038 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -1,64 +1,19 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.30.0 - package repository -type OidcCode struct { - Sub string - CodeHash string - Scope string - RedirectURI string - ClientID string - ExpiresAt int64 - Nonce string - CodeChallenge string -} +// This file is a stop-gap until more drivers are added. It re-exports the models from the sqlite package so that the rest +// of the codebase can import them from a single location without needing to know about the underlying database implementation. -type OidcToken struct { - Sub string - AccessTokenHash string - RefreshTokenHash string - CodeHash string - Scope string - ClientID string - TokenExpiresAt int64 - RefreshTokenExpiresAt int64 - Nonce string -} +import "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" -type OidcUserinfo struct { - Sub string - Name string - PreferredUsername string - Email string - Groups string - UpdatedAt int64 - GivenName string - FamilyName string - MiddleName string - Nickname string - Profile string - Picture string - Website string - Gender string - Birthdate string - Zoneinfo string - Locale string - PhoneNumber string - Address string -} +type Session = sqlite.Session +type OidcCode = sqlite.OidcCode +type OidcToken = sqlite.OidcToken +type OidcUserinfo = sqlite.OidcUserinfo -type Session struct { - UUID string - Username string - Email string - Name string - Provider string - TotpPending bool - OAuthGroups string - Expiry int64 - CreatedAt int64 - OAuthName string - OAuthSub string -} +type CreateSessionParams = sqlite.CreateSessionParams +type UpdateSessionParams = sqlite.UpdateSessionParams +type CreateOidcCodeParams = sqlite.CreateOidcCodeParams +type CreateOidcTokenParams = sqlite.CreateOidcTokenParams +type UpdateOidcTokenByRefreshTokenParams = sqlite.UpdateOidcTokenByRefreshTokenParams +type DeleteExpiredOidcTokensParams = sqlite.DeleteExpiredOidcTokensParams +type CreateOidcUserInfoParams = sqlite.CreateOidcUserInfoParams diff --git a/internal/repository/db.go b/internal/repository/sqlite/db.go similarity index 93% rename from internal/repository/db.go rename to internal/repository/sqlite/db.go index 998bfd3b..ee310fc2 100644 --- a/internal/repository/db.go +++ b/internal/repository/sqlite/db.go @@ -1,8 +1,8 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.0 -package repository +package sqlite import ( "context" diff --git a/internal/repository/sqlite/models.go b/internal/repository/sqlite/models.go new file mode 100644 index 00000000..caf37f4c --- /dev/null +++ b/internal/repository/sqlite/models.go @@ -0,0 +1,64 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.0 + +package sqlite + +type OidcCode struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 + Nonce string + CodeChallenge string +} + +type OidcToken struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + CodeHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + Nonce string +} + +type OidcUserinfo struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 + GivenName string + FamilyName string + MiddleName string + Nickname string + Profile string + Picture string + Website string + Gender string + Birthdate string + Zoneinfo string + Locale string + PhoneNumber string + Address string +} + +type Session struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + CreatedAt int64 + OAuthName string + OAuthSub string +} diff --git a/internal/repository/oidc_queries.sql.go b/internal/repository/sqlite/oidc_queries.sql.go similarity index 99% rename from internal/repository/oidc_queries.sql.go rename to internal/repository/sqlite/oidc_queries.sql.go index 7caac9d4..027ac421 100644 --- a/internal/repository/oidc_queries.sql.go +++ b/internal/repository/sqlite/oidc_queries.sql.go @@ -1,9 +1,9 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.0 // source: oidc_queries.sql -package repository +package sqlite import ( "context" diff --git a/internal/repository/session_queries.sql.go b/internal/repository/sqlite/session_queries.sql.go similarity index 98% rename from internal/repository/session_queries.sql.go rename to internal/repository/sqlite/session_queries.sql.go index c846c3f9..4271b727 100644 --- a/internal/repository/session_queries.sql.go +++ b/internal/repository/sqlite/session_queries.sql.go @@ -1,9 +1,9 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.0 // source: session_queries.sql -package repository +package sqlite import ( "context" diff --git a/internal/repository/store.go b/internal/repository/store.go new file mode 100644 index 00000000..765df6a5 --- /dev/null +++ b/internal/repository/store.go @@ -0,0 +1,41 @@ +package repository + +import "context" + +// Store is the interface that all storage drivers must implement. +// The sqlc-generated *Queries struct satisfies this interface for SQLite. +// Future drivers (postgres, etc.) must return the shared types defined in this package. +type Store interface { + // Sessions + CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) + GetSession(ctx context.Context, uuid string) (Session, error) + UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) + DeleteSession(ctx context.Context, uuid string) error + DeleteExpiredSessions(ctx context.Context, expiry int64) error + + // OIDC codes + CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) + GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) + GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) + GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) + GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) + DeleteOidcCode(ctx context.Context, codeHash string) error + DeleteOidcCodeBySub(ctx context.Context, sub string) error + DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) + + // OIDC tokens + CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) + GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) + GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) + GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) + UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) + DeleteOidcToken(ctx context.Context, accessTokenHash string) error + DeleteOidcTokenBySub(ctx context.Context, sub string) error + DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error + DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) + + // OIDC userinfo + CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) + GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) + DeleteOidcUserInfo(ctx context.Context, sub string) error +} diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 16c53fe0..29f491f1 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -96,14 +96,14 @@ type AuthService struct { loginMutex sync.RWMutex ldapGroupsMutex sync.RWMutex ldap *LdapService - queries *repository.Queries + queries repository.Store oauthBroker *OAuthBrokerService lockdown *Lockdown lockdownCtx context.Context lockdownCancelFunc context.CancelFunc } -func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { +func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries repository.Store, oauthBroker *OAuthBrokerService) *AuthService { return &AuthService{ config: config, loginAttempts: make(map[string]*LoginAttempt), diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 1e1c1986..d6b11628 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -121,7 +121,7 @@ type OIDCServiceConfig struct { type OIDCService struct { config OIDCServiceConfig - queries *repository.Queries + queries repository.Store clients map[string]model.OIDCClientConfig privateKey *rsa.PrivateKey publicKey crypto.PublicKey @@ -129,7 +129,7 @@ type OIDCService struct { isConfigured bool } -func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService { +func NewOIDCService(config OIDCServiceConfig, queries repository.Store) *OIDCService { return &OIDCService{ config: config, queries: queries, diff --git a/sql/oidc_queries.sql b/sql/sqlite/oidc_queries.sql similarity index 100% rename from sql/oidc_queries.sql rename to sql/sqlite/oidc_queries.sql diff --git a/sql/oidc_schemas.sql b/sql/sqlite/oidc_schemas.sql similarity index 100% rename from sql/oidc_schemas.sql rename to sql/sqlite/oidc_schemas.sql diff --git a/sql/session_queries.sql b/sql/sqlite/session_queries.sql similarity index 100% rename from sql/session_queries.sql rename to sql/sqlite/session_queries.sql diff --git a/sql/session_schemas.sql b/sql/sqlite/session_schemas.sql similarity index 100% rename from sql/session_schemas.sql rename to sql/sqlite/session_schemas.sql diff --git a/sqlc.yml b/sqlc.yml index de08738a..e7b2c4b4 100644 --- a/sqlc.yml +++ b/sqlc.yml @@ -1,12 +1,12 @@ version: "2" sql: - engine: "sqlite" - queries: "sql/*_queries.sql" - schema: "sql/*_schemas.sql" + queries: "sql/sqlite/*_queries.sql" + schema: "sql/sqlite/*_schemas.sql" gen: go: - package: "repository" - out: "internal/repository" + package: "sqlite" + out: "internal/repository/sqlite" rename: uuid: "UUID" oauth_groups: "OAuthGroups" From 95661052456240ae6ec89ebf2fed915dac0b429b Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Sun, 3 May 2026 13:49:24 +1200 Subject: [PATCH 2/5] feat(db): add code gen to build sqlc-compatible wrappers --- .github/workflows/ci.yml | 6 + cmd/gen/sqlc-wrapper/main.go | 522 +++++++++++++++++++++++++ go.mod | 2 + internal/bootstrap/db_bootstrap.go | 2 +- internal/repository/models.go | 163 +++++++- internal/repository/sqlite/generate.go | 3 + internal/repository/sqlite/store.go | 206 ++++++++++ 7 files changed, 886 insertions(+), 18 deletions(-) create mode 100644 cmd/gen/sqlc-wrapper/main.go create mode 100644 internal/repository/sqlite/generate.go create mode 100644 internal/repository/sqlite/store.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 12db1641..fb8c9736 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,6 +26,12 @@ jobs: - name: Go dependencies run: go mod download + - name: Check codegen is up to date + run: | + go generate ./internal/repository/... + git diff --exit-code -- internal/repository/ + git status --porcelain -- internal/repository/ | grep -q . && echo "untracked files in internal/repository/" && exit 1 || true + - name: Install frontend dependencies run: | cd frontend diff --git a/cmd/gen/sqlc-wrapper/main.go b/cmd/gen/sqlc-wrapper/main.go new file mode 100644 index 00000000..e66ae8ee --- /dev/null +++ b/cmd/gen/sqlc-wrapper/main.go @@ -0,0 +1,522 @@ +// gen/sqlc-wrapper generates store.go wrapper files for each sqlc driver package under +// internal/repository//. Run via: +// +// go generate ./internal/repository/... +// +// The generator introspects *Queries methods and the model/params types in the +// driver package, then emits a store.go that wraps *Queries so it satisfies +// repository.Store using the canonical shared types in the parent package. +// This generator is specific to sqlc-generated drivers. Non-sqlc drivers should +// implement repository.Store directly by hand. +package main + +import ( + "bytes" + "flag" + "fmt" + "go/format" + "go/types" + "log" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + "text/template" + + "golang.org/x/tools/go/packages" +) + +func main() { + driverPkg := flag.String("pkg", "", "import path of the driver package") + out := flag.String("out", "store.go", "output filename relative to driver package directory") + flag.Parse() + + if *driverPkg == "" { + log.Fatal("-pkg is required") + } + + // Resolve the driver package directory so we can overlay the output file + // with a valid stub. This prevents a stale store.go from poisoning the + // type-checker and producing cryptic "undefined" errors. + driverDir, err := pkgDir(*driverPkg) + if err != nil { + log.Fatalf("resolve driver dir: %v", err) + } + outPath := filepath.Join(driverDir, *out) + if filepath.IsAbs(*out) { + outPath = *out + } + + // Stub replaces the output file during load so stale generated code is ignored. + stub := []byte("package " + filepath.Base(driverDir) + "\n") + + cfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedImports, + Overlay: map[string][]byte{outPath: stub}, + } + pkgs, err := packages.Load(cfg, *driverPkg) + if err != nil { + log.Fatalf("load %s: %v", *driverPkg, err) + } + if len(pkgs) != 1 { + log.Fatalf("expected 1 package, got %d", len(pkgs)) + } + pkg := pkgs[0] + if len(pkg.Errors) > 0 { + for _, e := range pkg.Errors { + log.Printf("package error: %v", e) + } + log.Fatal("package has errors") + } + + repoPkg := parentPkg(*driverPkg) + + // Load the parent (repository) package so we can validate struct shapes. + repoPkgs, err := packages.Load(cfg, repoPkg) + if err != nil { + log.Fatalf("load repo pkg %s: %v", repoPkg, err) + } + if len(repoPkgs) != 1 || len(repoPkgs[0].Errors) > 0 { + log.Fatalf("could not load repo package %s cleanly", repoPkg) + } + if err := validateStructShapes(pkg.Types, repoPkgs[0].Types); err != nil { + log.Fatalf("struct shape mismatch: %v", err) + } + + // Check *Queries covers every method in repository.Store before generating. + if err := validateStoreCoverage(pkg.Types, repoPkgs[0].Types); err != nil { + log.Fatalf("%v", err) + } + + methods, err := collectMethods(pkg.Types) + if err != nil { + log.Fatal(err) + } + + models, _ := collectTypes(pkg.Types) + + data := tmplData{ + PkgName: pkg.Name, + RepoPkg: repoPkg, + ModelTypes: models, + Methods: renderMethods(methods), + } + + src, err := render(data) + if err != nil { + log.Fatalf("render: %v", err) + } + + if err := os.WriteFile(outPath, src, 0644); err != nil { + log.Fatalf("write %s: %v", outPath, err) + } + fmt.Printf("wrote %s\n", outPath) +} + +func parentPkg(imp string) string { + parts := strings.Split(imp, "/") + return strings.Join(parts[:len(parts)-1], "/") +} + +// pkgDir returns the on-disk directory for an import path using `go list`. +func pkgDir(importPath string) (string, error) { + out, err := exec.Command("go", "list", "-f", "{{.Dir}}", importPath).Output() + if err != nil { + return "", fmt.Errorf("go list %s: %w", importPath, err) + } + return strings.TrimSpace(string(out)), nil +} + +// validateStoreCoverage checks that every method declared in repository.Store +// exists on *Queries in the driver package. Missing methods are reported by +// name so the developer knows exactly which SQL queries need to be added. +func validateStoreCoverage(driverPkg, repoPkg *types.Package) error { + // Collect *Queries method names. + queriesObj := driverPkg.Scope().Lookup("Queries") + if queriesObj == nil { + return fmt.Errorf("Queries type not found in driver package") + } + queriesNamed := queriesObj.Type().(*types.Named) + queriesMS := types.NewMethodSet(types.NewPointer(queriesNamed)) + queriesMethods := make(map[string]bool) + for m := range queriesMS.Methods() { + queriesMethods[m.Obj().Name()] = true + } + + // Collect repository.Store interface methods. + storeObj := repoPkg.Scope().Lookup("Store") + if storeObj == nil { + return fmt.Errorf("Store type not found in repository package") + } + storeIface, ok := storeObj.Type().Underlying().(*types.Interface) + if !ok { + return fmt.Errorf("repository.Store is not an interface") + } + + var missing []string + for i := range storeIface.NumMethods() { + name := storeIface.Method(i).Name() + if !queriesMethods[name] { + missing = append(missing, name) + } + } + if len(missing) > 0 { + sort.Strings(missing) + return fmt.Errorf( + "driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries.", + len(missing), strings.Join(missing, "\n - "), + ) + } + return nil +} + +type methodInfo struct { + Name string + Params []paramInfo + Results []resultInfo +} + +type paramInfo struct { + Name string + TypeStr string // local (unqualified) type name + RepoType string // "repository.X" if this is a driver model/params type; else "" +} + +type resultInfo struct { + TypeStr string + IsSlice bool + RepoType string // "repository.X" if driver type; else "" +} + +func collectMethods(pkg *types.Package) ([]methodInfo, error) { + obj := pkg.Scope().Lookup("Queries") + if obj == nil { + return nil, fmt.Errorf("queries type not found in %s", pkg.Path()) + } + named, ok := obj.Type().(*types.Named) + if !ok { + return nil, fmt.Errorf("queries is not a named type") + } + ms := types.NewMethodSet(types.NewPointer(named)) + + var out []methodInfo + for method := range ms.Methods() { + fn, ok := method.Obj().(*types.Func) + if !ok || fn.Name() == "WithTx" { + continue + } + sig := fn.Type().(*types.Signature) + mi := methodInfo{Name: fn.Name()} + + // params: skip receiver + first (context.Context) + for i := 1; i < sig.Params().Len(); i++ { + p := sig.Params().At(i) + mi.Params = append(mi.Params, makeParam(p.Name(), p.Type(), pkg.Path())) + } + // results: skip error + for r := range sig.Results().Variables() { + if r.Type().String() == "error" { + continue + } + mi.Results = append(mi.Results, makeResult(r.Type(), pkg.Path())) + } + out = append(out, mi) + } + return out, nil +} + +func makeParam(name string, t types.Type, driverPath string) paramInfo { + pi := paramInfo{Name: name} + pi.TypeStr = localName(t, driverPath) + pi.RepoType = repoName(t, driverPath) + return pi +} + +func makeResult(t types.Type, driverPath string) resultInfo { + ri := resultInfo{} + if sl, ok := t.(*types.Slice); ok { + ri.IsSlice = true + t = sl.Elem() + } + ri.TypeStr = localName(t, driverPath) + ri.RepoType = repoName(t, driverPath) + return ri +} + +func localName(t types.Type, driverPath string) string { + named, ok := t.(*types.Named) + if !ok { + return types.TypeString(t, nil) + } + if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath { + return named.Obj().Name() + } + return types.TypeString(t, func(p *types.Package) string { return p.Name() }) +} + +func repoName(t types.Type, driverPath string) string { + named, ok := t.(*types.Named) + if !ok { + return "" + } + if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath { + return "repository." + named.Obj().Name() + } + return "" +} + +func collectTypes(pkg *types.Package) (models []string, params []string) { + for _, name := range pkg.Scope().Names() { + obj := pkg.Scope().Lookup(name) + if obj == nil { + continue + } + tn, ok := obj.(*types.TypeName) + if !ok { + continue + } + named, ok := tn.Type().(*types.Named) + if !ok { + continue + } + if _, ok := named.Underlying().(*types.Struct); !ok { + continue + } + switch name { + case "Queries", "DBTX", "Store": + continue + } + if strings.HasSuffix(name, "Params") { + params = append(params, name) + } else { + models = append(models, name) + } + } + return +} + +// validateStructShapes checks that every model/params struct in the driver +// package has fields that exactly match the corresponding type in the repo +// (parent) package. This catches drift between sqlc-generated types and the +// canonical repository types before a broken cast reaches the compiler. +func validateStructShapes(driverPkg, repoPkg *types.Package) error { + var errs []string + for _, name := range driverPkg.Scope().Names() { + obj := driverPkg.Scope().Lookup(name) + if obj == nil { + continue + } + tn, ok := obj.(*types.TypeName) + if !ok { + continue + } + named, ok := tn.Type().(*types.Named) + if !ok { + continue + } + driverStruct, ok := named.Underlying().(*types.Struct) + if !ok { + continue + } + switch name { + case "Queries", "DBTX", "Store": + continue + } + + repoObj := repoPkg.Scope().Lookup(name) + if repoObj == nil { + // Driver has a type not in repo — that's fine (e.g. internal helpers). + continue + } + repoNamed, ok := repoObj.Type().(*types.Named) + if !ok { + continue + } + repoStruct, ok := repoNamed.Underlying().(*types.Struct) + if !ok { + errs = append(errs, fmt.Sprintf("%s: repo type is not a struct", name)) + continue + } + + if err := compareStructs(name, driverStruct, repoStruct); err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + return fmt.Errorf("%s", strings.Join(errs, "\n ")) + } + return nil +} + +func compareStructs(name string, driver, repo *types.Struct) error { + if driver.NumFields() != repo.NumFields() { + return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)", + name, driver.NumFields(), repo.NumFields()) + } + for i := range driver.NumFields() { + df := driver.Field(i) + rf := repo.Field(i) + if df.Name() != rf.Name() { + return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)", + name, i, df.Name(), rf.Name()) + } + if !types.Identical(df.Type(), rf.Type()) { + return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)", + name, df.Name(), df.Type(), rf.Type()) + } + } + return nil +} + +// converterFn: "Session" -> "sessionToRepo" +func converterFn(s string) string { + if s == "" { + return "" + } + r := []rune(s) + r[0] = []rune(strings.ToLower(string(r[0])))[0] + return string(r) + "ToRepo" +} + +// renderedMethod is the pre-built method body passed to the template. +type renderedMethod struct { + Signature string + Body string +} + +// renderMethods converts []methodInfo into fully pre-rendered signature+body strings. +func renderMethods(methods []methodInfo) []renderedMethod { + var out []renderedMethod + for _, m := range methods { + out = append(out, renderedMethod{ + Signature: buildSig(m), + Body: buildBody(m), + }) + } + return out +} + +func buildSig(m methodInfo) string { + var sb strings.Builder + sb.WriteString("func (s *Store) ") + sb.WriteString(m.Name) + sb.WriteString("(ctx context.Context") + for _, p := range m.Params { + sb.WriteString(", ") + sb.WriteString(p.Name) + sb.WriteString(" ") + if p.RepoType != "" { + sb.WriteString(p.RepoType) + } else { + sb.WriteString(p.TypeStr) + } + } + sb.WriteString(") (") + for _, r := range m.Results { + if r.IsSlice { + sb.WriteString("[]") + } + if r.RepoType != "" { + sb.WriteString(r.RepoType) + } else { + sb.WriteString(r.TypeStr) + } + sb.WriteString(", ") + } + sb.WriteString("error)") + return sb.String() +} + +func callArgs(m methodInfo) string { + var args []string + for _, p := range m.Params { + if p.RepoType != "" { + // convert repo type → driver type: DriverType(arg) + args = append(args, p.TypeStr+"("+p.Name+")") + } else { + args = append(args, p.Name) + } + } + if len(args) == 0 { + return "ctx" + } + return "ctx, " + strings.Join(args, ", ") +} + +func buildBody(m methodInfo) string { + call := "s.q." + m.Name + "(" + callArgs(m) + ")" + + // no repo-typed result → direct return + if len(m.Results) == 0 || m.Results[0].RepoType == "" { + return "\treturn " + call + "\n" + } + + r := m.Results[0] + if r.IsSlice { + return fmt.Sprintf( + "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, err\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", + call, r.RepoType, converterFn(r.TypeStr), + ) + } + return fmt.Sprintf( + "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, err\n\t}\n\treturn %s(r), nil\n", + call, r.RepoType, converterFn(r.TypeStr), + ) +} + +type tmplData struct { + PkgName string + RepoPkg string + ModelTypes []string + Methods []renderedMethod +} + +const storeSrc = `// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. +package {{.PkgName}} + +import ( + "context" + + "{{.RepoPkg}}" +) + +// Store wraps *Queries and implements repository.Store. +type Store struct { + q *Queries +} + +// NewStore wraps a *Queries to satisfy repository.Store. +func NewStore(q *Queries) repository.Store { + return &Store{q: q} +} + +{{range .ModelTypes -}} +func {{converterFn .}}(v {{.}}) repository.{{.}} { + return repository.{{.}}(v) +} +{{end -}} +{{range .Methods}}{{.Signature}} { +{{.Body}}} + +{{end}}` + +func render(data tmplData) ([]byte, error) { + t, err := template.New("store").Funcs(template.FuncMap{ + "converterFn": converterFn, + }).Parse(storeSrc) + if err != nil { + return nil, fmt.Errorf("parse template: %w", err) + } + + var buf bytes.Buffer + if err := t.Execute(&buf, data); err != nil { + return nil, fmt.Errorf("execute template: %w", err) + } + + formatted, err := format.Source(buf.Bytes()) + if err != nil { + return buf.Bytes(), fmt.Errorf("format source: %w\nraw:\n%s", err, buf.String()) + } + return formatted, nil +} diff --git a/go.mod b/go.mod index 1e7c795e..15879c92 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/weppos/publicsuffix-go v0.50.3 golang.org/x/crypto v0.50.0 golang.org/x/oauth2 v0.36.0 + golang.org/x/tools v0.43.0 k8s.io/apimachinery v0.36.0 k8s.io/client-go v0.36.0 modernc.org/sqlite v1.50.0 @@ -121,6 +122,7 @@ require ( go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/arch v0.22.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect + golang.org/x/mod v0.34.0 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.43.0 // indirect diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index efc21311..2279cb23 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -66,5 +66,5 @@ func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, err return nil, fmt.Errorf("failed to migrate database: %w", err) } - return sqlite.New(db), nil + return sqlite.NewStore(sqlite.New(db)), nil } diff --git a/internal/repository/models.go b/internal/repository/models.go index 0c33e038..3f58dd66 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -1,19 +1,148 @@ package repository -// This file is a stop-gap until more drivers are added. It re-exports the models from the sqlite package so that the rest -// of the codebase can import them from a single location without needing to know about the underlying database implementation. - -import "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" - -type Session = sqlite.Session -type OidcCode = sqlite.OidcCode -type OidcToken = sqlite.OidcToken -type OidcUserinfo = sqlite.OidcUserinfo - -type CreateSessionParams = sqlite.CreateSessionParams -type UpdateSessionParams = sqlite.UpdateSessionParams -type CreateOidcCodeParams = sqlite.CreateOidcCodeParams -type CreateOidcTokenParams = sqlite.CreateOidcTokenParams -type UpdateOidcTokenByRefreshTokenParams = sqlite.UpdateOidcTokenByRefreshTokenParams -type DeleteExpiredOidcTokensParams = sqlite.DeleteExpiredOidcTokensParams -type CreateOidcUserInfoParams = sqlite.CreateOidcUserInfoParams +// Shared model and parameter types for all storage drivers. +// sqlc-generated driver packages use these via the conversion layer in their store.go. + +type Session struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + CreatedAt int64 + OAuthName string + OAuthSub string +} + +type OidcCode struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 + Nonce string + CodeChallenge string +} + +type OidcToken struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + CodeHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + Nonce string +} + +type OidcUserinfo struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 + GivenName string + FamilyName string + MiddleName string + Nickname string + Profile string + Picture string + Website string + Gender string + Birthdate string + Zoneinfo string + Locale string + PhoneNumber string + Address string +} + +type CreateSessionParams struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + CreatedAt int64 + OAuthName string + OAuthSub string +} + +type UpdateSessionParams struct { + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + OAuthName string + OAuthSub string + UUID string +} + +type CreateOidcCodeParams struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 + Nonce string + CodeChallenge string +} + +type CreateOidcTokenParams struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + CodeHash string + Nonce string +} + +type UpdateOidcTokenByRefreshTokenParams struct { + AccessTokenHash string + RefreshTokenHash string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + RefreshTokenHash_2 string +} + +type DeleteExpiredOidcTokensParams struct { + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 +} + +type CreateOidcUserInfoParams struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 + GivenName string + FamilyName string + MiddleName string + Nickname string + Profile string + Picture string + Website string + Gender string + Birthdate string + Zoneinfo string + Locale string + PhoneNumber string + Address string +} diff --git a/internal/repository/sqlite/generate.go b/internal/repository/sqlite/generate.go new file mode 100644 index 00000000..5f6011f1 --- /dev/null +++ b/internal/repository/sqlite/generate.go @@ -0,0 +1,3 @@ +package sqlite + +//go:generate go run github.com/tinyauthapp/tinyauth/cmd/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite diff --git a/internal/repository/sqlite/store.go b/internal/repository/sqlite/store.go new file mode 100644 index 00000000..65b4e190 --- /dev/null +++ b/internal/repository/sqlite/store.go @@ -0,0 +1,206 @@ +// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. +package sqlite + +import ( + "context" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +// Store wraps *Queries and implements repository.Store. +type Store struct { + q *Queries +} + +// NewStore wraps a *Queries to satisfy repository.Store. +func NewStore(q *Queries) repository.Store { + return &Store{q: q} +} + +func oidcCodeToRepo(v OidcCode) repository.OidcCode { + return repository.OidcCode(v) +} +func oidcTokenToRepo(v OidcToken) repository.OidcToken { + return repository.OidcToken(v) +} +func oidcUserinfoToRepo(v OidcUserinfo) repository.OidcUserinfo { + return repository.OidcUserinfo(v) +} +func sessionToRepo(v Session) repository.Session { + return repository.Session(v) +} +func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { + r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { + r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { + r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) + if err != nil { + return repository.OidcUserinfo{}, err + } + return oidcUserinfoToRepo(r), nil +} + +func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { + r, err := s.q.CreateSession(ctx, CreateSessionParams(arg)) + if err != nil { + return repository.Session{}, err + } + return sessionToRepo(r), nil +} + +func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { + rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) + if err != nil { + return nil, err + } + out := make([]repository.OidcCode, len(rows)) + for i, row := range rows { + out[i] = oidcCodeToRepo(row) + } + return out, nil +} + +func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { + rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) + if err != nil { + return nil, err + } + out := make([]repository.OidcToken, len(rows)) + for i, row := range rows { + out[i] = oidcTokenToRepo(row) + } + return out, nil +} + +func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { + return s.q.DeleteExpiredSessions(ctx, expiry) +} + +func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { + return s.q.DeleteOidcCode(ctx, codeHash) +} + +func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { + return s.q.DeleteOidcCodeBySub(ctx, sub) +} + +func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { + return s.q.DeleteOidcToken(ctx, accessTokenHash) +} + +func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { + return s.q.DeleteOidcTokenByCodeHash(ctx, codeHash) +} + +func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { + return s.q.DeleteOidcTokenBySub(ctx, sub) +} + +func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { + return s.q.DeleteOidcUserInfo(ctx, sub) +} + +func (s *Store) DeleteSession(ctx context.Context, uuid string) error { + return s.q.DeleteSession(ctx, uuid) +} + +func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCode(ctx, codeHash) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeBySub(ctx, sub) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) + if err != nil { + return repository.OidcCode{}, err + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { + r, err := s.q.GetOidcToken(ctx, accessTokenHash) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { + r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { + r, err := s.q.GetOidcTokenBySub(ctx, sub) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { + r, err := s.q.GetOidcUserInfo(ctx, sub) + if err != nil { + return repository.OidcUserinfo{}, err + } + return oidcUserinfoToRepo(r), nil +} + +func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { + r, err := s.q.GetSession(ctx, uuid) + if err != nil { + return repository.Session{}, err + } + return sessionToRepo(r), nil +} + +func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { + r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) + if err != nil { + return repository.OidcToken{}, err + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { + r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg)) + if err != nil { + return repository.Session{}, err + } + return sessionToRepo(r), nil +} From ef8bbd8c9fad4903b47c5510582b41f57cd795db Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Mon, 4 May 2026 05:02:27 +1200 Subject: [PATCH 3/5] feat(db): add `memory` storage driver removes the sqlite dependency for tests, also brings back the option for users to run zero persistence instances of tinyauth. adds new mapErr fn for sqlc wrapper gen to prevent sql errors from leaking out of the store implementation. --- cmd/gen/sqlc-wrapper/main.go | 24 +- internal/bootstrap/db_bootstrap.go | 17 +- internal/controller/oidc_controller_test.go | 7 +- internal/controller/proxy_controller_test.go | 9 +- internal/controller/user_controller_test.go | 12 +- .../controller/well_known_controller_test.go | 7 +- .../middleware/context_middleware_test.go | 26 +- internal/model/config.go | 6 +- internal/repository/memory/oidc_queries.go | 241 ++++++++++++++++++ internal/repository/memory/session_queries.go | 63 +++++ internal/repository/memory/store.go | 27 ++ internal/repository/sqlite/store.go | 68 +++-- internal/repository/store.go | 8 +- internal/service/auth_service.go | 3 +- internal/service/oidc_service.go | 15 +- 15 files changed, 443 insertions(+), 90 deletions(-) create mode 100644 internal/repository/memory/oidc_queries.go create mode 100644 internal/repository/memory/session_queries.go create mode 100644 internal/repository/memory/store.go diff --git a/cmd/gen/sqlc-wrapper/main.go b/cmd/gen/sqlc-wrapper/main.go index e66ae8ee..d6cb6318 100644 --- a/cmd/gen/sqlc-wrapper/main.go +++ b/cmd/gen/sqlc-wrapper/main.go @@ -449,18 +449,18 @@ func buildBody(m methodInfo) string { // no repo-typed result → direct return if len(m.Results) == 0 || m.Results[0].RepoType == "" { - return "\treturn " + call + "\n" + return "\treturn mapErr(" + call + ")\n" } r := m.Results[0] if r.IsSlice { return fmt.Sprintf( - "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, err\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", + "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, mapErr(err)\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", call, r.RepoType, converterFn(r.TypeStr), ) } return fmt.Sprintf( - "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, err\n\t}\n\treturn %s(r), nil\n", + "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, mapErr(err)\n\t}\n\treturn %s(r), nil\n", call, r.RepoType, converterFn(r.TypeStr), ) } @@ -477,6 +477,8 @@ package {{.PkgName}} import ( "context" + "database/sql" + "errors" "{{.RepoPkg}}" ) @@ -491,6 +493,22 @@ func NewStore(q *Queries) repository.Store { return &Store{q: q} } +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + {{range .ModelTypes -}} func {{converterFn .}}(v {{.}}) repository.{{.}} { return repository.{{.}}(v) diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 2279cb23..4f09372a 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -8,6 +8,7 @@ import ( "github.com/tinyauthapp/tinyauth/internal/assets" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" "github.com/golang-migrate/migrate/v4" @@ -17,14 +18,14 @@ import ( ) func (app *BootstrapApp) SetupStore() (repository.Store, error) { - return app.setupSQLite(app.config.Database.Path) -} - -// NewSQLiteStore opens a SQLite database at the given path, runs migrations, and returns a Store. -// Useful for testing or when constructing a store outside of a BootstrapApp. -func NewSQLiteStore(databasePath string) (repository.Store, error) { - app := &BootstrapApp{} - return app.setupSQLite(databasePath) + switch app.config.Database.Driver { + case "memory": + return memory.New(), nil + case "sqlite", "": + return app.setupSQLite(app.config.Database.Path) + default: + return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver) + } } func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) { diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 88690410..4f131ac7 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -14,9 +14,9 @@ import ( "github.com/google/go-querystring/query" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) @@ -851,11 +851,10 @@ func TestOIDCController(t *testing.T) { }, } - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() oidcService := service.NewOIDCService(oidcServiceCfg, store) - err = oidcService.Init() + err := oidcService.Init() require.NoError(t, err) for _, test := range tests { diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index f84d791b..c7876713 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -2,22 +2,20 @@ package controller_test import ( "net/http/httptest" - "path" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) func TestProxyController(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ LocalUsers: &[]model.LocalUser{ @@ -399,11 +397,10 @@ func TestProxyController(t *testing.T) { oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() docker := service.NewDockerService() - err = docker.Init() + err := docker.Init() require.NoError(t, err) ldap := service.NewLdapService(service.LdapServiceConfig{}) diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 4184274d..1d2f02e9 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "net/http" "net/http/httptest" - "path" "strings" "testing" "time" @@ -14,17 +13,16 @@ import ( "github.com/pquerna/otp/totp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) func TestUserController(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ LocalUsers: &[]model.LocalUser{ @@ -117,6 +115,8 @@ func TestUserController(t *testing.T) { run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) } + store := memory.New() + tests := []testCase{ { description: "Should be able to login with valid credentials", @@ -449,12 +449,8 @@ func TestUserController(t *testing.T) { oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) - - docker := service.NewDockerService() - err = docker.Init() + err := docker.Init() require.NoError(t, err) ldap := service.NewLdapService(service.LdapServiceConfig{}) diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index 2bb8cfe1..582e4842 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -10,9 +10,9 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) @@ -100,11 +100,10 @@ func TestWellKnownController(t *testing.T) { }, } - store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) + store := memory.New() oidcService := service.NewOIDCService(oidcServiceCfg, store) - err = oidcService.Init() + err := oidcService.Init() require.NoError(t, err) for _, test := range tests { diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go index 5dfde3b4..d1dfa99f 100644 --- a/internal/middleware/context_middleware_test.go +++ b/internal/middleware/context_middleware_test.go @@ -5,24 +5,22 @@ import ( "encoding/base64" "net/http" "net/http/httptest" - "path" "testing" "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) func TestContextMiddleware(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ LocalUsers: &[]model.LocalUser{ @@ -52,7 +50,7 @@ func TestContextMiddleware(t *testing.T) { return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) } - seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) { + seedSession := func(t *testing.T, queries repository.Store, params repository.CreateSessionParams) { t.Helper() _, err := queries.CreateSession(context.Background(), params) require.NoError(t, err) @@ -60,7 +58,7 @@ func TestContextMiddleware(t *testing.T) { type runArgs struct { do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) - queries *repository.Queries + queries repository.Store } type testCase struct { @@ -272,22 +270,17 @@ func TestContextMiddleware(t *testing.T) { oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - app := bootstrap.NewBootstrapApp(model.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) - - queries := repository.New(db) + store := memory.New() ldap := service.NewLdapService(service.LdapServiceConfig{}) - err = ldap.Init() + err := ldap.Init() require.NoError(t, err) broker := service.NewOAuthBrokerService(oauthBrokerCfgs) err = broker.Init() require.NoError(t, err) - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) + authService := service.NewAuthService(authServiceCfg, ldap, store, broker) err = authService.Init() require.NoError(t, err) @@ -317,12 +310,7 @@ func TestContextMiddleware(t *testing.T) { return captured, recorder } - test.run(t, runArgs{do: do, queries: queries}) + test.run(t, runArgs{do: do, queries: store}) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/model/config.go b/internal/model/config.go index c2125f92..cffe8734 100644 --- a/internal/model/config.go +++ b/internal/model/config.go @@ -4,7 +4,8 @@ package model func NewDefaultConfiguration() *Config { return &Config{ Database: DatabaseConfig{ - Path: "./tinyauth.db", + Driver: "sqlite", + Path: "./tinyauth.db", }, Analytics: AnalyticsConfig{ Enabled: true, @@ -82,7 +83,8 @@ type Config struct { } type DatabaseConfig struct { - Path string `description:"The path to the SQLite database, including file name." yaml:"path"` + Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"` + Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"` } type AnalyticsConfig struct { diff --git a/internal/repository/memory/oidc_queries.go b/internal/repository/memory/oidc_queries.go new file mode 100644 index 00000000..80305fc0 --- /dev/null +++ b/internal/repository/memory/oidc_queries.go @@ -0,0 +1,241 @@ +package memory + +import ( + "context" + "fmt" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + // Enforce sub UNIQUE constraint + for _, c := range s.oidcCodes { + if c.Sub == arg.Sub { + return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub") + } + } + code := repository.OidcCode(arg) + s.oidcCodes[arg.CodeHash] = code + return code, nil +} + +// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). +func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + c, ok := s.oidcCodes[codeHash] + if !ok { + return repository.OidcCode{}, repository.ErrNotFound + } + delete(s.oidcCodes, codeHash) + return c, nil +} + +// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). +func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + for k, c := range s.oidcCodes { + if c.Sub == sub { + delete(s.oidcCodes, k) + return c, nil + } + } + return repository.OidcCode{}, repository.ErrNotFound +} + +// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT). +func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + c, ok := s.oidcCodes[codeHash] + if !ok { + return repository.OidcCode{}, repository.ErrNotFound + } + return c, nil +} + +// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT). +func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, c := range s.oidcCodes { + if c.Sub == sub { + return c, nil + } + } + return repository.OidcCode{}, repository.ErrNotFound +} + +func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcCodes, codeHash) + return nil +} + +func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, c := range s.oidcCodes { + if c.Sub == sub { + delete(s.oidcCodes, k) + } + } + return nil +} + +func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + var deleted []repository.OidcCode + for k, c := range s.oidcCodes { + if c.ExpiresAt < expiresAt { + deleted = append(deleted, c) + delete(s.oidcCodes, k) + } + } + return deleted, nil +} + +func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + // Enforce sub UNIQUE constraint + for _, t := range s.oidcTokens { + if t.Sub == arg.Sub { + return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub") + } + } + tok := repository.OidcToken{ + Sub: arg.Sub, + AccessTokenHash: arg.AccessTokenHash, + RefreshTokenHash: arg.RefreshTokenHash, + CodeHash: arg.CodeHash, + Scope: arg.Scope, + ClientID: arg.ClientID, + TokenExpiresAt: arg.TokenExpiresAt, + RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt, + Nonce: arg.Nonce, + } + s.oidcTokens[arg.AccessTokenHash] = tok + return tok, nil +} + +func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + t, ok := s.oidcTokens[accessTokenHash] + if !ok { + return repository.OidcToken{}, repository.ErrNotFound + } + return t, nil +} + +func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, t := range s.oidcTokens { + if t.RefreshTokenHash == refreshTokenHash { + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, t := range s.oidcTokens { + if t.Sub == sub { + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.RefreshTokenHash == arg.RefreshTokenHash_2 { + delete(s.oidcTokens, k) + t.AccessTokenHash = arg.AccessTokenHash + t.RefreshTokenHash = arg.RefreshTokenHash + t.TokenExpiresAt = arg.TokenExpiresAt + t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt + s.oidcTokens[arg.AccessTokenHash] = t + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcTokens, accessTokenHash) + return nil +} + +func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.Sub == sub { + delete(s.oidcTokens, k) + } + } + return nil +} + +func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.CodeHash == codeHash { + delete(s.oidcTokens, k) + } + } + return nil +} + +func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + var deleted []repository.OidcToken + for k, t := range s.oidcTokens { + if t.TokenExpiresAt < arg.TokenExpiresAt || t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt { + deleted = append(deleted, t) + delete(s.oidcTokens, k) + } + } + return deleted, nil +} + +func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { + s.mu.Lock() + defer s.mu.Unlock() + u := repository.OidcUserinfo(arg) + s.oidcUsers[arg.Sub] = u + return u, nil +} + +func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + u, ok := s.oidcUsers[sub] + if !ok { + return repository.OidcUserinfo{}, repository.ErrNotFound + } + return u, nil +} + +func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcUsers, sub) + return nil +} diff --git a/internal/repository/memory/session_queries.go b/internal/repository/memory/session_queries.go new file mode 100644 index 00000000..2edde6b1 --- /dev/null +++ b/internal/repository/memory/session_queries.go @@ -0,0 +1,63 @@ +package memory + +import ( + "context" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func (s *Store) CreateSession(_ context.Context, arg repository.CreateSessionParams) (repository.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess := repository.Session(arg) + s.sessions[arg.UUID] = sess + return sess, nil +} + +func (s *Store) GetSession(_ context.Context, uuid string) (repository.Session, error) { + s.mu.RLock() + defer s.mu.RUnlock() + sess, ok := s.sessions[uuid] + if !ok { + return repository.Session{}, repository.ErrNotFound + } + return sess, nil +} + +func (s *Store) UpdateSession(_ context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess, ok := s.sessions[arg.UUID] + if !ok { + return repository.Session{}, repository.ErrNotFound + } + sess.Username = arg.Username + sess.Email = arg.Email + sess.Name = arg.Name + sess.Provider = arg.Provider + sess.TotpPending = arg.TotpPending + sess.OAuthGroups = arg.OAuthGroups + sess.Expiry = arg.Expiry + sess.OAuthName = arg.OAuthName + sess.OAuthSub = arg.OAuthSub + s.sessions[arg.UUID] = sess + return sess, nil +} + +func (s *Store) DeleteSession(_ context.Context, uuid string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, uuid) + return nil +} + +func (s *Store) DeleteExpiredSessions(_ context.Context, expiry int64) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, v := range s.sessions { + if v.Expiry < expiry { + delete(s.sessions, k) + } + } + return nil +} diff --git a/internal/repository/memory/store.go b/internal/repository/memory/store.go new file mode 100644 index 00000000..969cba66 --- /dev/null +++ b/internal/repository/memory/store.go @@ -0,0 +1,27 @@ +// Package memory provides an in-memory implementation of repository.Store for use in tests. +package memory + +import ( + "sync" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +// Store is a thread-safe in-memory implementation of repository.Store. +type Store struct { + mu sync.RWMutex + sessions map[string]repository.Session + oidcCodes map[string]repository.OidcCode + oidcTokens map[string]repository.OidcToken + oidcUsers map[string]repository.OidcUserinfo +} + +// New returns a new empty in-memory Store. +func New() repository.Store { + return &Store{ + sessions: make(map[string]repository.Session), + oidcCodes: make(map[string]repository.OidcCode), + oidcTokens: make(map[string]repository.OidcToken), + oidcUsers: make(map[string]repository.OidcUserinfo), + } +} diff --git a/internal/repository/sqlite/store.go b/internal/repository/sqlite/store.go index 65b4e190..f316efa4 100644 --- a/internal/repository/sqlite/store.go +++ b/internal/repository/sqlite/store.go @@ -3,6 +3,8 @@ package sqlite import ( "context" + "database/sql" + "errors" "github.com/tinyauthapp/tinyauth/internal/repository" ) @@ -17,6 +19,22 @@ func NewStore(q *Queries) repository.Store { return &Store{q: q} } +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + func oidcCodeToRepo(v OidcCode) repository.OidcCode { return repository.OidcCode(v) } @@ -32,7 +50,7 @@ func sessionToRepo(v Session) repository.Session { func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -40,7 +58,7 @@ func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCod func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -48,7 +66,7 @@ func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTo func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) if err != nil { - return repository.OidcUserinfo{}, err + return repository.OidcUserinfo{}, mapErr(err) } return oidcUserinfoToRepo(r), nil } @@ -56,7 +74,7 @@ func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOid func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { r, err := s.q.CreateSession(ctx, CreateSessionParams(arg)) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } @@ -64,7 +82,7 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) if err != nil { - return nil, err + return nil, mapErr(err) } out := make([]repository.OidcCode, len(rows)) for i, row := range rows { @@ -76,7 +94,7 @@ func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([] func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) if err != nil { - return nil, err + return nil, mapErr(err) } out := make([]repository.OidcToken, len(rows)) for i, row := range rows { @@ -86,41 +104,41 @@ func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.Dele } func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { - return s.q.DeleteExpiredSessions(ctx, expiry) + return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) } func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { - return s.q.DeleteOidcCode(ctx, codeHash) + return mapErr(s.q.DeleteOidcCode(ctx, codeHash)) } func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { - return s.q.DeleteOidcCodeBySub(ctx, sub) + return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub)) } func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { - return s.q.DeleteOidcToken(ctx, accessTokenHash) + return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash)) } func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { - return s.q.DeleteOidcTokenByCodeHash(ctx, codeHash) + return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash)) } func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { - return s.q.DeleteOidcTokenBySub(ctx, sub) + return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub)) } func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { - return s.q.DeleteOidcUserInfo(ctx, sub) + return mapErr(s.q.DeleteOidcUserInfo(ctx, sub)) } func (s *Store) DeleteSession(ctx context.Context, uuid string) error { - return s.q.DeleteSession(ctx, uuid) + return mapErr(s.q.DeleteSession(ctx, uuid)) } func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { r, err := s.q.GetOidcCode(ctx, codeHash) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -128,7 +146,7 @@ func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.Oi func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeBySub(ctx, sub) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -136,7 +154,7 @@ func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.Oi func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -144,7 +162,7 @@ func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (reposit func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) if err != nil { - return repository.OidcCode{}, err + return repository.OidcCode{}, mapErr(err) } return oidcCodeToRepo(r), nil } @@ -152,7 +170,7 @@ func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (reposit func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { r, err := s.q.GetOidcToken(ctx, accessTokenHash) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -160,7 +178,7 @@ func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repos func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -168,7 +186,7 @@ func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { r, err := s.q.GetOidcTokenBySub(ctx, sub) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -176,7 +194,7 @@ func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.O func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { r, err := s.q.GetOidcUserInfo(ctx, sub) if err != nil { - return repository.OidcUserinfo{}, err + return repository.OidcUserinfo{}, mapErr(err) } return oidcUserinfoToRepo(r), nil } @@ -184,7 +202,7 @@ func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.Oid func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { r, err := s.q.GetSession(ctx, uuid) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } @@ -192,7 +210,7 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) if err != nil { - return repository.OidcToken{}, err + return repository.OidcToken{}, mapErr(err) } return oidcTokenToRepo(r), nil } @@ -200,7 +218,7 @@ func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repositor func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg)) if err != nil { - return repository.Session{}, err + return repository.Session{}, mapErr(err) } return sessionToRepo(r), nil } diff --git a/internal/repository/store.go b/internal/repository/store.go index 765df6a5..302f2f10 100644 --- a/internal/repository/store.go +++ b/internal/repository/store.go @@ -1,6 +1,12 @@ package repository -import "context" +import ( + "context" + "errors" +) + +// ErrNotFound is returned by Store methods when the requested record does not exist. +var ErrNotFound = errors.New("not found") // Store is the interface that all storage drivers must implement. // The sqlc-generated *Queries struct satisfies this interface for SQLite. diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 29f491f1..c6a12944 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -2,7 +2,6 @@ package service import ( "context" - "database/sql" "errors" "fmt" "net/http" @@ -421,7 +420,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito session, err := auth.queries.GetSession(ctx, uuid) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return nil, errors.New("session not found") } return nil, err diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index d6b11628..65b5ccda 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -7,7 +7,6 @@ import ( "crypto/rsa" "crypto/sha256" "crypto/x509" - "database/sql" "encoding/base64" "encoding/json" "encoding/pem" @@ -422,7 +421,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client oidcCode, err := service.queries.GetOidcCode(c, codeHash) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.OidcCode{}, ErrCodeNotFound } return repository.OidcCode{}, err @@ -566,7 +565,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken)) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return TokenResponse{}, ErrTokenNotFound } return TokenResponse{}, err @@ -645,7 +644,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re entry, err := service.queries.GetOidcToken(c, tokenHash) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.OidcToken{}, ErrTokenNotFound } return repository.OidcToken{}, err @@ -733,15 +732,15 @@ func (service *OIDCService) Hash(token string) string { func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error { err := service.queries.DeleteOidcCodeBySub(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } err = service.queries.DeleteOidcTokenBySub(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } err = service.queries.DeleteOidcUserInfo(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } return nil @@ -786,7 +785,7 @@ func (service *OIDCService) Cleanup() { token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { continue } tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub") From db911a41c3df499da383d5e73d4be2d776378424 Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Thu, 7 May 2026 18:55:51 +1200 Subject: [PATCH 4/5] refactor(db): cleanup sqlc-wrapper gen --- Makefile | 2 +- .../sqlc-wrapper/{main.go => sqlc_wrapper.go} | 445 +++++++++--------- cmd/gen/sqlc-wrapper/store.tmpl | 46 ++ internal/repository/sqlite/db.go | 2 +- internal/repository/sqlite/models.go | 2 +- .../repository/sqlite/oidc_queries.sql.go | 2 +- .../repository/sqlite/session_queries.sql.go | 2 +- 7 files changed, 267 insertions(+), 234 deletions(-) rename cmd/gen/sqlc-wrapper/{main.go => sqlc_wrapper.go} (65%) create mode 100644 cmd/gen/sqlc-wrapper/store.tmpl diff --git a/Makefile b/Makefile index 616fd994..80e7a629 100644 --- a/Makefile +++ b/Makefile @@ -84,4 +84,4 @@ sql: # Go gen generate: - go run ./gen + go generate ./internal/repository/... diff --git a/cmd/gen/sqlc-wrapper/main.go b/cmd/gen/sqlc-wrapper/sqlc_wrapper.go similarity index 65% rename from cmd/gen/sqlc-wrapper/main.go rename to cmd/gen/sqlc-wrapper/sqlc_wrapper.go index d6cb6318..0592d20c 100644 --- a/cmd/gen/sqlc-wrapper/main.go +++ b/cmd/gen/sqlc-wrapper/sqlc_wrapper.go @@ -12,6 +12,7 @@ package main import ( "bytes" + _ "embed" "flag" "fmt" "go/format" @@ -27,13 +28,22 @@ import ( "golang.org/x/tools/go/packages" ) +//go:embed store.tmpl +var storeSrc string + func main() { + if err := run(); err != nil { + log.Fatal(err) + } +} + +func run() error { driverPkg := flag.String("pkg", "", "import path of the driver package") out := flag.String("out", "store.go", "output filename relative to driver package directory") flag.Parse() if *driverPkg == "" { - log.Fatal("-pkg is required") + return fmt.Errorf("-pkg is required") } // Resolve the driver package directory so we can overlay the output file @@ -41,8 +51,9 @@ func main() { // type-checker and producing cryptic "undefined" errors. driverDir, err := pkgDir(*driverPkg) if err != nil { - log.Fatalf("resolve driver dir: %v", err) + return fmt.Errorf("resolve driver dir: %w", err) } + outPath := filepath.Join(driverDir, *out) if filepath.IsAbs(*out) { outPath = *out @@ -50,73 +61,81 @@ func main() { // Stub replaces the output file during load so stale generated code is ignored. stub := []byte("package " + filepath.Base(driverDir) + "\n") - cfg := &packages.Config{ Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedImports, Overlay: map[string][]byte{outPath: stub}, } - pkgs, err := packages.Load(cfg, *driverPkg) + + driverTypePkg, err := loadOnePkg(cfg, *driverPkg) if err != nil { - log.Fatalf("load %s: %v", *driverPkg, err) - } - if len(pkgs) != 1 { - log.Fatalf("expected 1 package, got %d", len(pkgs)) - } - pkg := pkgs[0] - if len(pkg.Errors) > 0 { - for _, e := range pkg.Errors { - log.Printf("package error: %v", e) - } - log.Fatal("package has errors") + return fmt.Errorf("load driver package: %w", err) } - repoPkg := parentPkg(*driverPkg) - - // Load the parent (repository) package so we can validate struct shapes. - repoPkgs, err := packages.Load(cfg, repoPkg) + repoPkgPath := parentPkg(*driverPkg) + repoTypePkg, err := loadOnePkg(cfg, repoPkgPath) if err != nil { - log.Fatalf("load repo pkg %s: %v", repoPkg, err) - } - if len(repoPkgs) != 1 || len(repoPkgs[0].Errors) > 0 { - log.Fatalf("could not load repo package %s cleanly", repoPkg) - } - if err := validateStructShapes(pkg.Types, repoPkgs[0].Types); err != nil { - log.Fatalf("struct shape mismatch: %v", err) + return fmt.Errorf("load repo package: %w", err) } - // Check *Queries covers every method in repository.Store before generating. - if err := validateStoreCoverage(pkg.Types, repoPkgs[0].Types); err != nil { - log.Fatalf("%v", err) + if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil { + return fmt.Errorf("struct shape mismatch: %w", err) + } + if err := validateStoreCoverage(driverTypePkg, repoTypePkg); err != nil { + return err } - methods, err := collectMethods(pkg.Types) + methods, err := collectMethods(driverTypePkg) if err != nil { - log.Fatal(err) + return err } + models, _ := collectTypes(driverTypePkg) - models, _ := collectTypes(pkg.Types) - - data := tmplData{ - PkgName: pkg.Name, - RepoPkg: repoPkg, + src, err := render(tmplData{ + PkgName: driverTypePkg.Name(), + RepoPkg: repoPkgPath, ModelTypes: models, Methods: renderMethods(methods), - } - - src, err := render(data) + }) if err != nil { - log.Fatalf("render: %v", err) + return fmt.Errorf("render: %w", err) } if err := os.WriteFile(outPath, src, 0644); err != nil { - log.Fatalf("write %s: %v", outPath, err) + return fmt.Errorf("write %s: %w", outPath, err) } fmt.Printf("wrote %s\n", outPath) + return nil +} + +// loadOnePkg loads a single package via cfg and returns its *types.Package, +// or an error if the package fails to load or has type errors. +func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) { + pkgs, err := packages.Load(cfg, importPath) + if err != nil { + return nil, fmt.Errorf("load %s: %w", importPath, err) + } + if len(pkgs) != 1 { + return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs)) + } + pkg := pkgs[0] + if len(pkg.Errors) > 0 { + msgs := make([]string, len(pkg.Errors)) + for i, e := range pkg.Errors { + msgs[i] = e.Error() + } + return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n ")) + } + return pkg.Types, nil } +// parentPkg returns the parent import path (everything before the last /). +// Panics if imp contains no slash — callers are expected to pass driver sub-packages. func parentPkg(imp string) string { - parts := strings.Split(imp, "/") - return strings.Join(parts[:len(parts)-1], "/") + i := strings.LastIndex(imp, "/") + if i < 0 { + panic(fmt.Sprintf("parentPkg: import path %q has no parent", imp)) + } + return imp[:i] } // pkgDir returns the on-disk directory for an import path using `go list`. @@ -128,14 +147,40 @@ func pkgDir(importPath string) (string, error) { return strings.TrimSpace(string(out)), nil } +// scopeStructs returns all named struct types in pkg, excluding the internal +// sqlc types Queries, DBTX, and Store. Names are returned in sorted order. +func scopeStructs(pkg *types.Package) (names []string, byName map[string]*types.Struct) { + byName = make(map[string]*types.Struct) + for _, name := range pkg.Scope().Names() { // Names() is already sorted + switch name { + case "Queries", "DBTX", "Store": + continue + } + obj, ok := pkg.Scope().Lookup(name).(*types.TypeName) + if !ok { + continue + } + named, ok := obj.Type().(*types.Named) + if !ok { + continue + } + s, ok := named.Underlying().(*types.Struct) + if !ok { + continue + } + names = append(names, name) + byName[name] = s + } + return +} + // validateStoreCoverage checks that every method declared in repository.Store // exists on *Queries in the driver package. Missing methods are reported by // name so the developer knows exactly which SQL queries need to be added. func validateStoreCoverage(driverPkg, repoPkg *types.Package) error { - // Collect *Queries method names. queriesObj := driverPkg.Scope().Lookup("Queries") if queriesObj == nil { - return fmt.Errorf("Queries type not found in driver package") + return fmt.Errorf("queries type not found in driver package") } queriesNamed := queriesObj.Type().(*types.Named) queriesMS := types.NewMethodSet(types.NewPointer(queriesNamed)) @@ -144,10 +189,9 @@ func validateStoreCoverage(driverPkg, repoPkg *types.Package) error { queriesMethods[m.Obj().Name()] = true } - // Collect repository.Store interface methods. storeObj := repoPkg.Scope().Lookup("Store") if storeObj == nil { - return fmt.Errorf("Store type not found in repository package") + return fmt.Errorf("store type not found in repository package") } storeIface, ok := storeObj.Type().Underlying().(*types.Interface) if !ok { @@ -155,22 +199,80 @@ func validateStoreCoverage(driverPkg, repoPkg *types.Package) error { } var missing []string - for i := range storeIface.NumMethods() { - name := storeIface.Method(i).Name() - if !queriesMethods[name] { + for method := range storeIface.Methods() { + if name := method.Name(); !queriesMethods[name] { missing = append(missing, name) } } if len(missing) > 0 { sort.Strings(missing) return fmt.Errorf( - "driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries.", + "driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries", len(missing), strings.Join(missing, "\n - "), ) } return nil } +// validateStructShapes checks that every model/params struct in the driver +// package has fields that exactly match the corresponding type in the repo +// (parent) package. This catches drift between sqlc-generated types and the +// canonical repository types before a broken cast reaches the compiler. +func validateStructShapes(driverPkg, repoPkg *types.Package) error { + _, repoStructs := scopeStructs(repoPkg) + driverNames, driverStructs := scopeStructs(driverPkg) + + var errs []string + for _, name := range driverNames { + repoStruct, ok := repoStructs[name] + if !ok { + // Driver has a type not in repo — fine (e.g. internal helpers). + continue + } + if err := compareStructs(name, driverStructs[name], repoStruct); err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + sort.Strings(errs) + return fmt.Errorf("%s", strings.Join(errs, "\n ")) + } + return nil +} + +func compareStructs(name string, driver, repo *types.Struct) error { + if driver.NumFields() != repo.NumFields() { + return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)", + name, driver.NumFields(), repo.NumFields()) + } + for i := range driver.NumFields() { + df := driver.Field(i) + rf := repo.Field(i) + if df.Name() != rf.Name() { + return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)", + name, i, df.Name(), rf.Name()) + } + if !types.Identical(df.Type(), rf.Type()) { + return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)", + name, df.Name(), df.Type(), rf.Type()) + } + } + return nil +} + +// collectTypes returns model and params struct names from the driver package. +func collectTypes(pkg *types.Package) (models []string, params []string) { + names, _ := scopeStructs(pkg) + for _, name := range names { + if strings.HasSuffix(name, "Params") { + params = append(params, name) + } else { + models = append(models, name) + } + } + return +} + type methodInfo struct { Name string Params []paramInfo @@ -227,10 +329,11 @@ func collectMethods(pkg *types.Package) ([]methodInfo, error) { } func makeParam(name string, t types.Type, driverPath string) paramInfo { - pi := paramInfo{Name: name} - pi.TypeStr = localName(t, driverPath) - pi.RepoType = repoName(t, driverPath) - return pi + return paramInfo{ + Name: name, + TypeStr: localName(t, driverPath), + RepoType: repoName(t, driverPath), + } } func makeResult(t types.Type, driverPath string) resultInfo { @@ -266,133 +369,27 @@ func repoName(t types.Type, driverPath string) string { return "" } -func collectTypes(pkg *types.Package) (models []string, params []string) { - for _, name := range pkg.Scope().Names() { - obj := pkg.Scope().Lookup(name) - if obj == nil { - continue - } - tn, ok := obj.(*types.TypeName) - if !ok { - continue - } - named, ok := tn.Type().(*types.Named) - if !ok { - continue - } - if _, ok := named.Underlying().(*types.Struct); !ok { - continue - } - switch name { - case "Queries", "DBTX", "Store": - continue - } - if strings.HasSuffix(name, "Params") { - params = append(params, name) - } else { - models = append(models, name) - } - } - return -} - -// validateStructShapes checks that every model/params struct in the driver -// package has fields that exactly match the corresponding type in the repo -// (parent) package. This catches drift between sqlc-generated types and the -// canonical repository types before a broken cast reaches the compiler. -func validateStructShapes(driverPkg, repoPkg *types.Package) error { - var errs []string - for _, name := range driverPkg.Scope().Names() { - obj := driverPkg.Scope().Lookup(name) - if obj == nil { - continue - } - tn, ok := obj.(*types.TypeName) - if !ok { - continue - } - named, ok := tn.Type().(*types.Named) - if !ok { - continue - } - driverStruct, ok := named.Underlying().(*types.Struct) - if !ok { - continue - } - switch name { - case "Queries", "DBTX", "Store": - continue - } - - repoObj := repoPkg.Scope().Lookup(name) - if repoObj == nil { - // Driver has a type not in repo — that's fine (e.g. internal helpers). - continue - } - repoNamed, ok := repoObj.Type().(*types.Named) - if !ok { - continue - } - repoStruct, ok := repoNamed.Underlying().(*types.Struct) - if !ok { - errs = append(errs, fmt.Sprintf("%s: repo type is not a struct", name)) - continue - } - - if err := compareStructs(name, driverStruct, repoStruct); err != nil { - errs = append(errs, err.Error()) - } - } - if len(errs) > 0 { - return fmt.Errorf("%s", strings.Join(errs, "\n ")) - } - return nil -} - -func compareStructs(name string, driver, repo *types.Struct) error { - if driver.NumFields() != repo.NumFields() { - return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)", - name, driver.NumFields(), repo.NumFields()) - } - for i := range driver.NumFields() { - df := driver.Field(i) - rf := repo.Field(i) - if df.Name() != rf.Name() { - return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)", - name, i, df.Name(), rf.Name()) - } - if !types.Identical(df.Type(), rf.Type()) { - return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)", - name, df.Name(), df.Type(), rf.Type()) - } - } - return nil -} - -// converterFn: "Session" -> "sessionToRepo" +// converterFn maps a type name to its converter function name: "Session" → "sessionToRepo". func converterFn(s string) string { if s == "" { return "" } - r := []rune(s) - r[0] = []rune(strings.ToLower(string(r[0])))[0] - return string(r) + "ToRepo" + return strings.ToLower(s[:1]) + s[1:] + "ToRepo" } -// renderedMethod is the pre-built method body passed to the template. +// renderedMethod holds pre-built signature and body strings passed to the template. type renderedMethod struct { Signature string Body string } -// renderMethods converts []methodInfo into fully pre-rendered signature+body strings. func renderMethods(methods []methodInfo) []renderedMethod { - var out []renderedMethod - for _, m := range methods { - out = append(out, renderedMethod{ + out := make([]renderedMethod, len(methods)) + for i, m := range methods { + out[i] = renderedMethod{ Signature: buildSig(m), Body: buildBody(m), - }) + } } return out } @@ -429,7 +426,7 @@ func buildSig(m methodInfo) string { } func callArgs(m methodInfo) string { - var args []string + args := make([]string, 0, len(m.Params)) for _, p := range m.Params { if p.RepoType != "" { // convert repo type → driver type: DriverType(arg) @@ -444,80 +441,70 @@ func callArgs(m methodInfo) string { return "ctx, " + strings.Join(args, ", ") } -func buildBody(m methodInfo) string { - call := "s.q." + m.Name + "(" + callArgs(m) + ")" +// bodyTemplates holds the per-shape method body templates, parsed once at init. +var bodyTemplates = template.Must( + template.New("bodies").Parse(` +{{define "void"}} return mapErr({{.Call}}) +{{end}} - // no repo-typed result → direct return - if len(m.Results) == 0 || m.Results[0].RepoType == "" { - return "\treturn mapErr(" + call + ")\n" +{{define "scalar"}} r, err := {{.Call}} + if err != nil { + return {{.RepoType}}{}, mapErr(err) } + return {{.Converter}}(r), nil +{{end}} - r := m.Results[0] - if r.IsSlice { - return fmt.Sprintf( - "\trows, err := %s\n\tif err != nil {\n\t\treturn nil, mapErr(err)\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n", - call, r.RepoType, converterFn(r.TypeStr), - ) +{{define "slice"}} rows, err := {{.Call}} + if err != nil { + return nil, mapErr(err) } - return fmt.Sprintf( - "\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, mapErr(err)\n\t}\n\treturn %s(r), nil\n", - call, r.RepoType, converterFn(r.TypeStr), - ) -} - -type tmplData struct { - PkgName string - RepoPkg string - ModelTypes []string - Methods []renderedMethod -} - -const storeSrc = `// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. -package {{.PkgName}} - -import ( - "context" - "database/sql" - "errors" - - "{{.RepoPkg}}" + out := make([]{{.RepoType}}, len(rows)) + for i, row := range rows { + out[i] = {{.Converter}}(row) + } + return out, nil +{{end}}`), ) -// Store wraps *Queries and implements repository.Store. -type Store struct { - q *Queries +type bodyData struct { + Call string + RepoType string + Converter string } -// NewStore wraps a *Queries to satisfy repository.Store. -func NewStore(q *Queries) repository.Store { - return &Store{q: q} -} +func buildBody(m methodInfo) string { + call := "s.q." + m.Name + "(" + callArgs(m) + ")" -var errMap = []struct { - from error - to error -}{ - {sql.ErrNoRows, repository.ErrNotFound}, -} + var ( + name string + data bodyData + ) -func mapErr(err error) error { - for _, e := range errMap { - if errors.Is(err, e.from) { - return e.to - } + switch { + case len(m.Results) == 0 || m.Results[0].RepoType == "": + name = "void" + data = bodyData{Call: call} + case m.Results[0].IsSlice: + name = "slice" + data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)} + default: + name = "scalar" + data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)} } - return err -} -{{range .ModelTypes -}} -func {{converterFn .}}(v {{.}}) repository.{{.}} { - return repository.{{.}}(v) + var buf bytes.Buffer + if err := bodyTemplates.ExecuteTemplate(&buf, name, data); err != nil { + panic(fmt.Sprintf("buildBody %s: %v", name, err)) + } + return buf.String() } -{{end -}} -{{range .Methods}}{{.Signature}} { -{{.Body}}} -{{end}}` +type tmplData struct { + PkgName string + RepoPkg string + ModelTypes []string + Methods []renderedMethod +} func render(data tmplData) ([]byte, error) { t, err := template.New("store").Funcs(template.FuncMap{ diff --git a/cmd/gen/sqlc-wrapper/store.tmpl b/cmd/gen/sqlc-wrapper/store.tmpl new file mode 100644 index 00000000..02bb6fb1 --- /dev/null +++ b/cmd/gen/sqlc-wrapper/store.tmpl @@ -0,0 +1,46 @@ +// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. +package {{.PkgName}} + +import ( + "context" + "database/sql" + "errors" + + "{{.RepoPkg}}" +) + +// Store wraps *Queries and implements repository.Store. +type Store struct { + q *Queries +} + +// NewStore wraps a *Queries to satisfy repository.Store. +func NewStore(q *Queries) repository.Store { + return &Store{q: q} +} + +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + +{{range .ModelTypes -}} +func {{converterFn .}}(v {{.}}) repository.{{.}} { + return repository.{{.}}(v) +} +{{end -}} +{{range .Methods}}{{.Signature}} { +{{.Body}}} + +{{end}} diff --git a/internal/repository/sqlite/db.go b/internal/repository/sqlite/db.go index ee310fc2..51a4906a 100644 --- a/internal/repository/sqlite/db.go +++ b/internal/repository/sqlite/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.0 +// sqlc v1.31.1 package sqlite diff --git a/internal/repository/sqlite/models.go b/internal/repository/sqlite/models.go index caf37f4c..fd6f78da 100644 --- a/internal/repository/sqlite/models.go +++ b/internal/repository/sqlite/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.0 +// sqlc v1.31.1 package sqlite diff --git a/internal/repository/sqlite/oidc_queries.sql.go b/internal/repository/sqlite/oidc_queries.sql.go index 027ac421..e5d08bc2 100644 --- a/internal/repository/sqlite/oidc_queries.sql.go +++ b/internal/repository/sqlite/oidc_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.0 +// sqlc v1.31.1 // source: oidc_queries.sql package sqlite diff --git a/internal/repository/sqlite/session_queries.sql.go b/internal/repository/sqlite/session_queries.sql.go index 4271b727..7792fc4b 100644 --- a/internal/repository/sqlite/session_queries.sql.go +++ b/internal/repository/sqlite/session_queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.31.0 +// sqlc v1.31.1 // source: session_queries.sql package sqlite From 374b87964fe1c1aa9c67c3601b1dc3acfc45a650 Mon Sep 17 00:00:00 2001 From: Scott McKendry Date: Sun, 10 May 2026 14:01:08 +1200 Subject: [PATCH 5/5] test(db): add memory driver tests --- internal/repository/memory/memory_test.go | 427 ++++++++++++++++++++++ 1 file changed, 427 insertions(+) create mode 100644 internal/repository/memory/memory_test.go diff --git a/internal/repository/memory/memory_test.go b/internal/repository/memory/memory_test.go new file mode 100644 index 00000000..292d2abf --- /dev/null +++ b/internal/repository/memory/memory_test.go @@ -0,0 +1,427 @@ +package memory_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" +) + +var ctx = context.Background() + +func TestCreateAndGetSession(t *testing.T) { + s := memory.New() + sess, err := s.CreateSession(ctx, repository.CreateSessionParams{ + UUID: "uuid-1", + Username: "alice", + Expiry: 9999, + }) + require.NoError(t, err) + assert.Equal(t, "uuid-1", sess.UUID) + assert.Equal(t, "alice", sess.Username) + + got, err := s.GetSession(ctx, "uuid-1") + require.NoError(t, err) + assert.Equal(t, sess, got) +} + +func TestGetSession_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetSession(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestUpdateSession(t *testing.T) { + s := memory.New() + _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1", Username: "alice"}) + require.NoError(t, err) + + updated, err := s.UpdateSession(ctx, repository.UpdateSessionParams{ + UUID: "uuid-1", + Username: "bob", + Email: "bob@example.com", + }) + require.NoError(t, err) + assert.Equal(t, "bob", updated.Username) + assert.Equal(t, "bob@example.com", updated.Email) + + got, err := s.GetSession(ctx, "uuid-1") + require.NoError(t, err) + assert.Equal(t, updated, got) +} + +func TestUpdateSession_NotFound(t *testing.T) { + s := memory.New() + _, err := s.UpdateSession(ctx, repository.UpdateSessionParams{UUID: "missing"}) + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteSession(t *testing.T) { + s := memory.New() + _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteSession(ctx, "uuid-1")) + + _, err = s.GetSession(ctx, "uuid-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteExpiredSessions(t *testing.T) { + s := memory.New() + _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "expired", Expiry: 10}) + require.NoError(t, err) + _, err = s.CreateSession(ctx, repository.CreateSessionParams{UUID: "valid", Expiry: 100}) + require.NoError(t, err) + + require.NoError(t, s.DeleteExpiredSessions(ctx, 50)) + + _, err = s.GetSession(ctx, "expired") + assert.ErrorIs(t, err, repository.ErrNotFound) + + _, err = s.GetSession(ctx, "valid") + assert.NoError(t, err) +} + +func TestCreateAndGetOidcCode(t *testing.T) { + s := memory.New() + code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{ + Sub: "sub-1", + CodeHash: "hash-1", + Scope: "openid", + }) + require.NoError(t, err) + assert.Equal(t, "sub-1", code.Sub) + + // destructive read removes the record + got, err := s.GetOidcCode(ctx, "hash-1") + require.NoError(t, err) + assert.Equal(t, code, got) + + _, err = s.GetOidcCode(ctx, "hash-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCode_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcCode(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCodeBySub(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + got, err := s.GetOidcCodeBySub(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) + + // destructive — gone after read + _, err = s.GetOidcCodeBySub(ctx, "sub-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCodeBySub_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcCodeBySub(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCodeUnsafe(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + got, err := s.GetOidcCodeUnsafe(ctx, "hash-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) + + // non-destructive — still present + _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") + assert.NoError(t, err) +} + +func TestGetOidcCodeUnsafe_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcCodeUnsafe(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCodeBySubUnsafe(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, "hash-1", got.CodeHash) + + // non-destructive — still present + _, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1") + assert.NoError(t, err) +} + +func TestGetOidcCodeBySubUnsafe_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcCodeBySubUnsafe(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestCreateOidcCode_UniqueSubConstraint(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub") +} + +func TestDeleteOidcCode(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcCode(ctx, "hash-1")) + + _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcCodeBySub(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1")) + + _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteExpiredOidcCodes(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10}) + require.NoError(t, err) + _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100}) + require.NoError(t, err) + + deleted, err := s.DeleteExpiredOidcCodes(ctx, 50) + require.NoError(t, err) + require.Len(t, deleted, 1) + assert.Equal(t, "hash-1", deleted[0].CodeHash) + + _, err = s.GetOidcCodeUnsafe(ctx, "hash-2") + assert.NoError(t, err) +} + +func TestCreateAndGetOidcToken(t *testing.T) { + s := memory.New() + tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-hash-1", + CodeHash: "code-hash-1", + }) + require.NoError(t, err) + assert.Equal(t, "sub-1", tok.Sub) + + got, err := s.GetOidcToken(ctx, "at-hash-1") + require.NoError(t, err) + assert.Equal(t, tok, got) +} + +func TestGetOidcToken_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcToken(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestCreateOidcToken_UniqueSubConstraint(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + require.NoError(t, err) + + _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub") +} + +func TestGetOidcTokenByRefreshToken(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + RefreshTokenHash: "rt-1", + }) + require.NoError(t, err) + + got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) +} + +func TestGetOidcTokenByRefreshToken_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcTokenByRefreshToken(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcTokenBySub(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + }) + require.NoError(t, err) + + got, err := s.GetOidcTokenBySub(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, "at-1", got.AccessTokenHash) +} + +func TestGetOidcTokenBySub_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcTokenBySub(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestUpdateOidcTokenByRefreshToken(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + RefreshTokenHash: "rt-1", + }) + require.NoError(t, err) + + updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ + RefreshTokenHash_2: "rt-1", + AccessTokenHash: "at-2", + RefreshTokenHash: "rt-2", + TokenExpiresAt: 200, + RefreshTokenExpiresAt: 400, + }) + require.NoError(t, err) + assert.Equal(t, "at-2", updated.AccessTokenHash) + assert.Equal(t, "rt-2", updated.RefreshTokenHash) + + // old key gone, new key present + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + + got, err := s.GetOidcToken(ctx, "at-2") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) +} + +func TestUpdateOidcTokenByRefreshToken_NotFound(t *testing.T) { + s := memory.New() + _, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ + RefreshTokenHash_2: "missing", + }) + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcToken(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcToken(ctx, "at-1")) + + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcTokenBySub(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1")) + + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcTokenByCodeHash(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + CodeHash: "code-1", + }) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1")) + + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteExpiredOidcTokens(t *testing.T) { + s := memory.New() + // expired by TokenExpiresAt + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", AccessTokenHash: "at-1", + TokenExpiresAt: 10, RefreshTokenExpiresAt: 100, + }) + require.NoError(t, err) + // expired by RefreshTokenExpiresAt + _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-2", AccessTokenHash: "at-2", + TokenExpiresAt: 100, RefreshTokenExpiresAt: 10, + }) + require.NoError(t, err) + // valid + _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-3", AccessTokenHash: "at-3", + TokenExpiresAt: 100, RefreshTokenExpiresAt: 100, + }) + require.NoError(t, err) + + deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ + TokenExpiresAt: 50, + RefreshTokenExpiresAt: 50, + }) + require.NoError(t, err) + assert.Len(t, deleted, 2) + + _, err = s.GetOidcToken(ctx, "at-3") + assert.NoError(t, err) +} + +func TestCreateAndGetOidcUserInfo(t *testing.T) { + s := memory.New() + u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{ + Sub: "sub-1", + Name: "Alice", + Email: "alice@example.com", + }) + require.NoError(t, err) + assert.Equal(t, "sub-1", u.Sub) + + got, err := s.GetOidcUserInfo(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, u, got) +} + +func TestGetOidcUserInfo_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcUserInfo(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcUserInfo(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1")) + + _, err = s.GetOidcUserInfo(ctx, "sub-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +}