diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 12db1641..fb8c9736 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,6 +26,12 @@ jobs: - name: Go dependencies run: go mod download + - name: Check codegen is up to date + run: | + go generate ./internal/repository/... + git diff --exit-code -- internal/repository/ + git status --porcelain -- internal/repository/ | grep -q . && echo "untracked files in internal/repository/" && exit 1 || true + - name: Install frontend dependencies run: | cd frontend diff --git a/Makefile b/Makefile index 616fd994..80e7a629 100644 --- a/Makefile +++ b/Makefile @@ -84,4 +84,4 @@ sql: # Go gen generate: - go run ./gen + go generate ./internal/repository/... diff --git a/cmd/gen/sqlc-wrapper/sqlc_wrapper.go b/cmd/gen/sqlc-wrapper/sqlc_wrapper.go new file mode 100644 index 00000000..0592d20c --- /dev/null +++ b/cmd/gen/sqlc-wrapper/sqlc_wrapper.go @@ -0,0 +1,527 @@ +// gen/sqlc-wrapper generates store.go wrapper files for each sqlc driver package under +// internal/repository//. Run via: +// +// go generate ./internal/repository/... +// +// The generator introspects *Queries methods and the model/params types in the +// driver package, then emits a store.go that wraps *Queries so it satisfies +// repository.Store using the canonical shared types in the parent package. +// This generator is specific to sqlc-generated drivers. Non-sqlc drivers should +// implement repository.Store directly by hand. +package main + +import ( + "bytes" + _ "embed" + "flag" + "fmt" + "go/format" + "go/types" + "log" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + "text/template" + + "golang.org/x/tools/go/packages" +) + +//go:embed store.tmpl +var storeSrc string + +func main() { + if err := run(); err != nil { + log.Fatal(err) + } +} + +func run() error { + driverPkg := flag.String("pkg", "", "import path of the driver package") + out := flag.String("out", "store.go", "output filename relative to driver package directory") + flag.Parse() + + if *driverPkg == "" { + return fmt.Errorf("-pkg is required") + } + + // Resolve the driver package directory so we can overlay the output file + // with a valid stub. This prevents a stale store.go from poisoning the + // type-checker and producing cryptic "undefined" errors. + driverDir, err := pkgDir(*driverPkg) + if err != nil { + return fmt.Errorf("resolve driver dir: %w", err) + } + + outPath := filepath.Join(driverDir, *out) + if filepath.IsAbs(*out) { + outPath = *out + } + + // Stub replaces the output file during load so stale generated code is ignored. + stub := []byte("package " + filepath.Base(driverDir) + "\n") + cfg := &packages.Config{ + Mode: packages.NeedName | packages.NeedTypes | packages.NeedSyntax | packages.NeedImports, + Overlay: map[string][]byte{outPath: stub}, + } + + driverTypePkg, err := loadOnePkg(cfg, *driverPkg) + if err != nil { + return fmt.Errorf("load driver package: %w", err) + } + + repoPkgPath := parentPkg(*driverPkg) + repoTypePkg, err := loadOnePkg(cfg, repoPkgPath) + if err != nil { + return fmt.Errorf("load repo package: %w", err) + } + + if err := validateStructShapes(driverTypePkg, repoTypePkg); err != nil { + return fmt.Errorf("struct shape mismatch: %w", err) + } + if err := validateStoreCoverage(driverTypePkg, repoTypePkg); err != nil { + return err + } + + methods, err := collectMethods(driverTypePkg) + if err != nil { + return err + } + models, _ := collectTypes(driverTypePkg) + + src, err := render(tmplData{ + PkgName: driverTypePkg.Name(), + RepoPkg: repoPkgPath, + ModelTypes: models, + Methods: renderMethods(methods), + }) + if err != nil { + return fmt.Errorf("render: %w", err) + } + + if err := os.WriteFile(outPath, src, 0644); err != nil { + return fmt.Errorf("write %s: %w", outPath, err) + } + fmt.Printf("wrote %s\n", outPath) + return nil +} + +// loadOnePkg loads a single package via cfg and returns its *types.Package, +// or an error if the package fails to load or has type errors. +func loadOnePkg(cfg *packages.Config, importPath string) (*types.Package, error) { + pkgs, err := packages.Load(cfg, importPath) + if err != nil { + return nil, fmt.Errorf("load %s: %w", importPath, err) + } + if len(pkgs) != 1 { + return nil, fmt.Errorf("expected 1 package for %s, got %d", importPath, len(pkgs)) + } + pkg := pkgs[0] + if len(pkg.Errors) > 0 { + msgs := make([]string, len(pkg.Errors)) + for i, e := range pkg.Errors { + msgs[i] = e.Error() + } + return nil, fmt.Errorf("package %s has errors:\n %s", importPath, strings.Join(msgs, "\n ")) + } + return pkg.Types, nil +} + +// parentPkg returns the parent import path (everything before the last /). +// Panics if imp contains no slash — callers are expected to pass driver sub-packages. +func parentPkg(imp string) string { + i := strings.LastIndex(imp, "/") + if i < 0 { + panic(fmt.Sprintf("parentPkg: import path %q has no parent", imp)) + } + return imp[:i] +} + +// pkgDir returns the on-disk directory for an import path using `go list`. +func pkgDir(importPath string) (string, error) { + out, err := exec.Command("go", "list", "-f", "{{.Dir}}", importPath).Output() + if err != nil { + return "", fmt.Errorf("go list %s: %w", importPath, err) + } + return strings.TrimSpace(string(out)), nil +} + +// scopeStructs returns all named struct types in pkg, excluding the internal +// sqlc types Queries, DBTX, and Store. Names are returned in sorted order. +func scopeStructs(pkg *types.Package) (names []string, byName map[string]*types.Struct) { + byName = make(map[string]*types.Struct) + for _, name := range pkg.Scope().Names() { // Names() is already sorted + switch name { + case "Queries", "DBTX", "Store": + continue + } + obj, ok := pkg.Scope().Lookup(name).(*types.TypeName) + if !ok { + continue + } + named, ok := obj.Type().(*types.Named) + if !ok { + continue + } + s, ok := named.Underlying().(*types.Struct) + if !ok { + continue + } + names = append(names, name) + byName[name] = s + } + return +} + +// validateStoreCoverage checks that every method declared in repository.Store +// exists on *Queries in the driver package. Missing methods are reported by +// name so the developer knows exactly which SQL queries need to be added. +func validateStoreCoverage(driverPkg, repoPkg *types.Package) error { + queriesObj := driverPkg.Scope().Lookup("Queries") + if queriesObj == nil { + return fmt.Errorf("queries type not found in driver package") + } + queriesNamed := queriesObj.Type().(*types.Named) + queriesMS := types.NewMethodSet(types.NewPointer(queriesNamed)) + queriesMethods := make(map[string]bool) + for m := range queriesMS.Methods() { + queriesMethods[m.Obj().Name()] = true + } + + storeObj := repoPkg.Scope().Lookup("Store") + if storeObj == nil { + return fmt.Errorf("store type not found in repository package") + } + storeIface, ok := storeObj.Type().Underlying().(*types.Interface) + if !ok { + return fmt.Errorf("repository.Store is not an interface") + } + + var missing []string + for method := range storeIface.Methods() { + if name := method.Name(); !queriesMethods[name] { + missing = append(missing, name) + } + } + if len(missing) > 0 { + sort.Strings(missing) + return fmt.Errorf( + "driver *Queries is missing %d method(s) required by repository.Store:\n - %s\n\nRun sqlc generate to regenerate query methods, or add the missing SQL queries", + len(missing), strings.Join(missing, "\n - "), + ) + } + return nil +} + +// validateStructShapes checks that every model/params struct in the driver +// package has fields that exactly match the corresponding type in the repo +// (parent) package. This catches drift between sqlc-generated types and the +// canonical repository types before a broken cast reaches the compiler. +func validateStructShapes(driverPkg, repoPkg *types.Package) error { + _, repoStructs := scopeStructs(repoPkg) + driverNames, driverStructs := scopeStructs(driverPkg) + + var errs []string + for _, name := range driverNames { + repoStruct, ok := repoStructs[name] + if !ok { + // Driver has a type not in repo — fine (e.g. internal helpers). + continue + } + if err := compareStructs(name, driverStructs[name], repoStruct); err != nil { + errs = append(errs, err.Error()) + } + } + if len(errs) > 0 { + sort.Strings(errs) + return fmt.Errorf("%s", strings.Join(errs, "\n ")) + } + return nil +} + +func compareStructs(name string, driver, repo *types.Struct) error { + if driver.NumFields() != repo.NumFields() { + return fmt.Errorf("%s: field count mismatch (driver=%d, repo=%d)", + name, driver.NumFields(), repo.NumFields()) + } + for i := range driver.NumFields() { + df := driver.Field(i) + rf := repo.Field(i) + if df.Name() != rf.Name() { + return fmt.Errorf("%s: field %d name mismatch (driver=%q, repo=%q)", + name, i, df.Name(), rf.Name()) + } + if !types.Identical(df.Type(), rf.Type()) { + return fmt.Errorf("%s.%s: type mismatch (driver=%s, repo=%s)", + name, df.Name(), df.Type(), rf.Type()) + } + } + return nil +} + +// collectTypes returns model and params struct names from the driver package. +func collectTypes(pkg *types.Package) (models []string, params []string) { + names, _ := scopeStructs(pkg) + for _, name := range names { + if strings.HasSuffix(name, "Params") { + params = append(params, name) + } else { + models = append(models, name) + } + } + return +} + +type methodInfo struct { + Name string + Params []paramInfo + Results []resultInfo +} + +type paramInfo struct { + Name string + TypeStr string // local (unqualified) type name + RepoType string // "repository.X" if this is a driver model/params type; else "" +} + +type resultInfo struct { + TypeStr string + IsSlice bool + RepoType string // "repository.X" if driver type; else "" +} + +func collectMethods(pkg *types.Package) ([]methodInfo, error) { + obj := pkg.Scope().Lookup("Queries") + if obj == nil { + return nil, fmt.Errorf("queries type not found in %s", pkg.Path()) + } + named, ok := obj.Type().(*types.Named) + if !ok { + return nil, fmt.Errorf("queries is not a named type") + } + ms := types.NewMethodSet(types.NewPointer(named)) + + var out []methodInfo + for method := range ms.Methods() { + fn, ok := method.Obj().(*types.Func) + if !ok || fn.Name() == "WithTx" { + continue + } + sig := fn.Type().(*types.Signature) + mi := methodInfo{Name: fn.Name()} + + // params: skip receiver + first (context.Context) + for i := 1; i < sig.Params().Len(); i++ { + p := sig.Params().At(i) + mi.Params = append(mi.Params, makeParam(p.Name(), p.Type(), pkg.Path())) + } + // results: skip error + for r := range sig.Results().Variables() { + if r.Type().String() == "error" { + continue + } + mi.Results = append(mi.Results, makeResult(r.Type(), pkg.Path())) + } + out = append(out, mi) + } + return out, nil +} + +func makeParam(name string, t types.Type, driverPath string) paramInfo { + return paramInfo{ + Name: name, + TypeStr: localName(t, driverPath), + RepoType: repoName(t, driverPath), + } +} + +func makeResult(t types.Type, driverPath string) resultInfo { + ri := resultInfo{} + if sl, ok := t.(*types.Slice); ok { + ri.IsSlice = true + t = sl.Elem() + } + ri.TypeStr = localName(t, driverPath) + ri.RepoType = repoName(t, driverPath) + return ri +} + +func localName(t types.Type, driverPath string) string { + named, ok := t.(*types.Named) + if !ok { + return types.TypeString(t, nil) + } + if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath { + return named.Obj().Name() + } + return types.TypeString(t, func(p *types.Package) string { return p.Name() }) +} + +func repoName(t types.Type, driverPath string) string { + named, ok := t.(*types.Named) + if !ok { + return "" + } + if named.Obj().Pkg() != nil && named.Obj().Pkg().Path() == driverPath { + return "repository." + named.Obj().Name() + } + return "" +} + +// converterFn maps a type name to its converter function name: "Session" → "sessionToRepo". +func converterFn(s string) string { + if s == "" { + return "" + } + return strings.ToLower(s[:1]) + s[1:] + "ToRepo" +} + +// renderedMethod holds pre-built signature and body strings passed to the template. +type renderedMethod struct { + Signature string + Body string +} + +func renderMethods(methods []methodInfo) []renderedMethod { + out := make([]renderedMethod, len(methods)) + for i, m := range methods { + out[i] = renderedMethod{ + Signature: buildSig(m), + Body: buildBody(m), + } + } + return out +} + +func buildSig(m methodInfo) string { + var sb strings.Builder + sb.WriteString("func (s *Store) ") + sb.WriteString(m.Name) + sb.WriteString("(ctx context.Context") + for _, p := range m.Params { + sb.WriteString(", ") + sb.WriteString(p.Name) + sb.WriteString(" ") + if p.RepoType != "" { + sb.WriteString(p.RepoType) + } else { + sb.WriteString(p.TypeStr) + } + } + sb.WriteString(") (") + for _, r := range m.Results { + if r.IsSlice { + sb.WriteString("[]") + } + if r.RepoType != "" { + sb.WriteString(r.RepoType) + } else { + sb.WriteString(r.TypeStr) + } + sb.WriteString(", ") + } + sb.WriteString("error)") + return sb.String() +} + +func callArgs(m methodInfo) string { + args := make([]string, 0, len(m.Params)) + for _, p := range m.Params { + if p.RepoType != "" { + // convert repo type → driver type: DriverType(arg) + args = append(args, p.TypeStr+"("+p.Name+")") + } else { + args = append(args, p.Name) + } + } + if len(args) == 0 { + return "ctx" + } + return "ctx, " + strings.Join(args, ", ") +} + +// bodyTemplates holds the per-shape method body templates, parsed once at init. +var bodyTemplates = template.Must( + template.New("bodies").Parse(` +{{define "void"}} return mapErr({{.Call}}) +{{end}} + +{{define "scalar"}} r, err := {{.Call}} + if err != nil { + return {{.RepoType}}{}, mapErr(err) + } + return {{.Converter}}(r), nil +{{end}} + +{{define "slice"}} rows, err := {{.Call}} + if err != nil { + return nil, mapErr(err) + } + out := make([]{{.RepoType}}, len(rows)) + for i, row := range rows { + out[i] = {{.Converter}}(row) + } + return out, nil +{{end}}`), +) + +type bodyData struct { + Call string + RepoType string + Converter string +} + +func buildBody(m methodInfo) string { + call := "s.q." + m.Name + "(" + callArgs(m) + ")" + + var ( + name string + data bodyData + ) + + switch { + case len(m.Results) == 0 || m.Results[0].RepoType == "": + name = "void" + data = bodyData{Call: call} + case m.Results[0].IsSlice: + name = "slice" + data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)} + default: + name = "scalar" + data = bodyData{Call: call, RepoType: m.Results[0].RepoType, Converter: converterFn(m.Results[0].TypeStr)} + } + + var buf bytes.Buffer + if err := bodyTemplates.ExecuteTemplate(&buf, name, data); err != nil { + panic(fmt.Sprintf("buildBody %s: %v", name, err)) + } + return buf.String() +} + +type tmplData struct { + PkgName string + RepoPkg string + ModelTypes []string + Methods []renderedMethod +} + +func render(data tmplData) ([]byte, error) { + t, err := template.New("store").Funcs(template.FuncMap{ + "converterFn": converterFn, + }).Parse(storeSrc) + if err != nil { + return nil, fmt.Errorf("parse template: %w", err) + } + + var buf bytes.Buffer + if err := t.Execute(&buf, data); err != nil { + return nil, fmt.Errorf("execute template: %w", err) + } + + formatted, err := format.Source(buf.Bytes()) + if err != nil { + return buf.Bytes(), fmt.Errorf("format source: %w\nraw:\n%s", err, buf.String()) + } + return formatted, nil +} diff --git a/cmd/gen/sqlc-wrapper/store.tmpl b/cmd/gen/sqlc-wrapper/store.tmpl new file mode 100644 index 00000000..02bb6fb1 --- /dev/null +++ b/cmd/gen/sqlc-wrapper/store.tmpl @@ -0,0 +1,46 @@ +// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. +package {{.PkgName}} + +import ( + "context" + "database/sql" + "errors" + + "{{.RepoPkg}}" +) + +// Store wraps *Queries and implements repository.Store. +type Store struct { + q *Queries +} + +// NewStore wraps a *Queries to satisfy repository.Store. +func NewStore(q *Queries) repository.Store { + return &Store{q: q} +} + +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + +{{range .ModelTypes -}} +func {{converterFn .}}(v {{.}}) repository.{{.}} { + return repository.{{.}}(v) +} +{{end -}} +{{range .Methods}}{{.Signature}} { +{{.Body}}} + +{{end}} diff --git a/go.mod b/go.mod index 1e7c795e..15879c92 100644 --- a/go.mod +++ b/go.mod @@ -20,6 +20,7 @@ require ( github.com/weppos/publicsuffix-go v0.50.3 golang.org/x/crypto v0.50.0 golang.org/x/oauth2 v0.36.0 + golang.org/x/tools v0.43.0 k8s.io/apimachinery v0.36.0 k8s.io/client-go v0.36.0 modernc.org/sqlite v1.50.0 @@ -121,6 +122,7 @@ require ( go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/arch v0.22.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect + golang.org/x/mod v0.34.0 // indirect golang.org/x/net v0.52.0 // indirect golang.org/x/sync v0.20.0 // indirect golang.org/x/sys v0.43.0 // indirect diff --git a/internal/assets/assets.go b/internal/assets/assets.go index 412403c9..a5c3d79d 100644 --- a/internal/assets/assets.go +++ b/internal/assets/assets.go @@ -11,5 +11,5 @@ var FrontendAssets embed.FS // Migrations // -//go:embed migrations/*.sql +//go:embed migrations/sqlite/*.sql var Migrations embed.FS diff --git a/internal/assets/migrations/000001_init_sqlite.down.sql b/internal/assets/migrations/sqlite/000001_init_sqlite.down.sql similarity index 100% rename from internal/assets/migrations/000001_init_sqlite.down.sql rename to internal/assets/migrations/sqlite/000001_init_sqlite.down.sql diff --git a/internal/assets/migrations/000001_init_sqlite.up.sql b/internal/assets/migrations/sqlite/000001_init_sqlite.up.sql similarity index 100% rename from internal/assets/migrations/000001_init_sqlite.up.sql rename to internal/assets/migrations/sqlite/000001_init_sqlite.up.sql diff --git a/internal/assets/migrations/000002_oauth_name.down.sql b/internal/assets/migrations/sqlite/000002_oauth_name.down.sql similarity index 100% rename from internal/assets/migrations/000002_oauth_name.down.sql rename to internal/assets/migrations/sqlite/000002_oauth_name.down.sql diff --git a/internal/assets/migrations/000002_oauth_name.up.sql b/internal/assets/migrations/sqlite/000002_oauth_name.up.sql similarity index 100% rename from internal/assets/migrations/000002_oauth_name.up.sql rename to internal/assets/migrations/sqlite/000002_oauth_name.up.sql diff --git a/internal/assets/migrations/000003_oauth_sub.down.sql b/internal/assets/migrations/sqlite/000003_oauth_sub.down.sql similarity index 100% rename from internal/assets/migrations/000003_oauth_sub.down.sql rename to internal/assets/migrations/sqlite/000003_oauth_sub.down.sql diff --git a/internal/assets/migrations/000003_oauth_sub.up.sql b/internal/assets/migrations/sqlite/000003_oauth_sub.up.sql similarity index 100% rename from internal/assets/migrations/000003_oauth_sub.up.sql rename to internal/assets/migrations/sqlite/000003_oauth_sub.up.sql diff --git a/internal/assets/migrations/000004_created_at.down.sql b/internal/assets/migrations/sqlite/000004_created_at.down.sql similarity index 100% rename from internal/assets/migrations/000004_created_at.down.sql rename to internal/assets/migrations/sqlite/000004_created_at.down.sql diff --git a/internal/assets/migrations/000004_created_at.up.sql b/internal/assets/migrations/sqlite/000004_created_at.up.sql similarity index 100% rename from internal/assets/migrations/000004_created_at.up.sql rename to internal/assets/migrations/sqlite/000004_created_at.up.sql diff --git a/internal/assets/migrations/000005_oidc_session.down.sql b/internal/assets/migrations/sqlite/000005_oidc_session.down.sql similarity index 100% rename from internal/assets/migrations/000005_oidc_session.down.sql rename to internal/assets/migrations/sqlite/000005_oidc_session.down.sql diff --git a/internal/assets/migrations/000005_oidc_session.up.sql b/internal/assets/migrations/sqlite/000005_oidc_session.up.sql similarity index 100% rename from internal/assets/migrations/000005_oidc_session.up.sql rename to internal/assets/migrations/sqlite/000005_oidc_session.up.sql diff --git a/internal/assets/migrations/000006_oidc_nonce.down.sql b/internal/assets/migrations/sqlite/000006_oidc_nonce.down.sql similarity index 100% rename from internal/assets/migrations/000006_oidc_nonce.down.sql rename to internal/assets/migrations/sqlite/000006_oidc_nonce.down.sql diff --git a/internal/assets/migrations/000006_oidc_nonce.up.sql b/internal/assets/migrations/sqlite/000006_oidc_nonce.up.sql similarity index 100% rename from internal/assets/migrations/000006_oidc_nonce.up.sql rename to internal/assets/migrations/sqlite/000006_oidc_nonce.up.sql diff --git a/internal/assets/migrations/000007_oidc_pkce.down.sql b/internal/assets/migrations/sqlite/000007_oidc_pkce.down.sql similarity index 100% rename from internal/assets/migrations/000007_oidc_pkce.down.sql rename to internal/assets/migrations/sqlite/000007_oidc_pkce.down.sql diff --git a/internal/assets/migrations/000007_oidc_pkce.up.sql b/internal/assets/migrations/sqlite/000007_oidc_pkce.up.sql similarity index 100% rename from internal/assets/migrations/000007_oidc_pkce.up.sql rename to internal/assets/migrations/sqlite/000007_oidc_pkce.up.sql diff --git a/internal/assets/migrations/000008_oidc_code_reuse.down.sql b/internal/assets/migrations/sqlite/000008_oidc_code_reuse.down.sql similarity index 100% rename from internal/assets/migrations/000008_oidc_code_reuse.down.sql rename to internal/assets/migrations/sqlite/000008_oidc_code_reuse.down.sql diff --git a/internal/assets/migrations/000008_oidc_code_reuse.up.sql b/internal/assets/migrations/sqlite/000008_oidc_code_reuse.up.sql similarity index 100% rename from internal/assets/migrations/000008_oidc_code_reuse.up.sql rename to internal/assets/migrations/sqlite/000008_oidc_code_reuse.up.sql diff --git a/internal/assets/migrations/000009_oidc_userinfo_profile.down.sql b/internal/assets/migrations/sqlite/000009_oidc_userinfo_profile.down.sql similarity index 100% rename from internal/assets/migrations/000009_oidc_userinfo_profile.down.sql rename to internal/assets/migrations/sqlite/000009_oidc_userinfo_profile.down.sql diff --git a/internal/assets/migrations/000009_oidc_userinfo_profile.up.sql b/internal/assets/migrations/sqlite/000009_oidc_userinfo_profile.up.sql similarity index 100% rename from internal/assets/migrations/000009_oidc_userinfo_profile.up.sql rename to internal/assets/migrations/sqlite/000009_oidc_userinfo_profile.up.sql diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 5b342c48..a840ffcb 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -144,17 +144,14 @@ func (app *BootstrapApp) Setup() error { tlog.App.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name") // Database - db, err := app.SetupDatabase(app.config.Database.Path) + store, err := app.SetupStore() if err != nil { return fmt.Errorf("failed to setup database: %w", err) } - // Queries - queries := repository.New(db) - // Services - services, err := app.initServices(queries) + services, err := app.initServices(store) if err != nil { return fmt.Errorf("failed to initialize services: %w", err) @@ -210,7 +207,7 @@ func (app *BootstrapApp) Setup() error { // Start db cleanup routine tlog.App.Debug().Msg("Starting database cleanup routine") - go app.dbCleanupRoutine(queries) + go app.dbCleanupRoutine(store) // If analytics are not disabled, start heartbeat if app.config.Analytics.Enabled { @@ -300,7 +297,7 @@ func (app *BootstrapApp) heartbeatRoutine() { } } -func (app *BootstrapApp) dbCleanupRoutine(queries *repository.Queries) { +func (app *BootstrapApp) dbCleanupRoutine(queries repository.Store) { ticker := time.NewTicker(time.Duration(30) * time.Minute) defer ticker.Stop() ctx := context.Background() diff --git a/internal/bootstrap/db_bootstrap.go b/internal/bootstrap/db_bootstrap.go index 3f48f793..4f09372a 100644 --- a/internal/bootstrap/db_bootstrap.go +++ b/internal/bootstrap/db_bootstrap.go @@ -7,6 +7,9 @@ import ( "path/filepath" "github.com/tinyauthapp/tinyauth/internal/assets" + "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" + "github.com/tinyauthapp/tinyauth/internal/repository/sqlite" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/sqlite3" @@ -14,7 +17,18 @@ import ( _ "modernc.org/sqlite" ) -func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { +func (app *BootstrapApp) SetupStore() (repository.Store, error) { + switch app.config.Database.Driver { + case "memory": + return memory.New(), nil + case "sqlite", "": + return app.setupSQLite(app.config.Database.Path) + default: + return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver) + } +} + +func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) { dir := filepath.Dir(databasePath) if err := os.MkdirAll(dir, 0750); err != nil { @@ -31,7 +45,7 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { // if the sqlite connection starts being a bottleneck db.SetMaxOpenConns(1) - migrations, err := iofs.New(assets.Migrations, "migrations") + migrations, err := iofs.New(assets.Migrations, "migrations/sqlite") if err != nil { return nil, fmt.Errorf("failed to create migrations: %w", err) @@ -53,5 +67,5 @@ func (app *BootstrapApp) SetupDatabase(databasePath string) (*sql.DB, error) { return nil, fmt.Errorf("failed to migrate database: %w", err) } - return db, nil + return sqlite.NewStore(sqlite.New(db)), nil } diff --git a/internal/bootstrap/service_bootstrap.go b/internal/bootstrap/service_bootstrap.go index 09485bd0..48ded235 100644 --- a/internal/bootstrap/service_bootstrap.go +++ b/internal/bootstrap/service_bootstrap.go @@ -18,7 +18,7 @@ type Services struct { oidcService *service.OIDCService } -func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) { +func (app *BootstrapApp) initServices(queries repository.Store) (Services, error) { services := Services{} ldapService := service.NewLdapService(service.LdapServiceConfig{ diff --git a/internal/controller/oidc_controller_test.go b/internal/controller/oidc_controller_test.go index 150540fc..4f131ac7 100644 --- a/internal/controller/oidc_controller_test.go +++ b/internal/controller/oidc_controller_test.go @@ -14,10 +14,9 @@ import ( "github.com/google/go-querystring/query" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) @@ -852,14 +851,10 @@ func TestOIDCController(t *testing.T) { }, } - app := bootstrap.NewBootstrapApp(model.Config{}) + store := memory.New() - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) - - queries := repository.New(db) - oidcService := service.NewOIDCService(oidcServiceCfg, queries) - err = oidcService.Init() + oidcService := service.NewOIDCService(oidcServiceCfg, store) + err := oidcService.Init() require.NoError(t, err) for _, test := range tests { @@ -881,9 +876,4 @@ func TestOIDCController(t *testing.T) { test.run(t, router, recorder) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/controller/proxy_controller_test.go b/internal/controller/proxy_controller_test.go index 7b2e3202..c7876713 100644 --- a/internal/controller/proxy_controller_test.go +++ b/internal/controller/proxy_controller_test.go @@ -2,23 +2,20 @@ package controller_test import ( "net/http/httptest" - "path" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) func TestProxyController(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ LocalUsers: &[]model.LocalUser{ @@ -400,15 +397,10 @@ func TestProxyController(t *testing.T) { oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - app := bootstrap.NewBootstrapApp(model.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) - - queries := repository.New(db) + store := memory.New() docker := service.NewDockerService() - err = docker.Init() + err := docker.Init() require.NoError(t, err) ldap := service.NewLdapService(service.LdapServiceConfig{}) @@ -419,7 +411,7 @@ func TestProxyController(t *testing.T) { err = broker.Init() require.NoError(t, err) - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) + authService := service.NewAuthService(authServiceCfg, ldap, store, broker) err = authService.Init() require.NoError(t, err) @@ -444,9 +436,4 @@ func TestProxyController(t *testing.T) { test.run(t, router, recorder) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/controller/user_controller_test.go b/internal/controller/user_controller_test.go index 4863c16e..1d2f02e9 100644 --- a/internal/controller/user_controller_test.go +++ b/internal/controller/user_controller_test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "net/http" "net/http/httptest" - "path" "strings" "testing" "time" @@ -14,17 +13,16 @@ import ( "github.com/pquerna/otp/totp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) func TestUserController(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ LocalUsers: &[]model.LocalUser{ @@ -111,21 +109,14 @@ func TestUserController(t *testing.T) { }) } - oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - - app := bootstrap.NewBootstrapApp(model.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) - - queries := repository.New(db) - type testCase struct { description string middlewares []gin.HandlerFunc run func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) } + store := memory.New() + tests := []testCase{ { description: "Should be able to login with valid credentials", @@ -294,7 +285,7 @@ func TestUserController(t *testing.T) { totpCtx, }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - _, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{ + _, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{ UUID: "test-totp-login-uuid", Username: "test", Email: "test@example.com", @@ -418,7 +409,7 @@ func TestUserController(t *testing.T) { totpAttrCtx, }, run: func(t *testing.T, router *gin.Engine, recorder *httptest.ResponseRecorder) { - _, err := queries.CreateSession(context.TODO(), repository.CreateSessionParams{ + _, err := store.CreateSession(context.TODO(), repository.CreateSessionParams{ UUID: "test-totp-login-attributes-uuid", Username: "test", Email: "test@example.com", @@ -456,8 +447,10 @@ func TestUserController(t *testing.T) { }, } + oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) + docker := service.NewDockerService() - err = docker.Init() + err := docker.Init() require.NoError(t, err) ldap := service.NewLdapService(service.LdapServiceConfig{}) @@ -468,7 +461,7 @@ func TestUserController(t *testing.T) { err = broker.Init() require.NoError(t, err) - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) + authService := service.NewAuthService(authServiceCfg, ldap, store, broker) err = authService.Init() require.NoError(t, err) @@ -497,9 +490,4 @@ func TestUserController(t *testing.T) { test.run(t, router, recorder) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/controller/well_known_controller_test.go b/internal/controller/well_known_controller_test.go index 7dcf2bdc..582e4842 100644 --- a/internal/controller/well_known_controller_test.go +++ b/internal/controller/well_known_controller_test.go @@ -10,10 +10,9 @@ import ( "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/model" - "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) @@ -101,15 +100,10 @@ func TestWellKnownController(t *testing.T) { }, } - app := bootstrap.NewBootstrapApp(model.Config{}) + store := memory.New() - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) - - queries := repository.New(db) - - oidcService := service.NewOIDCService(oidcServiceCfg, queries) - err = oidcService.Init() + oidcService := service.NewOIDCService(oidcServiceCfg, store) + err := oidcService.Init() require.NoError(t, err) for _, test := range tests { @@ -125,9 +119,4 @@ func TestWellKnownController(t *testing.T) { test.run(t, router, recorder) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/middleware/context_middleware_test.go b/internal/middleware/context_middleware_test.go index 5dfde3b4..d1dfa99f 100644 --- a/internal/middleware/context_middleware_test.go +++ b/internal/middleware/context_middleware_test.go @@ -5,24 +5,22 @@ import ( "encoding/base64" "net/http" "net/http/httptest" - "path" "testing" "time" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" "github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/utils/tlog" ) func TestContextMiddleware(t *testing.T) { tlog.NewTestLogger().Init() - tempDir := t.TempDir() authServiceCfg := service.AuthServiceConfig{ LocalUsers: &[]model.LocalUser{ @@ -52,7 +50,7 @@ func TestContextMiddleware(t *testing.T) { return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) } - seedSession := func(t *testing.T, queries *repository.Queries, params repository.CreateSessionParams) { + seedSession := func(t *testing.T, queries repository.Store, params repository.CreateSessionParams) { t.Helper() _, err := queries.CreateSession(context.Background(), params) require.NoError(t, err) @@ -60,7 +58,7 @@ func TestContextMiddleware(t *testing.T) { type runArgs struct { do func(req *http.Request) (*model.UserContext, *httptest.ResponseRecorder) - queries *repository.Queries + queries repository.Store } type testCase struct { @@ -272,22 +270,17 @@ func TestContextMiddleware(t *testing.T) { oauthBrokerCfgs := make(map[string]model.OAuthServiceConfig) - app := bootstrap.NewBootstrapApp(model.Config{}) - - db, err := app.SetupDatabase(path.Join(tempDir, "tinyauth.db")) - require.NoError(t, err) - - queries := repository.New(db) + store := memory.New() ldap := service.NewLdapService(service.LdapServiceConfig{}) - err = ldap.Init() + err := ldap.Init() require.NoError(t, err) broker := service.NewOAuthBrokerService(oauthBrokerCfgs) err = broker.Init() require.NoError(t, err) - authService := service.NewAuthService(authServiceCfg, ldap, queries, broker) + authService := service.NewAuthService(authServiceCfg, ldap, store, broker) err = authService.Init() require.NoError(t, err) @@ -317,12 +310,7 @@ func TestContextMiddleware(t *testing.T) { return captured, recorder } - test.run(t, runArgs{do: do, queries: queries}) + test.run(t, runArgs{do: do, queries: store}) }) } - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) } diff --git a/internal/model/config.go b/internal/model/config.go index 95870e3d..cffe8734 100644 --- a/internal/model/config.go +++ b/internal/model/config.go @@ -4,7 +4,8 @@ package model func NewDefaultConfiguration() *Config { return &Config{ Database: DatabaseConfig{ - Path: "./tinyauth.db", + Driver: "sqlite", + Path: "./tinyauth.db", }, Analytics: AnalyticsConfig{ Enabled: true, @@ -82,7 +83,8 @@ type Config struct { } type DatabaseConfig struct { - Path string `description:"The path to the database, including file name." yaml:"path"` + Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"` + Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"` } type AnalyticsConfig struct { diff --git a/internal/repository/memory/memory_test.go b/internal/repository/memory/memory_test.go new file mode 100644 index 00000000..292d2abf --- /dev/null +++ b/internal/repository/memory/memory_test.go @@ -0,0 +1,427 @@ +package memory_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/tinyauthapp/tinyauth/internal/repository" + "github.com/tinyauthapp/tinyauth/internal/repository/memory" +) + +var ctx = context.Background() + +func TestCreateAndGetSession(t *testing.T) { + s := memory.New() + sess, err := s.CreateSession(ctx, repository.CreateSessionParams{ + UUID: "uuid-1", + Username: "alice", + Expiry: 9999, + }) + require.NoError(t, err) + assert.Equal(t, "uuid-1", sess.UUID) + assert.Equal(t, "alice", sess.Username) + + got, err := s.GetSession(ctx, "uuid-1") + require.NoError(t, err) + assert.Equal(t, sess, got) +} + +func TestGetSession_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetSession(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestUpdateSession(t *testing.T) { + s := memory.New() + _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1", Username: "alice"}) + require.NoError(t, err) + + updated, err := s.UpdateSession(ctx, repository.UpdateSessionParams{ + UUID: "uuid-1", + Username: "bob", + Email: "bob@example.com", + }) + require.NoError(t, err) + assert.Equal(t, "bob", updated.Username) + assert.Equal(t, "bob@example.com", updated.Email) + + got, err := s.GetSession(ctx, "uuid-1") + require.NoError(t, err) + assert.Equal(t, updated, got) +} + +func TestUpdateSession_NotFound(t *testing.T) { + s := memory.New() + _, err := s.UpdateSession(ctx, repository.UpdateSessionParams{UUID: "missing"}) + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteSession(t *testing.T) { + s := memory.New() + _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "uuid-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteSession(ctx, "uuid-1")) + + _, err = s.GetSession(ctx, "uuid-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteExpiredSessions(t *testing.T) { + s := memory.New() + _, err := s.CreateSession(ctx, repository.CreateSessionParams{UUID: "expired", Expiry: 10}) + require.NoError(t, err) + _, err = s.CreateSession(ctx, repository.CreateSessionParams{UUID: "valid", Expiry: 100}) + require.NoError(t, err) + + require.NoError(t, s.DeleteExpiredSessions(ctx, 50)) + + _, err = s.GetSession(ctx, "expired") + assert.ErrorIs(t, err, repository.ErrNotFound) + + _, err = s.GetSession(ctx, "valid") + assert.NoError(t, err) +} + +func TestCreateAndGetOidcCode(t *testing.T) { + s := memory.New() + code, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{ + Sub: "sub-1", + CodeHash: "hash-1", + Scope: "openid", + }) + require.NoError(t, err) + assert.Equal(t, "sub-1", code.Sub) + + // destructive read removes the record + got, err := s.GetOidcCode(ctx, "hash-1") + require.NoError(t, err) + assert.Equal(t, code, got) + + _, err = s.GetOidcCode(ctx, "hash-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCode_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcCode(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCodeBySub(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + got, err := s.GetOidcCodeBySub(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) + + // destructive — gone after read + _, err = s.GetOidcCodeBySub(ctx, "sub-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCodeBySub_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcCodeBySub(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCodeUnsafe(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + got, err := s.GetOidcCodeUnsafe(ctx, "hash-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) + + // non-destructive — still present + _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") + assert.NoError(t, err) +} + +func TestGetOidcCodeUnsafe_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcCodeUnsafe(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcCodeBySubUnsafe(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + got, err := s.GetOidcCodeBySubUnsafe(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, "hash-1", got.CodeHash) + + // non-destructive — still present + _, err = s.GetOidcCodeBySubUnsafe(ctx, "sub-1") + assert.NoError(t, err) +} + +func TestGetOidcCodeBySubUnsafe_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcCodeBySubUnsafe(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestCreateOidcCode_UniqueSubConstraint(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-2"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_codes.sub") +} + +func TestDeleteOidcCode(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcCode(ctx, "hash-1")) + + _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcCodeBySub(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcCodeBySub(ctx, "sub-1")) + + _, err = s.GetOidcCodeUnsafe(ctx, "hash-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteExpiredOidcCodes(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-1", CodeHash: "hash-1", ExpiresAt: 10}) + require.NoError(t, err) + _, err = s.CreateOidcCode(ctx, repository.CreateOidcCodeParams{Sub: "sub-2", CodeHash: "hash-2", ExpiresAt: 100}) + require.NoError(t, err) + + deleted, err := s.DeleteExpiredOidcCodes(ctx, 50) + require.NoError(t, err) + require.Len(t, deleted, 1) + assert.Equal(t, "hash-1", deleted[0].CodeHash) + + _, err = s.GetOidcCodeUnsafe(ctx, "hash-2") + assert.NoError(t, err) +} + +func TestCreateAndGetOidcToken(t *testing.T) { + s := memory.New() + tok, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-hash-1", + CodeHash: "code-hash-1", + }) + require.NoError(t, err) + assert.Equal(t, "sub-1", tok.Sub) + + got, err := s.GetOidcToken(ctx, "at-hash-1") + require.NoError(t, err) + assert.Equal(t, tok, got) +} + +func TestGetOidcToken_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcToken(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestCreateOidcToken_UniqueSubConstraint(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + require.NoError(t, err) + + _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-2"}) + assert.ErrorContains(t, err, "UNIQUE constraint failed: oidc_tokens.sub") +} + +func TestGetOidcTokenByRefreshToken(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + RefreshTokenHash: "rt-1", + }) + require.NoError(t, err) + + got, err := s.GetOidcTokenByRefreshToken(ctx, "rt-1") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) +} + +func TestGetOidcTokenByRefreshToken_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcTokenByRefreshToken(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestGetOidcTokenBySub(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + }) + require.NoError(t, err) + + got, err := s.GetOidcTokenBySub(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, "at-1", got.AccessTokenHash) +} + +func TestGetOidcTokenBySub_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcTokenBySub(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestUpdateOidcTokenByRefreshToken(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + RefreshTokenHash: "rt-1", + }) + require.NoError(t, err) + + updated, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ + RefreshTokenHash_2: "rt-1", + AccessTokenHash: "at-2", + RefreshTokenHash: "rt-2", + TokenExpiresAt: 200, + RefreshTokenExpiresAt: 400, + }) + require.NoError(t, err) + assert.Equal(t, "at-2", updated.AccessTokenHash) + assert.Equal(t, "rt-2", updated.RefreshTokenHash) + + // old key gone, new key present + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) + + got, err := s.GetOidcToken(ctx, "at-2") + require.NoError(t, err) + assert.Equal(t, "sub-1", got.Sub) +} + +func TestUpdateOidcTokenByRefreshToken_NotFound(t *testing.T) { + s := memory.New() + _, err := s.UpdateOidcTokenByRefreshToken(ctx, repository.UpdateOidcTokenByRefreshTokenParams{ + RefreshTokenHash_2: "missing", + }) + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcToken(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcToken(ctx, "at-1")) + + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcTokenBySub(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{Sub: "sub-1", AccessTokenHash: "at-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcTokenBySub(ctx, "sub-1")) + + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcTokenByCodeHash(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", + AccessTokenHash: "at-1", + CodeHash: "code-1", + }) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcTokenByCodeHash(ctx, "code-1")) + + _, err = s.GetOidcToken(ctx, "at-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteExpiredOidcTokens(t *testing.T) { + s := memory.New() + // expired by TokenExpiresAt + _, err := s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-1", AccessTokenHash: "at-1", + TokenExpiresAt: 10, RefreshTokenExpiresAt: 100, + }) + require.NoError(t, err) + // expired by RefreshTokenExpiresAt + _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-2", AccessTokenHash: "at-2", + TokenExpiresAt: 100, RefreshTokenExpiresAt: 10, + }) + require.NoError(t, err) + // valid + _, err = s.CreateOidcToken(ctx, repository.CreateOidcTokenParams{ + Sub: "sub-3", AccessTokenHash: "at-3", + TokenExpiresAt: 100, RefreshTokenExpiresAt: 100, + }) + require.NoError(t, err) + + deleted, err := s.DeleteExpiredOidcTokens(ctx, repository.DeleteExpiredOidcTokensParams{ + TokenExpiresAt: 50, + RefreshTokenExpiresAt: 50, + }) + require.NoError(t, err) + assert.Len(t, deleted, 2) + + _, err = s.GetOidcToken(ctx, "at-3") + assert.NoError(t, err) +} + +func TestCreateAndGetOidcUserInfo(t *testing.T) { + s := memory.New() + u, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{ + Sub: "sub-1", + Name: "Alice", + Email: "alice@example.com", + }) + require.NoError(t, err) + assert.Equal(t, "sub-1", u.Sub) + + got, err := s.GetOidcUserInfo(ctx, "sub-1") + require.NoError(t, err) + assert.Equal(t, u, got) +} + +func TestGetOidcUserInfo_NotFound(t *testing.T) { + s := memory.New() + _, err := s.GetOidcUserInfo(ctx, "missing") + assert.ErrorIs(t, err, repository.ErrNotFound) +} + +func TestDeleteOidcUserInfo(t *testing.T) { + s := memory.New() + _, err := s.CreateOidcUserInfo(ctx, repository.CreateOidcUserInfoParams{Sub: "sub-1"}) + require.NoError(t, err) + + require.NoError(t, s.DeleteOidcUserInfo(ctx, "sub-1")) + + _, err = s.GetOidcUserInfo(ctx, "sub-1") + assert.ErrorIs(t, err, repository.ErrNotFound) +} diff --git a/internal/repository/memory/oidc_queries.go b/internal/repository/memory/oidc_queries.go new file mode 100644 index 00000000..80305fc0 --- /dev/null +++ b/internal/repository/memory/oidc_queries.go @@ -0,0 +1,241 @@ +package memory + +import ( + "context" + "fmt" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + // Enforce sub UNIQUE constraint + for _, c := range s.oidcCodes { + if c.Sub == arg.Sub { + return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub") + } + } + code := repository.OidcCode(arg) + s.oidcCodes[arg.CodeHash] = code + return code, nil +} + +// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). +func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + c, ok := s.oidcCodes[codeHash] + if !ok { + return repository.OidcCode{}, repository.ErrNotFound + } + delete(s.oidcCodes, codeHash) + return c, nil +} + +// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING). +func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + for k, c := range s.oidcCodes { + if c.Sub == sub { + delete(s.oidcCodes, k) + return c, nil + } + } + return repository.OidcCode{}, repository.ErrNotFound +} + +// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT). +func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + c, ok := s.oidcCodes[codeHash] + if !ok { + return repository.OidcCode{}, repository.ErrNotFound + } + return c, nil +} + +// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT). +func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, c := range s.oidcCodes { + if c.Sub == sub { + return c, nil + } + } + return repository.OidcCode{}, repository.ErrNotFound +} + +func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcCodes, codeHash) + return nil +} + +func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, c := range s.oidcCodes { + if c.Sub == sub { + delete(s.oidcCodes, k) + } + } + return nil +} + +func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) { + s.mu.Lock() + defer s.mu.Unlock() + var deleted []repository.OidcCode + for k, c := range s.oidcCodes { + if c.ExpiresAt < expiresAt { + deleted = append(deleted, c) + delete(s.oidcCodes, k) + } + } + return deleted, nil +} + +func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + // Enforce sub UNIQUE constraint + for _, t := range s.oidcTokens { + if t.Sub == arg.Sub { + return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub") + } + } + tok := repository.OidcToken{ + Sub: arg.Sub, + AccessTokenHash: arg.AccessTokenHash, + RefreshTokenHash: arg.RefreshTokenHash, + CodeHash: arg.CodeHash, + Scope: arg.Scope, + ClientID: arg.ClientID, + TokenExpiresAt: arg.TokenExpiresAt, + RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt, + Nonce: arg.Nonce, + } + s.oidcTokens[arg.AccessTokenHash] = tok + return tok, nil +} + +func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + t, ok := s.oidcTokens[accessTokenHash] + if !ok { + return repository.OidcToken{}, repository.ErrNotFound + } + return t, nil +} + +func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, t := range s.oidcTokens { + if t.RefreshTokenHash == refreshTokenHash { + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) { + s.mu.RLock() + defer s.mu.RUnlock() + for _, t := range s.oidcTokens { + if t.Sub == sub { + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.RefreshTokenHash == arg.RefreshTokenHash_2 { + delete(s.oidcTokens, k) + t.AccessTokenHash = arg.AccessTokenHash + t.RefreshTokenHash = arg.RefreshTokenHash + t.TokenExpiresAt = arg.TokenExpiresAt + t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt + s.oidcTokens[arg.AccessTokenHash] = t + return t, nil + } + } + return repository.OidcToken{}, repository.ErrNotFound +} + +func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcTokens, accessTokenHash) + return nil +} + +func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.Sub == sub { + delete(s.oidcTokens, k) + } + } + return nil +} + +func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, t := range s.oidcTokens { + if t.CodeHash == codeHash { + delete(s.oidcTokens, k) + } + } + return nil +} + +func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { + s.mu.Lock() + defer s.mu.Unlock() + var deleted []repository.OidcToken + for k, t := range s.oidcTokens { + if t.TokenExpiresAt < arg.TokenExpiresAt || t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt { + deleted = append(deleted, t) + delete(s.oidcTokens, k) + } + } + return deleted, nil +} + +func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { + s.mu.Lock() + defer s.mu.Unlock() + u := repository.OidcUserinfo(arg) + s.oidcUsers[arg.Sub] = u + return u, nil +} + +func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) { + s.mu.RLock() + defer s.mu.RUnlock() + u, ok := s.oidcUsers[sub] + if !ok { + return repository.OidcUserinfo{}, repository.ErrNotFound + } + return u, nil +} + +func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.oidcUsers, sub) + return nil +} diff --git a/internal/repository/memory/session_queries.go b/internal/repository/memory/session_queries.go new file mode 100644 index 00000000..2edde6b1 --- /dev/null +++ b/internal/repository/memory/session_queries.go @@ -0,0 +1,63 @@ +package memory + +import ( + "context" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +func (s *Store) CreateSession(_ context.Context, arg repository.CreateSessionParams) (repository.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess := repository.Session(arg) + s.sessions[arg.UUID] = sess + return sess, nil +} + +func (s *Store) GetSession(_ context.Context, uuid string) (repository.Session, error) { + s.mu.RLock() + defer s.mu.RUnlock() + sess, ok := s.sessions[uuid] + if !ok { + return repository.Session{}, repository.ErrNotFound + } + return sess, nil +} + +func (s *Store) UpdateSession(_ context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { + s.mu.Lock() + defer s.mu.Unlock() + sess, ok := s.sessions[arg.UUID] + if !ok { + return repository.Session{}, repository.ErrNotFound + } + sess.Username = arg.Username + sess.Email = arg.Email + sess.Name = arg.Name + sess.Provider = arg.Provider + sess.TotpPending = arg.TotpPending + sess.OAuthGroups = arg.OAuthGroups + sess.Expiry = arg.Expiry + sess.OAuthName = arg.OAuthName + sess.OAuthSub = arg.OAuthSub + s.sessions[arg.UUID] = sess + return sess, nil +} + +func (s *Store) DeleteSession(_ context.Context, uuid string) error { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.sessions, uuid) + return nil +} + +func (s *Store) DeleteExpiredSessions(_ context.Context, expiry int64) error { + s.mu.Lock() + defer s.mu.Unlock() + for k, v := range s.sessions { + if v.Expiry < expiry { + delete(s.sessions, k) + } + } + return nil +} diff --git a/internal/repository/memory/store.go b/internal/repository/memory/store.go new file mode 100644 index 00000000..969cba66 --- /dev/null +++ b/internal/repository/memory/store.go @@ -0,0 +1,27 @@ +// Package memory provides an in-memory implementation of repository.Store for use in tests. +package memory + +import ( + "sync" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +// Store is a thread-safe in-memory implementation of repository.Store. +type Store struct { + mu sync.RWMutex + sessions map[string]repository.Session + oidcCodes map[string]repository.OidcCode + oidcTokens map[string]repository.OidcToken + oidcUsers map[string]repository.OidcUserinfo +} + +// New returns a new empty in-memory Store. +func New() repository.Store { + return &Store{ + sessions: make(map[string]repository.Session), + oidcCodes: make(map[string]repository.OidcCode), + oidcTokens: make(map[string]repository.OidcToken), + oidcUsers: make(map[string]repository.OidcUserinfo), + } +} diff --git a/internal/repository/models.go b/internal/repository/models.go index bc2e2c66..3f58dd66 100644 --- a/internal/repository/models.go +++ b/internal/repository/models.go @@ -1,9 +1,22 @@ -// Code generated by sqlc. DO NOT EDIT. -// versions: -// sqlc v1.30.0 - package repository +// Shared model and parameter types for all storage drivers. +// sqlc-generated driver packages use these via the conversion layer in their store.go. + +type Session struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + CreatedAt int64 + OAuthName string + OAuthSub string +} + type OidcCode struct { Sub string CodeHash string @@ -49,7 +62,7 @@ type OidcUserinfo struct { Address string } -type Session struct { +type CreateSessionParams struct { UUID string Username string Email string @@ -62,3 +75,74 @@ type Session struct { OAuthName string OAuthSub string } + +type UpdateSessionParams struct { + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + OAuthName string + OAuthSub string + UUID string +} + +type CreateOidcCodeParams struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 + Nonce string + CodeChallenge string +} + +type CreateOidcTokenParams struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + CodeHash string + Nonce string +} + +type UpdateOidcTokenByRefreshTokenParams struct { + AccessTokenHash string + RefreshTokenHash string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + RefreshTokenHash_2 string +} + +type DeleteExpiredOidcTokensParams struct { + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 +} + +type CreateOidcUserInfoParams struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 + GivenName string + FamilyName string + MiddleName string + Nickname string + Profile string + Picture string + Website string + Gender string + Birthdate string + Zoneinfo string + Locale string + PhoneNumber string + Address string +} diff --git a/internal/repository/db.go b/internal/repository/sqlite/db.go similarity index 93% rename from internal/repository/db.go rename to internal/repository/sqlite/db.go index 998bfd3b..51a4906a 100644 --- a/internal/repository/db.go +++ b/internal/repository/sqlite/db.go @@ -1,8 +1,8 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 -package repository +package sqlite import ( "context" diff --git a/internal/repository/sqlite/generate.go b/internal/repository/sqlite/generate.go new file mode 100644 index 00000000..5f6011f1 --- /dev/null +++ b/internal/repository/sqlite/generate.go @@ -0,0 +1,3 @@ +package sqlite + +//go:generate go run github.com/tinyauthapp/tinyauth/cmd/gen/sqlc-wrapper -pkg github.com/tinyauthapp/tinyauth/internal/repository/sqlite diff --git a/internal/repository/sqlite/models.go b/internal/repository/sqlite/models.go new file mode 100644 index 00000000..fd6f78da --- /dev/null +++ b/internal/repository/sqlite/models.go @@ -0,0 +1,64 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package sqlite + +type OidcCode struct { + Sub string + CodeHash string + Scope string + RedirectURI string + ClientID string + ExpiresAt int64 + Nonce string + CodeChallenge string +} + +type OidcToken struct { + Sub string + AccessTokenHash string + RefreshTokenHash string + CodeHash string + Scope string + ClientID string + TokenExpiresAt int64 + RefreshTokenExpiresAt int64 + Nonce string +} + +type OidcUserinfo struct { + Sub string + Name string + PreferredUsername string + Email string + Groups string + UpdatedAt int64 + GivenName string + FamilyName string + MiddleName string + Nickname string + Profile string + Picture string + Website string + Gender string + Birthdate string + Zoneinfo string + Locale string + PhoneNumber string + Address string +} + +type Session struct { + UUID string + Username string + Email string + Name string + Provider string + TotpPending bool + OAuthGroups string + Expiry int64 + CreatedAt int64 + OAuthName string + OAuthSub string +} diff --git a/internal/repository/oidc_queries.sql.go b/internal/repository/sqlite/oidc_queries.sql.go similarity index 99% rename from internal/repository/oidc_queries.sql.go rename to internal/repository/sqlite/oidc_queries.sql.go index 7caac9d4..e5d08bc2 100644 --- a/internal/repository/oidc_queries.sql.go +++ b/internal/repository/sqlite/oidc_queries.sql.go @@ -1,9 +1,9 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: oidc_queries.sql -package repository +package sqlite import ( "context" diff --git a/internal/repository/session_queries.sql.go b/internal/repository/sqlite/session_queries.sql.go similarity index 98% rename from internal/repository/session_queries.sql.go rename to internal/repository/sqlite/session_queries.sql.go index c846c3f9..7792fc4b 100644 --- a/internal/repository/session_queries.sql.go +++ b/internal/repository/sqlite/session_queries.sql.go @@ -1,9 +1,9 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.30.0 +// sqlc v1.31.1 // source: session_queries.sql -package repository +package sqlite import ( "context" diff --git a/internal/repository/sqlite/store.go b/internal/repository/sqlite/store.go new file mode 100644 index 00000000..f316efa4 --- /dev/null +++ b/internal/repository/sqlite/store.go @@ -0,0 +1,224 @@ +// Code generated by cmd/gen/sqlc-wrapper. DO NOT EDIT. +package sqlite + +import ( + "context" + "database/sql" + "errors" + + "github.com/tinyauthapp/tinyauth/internal/repository" +) + +// Store wraps *Queries and implements repository.Store. +type Store struct { + q *Queries +} + +// NewStore wraps a *Queries to satisfy repository.Store. +func NewStore(q *Queries) repository.Store { + return &Store{q: q} +} + +var errMap = []struct { + from error + to error +}{ + {sql.ErrNoRows, repository.ErrNotFound}, +} + +func mapErr(err error) error { + for _, e := range errMap { + if errors.Is(err, e.from) { + return e.to + } + } + return err +} + +func oidcCodeToRepo(v OidcCode) repository.OidcCode { + return repository.OidcCode(v) +} +func oidcTokenToRepo(v OidcToken) repository.OidcToken { + return repository.OidcToken(v) +} +func oidcUserinfoToRepo(v OidcUserinfo) repository.OidcUserinfo { + return repository.OidcUserinfo(v) +} +func sessionToRepo(v Session) repository.Session { + return repository.Session(v) +} +func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) { + r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg)) + if err != nil { + return repository.OidcCode{}, mapErr(err) + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) { + r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg)) + if err != nil { + return repository.OidcToken{}, mapErr(err) + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) { + r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg)) + if err != nil { + return repository.OidcUserinfo{}, mapErr(err) + } + return oidcUserinfoToRepo(r), nil +} + +func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) { + r, err := s.q.CreateSession(ctx, CreateSessionParams(arg)) + if err != nil { + return repository.Session{}, mapErr(err) + } + return sessionToRepo(r), nil +} + +func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) { + rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt) + if err != nil { + return nil, mapErr(err) + } + out := make([]repository.OidcCode, len(rows)) + for i, row := range rows { + out[i] = oidcCodeToRepo(row) + } + return out, nil +} + +func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) { + rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg)) + if err != nil { + return nil, mapErr(err) + } + out := make([]repository.OidcToken, len(rows)) + for i, row := range rows { + out[i] = oidcTokenToRepo(row) + } + return out, nil +} + +func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error { + return mapErr(s.q.DeleteExpiredSessions(ctx, expiry)) +} + +func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error { + return mapErr(s.q.DeleteOidcCode(ctx, codeHash)) +} + +func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error { + return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub)) +} + +func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error { + return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash)) +} + +func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error { + return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash)) +} + +func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error { + return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub)) +} + +func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error { + return mapErr(s.q.DeleteOidcUserInfo(ctx, sub)) +} + +func (s *Store) DeleteSession(ctx context.Context, uuid string) error { + return mapErr(s.q.DeleteSession(ctx, uuid)) +} + +func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCode(ctx, codeHash) + if err != nil { + return repository.OidcCode{}, mapErr(err) + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeBySub(ctx, sub) + if err != nil { + return repository.OidcCode{}, mapErr(err) + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub) + if err != nil { + return repository.OidcCode{}, mapErr(err) + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) { + r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash) + if err != nil { + return repository.OidcCode{}, mapErr(err) + } + return oidcCodeToRepo(r), nil +} + +func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) { + r, err := s.q.GetOidcToken(ctx, accessTokenHash) + if err != nil { + return repository.OidcToken{}, mapErr(err) + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) { + r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash) + if err != nil { + return repository.OidcToken{}, mapErr(err) + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) { + r, err := s.q.GetOidcTokenBySub(ctx, sub) + if err != nil { + return repository.OidcToken{}, mapErr(err) + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) { + r, err := s.q.GetOidcUserInfo(ctx, sub) + if err != nil { + return repository.OidcUserinfo{}, mapErr(err) + } + return oidcUserinfoToRepo(r), nil +} + +func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) { + r, err := s.q.GetSession(ctx, uuid) + if err != nil { + return repository.Session{}, mapErr(err) + } + return sessionToRepo(r), nil +} + +func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) { + r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg)) + if err != nil { + return repository.OidcToken{}, mapErr(err) + } + return oidcTokenToRepo(r), nil +} + +func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) { + r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg)) + if err != nil { + return repository.Session{}, mapErr(err) + } + return sessionToRepo(r), nil +} diff --git a/internal/repository/store.go b/internal/repository/store.go new file mode 100644 index 00000000..302f2f10 --- /dev/null +++ b/internal/repository/store.go @@ -0,0 +1,47 @@ +package repository + +import ( + "context" + "errors" +) + +// ErrNotFound is returned by Store methods when the requested record does not exist. +var ErrNotFound = errors.New("not found") + +// Store is the interface that all storage drivers must implement. +// The sqlc-generated *Queries struct satisfies this interface for SQLite. +// Future drivers (postgres, etc.) must return the shared types defined in this package. +type Store interface { + // Sessions + CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) + GetSession(ctx context.Context, uuid string) (Session, error) + UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) + DeleteSession(ctx context.Context, uuid string) error + DeleteExpiredSessions(ctx context.Context, expiry int64) error + + // OIDC codes + CreateOidcCode(ctx context.Context, arg CreateOidcCodeParams) (OidcCode, error) + GetOidcCode(ctx context.Context, codeHash string) (OidcCode, error) + GetOidcCodeBySub(ctx context.Context, sub string) (OidcCode, error) + GetOidcCodeUnsafe(ctx context.Context, codeHash string) (OidcCode, error) + GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (OidcCode, error) + DeleteOidcCode(ctx context.Context, codeHash string) error + DeleteOidcCodeBySub(ctx context.Context, sub string) error + DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]OidcCode, error) + + // OIDC tokens + CreateOidcToken(ctx context.Context, arg CreateOidcTokenParams) (OidcToken, error) + GetOidcToken(ctx context.Context, accessTokenHash string) (OidcToken, error) + GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (OidcToken, error) + GetOidcTokenBySub(ctx context.Context, sub string) (OidcToken, error) + UpdateOidcTokenByRefreshToken(ctx context.Context, arg UpdateOidcTokenByRefreshTokenParams) (OidcToken, error) + DeleteOidcToken(ctx context.Context, accessTokenHash string) error + DeleteOidcTokenBySub(ctx context.Context, sub string) error + DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error + DeleteExpiredOidcTokens(ctx context.Context, arg DeleteExpiredOidcTokensParams) ([]OidcToken, error) + + // OIDC userinfo + CreateOidcUserInfo(ctx context.Context, arg CreateOidcUserInfoParams) (OidcUserinfo, error) + GetOidcUserInfo(ctx context.Context, sub string) (OidcUserinfo, error) + DeleteOidcUserInfo(ctx context.Context, sub string) error +} diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 16c53fe0..c6a12944 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -2,7 +2,6 @@ package service import ( "context" - "database/sql" "errors" "fmt" "net/http" @@ -96,14 +95,14 @@ type AuthService struct { loginMutex sync.RWMutex ldapGroupsMutex sync.RWMutex ldap *LdapService - queries *repository.Queries + queries repository.Store oauthBroker *OAuthBrokerService lockdown *Lockdown lockdownCtx context.Context lockdownCancelFunc context.CancelFunc } -func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *repository.Queries, oauthBroker *OAuthBrokerService) *AuthService { +func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries repository.Store, oauthBroker *OAuthBrokerService) *AuthService { return &AuthService{ config: config, loginAttempts: make(map[string]*LoginAttempt), @@ -421,7 +420,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito session, err := auth.queries.GetSession(ctx, uuid) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return nil, errors.New("session not found") } return nil, err diff --git a/internal/service/oidc_service.go b/internal/service/oidc_service.go index 1e1c1986..65b5ccda 100644 --- a/internal/service/oidc_service.go +++ b/internal/service/oidc_service.go @@ -7,7 +7,6 @@ import ( "crypto/rsa" "crypto/sha256" "crypto/x509" - "database/sql" "encoding/base64" "encoding/json" "encoding/pem" @@ -121,7 +120,7 @@ type OIDCServiceConfig struct { type OIDCService struct { config OIDCServiceConfig - queries *repository.Queries + queries repository.Store clients map[string]model.OIDCClientConfig privateKey *rsa.PrivateKey publicKey crypto.PublicKey @@ -129,7 +128,7 @@ type OIDCService struct { isConfigured bool } -func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService { +func NewOIDCService(config OIDCServiceConfig, queries repository.Store) *OIDCService { return &OIDCService{ config: config, queries: queries, @@ -422,7 +421,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client oidcCode, err := service.queries.GetOidcCode(c, codeHash) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.OidcCode{}, ErrCodeNotFound } return repository.OidcCode{}, err @@ -566,7 +565,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken)) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return TokenResponse{}, ErrTokenNotFound } return TokenResponse{}, err @@ -645,7 +644,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re entry, err := service.queries.GetOidcToken(c, tokenHash) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { return repository.OidcToken{}, ErrTokenNotFound } return repository.OidcToken{}, err @@ -733,15 +732,15 @@ func (service *OIDCService) Hash(token string) string { func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error { err := service.queries.DeleteOidcCodeBySub(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } err = service.queries.DeleteOidcTokenBySub(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } err = service.queries.DeleteOidcUserInfo(ctx, sub) - if err != nil && !errors.Is(err, sql.ErrNoRows) { + if err != nil && !errors.Is(err, repository.ErrNotFound) { return err } return nil @@ -786,7 +785,7 @@ func (service *OIDCService) Cleanup() { token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub) if err != nil { - if errors.Is(err, sql.ErrNoRows) { + if errors.Is(err, repository.ErrNotFound) { continue } tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub") diff --git a/sql/oidc_queries.sql b/sql/sqlite/oidc_queries.sql similarity index 100% rename from sql/oidc_queries.sql rename to sql/sqlite/oidc_queries.sql diff --git a/sql/oidc_schemas.sql b/sql/sqlite/oidc_schemas.sql similarity index 100% rename from sql/oidc_schemas.sql rename to sql/sqlite/oidc_schemas.sql diff --git a/sql/session_queries.sql b/sql/sqlite/session_queries.sql similarity index 100% rename from sql/session_queries.sql rename to sql/sqlite/session_queries.sql diff --git a/sql/session_schemas.sql b/sql/sqlite/session_schemas.sql similarity index 100% rename from sql/session_schemas.sql rename to sql/sqlite/session_schemas.sql diff --git a/sqlc.yml b/sqlc.yml index de08738a..e7b2c4b4 100644 --- a/sqlc.yml +++ b/sqlc.yml @@ -1,12 +1,12 @@ version: "2" sql: - engine: "sqlite" - queries: "sql/*_queries.sql" - schema: "sql/*_schemas.sql" + queries: "sql/sqlite/*_queries.sql" + schema: "sql/sqlite/*_schemas.sql" gen: go: - package: "repository" - out: "internal/repository" + package: "sqlite" + out: "internal/repository/sqlite" rename: uuid: "UUID" oauth_groups: "OAuthGroups"