From 54bf79234abd44a95a7171ca76de5c25a95250a9 Mon Sep 17 00:00:00 2001 From: Richard Wooding Date: Tue, 24 Feb 2026 15:15:54 +0200 Subject: [PATCH 1/3] refactor: extract dialect-agnostic schema types into schema package Move FieldSchema, Schema, NewSchema, and related types from pg/ into a new schema/ package to decouple them from PostgreSQL. The pg package re-exports them as type aliases, maintaining full backward compatibility. Co-Authored-By: Claude Opus 4.6 --- cel2sql.go | 8 +++--- pg/provider.go | 74 +++++++----------------------------------------- schema/schema.go | 70 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 67 deletions(-) create mode 100644 schema/schema.go diff --git a/cel2sql.go b/cel2sql.go index 5827061..52d91b4 100644 --- a/cel2sql.go +++ b/cel2sql.go @@ -18,7 +18,7 @@ import ( "github.com/google/cel-go/common/overloads" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" - "github.com/spandigital/cel2sql/v3/pg" + "github.com/spandigital/cel2sql/v3/schema" ) // Implementations based on `google/cel-go`'s unparser @@ -62,7 +62,7 @@ type ConvertOption func(*convertOptions) // convertOptions holds configuration options for the Convert function. type convertOptions struct { - schemas map[string]pg.Schema + schemas map[string]schema.Schema ctx context.Context logger *slog.Logger maxDepth int // Maximum recursion depth (0 = use default) @@ -76,7 +76,7 @@ type convertOptions struct { // // schemas := provider.GetSchemas() // sql, err := cel2sql.Convert(ast, cel2sql.WithSchemas(schemas)) -func WithSchemas(schemas map[string]pg.Schema) ConvertOption { +func WithSchemas(schemas map[string]schema.Schema) ConvertOption { return func(o *convertOptions) { o.schemas = schemas } @@ -310,7 +310,7 @@ func ConvertParameterized(ast *cel.Ast, opts ...ConvertOption) (*Result, error) type converter struct { str strings.Builder typeMap map[int64]*exprpb.Type - schemas map[string]pg.Schema + schemas map[string]schema.Schema ctx context.Context logger *slog.Logger depth int // Current recursion depth diff --git a/pg/provider.go b/pg/provider.go index f48565b..0565251 100644 --- a/pg/provider.go +++ b/pg/provider.go @@ -13,6 +13,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/spandigital/cel2sql/v3/schema" "github.com/spandigital/cel2sql/v3/sqltypes" ) @@ -36,70 +37,17 @@ const ( errMsgUnknownType = "unknown type in schema" ) -// FieldSchema represents a PostgreSQL field type with name, type, and optional nested schema. -type FieldSchema struct { - Name string - Type string // PostgreSQL type name (text, integer, boolean, etc.) - Repeated bool // true for arrays - Dimensions int // number of array dimensions (1 for integer[], 2 for integer[][], etc.) - Schema []FieldSchema // for composite types - IsJSON bool // true for json/jsonb types - IsJSONB bool // true for jsonb (vs json) - ElementType string // for arrays: element type name -} - -// Schema represents a PostgreSQL table schema with O(1) field lookup. -// It contains a slice of fields for ordered iteration and a map index for fast lookups. -type Schema struct { - fields []FieldSchema - fieldIndex map[string]*FieldSchema -} - -// NewSchema creates a new Schema with field indexing for O(1) lookups. -// This improves performance for tables with many columns. -func NewSchema(fields []FieldSchema) Schema { - index := make(map[string]*FieldSchema, len(fields)) - for i := range fields { - index[fields[i].Name] = &fields[i] - - // Build indices for nested schemas recursively - if len(fields[i].Schema) > 0 { - fields[i].Schema = rebuildSchemaIndex(fields[i].Schema) - } - } - - return Schema{ - fields: fields, - fieldIndex: index, - } -} - -// rebuildSchemaIndex recursively rebuilds indices for nested schemas. -// This is used internally when converting old-style []FieldSchema to new Schema struct. -func rebuildSchemaIndex(oldSchema []FieldSchema) []FieldSchema { - // For nested schemas, we need to ensure they're properly indexed too - // But since nested schemas are stored as []FieldSchema in FieldSchema.Schema, - // we keep them as slices but process them when needed - return oldSchema -} +// FieldSchema is an alias for schema.FieldSchema for backward compatibility. +// New code should prefer schema.FieldSchema directly. +type FieldSchema = schema.FieldSchema -// Fields returns the ordered slice of field schemas. -// Use this when you need to iterate over fields in their defined order. -func (s Schema) Fields() []FieldSchema { - return s.fields -} - -// FindField performs an O(1) lookup for a field by name. -// Returns the field schema and true if found, nil and false otherwise. -func (s Schema) FindField(name string) (*FieldSchema, bool) { - field, found := s.fieldIndex[name] - return field, found -} +// Schema is an alias for schema.Schema for backward compatibility. +// New code should prefer schema.Schema directly. +type Schema = schema.Schema -// Len returns the number of fields in the schema. -func (s Schema) Len() int { - return len(s.fields) -} +// NewSchema creates a new Schema. This is an alias for schema.NewSchema. +// New code should prefer schema.NewSchema directly. +var NewSchema = schema.NewSchema // TypeProvider interface for PostgreSQL type providers type TypeProvider interface { @@ -294,7 +242,7 @@ func (p *typeProvider) findSchema(typeName string) (Schema, bool) { } // For nested types, traverse the schema hierarchy using O(1) lookups - currentFields := schema.fields + currentFields := schema.Fields() for _, tn := range typeNames[1:] { // Use O(1) indexed lookup instead of linear search var nestedField *FieldSchema diff --git a/schema/schema.go b/schema/schema.go new file mode 100644 index 0000000..5bb3547 --- /dev/null +++ b/schema/schema.go @@ -0,0 +1,70 @@ +// Package schema provides dialect-agnostic database schema types for CEL to SQL conversion. +// These types describe column names, types, array dimensions, and JSON flags without +// coupling to any specific SQL dialect. +package schema + +// FieldSchema represents a database field type with name, type, and optional nested schema. +// This type is dialect-agnostic and used by all SQL dialect providers. +type FieldSchema struct { + Name string + Type string // SQL type name (text, integer, boolean, etc.) + Repeated bool // true for arrays + Dimensions int // number of array dimensions (1 for integer[], 2 for integer[][], etc.) + Schema []FieldSchema // for composite types + IsJSON bool // true for json/jsonb types + IsJSONB bool // true for jsonb (vs json) + ElementType string // for arrays: element type name +} + +// Schema represents a table schema with O(1) field lookup. +// It contains a slice of fields for ordered iteration and a map index for fast lookups. +type Schema struct { + fields []FieldSchema + fieldIndex map[string]*FieldSchema +} + +// NewSchema creates a new Schema with field indexing for O(1) lookups. +// This improves performance for tables with many columns. +func NewSchema(fields []FieldSchema) Schema { + index := make(map[string]*FieldSchema, len(fields)) + for i := range fields { + index[fields[i].Name] = &fields[i] + + // Build indices for nested schemas recursively + if len(fields[i].Schema) > 0 { + fields[i].Schema = rebuildSchemaIndex(fields[i].Schema) + } + } + + return Schema{ + fields: fields, + fieldIndex: index, + } +} + +// rebuildSchemaIndex recursively rebuilds indices for nested schemas. +// This is used internally when converting old-style []FieldSchema to new Schema struct. +func rebuildSchemaIndex(oldSchema []FieldSchema) []FieldSchema { + // For nested schemas, we need to ensure they're properly indexed too + // But since nested schemas are stored as []FieldSchema in FieldSchema.Schema, + // we keep them as slices but process them when needed + return oldSchema +} + +// Fields returns the ordered slice of field schemas. +// Use this when you need to iterate over fields in their defined order. +func (s Schema) Fields() []FieldSchema { + return s.fields +} + +// FindField performs an O(1) lookup for a field by name. +// Returns the field schema and true if found, nil and false otherwise. +func (s Schema) FindField(name string) (*FieldSchema, bool) { + field, found := s.fieldIndex[name] + return field, found +} + +// Len returns the number of fields in the schema. +func (s Schema) Len() int { + return len(s.fields) +} From 107ffec1a0a46471c8d5c2de1f67d4674c5da2f5 Mon Sep 17 00:00:00 2001 From: Richard Wooding Date: Wed, 25 Feb 2026 10:38:05 +0200 Subject: [PATCH 2/3] feat: add multi-dialect SQL support for MySQL, SQLite, DuckDB, and BigQuery Add support for 5 SQL dialects (PostgreSQL, MySQL, SQLite, DuckDB, BigQuery) via a pluggable Dialect interface with ~40 methods covering SQL generation differences. Includes per-dialect type providers with LoadTableSchema support, shared test infrastructure, integration tests, and dialect-specific index analysis recommendations. Co-Authored-By: Claude Opus 4.6 --- CHANGELOG.md | 14 + CLAUDE.md | 194 ++++-- CONTRIBUTING.md | 86 ++- README.md | 290 +++++++-- analysis.go | 184 +++--- analysis_test.go | 344 ++++++++++- bigquery/provider.go | 228 ++++++++ bigquery/provider_test.go | 274 +++++++++ bigquery/testdata/provider_seed.yaml | 41 ++ bigquery_integration_test.go | 441 ++++++++++++++ cel2sql.go | 846 +++++++++------------------ dialect/bigquery/dialect.go | 503 ++++++++++++++++ dialect/bigquery/index_advisor.go | 86 +++ dialect/bigquery/regex.go | 137 +++++ dialect/bigquery/validation.go | 56 ++ dialect/dialect.go | 234 ++++++++ dialect/duckdb/dialect.go | 473 +++++++++++++++ dialect/duckdb/index_advisor.go | 92 +++ dialect/duckdb/regex.go | 137 +++++ dialect/duckdb/validation.go | 55 ++ dialect/index_advisor.go | 61 ++ dialect/mysql/dialect.go | 475 +++++++++++++++ dialect/mysql/index_advisor.go | 93 +++ dialect/mysql/regex.go | 139 +++++ dialect/mysql/validation.go | 91 +++ dialect/postgres/dialect.go | 496 ++++++++++++++++ dialect/postgres/index_advisor.go | 110 ++++ dialect/postgres/regex.go | 143 +++++ dialect/postgres/validation.go | 66 +++ dialect/registry.go | 42 ++ dialect/sqlite/dialect.go | 462 +++++++++++++++ dialect/sqlite/index_advisor.go | 73 +++ dialect/sqlite/validation.go | 68 +++ duckdb/provider.go | 251 ++++++++ duckdb/provider_test.go | 157 +++++ errors.go | 3 + examples/index_analysis/main.go | 127 ++-- go.mod | 51 +- go.sum | 173 +++++- json.go | 79 +-- mysql/provider.go | 226 +++++++ mysql/provider_test.go | 280 +++++++++ mysql_integration_test.go | 447 ++++++++++++++ sqlite/provider.go | 280 +++++++++ sqlite/provider_test.go | 348 +++++++++++ sqlite_integration_test.go | 518 ++++++++++++++++ testcases/array_tests.go | 69 +++ testcases/basic_tests.go | 141 +++++ testcases/cast_tests.go | 81 +++ testcases/comprehension_tests.go | 64 ++ testcases/fixtures.go | 93 +++ testcases/json_tests.go | 49 ++ testcases/operator_tests.go | 91 +++ testcases/parameterized_tests.go | 111 ++++ testcases/regex_tests.go | 93 +++ testcases/string_tests.go | 69 +++ testcases/testcases.go | 98 ++++ testcases/timestamp_tests.go | 120 ++++ testdata/bigquery_seed.yaml | 78 +++ testutil/env.go | 310 ++++++++++ testutil/runner.go | 180 ++++++ testutil/runner_bigquery_test.go | 16 + testutil/runner_duckdb_test.go | 16 + testutil/runner_mysql_test.go | 12 + testutil/runner_pg_test.go | 16 + testutil/runner_sqlite_test.go | 16 + timestamps.go | 119 ++-- 67 files changed, 10830 insertions(+), 986 deletions(-) create mode 100644 bigquery/provider.go create mode 100644 bigquery/provider_test.go create mode 100644 bigquery/testdata/provider_seed.yaml create mode 100644 bigquery_integration_test.go create mode 100644 dialect/bigquery/dialect.go create mode 100644 dialect/bigquery/index_advisor.go create mode 100644 dialect/bigquery/regex.go create mode 100644 dialect/bigquery/validation.go create mode 100644 dialect/dialect.go create mode 100644 dialect/duckdb/dialect.go create mode 100644 dialect/duckdb/index_advisor.go create mode 100644 dialect/duckdb/regex.go create mode 100644 dialect/duckdb/validation.go create mode 100644 dialect/index_advisor.go create mode 100644 dialect/mysql/dialect.go create mode 100644 dialect/mysql/index_advisor.go create mode 100644 dialect/mysql/regex.go create mode 100644 dialect/mysql/validation.go create mode 100644 dialect/postgres/dialect.go create mode 100644 dialect/postgres/index_advisor.go create mode 100644 dialect/postgres/regex.go create mode 100644 dialect/postgres/validation.go create mode 100644 dialect/registry.go create mode 100644 dialect/sqlite/dialect.go create mode 100644 dialect/sqlite/index_advisor.go create mode 100644 dialect/sqlite/validation.go create mode 100644 duckdb/provider.go create mode 100644 duckdb/provider_test.go create mode 100644 mysql/provider.go create mode 100644 mysql/provider_test.go create mode 100644 mysql_integration_test.go create mode 100644 sqlite/provider.go create mode 100644 sqlite/provider_test.go create mode 100644 sqlite_integration_test.go create mode 100644 testcases/array_tests.go create mode 100644 testcases/basic_tests.go create mode 100644 testcases/cast_tests.go create mode 100644 testcases/comprehension_tests.go create mode 100644 testcases/fixtures.go create mode 100644 testcases/json_tests.go create mode 100644 testcases/operator_tests.go create mode 100644 testcases/parameterized_tests.go create mode 100644 testcases/regex_tests.go create mode 100644 testcases/string_tests.go create mode 100644 testcases/testcases.go create mode 100644 testcases/timestamp_tests.go create mode 100644 testdata/bigquery_seed.yaml create mode 100644 testutil/env.go create mode 100644 testutil/runner.go create mode 100644 testutil/runner_bigquery_test.go create mode 100644 testutil/runner_duckdb_test.go create mode 100644 testutil/runner_mysql_test.go create mode 100644 testutil/runner_pg_test.go create mode 100644 testutil/runner_sqlite_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 413e6e8..d7bf631 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,20 @@ ## [Unreleased] +### Added +- **Multi-Dialect SQL Support** + - Introduced `Dialect` interface for pluggable SQL generation (`dialect/dialect.go`) + - PostgreSQL dialect extracted from converter into `dialect/postgres/` (zero behavior change) + - MySQL dialect implementation (`dialect/mysql/`) + - SQLite dialect implementation (`dialect/sqlite/`) + - DuckDB dialect implementation (`dialect/duckdb/`) + - BigQuery dialect implementation (`dialect/bigquery/`) + - `WithDialect()` option for `Convert()` and `ConvertParameterized()` (defaults to PostgreSQL) + - Per-dialect type providers: `mysql/provider.go`, `sqlite/provider.go`, `duckdb/provider.go`, `bigquery/provider.go` + - Dialect-agnostic schema types in `schema/` package + - Shared test case infrastructure (`testcases/`, `testutil/`) with per-dialect expected SQL + - Dialect registry for name-based lookup (`dialect.Register()`, `dialect.Get()`) + ## [3.5.0] - 2026-01-08 ### Changed diff --git a/CLAUDE.md b/CLAUDE.md index 4549260..a77e51b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,7 +4,7 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co ## Project Overview -cel2sql converts CEL (Common Expression Language) expressions to PostgreSQL SQL conditions. It specifically targets PostgreSQL standard SQL and was recently migrated from BigQuery. +cel2sql converts CEL (Common Expression Language) expressions to SQL conditions. It supports multiple SQL dialects: PostgreSQL (default), MySQL, SQLite, DuckDB, and BigQuery. **Module**: `github.com/spandigital/cel2sql/v3` **Go Version**: 1.24+ @@ -70,10 +70,31 @@ go test -v -run TestFunctionName ./... 6. **`pg/provider.go`** - PostgreSQL type provider for CEL type system - Maps PostgreSQL types to CEL types - - Supports dynamic schema loading from live databases + - Supports dynamic schema loading from live databases via `LoadTableSchema` - Handles composite types and arrays -7. **`sqltypes/types.go`** - Custom SQL type definitions for CEL (DATE, TIME, DATETIME, INTERVAL) +7. **`mysql/provider.go`** - MySQL type provider + - Maps MySQL types to CEL types + - `LoadTableSchema` uses `information_schema.columns` with `table_schema = DATABASE()` + - Accepts `*sql.DB` (caller owns connection) + +8. **`sqlite/provider.go`** - SQLite type provider + - Maps SQLite type affinity to CEL types + - `LoadTableSchema` uses `PRAGMA table_info` with table name validation + - Accepts `*sql.DB` (caller owns connection) + +9. **`duckdb/provider.go`** - DuckDB type provider + - Maps DuckDB types to CEL types, detects array types from `[]` suffix + - `LoadTableSchema` uses `information_schema.columns` + - Accepts `*sql.DB` (works with any DuckDB driver) + +10. **`bigquery/provider.go`** - BigQuery type provider + - Maps BigQuery types to CEL types + - `LoadTableSchema` uses BigQuery client API (`Table.Metadata`) + - Handles nested RECORD types recursively + - Accepts `*bigquery.Client` + dataset ID + +11. **`sqltypes/types.go`** - Custom SQL type definitions for CEL (DATE, TIME, DATETIME, INTERVAL) ### Type System Integration @@ -198,11 +219,14 @@ These validations prevent PostgreSQL syntax errors and ensure predictable behavi - Include package comments for main packages ### Testing Guidelines -- Use PostgreSQL schemas (`pg.Schema`) in tests, not BigQuery -- Use `pg.NewTypeProvider()` for schema definitions +- Use the dialect-specific schema/provider for each dialect's tests +- Use `pg.NewTypeProvider()` for PostgreSQL, `mysql.NewTypeProvider()` for MySQL, etc. - Include tests for nested types, arrays, and JSON fields -- Verify SQL output matches PostgreSQL syntax (single quotes, proper functions) -- Use testcontainers for integration tests with real PostgreSQL +- Verify SQL output matches the target dialect's syntax +- Use testcontainers for integration tests (PostgreSQL, MySQL, BigQuery) +- Use in-memory databases for SQLite integration tests (no Docker needed) +- DuckDB integration tests require CGO; use unit tests for type mapping validation +- Provider tests live in `{dialect}/provider_test.go` ### Performance Benchmarks @@ -334,8 +358,12 @@ benchstat bench-old.txt bench-new.txt ## Common Patterns -### Creating Type Providers +### Creating Type Providers (Pre-defined Schemas) + +All dialects support pre-defined schemas via `NewTypeProvider`: + ```go +// PostgreSQL schema := pg.NewSchema([]pg.FieldSchema{ {Name: "field_name", Type: "text", Repeated: false}, {Name: "array_field", Type: "text", Repeated: true}, @@ -343,19 +371,57 @@ schema := pg.NewSchema([]pg.FieldSchema{ {Name: "composite_field", Type: "composite", Schema: []pg.FieldSchema{...}}, }) provider := pg.NewTypeProvider(map[string]pg.Schema{"TableName": schema}) + +// MySQL (same schema types, dialect-specific type names) +schema := mysql.NewSchema([]mysql.FieldSchema{ + {Name: "name", Type: "varchar"}, + {Name: "metadata", Type: "json", IsJSON: true}, +}) +provider := mysql.NewTypeProvider(map[string]mysql.Schema{"TableName": schema}) + +// SQLite, DuckDB, BigQuery follow the same pattern with their own type names ``` ### Dynamic Schema Loading + +All dialects support runtime schema introspection from live databases via `LoadTableSchema`: + ```go +// PostgreSQL — accepts connection string, manages its own pool provider, err := pg.NewTypeProviderWithConnection(ctx, connectionString) if err != nil { return err } defer provider.Close() +err = provider.LoadTableSchema(ctx, "tableName") + +// MySQL — accepts *sql.DB, caller owns connection +db, _ := sql.Open("mysql", "user:pass@tcp(host:3306)/db?parseTime=true") +provider, err := mysql.NewTypeProviderWithConnection(ctx, db) +err = provider.LoadTableSchema(ctx, "tableName") + +// SQLite — accepts *sql.DB, uses PRAGMA table_info (validates table name) +db, _ := sql.Open("sqlite", "mydb.sqlite") +provider, err := sqlite.NewTypeProviderWithConnection(ctx, db) +err = provider.LoadTableSchema(ctx, "tableName") + +// DuckDB — accepts *sql.DB, works with any DuckDB driver +db, _ := sql.Open("duckdb", "mydb.duckdb") +provider, err := duckdb.NewTypeProviderWithConnection(ctx, db) +err = provider.LoadTableSchema(ctx, "tableName") +// BigQuery — accepts *bigquery.Client + dataset ID +client, _ := bigquery.NewClient(ctx, "project-id") +provider, err := bqprovider.NewTypeProviderWithClient(ctx, client, "dataset_id") err = provider.LoadTableSchema(ctx, "tableName") ``` +**Key differences per dialect:** +- **PostgreSQL**: `NewTypeProviderWithConnection(ctx, connString)` — owns its pgxpool, `Close()` releases it +- **MySQL/SQLite/DuckDB**: `NewTypeProviderWithConnection(ctx, *sql.DB)` — caller owns DB, `Close()` is no-op +- **BigQuery**: `NewTypeProviderWithClient(ctx, *bigquery.Client, datasetID)` — caller owns client, `Close()` is no-op +- **SQLite**: Table name validated via regex (`^[a-zA-Z_][a-zA-Z0-9_]*$`) since PRAGMA doesn't support parameterized queries + ### CEL Environment Setup ```go env, err := cel.NewEnv( @@ -377,7 +443,18 @@ sqlCondition, err := cel2sql.Convert(ast) ### Query Analysis and Index Recommendations -cel2sql can analyze CEL expressions and recommend database indexes to optimize performance. +cel2sql can analyze CEL expressions and recommend **dialect-specific** database indexes to optimize performance. + +#### Architecture + +Index analysis uses the **IndexAdvisor** interface (`dialect/index_advisor.go`): +- **Pattern detection** stays centralized in `analysis.go` (walks the CEL AST once) +- **DDL generation** is delegated to per-dialect `IndexAdvisor` implementations +- Each built-in dialect implements `IndexAdvisor` on its `*Dialect` struct +- Use `dialect.GetIndexAdvisor(d)` to type-assert a dialect to `IndexAdvisor` +- Unsupported patterns return `nil` (silently skipped) + +**PatternTypes** detected: `PatternComparison`, `PatternJSONAccess`, `PatternRegexMatch`, `PatternArrayMembership`, `PatternArrayComprehension`, `PatternJSONArrayComprehension`. #### Using AnalyzeQuery @@ -387,16 +464,18 @@ if issues != nil && issues.Err() != nil { return issues.Err() } +// PostgreSQL (default) sql, recommendations, err := cel2sql.AnalyzeQuery(ast, cel2sql.WithSchemas(schemas)) + +// Or with a specific dialect +sql, recommendations, err := cel2sql.AnalyzeQuery(ast, + cel2sql.WithSchemas(schemas), + cel2sql.WithDialect(mysql.New())) if err != nil { return err } -// Use the generated SQL -rows, err := db.Query("SELECT * FROM people WHERE " + sql) - -// Review and apply index recommendations for _, rec := range recommendations { fmt.Printf("Column: %s, Type: %s\n", rec.Column, rec.IndexType) fmt.Printf("Reason: %s\n", rec.Reason) @@ -404,42 +483,51 @@ for _, rec := range recommendations { } ``` -#### Index Recommendation Types - -AnalyzeQuery detects patterns and recommends appropriate index types: +#### Per-Dialect Index Types -- **B-tree indexes**: Comparison operations (`==, >, <, >=, <=`) - - Best for: Equality checks, range queries, sorting - - Example: `person.age > 18` → B-tree on `person.age` +| Pattern | PostgreSQL | MySQL | SQLite | DuckDB | BigQuery | +|---------|-----------|-------|--------|--------|----------| +| Comparison | BTREE | BTREE | BTREE | ART | CLUSTERING | +| JSON access | GIN | BTREE (functional) | _(nil)_ | ART | SEARCH_INDEX | +| Regex match | GIN + pg_trgm | FULLTEXT | _(nil)_ | _(nil)_ | _(nil)_ | +| Array membership | GIN | _(nil)_ | _(nil)_ | ART | _(nil)_ | +| Array comprehension | GIN | _(nil)_ | _(nil)_ | ART | _(nil)_ | +| JSON array comprehension | GIN | BTREE (functional) | _(nil)_ | ART | SEARCH_INDEX | -- **GIN indexes**: JSON/JSONB path operations, array operations - - Best for: JSON field access, array membership, containment - - Example: `person.metadata.verified == true` → GIN on `person.metadata` - - Example: `"premium" in person.tags` → GIN on `person.tags` - -- **GIN indexes with pg_trgm**: Regex pattern matching - - Best for: Text search, pattern matching, fuzzy matching - - Requires: PostgreSQL pg_trgm extension - - Example: `person.email.matches(r"@example\.com$")` → GIN on `person.email` +**Per-dialect DDL examples:** +- **PostgreSQL**: `CREATE INDEX idx_col_gin ON table_name USING GIN (col);` +- **MySQL**: `CREATE INDEX idx_col_btree ON table_name (col);` / `CREATE FULLTEXT INDEX ...` +- **SQLite**: `CREATE INDEX idx_col ON table_name (col);` +- **DuckDB**: `CREATE INDEX idx_col ON table_name (col);` (ART by default) +- **BigQuery**: `ALTER TABLE t SET OPTIONS (clustering_columns=['col']);` / `CREATE SEARCH INDEX ...` #### IndexRecommendation Structure ```go type IndexRecommendation struct { Column string // Full column name (e.g., "person.metadata") - IndexType string // "BTREE", "GIN", or "GIST" - Expression string // Complete CREATE INDEX statement + IndexType string // Dialect-specific: "BTREE", "GIN", "ART", "CLUSTERING", "SEARCH_INDEX", etc. + Expression string // Complete DDL statement for the target dialect Reason string // Explanation of why this index is recommended } ``` +#### Implementation Files + +- `dialect/index_advisor.go` — `IndexAdvisor` interface, `PatternType`, `IndexPattern`, `GetIndexAdvisor()` helper +- `dialect/postgres/index_advisor.go` — PostgreSQL: BTREE, GIN, GIN+pg_trgm +- `dialect/mysql/index_advisor.go` — MySQL: BTREE, FULLTEXT +- `dialect/sqlite/index_advisor.go` — SQLite: BTREE only +- `dialect/duckdb/index_advisor.go` — DuckDB: ART +- `dialect/bigquery/index_advisor.go` — BigQuery: CLUSTERING, SEARCH_INDEX + #### When to Use - **Development**: Discover which indexes your queries need - **Performance tuning**: Identify missing indexes causing slow queries - **Production monitoring**: Analyze user-generated filter expressions -See `examples/index_analysis/` for a complete working example. +See `examples/index_analysis/` for a complete working example with all 5 dialects. ### Logging and Observability @@ -731,23 +819,23 @@ For detailed security information, see the security documentation. ## Important Notes ### Migration Context -This project was migrated from BigQuery to PostgreSQL in v2.0: -- All `cloud.google.com/go/bigquery` dependencies removed -- `bq/` package removed entirely -- PostgreSQL-specific syntax (single quotes, POSITION(), ARRAY_LENGTH(,1), etc.) -- Comprehensive JSON/JSONB support added -- Dynamic schema loading added +This project was originally BigQuery-only, migrated to PostgreSQL in v2.0, and expanded to multi-dialect in v3.0: +- v2.0: All `cloud.google.com/go/bigquery` dependencies removed, `bq/` package removed +- v3.0: Multi-dialect support added (PostgreSQL, MySQL, SQLite, DuckDB, BigQuery) +- Each dialect has its own type provider with `LoadTableSchema` support +- BigQuery dependency re-added for BigQuery dialect support ### Things to Avoid -- Do NOT add BigQuery dependencies back - Do NOT remove protobuf dependencies (required by CEL) - Do NOT use direct SQL string concatenation (use proper escaping) - Do NOT ignore context cancellation in database operations +- Do NOT use `PRAGMA` with user-controlled table names without validation (SQLite) +- Do NOT assume a specific dialect — use the dialect interface for dialect-specific behavior ### When Adding Features -1. Consider PostgreSQL-specific SQL syntax -2. Add comprehensive tests with realistic schemas -3. Update type mappings in `pg/provider.go` if needed +1. Consider all supported SQL dialects, not just PostgreSQL +2. Add comprehensive tests with realistic schemas for each affected dialect +3. Update type mappings in the appropriate `{dialect}/provider.go` if needed 4. Document new CEL operators/functions in README.md 5. Ensure backward compatibility 6. Run `make ci` before committing @@ -756,18 +844,40 @@ This project was migrated from BigQuery to PostgreSQL in v2.0: ``` cel2sql/ ├── cel2sql.go # Main conversion engine +├── analysis.go # Query analysis and index recommendations (multi-dialect) ├── comprehensions.go # CEL comprehensions support ├── json.go # JSON/JSONB handling ├── operators.go # Operator conversion ├── timestamps.go # Timestamp/duration handling ├── utils.go # Utility functions +├── schema/ # Dialect-agnostic schema types +│ └── schema.go # FieldSchema, Schema with O(1) lookup ├── pg/ # PostgreSQL type provider -│ └── provider.go +│ └── provider.go # LoadTableSchema via information_schema + pgxpool +├── mysql/ # MySQL type provider +│ └── provider.go # LoadTableSchema via information_schema + *sql.DB +├── sqlite/ # SQLite type provider +│ └── provider.go # LoadTableSchema via PRAGMA table_info + *sql.DB +├── duckdb/ # DuckDB type provider +│ └── provider.go # LoadTableSchema via information_schema + *sql.DB +├── bigquery/ # BigQuery type provider +│ └── provider.go # LoadTableSchema via BigQuery client API +├── dialect/ # Dialect interface and implementations +│ ├── dialect.go # Core Dialect interface (~40 methods) +│ ├── index_advisor.go # IndexAdvisor interface, PatternType, IndexPattern +│ ├── postgres/ # PostgreSQL dialect + IndexAdvisor (BTREE, GIN, GIN+trgm) +│ ├── mysql/ # MySQL dialect + IndexAdvisor (BTREE, FULLTEXT) +│ ├── sqlite/ # SQLite dialect + IndexAdvisor (BTREE only) +│ ├── duckdb/ # DuckDB dialect + IndexAdvisor (ART) +│ └── bigquery/ # BigQuery dialect + IndexAdvisor (CLUSTERING, SEARCH_INDEX) ├── sqltypes/ # Custom SQL types for CEL │ └── types.go +├── testcases/ # Shared test cases with per-dialect expected SQL +├── testutil/ # Multi-dialect test runner + env factories └── examples/ # Usage examples ├── basic/ ├── comprehensions/ + ├── index_analysis/ # Multi-dialect index recommendation demo └── load_table_schema/ ``` diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b4206b8..cb312b2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -137,16 +137,44 @@ import ( 1. Add the function mapping in `cel2sql.go` 2. Add comprehensive tests in `cel2sql_test.go` 3. Update the README with documentation -4. Ensure PostgreSQL compatibility +4. Ensure the function works with dialect abstraction -### PostgreSQL Focus +### Multi-Dialect Architecture -This project targets PostgreSQL. When adding features: +cel2sql supports PostgreSQL (default), MySQL, SQLite, DuckDB, and BigQuery. When adding features: -- Use PostgreSQL-specific SQL syntax -- Test with realistic PostgreSQL schemas -- Use `pgx/v5` driver patterns -- Avoid BigQuery-specific features +- Call `con.dialect.*` methods for any SQL that differs between databases +- Standard SQL (AND, OR, =, !=, etc.) stays inline in the converter +- Add expected SQL for all dialects in `testcases/*.go` +- Run `make ci` to verify all dialects pass + +### Adding a New Dialect + +To add support for a new SQL dialect: + +1. **Create the dialect package**: `dialect//dialect.go` + - Implement the `dialect.Dialect` interface (~40 methods) + - Register with `dialect.Register()` in `init()` + +2. **Create regex conversion** (if applicable): `dialect//regex.go` + - Convert RE2 patterns to the dialect's regex format + - Include ReDoS protection (pattern length, nesting limits) + +3. **Create validation**: `dialect//validation.go` + - Field name validation, reserved keywords + +4. **Create type provider**: `/provider.go` + - Map native database types to CEL types + +5. **Add env factory**: `testutil/env.go` + - Add `EnvFactory()` function + - Update `DialectEnvFactory()` switch + +6. **Add test runner**: `testutil/runner__test.go` + +7. **Add expected SQL to all test case files** in `testcases/` + +8. **Update `dialect/dialect.go`** to add the dialect name constant ## Pull Request Process @@ -174,30 +202,36 @@ This project targets PostgreSQL. When adding features: ``` cel2sql/ -├── cel2sql.go # Main conversion engine -├── cel2sql_test.go # Main tests -├── pg/ # PostgreSQL type provider -│ ├── provider.go # Type provider implementation -│ └── provider_test.go # Type provider tests -├── sqltypes/ # Custom SQL types -│ └── types.go # CEL type definitions -├── examples/ # Usage examples -│ ├── basic/ # Basic usage example -│ │ ├── main.go -│ │ └── README.md -│ ├── load_table_schema/ # Dynamic schema loading example -│ │ ├── main.go -│ │ └── README.md -│ └── README.md # Examples overview -└── test/ # Test utilities - └── testdata.go # Test schemas +├── cel2sql.go # Main conversion engine (uses dialect interface) +├── cel2sql_test.go # Main tests +├── dialect/ # Dialect interface + implementations +│ ├── dialect.go # Interface definition + Name type +│ ├── registry.go # Name→Dialect lookup +│ ├── postgres/ # PostgreSQL dialect +│ ├── mysql/ # MySQL dialect +│ ├── sqlite/ # SQLite dialect +│ ├── duckdb/ # DuckDB dialect +│ └── bigquery/ # BigQuery dialect +├── pg/ # PostgreSQL type provider +├── mysql/ # MySQL type provider +├── sqlite/ # SQLite type provider +├── duckdb/ # DuckDB type provider +├── bigquery/ # BigQuery type provider +├── schema/ # Dialect-agnostic schema types +├── sqltypes/ # Custom SQL types for CEL +├── testcases/ # Shared test cases with per-dialect expected SQL +├── testutil/ # Test runner + env factories +└── examples/ # Usage examples ``` ### Key Components -- **cel2sql.go**: Core conversion logic from CEL AST to SQL -- **pg/provider.go**: PostgreSQL type system integration +- **cel2sql.go**: Core conversion logic from CEL AST to SQL (calls dialect methods) +- **dialect/dialect.go**: Dialect interface defining all SQL generation points +- **dialect/*/dialect.go**: Per-dialect SQL generation implementations +- **pg/provider.go**, **mysql/provider.go**, etc.: Type system integration per dialect - **sqltypes/types.go**: Custom SQL type definitions for CEL +- **testcases/*.go**: Shared test cases with expected SQL for all dialects ## Debugging diff --git a/README.md b/README.md index 93ae0b4..4774982 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,17 @@ # cel2sql -> Convert [CEL (Common Expression Language)](https://cel.dev/) expressions to PostgreSQL SQL +> Convert [CEL (Common Expression Language)](https://cel.dev/) expressions to SQL for PostgreSQL, MySQL, SQLite, DuckDB, and BigQuery [![Go Version](https://img.shields.io/badge/Go-1.24%2B-blue)](https://golang.org) -[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-17-blue)](https://www.postgresql.org) +[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-17-336791)](https://www.postgresql.org) +[![MySQL](https://img.shields.io/badge/MySQL-8.0-4479A1)](https://www.mysql.com) +[![SQLite](https://img.shields.io/badge/SQLite-3-003B57)](https://www.sqlite.org) +[![DuckDB](https://img.shields.io/badge/DuckDB-1.x-FFF000)](https://duckdb.org) +[![BigQuery](https://img.shields.io/badge/BigQuery-GCP-4285F4)](https://cloud.google.com/bigquery) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) [![Benchmarks](https://img.shields.io/badge/benchmarks-performance%20tracking-green)](https://spandigital.github.io/cel2sql/dev/bench/) -**cel2sql** makes it easy to build dynamic SQL queries using CEL expressions. Write type-safe, expressive filters in CEL and automatically convert them to PostgreSQL-compatible SQL. +**cel2sql** makes it easy to build dynamic SQL queries using CEL expressions. Write type-safe, expressive filters in CEL and automatically convert them to SQL for your database of choice. ## Quick Start @@ -61,10 +65,10 @@ func main() { ## Why cel2sql? +✅ **Multi-Dialect**: PostgreSQL, MySQL, SQLite, DuckDB, and BigQuery from a single API ✅ **Type-Safe**: Catch errors at compile time, not runtime -✅ **PostgreSQL 17**: Fully compatible with the latest PostgreSQL ✅ **Rich Features**: JSON/JSONB, arrays, regex, timestamps, and more -✅ **Well-Tested**: 100+ tests including integration tests with real PostgreSQL +✅ **Well-Tested**: 100+ tests including integration tests with real databases ✅ **Easy to Use**: Simple API, comprehensive documentation ✅ **Secure by Default**: Built-in protections against SQL injection and ReDoS attacks ✅ **Performance Tracked**: [Continuous benchmark monitoring](https://spandigital.github.io/cel2sql/dev/bench/) to prevent regressions @@ -117,29 +121,92 @@ sql, err := cel2sql.Convert(ast, ``` **Available Options:** +- `WithDialect(dialect.Dialect)` - Select target SQL dialect (default: PostgreSQL) - `WithSchemas(map[string]pg.Schema)` - Provide table schemas for JSON detection - `WithContext(context.Context)` - Enable cancellation and timeouts - `WithLogger(*slog.Logger)` - Enable structured logging - `WithMaxDepth(int)` - Set custom recursion depth limit (default: 100) +## Multi-Dialect Support + +cel2sql supports 5 SQL dialects. PostgreSQL is the default; select other dialects with `WithDialect()`: + +```go +import ( + "github.com/spandigital/cel2sql/v3" + "github.com/spandigital/cel2sql/v3/dialect/mysql" + "github.com/spandigital/cel2sql/v3/dialect/sqlite" + "github.com/spandigital/cel2sql/v3/dialect/duckdb" + "github.com/spandigital/cel2sql/v3/dialect/bigquery" +) + +// PostgreSQL (default - no option needed) +sql, err := cel2sql.Convert(ast) + +// MySQL +sql, err := cel2sql.Convert(ast, cel2sql.WithDialect(mysql.New())) + +// SQLite +sql, err := cel2sql.Convert(ast, cel2sql.WithDialect(sqlite.New())) + +// DuckDB +sql, err := cel2sql.Convert(ast, cel2sql.WithDialect(duckdb.New())) + +// BigQuery +sql, err := cel2sql.Convert(ast, cel2sql.WithDialect(bigquery.New())) +``` + +### Dialect Comparison + +| Feature | PostgreSQL | MySQL | SQLite | DuckDB | BigQuery | +|---------|-----------|-------|--------|--------|----------| +| String concat | `\|\|` | `CONCAT()` | `\|\|` | `\|\|` | `\|\|` | +| Regex | `~ / ~*` | `REGEXP` | unsupported | `~ / ~*` | `REGEXP_CONTAINS()` | +| JSON access | `->>'f'` | `->>'$.f'` | `json_extract()` | `->>'f'` | `JSON_VALUE()` | +| Arrays | `ARRAY[...]` | JSON arrays | JSON arrays | `[...]` | `[...]` | +| UNNEST | `UNNEST(x)` | `JSON_TABLE(...)` | `json_each(x)` | `UNNEST(x)` | `UNNEST(x)` | +| Param placeholder | `$1, $2` | `?, ?` | `?, ?` | `$1, $2` | `@p1, @p2` | +| Timestamp cast | `TIMESTAMP WITH TIME ZONE` | `DATETIME` | `datetime()` | `TIMESTAMPTZ` | `TIMESTAMP` | +| Contains | `POSITION()` | `LOCATE()` | `INSTR()` | `CONTAINS()` | `INSTR()` | +| Index analysis | BTREE, GIN, GIN+trgm | BTREE, FULLTEXT | BTREE | ART | CLUSTERING, SEARCH_INDEX | + +### Per-Dialect Type Providers + +Each dialect has its own type provider for mapping database types to CEL types. All providers support both pre-defined schemas (`NewTypeProvider`) and dynamic schema loading (`LoadTableSchema`): + +```go +import "github.com/spandigital/cel2sql/v3/pg" // PostgreSQL (pgxpool connection string) +import "github.com/spandigital/cel2sql/v3/mysql" // MySQL (*sql.DB) +import "github.com/spandigital/cel2sql/v3/sqlite" // SQLite (*sql.DB) +import "github.com/spandigital/cel2sql/v3/duckdb" // DuckDB (*sql.DB) +import "github.com/spandigital/cel2sql/v3/bigquery" // BigQuery (*bigquery.Client) +``` + ## Query Analysis and Index Recommendations -cel2sql can analyze your CEL queries and recommend database indexes to optimize performance. The `AnalyzeQuery()` function returns both the converted SQL and actionable index recommendations. +cel2sql can analyze your CEL queries and recommend database indexes to optimize performance. The `AnalyzeQuery()` function returns both the converted SQL and **dialect-specific** index recommendations. ### How It Works -`AnalyzeQuery()` examines your CEL expression and detects patterns that would benefit from specific PostgreSQL index types: +`AnalyzeQuery()` examines your CEL expression and detects patterns that would benefit from indexing, then generates dialect-appropriate DDL: -- **JSON/JSONB path operations** (`->>, ?`) → GIN indexes -- **Array operations** (comprehensions, `IN` clauses) → GIN indexes -- **Regex matching** (`matches()`) → GIN indexes with `pg_trgm` extension -- **Comparison operations** (`==, >, <, >=, <=`) → B-tree indexes +- **Comparison operations** (`==, >, <, >=, <=`) → B-tree (PG/MySQL/SQLite), ART (DuckDB), Clustering (BigQuery) +- **JSON/JSONB path operations** (`->>, ?`) → GIN (PG), functional index (MySQL), Search Index (BigQuery), ART (DuckDB) +- **Regex matching** (`matches()`) → GIN with pg_trgm (PG), FULLTEXT (MySQL) +- **Array operations** (comprehensions, `IN` clauses) → GIN (PG), ART (DuckDB) ### Usage ```go +// PostgreSQL (default dialect) sql, recommendations, err := cel2sql.AnalyzeQuery(ast, cel2sql.WithSchemas(schemas)) + +// Or specify a dialect +sql, recommendations, err := cel2sql.AnalyzeQuery(ast, + cel2sql.WithSchemas(schemas), + cel2sql.WithDialect(mysql.New())) + if err != nil { log.Fatal(err) } @@ -153,38 +220,48 @@ for _, rec := range recommendations { fmt.Printf("Type: %s\n", rec.IndexType) fmt.Printf("Reason: %s\n", rec.Reason) fmt.Printf("Execute: %s\n\n", rec.Expression) - - // Apply the recommendation - // _, err := db.Exec(rec.Expression) } ``` +### Per-Dialect Index Types + +| Pattern | PostgreSQL | MySQL | SQLite | DuckDB | BigQuery | +|---------|-----------|-------|--------|--------|----------| +| Comparison | BTREE | BTREE | BTREE | ART | CLUSTERING | +| JSON access | GIN | BTREE (functional) | _(skip)_ | ART | SEARCH_INDEX | +| Regex | GIN + pg_trgm | FULLTEXT | _(skip)_ | _(skip)_ | _(skip)_ | +| Array membership | GIN | _(skip)_ | _(skip)_ | ART | _(skip)_ | +| Comprehension | GIN | _(skip)_ | _(skip)_ | ART | _(skip)_ | + +Unsupported patterns are silently skipped (no recommendation emitted). + ### Example ```go -// Query with multiple index-worthy patterns -celExpr := `person.age > 18 && - person.email.matches(r"@example\.com$") && - person.metadata.verified == true` - +celExpr := `person.age > 18 && person.metadata.verified == true` ast, _ := env.Compile(celExpr) -sql, recs, _ := cel2sql.AnalyzeQuery(ast, cel2sql.WithSchemas(schemas)) -// Generated SQL: -// person.age > 18 AND person.email ~ '@example\.com$' -// AND person.metadata->>'verified' = 'true' +// PostgreSQL recommendations +sql, recs, _ := cel2sql.AnalyzeQuery(ast, cel2sql.WithSchemas(schemas)) +// Recommendations: +// 1. CREATE INDEX idx_person_age_btree ON table_name (person.age); +// 2. CREATE INDEX idx_person_metadata_gin ON table_name USING GIN (person.metadata); +// MySQL recommendations +sql, recs, _ = cel2sql.AnalyzeQuery(ast, + cel2sql.WithSchemas(schemas), + cel2sql.WithDialect(mysql.New())) // Recommendations: // 1. CREATE INDEX idx_person_age_btree ON table_name (person.age); -// Reason: Comparison operations benefit from B-tree for range queries -// -// 2. CREATE INDEX idx_person_email_gin_trgm ON table_name -// USING GIN (person.email gin_trgm_ops); -// Reason: Regex matching benefits from GIN index with pg_trgm -// -// 3. CREATE INDEX idx_person_metadata_gin ON table_name -// USING GIN (person.metadata); -// Reason: JSON path operations benefit from GIN index +// 2. CREATE INDEX idx_person_metadata_json ON table_name ((CAST(person.metadata->>'$.path' AS CHAR(255)))); + +// BigQuery recommendations +sql, recs, _ = cel2sql.AnalyzeQuery(ast, + cel2sql.WithSchemas(schemas), + cel2sql.WithDialect(bigquery.New())) +// Recommendations: +// 1. ALTER TABLE table_name SET OPTIONS (clustering_columns=['person.age']); +// 2. CREATE SEARCH INDEX idx_person_metadata ON table_name (person.metadata); ``` ### When to Use @@ -193,7 +270,7 @@ sql, recs, _ := cel2sql.AnalyzeQuery(ast, cel2sql.WithSchemas(schemas)) - **Performance tuning**: Identify missing indexes causing slow queries - **Production monitoring**: Analyze user-generated filter expressions -See `examples/index_analysis/` for a complete working example. +See `examples/index_analysis/` for a complete working example with all 5 dialects. ## Parameterized Queries @@ -393,42 +470,159 @@ See [Regex Matching documentation](docs/regex-matching.md) for complete details, ## Type Mapping -| CEL Type | PostgreSQL Type | -|----------|-----------------| -| `int` | `bigint` | -| `double` | `double precision` | -| `bool` | `boolean` | -| `string` | `text` | -| `bytes` | `bytea` | -| `list` | `ARRAY` | -| `timestamp` | `timestamp with time zone` | -| `duration` | `INTERVAL` | +| CEL Type | PostgreSQL | MySQL | SQLite | DuckDB | BigQuery | +|----------|-----------|-------|--------|--------|----------| +| `int` | `bigint` | `SIGNED` | `INTEGER` | `BIGINT` | `INT64` | +| `double` | `double precision` | `DECIMAL` | `REAL` | `DOUBLE` | `FLOAT64` | +| `bool` | `boolean` | `UNSIGNED` | `INTEGER` | `BOOLEAN` | `BOOL` | +| `string` | `text` | `CHAR` | `TEXT` | `VARCHAR` | `STRING` | +| `bytes` | `bytea` | `BINARY` | `BLOB` | `BLOB` | `BYTES` | +| `list` | `ARRAY` | JSON array | JSON array | `LIST` | `ARRAY` | +| `timestamp` | `timestamptz` | `DATETIME` | `datetime()` | `TIMESTAMPTZ` | `TIMESTAMP` | +| `duration` | `INTERVAL` | `INTERVAL` | string modifier | `INTERVAL` | `INTERVAL` | ## Dynamic Schema Loading -Load table schemas directly from your PostgreSQL database: +Load table schemas directly from your database at runtime instead of defining them manually. Each dialect provider supports introspecting table schemas from a live database connection. + +### PostgreSQL ```go -// Connect to database and load schema +import "github.com/spandigital/cel2sql/v3/pg" + +// PostgreSQL accepts a connection string and manages its own connection pool provider, _ := pg.NewTypeProviderWithConnection(ctx, "postgres://user:pass@localhost/db") defer provider.Close() -// Load table schema dynamically provider.LoadTableSchema(ctx, "users") -// Use with CEL env, _ := cel.NewEnv( cel.CustomTypeProvider(provider), cel.Variable("user", cel.ObjectType("users")), ) ``` +### MySQL + +```go +import ( + "database/sql" + _ "github.com/go-sql-driver/mysql" + "github.com/spandigital/cel2sql/v3/mysql" +) + +// MySQL accepts a *sql.DB — you own the connection +db, _ := sql.Open("mysql", "user:pass@tcp(localhost:3306)/mydb?parseTime=true") +defer db.Close() + +provider, _ := mysql.NewTypeProviderWithConnection(ctx, db) +provider.LoadTableSchema(ctx, "users") + +env, _ := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("user", cel.ObjectType("users")), +) + +sql, _ := cel2sql.Convert(ast, cel2sql.WithDialect(mysqlDialect.New()), + cel2sql.WithSchemas(provider.GetSchemas())) +``` + +### SQLite + +```go +import ( + "database/sql" + _ "modernc.org/sqlite" + "github.com/spandigital/cel2sql/v3/sqlite" +) + +db, _ := sql.Open("sqlite", "mydb.sqlite") +defer db.Close() + +provider, _ := sqlite.NewTypeProviderWithConnection(ctx, db) +provider.LoadTableSchema(ctx, "users") + +env, _ := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("user", cel.ObjectType("users")), +) + +sql, _ := cel2sql.Convert(ast, cel2sql.WithDialect(sqliteDialect.New()), + cel2sql.WithSchemas(provider.GetSchemas())) +``` + +### DuckDB + +```go +import ( + "database/sql" + "github.com/spandigital/cel2sql/v3/duckdb" +) + +// DuckDB accepts *sql.DB — works with any DuckDB driver (requires CGO) +db, _ := sql.Open("duckdb", "mydb.duckdb") +defer db.Close() + +provider, _ := duckdb.NewTypeProviderWithConnection(ctx, db) +provider.LoadTableSchema(ctx, "users") + +env, _ := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("user", cel.ObjectType("users")), +) + +sql, _ := cel2sql.Convert(ast, cel2sql.WithDialect(duckdbDialect.New()), + cel2sql.WithSchemas(provider.GetSchemas())) +``` + +### BigQuery + +```go +import ( + "cloud.google.com/go/bigquery" + bqprovider "github.com/spandigital/cel2sql/v3/bigquery" +) + +// BigQuery uses the BigQuery client API (not database/sql) +client, _ := bigquery.NewClient(ctx, "my-project") +defer client.Close() + +provider, _ := bqprovider.NewTypeProviderWithClient(ctx, client, "my_dataset") +provider.LoadTableSchema(ctx, "users") + +env, _ := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("user", cel.ObjectType("users")), +) + +sql, _ := cel2sql.Convert(ast, cel2sql.WithDialect(bigqueryDialect.New()), + cel2sql.WithSchemas(provider.GetSchemas())) +``` + +### Notes + +- **PostgreSQL** manages its own connection pool via `pgxpool` — call `provider.Close()` when done. +- **MySQL, SQLite, DuckDB** accept a `*sql.DB` you provide — you own the connection lifecycle. `Close()` is a no-op. +- **BigQuery** accepts a `*bigquery.Client` + dataset ID — you own the client lifecycle. `Close()` is a no-op. +- All providers also support pre-defined schemas via `NewTypeProvider(schemas)` if you don't need runtime introspection. + See [Getting Started Guide](docs/getting-started.md) for more details. ## Requirements - Go 1.24 or higher -- PostgreSQL 17 (also compatible with PostgreSQL 15+) + +### CGO Requirement (DuckDB only) + +The DuckDB dialect's `LoadTableSchema` requires a DuckDB Go driver (e.g., `github.com/marcboeker/go-duckdb`) which depends on **CGO** and a C/C++ compiler. This means: + +- You must have `CGO_ENABLED=1` (the Go default on most platforms) +- A C/C++ compiler must be installed (GCC, Clang, or MSVC) +- Cross-compilation requires a C cross-compiler for the target platform + +**All other dialects (PostgreSQL, MySQL, SQLite, BigQuery) use pure Go drivers and do not require CGO.** + +If you only use DuckDB with pre-defined schemas via `duckdb.NewTypeProvider()` (no live database connection), CGO is **not** required. ## Contributing diff --git a/analysis.go b/analysis.go index e1bbd46..3dbbfeb 100644 --- a/analysis.go +++ b/analysis.go @@ -4,16 +4,18 @@ package cel2sql import ( "fmt" "log/slog" - "strings" "time" "github.com/google/cel-go/cel" "github.com/google/cel-go/common/operators" "github.com/google/cel-go/common/overloads" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/dialect/postgres" ) -// Index type constants for recommendations +// Index type constants for recommendations (kept for backward compatibility). const ( // IndexTypeBTree represents a B-tree index for efficient range queries and equality checks IndexTypeBTree = "BTREE" @@ -29,10 +31,10 @@ type IndexRecommendation struct { // Column is the database column that should be indexed Column string - // IndexType specifies the PostgreSQL index type (e.g., "BTREE", "GIN", "GIST") + // IndexType specifies the index type (e.g., "BTREE", "GIN", "ART", "CLUSTERING") IndexType string - // Expression is the complete CREATE INDEX statement that can be executed directly + // Expression is the complete DDL statement that can be executed directly Expression string // Reason explains why this index is recommended and what query patterns it optimizes @@ -44,21 +46,25 @@ type analysisConverter struct { *converter recommendations map[string]*IndexRecommendation // Key: column name, Value: recommendation visitedColumns map[string]bool // Track which columns have been accessed + advisor dialect.IndexAdvisor // Dialect-specific index advisor } -// AnalyzeQuery converts a CEL AST to PostgreSQL SQL and provides index recommendations. +// AnalyzeQuery converts a CEL AST to SQL and provides dialect-specific index recommendations. // It analyzes the query patterns to suggest indexes that would optimize performance. // // The function detects patterns that benefit from specific index types: -// - JSON/JSONB path operations (->>, ?) → GIN indexes -// - Array operations (UNNEST, comprehensions) → GIN indexes -// - Regex matching (matches()) → GIN indexes with pg_trgm extension -// - Frequently accessed fields in comparisons → B-tree indexes +// - JSON/JSONB path operations → GIN indexes (PostgreSQL), functional indexes (MySQL), search indexes (BigQuery) +// - Array operations → GIN indexes (PostgreSQL), ART indexes (DuckDB) +// - Regex matching → GIN indexes with pg_trgm (PostgreSQL), FULLTEXT indexes (MySQL) +// - Comparison operations → B-tree indexes (PostgreSQL/MySQL/SQLite), ART (DuckDB), clustering (BigQuery) +// +// Use WithDialect() to get dialect-specific index recommendations. Defaults to PostgreSQL. // // Example: // // sql, recommendations, err := cel2sql.AnalyzeQuery(ast, -// cel2sql.WithSchemas(schemas)) +// cel2sql.WithSchemas(schemas), +// cel2sql.WithDialect(mysql.New())) // if err != nil { // return err // } @@ -84,6 +90,19 @@ func AnalyzeQuery(ast *cel.Ast, opts ...ConvertOption) (string, []IndexRecommend opt(options) } + // Default to PostgreSQL dialect if none specified + d := options.dialect + if d == nil { + d = postgres.New() + } + + // Get the IndexAdvisor for the dialect (all built-in dialects implement it) + advisor, hasAdvisor := dialect.GetIndexAdvisor(d) + if !hasAdvisor { + // Fallback: use PostgreSQL advisor for backward compatibility + advisor = postgres.New() + } + // Convert AST to CheckedExpr checkedExpr, err := cel.AstToCheckedExpr(ast) if err != nil { @@ -104,6 +123,7 @@ func AnalyzeQuery(ast *cel.Ast, opts ...ConvertOption) (string, []IndexRecommend converter: baseConverter, recommendations: make(map[string]*IndexRecommendation), visitedColumns: make(map[string]bool), + advisor: advisor, } // Analyze the expression tree to collect index patterns @@ -121,6 +141,7 @@ func AnalyzeQuery(ast *cel.Ast, opts ...ConvertOption) (string, []IndexRecommend if options.logger != nil { options.logger.Debug("query analysis completed", "sql", sql, + "dialect", d.Name(), "recommendation_count", len(analyzer.recommendations), "duration", duration) } @@ -225,24 +246,18 @@ func (a *analysisConverter) analyzeCall(expr *exprpb.Expr) error { switch fun { case overloads.Matches: - // Regex matching benefits from GIN index with pg_trgm extension - if err := a.recommendRegexIndex(expr); err != nil { - return err - } + // Regex matching benefits from dialect-specific indexes + a.recommendRegexIndex(expr) case operators.Equals, operators.NotEquals, operators.Greater, operators.GreaterEquals, operators.Less, operators.LessEquals: - // Comparison operations benefit from B-tree indexes - if err := a.recommendComparisonIndex(expr); err != nil { - return err - } + // Comparison operations benefit from indexes + a.recommendComparisonIndex(expr) case operators.In: - // IN operations on arrays benefit from GIN indexes - if err := a.recommendArrayIndex(expr); err != nil { - return err - } + // IN operations on arrays benefit from indexes + a.recommendArrayIndex(expr) } return nil @@ -259,26 +274,13 @@ func (a *analysisConverter) analyzeComprehension(expr *exprpb.Expr) error { // Check if this is a JSON array comprehension if a.isJSONArrayField(iterRange) { - // Extract the column name from the iter range if column := a.extractColumnName(iterRange); column != "" { - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeGIN, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON table_name USING GIN (%s);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("JSONB array comprehension on '%s' benefits from GIN index for efficient array element access", column), - }) + a.recommendForPattern(column, dialect.PatternJSONArrayComprehension) } } else { // Regular array comprehension if column := a.extractColumnName(iterRange); column != "" { - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeGIN, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON table_name USING GIN (%s);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("Array comprehension on '%s' benefits from GIN index for efficient array operations", column), - }) + a.recommendForPattern(column, dialect.PatternArrayComprehension) } } @@ -299,13 +301,7 @@ func (a *analysisConverter) analyzeSelect(expr *exprpb.Expr) error { // Check if the parent field is JSON if a.isFieldJSON(tableName, operandField) { column := tableName + "." + operandField - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeGIN, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON table_name USING GIN (%s);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("JSON path operations on '%s' benefit from GIN index for efficient nested field access", column), - }) + a.recommendForPattern(column, dialect.PatternJSONAccess) } } } @@ -315,15 +311,9 @@ func (a *analysisConverter) analyzeSelect(expr *exprpb.Expr) error { tableName := identExpr.GetName() if a.isFieldJSON(tableName, fieldName) { column := tableName + "." + fieldName - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeGIN, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON table_name USING GIN (%s);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("JSON field '%s' benefits from GIN index for efficient access", column), - }) + a.recommendForPattern(column, dialect.PatternJSONAccess) } - // Track column access for potential B-tree indexes + // Track column access for potential indexes fullColumn := tableName + "." + fieldName a.visitedColumns[fullColumn] = true } @@ -331,8 +321,8 @@ func (a *analysisConverter) analyzeSelect(expr *exprpb.Expr) error { return nil } -// recommendRegexIndex recommends a GIN index with pg_trgm for regex operations -func (a *analysisConverter) recommendRegexIndex(expr *exprpb.Expr) error { +// recommendRegexIndex recommends an index for regex operations +func (a *analysisConverter) recommendRegexIndex(expr *exprpb.Expr) { c := expr.GetCallExpr() target := c.GetTarget() @@ -342,54 +332,38 @@ func (a *analysisConverter) recommendRegexIndex(expr *exprpb.Expr) error { if target != nil { if column := a.extractColumnName(target); column != "" { - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeGIN, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin_trgm ON table_name USING GIN (%s gin_trgm_ops);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("Regex matching on '%s' benefits from GIN index with pg_trgm extension for pattern matching", column), - }) + a.recommendForPattern(column, dialect.PatternRegexMatch) } } - - return nil } -// recommendComparisonIndex recommends a B-tree index for comparison operations -func (a *analysisConverter) recommendComparisonIndex(expr *exprpb.Expr) error { +// recommendComparisonIndex recommends an index for comparison operations +func (a *analysisConverter) recommendComparisonIndex(expr *exprpb.Expr) { c := expr.GetCallExpr() args := c.GetArgs() if len(args) < 2 { - return nil + return } lhs := args[0] // Extract column from left-hand side if column := a.extractColumnName(lhs); column != "" { - // Check if this is a JSON field (skip B-tree recommendation for JSON) + // Check if this is a JSON field (skip comparison recommendation for JSON) if !a.isJSONField(lhs) { - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeBTree, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_btree ON table_name (%s);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("Comparison operations on '%s' benefit from B-tree index for efficient range queries and equality checks", column), - }) + a.recommendForPattern(column, dialect.PatternComparison) } } - - return nil } -// recommendArrayIndex recommends a GIN index for array containment operations -func (a *analysisConverter) recommendArrayIndex(expr *exprpb.Expr) error { +// recommendArrayIndex recommends an index for array containment operations +func (a *analysisConverter) recommendArrayIndex(expr *exprpb.Expr) { c := expr.GetCallExpr() args := c.GetArgs() if len(args) < 2 { - return nil + return } rhs := args[1] @@ -397,17 +371,26 @@ func (a *analysisConverter) recommendArrayIndex(expr *exprpb.Expr) error { // Check if the right-hand side is an array field if a.isFieldArray(a.extractTableName(rhs), a.extractFieldName(rhs)) { if column := a.extractColumnName(rhs); column != "" { - a.addRecommendation(column, &IndexRecommendation{ - Column: column, - IndexType: IndexTypeGIN, - Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON table_name USING GIN (%s);", - sanitizeIndexName(column), column), - Reason: fmt.Sprintf("Array membership tests on '%s' benefit from GIN index for efficient element lookups", column), - }) + a.recommendForPattern(column, dialect.PatternArrayMembership) } } +} - return nil +// recommendForPattern asks the dialect's IndexAdvisor for a recommendation and stores it. +func (a *analysisConverter) recommendForPattern(column string, pattern dialect.PatternType) { + rec := a.advisor.RecommendIndex(dialect.IndexPattern{ + Column: column, + Pattern: pattern, + }) + if rec == nil { + return + } + a.addRecommendation(column, &IndexRecommendation{ + Column: rec.Column, + IndexType: rec.IndexType, + Expression: rec.Expression, + Reason: rec.Reason, + }) } // extractColumnName extracts the full column name (table.column) from an expression @@ -458,34 +441,25 @@ func (a *analysisConverter) isJSONField(expr *exprpb.Expr) bool { return false } -// addRecommendation adds or updates an index recommendation +// addRecommendation adds or updates an index recommendation. +// When a more specialized recommendation exists for a column, it takes priority. func (a *analysisConverter) addRecommendation(column string, rec *IndexRecommendation) { // Only add if we don't already have a recommendation for this column - // or if the new recommendation is more specific (e.g., GIN over BTREE) + // or if the new recommendation is more specific existing, exists := a.recommendations[column] if !exists { a.recommendations[column] = rec return } - // GIN indexes are more versatile than BTREE for JSON/array operations - // If we already have a BTREE recommendation and we're suggesting GIN, upgrade it - if existing.IndexType == IndexTypeBTree && rec.IndexType == IndexTypeGIN { + // More specialized index types take priority over basic B-tree/comparison indexes + if isBasicIndexType(existing.IndexType) && !isBasicIndexType(rec.IndexType) { a.recommendations[column] = rec } } -// sanitizeIndexName creates a safe index name from a column name -func sanitizeIndexName(column string) string { - // Replace dots and special characters with underscores - sanitized := strings.ReplaceAll(column, ".", "_") - sanitized = strings.ReplaceAll(sanitized, " ", "_") - sanitized = strings.ReplaceAll(sanitized, "-", "_") - - // PostgreSQL index names are limited to 63 characters - if len(sanitized) > 50 { - sanitized = sanitized[:50] - } - - return sanitized +// isBasicIndexType returns true if the index type is a basic comparison index +// that should be upgraded when a more specialized recommendation is available. +func isBasicIndexType(indexType string) bool { + return indexType == IndexTypeBTree || indexType == "ART" || indexType == "CLUSTERING" } diff --git a/analysis_test.go b/analysis_test.go index 212ae7c..5081f47 100644 --- a/analysis_test.go +++ b/analysis_test.go @@ -6,9 +6,21 @@ import ( "testing" "github.com/google/cel-go/cel" + "github.com/spandigital/cel2sql/v3/dialect" + dialectbq "github.com/spandigital/cel2sql/v3/dialect/bigquery" + dialectduckdb "github.com/spandigital/cel2sql/v3/dialect/duckdb" + dialectmysql "github.com/spandigital/cel2sql/v3/dialect/mysql" + dialectpg "github.com/spandigital/cel2sql/v3/dialect/postgres" + dialectsqlite "github.com/spandigital/cel2sql/v3/dialect/sqlite" "github.com/spandigital/cel2sql/v3/pg" ) +// Test column name constants to avoid repetition. +const ( + colPersonEmail = "person.email" + colPersonMetadata = "person.metadata" +) + func TestAnalyzeQuery_JSONPathOperations(t *testing.T) { schema := pg.NewSchema([]pg.FieldSchema{ {Name: "id", Type: "text"}, @@ -34,14 +46,14 @@ func TestAnalyzeQuery_JSONPathOperations(t *testing.T) { { name: "simple JSON path access", expression: `person.metadata.name == "John"`, - expectedColumn: "person.metadata", + expectedColumn: colPersonMetadata, expectedType: "GIN", expectReason: "JSON path operations", }, { name: "nested JSON path access", expression: `person.metadata.profile.age > 18`, - expectedColumn: "person.metadata", + expectedColumn: colPersonMetadata, expectedType: "GIN", expectReason: "JSON path operations", }, @@ -120,7 +132,7 @@ func TestAnalyzeQuery_RegexOperations(t *testing.T) { // Check that we got a GIN index recommendation with pg_trgm found := false for _, rec := range recommendations { - if rec.Column == "person.email" && rec.IndexType == IndexTypeGIN { + if rec.Column == colPersonEmail && rec.IndexType == IndexTypeGIN { found = true if !strings.Contains(rec.Reason, "Regex matching") { t.Errorf("expected reason to mention regex matching, got %q", rec.Reason) @@ -370,9 +382,9 @@ func TestAnalyzeQuery_MultipleRecommendations(t *testing.T) { switch rec.Column { case "person.age": foundAge = rec.IndexType == IndexTypeBTree - case "person.email": + case colPersonEmail: foundEmail = rec.IndexType == IndexTypeGIN - case "person.metadata": + case colPersonMetadata: foundMetadata = rec.IndexType == IndexTypeGIN } } @@ -494,10 +506,330 @@ func TestAnalyzeQuery_IndexRecommendationPriority(t *testing.T) { // We should get a GIN recommendation for metadata, not BTREE for _, rec := range recommendations { - if rec.Column == "person.metadata" { + if rec.Column == colPersonMetadata { if rec.IndexType != IndexTypeGIN { t.Errorf("expected GIN index for JSON field, got %s", rec.IndexType) } } } } + +func TestAnalyzeQuery_WithDialect(t *testing.T) { + // Test that each dialect produces its own appropriate index types and DDL + schema := pg.NewSchema([]pg.FieldSchema{ + {Name: "id", Type: "bigint"}, + {Name: "age", Type: "integer"}, + {Name: "email", Type: "text"}, + {Name: "tags", Type: "text", Repeated: true}, + {Name: "metadata", Type: "jsonb", IsJSON: true, IsJSONB: true}, + }) + provider := pg.NewTypeProvider(map[string]pg.Schema{"person": schema}) + + env, err := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("person", cel.ObjectType("person")), + ) + if err != nil { + t.Fatalf("failed to create CEL environment: %v", err) + } + + type dialectTestCase struct { + name string + dialect dialect.Dialect + // Per-pattern expected results + comparisonType string // Expected IndexType for comparisons + comparisonContain string // Substring expected in Expression + jsonType string // Expected IndexType for JSON access + jsonContain string // Substring expected in Expression + } + + dialects := []dialectTestCase{ + { + name: "PostgreSQL", + dialect: dialectpg.New(), + comparisonType: "BTREE", + comparisonContain: "CREATE INDEX", + jsonType: "GIN", + jsonContain: "USING GIN", + }, + { + name: "MySQL", + dialect: dialectmysql.New(), + comparisonType: "BTREE", + comparisonContain: "CREATE INDEX", + jsonType: "BTREE", + jsonContain: "CAST", + }, + { + name: "SQLite", + dialect: dialectsqlite.New(), + comparisonType: "BTREE", + comparisonContain: "CREATE INDEX", + jsonType: "", // SQLite doesn't support JSON indexes + jsonContain: "", + }, + { + name: "DuckDB", + dialect: dialectduckdb.New(), + comparisonType: "ART", + comparisonContain: "CREATE INDEX", + jsonType: "ART", + jsonContain: "CREATE INDEX", + }, + { + name: "BigQuery", + dialect: dialectbq.New(), + comparisonType: "CLUSTERING", + comparisonContain: "clustering_columns", + jsonType: "SEARCH_INDEX", + jsonContain: "SEARCH INDEX", + }, + } + + for _, dt := range dialects { + t.Run(dt.name+"_comparison", func(t *testing.T) { + ast, issues := env.Compile(`person.age > 18`) + if issues != nil && issues.Err() != nil { + t.Fatalf("failed to compile expression: %v", issues.Err()) + } + + _, recommendations, err := AnalyzeQuery(ast, + WithSchemas(provider.GetSchemas()), + WithDialect(dt.dialect)) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + found := false + for _, rec := range recommendations { + if rec.Column == "person.age" { + found = true + if rec.IndexType != dt.comparisonType { + t.Errorf("expected index type %q, got %q", dt.comparisonType, rec.IndexType) + } + if !strings.Contains(rec.Expression, dt.comparisonContain) { + t.Errorf("expected expression to contain %q, got %q", dt.comparisonContain, rec.Expression) + } + } + } + if !found { + t.Errorf("expected recommendation for person.age, got: %+v", recommendations) + } + }) + + t.Run(dt.name+"_json", func(t *testing.T) { + ast, issues := env.Compile(`person.metadata.verified == true`) + if issues != nil && issues.Err() != nil { + t.Fatalf("failed to compile expression: %v", issues.Err()) + } + + _, recommendations, err := AnalyzeQuery(ast, + WithSchemas(provider.GetSchemas()), + WithDialect(dt.dialect)) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + if dt.jsonType == "" { + // This dialect doesn't recommend JSON indexes; verify none present for metadata + for _, rec := range recommendations { + if rec.Column == colPersonMetadata { + t.Errorf("expected no recommendation for JSON on %s, got: %+v", dt.name, rec) + } + } + return + } + + found := false + for _, rec := range recommendations { + if rec.Column == colPersonMetadata { + found = true + if rec.IndexType != dt.jsonType { + t.Errorf("expected index type %q, got %q", dt.jsonType, rec.IndexType) + } + if !strings.Contains(rec.Expression, dt.jsonContain) { + t.Errorf("expected expression to contain %q, got %q", dt.jsonContain, rec.Expression) + } + } + } + if !found { + t.Errorf("expected JSON recommendation for person.metadata on %s, got: %+v", dt.name, recommendations) + } + }) + } +} + +func TestAnalyzeQuery_UnsupportedPatternReturnsNil(t *testing.T) { + // SQLite should not produce recommendations for regex patterns + schema := pg.NewSchema([]pg.FieldSchema{ + {Name: "id", Type: "text"}, + {Name: "email", Type: "text"}, + }) + provider := pg.NewTypeProvider(map[string]pg.Schema{"person": schema}) + + env, err := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("person", cel.ObjectType("person")), + ) + if err != nil { + t.Fatalf("failed to create CEL environment: %v", err) + } + + // Note: We use person.email == "test" rather than matches() because SQLite + // doesn't support regex in SQL generation. We test the advisor directly instead. + advisor := dialectsqlite.New() + rec := advisor.RecommendIndex(dialect.IndexPattern{ + Column: colPersonEmail, + Pattern: dialect.PatternRegexMatch, + }) + if rec != nil { + t.Errorf("expected nil recommendation for regex on SQLite, got: %+v", rec) + } + + // Also verify SQLite returns nil for array patterns + rec = advisor.RecommendIndex(dialect.IndexPattern{ + Column: "person.tags", + Pattern: dialect.PatternArrayMembership, + }) + if rec != nil { + t.Errorf("expected nil recommendation for array membership on SQLite, got: %+v", rec) + } + + // But comparisons should still work + ast, issues := env.Compile(`person.email == "test@example.com"`) + if issues != nil && issues.Err() != nil { + t.Fatalf("failed to compile expression: %v", issues.Err()) + } + + _, recommendations, err := AnalyzeQuery(ast, + WithSchemas(provider.GetSchemas()), + WithDialect(dialectsqlite.New())) + if err != nil { + t.Fatalf("AnalyzeQuery failed: %v", err) + } + + found := false + for _, rec := range recommendations { + if rec.Column == colPersonEmail && rec.IndexType == "BTREE" { + found = true + } + } + if !found { + t.Errorf("expected BTREE recommendation for person.email on SQLite, got: %+v", recommendations) + } +} + +func TestAnalyzeQuery_AllDialectsSupportsIndexAnalysis(t *testing.T) { + // Verify that all built-in dialects report SupportsIndexAnalysis() = true + dialects := []dialect.Dialect{ + dialectpg.New(), + dialectmysql.New(), + dialectsqlite.New(), + dialectduckdb.New(), + dialectbq.New(), + } + + for _, d := range dialects { + t.Run(string(d.Name()), func(t *testing.T) { + if !d.SupportsIndexAnalysis() { + t.Errorf("%s should support index analysis", d.Name()) + } + + // Also verify the dialect implements IndexAdvisor + advisor, ok := dialect.GetIndexAdvisor(d) + if !ok { + t.Fatalf("%s does not implement IndexAdvisor", d.Name()) + } + + patterns := advisor.SupportedPatterns() + if len(patterns) == 0 { + t.Errorf("%s reports no supported patterns", d.Name()) + } + }) + } +} + +func TestAnalyzeQuery_IndexAdvisorSupportedPatterns(t *testing.T) { + tests := []struct { + name string + dialect dialect.Dialect + expectedPatterns []dialect.PatternType + }{ + { + name: "PostgreSQL supports all patterns", + dialect: dialectpg.New(), + expectedPatterns: []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternRegexMatch, + dialect.PatternArrayMembership, + dialect.PatternArrayComprehension, + dialect.PatternJSONArrayComprehension, + }, + }, + { + name: "MySQL supports comparison, JSON, regex, JSON array", + dialect: dialectmysql.New(), + expectedPatterns: []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternRegexMatch, + dialect.PatternJSONArrayComprehension, + }, + }, + { + name: "SQLite supports only comparison", + dialect: dialectsqlite.New(), + expectedPatterns: []dialect.PatternType{ + dialect.PatternComparison, + }, + }, + { + name: "DuckDB supports comparison, JSON, arrays", + dialect: dialectduckdb.New(), + expectedPatterns: []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternArrayMembership, + dialect.PatternArrayComprehension, + dialect.PatternJSONArrayComprehension, + }, + }, + { + name: "BigQuery supports comparison, JSON, JSON array", + dialect: dialectbq.New(), + expectedPatterns: []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternJSONArrayComprehension, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + advisor, ok := dialect.GetIndexAdvisor(tt.dialect) + if !ok { + t.Fatalf("dialect does not implement IndexAdvisor") + } + + patterns := advisor.SupportedPatterns() + if len(patterns) != len(tt.expectedPatterns) { + t.Errorf("expected %d patterns, got %d: %v", len(tt.expectedPatterns), len(patterns), patterns) + } + + for _, expected := range tt.expectedPatterns { + found := false + for _, actual := range patterns { + if actual == expected { + found = true + break + } + } + if !found { + t.Errorf("expected pattern %d not found in supported patterns", expected) + } + } + }) + } +} diff --git a/bigquery/provider.go b/bigquery/provider.go new file mode 100644 index 0000000..ce6dd57 --- /dev/null +++ b/bigquery/provider.go @@ -0,0 +1,228 @@ +// Package bigquery provides BigQuery type provider for CEL type system integration. +package bigquery + +import ( + "context" + "errors" + "fmt" + "strings" + + bq "cloud.google.com/go/bigquery" + "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/spandigital/cel2sql/v3/schema" +) + +// Sentinel errors for the bigquery package. +var ( + // ErrInvalidSchema indicates a problem with the provided schema or database introspection. + ErrInvalidSchema = errors.New("invalid schema") +) + +// FieldSchema is an alias for schema.FieldSchema. +type FieldSchema = schema.FieldSchema + +// Schema is an alias for schema.Schema. +type Schema = schema.Schema + +// NewSchema creates a new Schema. This is an alias for schema.NewSchema. +var NewSchema = schema.NewSchema + +// TypeProvider interface for BigQuery type providers. +type TypeProvider interface { + types.Provider + LoadTableSchema(ctx context.Context, tableName string) error + GetSchemas() map[string]Schema + Close() +} + +type typeProvider struct { + schemas map[string]Schema + client *bq.Client + datasetID string +} + +// NewTypeProvider creates a new BigQuery type provider with pre-defined schemas. +func NewTypeProvider(schemas map[string]Schema) TypeProvider { + return &typeProvider{schemas: schemas} +} + +// NewTypeProviderWithClient creates a new BigQuery type provider that can introspect database schemas. +// The caller owns the *bigquery.Client and is responsible for closing it. +func NewTypeProviderWithClient(_ context.Context, client *bq.Client, datasetID string) (TypeProvider, error) { + if client == nil { + return nil, fmt.Errorf("%w: BigQuery client must not be nil", ErrInvalidSchema) + } + if datasetID == "" { + return nil, fmt.Errorf("%w: dataset ID must not be empty", ErrInvalidSchema) + } + + return &typeProvider{ + schemas: make(map[string]Schema), + client: client, + datasetID: datasetID, + }, nil +} + +// LoadTableSchema loads schema information for a table from BigQuery using the client API. +func (tp *typeProvider) LoadTableSchema(ctx context.Context, tableName string) error { + if tp.client == nil { + return fmt.Errorf("%w: no BigQuery client available", ErrInvalidSchema) + } + + meta, err := tp.client.Dataset(tp.datasetID).Table(tableName).Metadata(ctx) + if err != nil { + return fmt.Errorf("%w: failed to get table metadata", ErrInvalidSchema) + } + + fields := bigquerySchemaToFieldSchemas(meta.Schema) + if len(fields) == 0 { + return fmt.Errorf("%w: table %q has no columns", ErrInvalidSchema, tableName) + } + + tp.schemas[tableName] = NewSchema(fields) + return nil +} + +// bigquerySchemaToFieldSchemas converts a BigQuery schema to a slice of FieldSchemas. +func bigquerySchemaToFieldSchemas(bqSchema bq.Schema) []FieldSchema { + fields := make([]FieldSchema, 0, len(bqSchema)) + for _, f := range bqSchema { + fields = append(fields, bigqueryFieldToFieldSchema(f)) + } + return fields +} + +// bigqueryFieldToFieldSchema converts a BigQuery FieldSchema to our FieldSchema. +func bigqueryFieldToFieldSchema(f *bq.FieldSchema) FieldSchema { + typeName := bigqueryFieldTypeToString(f.Type) + isJSON := f.Type == bq.JSONFieldType + repeated := f.Repeated + + field := FieldSchema{ + Name: f.Name, + Type: typeName, + Repeated: repeated, + IsJSON: isJSON, + } + + // Handle nested RECORD types recursively + if f.Type == bq.RecordFieldType && len(f.Schema) > 0 { + field.Schema = bigquerySchemaToFieldSchemas(f.Schema) + } + + if repeated { + field.Dimensions = 1 + field.ElementType = typeName + } + + return field +} + +// bigqueryFieldTypeToString converts a BigQuery FieldType to a string type name. +func bigqueryFieldTypeToString(ft bq.FieldType) string { + return strings.ToLower(string(ft)) +} + +// Close is a no-op since we don't own the *bigquery.Client. +func (tp *typeProvider) Close() { + // No-op: caller owns the *bigquery.Client connection +} + +// GetSchemas returns the schemas known to this type provider. +func (tp *typeProvider) GetSchemas() map[string]Schema { + return tp.schemas +} + +// EnumValue implements types.Provider. +func (tp *typeProvider) EnumValue(_ string) ref.Val { + return types.NewErr("unknown enum value") +} + +// FindIdent implements types.Provider. +func (tp *typeProvider) FindIdent(_ string) (ref.Val, bool) { + return nil, false +} + +// FindStructType implements types.Provider. +func (tp *typeProvider) FindStructType(structType string) (*types.Type, bool) { + if _, ok := tp.schemas[structType]; ok { + return types.NewObjectType(structType), true + } + return nil, false +} + +// FindStructFieldNames implements types.Provider. +func (tp *typeProvider) FindStructFieldNames(structType string) ([]string, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + fields := s.Fields() + names := make([]string, len(fields)) + for i, f := range fields { + names[i] = f.Name + } + return names, true +} + +// FindStructFieldType implements types.Provider. +func (tp *typeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + field, found := s.FindField(fieldName) + if !found { + return nil, false + } + + exprType := bigqueryTypeToCELExprType(field) + celType, err := types.ExprTypeToType(exprType) + if err != nil { + return nil, false + } + + return &types.FieldType{ + Type: celType, + }, true +} + +// NewValue implements types.Provider. +func (tp *typeProvider) NewValue(_ string, _ map[string]ref.Val) ref.Val { + return types.NewErr("unknown type in schema") +} + +// bigqueryTypeToCELExprType converts a BigQuery field schema to a CEL expression type. +func bigqueryTypeToCELExprType(field *schema.FieldSchema) *exprpb.Type { + baseType := bigqueryBaseTypeToCEL(field.Type) + if field.Repeated { + return decls.NewListType(baseType) + } + return baseType +} + +// bigqueryBaseTypeToCEL converts a BigQuery type name to a CEL expression type. +func bigqueryBaseTypeToCEL(typeName string) *exprpb.Type { + switch typeName { + case "STRING", "string": + return decls.String + case "INT64", "int64", "INTEGER", "integer": + return decls.Int + case "FLOAT64", "float64", "FLOAT", "float", "NUMERIC", "numeric": + return decls.Double + case "BOOL", "bool", "BOOLEAN", "boolean": + return decls.Bool + case "BYTES", "bytes": + return decls.Bytes + case "JSON", "json": + return decls.Dyn + case "TIMESTAMP", "timestamp": + return decls.Timestamp + default: + return decls.Dyn + } +} diff --git a/bigquery/provider_test.go b/bigquery/provider_test.go new file mode 100644 index 0000000..0583d24 --- /dev/null +++ b/bigquery/provider_test.go @@ -0,0 +1,274 @@ +package bigquery_test + +import ( + "bytes" + "context" + _ "embed" + "runtime" + "strings" + "testing" + + bq "cloud.google.com/go/bigquery" + "github.com/google/cel-go/common/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + tcbigquery "github.com/testcontainers/testcontainers-go/modules/gcloud/bigquery" + "google.golang.org/api/option" + "google.golang.org/api/option/internaloption" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/spandigital/cel2sql/v3/bigquery" +) + +//go:embed testdata/provider_seed.yaml +var providerSeedYAML []byte + +const ( + testProjectID = "test-project" + testDataset = "testdataset" +) + +func TestNewTypeProvider(t *testing.T) { + schemas := map[string]bigquery.Schema{ + "users": bigquery.NewSchema([]bigquery.FieldSchema{ + {Name: "id", Type: "INTEGER"}, + {Name: "name", Type: "STRING"}, + }), + } + + provider := bigquery.NewTypeProvider(schemas) + require.NotNil(t, provider) + + got := provider.GetSchemas() + assert.Len(t, got, 1) + assert.Contains(t, got, "users") +} + +func TestNewTypeProviderWithClient_NilClient(t *testing.T) { + _, err := bigquery.NewTypeProviderWithClient(context.Background(), nil, "dataset") + require.Error(t, err) + assert.ErrorIs(t, err, bigquery.ErrInvalidSchema) +} + +func TestNewTypeProviderWithClient_EmptyDataset(t *testing.T) { + // We can't create a real client without credentials, so test the nil case + _, err := bigquery.NewTypeProviderWithClient(context.Background(), nil, "") + require.Error(t, err) + assert.ErrorIs(t, err, bigquery.ErrInvalidSchema) +} + +func TestLoadTableSchema_NoClient(t *testing.T) { + provider := bigquery.NewTypeProvider(nil) + err := provider.LoadTableSchema(context.Background(), "test") + require.Error(t, err) + assert.ErrorIs(t, err, bigquery.ErrInvalidSchema) +} + +func TestTypeProvider_FindStructType(t *testing.T) { + schemas := map[string]bigquery.Schema{ + "users": bigquery.NewSchema([]bigquery.FieldSchema{ + {Name: "id", Type: "INTEGER"}, + {Name: "name", Type: "STRING"}, + }), + } + provider := bigquery.NewTypeProvider(schemas) + + typ, found := provider.FindStructType("users") + assert.True(t, found) + assert.NotNil(t, typ) + + _, found = provider.FindStructType("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldNames(t *testing.T) { + schemas := map[string]bigquery.Schema{ + "users": bigquery.NewSchema([]bigquery.FieldSchema{ + {Name: "id", Type: "INTEGER"}, + {Name: "name", Type: "STRING"}, + {Name: "email", Type: "STRING"}, + }), + } + provider := bigquery.NewTypeProvider(schemas) + + names, found := provider.FindStructFieldNames("users") + assert.True(t, found) + assert.ElementsMatch(t, []string{"id", "name", "email"}, names) + + _, found = provider.FindStructFieldNames("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldType(t *testing.T) { + schemas := map[string]bigquery.Schema{ + "test_table": bigquery.NewSchema([]bigquery.FieldSchema{ + {Name: "str_field", Type: "STRING"}, + {Name: "int_field", Type: "INTEGER"}, + {Name: "int64_field", Type: "INT64"}, + {Name: "float_field", Type: "FLOAT64"}, + {Name: "bool_field", Type: "BOOL"}, + {Name: "bytes_field", Type: "BYTES"}, + {Name: "json_field", Type: "JSON"}, + {Name: "ts_field", Type: "TIMESTAMP"}, + {Name: "str_lower", Type: "string"}, + {Name: "int_lower", Type: "integer"}, + }), + } + provider := bigquery.NewTypeProvider(schemas) + + tests := []struct { + fieldName string + wantType *types.Type + wantFound bool + }{ + {"str_field", types.StringType, true}, + {"int_field", types.IntType, true}, + {"int64_field", types.IntType, true}, + {"float_field", types.DoubleType, true}, + {"bool_field", types.BoolType, true}, + {"bytes_field", types.BytesType, true}, + {"json_field", types.DynType, true}, + {"ts_field", types.TimestampType, true}, + {"str_lower", types.StringType, true}, + {"int_lower", types.IntType, true}, + {"nonexistent", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("test_table", tt.fieldName) + assert.Equal(t, tt.wantFound, found) + if found { + assert.Equal(t, tt.wantType, got.Type) + } + }) + } +} + +func TestTypeProvider_Close(_ *testing.T) { + provider := bigquery.NewTypeProvider(nil) + // Close should not panic + provider.Close() +} + +// setupBigQueryContainer starts a BigQuery emulator container and returns a client. +func setupBigQueryContainer(ctx context.Context, t *testing.T) (*tcbigquery.Container, *bq.Client) { + t.Helper() + + container, err := tcbigquery.Run(ctx, + "ghcr.io/goccy/bigquery-emulator:0.6.6", + tcbigquery.WithProjectID(testProjectID), + tcbigquery.WithDataYAML(bytes.NewReader(providerSeedYAML)), + testcontainers.WithImagePlatform("linux/amd64"), + ) + if err != nil { + if runtime.GOARCH == "arm64" || strings.Contains(err.Error(), "no image found") || strings.Contains(err.Error(), "container exited") { + t.Skipf("Skipping BigQuery integration test: emulator not available on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) + } + t.Fatalf("Failed to start BigQuery emulator container: %v", err) + } + + opts := []option.ClientOption{ + option.WithEndpoint(container.URI()), + option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())), + option.WithoutAuthentication(), + internaloption.SkipDialSettingsValidation(), + } + + client, err := bq.NewClient(ctx, container.ProjectID(), opts...) + if err != nil { + if termErr := container.Terminate(ctx); termErr != nil { + t.Logf("failed to terminate container: %v", termErr) + } + t.Fatalf("Failed to create BigQuery client: %v", err) + } + + return container, client +} + +func TestLoadTableSchema_Integration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, client := setupBigQueryContainer(ctx, t) + defer func() { + if closeErr := client.Close(); closeErr != nil { + t.Logf("failed to close client: %v", closeErr) + } + if err := container.Terminate(ctx); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + provider, err := bigquery.NewTypeProviderWithClient(ctx, client, testDataset) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "test_data") + require.NoError(t, err) + + // Verify schema was loaded + schemas := provider.GetSchemas() + assert.Contains(t, schemas, "test_data") + + // Verify FindStructType + typ, found := provider.FindStructType("test_data") + assert.True(t, found) + assert.NotNil(t, typ) + + // Verify FindStructFieldNames + names, found := provider.FindStructFieldNames("test_data") + assert.True(t, found) + assert.Contains(t, names, "id") + assert.Contains(t, names, "text_val") + assert.Contains(t, names, "int_val") + + // Verify type mappings from loaded schema + tests := []struct { + fieldName string + wantType *types.Type + }{ + {"id", types.IntType}, + {"text_val", types.StringType}, + {"int_val", types.IntType}, + {"float_val", types.DoubleType}, + {"bool_val", types.BoolType}, + } + + for _, tt := range tests { + t.Run("type_"+tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("test_data", tt.fieldName) + assert.True(t, found, "field %q should be found", tt.fieldName) + if found { + assert.Equal(t, tt.wantType, got.Type, "field %q type mismatch", tt.fieldName) + } + }) + } +} + +func TestLoadTableSchema_NonexistentTable(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, client := setupBigQueryContainer(ctx, t) + defer func() { + if closeErr := client.Close(); closeErr != nil { + t.Logf("failed to close client: %v", closeErr) + } + if err := container.Terminate(ctx); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + provider, err := bigquery.NewTypeProviderWithClient(ctx, client, testDataset) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "nonexistent_table") + require.Error(t, err) + assert.ErrorIs(t, err, bigquery.ErrInvalidSchema) +} diff --git a/bigquery/testdata/provider_seed.yaml b/bigquery/testdata/provider_seed.yaml new file mode 100644 index 0000000..8afab3f --- /dev/null +++ b/bigquery/testdata/provider_seed.yaml @@ -0,0 +1,41 @@ +projects: + - id: test-project + datasets: + - id: testdataset + tables: + - id: test_data + columns: + - name: id + type: INT64 + - name: text_val + type: STRING + - name: int_val + type: INT64 + - name: float_val + type: FLOAT64 + - name: bool_val + type: BOOL + - name: nullable_text + type: STRING + - name: nullable_int + type: INT64 + data: + - id: 1 + text_val: "hello" + int_val: 10 + float_val: 10.5 + bool_val: true + nullable_text: "present" + nullable_int: 100 + - id: 2 + text_val: "world" + int_val: 20 + float_val: 20.5 + bool_val: false + - id: 3 + text_val: "test" + int_val: 30 + float_val: 30.5 + bool_val: true + nullable_text: "here" + nullable_int: 200 diff --git a/bigquery_integration_test.go b/bigquery_integration_test.go new file mode 100644 index 0000000..a56546d --- /dev/null +++ b/bigquery_integration_test.go @@ -0,0 +1,441 @@ +package cel2sql_test + +import ( + "bytes" + "context" + _ "embed" + "fmt" + "runtime" + "strings" + "testing" + + "cloud.google.com/go/bigquery" + "github.com/google/cel-go/cel" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + tcbigquery "github.com/testcontainers/testcontainers-go/modules/gcloud/bigquery" + "google.golang.org/api/option" + "google.golang.org/api/option/internaloption" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/spandigital/cel2sql/v3" + bigqueryDialect "github.com/spandigital/cel2sql/v3/dialect/bigquery" + "github.com/spandigital/cel2sql/v3/pg" +) + +//go:embed testdata/bigquery_seed.yaml +var bigQuerySeedYAML []byte + +const ( + bigQueryProjectID = "test-project" + bigQueryDataset = "testdataset" +) + +// setupBigQueryContainer starts a BigQuery emulator container and returns a client. +// Returns nil container and client if the emulator cannot start (e.g., on arm64). +func setupBigQueryContainer(ctx context.Context, t *testing.T) (*tcbigquery.Container, *bigquery.Client) { + t.Helper() + + container, err := tcbigquery.Run(ctx, + "ghcr.io/goccy/bigquery-emulator:0.6.6", + tcbigquery.WithProjectID(bigQueryProjectID), + tcbigquery.WithDataYAML(bytes.NewReader(bigQuerySeedYAML)), + testcontainers.WithImagePlatform("linux/amd64"), + ) + if err != nil { + // The BigQuery emulator only provides amd64 images. On arm64 (Apple Silicon), + // it crashes under QEMU emulation due to Go runtime lfstack.push issues. + if runtime.GOARCH == "arm64" || strings.Contains(err.Error(), "no image found") || strings.Contains(err.Error(), "container exited") { + t.Skipf("Skipping BigQuery integration test: emulator not available on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) + } + t.Fatalf("Failed to start BigQuery emulator container: %v", err) + } + + opts := []option.ClientOption{ + option.WithEndpoint(container.URI()), + option.WithGRPCDialOption(grpc.WithTransportCredentials(insecure.NewCredentials())), + option.WithoutAuthentication(), + internaloption.SkipDialSettingsValidation(), + } + + client, err := bigquery.NewClient(ctx, container.ProjectID(), opts...) + if err != nil { + if termErr := container.Terminate(ctx); termErr != nil { + t.Logf("failed to terminate container: %v", termErr) + } + t.Fatalf("Failed to create BigQuery client: %v", err) + } + + return container, client +} + +// bigQueryCount executes a count query and returns the result. +func bigQueryCount(ctx context.Context, t *testing.T, client *bigquery.Client, query string) int { + t.Helper() + + q := client.Query(query) + it, err := q.Read(ctx) + require.NoError(t, err, "Failed to execute query: %s", query) + + var row []bigquery.Value + err = it.Next(&row) + require.NoError(t, err, "Failed to read query result: %s", query) + require.Len(t, row, 1, "Expected exactly one column in COUNT(*) result") + + switch v := row[0].(type) { + case int64: + return int(v) + case float64: + return int(v) + default: + t.Fatalf("Unexpected type %T for COUNT(*) result: %v", row[0], row[0]) + return 0 + } +} + +// TestBigQueryOperatorsIntegration validates operator conversions against a BigQuery emulator. +func TestBigQueryOperatorsIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, client := setupBigQueryContainer(ctx, t) + defer func() { + if closeErr := client.Close(); closeErr != nil { + t.Logf("failed to close client: %v", closeErr) + } + if err := container.Terminate(ctx); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + env, err := cel.NewEnv( + cel.Variable("id", cel.IntType), + cel.Variable("text_val", cel.StringType), + cel.Variable("int_val", cel.IntType), + cel.Variable("float_val", cel.DoubleType), + cel.Variable("bool_val", cel.BoolType), + cel.Variable("nullable_text", cel.StringType), + cel.Variable("nullable_int", cel.IntType), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(bigqueryDialect.New()) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + // Comparison operators + { + name: "Equality string", + celExpr: `text_val == "hello"`, + expectedRows: 1, + description: "String equality comparison", + }, + { + name: "Equality integer", + celExpr: `int_val == 20`, + expectedRows: 1, + description: "Integer equality comparison", + }, + { + name: "Equality float", + celExpr: `float_val == 10.5`, + expectedRows: 1, + description: "Float equality comparison", + }, + { + name: "Equality boolean", + celExpr: `bool_val == true`, + expectedRows: 3, + description: "Boolean equality comparison", + }, + { + name: "Not equal", + celExpr: `text_val != "hello"`, + expectedRows: 4, + description: "Not equal comparison", + }, + { + name: "Less than", + celExpr: `int_val < 15`, + expectedRows: 2, // 10, 5 + description: "Less than comparison", + }, + { + name: "Less than or equal", + celExpr: `int_val <= 15`, + expectedRows: 3, // 10, 5, 15 + description: "Less than or equal comparison", + }, + { + name: "Greater than", + celExpr: `int_val > 15`, + expectedRows: 2, // 20, 30 + description: "Greater than comparison", + }, + { + name: "Greater than or equal", + celExpr: `int_val >= 15`, + expectedRows: 3, // 20, 30, 15 + description: "Greater than or equal comparison", + }, + + // Logical operators + { + name: "Logical AND", + celExpr: `int_val > 10 && bool_val == true`, + expectedRows: 2, // rows 3 (30,true) and 5 (15,true) + description: "Logical AND operator", + }, + { + name: "Logical OR", + celExpr: `int_val < 10 || bool_val == false`, + expectedRows: 2, // rows 2 (20,false) and 4 (5,false) + description: "Logical OR operator", + }, + { + name: "Logical NOT", + celExpr: `!bool_val`, + expectedRows: 2, // rows 2 and 4 + description: "Logical NOT operator", + }, + { + name: "Complex logical expression", + celExpr: `(int_val > 10 && bool_val) || int_val < 10`, + expectedRows: 3, // rows 3, 5, 4 + description: "Complex nested logical operators", + }, + + // Arithmetic operators + { + name: "Addition", + celExpr: `int_val + 10 == 20`, + expectedRows: 1, // 10 + 10 = 20 + description: "Addition operator", + }, + { + name: "Subtraction", + celExpr: `int_val - 5 == 15`, + expectedRows: 1, // 20 - 5 = 15 + description: "Subtraction operator", + }, + { + name: "Multiplication", + celExpr: `int_val * 2 == 20`, + expectedRows: 1, // 10 * 2 = 20 + description: "Multiplication operator", + }, + { + name: "Division", + celExpr: `int_val / 2 == 10`, + expectedRows: 1, // 20 / 2 = 10 + description: "Division operator", + }, + { + name: "Modulo", + celExpr: `int_val % 10 == 0`, + expectedRows: 3, // 10, 20, 30 + description: "Modulo operator", + }, + { + name: "Complex arithmetic", + celExpr: `(int_val * 2) + 5 > 30`, + expectedRows: 3, // (20*2)+5=45, (30*2)+5=65, (15*2)+5=35 + description: "Complex arithmetic expression", + }, + + // String operators + { + name: "String concatenation", + celExpr: `text_val + "!" == "hello!"`, + expectedRows: 1, + description: "String concatenation (||)", + }, + { + name: "String contains", + celExpr: `text_val.contains("world")`, + expectedRows: 2, // "world", "hello world" + description: "String contains function (INSTR)", + }, + { + name: "String startsWith", + celExpr: `text_val.startsWith("hello")`, + expectedRows: 2, // "hello", "hello world" + description: "String startsWith function (LIKE)", + }, + { + name: "String endsWith", + celExpr: `text_val.endsWith("world")`, + expectedRows: 2, // "world", "hello world" + description: "String endsWith function (LIKE)", + }, + + // Regex (BigQuery uses REGEXP_CONTAINS with RE2) + { + name: "Regex match", + celExpr: `text_val.matches(r"^hello")`, + expectedRows: 2, // "hello", "hello world" + description: "Regex match (REGEXP_CONTAINS)", + }, + { + name: "Regex simple pattern", + celExpr: `text_val.matches(r"test")`, + expectedRows: 2, // "test", "testing" + description: "Regex simple pattern", + }, + + // Complex combined operators + { + name: "Complex multi-operator expression", + celExpr: `int_val > 10 && bool_val && text_val.contains("test")`, + expectedRows: 2, // rows 3 and 5 + description: "Complex expression with multiple operator types", + }, + { + name: "Nested parenthesized operators", + celExpr: `((int_val + 5) * 2 > 30) && (text_val.contains("test") || bool_val)`, + expectedRows: 2, // rows 3 and 5 + description: "Deeply nested operators with parentheses", + }, + { + name: "Triple negation", + celExpr: `!!!bool_val`, + expectedRows: 2, + description: "Multiple NOT operators", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + sqlCondition, err := cel2sql.Convert(ast, dialectOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := fmt.Sprintf("SELECT COUNT(*) FROM `%s.test_data` WHERE %s", bigQueryDataset, sqlCondition) + t.Logf("Full SQL Query: %s", query) + + actualRows := bigQueryCount(ctx, t, client, query) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s\nCEL: %s\nSQL: %s", + tt.description, tt.celExpr, sqlCondition) + + t.Logf("OK: %s (expected %d rows, got %d rows)", + tt.description, tt.expectedRows, actualRows) + }) + } +} + +// TestBigQueryJSONIntegration validates JSON operations against a BigQuery emulator. +func TestBigQueryJSONIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, client := setupBigQueryContainer(ctx, t) + defer func() { + if closeErr := client.Close(); closeErr != nil { + t.Logf("failed to close client: %v", closeErr) + } + if err := container.Terminate(ctx); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + // Set up CEL environment with schema for JSON detection + productSchema := pg.NewSchema([]pg.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + {Name: "price", Type: "double precision"}, + {Name: "metadata", Type: "json", IsJSON: true}, + }) + + schemas := map[string]pg.Schema{ + "product": productSchema, + } + + provider := pg.NewTypeProvider(schemas) + + env, err := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("product", cel.ObjectType("product")), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(bigqueryDialect.New()) + schemaOpt := cel2sql.WithSchemas(schemas) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + { + name: "JSON field access", + celExpr: `product.metadata.brand == "Acme"`, + expectedRows: 2, + description: "JSON field access with JSON_VALUE", + }, + { + name: "JSON field access different value", + celExpr: `product.metadata.color == "blue"`, + expectedRows: 1, + description: "JSON field access with different value", + }, + { + name: "JSON with regular field", + celExpr: `product.metadata.brand == "Acme" && product.price > 30.0`, + expectedRows: 1, // Doohickey (Acme, 39.99) + description: "JSON field combined with regular field comparison", + }, + { + name: "JSON field existence", + celExpr: `has(product.metadata.brand)`, + expectedRows: 3, // All rows have 'brand' + description: "JSON field existence check", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + sqlCondition, err := cel2sql.Convert(ast, dialectOpt, schemaOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := fmt.Sprintf("SELECT COUNT(*) FROM `%s.products` product WHERE %s", bigQueryDataset, sqlCondition) + t.Logf("Full SQL Query: %s", query) + + actualRows := bigQueryCount(ctx, t, client, query) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s\nCEL: %s\nSQL: %s", + tt.description, tt.celExpr, sqlCondition) + + t.Logf("OK: %s", tt.description) + }) + } +} diff --git a/cel2sql.go b/cel2sql.go index 52d91b4..9c70b0e 100644 --- a/cel2sql.go +++ b/cel2sql.go @@ -1,13 +1,12 @@ -// Package cel2sql converts CEL (Common Expression Language) expressions to PostgreSQL SQL conditions. +// Package cel2sql converts CEL (Common Expression Language) expressions to SQL conditions. +// It supports multiple SQL dialects through the dialect interface, with PostgreSQL as the default. package cel2sql import ( "context" - "encoding/hex" "fmt" "log/slog" "math" - "regexp" "slices" "strconv" "strings" @@ -18,26 +17,16 @@ import ( "github.com/google/cel-go/common/overloads" exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/dialect/postgres" "github.com/spandigital/cel2sql/v3/schema" ) // Implementations based on `google/cel-go`'s unparser // https://github.com/google/cel-go/blob/master/parser/unparser.go -// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). +// Resource limit constants. const ( - // maxRegexPatternLength is the maximum allowed length for regex patterns - // to prevent processing extremely long patterns that could cause DoS. - maxRegexPatternLength = 500 - - // maxRegexGroups is the maximum number of capture groups allowed in a pattern - // to prevent memory exhaustion and slow matching. - maxRegexGroups = 20 - - // maxRegexNestingDepth is the maximum nesting depth for groups and quantifiers - // to prevent catastrophic backtracking. - maxRegexNestingDepth = 10 - // defaultMaxRecursionDepth is the default maximum recursion depth for visit() // to prevent stack overflow from deeply nested expressions (CWE-674: Uncontrolled Recursion). defaultMaxRecursionDepth = 100 @@ -65,8 +54,23 @@ type convertOptions struct { schemas map[string]schema.Schema ctx context.Context logger *slog.Logger - maxDepth int // Maximum recursion depth (0 = use default) - maxOutputLen int // Maximum SQL output length (0 = use default) + maxDepth int // Maximum recursion depth (0 = use default) + maxOutputLen int // Maximum SQL output length (0 = use default) + dialect dialect.Dialect // SQL dialect (nil = PostgreSQL default) +} + +// WithDialect sets the SQL dialect for conversion. +// If not provided, PostgreSQL is used as the default dialect. +// +// Example: +// +// import "github.com/spandigital/cel2sql/v3/dialect/mysql" +// +// sql, err := cel2sql.Convert(ast, cel2sql.WithDialect(mysql.New())) +func WithDialect(d dialect.Dialect) ConvertOption { + return func(o *convertOptions) { + o.dialect = d + } } // WithSchemas provides schema information for proper JSON/JSONB field handling. @@ -177,16 +181,20 @@ type Result struct { Parameters []any // Parameter values in order ($1, $2, etc.) } -// Convert converts a CEL AST to a PostgreSQL SQL WHERE clause condition. -// Options can be provided to configure the conversion behavior. +// Convert converts a CEL AST to a SQL WHERE clause condition. +// By default, PostgreSQL SQL is generated. Use WithDialect to select a different dialect. // -// Example without options: +// Example without options (PostgreSQL): // // sql, err := cel2sql.Convert(ast) // // Example with schema information for JSON/JSONB support: // // sql, err := cel2sql.Convert(ast, cel2sql.WithSchemas(schemas)) +// +// Example with a different dialect: +// +// sql, err := cel2sql.Convert(ast, cel2sql.WithDialect(mysql.New())) func Convert(ast *cel.Ast, opts ...ConvertOption) (string, error) { start := time.Now() @@ -199,6 +207,11 @@ func Convert(ast *cel.Ast, opts ...ConvertOption) (string, error) { opt(options) } + // Default to PostgreSQL dialect if none specified + if options.dialect == nil { + options.dialect = postgres.New() + } + options.logger.Debug("starting CEL to SQL conversion") checkedExpr, err := cel.AstToCheckedExpr(ast) @@ -212,6 +225,7 @@ func Convert(ast *cel.Ast, opts ...ConvertOption) (string, error) { schemas: options.schemas, ctx: options.ctx, logger: options.logger, + dialect: options.dialect, maxDepth: options.maxDepth, maxOutputLen: options.maxOutputLen, } @@ -227,15 +241,16 @@ func Convert(ast *cel.Ast, opts ...ConvertOption) (string, error) { options.logger.LogAttrs(context.Background(), slog.LevelDebug, "conversion completed", slog.String("sql", result), + slog.String("dialect", string(options.dialect.Name())), slog.Duration("duration", duration), ) return result, nil } -// ConvertParameterized converts a CEL AST to a parameterized PostgreSQL SQL WHERE clause. -// Returns both the SQL string with placeholders ($1, $2, etc.) and the parameter values. -// This enables query plan caching and provides additional SQL injection protection. +// ConvertParameterized converts a CEL AST to a parameterized SQL WHERE clause. +// Returns both the SQL string with placeholders and the parameter values. +// By default uses PostgreSQL ($1, $2). Use WithDialect for other placeholder styles. // // Constants that are parameterized: // - String literals: 'John' → $1 @@ -268,6 +283,11 @@ func ConvertParameterized(ast *cel.Ast, opts ...ConvertOption) (*Result, error) opt(options) } + // Default to PostgreSQL dialect if none specified + if options.dialect == nil { + options.dialect = postgres.New() + } + options.logger.Debug("starting parameterized CEL to SQL conversion") checkedExpr, err := cel.AstToCheckedExpr(ast) @@ -281,6 +301,7 @@ func ConvertParameterized(ast *cel.Ast, opts ...ConvertOption) (*Result, error) schemas: options.schemas, ctx: options.ctx, logger: options.logger, + dialect: options.dialect, maxDepth: options.maxDepth, maxOutputLen: options.maxOutputLen, parameterize: true, // Enable parameterization @@ -313,13 +334,14 @@ type converter struct { schemas map[string]schema.Schema ctx context.Context logger *slog.Logger + dialect dialect.Dialect depth int // Current recursion depth maxDepth int // Maximum allowed recursion depth maxOutputLen int // Maximum allowed SQL output length comprehensionDepth int // Current comprehension nesting depth parameterize bool // Enable parameterized output parameters []any // Collected parameters for parameterized queries - paramCount int // Parameter counter for placeholders ($1, $2, etc.) + paramCount int // Parameter counter for placeholders } // checkContext checks if the context has been cancelled or expired. @@ -599,6 +621,30 @@ func (con *converter) visitCallBinary(expr *exprpb.Expr) error { rhsParen = isSamePrecedence(fun, rhs) } + // Handle string concatenation via dialect before writing LHS. + // This allows MySQL to use CONCAT() instead of ||. + if fun == operators.Add && + ((lhsType.GetPrimitive() == exprpb.Type_STRING && rhsType.GetPrimitive() == exprpb.Type_STRING) || + (isStringLiteral(lhs) || isStringLiteral(rhs))) { + return con.dialect.WriteStringConcat(&con.str, + func() error { return con.visitMaybeNested(lhs, lhsParen) }, + func() error { return con.visitMaybeNested(rhs, rhsParen) }, + ) + } + + // Handle array membership (IN operator with list) via dialect before writing LHS. + // This allows dialects like SQLite to use a fundamentally different pattern + // (e.g., "elem IN (SELECT value FROM json_each(array))") instead of "elem = ANY(array)". + if fun == operators.In && isListType(rhsType) { + // Non-JSON list membership + if !isFieldAccessExpression(rhs) || !con.isJSONArrayField(rhs) { + return con.dialect.WriteArrayMembership(&con.str, + func() error { return con.visitMaybeNested(lhs, lhsParen) }, + func() error { return con.visitMaybeNested(rhs, rhsParen) }, + ) + } + } + // Check if we need numeric casting for JSON text extraction needsNumericCasting := false if con.isJSONTextExtraction(lhs) && isNumericComparison(fun) && isNumericType(rhsType) { @@ -611,7 +657,8 @@ func (con *converter) visitCallBinary(expr *exprpb.Expr) error { } if needsNumericCasting { - con.str.WriteString(")::numeric") + con.str.WriteString(")") + con.dialect.WriteCastToNumeric(&con.str) } var operator string if fun == operators.Add && (lhsType.GetPrimitive() == exprpb.Type_STRING && rhsType.GetPrimitive() == exprpb.Type_STRING) { @@ -655,28 +702,25 @@ func (con *converter) visitCallBinary(expr *exprpb.Expr) error { if fun == operators.In && (isListType(rhsType) || isFieldAccessExpression(rhs)) { // Check if we're dealing with a JSON array if isFieldAccessExpression(rhs) && con.isJSONArrayField(rhs) { - // For JSON arrays, use jsonb_array_elements with ANY + // For JSON arrays, use dialect-specific JSON array membership jsonFunc := con.getJSONArrayFunction(rhs) - con.str.WriteString("ANY(ARRAY(SELECT ") // For nested JSON access like settings.permissions, we need to handle differently if con.isNestedJSONAccess(rhs) { - // Use text extraction for the array elements - con.str.WriteString("jsonb_array_elements_text(") - // Generate the JSON path with -> instead of ->> to preserve JSONB type - if err := con.visitNestedJSONForArray(rhs); err != nil { + // Use dialect-specific nested JSON array membership + if err := con.dialect.WriteNestedJSONArrayMembership(&con.str, func() error { + return con.visitNestedJSONForArray(rhs) + }); err != nil { return err } - con.str.WriteString(")))") return nil } // For direct JSON array access - con.str.WriteString(jsonFunc) - con.str.WriteString("(") - if err := con.visitMaybeNested(rhs, rhsParen); err != nil { + if err := con.dialect.WriteJSONArrayMembership(&con.str, jsonFunc, func() error { + return con.visitMaybeNested(rhs, rhsParen) + }); err != nil { return err } - con.str.WriteString(")))") return nil } con.str.WriteString("ANY(") @@ -728,27 +772,27 @@ func (con *converter) callContains(target *exprpb.Expr, args []*exprpb.Expr) err return nil } - // For regular strings, use POSITION - con.str.WriteString("POSITION(") - for i, arg := range args { - err := con.visit(arg) - if err != nil { - return err - } - if i < len(args)-1 { - con.str.WriteString(" IN ") - } - } - if target != nil { - con.str.WriteString(" IN ") - nested := isBinaryOrTernaryOperator(target) - err := con.visitMaybeNested(target, nested) - if err != nil { - return err - } - } - con.str.WriteString(") > 0") - return nil + // For regular strings, use dialect-specific contains + return con.dialect.WriteContains(&con.str, + func() error { + if target != nil { + nested := isBinaryOrTernaryOperator(target) + return con.visitMaybeNested(target, nested) + } + return nil + }, + func() error { + for i, arg := range args { + if err := con.visit(arg); err != nil { + return err + } + if i < len(args)-1 { + con.str.WriteString(", ") + } + } + return nil + }, + ) } func (con *converter) callStartsWith(target *exprpb.Expr, args []*exprpb.Expr) error { @@ -780,14 +824,16 @@ func (con *converter) callStartsWith(target *exprpb.Expr, args []*exprpb.Expr) e escaped := escapeLikePattern(prefix) con.str.WriteString("'") con.str.WriteString(escaped) - con.str.WriteString("%' ESCAPE E'\\\\'") + con.str.WriteString("%'") + con.dialect.WriteLikeEscape(&con.str) } else { // For non-literal patterns, escape special characters at runtime and concatenate with % con.str.WriteString("REPLACE(REPLACE(REPLACE(") if err := con.visit(args[0]); err != nil { return err } - con.str.WriteString(", '\\\\', '\\\\\\\\'), '%', '\\%'), '_', '\\_') || '%' ESCAPE E'\\\\'") + con.str.WriteString(", '\\\\', '\\\\\\\\'), '%', '\\%'), '_', '\\_') || '%'") + con.dialect.WriteLikeEscape(&con.str) } return nil @@ -795,8 +841,7 @@ func (con *converter) callStartsWith(target *exprpb.Expr, args []*exprpb.Expr) e func (con *converter) callEndsWith(target *exprpb.Expr, args []*exprpb.Expr) error { // CEL endsWith function: string.endsWith(suffix) - // Convert to PostgreSQL: string LIKE '%suffix' - // or for more robust handling: RIGHT(string, LENGTH(suffix)) = suffix + // Convert to SQL: string LIKE '%suffix' if target == nil || len(args) == 0 { return fmt.Errorf("%w: endsWith function requires both string and suffix arguments", ErrInvalidArguments) @@ -822,14 +867,16 @@ func (con *converter) callEndsWith(target *exprpb.Expr, args []*exprpb.Expr) err escaped := escapeLikePattern(suffix) con.str.WriteString("'%") con.str.WriteString(escaped) - con.str.WriteString("' ESCAPE E'\\\\'") + con.str.WriteString("'") + con.dialect.WriteLikeEscape(&con.str) } else { // For non-literal patterns, escape special characters at runtime and concatenate with % con.str.WriteString("'%' || REPLACE(REPLACE(REPLACE(") if err := con.visit(args[0]); err != nil { return err } - con.str.WriteString(", '\\\\', '\\\\\\\\'), '%', '\\%'), '_', '\\_') ESCAPE E'\\\\'") + con.str.WriteString(", '\\\\', '\\\\\\\\'), '%', '\\%'), '_', '\\_')") + con.dialect.WriteLikeEscape(&con.str) } return nil @@ -841,40 +888,44 @@ func (con *converter) callCasting(function string, _ *exprpb.Expr, args []*exprp } arg := args[0] if function == overloads.TypeConvertInt && isTimestampType(con.getType(arg)) { - con.str.WriteString("EXTRACT(EPOCH FROM ") - if err := con.visit(arg); err != nil { - return err - } - con.str.WriteString(")::bigint") - return nil + return con.dialect.WriteEpochExtract(&con.str, func() error { + return con.visit(arg) + }) } con.str.WriteString("CAST(") if err := con.visit(arg); err != nil { return err } con.str.WriteString(" AS ") + // Map CEL type conversion function to dialect-specific type name + var celTypeName string switch function { case overloads.TypeConvertBool: - con.str.WriteString("BOOLEAN") + celTypeName = "bool" case overloads.TypeConvertBytes: - con.str.WriteString("BYTEA") + celTypeName = "bytes" case overloads.TypeConvertDouble: - con.str.WriteString("DOUBLE PRECISION") + celTypeName = "double" case overloads.TypeConvertInt: - con.str.WriteString("BIGINT") + celTypeName = "int" case overloads.TypeConvertString: - con.str.WriteString("TEXT") + celTypeName = "string" case overloads.TypeConvertUint: - con.str.WriteString("BIGINT") + celTypeName = "uint" } + con.dialect.WriteTypeName(&con.str, celTypeName) con.str.WriteString(")") return nil } -// callMatches handles CEL matches() function with RE2 to POSIX regex conversion +// callMatches handles CEL matches() function with regex conversion func (con *converter) callMatches(target *exprpb.Expr, args []*exprpb.Expr) error { // CEL matches function: string.matches(pattern) or matches(string, pattern) - // Convert to PostgreSQL: string ~ 'posix_pattern' + + // Check if the dialect supports regex + if !con.dialect.SupportsRegex() { + return fmt.Errorf("%w: regex matching is not supported by %s dialect", ErrUnsupportedDialectFeature, con.dialect.Name()) + } // Get the string to match against var stringExpr *exprpb.Expr @@ -896,22 +947,16 @@ func (con *converter) callMatches(target *exprpb.Expr, args []*exprpb.Expr) erro return fmt.Errorf("%w: matches function requires both string and pattern arguments", ErrInvalidArguments) } - // Visit the string expression - if err := con.visit(stringExpr); err != nil { - return err - } - - // Visit the pattern expression and convert from RE2 to POSIX if it's a string literal + // Visit the pattern expression and convert if it's a string literal if constExpr := patternExpr.GetConstExpr(); constExpr != nil && constExpr.GetStringValue() != "" { - // Convert RE2 pattern to POSIX re2Pattern := constExpr.GetStringValue() // Reject patterns containing null bytes if strings.Contains(re2Pattern, "\x00") { return fmt.Errorf("%w: regex patterns cannot contain null bytes", ErrInvalidRegexPattern) } - // Convert RE2 to POSIX with security validation - posixPattern, caseInsensitive, err := convertRE2ToPOSIX(re2Pattern) + // Convert RE2 to dialect-native format with security validation + convertedPattern, caseInsensitive, err := con.dialect.ConvertRegex(re2Pattern) if err != nil { return fmt.Errorf("%w: %w", ErrInvalidRegexPattern, err) } @@ -919,32 +964,23 @@ func (con *converter) callMatches(target *exprpb.Expr, args []*exprpb.Expr) erro con.logger.LogAttrs(context.Background(), slog.LevelDebug, "regex pattern conversion", slog.String("original_pattern", re2Pattern), - slog.String("converted_pattern", posixPattern), + slog.String("converted_pattern", convertedPattern), slog.Bool("case_insensitive", caseInsensitive), + slog.String("dialect", string(con.dialect.Name())), ) - // Use ~* for case-insensitive matching, ~ for case-sensitive - if caseInsensitive { - con.str.WriteString(" ~* ") - } else { - con.str.WriteString(" ~ ") - } - - // Write the converted pattern as a string literal - escaped := strings.ReplaceAll(posixPattern, "'", "''") - con.str.WriteString("'") - con.str.WriteString(escaped) - con.str.WriteString("'") - } else { - // For non-literal patterns, we can't convert at compile time - // Just use the pattern as-is with case-sensitive operator - con.str.WriteString(" ~ ") - if err := con.visit(patternExpr); err != nil { - return err - } + // Use dialect-specific regex match writing + return con.dialect.WriteRegexMatch(&con.str, func() error { + return con.visit(stringExpr) + }, convertedPattern, caseInsensitive) } - - return nil + // For non-literal patterns, we can't convert at compile time + // Visit the string, then write regex operator, then visit the pattern + if err := con.visit(stringExpr); err != nil { + return err + } + con.str.WriteString(" ~ ") + return con.visit(patternExpr) } // callLowerASCII handles CEL lowerAscii() string function @@ -1469,53 +1505,36 @@ func (con *converter) callSplit(target *exprpb.Expr, args []*exprpb.Expr) error } // Generate SQL based on limit value + writeStr := func() error { + nested := isBinaryOrTernaryOperator(stringExpr) + return con.visitMaybeNested(stringExpr, nested) + } + writeDelim := func() error { + return con.visit(delimiterExpr) + } + switch { case limit == 0: // Empty array - con.str.WriteString("ARRAY[]::text[]") + con.dialect.WriteEmptyTypedArray(&con.str, "text") return nil case limit == 1: // Return original string as single-element array - con.str.WriteString("ARRAY[") - nested := isBinaryOrTernaryOperator(stringExpr) - if err := con.visitMaybeNested(stringExpr, nested); err != nil { + con.dialect.WriteArrayLiteralOpen(&con.str) + if err := writeStr(); err != nil { return err } - con.str.WriteString("]") + con.dialect.WriteArrayLiteralClose(&con.str) return nil case limit == -1: - // Unlimited splits (default PostgreSQL behavior) - con.str.WriteString("STRING_TO_ARRAY(") - nested := isBinaryOrTernaryOperator(stringExpr) - if err := con.visitMaybeNested(stringExpr, nested); err != nil { - return err - } - con.str.WriteString(", ") - if err := con.visit(delimiterExpr); err != nil { - return err - } - con.str.WriteString(")") - return nil + // Unlimited splits + return con.dialect.WriteSplit(&con.str, writeStr, writeDelim) case limit > 1: - // Arbitrary positive limit - use array slicing with REGEXP_SPLIT_TO_ARRAY - // REGEXP_SPLIT_TO_ARRAY is more powerful and allows us to limit splits - // Result: (REGEXP_SPLIT_TO_ARRAY(string, delimiter))[1:limit] - con.str.WriteString("(STRING_TO_ARRAY(") - nested := isBinaryOrTernaryOperator(stringExpr) - if err := con.visitMaybeNested(stringExpr, nested); err != nil { - return err - } - con.str.WriteString(", ") - if err := con.visit(delimiterExpr); err != nil { - return err - } - con.str.WriteString("))[1:") - con.str.WriteString(strconv.FormatInt(limit, 10)) - con.str.WriteString("]") - return nil + // Positive limit - use dialect-specific split with limit + return con.dialect.WriteSplitWithLimit(&con.str, writeStr, writeDelim, limit) default: // Negative limits other than -1 are not supported @@ -1559,26 +1578,18 @@ func (con *converter) callJoin(target *exprpb.Expr, args []*exprpb.Expr) error { } } - // Generate SQL - con.str.WriteString("ARRAY_TO_STRING(") - nested := isBinaryOrTernaryOperator(arrayExpr) - if err := con.visitMaybeNested(arrayExpr, nested); err != nil { - return err + // Generate SQL using dialect-specific join + writeArray := func() error { + nested := isBinaryOrTernaryOperator(arrayExpr) + return con.visitMaybeNested(arrayExpr, nested) } - con.str.WriteString(", ") - - // Use provided delimiter or empty string default + var writeDelim func() error if delimiterExpr != nil { - if err := con.visit(delimiterExpr); err != nil { - return err + writeDelim = func() error { + return con.visit(delimiterExpr) } - } else { - con.str.WriteString("''") } - - // Third parameter: null_string (use empty string to replace nulls) - con.str.WriteString(", '')") - return nil + return con.dialect.WriteJoin(&con.str, writeArray, writeDelim) } // callFormat handles CEL format() function @@ -1810,28 +1821,17 @@ func (con *converter) visitCallFunc(expr *exprpb.Expr) error { case isListType(argType): // Check if this is a JSON array field if con.isJSONArrayField(argExpr) { - // For JSON arrays, use jsonb_array_length wrapped in COALESCE - con.str.WriteString("COALESCE(jsonb_array_length(") - err := con.visit(argExpr) - if err != nil { - return err - } - con.str.WriteString("), 0)") - return nil + // For JSON arrays, use dialect-specific JSON array length + return con.dialect.WriteJSONArrayLength(&con.str, func() error { + return con.visit(argExpr) + }) } - // For PostgreSQL, we need to specify the array dimension - // Detect the dimension from schema if available, otherwise default to 1 + // For native arrays, use dialect-specific array length dimension := con.getArrayDimension(argExpr) - - // Wrap in COALESCE to handle NULL arrays (ARRAY_LENGTH returns NULL for NULL input) - con.str.WriteString("COALESCE(ARRAY_LENGTH(") - nested := isBinaryOrTernaryOperator(argExpr) - err := con.visitMaybeNested(argExpr, nested) - if err != nil { - return err - } - fmt.Fprintf(&con.str, ", %d), 0)", dimension) - return nil + return con.dialect.WriteArrayLength(&con.str, dimension, func() error { + nested := isBinaryOrTernaryOperator(argExpr) + return con.visitMaybeNested(argExpr, nested) + }) default: return newConversionErrorf(errMsgUnsupportedType, "size() argument type: %s", argType.String()) } @@ -1900,13 +1900,9 @@ func (con *converter) visitCallListIndex(expr *exprpb.Expr) error { return fmt.Errorf("%w: list index operator requires list and index arguments", ErrInvalidArguments) } l := args[0] - nested := isBinaryOrTernaryOperator(l) - if err := con.visitMaybeNested(l, nested); err != nil { - return err - } - con.str.WriteString("[") index := args[1] - // PostgreSQL arrays are 1-indexed, CEL is 0-indexed, so add 1 + + // Check for constant index if constExpr := index.GetConstExpr(); constExpr != nil { idx := constExpr.GetInt64Value() if idx == math.MaxInt64 { @@ -1915,15 +1911,19 @@ func (con *converter) visitCallListIndex(expr *exprpb.Expr) error { if idx < 0 { return fmt.Errorf("%w: negative array index %d is not supported", ErrInvalidArguments, idx) } - con.str.WriteString(strconv.FormatInt(idx+1, 10)) - } else { - if err := con.visit(index); err != nil { - return err - } - con.str.WriteString(" + 1") + return con.dialect.WriteListIndexConst(&con.str, func() error { + nested := isBinaryOrTernaryOperator(l) + return con.visitMaybeNested(l, nested) + }, idx) } - con.str.WriteString("]") - return nil + + // Dynamic index + return con.dialect.WriteListIndex(&con.str, func() error { + nested := isBinaryOrTernaryOperator(l) + return con.visitMaybeNested(l, nested) + }, func() error { + return con.visit(index) + }) } func (con *converter) visitCallUnary(expr *exprpb.Expr) error { @@ -1992,36 +1992,17 @@ func (con *converter) visitComprehension(expr *exprpb.Expr) error { // Comprehension visit functions - Phase 1 placeholder implementations func (con *converter) visitAllComprehension(expr *exprpb.Expr, info *ComprehensionInfo) error { - // Generate SQL for ALL comprehension: all elements must satisfy the predicate - // Pattern: NOT EXISTS (SELECT 1 FROM UNNEST(array) AS item WHERE NOT predicate) - // For JSON arrays: NOT EXISTS (SELECT 1 FROM jsonb_array_elements(json_field) AS item WHERE NOT predicate) - comprehension := expr.GetComprehensionExpr() if comprehension == nil { return newConversionError(errMsgUnsupportedComprehension, "expression is not a comprehension (ALL)") } iterRange := comprehension.GetIterRange() - isJSONArray := con.isJSONArrayField(iterRange) con.str.WriteString("NOT EXISTS (SELECT 1 FROM ") - - if isJSONArray { - jsonFunc := con.getJSONArrayFunction(iterRange) - con.str.WriteString(jsonFunc) - con.str.WriteString("(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in ALL comprehension") - } - con.str.WriteString(")") - } else { - con.str.WriteString("UNNEST(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in ALL comprehension") - } - con.str.WriteString(")") + if err := con.writeComprehensionSource(iterRange); err != nil { + return wrapConversionError(err, "visiting iter range in ALL comprehension") } - con.str.WriteString(" AS ") con.str.WriteString(info.IterVar) @@ -2038,36 +2019,17 @@ func (con *converter) visitAllComprehension(expr *exprpb.Expr, info *Comprehensi } func (con *converter) visitExistsComprehension(expr *exprpb.Expr, info *ComprehensionInfo) error { - // Generate SQL for EXISTS comprehension: at least one element satisfies the predicate - // Pattern: EXISTS (SELECT 1 FROM UNNEST(array) AS item WHERE predicate) - // For JSON arrays: EXISTS (SELECT 1 FROM jsonb_array_elements(json_field) AS item WHERE predicate) - comprehension := expr.GetComprehensionExpr() if comprehension == nil { return newConversionError(errMsgUnsupportedComprehension, "expression is not a comprehension (EXISTS)") } iterRange := comprehension.GetIterRange() - isJSONArray := con.isJSONArrayField(iterRange) con.str.WriteString("EXISTS (SELECT 1 FROM ") - - if isJSONArray { - jsonFunc := con.getJSONArrayFunction(iterRange) - con.str.WriteString(jsonFunc) - con.str.WriteString("(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in EXISTS comprehension") - } - con.str.WriteString(")") - } else { - con.str.WriteString("UNNEST(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in EXISTS comprehension") - } - con.str.WriteString(")") + if err := con.writeComprehensionSource(iterRange); err != nil { + return wrapConversionError(err, "visiting iter range in EXISTS comprehension") } - con.str.WriteString(" AS ") con.str.WriteString(info.IterVar) @@ -2083,36 +2045,17 @@ func (con *converter) visitExistsComprehension(expr *exprpb.Expr, info *Comprehe } func (con *converter) visitExistsOneComprehension(expr *exprpb.Expr, info *ComprehensionInfo) error { - // Generate SQL for EXISTS_ONE comprehension: exactly one element satisfies the predicate - // Pattern: (SELECT COUNT(*) FROM UNNEST(array) AS item WHERE predicate) = 1 - // For JSON arrays: (SELECT COUNT(*) FROM jsonb_array_elements(json_field) AS item WHERE predicate) = 1 - comprehension := expr.GetComprehensionExpr() if comprehension == nil { return newConversionError(errMsgUnsupportedComprehension, "expression is not a comprehension (EXISTS_ONE)") } iterRange := comprehension.GetIterRange() - isJSONArray := con.isJSONArrayField(iterRange) con.str.WriteString("(SELECT COUNT(*) FROM ") - - if isJSONArray { - jsonFunc := con.getJSONArrayFunction(iterRange) - con.str.WriteString(jsonFunc) - con.str.WriteString("(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in EXISTS_ONE comprehension") - } - con.str.WriteString(")") - } else { - con.str.WriteString("UNNEST(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in EXISTS_ONE comprehension") - } - con.str.WriteString(")") + if err := con.writeComprehensionSource(iterRange); err != nil { + return wrapConversionError(err, "visiting iter range in EXISTS_ONE comprehension") } - con.str.WriteString(" AS ") con.str.WriteString(info.IterVar) @@ -2128,52 +2071,29 @@ func (con *converter) visitExistsOneComprehension(expr *exprpb.Expr, info *Compr } func (con *converter) visitMapComprehension(expr *exprpb.Expr, info *ComprehensionInfo) error { - // Generate SQL for MAP comprehension: transform elements using the transform expression - // Pattern: ARRAY(SELECT transform FROM UNNEST(array) AS item [WHERE filter]) - // For JSON arrays: ARRAY(SELECT transform FROM jsonb_array_elements(json_field) AS item [WHERE filter]) - comprehension := expr.GetComprehensionExpr() if comprehension == nil { return newConversionError(errMsgUnsupportedComprehension, "expression is not a comprehension (MAP)") } iterRange := comprehension.GetIterRange() - isJSONArray := con.isJSONArrayField(iterRange) - - con.str.WriteString("ARRAY(SELECT ") - // Visit the transform expression + con.dialect.WriteArraySubqueryOpen(&con.str) if info.Transform != nil { if err := con.visit(info.Transform); err != nil { return wrapConversionError(err, "visiting transform in MAP comprehension") } } else { - // If no transform, just return the variable itself con.str.WriteString(info.IterVar) } - + con.dialect.WriteArraySubqueryExprClose(&con.str) con.str.WriteString(" FROM ") - - if isJSONArray { - jsonFunc := con.getJSONArrayFunction(iterRange) - con.str.WriteString(jsonFunc) - con.str.WriteString("(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in MAP comprehension") - } - con.str.WriteString(")") - } else { - con.str.WriteString("UNNEST(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in MAP comprehension") - } - con.str.WriteString(")") + if err := con.writeComprehensionSource(iterRange); err != nil { + return wrapConversionError(err, "visiting iter range in MAP comprehension") } - con.str.WriteString(" AS ") con.str.WriteString(info.IterVar) - // Add filter condition if present (for map with filter) if info.Filter != nil { con.str.WriteString(" WHERE ") if err := con.visit(info.Filter); err != nil { @@ -2186,38 +2106,20 @@ func (con *converter) visitMapComprehension(expr *exprpb.Expr, info *Comprehensi } func (con *converter) visitFilterComprehension(expr *exprpb.Expr, info *ComprehensionInfo) error { - // Generate SQL for FILTER comprehension: return elements that satisfy the predicate - // Pattern: ARRAY(SELECT item FROM UNNEST(array) AS item WHERE predicate) - // For JSON arrays: ARRAY(SELECT item FROM jsonb_array_elements(json_field) AS item WHERE predicate) - comprehension := expr.GetComprehensionExpr() if comprehension == nil { return newConversionError(errMsgUnsupportedComprehension, "expression is not a comprehension (FILTER)") } iterRange := comprehension.GetIterRange() - isJSONArray := con.isJSONArrayField(iterRange) - con.str.WriteString("ARRAY(SELECT ") + con.dialect.WriteArraySubqueryOpen(&con.str) con.str.WriteString(info.IterVar) + con.dialect.WriteArraySubqueryExprClose(&con.str) con.str.WriteString(" FROM ") - - if isJSONArray { - jsonFunc := con.getJSONArrayFunction(iterRange) - con.str.WriteString(jsonFunc) - con.str.WriteString("(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in FILTER comprehension") - } - con.str.WriteString(")") - } else { - con.str.WriteString("UNNEST(") - if err := con.visit(iterRange); err != nil { - return wrapConversionError(err, "visiting iter range in FILTER comprehension") - } - con.str.WriteString(")") + if err := con.writeComprehensionSource(iterRange); err != nil { + return wrapConversionError(err, "visiting iter range in FILTER comprehension") } - con.str.WriteString(" AS ") con.str.WriteString(info.IterVar) @@ -2233,37 +2135,27 @@ func (con *converter) visitFilterComprehension(expr *exprpb.Expr, info *Comprehe } func (con *converter) visitTransformListComprehension(expr *exprpb.Expr, info *ComprehensionInfo) error { - // Generate SQL for TRANSFORM_LIST comprehension: similar to MAP but may have different semantics - // Pattern: ARRAY(SELECT transform FROM UNNEST(array) AS item [WHERE filter]) - comprehension := expr.GetComprehensionExpr() if comprehension == nil { return newConversionError(errMsgUnsupportedComprehension, "expression is not a comprehension (TRANSFORM_LIST)") } - con.str.WriteString("ARRAY(SELECT ") - - // Visit the transform expression + con.dialect.WriteArraySubqueryOpen(&con.str) if info.Transform != nil { if err := con.visit(info.Transform); err != nil { return wrapConversionError(err, "visiting transform in TRANSFORM_LIST comprehension") } } else { - // If no transform, just return the variable itself con.str.WriteString(info.IterVar) } - - con.str.WriteString(" FROM UNNEST(") - - // Visit the iterable range (the array/list being comprehended over) - if err := con.visit(comprehension.GetIterRange()); err != nil { + con.dialect.WriteArraySubqueryExprClose(&con.str) + con.str.WriteString(" FROM ") + if err := con.writeComprehensionSource(comprehension.GetIterRange()); err != nil { return wrapConversionError(err, "visiting iter range in TRANSFORM_LIST comprehension") } - - con.str.WriteString(") AS ") + con.str.WriteString(" AS ") con.str.WriteString(info.IterVar) - // Add filter condition if present if info.Filter != nil { con.str.WriteString(" WHERE ") if err := con.visit(info.Filter); err != nil { @@ -2305,7 +2197,7 @@ func (con *converter) visitConst(expr *exprpb.Expr) error { case *exprpb.Constant_Int64Value: if con.parameterize { con.paramCount++ - fmt.Fprintf(&con.str, "$%d", con.paramCount) + con.dialect.WriteParamPlaceholder(&con.str, con.paramCount) con.parameters = append(con.parameters, c.GetInt64Value()) } else { i := strconv.FormatInt(c.GetInt64Value(), 10) @@ -2314,7 +2206,7 @@ func (con *converter) visitConst(expr *exprpb.Expr) error { case *exprpb.Constant_Uint64Value: if con.parameterize { con.paramCount++ - fmt.Fprintf(&con.str, "$%d", con.paramCount) + con.dialect.WriteParamPlaceholder(&con.str, con.paramCount) con.parameters = append(con.parameters, c.GetUint64Value()) } else { ui := strconv.FormatUint(c.GetUint64Value(), 10) @@ -2323,7 +2215,7 @@ func (con *converter) visitConst(expr *exprpb.Expr) error { case *exprpb.Constant_DoubleValue: if con.parameterize { con.paramCount++ - fmt.Fprintf(&con.str, "$%d", con.paramCount) + con.dialect.WriteParamPlaceholder(&con.str, con.paramCount) con.parameters = append(con.parameters, c.GetDoubleValue()) } else { d := strconv.FormatFloat(c.GetDoubleValue(), 'g', -1, 64) @@ -2338,31 +2230,26 @@ func (con *converter) visitConst(expr *exprpb.Expr) error { if con.parameterize { con.paramCount++ - fmt.Fprintf(&con.str, "$%d", con.paramCount) + con.dialect.WriteParamPlaceholder(&con.str, con.paramCount) con.parameters = append(con.parameters, str) } else { - // Use single quotes for PostgreSQL string literals - // Escape single quotes by doubling them - escaped := strings.ReplaceAll(str, "'", "''") - con.str.WriteString("'") - con.str.WriteString(escaped) - con.str.WriteString("'") + con.dialect.WriteStringLiteral(&con.str, str) } case *exprpb.Constant_BytesValue: b := c.GetBytesValue() if con.parameterize { con.paramCount++ - fmt.Fprintf(&con.str, "$%d", con.paramCount) + con.dialect.WriteParamPlaceholder(&con.str, con.paramCount) con.parameters = append(con.parameters, b) } else { // Validate byte array length to prevent resource exhaustion (CWE-400) if len(b) > maxByteArrayLength { return fmt.Errorf("%w: %d bytes exceeds limit of %d bytes", ErrInvalidByteArrayLength, len(b), maxByteArrayLength) } - con.str.WriteString("'\\x") - con.str.WriteString(hex.EncodeToString(b)) - con.str.WriteString("'") + if err := con.dialect.WriteBytesLiteral(&con.str, b); err != nil { + return err + } } default: return newConversionErrorf(errMsgUnsupportedExpression, "constant type: %T", c.ConstantKind) @@ -2374,7 +2261,7 @@ func (con *converter) visitIdent(expr *exprpb.Expr) error { identName := expr.GetIdentExpr().GetName() // Validate identifier name for security (prevent SQL injection) - if err := validateFieldName(identName); err != nil { + if err := con.dialect.ValidateFieldName(identName); err != nil { return fmt.Errorf("%w: %w", ErrInvalidFieldName, err) } @@ -2382,7 +2269,8 @@ func (con *converter) visitIdent(expr *exprpb.Expr) error { if con.needsNumericCasting(identName) { con.str.WriteString("(") con.str.WriteString(identName) - con.str.WriteString(")::numeric") + con.str.WriteString(")") + con.dialect.WriteCastToNumeric(&con.str) } else { con.str.WriteString(identName) } @@ -2392,7 +2280,7 @@ func (con *converter) visitIdent(expr *exprpb.Expr) error { func (con *converter) visitList(expr *exprpb.Expr) error { l := expr.GetListExpr() elems := l.GetElements() - con.str.WriteString("ARRAY[") + con.dialect.WriteArrayLiteralOpen(&con.str) for i, elem := range elems { err := con.visit(elem) if err != nil { @@ -2402,7 +2290,7 @@ func (con *converter) visitList(expr *exprpb.Expr) error { con.str.WriteString(", ") } } - con.str.WriteString("]") + con.dialect.WriteArrayLiteralClose(&con.str) return nil } @@ -2411,7 +2299,7 @@ func (con *converter) visitSelect(expr *exprpb.Expr) error { // Validate field name for security (prevent SQL injection) fieldName := sel.GetField() - if err := validateFieldName(fieldName); err != nil { + if err := con.dialect.ValidateFieldName(fieldName); err != nil { return fmt.Errorf("%w: %w", ErrInvalidFieldName, err) } @@ -2433,34 +2321,35 @@ func (con *converter) visitSelect(expr *exprpb.Expr) error { nested := !sel.GetTestOnly() && isBinaryOrTernaryOperator(sel.GetOperand()) - if useJSONObjectAccess && con.isNumericJSONField(fieldName) { - // For numeric JSON fields, wrap in parentheses for casting - con.str.WriteString("(") - } - - err := con.visitMaybeNested(sel.GetOperand(), nested) - if err != nil { - return err + writeBase := func() error { + return con.visitMaybeNested(sel.GetOperand(), nested) } switch { case useJSONPath: - // Use ->> for text extraction - con.str.WriteString("->>") - con.str.WriteString("'") - con.str.WriteString(escapeJSONFieldName(fieldName)) - con.str.WriteString("'") + // Use dialect-specific JSON field access (text extraction) + if err := con.dialect.WriteJSONFieldAccess(&con.str, writeBase, fieldName, true); err != nil { + return err + } case useJSONObjectAccess: - // Use -> for JSON object field access in comprehensions - con.str.WriteString("->>'") - con.str.WriteString(escapeJSONFieldName(fieldName)) - con.str.WriteString("'") - if con.isNumericJSONField(fieldName) { + // Use dialect-specific JSON object field access in comprehensions + isNumeric := con.isNumericJSONField(fieldName) + if isNumeric { + con.str.WriteString("(") + } + if err := con.dialect.WriteJSONFieldAccess(&con.str, writeBase, fieldName, true); err != nil { + return err + } + if isNumeric { // Close parentheses and add numeric cast - con.str.WriteString(")::numeric") + con.str.WriteString(")") + con.dialect.WriteCastToNumeric(&con.str) } default: // Regular field selection + if err := writeBase(); err != nil { + return err + } con.str.WriteString(".") con.str.WriteString(fieldName) } @@ -2476,25 +2365,10 @@ func (con *converter) visitHasFunction(expr *exprpb.Expr) error { // Check if this is a direct JSON field access (e.g., table.json_column.key) if con.isDirectJSONFieldAccess(operand, field) { - // For direct JSON field access, use the appropriate existence operator - err := con.visitMaybeNested(operand, isBinaryOrTernaryOperator(operand)) - if err != nil { - return err - } - - // Check if this is a JSONB field - if con.isJSONBField(operand) { - // Use JSONB's ? operator for existence check - con.str.WriteString(" ? '") - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("'") - } else { - // For JSON fields, check if the field is not null - con.str.WriteString("->'") - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("' IS NOT NULL") - } - return nil + isJSONB := con.isJSONBField(operand) + return con.dialect.WriteJSONExistence(&con.str, isJSONB, field, func() error { + return con.visitMaybeNested(operand, isBinaryOrTernaryOperator(operand)) + }) } // Check if this is a nested JSON path (e.g., table.json_column.key.subkey) @@ -2532,27 +2406,12 @@ func (con *converter) isDirectJSONFieldAccess(operand *exprpb.Expr, _ string) bo // visitNestedJSONHas handles has() for deeply nested JSON paths func (con *converter) visitNestedJSONHas(expr *exprpb.Expr) error { - // For nested JSON paths, we use jsonb_extract_path_text and check for NOT NULL - // This is more reliable than trying to use ? operator on nested paths - con.str.WriteString("jsonb_extract_path_text(") - // Get the root JSON column and remaining path segments rootColumn, pathSegments := con.getJSONRootAndPath(expr) - // Visit the root column without adding JSON access operators - if err := con.visitJSONColumnReference(rootColumn); err != nil { - return err - } - - // Add path segments as arguments - for _, segment := range pathSegments { - con.str.WriteString(", '") - con.str.WriteString(escapeJSONFieldName(segment)) - con.str.WriteString("'") - } - - con.str.WriteString(") IS NOT NULL") - return nil + return con.dialect.WriteJSONExtractPath(&con.str, pathSegments, func() error { + return con.visitJSONColumnReference(rootColumn) + }) } // visitJSONColumnReference visits a JSON column reference without adding JSON access operators @@ -2676,7 +2535,7 @@ func (con *converter) visitStructMsg(expr *exprpb.Expr) error { func (con *converter) visitStructMap(expr *exprpb.Expr) error { m := expr.GetStructExpr() entries := m.GetEntries() - con.str.WriteString("ROW(") + con.dialect.WriteStructOpen(&con.str) for i, entry := range entries { v := entry.GetValue() if err := con.visit(v); err != nil { @@ -2686,10 +2545,27 @@ func (con *converter) visitStructMap(expr *exprpb.Expr) error { con.str.WriteString(", ") } } - con.str.WriteString(")") + con.dialect.WriteStructClose(&con.str) return nil } +// writeComprehensionSource writes the source expression for a comprehension (UNNEST or JSON function). +func (con *converter) writeComprehensionSource(iterRange *exprpb.Expr) error { + isJSONArray := con.isJSONArrayField(iterRange) + if isJSONArray { + jsonFunc := con.getJSONArrayFunction(iterRange) + isJSONB := con.isJSONBField(iterRange) + // Determine if we need text extraction or object extraction + asText := strings.HasSuffix(jsonFunc, "_text") + return con.dialect.WriteJSONArrayElements(&con.str, isJSONB, asText, func() error { + return con.visit(iterRange) + }) + } + return con.dialect.WriteUnnest(&con.str, func() error { + return con.visit(iterRange) + }) +} + func (con *converter) visitMaybeNested(expr *exprpb.Expr, nested bool) error { if nested { con.str.WriteString("(") @@ -2767,181 +2643,3 @@ func isBinaryOrTernaryOperator(expr *exprpb.Expr) bool { _, isBinaryOp := operators.FindReverseBinaryOperator(expr.GetCallExpr().GetFunction()) return isBinaryOp || isSamePrecedence(operators.Conditional, expr) } - -// convertRE2ToPOSIX converts an RE2 regex pattern to POSIX ERE format for PostgreSQL. -// It performs security validation to prevent ReDoS attacks (CWE-1333). -// Returns: (posixPattern, caseInsensitive, error) -// Note: This is a basic conversion for common patterns. Full RE2 to POSIX conversion is complex. -func convertRE2ToPOSIX(re2Pattern string) (string, bool, error) { - // 1. Check pattern length to prevent processing extremely long patterns - if len(re2Pattern) > maxRegexPatternLength { - return "", false, fmt.Errorf("%w: pattern length %d exceeds limit of %d characters", ErrInvalidRegexPattern, len(re2Pattern), maxRegexPatternLength) - } - - // 2. Extract case-insensitive flag if present - caseInsensitive := false - if strings.HasPrefix(re2Pattern, "(?i)") { - caseInsensitive = true - re2Pattern = strings.TrimPrefix(re2Pattern, "(?i)") - } - - // 3. Detect unsupported RE2 features and return errors - // Lookahead assertions - if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { - return "", false, fmt.Errorf("%w: lookahead assertions (?=...), (?!...) are not supported in PostgreSQL POSIX regex", ErrInvalidRegexPattern) - } - // Lookbehind assertions - if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported in PostgreSQL POSIX regex", ErrInvalidRegexPattern) - } - // Other inline flags (after we've already handled (?i)) - if strings.Contains(re2Pattern, "(?m") || strings.Contains(re2Pattern, "(?s") || strings.Contains(re2Pattern, "(?-") { - return "", false, fmt.Errorf("%w: inline flags other than (?i) are not supported in PostgreSQL POSIX regex", ErrInvalidRegexPattern) - } - - // 4. Detect catastrophic nested quantifiers that cause exponential backtracking - // Patterns like (a+)+, (a*)*, (x+x+)+, ((a)+b)+, etc. are extremely dangerous - - // Check for doubled quantifiers - if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { - return "", false, fmt.Errorf("%w: regex contains catastrophic nested quantifiers that could cause ReDoS", ErrInvalidRegexPattern) - } - - // Check for groups that contain quantifiers and are themselves quantified - // This catches patterns like (a+)+, ((a)+b)+, (a*b*)*, etc. - // We need to check if any opening paren eventually leads to a closing paren followed by a quantifier, - // and if there are quantifiers between those parens. - depth := 0 - groupHasQuantifier := make([]bool, 0) - - for i := 0; i < len(re2Pattern); i++ { - char := re2Pattern[i] - - // Skip escaped characters - if i > 0 && re2Pattern[i-1] == '\\' { - continue - } - - switch char { - case '(': - depth++ - groupHasQuantifier = append(groupHasQuantifier, false) - case ')': - if depth > 0 { - depth-- - // Check if the closing paren is followed by a quantifier - if i+1 < len(re2Pattern) { - nextChar := re2Pattern[i+1] - if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { - // This group is quantified. Check if it contains quantifiers - if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { - return "", false, fmt.Errorf("%w: regex contains catastrophic nested quantifiers that could cause ReDoS", ErrInvalidRegexPattern) - } - } - } - if len(groupHasQuantifier) > 0 { - // Pop the last group - if len(groupHasQuantifier) > 1 { - // If inner group had quantifier, mark outer group as having quantifier too - if groupHasQuantifier[len(groupHasQuantifier)-1] { - groupHasQuantifier[len(groupHasQuantifier)-2] = true - } - } - groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] - } - } - case '*', '+', '?': - // Mark that current group contains a quantifier - if len(groupHasQuantifier) > 0 { - groupHasQuantifier[len(groupHasQuantifier)-1] = true - } - case '{': - // Brace quantifier {n,m} - if len(groupHasQuantifier) > 0 { - groupHasQuantifier[len(groupHasQuantifier)-1] = true - } - } - } - - // 5. Count and limit capture groups to prevent memory exhaustion - groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, `\(`) - if groupCount > maxRegexGroups { - return "", false, fmt.Errorf("%w: regex contains %d capture groups, exceeds limit of %d", ErrInvalidRegexPattern, groupCount, maxRegexGroups) - } - - // 6. Detect exponential alternation patterns like (a|a)*b or (a|ab)* - alternationPattern := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) - if alternationPattern.MatchString(re2Pattern) { - // Check if alternation has overlapping branches (more dangerous) - // This is a simple heuristic - full analysis would be more complex - return "", false, fmt.Errorf("%w: regex contains quantified alternation that could cause ReDoS", ErrInvalidRegexPattern) - } - - // 7. Check nesting depth to prevent deeply nested patterns - maxDepth := 0 - currentDepth := 0 - for _, char := range re2Pattern { - if char == '(' && !strings.HasSuffix(re2Pattern[:strings.LastIndex(re2Pattern, string(char))], `\`) { - currentDepth++ - if currentDepth > maxDepth { - maxDepth = currentDepth - } - } else if char == ')' && !strings.HasSuffix(re2Pattern[:strings.LastIndex(re2Pattern, string(char))], `\`) { - currentDepth-- - } - } - if maxDepth > maxRegexNestingDepth { - return "", false, fmt.Errorf("%w: nesting depth %d exceeds limit of %d", ErrInvalidRegexPattern, maxDepth, maxRegexNestingDepth) - } - - // Passed all security checks - proceed with conversion - posixPattern := re2Pattern - - // Basic conversions for common differences between RE2 and POSIX: - - // 1. Word boundaries: \b -> [[:<:]] and [[:<:]] (PostgreSQL extension) - // Note: PostgreSQL supports \y for word boundaries in some contexts - posixPattern = strings.ReplaceAll(posixPattern, `\b`, `\y`) - - // 2. Non-word boundaries: \B -> [^[:alnum:]_] (approximate) - // This is a simplification; exact conversion is complex - posixPattern = strings.ReplaceAll(posixPattern, `\B`, `[^[:alnum:]_]`) - - // 3. Digit shortcuts: \d -> [[:digit:]] or [0-9] - posixPattern = strings.ReplaceAll(posixPattern, `\d`, `[[:digit:]]`) - - // 4. Non-digit shortcuts: \D -> [^[:digit:]] or [^0-9] - posixPattern = strings.ReplaceAll(posixPattern, `\D`, `[^[:digit:]]`) - - // 5. Word character shortcuts: \w -> [[:alnum:]_] - posixPattern = strings.ReplaceAll(posixPattern, `\w`, `[[:alnum:]_]`) - - // 6. Non-word character shortcuts: \W -> [^[:alnum:]_] - posixPattern = strings.ReplaceAll(posixPattern, `\W`, `[^[:alnum:]_]`) - - // 7. Whitespace shortcuts: \s -> [[:space:]] - posixPattern = strings.ReplaceAll(posixPattern, `\s`, `[[:space:]]`) - - // 8. Non-whitespace shortcuts: \S -> [^[:space:]] - posixPattern = strings.ReplaceAll(posixPattern, `\S`, `[^[:space:]]`) - - // 9. Non-capturing groups: (?:...) -> (...) - // POSIX ERE doesn't have non-capturing groups, so convert to regular groups - posixPattern = strings.ReplaceAll(posixPattern, `(?:`, `(`) - - // Note: Unsupported RE2 features that are now validated and return errors: - // - Lookahead/lookbehind assertions (?=...), (?!...), (?<=...), (?...) - ERROR - // - Case-insensitive flag (?i) - CONVERTED (returned as separate boolean) - // - Other inline flags (?m), (?s) - ERROR - // - // Converted features: - // - Non-capturing groups (?:...) - Converted to regular groups (...) - // - Character class shortcuts (\d, \w, \s, etc.) - Converted to POSIX equivalents - - return posixPattern, caseInsensitive, nil -} diff --git a/dialect/bigquery/dialect.go b/dialect/bigquery/dialect.go new file mode 100644 index 0000000..d097e9a --- /dev/null +++ b/dialect/bigquery/dialect.go @@ -0,0 +1,503 @@ +// Package bigquery implements the BigQuery SQL dialect for cel2sql. +package bigquery + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// Dialect implements dialect.Dialect for BigQuery. +type Dialect struct{} + +// New creates a new BigQuery dialect. +func New() *Dialect { + return &Dialect{} +} + +func init() { + dialect.Register(dialect.BigQuery, func() dialect.Dialect { return New() }) +} + +// Ensure Dialect implements dialect.Dialect at compile time. +var _ dialect.Dialect = (*Dialect)(nil) + +// Name returns the dialect name. +func (d *Dialect) Name() dialect.Name { return dialect.BigQuery } + +// --- Literals --- + +// WriteStringLiteral writes a BigQuery string literal with ” escaping. +func (d *Dialect) WriteStringLiteral(w *strings.Builder, value string) { + escaped := strings.ReplaceAll(value, "'", "\\'") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") +} + +// WriteBytesLiteral writes a BigQuery octal-encoded byte literal (b"..."). +func (d *Dialect) WriteBytesLiteral(w *strings.Builder, value []byte) error { + w.WriteString("b\"") + for _, b := range value { + fmt.Fprintf(w, "\\%03o", b) + } + w.WriteString("\"") + return nil +} + +// WriteParamPlaceholder writes a BigQuery named parameter (@p1, @p2, ...). +func (d *Dialect) WriteParamPlaceholder(w *strings.Builder, paramIndex int) { + fmt.Fprintf(w, "@p%d", paramIndex) +} + +// --- Operators --- + +// WriteStringConcat writes BigQuery string concatenation using ||. +func (d *Dialect) WriteStringConcat(w *strings.Builder, writeLHS, writeRHS func() error) error { + if err := writeLHS(); err != nil { + return err + } + w.WriteString(" || ") + return writeRHS() +} + +// WriteRegexMatch writes a BigQuery regex match using REGEXP_CONTAINS. +func (d *Dialect) WriteRegexMatch(w *strings.Builder, writeTarget func() error, pattern string, _ bool) error { + w.WriteString("REGEXP_CONTAINS(") + if err := writeTarget(); err != nil { + return err + } + w.WriteString(", '") + escaped := strings.ReplaceAll(pattern, "'", "\\'") + w.WriteString(escaped) + w.WriteString("')") + return nil +} + +// WriteLikeEscape writes the BigQuery LIKE escape clause. +func (d *Dialect) WriteLikeEscape(w *strings.Builder) { + w.WriteString(" ESCAPE '\\\\'") +} + +// WriteArrayMembership writes a BigQuery array membership test using IN UNNEST(). +func (d *Dialect) WriteArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + if err := writeElem(); err != nil { + return err + } + w.WriteString(" IN UNNEST(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- Type Casting --- + +// WriteCastToNumeric writes a BigQuery numeric cast (CAST(... AS FLOAT64)). +func (d *Dialect) WriteCastToNumeric(w *strings.Builder) { + // BigQuery doesn't have a ::type cast syntax; this is used after expressions. + // For BigQuery, the converter should use CAST(expr AS FLOAT64) instead. + w.WriteString("::FLOAT64") +} + +// WriteTypeName writes a BigQuery type name for CAST expressions. +func (d *Dialect) WriteTypeName(w *strings.Builder, celTypeName string) { + switch celTypeName { + case "bool": + w.WriteString("BOOL") + case "bytes": + w.WriteString("BYTES") + case "double": + w.WriteString("FLOAT64") + case "int": + w.WriteString("INT64") + case "string": + w.WriteString("STRING") + case "uint": + w.WriteString("INT64") + default: + w.WriteString(strings.ToUpper(celTypeName)) + } +} + +// WriteEpochExtract writes UNIX_SECONDS(expr) for BigQuery. +func (d *Dialect) WriteEpochExtract(w *strings.Builder, writeExpr func() error) error { + w.WriteString("UNIX_SECONDS(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteTimestampCast writes a BigQuery CAST to TIMESTAMP. +func (d *Dialect) WriteTimestampCast(w *strings.Builder, writeExpr func() error) error { + w.WriteString("CAST(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(" AS TIMESTAMP)") + return nil +} + +// --- Arrays --- + +// WriteArrayLiteralOpen writes the BigQuery array literal opening ([). +func (d *Dialect) WriteArrayLiteralOpen(w *strings.Builder) { + w.WriteString("[") +} + +// WriteArrayLiteralClose writes the BigQuery array literal closing (]). +func (d *Dialect) WriteArrayLiteralClose(w *strings.Builder) { + w.WriteString("]") +} + +// WriteArrayLength writes ARRAY_LENGTH(expr) for BigQuery. +func (d *Dialect) WriteArrayLength(w *strings.Builder, _ int, writeExpr func() error) error { + w.WriteString("ARRAY_LENGTH(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteListIndex writes BigQuery 0-indexed array access using OFFSET. +func (d *Dialect) WriteListIndex(w *strings.Builder, writeArray func() error, writeIndex func() error) error { + if err := writeArray(); err != nil { + return err + } + w.WriteString("[OFFSET(") + if err := writeIndex(); err != nil { + return err + } + w.WriteString(")]") + return nil +} + +// WriteListIndexConst writes BigQuery constant array index access using OFFSET. +func (d *Dialect) WriteListIndexConst(w *strings.Builder, writeArray func() error, index int64) error { + if err := writeArray(); err != nil { + return err + } + fmt.Fprintf(w, "[OFFSET(%d)]", index) + return nil +} + +// WriteEmptyTypedArray writes an empty BigQuery typed array. +func (d *Dialect) WriteEmptyTypedArray(w *strings.Builder, typeName string) { + w.WriteString("ARRAY<") + w.WriteString(bigqueryTypeName(typeName)) + w.WriteString(">[]") +} + +// --- JSON --- + +// WriteJSONFieldAccess writes BigQuery JSON field access using JSON_VALUE. +func (d *Dialect) WriteJSONFieldAccess(w *strings.Builder, writeBase func() error, fieldName string, isFinal bool) error { + escaped := escapeJSONFieldName(fieldName) + if isFinal { + w.WriteString("JSON_VALUE(") + } else { + w.WriteString("JSON_QUERY(") + } + if err := writeBase(); err != nil { + return err + } + w.WriteString(", '$.") + w.WriteString(escaped) + w.WriteString("')") + return nil +} + +// WriteJSONExistence writes a BigQuery JSON key existence check. +func (d *Dialect) WriteJSONExistence(w *strings.Builder, _ bool, fieldName string, writeBase func() error) error { + escaped := escapeJSONFieldName(fieldName) + w.WriteString("JSON_VALUE(") + if err := writeBase(); err != nil { + return err + } + w.WriteString(", '$.") + w.WriteString(escaped) + w.WriteString("') IS NOT NULL") + return nil +} + +// WriteJSONArrayElements writes BigQuery JSON array expansion. +func (d *Dialect) WriteJSONArrayElements(w *strings.Builder, _ bool, _ bool, writeExpr func() error) error { + w.WriteString("UNNEST(JSON_QUERY_ARRAY(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// WriteJSONArrayLength writes ARRAY_LENGTH(JSON_QUERY_ARRAY(expr)) for BigQuery. +func (d *Dialect) WriteJSONArrayLength(w *strings.Builder, writeExpr func() error) error { + w.WriteString("ARRAY_LENGTH(JSON_QUERY_ARRAY(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// WriteJSONExtractPath writes BigQuery JSON path existence using JSON_VALUE. +func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string, writeRoot func() error) error { + w.WriteString("JSON_VALUE(") + if err := writeRoot(); err != nil { + return err + } + w.WriteString(", '$") + for _, segment := range pathSegments { + w.WriteString(".") + w.WriteString(escapeJSONFieldName(segment)) + } + w.WriteString("') IS NOT NULL") + return nil +} + +// WriteJSONArrayMembership writes BigQuery JSON array membership. +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeExpr func() error) error { + w.WriteString("UNNEST(JSON_VALUE_ARRAY(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// WriteNestedJSONArrayMembership writes BigQuery nested JSON array membership. +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { + w.WriteString("UNNEST(JSON_VALUE_ARRAY(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// --- Timestamps --- + +// WriteDuration writes a BigQuery INTERVAL literal. +func (d *Dialect) WriteDuration(w *strings.Builder, value int64, unit string) { + fmt.Fprintf(w, "INTERVAL %d %s", value, unit) +} + +// WriteInterval writes a BigQuery INTERVAL expression. +func (d *Dialect) WriteInterval(w *strings.Builder, writeValue func() error, unit string) error { + w.WriteString("INTERVAL ") + if err := writeValue(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(unit) + return nil +} + +// WriteExtract writes a BigQuery EXTRACT expression. +// BigQuery uses DAYOFWEEK (1=Sunday) instead of DOW (0=Sunday). +func (d *Dialect) WriteExtract(w *strings.Builder, part string, writeExpr func() error, writeTZ func() error) error { + isDOW := part == "DOW" + bqPart := part + if isDOW { + bqPart = "DAYOFWEEK" + w.WriteString("(") + } + w.WriteString("EXTRACT(") + w.WriteString(bqPart) + w.WriteString(" FROM ") + if err := writeExpr(); err != nil { + return err + } + if writeTZ != nil { + w.WriteString(" AT TIME ZONE ") + if err := writeTZ(); err != nil { + return err + } + } + w.WriteString(")") + if isDOW { + // BigQuery DAYOFWEEK: 1=Sunday, 2=Monday, ..., 7=Saturday + // CEL getDayOfWeek: 0=Sunday, 1=Monday, ..., 6=Saturday + w.WriteString(" - 1)") + } + return nil +} + +// WriteTimestampArithmetic writes BigQuery timestamp arithmetic using functions. +// BigQuery uses TIMESTAMP_ADD/TIMESTAMP_SUB instead of + / - operators. +func (d *Dialect) WriteTimestampArithmetic(w *strings.Builder, op string, writeTS, writeDur func() error) error { + if op == "+" { + w.WriteString("TIMESTAMP_ADD(") + } else { + w.WriteString("TIMESTAMP_SUB(") + } + if err := writeTS(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDur(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- String Functions --- + +// WriteContains writes INSTR(haystack, needle) != 0 for BigQuery. +func (d *Dialect) WriteContains(w *strings.Builder, writeHaystack, writeNeedle func() error) error { + w.WriteString("INSTR(") + if err := writeHaystack(); err != nil { + return err + } + w.WriteString(", ") + if err := writeNeedle(); err != nil { + return err + } + w.WriteString(") != 0") + return nil +} + +// WriteSplit writes BigQuery string split using SPLIT. +func (d *Dialect) WriteSplit(w *strings.Builder, writeStr, writeDelim func() error) error { + w.WriteString("SPLIT(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteSplitWithLimit writes BigQuery SPLIT with array slice. +func (d *Dialect) WriteSplitWithLimit(w *strings.Builder, writeStr, writeDelim func() error, limit int64) error { + w.WriteString("ARRAY(SELECT x FROM UNNEST(SPLIT(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + fmt.Fprintf(w, ")) AS x WITH OFFSET WHERE OFFSET < %d)", limit) + return nil +} + +// WriteJoin writes BigQuery array join using ARRAY_TO_STRING. +func (d *Dialect) WriteJoin(w *strings.Builder, writeArray, writeDelim func() error) error { + w.WriteString("ARRAY_TO_STRING(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- Comprehensions --- + +// WriteUnnest writes BigQuery UNNEST for array unnesting. +func (d *Dialect) WriteUnnest(w *strings.Builder, writeSource func() error) error { + w.WriteString("UNNEST(") + if err := writeSource(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteArraySubqueryOpen writes ARRAY(SELECT for BigQuery. +func (d *Dialect) WriteArraySubqueryOpen(w *strings.Builder) { + w.WriteString("ARRAY(SELECT ") +} + +// WriteArraySubqueryExprClose is a no-op for BigQuery. +func (d *Dialect) WriteArraySubqueryExprClose(_ *strings.Builder) { +} + +// --- Struct --- + +// WriteStructOpen writes the BigQuery struct literal opening. +func (d *Dialect) WriteStructOpen(w *strings.Builder) { + w.WriteString("STRUCT(") +} + +// WriteStructClose writes the BigQuery struct literal closing. +func (d *Dialect) WriteStructClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Validation --- + +// MaxIdentifierLength returns 300 for BigQuery. +func (d *Dialect) MaxIdentifierLength() int { + return 300 +} + +// ValidateFieldName validates a field name against BigQuery naming rules. +func (d *Dialect) ValidateFieldName(name string) error { + return validateFieldName(name) +} + +// ReservedKeywords returns the set of reserved SQL keywords for BigQuery. +func (d *Dialect) ReservedKeywords() map[string]bool { + return reservedSQLKeywords +} + +// --- Regex --- + +// ConvertRegex converts an RE2 regex pattern for BigQuery. +// BigQuery uses RE2 natively, so minimal conversion is needed. +func (d *Dialect) ConvertRegex(re2Pattern string) (string, bool, error) { + return convertRE2ToBigQuery(re2Pattern) +} + +// SupportsRegex returns true as BigQuery supports RE2 regex natively. +func (d *Dialect) SupportsRegex() bool { return true } + +// --- Capabilities --- + +// SupportsNativeArrays returns true as BigQuery has native array types. +func (d *Dialect) SupportsNativeArrays() bool { return true } + +// SupportsJSONB returns false as BigQuery has a single JSON type. +func (d *Dialect) SupportsJSONB() bool { return false } + +// SupportsIndexAnalysis returns true as BigQuery index analysis is supported. +func (d *Dialect) SupportsIndexAnalysis() bool { return true } + +// --- Internal helpers --- + +// escapeJSONFieldName escapes special characters in JSON field names for BigQuery. +func escapeJSONFieldName(fieldName string) string { + return strings.ReplaceAll(fieldName, "'", "\\'") +} + +// bigqueryTypeName converts a CEL/common type name to a BigQuery type name. +func bigqueryTypeName(typeName string) string { + switch strings.ToLower(typeName) { + case "text", "string", "varchar": + return "STRING" + case "int", "integer", "bigint", "int64": + return "INT64" + case "double", "float", "real", "float64": + return "FLOAT64" + case "boolean", "bool": + return "BOOL" + case "bytes", "bytea", "blob": + return "BYTES" + default: + return strings.ToUpper(typeName) + } +} diff --git a/dialect/bigquery/index_advisor.go b/dialect/bigquery/index_advisor.go new file mode 100644 index 0000000..bbc9bce --- /dev/null +++ b/dialect/bigquery/index_advisor.go @@ -0,0 +1,86 @@ +package bigquery + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// BigQuery index type constants. +const ( + IndexTypeClustering = "CLUSTERING" + IndexTypeSearchIndex = "SEARCH_INDEX" +) + +// RecommendIndex generates a BigQuery-specific index recommendation for the given pattern. +// BigQuery uses clustering keys and search indexes. Returns nil for unsupported patterns. +func (d *Dialect) RecommendIndex(pattern dialect.IndexPattern) *dialect.IndexRecommendation { + table := pattern.TableHint + if table == "" { + table = "table_name" + } + col := pattern.Column + safeName := sanitizeIndexName(col) + + switch pattern.Pattern { + case dialect.PatternComparison: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeClustering, + Expression: fmt.Sprintf("ALTER TABLE %s SET OPTIONS (clustering_columns=['%s']);", + table, col), + Reason: fmt.Sprintf("Comparison operations on '%s' benefit from clustering for efficient partition pruning and range scans", col), + } + + case dialect.PatternJSONAccess: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeSearchIndex, + Expression: fmt.Sprintf("CREATE SEARCH INDEX idx_%s ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("JSON field access on '%s' benefits from a search index for efficient nested field lookups", col), + } + + case dialect.PatternRegexMatch: + // BigQuery does not have specialized regex indexes + return nil + + case dialect.PatternArrayMembership, dialect.PatternArrayComprehension: + // BigQuery arrays do not benefit from standalone indexes + return nil + + case dialect.PatternJSONArrayComprehension: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeSearchIndex, + Expression: fmt.Sprintf("CREATE SEARCH INDEX idx_%s ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("JSON array operations on '%s' may benefit from a search index", col), + } + } + + return nil +} + +// SupportedPatterns returns the pattern types supported by BigQuery. +func (d *Dialect) SupportedPatterns() []dialect.PatternType { + return []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternJSONArrayComprehension, + } +} + +// sanitizeIndexName creates a safe index name from a column name. +func sanitizeIndexName(column string) string { + sanitized := strings.ReplaceAll(column, ".", "_") + sanitized = strings.ReplaceAll(sanitized, " ", "_") + sanitized = strings.ReplaceAll(sanitized, "-", "_") + + if len(sanitized) > 50 { + sanitized = sanitized[:50] + } + + return sanitized +} diff --git a/dialect/bigquery/regex.go b/dialect/bigquery/regex.go new file mode 100644 index 0000000..2fcfbf7 --- /dev/null +++ b/dialect/bigquery/regex.go @@ -0,0 +1,137 @@ +package bigquery + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). +const ( + maxRegexPatternLength = 500 + maxRegexGroups = 20 + maxRegexNestingDepth = 10 +) + +// convertRE2ToBigQuery converts an RE2 regex pattern to BigQuery-compatible format. +// BigQuery uses RE2 natively, so most patterns pass through unchanged. +// Returns the converted pattern, whether it's case-insensitive, and any error. +func convertRE2ToBigQuery(re2Pattern string) (string, bool, error) { + // 1. Pattern length validation + if len(re2Pattern) > maxRegexPatternLength { + return "", false, fmt.Errorf("pattern length %d exceeds limit of %d characters", len(re2Pattern), maxRegexPatternLength) + } + + // 2. Validate pattern compiles + if _, err := regexp.Compile(re2Pattern); err != nil { + return "", false, fmt.Errorf("invalid regex pattern: %w", err) + } + + // 3. Detect unsupported features (lookahead/lookbehind not in RE2 anyway) + if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { + return "", false, errors.New("lookahead assertions (?=...), (?!...) are not supported") + } + if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported") + } + + // 4. Detect catastrophic nested quantifiers + if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + + // 5. Check for nested quantifiers in groups + depth := 0 + groupHasQuantifier := make([]bool, 0) + for i := 0; i < len(re2Pattern); i++ { + char := re2Pattern[i] + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + switch char { + case '(': + depth++ + groupHasQuantifier = append(groupHasQuantifier, false) + case ')': + if depth > 0 { + depth-- + if i+1 < len(re2Pattern) { + nextChar := re2Pattern[i+1] + if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { + if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + } + } + if len(groupHasQuantifier) > 0 { + groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] + } + } + case '*', '+', '?', '{': + for j := range groupHasQuantifier { + groupHasQuantifier[j] = true + } + } + } + + // 6. Check group count limit + groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, "\\(") + if groupCount > maxRegexGroups { + return "", false, fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, maxRegexGroups) + } + + // 7. Check for quantified alternation + quantifiedAlternation := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) + if quantifiedAlternation.MatchString(re2Pattern) { + return "", false, errors.New("regex contains quantified alternation that could cause ReDoS") + } + + // 8. Check nesting depth + maxDepthVal := 0 + currentDepth := 0 + for i := 0; i < len(re2Pattern); i++ { + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + switch re2Pattern[i] { + case '(': + currentDepth++ + if currentDepth > maxDepthVal { + maxDepthVal = currentDepth + } + case ')': + if currentDepth > 0 { + currentDepth-- + } + } + } + if maxDepthVal > maxRegexNestingDepth { + return "", false, fmt.Errorf("nesting depth %d exceeds limit of %d", maxDepthVal, maxRegexNestingDepth) + } + + // Process pattern: BigQuery uses RE2 natively, so minimal conversion needed + caseInsensitive := false + pattern := re2Pattern + + // Handle (?i) flag - BigQuery REGEXP_CONTAINS embeds the flag in the pattern + if strings.HasPrefix(pattern, "(?i)") { + caseInsensitive = true + pattern = pattern[4:] + } + + // Handle inline flags other than (?i) at start + if strings.Contains(pattern, "(?m") || strings.Contains(pattern, "(?s") || strings.Contains(pattern, "(?-") { + return "", false, errors.New("inline flags other than (?i) are not supported in BigQuery regex") + } + + // Convert non-capturing groups (?:...) to regular groups (...) + pattern = strings.ReplaceAll(pattern, "(?:", "(") + + // BigQuery RE2 supports \d, \w, \s, \b natively - no conversion needed + + return pattern, caseInsensitive, nil +} diff --git a/dialect/bigquery/validation.go b/dialect/bigquery/validation.go new file mode 100644 index 0000000..0ae982d --- /dev/null +++ b/dialect/bigquery/validation.go @@ -0,0 +1,56 @@ +package bigquery + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +var ( + // fieldNameRegexp validates BigQuery identifier format. + fieldNameRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + + // reservedSQLKeywords contains BigQuery reserved keywords. + reservedSQLKeywords = map[string]bool{ + "all": true, "and": true, "any": true, "array": true, "as": true, + "asc": true, "assert_rows_modified": true, "at": true, "between": true, + "by": true, "case": true, "cast": true, "collate": true, "contains": true, + "create": true, "cross": true, "cube": true, "current": true, + "default": true, "define": true, "desc": true, "distinct": true, + "else": true, "end": true, "enum": true, "escape": true, "except": true, + "exclude": true, "exists": true, "extract": true, "false": true, + "fetch": true, "following": true, "for": true, "from": true, "full": true, + "group": true, "grouping": true, "groups": true, "hash": true, + "having": true, "if": true, "ignore": true, "in": true, "inner": true, + "intersect": true, "interval": true, "into": true, "is": true, + "join": true, "lateral": true, "left": true, "like": true, "limit": true, + "lookup": true, "merge": true, "natural": true, "new": true, "no": true, + "not": true, "null": true, "nulls": true, "of": true, "on": true, + "or": true, "order": true, "outer": true, "over": true, + "partition": true, "preceding": true, "proto": true, "range": true, + "recursive": true, "respect": true, "right": true, "rollup": true, + "rows": true, "select": true, "set": true, "some": true, "struct": true, + "tablesample": true, "then": true, "to": true, "treat": true, + "true": true, "unbounded": true, "union": true, "unnest": true, + "using": true, "when": true, "where": true, "window": true, + "with": true, "within": true, + } +) + +// validateFieldName validates that a field name follows BigQuery naming conventions. +func validateFieldName(name string) error { + if len(name) == 0 { + return errors.New("field name cannot be empty") + } + + if !fieldNameRegexp.MatchString(name) { + return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) + } + + if reservedSQLKeywords[strings.ToLower(name)] { + return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) + } + + return nil +} diff --git a/dialect/dialect.go b/dialect/dialect.go new file mode 100644 index 0000000..7b0e30b --- /dev/null +++ b/dialect/dialect.go @@ -0,0 +1,234 @@ +// Package dialect defines the interface for SQL dialect-specific code generation. +// Each supported database implements this interface to produce correct SQL syntax. +package dialect + +import ( + "errors" + "strings" +) + +// Name represents a SQL dialect name. +type Name string + +// Supported SQL dialect names. +const ( + PostgreSQL Name = "postgresql" + MySQL Name = "mysql" + SQLite Name = "sqlite" + DuckDB Name = "duckdb" + BigQuery Name = "bigquery" +) + +// ErrUnsupportedFeature indicates that the requested feature is not supported by this dialect. +var ErrUnsupportedFeature = errors.New("unsupported dialect feature") + +// Dialect defines the interface for SQL dialect-specific code generation. +// The converter calls these methods at every point where SQL syntax diverges +// between databases. Methods receive a *strings.Builder that shares the +// converter's output buffer, and callback functions for writing sub-expressions. +type Dialect interface { + // Name returns the dialect name. + Name() Name + + // --- Literals --- + + // WriteStringLiteral writes a string literal in the dialect's syntax. + // For PostgreSQL: 'value' with '' escaping. + WriteStringLiteral(w *strings.Builder, value string) + + // WriteBytesLiteral writes a byte array literal in the dialect's syntax. + // For PostgreSQL: '\xDEADBEEF'. + WriteBytesLiteral(w *strings.Builder, value []byte) error + + // WriteParamPlaceholder writes a parameter placeholder. + // For PostgreSQL: $1, $2. For MySQL: ?, ?. For BigQuery: @p1, @p2. + WriteParamPlaceholder(w *strings.Builder, paramIndex int) + + // --- Operators --- + + // WriteStringConcat writes a string concatenation expression. + // For PostgreSQL: lhs || rhs. For MySQL: CONCAT(lhs, rhs). + WriteStringConcat(w *strings.Builder, writeLHS, writeRHS func() error) error + + // WriteRegexMatch writes a regex match expression. + // For PostgreSQL: expr ~ 'pattern' or expr ~* 'pattern'. + WriteRegexMatch(w *strings.Builder, writeTarget func() error, pattern string, caseInsensitive bool) error + + // WriteLikeEscape writes the LIKE escape clause. + // For PostgreSQL: ESCAPE E'\\'. For MySQL: ESCAPE '\\'. + WriteLikeEscape(w *strings.Builder) + + // WriteArrayMembership writes an array membership test. + // For PostgreSQL: elem = ANY(array). For MySQL: JSON_CONTAINS(array, elem). + WriteArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error + + // --- Type Casting --- + + // WriteCastToNumeric writes a cast to numeric type. + // For PostgreSQL: ::numeric. For MySQL: CAST(... AS DECIMAL). + WriteCastToNumeric(w *strings.Builder) + + // WriteTypeName writes a type name for CAST expressions. + // For PostgreSQL: BOOLEAN, BYTEA, DOUBLE PRECISION, BIGINT, TEXT. + WriteTypeName(w *strings.Builder, celTypeName string) + + // WriteEpochExtract writes extraction of epoch from a timestamp. + // For PostgreSQL: EXTRACT(EPOCH FROM expr)::bigint. + WriteEpochExtract(w *strings.Builder, writeExpr func() error) error + + // WriteTimestampCast writes a cast to timestamp type. + // For PostgreSQL: CAST(expr AS TIMESTAMP WITH TIME ZONE). + WriteTimestampCast(w *strings.Builder, writeExpr func() error) error + + // --- Arrays --- + + // WriteArrayLiteralOpen writes the opening of an array literal. + // For PostgreSQL: ARRAY[. For DuckDB: [. + WriteArrayLiteralOpen(w *strings.Builder) + + // WriteArrayLiteralClose writes the closing of an array literal. + // For PostgreSQL: ]. For DuckDB: ]. + WriteArrayLiteralClose(w *strings.Builder) + + // WriteArrayLength writes an array length expression. + // For PostgreSQL: COALESCE(ARRAY_LENGTH(expr, dimension), 0). + WriteArrayLength(w *strings.Builder, dimension int, writeExpr func() error) error + + // WriteListIndex writes a list index expression. + // For PostgreSQL: array[index + 1] (1-indexed). + WriteListIndex(w *strings.Builder, writeArray func() error, writeIndex func() error) error + + // WriteListIndexConst writes a constant list index. + // For PostgreSQL: array[idx+1] (converts 0-indexed to 1-indexed). + WriteListIndexConst(w *strings.Builder, writeArray func() error, index int64) error + + // WriteEmptyTypedArray writes an empty typed array literal. + // For PostgreSQL: ARRAY[]::text[]. + WriteEmptyTypedArray(w *strings.Builder, typeName string) + + // --- JSON --- + + // WriteJSONFieldAccess writes JSON field access. + // For PostgreSQL: base->>'field' (text) or base->'field' (json). + // For SQLite: json_extract(base, '$.field'). + // writeBase writes the base expression; the dialect wraps or appends as needed. + WriteJSONFieldAccess(w *strings.Builder, writeBase func() error, fieldName string, isFinal bool) error + + // WriteJSONExistence writes a JSON key existence check. + // For PostgreSQL (JSONB): ? 'key'. For PostgreSQL (JSON): ->'key' IS NOT NULL. + WriteJSONExistence(w *strings.Builder, isJSONB bool, fieldName string, writeBase func() error) error + + // WriteJSONArrayElements writes a call to extract JSON array elements. + // For PostgreSQL: jsonb_array_elements(expr) or json_array_elements(expr). + WriteJSONArrayElements(w *strings.Builder, isJSONB bool, asText bool, writeExpr func() error) error + + // WriteJSONArrayLength writes a JSON array length expression. + // For PostgreSQL: COALESCE(jsonb_array_length(expr), 0). + WriteJSONArrayLength(w *strings.Builder, writeExpr func() error) error + + // WriteJSONExtractPath writes a JSON path extraction function. + // For PostgreSQL: jsonb_extract_path_text(root, 'seg1', 'seg2'). + WriteJSONExtractPath(w *strings.Builder, pathSegments []string, writeRoot func() error) error + + // WriteJSONArrayMembership writes a JSON array membership test for the IN operator. + // For PostgreSQL: ANY(ARRAY(SELECT jsonb_func(expr))). + WriteJSONArrayMembership(w *strings.Builder, jsonFunc string, writeExpr func() error) error + + // WriteNestedJSONArrayMembership writes a nested JSON array membership test. + // For PostgreSQL: ANY(ARRAY(SELECT jsonb_array_elements_text(expr))). + WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error + + // --- Timestamps --- + + // WriteDuration writes a duration/interval literal. + // For PostgreSQL: INTERVAL N UNIT. + WriteDuration(w *strings.Builder, value int64, unit string) + + // WriteInterval writes an INTERVAL expression from a variable. + // For PostgreSQL: INTERVAL expr UNIT. + WriteInterval(w *strings.Builder, writeValue func() error, unit string) error + + // WriteExtract writes a timestamp field extraction expression. + // Handles DOW conversion, Month/DOY adjustment, and timezone support. + WriteExtract(w *strings.Builder, part string, writeExpr func() error, writeTZ func() error) error + + // WriteTimestampArithmetic writes timestamp arithmetic. + // For PostgreSQL: timestamp +/- interval. + WriteTimestampArithmetic(w *strings.Builder, op string, writeTS, writeDur func() error) error + + // --- String Functions --- + + // WriteContains writes a string contains expression. + // For PostgreSQL: POSITION(needle IN haystack) > 0. + WriteContains(w *strings.Builder, writeHaystack, writeNeedle func() error) error + + // WriteSplit writes a string split expression. + // For PostgreSQL: STRING_TO_ARRAY(string, delimiter). + WriteSplit(w *strings.Builder, writeStr, writeDelim func() error) error + + // WriteSplitWithLimit writes a string split expression with a limit. + // For PostgreSQL: (STRING_TO_ARRAY(string, delimiter))[1:limit]. + WriteSplitWithLimit(w *strings.Builder, writeStr, writeDelim func() error, limit int64) error + + // WriteJoin writes an array join expression. + // For PostgreSQL: ARRAY_TO_STRING(array, delimiter, ''). + WriteJoin(w *strings.Builder, writeArray, writeDelim func() error) error + + // --- Comprehensions --- + + // WriteUnnest writes the UNNEST source for comprehensions. + // For PostgreSQL: UNNEST(array). For MySQL: JSON_TABLE(...). + WriteUnnest(w *strings.Builder, writeSource func() error) error + + // WriteArraySubqueryOpen writes the prefix before the transform expression + // in an array-building subquery. + // For PostgreSQL: "ARRAY(SELECT ". For SQLite: "(SELECT json_group_array(". + WriteArraySubqueryOpen(w *strings.Builder) + + // WriteArraySubqueryExprClose writes the suffix after the transform expression + // and before FROM in an array-building subquery. + // For PostgreSQL: "" (nothing). For SQLite: ")". + WriteArraySubqueryExprClose(w *strings.Builder) + + // --- Struct --- + + // WriteStructOpen writes the opening of a struct/row literal. + // For PostgreSQL: ROW(. For BigQuery: STRUCT(. + WriteStructOpen(w *strings.Builder) + + // WriteStructClose writes the closing of a struct/row literal. + // For PostgreSQL: ). For BigQuery: ). + WriteStructClose(w *strings.Builder) + + // --- Validation --- + + // MaxIdentifierLength returns the maximum identifier length for this dialect. + // For PostgreSQL: 63. For MySQL: 64. For SQLite: unlimited (0). + MaxIdentifierLength() int + + // ValidateFieldName validates a field name for this dialect. + ValidateFieldName(name string) error + + // ReservedKeywords returns the set of reserved SQL keywords for this dialect. + ReservedKeywords() map[string]bool + + // --- Regex --- + + // ConvertRegex converts an RE2 regex pattern to the dialect's native format. + // Returns: (convertedPattern, caseInsensitive, error). + ConvertRegex(re2Pattern string) (pattern string, caseInsensitive bool, err error) + + // SupportsRegex indicates whether this dialect supports regex matching. + SupportsRegex() bool + + // --- Capabilities --- + + // SupportsNativeArrays indicates whether this dialect has native array types. + SupportsNativeArrays() bool + + // SupportsJSONB indicates whether this dialect has a distinct JSONB type. + SupportsJSONB() bool + + // SupportsIndexAnalysis indicates whether index analysis is supported. + SupportsIndexAnalysis() bool +} diff --git a/dialect/duckdb/dialect.go b/dialect/duckdb/dialect.go new file mode 100644 index 0000000..5a22235 --- /dev/null +++ b/dialect/duckdb/dialect.go @@ -0,0 +1,473 @@ +// Package duckdb implements the DuckDB SQL dialect for cel2sql. +package duckdb + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// Dialect implements dialect.Dialect for DuckDB. +type Dialect struct{} + +// New creates a new DuckDB dialect. +func New() *Dialect { + return &Dialect{} +} + +func init() { + dialect.Register(dialect.DuckDB, func() dialect.Dialect { return New() }) +} + +// Ensure Dialect implements dialect.Dialect at compile time. +var _ dialect.Dialect = (*Dialect)(nil) + +// Name returns the dialect name. +func (d *Dialect) Name() dialect.Name { return dialect.DuckDB } + +// --- Literals --- + +// WriteStringLiteral writes a DuckDB string literal with ” escaping. +func (d *Dialect) WriteStringLiteral(w *strings.Builder, value string) { + escaped := strings.ReplaceAll(value, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") +} + +// WriteBytesLiteral writes a DuckDB hex-encoded byte literal ('\x...'). +func (d *Dialect) WriteBytesLiteral(w *strings.Builder, value []byte) error { + w.WriteString("'\\x") + for _, b := range value { + fmt.Fprintf(w, "%02x", b) + } + w.WriteString("'") + return nil +} + +// WriteParamPlaceholder writes a DuckDB positional parameter ($1, $2, ...). +func (d *Dialect) WriteParamPlaceholder(w *strings.Builder, paramIndex int) { + fmt.Fprintf(w, "$%d", paramIndex) +} + +// --- Operators --- + +// WriteStringConcat writes DuckDB string concatenation using ||. +func (d *Dialect) WriteStringConcat(w *strings.Builder, writeLHS, writeRHS func() error) error { + if err := writeLHS(); err != nil { + return err + } + w.WriteString(" || ") + return writeRHS() +} + +// WriteRegexMatch writes a DuckDB regex match expression using ~ or ~*. +func (d *Dialect) WriteRegexMatch(w *strings.Builder, writeTarget func() error, pattern string, caseInsensitive bool) error { + if err := writeTarget(); err != nil { + return err + } + if caseInsensitive { + w.WriteString(" ~* ") + } else { + w.WriteString(" ~ ") + } + escaped := strings.ReplaceAll(pattern, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") + return nil +} + +// WriteLikeEscape writes the DuckDB LIKE escape clause. +func (d *Dialect) WriteLikeEscape(w *strings.Builder) { + w.WriteString(" ESCAPE '\\\\'") +} + +// WriteArrayMembership writes a DuckDB array membership test using = ANY(). +func (d *Dialect) WriteArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + if err := writeElem(); err != nil { + return err + } + w.WriteString(" = ANY(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- Type Casting --- + +// WriteCastToNumeric writes a DuckDB numeric cast (::DOUBLE). +func (d *Dialect) WriteCastToNumeric(w *strings.Builder) { + w.WriteString("::DOUBLE") +} + +// WriteTypeName writes a DuckDB type name for CAST expressions. +func (d *Dialect) WriteTypeName(w *strings.Builder, celTypeName string) { + switch celTypeName { + case "bool": + w.WriteString("BOOLEAN") + case "bytes": + w.WriteString("BLOB") + case "double": + w.WriteString("DOUBLE") + case "int": + w.WriteString("BIGINT") + case "string": + w.WriteString("VARCHAR") + case "uint": + w.WriteString("UBIGINT") + default: + w.WriteString(strings.ToUpper(celTypeName)) + } +} + +// WriteEpochExtract writes EXTRACT(EPOCH FROM expr)::BIGINT for DuckDB. +func (d *Dialect) WriteEpochExtract(w *strings.Builder, writeExpr func() error) error { + w.WriteString("EXTRACT(EPOCH FROM ") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")::BIGINT") + return nil +} + +// WriteTimestampCast writes a DuckDB CAST to TIMESTAMPTZ. +func (d *Dialect) WriteTimestampCast(w *strings.Builder, writeExpr func() error) error { + w.WriteString("CAST(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(" AS TIMESTAMPTZ)") + return nil +} + +// --- Arrays --- + +// WriteArrayLiteralOpen writes the DuckDB array literal opening ([). +func (d *Dialect) WriteArrayLiteralOpen(w *strings.Builder) { + w.WriteString("[") +} + +// WriteArrayLiteralClose writes the DuckDB array literal closing (]). +func (d *Dialect) WriteArrayLiteralClose(w *strings.Builder) { + w.WriteString("]") +} + +// WriteArrayLength writes COALESCE(array_length(expr), 0) for DuckDB. +func (d *Dialect) WriteArrayLength(w *strings.Builder, _ int, writeExpr func() error) error { + w.WriteString("COALESCE(array_length(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteListIndex writes DuckDB 1-indexed array access. +func (d *Dialect) WriteListIndex(w *strings.Builder, writeArray func() error, writeIndex func() error) error { + if err := writeArray(); err != nil { + return err + } + w.WriteString("[") + if err := writeIndex(); err != nil { + return err + } + w.WriteString(" + 1]") + return nil +} + +// WriteListIndexConst writes DuckDB constant array index access (0-indexed to 1-indexed). +func (d *Dialect) WriteListIndexConst(w *strings.Builder, writeArray func() error, index int64) error { + if err := writeArray(); err != nil { + return err + } + fmt.Fprintf(w, "[%d]", index+1) + return nil +} + +// WriteEmptyTypedArray writes an empty DuckDB typed array. +func (d *Dialect) WriteEmptyTypedArray(w *strings.Builder, typeName string) { + w.WriteString("[]::") //nolint:gocritic + w.WriteString(typeName) + w.WriteString("[]") +} + +// --- JSON --- + +// WriteJSONFieldAccess writes DuckDB JSON field access using -> or ->> operators. +func (d *Dialect) WriteJSONFieldAccess(w *strings.Builder, writeBase func() error, fieldName string, isFinal bool) error { + if err := writeBase(); err != nil { + return err + } + escaped := escapeJSONFieldName(fieldName) + if isFinal { + w.WriteString("->>'") + } else { + w.WriteString("->'") + } + w.WriteString(escaped) + w.WriteString("'") + return nil +} + +// WriteJSONExistence writes a DuckDB JSON key existence check using json_exists. +func (d *Dialect) WriteJSONExistence(w *strings.Builder, _ bool, fieldName string, writeBase func() error) error { + w.WriteString("json_exists(") + if err := writeBase(); err != nil { + return err + } + escaped := escapeJSONFieldName(fieldName) + w.WriteString(", '$.") + w.WriteString(escaped) + w.WriteString("')") + return nil +} + +// WriteJSONArrayElements writes DuckDB JSON array expansion using json_each. +func (d *Dialect) WriteJSONArrayElements(w *strings.Builder, _ bool, _ bool, writeExpr func() error) error { + w.WriteString("json_each(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteJSONArrayLength writes COALESCE(json_array_length(expr), 0) for DuckDB. +func (d *Dialect) WriteJSONArrayLength(w *strings.Builder, writeExpr func() error) error { + w.WriteString("COALESCE(json_array_length(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteJSONExtractPath writes DuckDB JSON path existence using json_exists. +func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string, writeRoot func() error) error { + w.WriteString("json_exists(") + if err := writeRoot(); err != nil { + return err + } + w.WriteString(", '$") + for _, segment := range pathSegments { + w.WriteString(".") + w.WriteString(escapeJSONFieldName(segment)) + } + w.WriteString("')") + return nil +} + +// WriteJSONArrayMembership writes DuckDB JSON array membership using json_each. +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeExpr func() error) error { + w.WriteString("(SELECT value FROM json_each(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// WriteNestedJSONArrayMembership writes DuckDB nested JSON array membership. +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { + w.WriteString("(SELECT value FROM json_each(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// --- Timestamps --- + +// WriteDuration writes a DuckDB INTERVAL literal. +func (d *Dialect) WriteDuration(w *strings.Builder, value int64, unit string) { + fmt.Fprintf(w, "INTERVAL %d %s", value, unit) +} + +// WriteInterval writes a DuckDB INTERVAL expression. +func (d *Dialect) WriteInterval(w *strings.Builder, writeValue func() error, unit string) error { + w.WriteString("INTERVAL ") + if err := writeValue(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(unit) + return nil +} + +// WriteExtract writes a DuckDB EXTRACT expression with DOW conversion. +func (d *Dialect) WriteExtract(w *strings.Builder, part string, writeExpr func() error, writeTZ func() error) error { + isDOW := part == "DOW" + if isDOW { + w.WriteString("(") + } + w.WriteString("EXTRACT(") + w.WriteString(part) + w.WriteString(" FROM ") + if err := writeExpr(); err != nil { + return err + } + if writeTZ != nil { + w.WriteString(" AT TIME ZONE ") + if err := writeTZ(); err != nil { + return err + } + } + w.WriteString(")") + if isDOW { + w.WriteString(" + 6) % 7") + } + return nil +} + +// WriteTimestampArithmetic writes DuckDB timestamp arithmetic. +func (d *Dialect) WriteTimestampArithmetic(w *strings.Builder, op string, writeTS, writeDur func() error) error { + if err := writeTS(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(op) + w.WriteString(" ") + return writeDur() +} + +// --- String Functions --- + +// WriteContains writes CONTAINS(haystack, needle) for DuckDB. +func (d *Dialect) WriteContains(w *strings.Builder, writeHaystack, writeNeedle func() error) error { + w.WriteString("CONTAINS(") + if err := writeHaystack(); err != nil { + return err + } + w.WriteString(", ") + if err := writeNeedle(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteSplit writes DuckDB string split using STRING_SPLIT. +func (d *Dialect) WriteSplit(w *strings.Builder, writeStr, writeDelim func() error) error { + w.WriteString("STRING_SPLIT(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteSplitWithLimit writes DuckDB string split with array slice. +func (d *Dialect) WriteSplitWithLimit(w *strings.Builder, writeStr, writeDelim func() error, limit int64) error { + w.WriteString("STRING_SPLIT(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + fmt.Fprintf(w, ")[1:%d]", limit) + return nil +} + +// WriteJoin writes DuckDB array join using ARRAY_TO_STRING. +func (d *Dialect) WriteJoin(w *strings.Builder, writeArray, writeDelim func() error) error { + w.WriteString("ARRAY_TO_STRING(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- Comprehensions --- + +// WriteUnnest writes DuckDB UNNEST for array unnesting. +func (d *Dialect) WriteUnnest(w *strings.Builder, writeSource func() error) error { + w.WriteString("UNNEST(") + if err := writeSource(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteArraySubqueryOpen writes ARRAY(SELECT for DuckDB. +func (d *Dialect) WriteArraySubqueryOpen(w *strings.Builder) { + w.WriteString("ARRAY(SELECT ") +} + +// WriteArraySubqueryExprClose is a no-op for DuckDB (no wrapper around the expression). +func (d *Dialect) WriteArraySubqueryExprClose(_ *strings.Builder) { +} + +// --- Struct --- + +// WriteStructOpen writes the DuckDB struct literal opening. +func (d *Dialect) WriteStructOpen(w *strings.Builder) { + w.WriteString("ROW(") +} + +// WriteStructClose writes the DuckDB struct literal closing. +func (d *Dialect) WriteStructClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Validation --- + +// MaxIdentifierLength returns 0 as DuckDB has no hard identifier length limit. +func (d *Dialect) MaxIdentifierLength() int { + return 0 +} + +// ValidateFieldName validates a field name against DuckDB naming rules. +func (d *Dialect) ValidateFieldName(name string) error { + return validateFieldName(name) +} + +// ReservedKeywords returns the set of reserved SQL keywords for DuckDB. +func (d *Dialect) ReservedKeywords() map[string]bool { + return reservedSQLKeywords +} + +// --- Regex --- + +// ConvertRegex converts an RE2 regex pattern to DuckDB-compatible format. +// DuckDB uses RE2 natively, so minimal conversion is needed. +func (d *Dialect) ConvertRegex(re2Pattern string) (string, bool, error) { + return convertRE2ToDuckDB(re2Pattern) +} + +// SupportsRegex returns true as DuckDB supports RE2 regex natively. +func (d *Dialect) SupportsRegex() bool { return true } + +// --- Capabilities --- + +// SupportsNativeArrays returns true as DuckDB has native array (LIST) types. +func (d *Dialect) SupportsNativeArrays() bool { return true } + +// SupportsJSONB returns false as DuckDB has a single JSON type. +func (d *Dialect) SupportsJSONB() bool { return false } + +// SupportsIndexAnalysis returns true as DuckDB index analysis is supported. +func (d *Dialect) SupportsIndexAnalysis() bool { return true } + +// --- Internal helpers --- + +// escapeJSONFieldName escapes special characters in JSON field names for DuckDB. +func escapeJSONFieldName(fieldName string) string { + return strings.ReplaceAll(fieldName, "'", "''") +} diff --git a/dialect/duckdb/index_advisor.go b/dialect/duckdb/index_advisor.go new file mode 100644 index 0000000..9906b32 --- /dev/null +++ b/dialect/duckdb/index_advisor.go @@ -0,0 +1,92 @@ +package duckdb + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// DuckDB index type constants. +const ( + IndexTypeART = "ART" +) + +// RecommendIndex generates a DuckDB-specific index recommendation for the given pattern. +// DuckDB uses ART (Adaptive Radix Tree) indexes by default. Returns nil for unsupported patterns. +func (d *Dialect) RecommendIndex(pattern dialect.IndexPattern) *dialect.IndexRecommendation { + table := pattern.TableHint + if table == "" { + table = "table_name" + } + col := pattern.Column + safeName := sanitizeIndexName(col) + + switch pattern.Pattern { + case dialect.PatternComparison: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeART, + Expression: fmt.Sprintf("CREATE INDEX idx_%s ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Comparison operations on '%s' benefit from an ART index for efficient range queries and equality checks", col), + } + + case dialect.PatternJSONAccess: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeART, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_json ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("JSON field access on '%s' may benefit from an ART index", col), + } + + case dialect.PatternRegexMatch: + // DuckDB does not have specialized regex indexes + return nil + + case dialect.PatternArrayMembership, dialect.PatternArrayComprehension: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeART, + Expression: fmt.Sprintf("CREATE INDEX idx_%s ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Array operations on '%s' may benefit from an ART index", col), + } + + case dialect.PatternJSONArrayComprehension: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeART, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_json ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("JSON array comprehension on '%s' may benefit from an ART index", col), + } + } + + return nil +} + +// SupportedPatterns returns the pattern types supported by DuckDB. +func (d *Dialect) SupportedPatterns() []dialect.PatternType { + return []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternArrayMembership, + dialect.PatternArrayComprehension, + dialect.PatternJSONArrayComprehension, + } +} + +// sanitizeIndexName creates a safe index name from a column name. +func sanitizeIndexName(column string) string { + sanitized := strings.ReplaceAll(column, ".", "_") + sanitized = strings.ReplaceAll(sanitized, " ", "_") + sanitized = strings.ReplaceAll(sanitized, "-", "_") + + if len(sanitized) > 50 { + sanitized = sanitized[:50] + } + + return sanitized +} diff --git a/dialect/duckdb/regex.go b/dialect/duckdb/regex.go new file mode 100644 index 0000000..582f83d --- /dev/null +++ b/dialect/duckdb/regex.go @@ -0,0 +1,137 @@ +package duckdb + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). +const ( + maxRegexPatternLength = 500 + maxRegexGroups = 20 + maxRegexNestingDepth = 10 +) + +// convertRE2ToDuckDB converts an RE2 regex pattern to DuckDB-compatible format. +// DuckDB uses RE2 natively, so most patterns pass through unchanged. +// Returns the converted pattern, whether it's case-insensitive, and any error. +func convertRE2ToDuckDB(re2Pattern string) (string, bool, error) { + // 1. Pattern length validation + if len(re2Pattern) > maxRegexPatternLength { + return "", false, fmt.Errorf("pattern length %d exceeds limit of %d characters", len(re2Pattern), maxRegexPatternLength) + } + + // 2. Validate pattern compiles + if _, err := regexp.Compile(re2Pattern); err != nil { + return "", false, fmt.Errorf("invalid regex pattern: %w", err) + } + + // 3. Detect unsupported features (lookahead/lookbehind not in RE2 anyway) + if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { + return "", false, errors.New("lookahead assertions (?=...), (?!...) are not supported") + } + if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported") + } + + // 4. Detect catastrophic nested quantifiers + if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + + // 5. Check for nested quantifiers in groups + depth := 0 + groupHasQuantifier := make([]bool, 0) + for i := 0; i < len(re2Pattern); i++ { + char := re2Pattern[i] + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + switch char { + case '(': + depth++ + groupHasQuantifier = append(groupHasQuantifier, false) + case ')': + if depth > 0 { + depth-- + if i+1 < len(re2Pattern) { + nextChar := re2Pattern[i+1] + if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { + if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + } + } + if len(groupHasQuantifier) > 0 { + groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] + } + } + case '*', '+', '?', '{': + for j := range groupHasQuantifier { + groupHasQuantifier[j] = true + } + } + } + + // 6. Check group count limit + groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, "\\(") + if groupCount > maxRegexGroups { + return "", false, fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, maxRegexGroups) + } + + // 7. Check for quantified alternation + quantifiedAlternation := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) + if quantifiedAlternation.MatchString(re2Pattern) { + return "", false, errors.New("regex contains quantified alternation that could cause ReDoS") + } + + // 8. Check nesting depth + maxDepthVal := 0 + currentDepth := 0 + for i := 0; i < len(re2Pattern); i++ { + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + switch re2Pattern[i] { + case '(': + currentDepth++ + if currentDepth > maxDepthVal { + maxDepthVal = currentDepth + } + case ')': + if currentDepth > 0 { + currentDepth-- + } + } + } + if maxDepthVal > maxRegexNestingDepth { + return "", false, fmt.Errorf("nesting depth %d exceeds limit of %d", maxDepthVal, maxRegexNestingDepth) + } + + // Process pattern: DuckDB uses RE2 natively, so minimal conversion needed + caseInsensitive := false + pattern := re2Pattern + + // Handle (?i) flag + if strings.HasPrefix(pattern, "(?i)") { + caseInsensitive = true + pattern = pattern[4:] + } + + // Handle inline flags other than (?i) at start + if strings.Contains(pattern, "(?m") || strings.Contains(pattern, "(?s") || strings.Contains(pattern, "(?-") { + return "", false, errors.New("inline flags other than (?i) are not supported in DuckDB regex") + } + + // Convert non-capturing groups (?:...) to regular groups (...) + pattern = strings.ReplaceAll(pattern, "(?:", "(") + + // DuckDB RE2 supports \d, \w, \s, \b natively - no conversion needed + + return pattern, caseInsensitive, nil +} diff --git a/dialect/duckdb/validation.go b/dialect/duckdb/validation.go new file mode 100644 index 0000000..976e304 --- /dev/null +++ b/dialect/duckdb/validation.go @@ -0,0 +1,55 @@ +package duckdb + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +var ( + // fieldNameRegexp validates DuckDB identifier format. + fieldNameRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + + // reservedSQLKeywords contains DuckDB reserved keywords. + reservedSQLKeywords = map[string]bool{ + "all": true, "alter": true, "analyse": true, "analyze": true, "and": true, + "any": true, "array": true, "as": true, "asc": true, "asymmetric": true, + "between": true, "both": true, "case": true, "cast": true, "check": true, + "collate": true, "column": true, "constraint": true, "create": true, + "cross": true, "current_catalog": true, "current_date": true, + "current_role": true, "current_schema": true, "current_time": true, + "current_timestamp": true, "current_user": true, "default": true, + "deferrable": true, "desc": true, "distinct": true, "do": true, + "else": true, "end": true, "except": true, "exists": true, "false": true, + "fetch": true, "for": true, "foreign": true, "from": true, "full": true, + "grant": true, "group": true, "having": true, "in": true, "initially": true, + "inner": true, "intersect": true, "into": true, "is": true, "isnull": true, + "join": true, "lateral": true, "leading": true, "left": true, "like": true, + "limit": true, "localtime": true, "localtimestamp": true, "natural": true, + "not": true, "notnull": true, "null": true, "offset": true, "on": true, + "only": true, "or": true, "order": true, "outer": true, "overlaps": true, + "placing": true, "primary": true, "references": true, "returning": true, + "right": true, "select": true, "session_user": true, "similar": true, + "some": true, "symmetric": true, "table": true, "then": true, "to": true, + "trailing": true, "true": true, "union": true, "unique": true, "using": true, + "variadic": true, "when": true, "where": true, "window": true, "with": true, + } +) + +// validateFieldName validates that a field name follows DuckDB naming conventions. +func validateFieldName(name string) error { + if len(name) == 0 { + return errors.New("field name cannot be empty") + } + + if !fieldNameRegexp.MatchString(name) { + return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) + } + + if reservedSQLKeywords[strings.ToLower(name)] { + return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) + } + + return nil +} diff --git a/dialect/index_advisor.go b/dialect/index_advisor.go new file mode 100644 index 0000000..210e62f --- /dev/null +++ b/dialect/index_advisor.go @@ -0,0 +1,61 @@ +// Package dialect defines the IndexAdvisor interface for dialect-specific index recommendations. +package dialect + +// PatternType enumerates detected index-worthy query patterns. +type PatternType int + +// Index-worthy pattern types detected during query analysis. +const ( + PatternComparison PatternType = iota // Equality/range comparisons (==, >, <, >=, <=) + PatternJSONAccess // JSON/JSONB field access + PatternRegexMatch // Regex pattern matching + PatternArrayMembership // Array IN/containment + PatternArrayComprehension // Array comprehension (all, exists, filter, map) + PatternJSONArrayComprehension // JSON array comprehension +) + +// IndexPattern describes a detected query pattern that could benefit from indexing. +type IndexPattern struct { + // Column is the full column name (e.g., "person.metadata"). + Column string + + // Pattern is the type of query pattern detected. + Pattern PatternType + + // TableHint is an optional table name hint for generating CREATE INDEX statements. + // If empty, "table_name" is used as the default placeholder. + TableHint string +} + +// IndexRecommendation represents a database index recommendation. +// It provides actionable guidance for optimizing query performance. +type IndexRecommendation struct { + // Column is the database column that should be indexed. + Column string + + // IndexType specifies the index type (e.g., "BTREE", "GIN", "ART", "CLUSTERING"). + IndexType string + + // Expression is the complete DDL statement that can be executed directly. + Expression string + + // Reason explains why this index is recommended and what query patterns it optimizes. + Reason string +} + +// IndexAdvisor generates dialect-specific index recommendations. +// Dialects that support index analysis implement this interface on their Dialect struct. +type IndexAdvisor interface { + // RecommendIndex generates an IndexRecommendation for the given pattern, + // or returns nil if the dialect has no applicable index for this pattern. + RecommendIndex(pattern IndexPattern) *IndexRecommendation + + // SupportedPatterns returns which PatternTypes this advisor can handle. + SupportedPatterns() []PatternType +} + +// GetIndexAdvisor returns the IndexAdvisor for a dialect, if it implements the interface. +func GetIndexAdvisor(d Dialect) (IndexAdvisor, bool) { + advisor, ok := d.(IndexAdvisor) + return advisor, ok +} diff --git a/dialect/mysql/dialect.go b/dialect/mysql/dialect.go new file mode 100644 index 0000000..01cd97d --- /dev/null +++ b/dialect/mysql/dialect.go @@ -0,0 +1,475 @@ +// Package mysql implements the MySQL SQL dialect for cel2sql. +package mysql + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// Dialect implements dialect.Dialect for MySQL 8.0+. +type Dialect struct{} + +// New creates a new MySQL dialect. +func New() *Dialect { + return &Dialect{} +} + +func init() { + dialect.Register(dialect.MySQL, func() dialect.Dialect { return New() }) +} + +// Ensure Dialect implements dialect.Dialect at compile time. +var _ dialect.Dialect = (*Dialect)(nil) + +// Name returns the dialect name. +func (d *Dialect) Name() dialect.Name { return dialect.MySQL } + +// --- Literals --- + +// WriteStringLiteral writes a MySQL string literal with ” escaping. +func (d *Dialect) WriteStringLiteral(w *strings.Builder, value string) { + escaped := strings.ReplaceAll(value, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") +} + +// WriteBytesLiteral writes a MySQL hex-encoded byte literal (X'...'). +func (d *Dialect) WriteBytesLiteral(w *strings.Builder, value []byte) error { + w.WriteString("X'") + for _, b := range value { + fmt.Fprintf(w, "%02x", b) + } + w.WriteString("'") + return nil +} + +// WriteParamPlaceholder writes a MySQL positional parameter (?). +func (d *Dialect) WriteParamPlaceholder(w *strings.Builder, _ int) { + w.WriteString("?") +} + +// --- Operators --- + +// WriteStringConcat writes MySQL string concatenation using CONCAT(). +func (d *Dialect) WriteStringConcat(w *strings.Builder, writeLHS, writeRHS func() error) error { + w.WriteString("CONCAT(") + if err := writeLHS(); err != nil { + return err + } + w.WriteString(", ") + if err := writeRHS(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteRegexMatch writes a MySQL REGEXP match expression. +func (d *Dialect) WriteRegexMatch(w *strings.Builder, writeTarget func() error, pattern string, _ bool) error { + if err := writeTarget(); err != nil { + return err + } + w.WriteString(" REGEXP ") + escaped := strings.ReplaceAll(pattern, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") + return nil +} + +// WriteLikeEscape writes the MySQL LIKE escape clause. +func (d *Dialect) WriteLikeEscape(w *strings.Builder) { + w.WriteString(" ESCAPE '\\\\'") +} + +// WriteArrayMembership writes a MySQL array membership test using JSON_CONTAINS. +func (d *Dialect) WriteArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + w.WriteString("JSON_CONTAINS(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", CAST(") + if err := writeElem(); err != nil { + return err + } + w.WriteString(" AS JSON))") + return nil +} + +// --- Type Casting --- + +// WriteCastToNumeric writes a MySQL numeric cast (CAST(... AS DECIMAL)). +func (d *Dialect) WriteCastToNumeric(w *strings.Builder) { + w.WriteString(" + 0") +} + +// WriteTypeName writes a MySQL type name for CAST expressions. +func (d *Dialect) WriteTypeName(w *strings.Builder, celTypeName string) { + switch celTypeName { + case "bool": + w.WriteString("UNSIGNED") + case "bytes": + w.WriteString("BINARY") + case "double": + w.WriteString("DECIMAL") + case "int": + w.WriteString("SIGNED") + case "string": + w.WriteString("CHAR") + case "uint": + w.WriteString("UNSIGNED") + default: + w.WriteString(strings.ToUpper(celTypeName)) + } +} + +// WriteEpochExtract writes UNIX_TIMESTAMP(expr) for MySQL. +func (d *Dialect) WriteEpochExtract(w *strings.Builder, writeExpr func() error) error { + w.WriteString("UNIX_TIMESTAMP(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteTimestampCast writes a MySQL CAST to DATETIME. +func (d *Dialect) WriteTimestampCast(w *strings.Builder, writeExpr func() error) error { + w.WriteString("CAST(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(" AS DATETIME)") + return nil +} + +// --- Arrays --- + +// WriteArrayLiteralOpen writes the MySQL JSON array literal opening. +func (d *Dialect) WriteArrayLiteralOpen(w *strings.Builder) { + w.WriteString("JSON_ARRAY(") +} + +// WriteArrayLiteralClose writes the MySQL JSON array literal closing. +func (d *Dialect) WriteArrayLiteralClose(w *strings.Builder) { + w.WriteString(")") +} + +// WriteArrayLength writes JSON_LENGTH(expr) for MySQL. +func (d *Dialect) WriteArrayLength(w *strings.Builder, _ int, writeExpr func() error) error { + w.WriteString("COALESCE(JSON_LENGTH(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteListIndex writes MySQL JSON array index access. +func (d *Dialect) WriteListIndex(w *strings.Builder, writeArray func() error, writeIndex func() error) error { + w.WriteString("JSON_EXTRACT(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", CONCAT('$[', ") + if err := writeIndex(); err != nil { + return err + } + w.WriteString(", ']'))") + return nil +} + +// WriteListIndexConst writes MySQL JSON constant array index access. +func (d *Dialect) WriteListIndexConst(w *strings.Builder, writeArray func() error, index int64) error { + w.WriteString("JSON_EXTRACT(") + if err := writeArray(); err != nil { + return err + } + fmt.Fprintf(w, ", '$[%d]')", index) + return nil +} + +// WriteEmptyTypedArray writes an empty MySQL JSON array. +func (d *Dialect) WriteEmptyTypedArray(w *strings.Builder, _ string) { + w.WriteString("JSON_ARRAY()") +} + +// --- JSON --- + +// WriteJSONFieldAccess writes MySQL JSON field access using JSON_EXTRACT/JSON_UNQUOTE. +func (d *Dialect) WriteJSONFieldAccess(w *strings.Builder, writeBase func() error, fieldName string, isFinal bool) error { + if err := writeBase(); err != nil { + return err + } + escaped := escapeJSONFieldName(fieldName) + if isFinal { + // For final access, we need text: use ->> which is JSON_UNQUOTE(JSON_EXTRACT(...)) + w.WriteString("->>'$.") + w.WriteString(escaped) + w.WriteString("'") + } else { + w.WriteString("->'$.") + w.WriteString(escaped) + w.WriteString("'") + } + return nil +} + +// WriteJSONExistence writes a MySQL JSON key existence check. +func (d *Dialect) WriteJSONExistence(w *strings.Builder, _ bool, fieldName string, writeBase func() error) error { + w.WriteString("JSON_CONTAINS_PATH(") + if err := writeBase(); err != nil { + return err + } + escaped := escapeJSONFieldName(fieldName) + w.WriteString(", 'one', '$.") + w.WriteString(escaped) + w.WriteString("')") + return nil +} + +// WriteJSONArrayElements writes MySQL JSON array expansion using JSON_TABLE. +func (d *Dialect) WriteJSONArrayElements(w *strings.Builder, _ bool, _ bool, writeExpr func() error) error { + w.WriteString("JSON_TABLE(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(", '$[*]' COLUMNS(value TEXT PATH '$'))") + return nil +} + +// WriteJSONArrayLength writes COALESCE(JSON_LENGTH(expr), 0) for MySQL. +func (d *Dialect) WriteJSONArrayLength(w *strings.Builder, writeExpr func() error) error { + w.WriteString("COALESCE(JSON_LENGTH(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteJSONExtractPath writes MySQL JSON path extraction using JSON_CONTAINS_PATH. +func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string, writeRoot func() error) error { + w.WriteString("JSON_CONTAINS_PATH(") + if err := writeRoot(); err != nil { + return err + } + w.WriteString(", 'one', '$") + for _, segment := range pathSegments { + w.WriteString(".") + w.WriteString(escapeJSONFieldName(segment)) + } + w.WriteString("')") + return nil +} + +// WriteJSONArrayMembership writes MySQL JSON array membership using JSON_CONTAINS. +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeExpr func() error) error { + w.WriteString("JSON_CONTAINS(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(", CAST(? AS JSON))") + return nil +} + +// WriteNestedJSONArrayMembership writes MySQL nested JSON array membership. +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { + w.WriteString("JSON_CONTAINS(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(", CAST(? AS JSON))") + return nil +} + +// --- Timestamps --- + +// WriteDuration writes a MySQL INTERVAL literal. +func (d *Dialect) WriteDuration(w *strings.Builder, value int64, unit string) { + fmt.Fprintf(w, "INTERVAL %d %s", value, unit) +} + +// WriteInterval writes a MySQL INTERVAL expression. +func (d *Dialect) WriteInterval(w *strings.Builder, writeValue func() error, unit string) error { + w.WriteString("INTERVAL ") + if err := writeValue(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(unit) + return nil +} + +// WriteExtract writes a MySQL EXTRACT expression with DOW conversion. +func (d *Dialect) WriteExtract(w *strings.Builder, part string, writeExpr func() error, writeTZ func() error) error { + isDOW := part == "DOW" + if isDOW { + // MySQL DAYOFWEEK: 1=Sunday, 2=Monday, ..., 7=Saturday + // CEL getDayOfWeek: 0=Monday, 1=Tuesday, ..., 6=Sunday (ISO 8601) + // Convert: (DAYOFWEEK(x) + 5) % 7 + w.WriteString("(DAYOFWEEK(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(") + 5) % 7") + return nil + } + + w.WriteString("EXTRACT(") + w.WriteString(part) + w.WriteString(" FROM ") + if err := writeExpr(); err != nil { + return err + } + if writeTZ != nil { + w.WriteString(" AT TIME ZONE ") + if err := writeTZ(); err != nil { + return err + } + } + w.WriteString(")") + return nil +} + +// WriteTimestampArithmetic writes MySQL timestamp arithmetic. +func (d *Dialect) WriteTimestampArithmetic(w *strings.Builder, op string, writeTS, writeDur func() error) error { + if err := writeTS(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(op) + w.WriteString(" ") + return writeDur() +} + +// --- String Functions --- + +// WriteContains writes LOCATE(needle, haystack) > 0 for MySQL. +func (d *Dialect) WriteContains(w *strings.Builder, writeHaystack, writeNeedle func() error) error { + w.WriteString("LOCATE(") + if err := writeNeedle(); err != nil { + return err + } + w.WriteString(", ") + if err := writeHaystack(); err != nil { + return err + } + w.WriteString(") > 0") + return nil +} + +// WriteSplit writes MySQL string split using SUBSTRING_INDEX pattern. +func (d *Dialect) WriteSplit(w *strings.Builder, writeStr, writeDelim func() error) error { + // MySQL doesn't have a direct STRING_TO_ARRAY equivalent. + // Use a JSON approach: convert to JSON array. + w.WriteString("JSON_ARRAY(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(")") + // Note: A full MySQL split implementation would require a more complex approach. + // This is a simplified version. + _ = writeDelim + return nil +} + +// WriteSplitWithLimit writes MySQL string split with limit. +func (d *Dialect) WriteSplitWithLimit(w *strings.Builder, writeStr, writeDelim func() error, _ int64) error { + // Simplified: delegate to WriteSplit + return d.WriteSplit(w, writeStr, writeDelim) +} + +// WriteJoin writes MySQL array join using JSON_UNQUOTE/GROUP_CONCAT pattern. +func (d *Dialect) WriteJoin(w *strings.Builder, writeArray, writeDelim func() error) error { + // MySQL doesn't have ARRAY_TO_STRING; simplified approach + w.WriteString("JSON_UNQUOTE(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(")") + _ = writeDelim + return nil +} + +// --- Comprehensions --- + +// WriteUnnest writes MySQL JSON_TABLE for array unnesting. +func (d *Dialect) WriteUnnest(w *strings.Builder, writeSource func() error) error { + w.WriteString("JSON_TABLE(") + if err := writeSource(); err != nil { + return err + } + w.WriteString(", '$[*]' COLUMNS(value TEXT PATH '$'))") + return nil +} + +// WriteArraySubqueryOpen writes (SELECT JSON_ARRAYAGG( for MySQL array subqueries. +func (d *Dialect) WriteArraySubqueryOpen(w *strings.Builder) { + w.WriteString("(SELECT JSON_ARRAYAGG(") +} + +// WriteArraySubqueryExprClose closes the JSON_ARRAYAGG aggregate function for MySQL. +func (d *Dialect) WriteArraySubqueryExprClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Struct --- + +// WriteStructOpen writes the MySQL struct literal opening. +func (d *Dialect) WriteStructOpen(w *strings.Builder) { + w.WriteString("ROW(") +} + +// WriteStructClose writes the MySQL struct literal closing. +func (d *Dialect) WriteStructClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Validation --- + +// MaxIdentifierLength returns the MySQL maximum identifier length (64). +func (d *Dialect) MaxIdentifierLength() int { + return maxMySQLIdentifierLength +} + +// ValidateFieldName validates a field name against MySQL naming rules. +func (d *Dialect) ValidateFieldName(name string) error { + return validateFieldName(name) +} + +// ReservedKeywords returns the set of reserved SQL keywords for MySQL. +func (d *Dialect) ReservedKeywords() map[string]bool { + return reservedSQLKeywords +} + +// --- Regex --- + +// ConvertRegex converts an RE2 regex pattern to MySQL-compatible format. +func (d *Dialect) ConvertRegex(re2Pattern string) (string, bool, error) { + return convertRE2ToMySQL(re2Pattern) +} + +// SupportsRegex returns true as MySQL 8.0+ supports ICU regex. +func (d *Dialect) SupportsRegex() bool { return true } + +// --- Capabilities --- + +// SupportsNativeArrays returns false as MySQL uses JSON arrays. +func (d *Dialect) SupportsNativeArrays() bool { return false } + +// SupportsJSONB returns false as MySQL has a single JSON type. +func (d *Dialect) SupportsJSONB() bool { return false } + +// SupportsIndexAnalysis returns true as MySQL index analysis is supported. +func (d *Dialect) SupportsIndexAnalysis() bool { return true } + +// --- Internal helpers --- + +// escapeJSONFieldName escapes special characters in JSON field names for MySQL. +func escapeJSONFieldName(fieldName string) string { + return strings.ReplaceAll(fieldName, "'", "''") +} diff --git a/dialect/mysql/index_advisor.go b/dialect/mysql/index_advisor.go new file mode 100644 index 0000000..4ffe9df --- /dev/null +++ b/dialect/mysql/index_advisor.go @@ -0,0 +1,93 @@ +package mysql + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// MySQL index type constants. +const ( + IndexTypeBTree = "BTREE" + IndexTypeFullText = "FULLTEXT" +) + +// RecommendIndex generates a MySQL-specific index recommendation for the given pattern. +// Returns nil if no applicable index exists for this pattern. +func (d *Dialect) RecommendIndex(pattern dialect.IndexPattern) *dialect.IndexRecommendation { + table := pattern.TableHint + if table == "" { + table = "table_name" + } + col := pattern.Column + safeName := sanitizeIndexName(col) + + switch pattern.Pattern { + case dialect.PatternComparison: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeBTree, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_btree ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Comparison operations on '%s' benefit from B-tree index for efficient range queries and equality checks", col), + } + + case dialect.PatternJSONAccess: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeBTree, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_json ON %s ((CAST(%s->>'$.path' AS CHAR(255))));", + safeName, table, col), + Reason: fmt.Sprintf("JSON field access on '%s' benefits from a functional B-tree index on extracted JSON paths", col), + } + + case dialect.PatternRegexMatch: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeFullText, + Expression: fmt.Sprintf("CREATE FULLTEXT INDEX idx_%s_fulltext ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Regex matching on '%s' may benefit from FULLTEXT index for text search patterns", col), + } + + case dialect.PatternArrayMembership, dialect.PatternArrayComprehension: + // MySQL does not have native array types; skip + return nil + + case dialect.PatternJSONArrayComprehension: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeBTree, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_json ON %s ((CAST(%s->>'$.path' AS CHAR(255))));", + safeName, table, col), + Reason: fmt.Sprintf("JSON array operations on '%s' may benefit from a functional index on extracted JSON values", col), + } + } + + return nil +} + +// SupportedPatterns returns the pattern types supported by MySQL. +func (d *Dialect) SupportedPatterns() []dialect.PatternType { + return []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternRegexMatch, + dialect.PatternJSONArrayComprehension, + } +} + +// sanitizeIndexName creates a safe index name from a column name. +func sanitizeIndexName(column string) string { + sanitized := strings.ReplaceAll(column, ".", "_") + sanitized = strings.ReplaceAll(sanitized, " ", "_") + sanitized = strings.ReplaceAll(sanitized, "-", "_") + + // MySQL index names are limited to 64 characters + if len(sanitized) > 50 { + sanitized = sanitized[:50] + } + + return sanitized +} diff --git a/dialect/mysql/regex.go b/dialect/mysql/regex.go new file mode 100644 index 0000000..7965097 --- /dev/null +++ b/dialect/mysql/regex.go @@ -0,0 +1,139 @@ +package mysql + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). +const ( + maxRegexPatternLength = 500 + maxRegexGroups = 20 + maxRegexNestingDepth = 10 +) + +// convertRE2ToMySQL converts an RE2 regex pattern to MySQL-compatible format. +// MySQL 8.0+ uses ICU regex which supports most RE2 features. +// Returns the converted pattern, whether it's case-insensitive, and any error. +func convertRE2ToMySQL(re2Pattern string) (string, bool, error) { + // 1. Pattern length validation + if len(re2Pattern) > maxRegexPatternLength { + return "", false, fmt.Errorf("pattern length %d exceeds limit of %d characters", len(re2Pattern), maxRegexPatternLength) + } + + // 2. Validate pattern compiles + if _, err := regexp.Compile(re2Pattern); err != nil { + return "", false, fmt.Errorf("invalid regex pattern: %w", err) + } + + // 3. Detect unsupported RE2 features + if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { + return "", false, errors.New("lookahead assertions (?=...), (?!...) are not supported in MySQL regex") + } + if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported in MySQL regex") + } + + // 4. Detect catastrophic nested quantifiers + if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + + // 5. Check for nested quantifiers in groups + depth := 0 + groupHasQuantifier := make([]bool, 0) + for i := 0; i < len(re2Pattern); i++ { + char := re2Pattern[i] + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + switch char { + case '(': + depth++ + groupHasQuantifier = append(groupHasQuantifier, false) + case ')': + if depth > 0 { + depth-- + if i+1 < len(re2Pattern) { + nextChar := re2Pattern[i+1] + if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { + if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + } + } + if len(groupHasQuantifier) > 0 { + groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] + } + } + case '*', '+', '?', '{': + for j := range groupHasQuantifier { + groupHasQuantifier[j] = true + } + } + } + + // 6. Check group count limit + groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, "\\(") + if groupCount > maxRegexGroups { + return "", false, fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, maxRegexGroups) + } + + // 7. Check for quantified alternation + quantifiedAlternation := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) + if quantifiedAlternation.MatchString(re2Pattern) { + return "", false, errors.New("regex contains quantified alternation that could cause ReDoS") + } + + // 8. Check nesting depth + maxDepthVal := 0 + currentDepth := 0 + for i := 0; i < len(re2Pattern); i++ { + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + switch re2Pattern[i] { + case '(': + currentDepth++ + if currentDepth > maxDepthVal { + maxDepthVal = currentDepth + } + case ')': + if currentDepth > 0 { + currentDepth-- + } + } + } + if maxDepthVal > maxRegexNestingDepth { + return "", false, fmt.Errorf("nesting depth %d exceeds limit of %d", maxDepthVal, maxRegexNestingDepth) + } + + // Process pattern: extract case-insensitivity, convert features + caseInsensitive := false + pattern := re2Pattern + + // Handle (?i) flag + if strings.HasPrefix(pattern, "(?i)") { + caseInsensitive = true + pattern = pattern[4:] + } + + // Handle inline flags other than (?i) at start + if strings.Contains(pattern, "(?m") || strings.Contains(pattern, "(?s") || strings.Contains(pattern, "(?-") { + return "", false, errors.New("inline flags other than (?i) are not supported in MySQL regex") + } + + // Convert non-capturing groups (?:...) to regular groups (...) + pattern = strings.ReplaceAll(pattern, "(?:", "(") + + // MySQL ICU regex supports \d, \w, \s natively - no conversion needed + // Convert \b word boundary to MySQL's \b (same syntax in ICU) + // No conversion needed for MySQL 8.0+ + + return pattern, caseInsensitive, nil +} diff --git a/dialect/mysql/validation.go b/dialect/mysql/validation.go new file mode 100644 index 0000000..15a20d4 --- /dev/null +++ b/dialect/mysql/validation.go @@ -0,0 +1,91 @@ +package mysql + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +const ( + // maxMySQLIdentifierLength is the maximum length for MySQL identifiers. + maxMySQLIdentifierLength = 64 +) + +var ( + // fieldNameRegexp validates MySQL identifier format. + fieldNameRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + + // reservedSQLKeywords contains MySQL reserved keywords. + reservedSQLKeywords = map[string]bool{ + "accessible": true, "add": true, "all": true, "alter": true, "analyze": true, + "and": true, "as": true, "asc": true, "asensitive": true, "before": true, + "between": true, "bigint": true, "binary": true, "blob": true, "both": true, + "by": true, "call": true, "cascade": true, "case": true, "change": true, + "char": true, "character": true, "check": true, "collate": true, "column": true, + "condition": true, "constraint": true, "continue": true, "convert": true, + "create": true, "cross": true, "current_date": true, "current_time": true, + "current_timestamp": true, "current_user": true, "cursor": true, "database": true, + "databases": true, "day_hour": true, "day_microsecond": true, "day_minute": true, + "day_second": true, "dec": true, "decimal": true, "declare": true, "default": true, + "delayed": true, "delete": true, "desc": true, "describe": true, "deterministic": true, + "distinct": true, "distinctrow": true, "div": true, "double": true, "drop": true, + "dual": true, "each": true, "else": true, "elseif": true, "enclosed": true, + "escaped": true, "exists": true, "exit": true, "explain": true, "false": true, + "fetch": true, "float": true, "float4": true, "float8": true, "for": true, + "force": true, "foreign": true, "from": true, "fulltext": true, "grant": true, + "group": true, "having": true, "high_priority": true, "hour_microsecond": true, + "hour_minute": true, "hour_second": true, "if": true, "ignore": true, "in": true, + "index": true, "infile": true, "inner": true, "inout": true, "insensitive": true, + "insert": true, "int": true, "int1": true, "int2": true, "int3": true, + "int4": true, "int8": true, "integer": true, "interval": true, "into": true, + "is": true, "iterate": true, "join": true, "key": true, "keys": true, + "kill": true, "leading": true, "leave": true, "left": true, "like": true, + "limit": true, "linear": true, "lines": true, "load": true, "localtime": true, + "localtimestamp": true, "lock": true, "long": true, "longblob": true, + "longtext": true, "loop": true, "low_priority": true, "match": true, + "mediumblob": true, "mediumint": true, "mediumtext": true, "middleint": true, + "minute_microsecond": true, "minute_second": true, "mod": true, "modifies": true, + "natural": true, "not": true, "null": true, "numeric": true, "on": true, + "optimize": true, "option": true, "optionally": true, "or": true, "order": true, + "out": true, "outer": true, "outfile": true, "precision": true, "primary": true, + "procedure": true, "purge": true, "range": true, "read": true, "reads": true, + "real": true, "references": true, "regexp": true, "release": true, "rename": true, + "repeat": true, "replace": true, "require": true, "restrict": true, "return": true, + "revoke": true, "right": true, "rlike": true, "schema": true, "schemas": true, + "second_microsecond": true, "select": true, "sensitive": true, "separator": true, + "set": true, "show": true, "signal": true, "smallint": true, "spatial": true, + "specific": true, "sql": true, "sqlexception": true, "sqlstate": true, + "sqlwarning": true, "sql_big_result": true, "sql_calc_found_rows": true, + "sql_small_result": true, "ssl": true, "starting": true, "straight_join": true, + "table": true, "terminated": true, "then": true, "tinyblob": true, "tinyint": true, + "tinytext": true, "to": true, "trailing": true, "trigger": true, "true": true, + "undo": true, "union": true, "unique": true, "unlock": true, "unsigned": true, + "update": true, "usage": true, "use": true, "using": true, "utc_date": true, + "utc_time": true, "utc_timestamp": true, "values": true, "varbinary": true, + "varchar": true, "varcharacter": true, "varying": true, "when": true, + "where": true, "while": true, "with": true, "write": true, "xor": true, + "year_month": true, "zerofill": true, + } +) + +// validateFieldName validates that a field name follows MySQL naming conventions. +func validateFieldName(name string) error { + if len(name) == 0 { + return errors.New("field name cannot be empty") + } + + if len(name) > maxMySQLIdentifierLength { + return fmt.Errorf("field name %q exceeds MySQL maximum identifier length of %d characters", name, maxMySQLIdentifierLength) + } + + if !fieldNameRegexp.MatchString(name) { + return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) + } + + if reservedSQLKeywords[strings.ToLower(name)] { + return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) + } + + return nil +} diff --git a/dialect/postgres/dialect.go b/dialect/postgres/dialect.go new file mode 100644 index 0000000..8aee559 --- /dev/null +++ b/dialect/postgres/dialect.go @@ -0,0 +1,496 @@ +// Package postgres implements the PostgreSQL SQL dialect for cel2sql. +package postgres + +import ( + "encoding/hex" + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// Dialect implements dialect.Dialect for PostgreSQL. +type Dialect struct{} + +// New creates a new PostgreSQL dialect. +func New() *Dialect { + return &Dialect{} +} + +func init() { + dialect.Register(dialect.PostgreSQL, func() dialect.Dialect { return New() }) +} + +// Ensure Dialect implements dialect.Dialect at compile time. +var _ dialect.Dialect = (*Dialect)(nil) + +// Name returns the dialect name. +func (d *Dialect) Name() dialect.Name { return dialect.PostgreSQL } + +// --- Literals --- + +// WriteStringLiteral writes a PostgreSQL string literal with ” escaping. +func (d *Dialect) WriteStringLiteral(w *strings.Builder, value string) { + escaped := strings.ReplaceAll(value, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") +} + +// WriteBytesLiteral writes a PostgreSQL hex-encoded byte literal. +func (d *Dialect) WriteBytesLiteral(w *strings.Builder, value []byte) error { + w.WriteString("'\\x") + w.WriteString(hex.EncodeToString(value)) + w.WriteString("'") + return nil +} + +// WriteParamPlaceholder writes a PostgreSQL positional parameter ($1, $2, ...). +func (d *Dialect) WriteParamPlaceholder(w *strings.Builder, paramIndex int) { + fmt.Fprintf(w, "$%d", paramIndex) +} + +// --- Operators --- + +// WriteStringConcat writes a PostgreSQL string concatenation using ||. +func (d *Dialect) WriteStringConcat(w *strings.Builder, writeLHS, writeRHS func() error) error { + if err := writeLHS(); err != nil { + return err + } + w.WriteString(" || ") + return writeRHS() +} + +// WriteRegexMatch writes a PostgreSQL regex match using ~ or ~* operators. +func (d *Dialect) WriteRegexMatch(w *strings.Builder, writeTarget func() error, pattern string, caseInsensitive bool) error { + if err := writeTarget(); err != nil { + return err + } + if caseInsensitive { + w.WriteString(" ~* ") + } else { + w.WriteString(" ~ ") + } + escaped := strings.ReplaceAll(pattern, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") + return nil +} + +// WriteLikeEscape writes the PostgreSQL LIKE escape clause. +func (d *Dialect) WriteLikeEscape(w *strings.Builder) { + w.WriteString(" ESCAPE E'\\\\'") +} + +// WriteArrayMembership writes a PostgreSQL array membership test using = ANY(). +func (d *Dialect) WriteArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + if err := writeElem(); err != nil { + return err + } + w.WriteString(" = ANY(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- Type Casting --- + +// WriteCastToNumeric writes a PostgreSQL numeric cast suffix (::numeric). +func (d *Dialect) WriteCastToNumeric(w *strings.Builder) { + w.WriteString("::numeric") +} + +// WriteTypeName writes a PostgreSQL type name for CAST expressions. +func (d *Dialect) WriteTypeName(w *strings.Builder, celTypeName string) { + switch celTypeName { + case "bool": + w.WriteString("BOOLEAN") + case "bytes": + w.WriteString("BYTEA") + case "double": + w.WriteString("DOUBLE PRECISION") + case "int": + w.WriteString("BIGINT") + case "string": + w.WriteString("TEXT") + case "uint": + w.WriteString("BIGINT") + default: + w.WriteString(strings.ToUpper(celTypeName)) + } +} + +// WriteEpochExtract writes EXTRACT(EPOCH FROM expr)::bigint for PostgreSQL. +func (d *Dialect) WriteEpochExtract(w *strings.Builder, writeExpr func() error) error { + w.WriteString("EXTRACT(EPOCH FROM ") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")::bigint") + return nil +} + +// WriteTimestampCast writes CAST(expr AS TIMESTAMP WITH TIME ZONE) for PostgreSQL. +func (d *Dialect) WriteTimestampCast(w *strings.Builder, writeExpr func() error) error { + w.WriteString("CAST(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(" AS TIMESTAMP WITH TIME ZONE)") + return nil +} + +// --- Arrays --- + +// WriteArrayLiteralOpen writes the PostgreSQL array literal opening (ARRAY[). +func (d *Dialect) WriteArrayLiteralOpen(w *strings.Builder) { + w.WriteString("ARRAY[") +} + +// WriteArrayLiteralClose writes the PostgreSQL array literal closing (]). +func (d *Dialect) WriteArrayLiteralClose(w *strings.Builder) { + w.WriteString("]") +} + +// WriteArrayLength writes COALESCE(ARRAY_LENGTH(expr, dimension), 0) for PostgreSQL. +func (d *Dialect) WriteArrayLength(w *strings.Builder, dimension int, writeExpr func() error) error { + w.WriteString("COALESCE(ARRAY_LENGTH(") + if err := writeExpr(); err != nil { + return err + } + fmt.Fprintf(w, ", %d), 0)", dimension) + return nil +} + +// WriteListIndex writes a PostgreSQL 1-indexed array access (array[index + 1]). +func (d *Dialect) WriteListIndex(w *strings.Builder, writeArray func() error, writeIndex func() error) error { + if err := writeArray(); err != nil { + return err + } + w.WriteString("[") + if err := writeIndex(); err != nil { + return err + } + w.WriteString(" + 1]") + return nil +} + +// WriteListIndexConst writes a PostgreSQL constant array index (0-indexed to 1-indexed). +func (d *Dialect) WriteListIndexConst(w *strings.Builder, writeArray func() error, index int64) error { + if err := writeArray(); err != nil { + return err + } + fmt.Fprintf(w, "[%d]", index+1) + return nil +} + +// WriteEmptyTypedArray writes an empty PostgreSQL typed array (ARRAY[]::type[]). +func (d *Dialect) WriteEmptyTypedArray(w *strings.Builder, typeName string) { + fmt.Fprintf(w, "ARRAY[]::%s[]", typeName) +} + +// --- JSON --- + +// WriteJSONFieldAccess writes PostgreSQL JSON field access using -> or ->> operators. +func (d *Dialect) WriteJSONFieldAccess(w *strings.Builder, writeBase func() error, fieldName string, isFinal bool) error { + if err := writeBase(); err != nil { + return err + } + escapedField := escapeJSONFieldName(fieldName) + if isFinal { + w.WriteString("->>'") + } else { + w.WriteString("->'") + } + w.WriteString(escapedField) + w.WriteString("'") + return nil +} + +// WriteJSONExistence writes a PostgreSQL JSON key existence check (? or IS NOT NULL). +func (d *Dialect) WriteJSONExistence(w *strings.Builder, isJSONB bool, fieldName string, writeBase func() error) error { + if err := writeBase(); err != nil { + return err + } + escapedField := escapeJSONFieldName(fieldName) + if isJSONB { + w.WriteString(" ? '") + w.WriteString(escapedField) + w.WriteString("'") + } else { + w.WriteString("->'") + w.WriteString(escapedField) + w.WriteString("' IS NOT NULL") + } + return nil +} + +// WriteJSONArrayElements writes a PostgreSQL JSON array expansion function. +func (d *Dialect) WriteJSONArrayElements(w *strings.Builder, isJSONB bool, asText bool, writeExpr func() error) error { + if isJSONB { + if asText { + w.WriteString("jsonb_array_elements_text(") + } else { + w.WriteString("jsonb_array_elements(") + } + } else { + if asText { + w.WriteString("json_array_elements_text(") + } else { + w.WriteString("json_array_elements(") + } + } + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteJSONArrayLength writes COALESCE(jsonb_array_length(expr), 0) for PostgreSQL. +func (d *Dialect) WriteJSONArrayLength(w *strings.Builder, writeExpr func() error) error { + w.WriteString("COALESCE(jsonb_array_length(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteJSONExtractPath writes jsonb_extract_path_text() IS NOT NULL for PostgreSQL. +func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string, writeRoot func() error) error { + w.WriteString("jsonb_extract_path_text(") + if err := writeRoot(); err != nil { + return err + } + for _, segment := range pathSegments { + w.WriteString(", '") + w.WriteString(escapeJSONFieldName(segment)) + w.WriteString("'") + } + w.WriteString(") IS NOT NULL") + return nil +} + +// WriteJSONArrayMembership writes ANY(ARRAY(SELECT json_func(expr))) for PostgreSQL. +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, jsonFunc string, writeExpr func() error) error { + w.WriteString("ANY(ARRAY(SELECT ") + w.WriteString(jsonFunc) + w.WriteString("(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")))") + return nil +} + +// WriteNestedJSONArrayMembership writes ANY(ARRAY(SELECT jsonb_array_elements_text(expr))) for PostgreSQL. +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { + w.WriteString("ANY(ARRAY(SELECT jsonb_array_elements_text(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")))") + return nil +} + +// --- Timestamps --- + +// WriteDuration writes a PostgreSQL INTERVAL literal (INTERVAL N UNIT). +func (d *Dialect) WriteDuration(w *strings.Builder, value int64, unit string) { + fmt.Fprintf(w, "INTERVAL %d %s", value, unit) +} + +// WriteInterval writes a PostgreSQL INTERVAL expression (INTERVAL expr UNIT). +func (d *Dialect) WriteInterval(w *strings.Builder, writeValue func() error, unit string) error { + w.WriteString("INTERVAL ") + if err := writeValue(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(unit) + return nil +} + +// WriteExtract writes a PostgreSQL EXTRACT expression with DOW conversion. +func (d *Dialect) WriteExtract(w *strings.Builder, part string, writeExpr func() error, writeTZ func() error) error { + // For getDayOfWeek, we need to wrap the entire EXTRACT for modulo operation + isDOW := part == "DOW" + if isDOW { + w.WriteString("(") + } + w.WriteString("EXTRACT(") + w.WriteString(part) + w.WriteString(" FROM ") + if err := writeExpr(); err != nil { + return err + } + if writeTZ != nil { + w.WriteString(" AT TIME ZONE ") + if err := writeTZ(); err != nil { + return err + } + } + w.WriteString(")") + if isDOW { + // PostgreSQL DOW: 0=Sunday, 1=Monday, ..., 6=Saturday + // CEL getDayOfWeek: 0=Monday, 1=Tuesday, ..., 6=Sunday (ISO 8601) + // Convert: (DOW + 6) % 7 + w.WriteString(" + 6) % 7") + } + return nil +} + +// WriteTimestampArithmetic writes PostgreSQL timestamp arithmetic (timestamp +/- interval). +func (d *Dialect) WriteTimestampArithmetic(w *strings.Builder, op string, writeTS, writeDur func() error) error { + if err := writeTS(); err != nil { + return err + } + w.WriteString(" ") + w.WriteString(op) + w.WriteString(" ") + return writeDur() +} + +// --- String Functions --- + +// WriteContains writes POSITION(needle IN haystack) > 0 for PostgreSQL. +func (d *Dialect) WriteContains(w *strings.Builder, writeHaystack, writeNeedle func() error) error { + w.WriteString("POSITION(") + if err := writeNeedle(); err != nil { + return err + } + w.WriteString(" IN ") + if err := writeHaystack(); err != nil { + return err + } + w.WriteString(") > 0") + return nil +} + +// WriteSplit writes STRING_TO_ARRAY(string, delimiter) for PostgreSQL. +func (d *Dialect) WriteSplit(w *strings.Builder, writeStr, writeDelim func() error) error { + w.WriteString("STRING_TO_ARRAY(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteSplitWithLimit writes (STRING_TO_ARRAY(string, delimiter))[1:limit] for PostgreSQL. +func (d *Dialect) WriteSplitWithLimit(w *strings.Builder, writeStr, writeDelim func() error, limit int64) error { + w.WriteString("(STRING_TO_ARRAY(") + if err := writeStr(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDelim(); err != nil { + return err + } + fmt.Fprintf(w, "))[1:%d]", limit) + return nil +} + +// WriteJoin writes ARRAY_TO_STRING(array, delimiter, ”) for PostgreSQL. +func (d *Dialect) WriteJoin(w *strings.Builder, writeArray, writeDelim func() error) error { + w.WriteString("ARRAY_TO_STRING(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", ") + if writeDelim != nil { + if err := writeDelim(); err != nil { + return err + } + } else { + w.WriteString("''") + } + w.WriteString(", '')") + return nil +} + +// --- Comprehensions --- + +// WriteUnnest writes UNNEST(source) for PostgreSQL comprehensions. +func (d *Dialect) WriteUnnest(w *strings.Builder, writeSource func() error) error { + w.WriteString("UNNEST(") + if err := writeSource(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteArraySubqueryOpen writes ARRAY(SELECT for PostgreSQL. +func (d *Dialect) WriteArraySubqueryOpen(w *strings.Builder) { + w.WriteString("ARRAY(SELECT ") +} + +// WriteArraySubqueryExprClose is a no-op for PostgreSQL (no wrapper around the expression). +func (d *Dialect) WriteArraySubqueryExprClose(_ *strings.Builder) { +} + +// --- Struct --- + +// WriteStructOpen writes the PostgreSQL struct/row literal opening (ROW(). +func (d *Dialect) WriteStructOpen(w *strings.Builder) { + w.WriteString("ROW(") +} + +// WriteStructClose writes the PostgreSQL struct/row literal closing ()). +func (d *Dialect) WriteStructClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Validation --- + +// MaxIdentifierLength returns the PostgreSQL maximum identifier length (63). +func (d *Dialect) MaxIdentifierLength() int { + return maxPostgreSQLIdentifierLength +} + +// ValidateFieldName validates a field name against PostgreSQL naming rules. +func (d *Dialect) ValidateFieldName(name string) error { + return validateFieldName(name) +} + +// ReservedKeywords returns the set of reserved SQL keywords for PostgreSQL. +func (d *Dialect) ReservedKeywords() map[string]bool { + return reservedSQLKeywords +} + +// --- Regex --- + +// ConvertRegex converts an RE2 regex pattern to PostgreSQL POSIX format. +func (d *Dialect) ConvertRegex(re2Pattern string) (string, bool, error) { + return convertRE2ToPOSIX(re2Pattern) +} + +// SupportsRegex returns true as PostgreSQL supports POSIX regex matching. +func (d *Dialect) SupportsRegex() bool { return true } + +// --- Capabilities --- + +// SupportsNativeArrays returns true as PostgreSQL has native array types. +func (d *Dialect) SupportsNativeArrays() bool { return true } + +// SupportsJSONB returns true as PostgreSQL has a distinct JSONB type. +func (d *Dialect) SupportsJSONB() bool { return true } + +// SupportsIndexAnalysis returns true as PostgreSQL index analysis is supported. +func (d *Dialect) SupportsIndexAnalysis() bool { return true } + +// --- Internal helpers --- + +// escapeJSONFieldName escapes single quotes in JSON field names for safe use in PostgreSQL JSON path operators. +func escapeJSONFieldName(fieldName string) string { + return strings.ReplaceAll(fieldName, "'", "''") +} diff --git a/dialect/postgres/index_advisor.go b/dialect/postgres/index_advisor.go new file mode 100644 index 0000000..286d24b --- /dev/null +++ b/dialect/postgres/index_advisor.go @@ -0,0 +1,110 @@ +package postgres + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// PostgreSQL index type constants. +const ( + IndexTypeBTree = "BTREE" + IndexTypeGIN = "GIN" + IndexTypeGIST = "GIST" +) + +// RecommendIndex generates a PostgreSQL-specific index recommendation for the given pattern. +// Returns nil if no applicable index exists for this pattern. +func (d *Dialect) RecommendIndex(pattern dialect.IndexPattern) *dialect.IndexRecommendation { + table := pattern.TableHint + if table == "" { + table = "table_name" + } + col := pattern.Column + safeName := sanitizeIndexName(col) + + switch pattern.Pattern { + case dialect.PatternComparison: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeBTree, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_btree ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Comparison operations on '%s' benefit from B-tree index for efficient range queries and equality checks", col), + } + + case dialect.PatternJSONAccess: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeGIN, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON %s USING GIN (%s);", + safeName, table, col), + Reason: fmt.Sprintf("JSON path operations on '%s' benefit from GIN index for efficient nested field access", col), + } + + case dialect.PatternRegexMatch: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeGIN, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin_trgm ON %s USING GIN (%s gin_trgm_ops);", + safeName, table, col), + Reason: fmt.Sprintf("Regex matching on '%s' benefits from GIN index with pg_trgm extension for pattern matching", col), + } + + case dialect.PatternArrayMembership: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeGIN, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON %s USING GIN (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Array membership tests on '%s' benefit from GIN index for efficient element lookups", col), + } + + case dialect.PatternArrayComprehension: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeGIN, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON %s USING GIN (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Array comprehension on '%s' benefits from GIN index for efficient array operations", col), + } + + case dialect.PatternJSONArrayComprehension: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeGIN, + Expression: fmt.Sprintf("CREATE INDEX idx_%s_gin ON %s USING GIN (%s);", + safeName, table, col), + Reason: fmt.Sprintf("JSONB array comprehension on '%s' benefits from GIN index for efficient array element access", col), + } + } + + return nil +} + +// SupportedPatterns returns all pattern types supported by PostgreSQL. +func (d *Dialect) SupportedPatterns() []dialect.PatternType { + return []dialect.PatternType{ + dialect.PatternComparison, + dialect.PatternJSONAccess, + dialect.PatternRegexMatch, + dialect.PatternArrayMembership, + dialect.PatternArrayComprehension, + dialect.PatternJSONArrayComprehension, + } +} + +// sanitizeIndexName creates a safe index name from a column name. +func sanitizeIndexName(column string) string { + sanitized := strings.ReplaceAll(column, ".", "_") + sanitized = strings.ReplaceAll(sanitized, " ", "_") + sanitized = strings.ReplaceAll(sanitized, "-", "_") + + // PostgreSQL index names are limited to 63 characters + if len(sanitized) > 50 { + sanitized = sanitized[:50] + } + + return sanitized +} diff --git a/dialect/postgres/regex.go b/dialect/postgres/regex.go new file mode 100644 index 0000000..1cc4a4b --- /dev/null +++ b/dialect/postgres/regex.go @@ -0,0 +1,143 @@ +package postgres + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +// Regex pattern complexity limits to prevent ReDoS attacks (CWE-1333). +const ( + maxRegexPatternLength = 500 + maxRegexGroups = 20 + maxRegexNestingDepth = 10 +) + +// convertRE2ToPOSIX converts an RE2 regex pattern to POSIX ERE format for PostgreSQL. +// It performs security validation to prevent ReDoS attacks (CWE-1333). +// Returns: (posixPattern, caseInsensitive, error) +func convertRE2ToPOSIX(re2Pattern string) (string, bool, error) { + // 1. Check pattern length to prevent processing extremely long patterns + if len(re2Pattern) > maxRegexPatternLength { + return "", false, fmt.Errorf("pattern length %d exceeds limit of %d characters", len(re2Pattern), maxRegexPatternLength) + } + + // 2. Extract case-insensitive flag if present + caseInsensitive := false + if strings.HasPrefix(re2Pattern, "(?i)") { + caseInsensitive = true + re2Pattern = strings.TrimPrefix(re2Pattern, "(?i)") + } + + // 3. Detect unsupported RE2 features and return errors + if strings.Contains(re2Pattern, "(?=") || strings.Contains(re2Pattern, "(?!") { + return "", false, errors.New("lookahead assertions (?=...), (?!...) are not supported in PostgreSQL POSIX regex") + } + if strings.Contains(re2Pattern, "(?<=") || strings.Contains(re2Pattern, "(?...) are not supported in PostgreSQL POSIX regex") + } + if strings.Contains(re2Pattern, "(?m") || strings.Contains(re2Pattern, "(?s") || strings.Contains(re2Pattern, "(?-") { + return "", false, errors.New("inline flags other than (?i) are not supported in PostgreSQL POSIX regex") + } + + // 4. Detect catastrophic nested quantifiers + if matched, _ := regexp.MatchString(`[*+][*+]`, re2Pattern); matched { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + + // Check for groups that contain quantifiers and are themselves quantified + depth := 0 + groupHasQuantifier := make([]bool, 0) + + for i := 0; i < len(re2Pattern); i++ { + char := re2Pattern[i] + + // Skip escaped characters + if i > 0 && re2Pattern[i-1] == '\\' { + continue + } + + switch char { + case '(': + depth++ + groupHasQuantifier = append(groupHasQuantifier, false) + case ')': + if depth > 0 { + depth-- + if i+1 < len(re2Pattern) { + nextChar := re2Pattern[i+1] + if nextChar == '*' || nextChar == '+' || nextChar == '?' || nextChar == '{' { + if len(groupHasQuantifier) > 0 && groupHasQuantifier[len(groupHasQuantifier)-1] { + return "", false, errors.New("regex contains catastrophic nested quantifiers that could cause ReDoS") + } + } + } + if len(groupHasQuantifier) > 0 { + if len(groupHasQuantifier) > 1 { + if groupHasQuantifier[len(groupHasQuantifier)-1] { + groupHasQuantifier[len(groupHasQuantifier)-2] = true + } + } + groupHasQuantifier = groupHasQuantifier[:len(groupHasQuantifier)-1] + } + } + case '*', '+', '?': + if len(groupHasQuantifier) > 0 { + groupHasQuantifier[len(groupHasQuantifier)-1] = true + } + case '{': + if len(groupHasQuantifier) > 0 { + groupHasQuantifier[len(groupHasQuantifier)-1] = true + } + } + } + + // 5. Count and limit capture groups + groupCount := strings.Count(re2Pattern, "(") - strings.Count(re2Pattern, `\(`) + if groupCount > maxRegexGroups { + return "", false, fmt.Errorf("regex contains %d capture groups, exceeds limit of %d", groupCount, maxRegexGroups) + } + + // 6. Detect exponential alternation patterns + alternationPattern := regexp.MustCompile(`\([^)]*\|[^)]*\)[*+]`) + if alternationPattern.MatchString(re2Pattern) { + return "", false, errors.New("regex contains quantified alternation that could cause ReDoS") + } + + // 7. Check nesting depth + maxDepthVal := 0 + currentDepth := 0 + for _, char := range re2Pattern { + if char == '(' && !strings.HasSuffix(re2Pattern[:strings.LastIndex(re2Pattern, string(char))], `\`) { + currentDepth++ + if currentDepth > maxDepthVal { + maxDepthVal = currentDepth + } + } else if char == ')' && !strings.HasSuffix(re2Pattern[:strings.LastIndex(re2Pattern, string(char))], `\`) { + currentDepth-- + } + } + if maxDepthVal > maxRegexNestingDepth { + return "", false, fmt.Errorf("nesting depth %d exceeds limit of %d", maxDepthVal, maxRegexNestingDepth) + } + + // Passed all security checks - proceed with conversion + posixPattern := re2Pattern + + // Convert RE2 patterns to POSIX equivalents + posixPattern = strings.ReplaceAll(posixPattern, `\b`, `\y`) + posixPattern = strings.ReplaceAll(posixPattern, `\B`, `[^[:alnum:]_]`) + posixPattern = strings.ReplaceAll(posixPattern, `\d`, `[[:digit:]]`) + posixPattern = strings.ReplaceAll(posixPattern, `\D`, `[^[:digit:]]`) + posixPattern = strings.ReplaceAll(posixPattern, `\w`, `[[:alnum:]_]`) + posixPattern = strings.ReplaceAll(posixPattern, `\W`, `[^[:alnum:]_]`) + posixPattern = strings.ReplaceAll(posixPattern, `\s`, `[[:space:]]`) + posixPattern = strings.ReplaceAll(posixPattern, `\S`, `[^[:space:]]`) + posixPattern = strings.ReplaceAll(posixPattern, `(?:`, `(`) + + return posixPattern, caseInsensitive, nil +} diff --git a/dialect/postgres/validation.go b/dialect/postgres/validation.go new file mode 100644 index 0000000..162da68 --- /dev/null +++ b/dialect/postgres/validation.go @@ -0,0 +1,66 @@ +package postgres + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +const ( + // maxPostgreSQLIdentifierLength is the maximum length for PostgreSQL identifiers + // PostgreSQL's NAMEDATALEN is 64 bytes (including null terminator), so max usable length is 63 + maxPostgreSQLIdentifierLength = 63 +) + +var ( + // fieldNameRegexp validates PostgreSQL identifier format + fieldNameRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + + // reservedSQLKeywords contains SQL keywords that should not be used as unquoted identifiers + reservedSQLKeywords = map[string]bool{ + "all": true, "analyse": true, "analyze": true, "and": true, "any": true, + "array": true, "as": true, "asc": true, "asymmetric": true, "both": true, + "case": true, "cast": true, "check": true, "collate": true, "column": true, + "constraint": true, "create": true, "cross": true, "current_catalog": true, + "current_date": true, "current_role": true, "current_time": true, + "current_timestamp": true, "current_user": true, "default": true, + "deferrable": true, "desc": true, "distinct": true, "do": true, "else": true, + "end": true, "except": true, "false": true, "fetch": true, "for": true, + "foreign": true, "from": true, "grant": true, "group": true, "having": true, + "in": true, "initially": true, "inner": true, "intersect": true, "into": true, + "is": true, "join": true, "leading": true, "left": true, "like": true, + "limit": true, "localtime": true, "localtimestamp": true, "natural": true, + "not": true, "null": true, "offset": true, "on": true, "only": true, + "or": true, "order": true, "outer": true, "overlaps": true, "placing": true, + "primary": true, "references": true, "returning": true, "right": true, + "select": true, "session_user": true, "similar": true, "some": true, + "symmetric": true, "table": true, "then": true, "to": true, "trailing": true, + "true": true, "union": true, "unique": true, "user": true, "using": true, + "variadic": true, "when": true, "where": true, "window": true, "with": true, + // Additional keywords that commonly cause issues + "alter": true, "delete": true, "drop": true, "insert": true, "update": true, + } +) + +// validateFieldName validates that a field name follows PostgreSQL naming conventions +// and is safe to use in SQL queries without quoting. +func validateFieldName(name string) error { + if len(name) == 0 { + return errors.New("field name cannot be empty") + } + + if len(name) > maxPostgreSQLIdentifierLength { + return fmt.Errorf("field name %q exceeds PostgreSQL maximum identifier length of %d characters", name, maxPostgreSQLIdentifierLength) + } + + if !fieldNameRegexp.MatchString(name) { + return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) + } + + if reservedSQLKeywords[strings.ToLower(name)] { + return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) + } + + return nil +} diff --git a/dialect/registry.go b/dialect/registry.go new file mode 100644 index 0000000..095af2d --- /dev/null +++ b/dialect/registry.go @@ -0,0 +1,42 @@ +package dialect + +import ( + "fmt" + "sync" +) + +var ( + registryMu sync.RWMutex + registry = make(map[Name]func() Dialect) +) + +// Register registers a dialect factory function by name. +// This is typically called in an init() function by each dialect package. +func Register(name Name, factory func() Dialect) { + registryMu.Lock() + defer registryMu.Unlock() + registry[name] = factory +} + +// Get returns a new dialect instance by name. +// Returns an error if the dialect is not registered. +func Get(name Name) (Dialect, error) { + registryMu.RLock() + defer registryMu.RUnlock() + factory, ok := registry[name] + if !ok { + return nil, fmt.Errorf("%w: dialect %q is not registered", ErrUnsupportedFeature, name) + } + return factory(), nil +} + +// Registered returns the names of all registered dialects. +func Registered() []Name { + registryMu.RLock() + defer registryMu.RUnlock() + names := make([]Name, 0, len(registry)) + for name := range registry { + names = append(names, name) + } + return names +} diff --git a/dialect/sqlite/dialect.go b/dialect/sqlite/dialect.go new file mode 100644 index 0000000..5e97d1f --- /dev/null +++ b/dialect/sqlite/dialect.go @@ -0,0 +1,462 @@ +// Package sqlite implements the SQLite SQL dialect for cel2sql. +package sqlite + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// Dialect implements dialect.Dialect for SQLite. +type Dialect struct{} + +// New creates a new SQLite dialect. +func New() *Dialect { + return &Dialect{} +} + +func init() { + dialect.Register(dialect.SQLite, func() dialect.Dialect { return New() }) +} + +// Ensure Dialect implements dialect.Dialect at compile time. +var _ dialect.Dialect = (*Dialect)(nil) + +// Name returns the dialect name. +func (d *Dialect) Name() dialect.Name { return dialect.SQLite } + +// --- Literals --- + +// WriteStringLiteral writes a SQLite string literal with ” escaping. +func (d *Dialect) WriteStringLiteral(w *strings.Builder, value string) { + escaped := strings.ReplaceAll(value, "'", "''") + w.WriteString("'") + w.WriteString(escaped) + w.WriteString("'") +} + +// WriteBytesLiteral writes a SQLite hex-encoded byte literal (X'...'). +func (d *Dialect) WriteBytesLiteral(w *strings.Builder, value []byte) error { + w.WriteString("X'") + for _, b := range value { + fmt.Fprintf(w, "%02x", b) + } + w.WriteString("'") + return nil +} + +// WriteParamPlaceholder writes a SQLite positional parameter (?). +func (d *Dialect) WriteParamPlaceholder(w *strings.Builder, _ int) { + w.WriteString("?") +} + +// --- Operators --- + +// WriteStringConcat writes SQLite string concatenation using ||. +func (d *Dialect) WriteStringConcat(w *strings.Builder, writeLHS, writeRHS func() error) error { + if err := writeLHS(); err != nil { + return err + } + w.WriteString(" || ") + return writeRHS() +} + +// WriteRegexMatch returns an error as SQLite does not natively support regex. +func (d *Dialect) WriteRegexMatch(_ *strings.Builder, _ func() error, _ string, _ bool) error { + return fmt.Errorf("%w: regex matching", dialect.ErrUnsupportedFeature) +} + +// WriteLikeEscape writes the SQLite LIKE escape clause. +// SQLite does not use backslash escaping in string literals, so '\' is a single character. +func (d *Dialect) WriteLikeEscape(w *strings.Builder) { + w.WriteString(" ESCAPE '\\'") +} + +// WriteArrayMembership writes a SQLite array membership test using json_each. +func (d *Dialect) WriteArrayMembership(w *strings.Builder, writeElem func() error, writeArray func() error) error { + if err := writeElem(); err != nil { + return err + } + w.WriteString(" IN (SELECT value FROM json_each(") + if err := writeArray(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// --- Type Casting --- + +// WriteCastToNumeric writes a SQLite numeric cast (CAST(... AS REAL)). +func (d *Dialect) WriteCastToNumeric(w *strings.Builder) { + w.WriteString(" + 0") +} + +// WriteTypeName writes a SQLite type name for CAST expressions. +func (d *Dialect) WriteTypeName(w *strings.Builder, celTypeName string) { + switch celTypeName { + case "bool": + w.WriteString("INTEGER") + case "bytes": + w.WriteString("BLOB") + case "double": + w.WriteString("REAL") + case "int": + w.WriteString("INTEGER") + case "string": + w.WriteString("TEXT") + case "uint": + w.WriteString("INTEGER") + default: + w.WriteString(strings.ToUpper(celTypeName)) + } +} + +// WriteEpochExtract writes strftime('%s', expr) for SQLite. +func (d *Dialect) WriteEpochExtract(w *strings.Builder, writeExpr func() error) error { + w.WriteString("CAST(strftime('%s', ") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(") AS INTEGER)") + return nil +} + +// WriteTimestampCast writes a SQLite datetime cast. +func (d *Dialect) WriteTimestampCast(w *strings.Builder, writeExpr func() error) error { + w.WriteString("datetime(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// --- Arrays --- + +// WriteArrayLiteralOpen writes the SQLite JSON array literal opening. +func (d *Dialect) WriteArrayLiteralOpen(w *strings.Builder) { + w.WriteString("json_array(") +} + +// WriteArrayLiteralClose writes the SQLite JSON array literal closing. +func (d *Dialect) WriteArrayLiteralClose(w *strings.Builder) { + w.WriteString(")") +} + +// WriteArrayLength writes json_array_length(expr) for SQLite. +func (d *Dialect) WriteArrayLength(w *strings.Builder, _ int, writeExpr func() error) error { + w.WriteString("COALESCE(json_array_length(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteListIndex writes SQLite JSON array index access. +func (d *Dialect) WriteListIndex(w *strings.Builder, writeArray func() error, writeIndex func() error) error { + w.WriteString("json_extract(") + if err := writeArray(); err != nil { + return err + } + w.WriteString(", '$[' || ") + if err := writeIndex(); err != nil { + return err + } + w.WriteString(" || ']')") + return nil +} + +// WriteListIndexConst writes SQLite JSON constant array index access. +func (d *Dialect) WriteListIndexConst(w *strings.Builder, writeArray func() error, index int64) error { + w.WriteString("json_extract(") + if err := writeArray(); err != nil { + return err + } + fmt.Fprintf(w, ", '$[%d]')", index) + return nil +} + +// WriteEmptyTypedArray writes an empty SQLite JSON array. +func (d *Dialect) WriteEmptyTypedArray(w *strings.Builder, _ string) { + w.WriteString("json_array()") +} + +// --- JSON --- + +// WriteJSONFieldAccess writes SQLite JSON field access using json_extract. +func (d *Dialect) WriteJSONFieldAccess(w *strings.Builder, writeBase func() error, fieldName string, _ bool) error { + w.WriteString("json_extract(") + if err := writeBase(); err != nil { + return err + } + escaped := escapeJSONFieldName(fieldName) + w.WriteString(", '$.") + w.WriteString(escaped) + w.WriteString("')") + return nil +} + +// WriteJSONExistence writes a SQLite JSON key existence check. +func (d *Dialect) WriteJSONExistence(w *strings.Builder, _ bool, fieldName string, writeBase func() error) error { + w.WriteString("json_type(") + if err := writeBase(); err != nil { + return err + } + escaped := escapeJSONFieldName(fieldName) + w.WriteString(", '$.") + w.WriteString(escaped) + w.WriteString("') IS NOT NULL") + return nil +} + +// WriteJSONArrayElements writes SQLite JSON array expansion using json_each. +func (d *Dialect) WriteJSONArrayElements(w *strings.Builder, _ bool, _ bool, writeExpr func() error) error { + w.WriteString("json_each(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteJSONArrayLength writes COALESCE(json_array_length(expr), 0) for SQLite. +func (d *Dialect) WriteJSONArrayLength(w *strings.Builder, writeExpr func() error) error { + w.WriteString("COALESCE(json_array_length(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("), 0)") + return nil +} + +// WriteJSONExtractPath writes SQLite JSON path extraction. +func (d *Dialect) WriteJSONExtractPath(w *strings.Builder, pathSegments []string, writeRoot func() error) error { + w.WriteString("json_type(") + if err := writeRoot(); err != nil { + return err + } + w.WriteString(", '$") + for _, segment := range pathSegments { + w.WriteString(".") + w.WriteString(escapeJSONFieldName(segment)) + } + w.WriteString("') IS NOT NULL") + return nil +} + +// WriteJSONArrayMembership writes SQLite JSON array membership using json_each. +func (d *Dialect) WriteJSONArrayMembership(w *strings.Builder, _ string, writeExpr func() error) error { + w.WriteString("(SELECT value FROM json_each(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// WriteNestedJSONArrayMembership writes SQLite nested JSON array membership. +func (d *Dialect) WriteNestedJSONArrayMembership(w *strings.Builder, writeExpr func() error) error { + w.WriteString("(SELECT value FROM json_each(") + if err := writeExpr(); err != nil { + return err + } + w.WriteString("))") + return nil +} + +// --- Timestamps --- + +// WriteDuration writes a SQLite duration as a string modifier. +func (d *Dialect) WriteDuration(w *strings.Builder, value int64, unit string) { + // SQLite uses datetime modifiers like '+N seconds', '+N minutes', etc. + fmt.Fprintf(w, "'%+d %s'", value, strings.ToLower(unit)+"s") +} + +// WriteInterval writes a SQLite interval expression. +func (d *Dialect) WriteInterval(w *strings.Builder, writeValue func() error, unit string) error { + w.WriteString("'+'||") + if err := writeValue(); err != nil { + return err + } + fmt.Fprintf(w, "||' %s'", strings.ToLower(unit)+"s") + return nil +} + +// WriteExtract writes a SQLite strftime extraction expression. +func (d *Dialect) WriteExtract(w *strings.Builder, part string, writeExpr func() error, _ func() error) error { + format := sqliteExtractFormat(part) + w.WriteString("CAST(strftime('") + w.WriteString(format) + w.WriteString("', ") + if err := writeExpr(); err != nil { + return err + } + w.WriteString(") AS INTEGER)") + return nil +} + +// WriteTimestampArithmetic writes SQLite timestamp arithmetic using datetime(). +func (d *Dialect) WriteTimestampArithmetic(w *strings.Builder, op string, writeTS, writeDur func() error) error { + if op == "-" { + // For subtraction, negate the duration + w.WriteString("datetime(") + if err := writeTS(); err != nil { + return err + } + w.WriteString(", '-'||") + if err := writeDur(); err != nil { + return err + } + w.WriteString(")") + } else { + w.WriteString("datetime(") + if err := writeTS(); err != nil { + return err + } + w.WriteString(", ") + if err := writeDur(); err != nil { + return err + } + w.WriteString(")") + } + return nil +} + +// --- String Functions --- + +// WriteContains writes INSTR(haystack, needle) > 0 for SQLite. +func (d *Dialect) WriteContains(w *strings.Builder, writeHaystack, writeNeedle func() error) error { + w.WriteString("INSTR(") + if err := writeHaystack(); err != nil { + return err + } + w.WriteString(", ") + if err := writeNeedle(); err != nil { + return err + } + w.WriteString(") > 0") + return nil +} + +// WriteSplit returns an error as SQLite does not have a native string split. +func (d *Dialect) WriteSplit(_ *strings.Builder, _, _ func() error) error { + return fmt.Errorf("%w: string split", dialect.ErrUnsupportedFeature) +} + +// WriteSplitWithLimit returns an error as SQLite does not have a native string split. +func (d *Dialect) WriteSplitWithLimit(_ *strings.Builder, _, _ func() error, _ int64) error { + return fmt.Errorf("%w: string split with limit", dialect.ErrUnsupportedFeature) +} + +// WriteJoin returns an error as SQLite does not have a native array join. +func (d *Dialect) WriteJoin(_ *strings.Builder, _, _ func() error) error { + return fmt.Errorf("%w: array join", dialect.ErrUnsupportedFeature) +} + +// --- Comprehensions --- + +// WriteUnnest writes SQLite json_each for array unnesting. +func (d *Dialect) WriteUnnest(w *strings.Builder, writeSource func() error) error { + w.WriteString("json_each(") + if err := writeSource(); err != nil { + return err + } + w.WriteString(")") + return nil +} + +// WriteArraySubqueryOpen writes (SELECT json_group_array( for SQLite array subqueries. +func (d *Dialect) WriteArraySubqueryOpen(w *strings.Builder) { + w.WriteString("(SELECT json_group_array(") +} + +// WriteArraySubqueryExprClose closes the json_group_array aggregate function for SQLite. +func (d *Dialect) WriteArraySubqueryExprClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Struct --- + +// WriteStructOpen writes the SQLite struct literal opening. +func (d *Dialect) WriteStructOpen(w *strings.Builder) { + w.WriteString("json_object(") +} + +// WriteStructClose writes the SQLite struct literal closing. +func (d *Dialect) WriteStructClose(w *strings.Builder) { + w.WriteString(")") +} + +// --- Validation --- + +// MaxIdentifierLength returns 0 as SQLite has no hard identifier length limit. +func (d *Dialect) MaxIdentifierLength() int { + return 0 +} + +// ValidateFieldName validates a field name against SQLite naming rules. +func (d *Dialect) ValidateFieldName(name string) error { + return validateFieldName(name) +} + +// ReservedKeywords returns the set of reserved SQL keywords for SQLite. +func (d *Dialect) ReservedKeywords() map[string]bool { + return reservedSQLKeywords +} + +// --- Regex --- + +// ConvertRegex returns an error as SQLite does not natively support regex. +func (d *Dialect) ConvertRegex(_ string) (string, bool, error) { + return "", false, fmt.Errorf("%w: regex matching", dialect.ErrUnsupportedFeature) +} + +// SupportsRegex returns false as SQLite does not natively support regex. +func (d *Dialect) SupportsRegex() bool { return false } + +// --- Capabilities --- + +// SupportsNativeArrays returns false as SQLite uses JSON arrays. +func (d *Dialect) SupportsNativeArrays() bool { return false } + +// SupportsJSONB returns false as SQLite has a single JSON type. +func (d *Dialect) SupportsJSONB() bool { return false } + +// SupportsIndexAnalysis returns true as SQLite index analysis is supported. +func (d *Dialect) SupportsIndexAnalysis() bool { return true } + +// --- Internal helpers --- + +// escapeJSONFieldName escapes special characters in JSON field names for SQLite. +func escapeJSONFieldName(fieldName string) string { + return strings.ReplaceAll(fieldName, "'", "''") +} + +// sqliteExtractFormat maps SQL EXTRACT parts to SQLite strftime format strings. +func sqliteExtractFormat(part string) string { + switch part { + case "YEAR": + return "%Y" + case "MONTH": + return "%m" + case "DAY": + return "%d" + case "HOUR": + return "%H" + case "MINUTE": + return "%M" + case "SECOND": + return "%S" + case "DOY": + return "%j" + case "DOW": + return "%w" + case "MILLISECONDS": + return "%f" + default: + return "%Y" + } +} diff --git a/dialect/sqlite/index_advisor.go b/dialect/sqlite/index_advisor.go new file mode 100644 index 0000000..107ce11 --- /dev/null +++ b/dialect/sqlite/index_advisor.go @@ -0,0 +1,73 @@ +package sqlite + +import ( + "fmt" + "strings" + + "github.com/spandigital/cel2sql/v3/dialect" +) + +// SQLite index type constants. +const ( + IndexTypeBTree = "BTREE" +) + +// RecommendIndex generates a SQLite-specific index recommendation for the given pattern. +// SQLite only supports standard B-tree indexes. Returns nil for unsupported patterns. +func (d *Dialect) RecommendIndex(pattern dialect.IndexPattern) *dialect.IndexRecommendation { + table := pattern.TableHint + if table == "" { + table = "table_name" + } + col := pattern.Column + safeName := sanitizeIndexName(col) + + switch pattern.Pattern { + case dialect.PatternComparison: + return &dialect.IndexRecommendation{ + Column: col, + IndexType: IndexTypeBTree, + Expression: fmt.Sprintf("CREATE INDEX idx_%s ON %s (%s);", + safeName, table, col), + Reason: fmt.Sprintf("Comparison operations on '%s' benefit from an index for efficient range queries and equality checks", col), + } + + case dialect.PatternJSONAccess: + // SQLite does not support indexes on JSON expressions directly + return nil + + case dialect.PatternRegexMatch: + // SQLite does not support native regex; no index recommendation + return nil + + case dialect.PatternArrayMembership, dialect.PatternArrayComprehension: + // SQLite does not have native array types + return nil + + case dialect.PatternJSONArrayComprehension: + // SQLite does not support indexes on JSON array operations + return nil + } + + return nil +} + +// SupportedPatterns returns the pattern types supported by SQLite. +func (d *Dialect) SupportedPatterns() []dialect.PatternType { + return []dialect.PatternType{ + dialect.PatternComparison, + } +} + +// sanitizeIndexName creates a safe index name from a column name. +func sanitizeIndexName(column string) string { + sanitized := strings.ReplaceAll(column, ".", "_") + sanitized = strings.ReplaceAll(sanitized, " ", "_") + sanitized = strings.ReplaceAll(sanitized, "-", "_") + + if len(sanitized) > 50 { + sanitized = sanitized[:50] + } + + return sanitized +} diff --git a/dialect/sqlite/validation.go b/dialect/sqlite/validation.go new file mode 100644 index 0000000..805c06c --- /dev/null +++ b/dialect/sqlite/validation.go @@ -0,0 +1,68 @@ +package sqlite + +import ( + "errors" + "fmt" + "regexp" + "strings" +) + +var ( + // fieldNameRegexp validates SQLite identifier format. + fieldNameRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + + // reservedSQLKeywords contains SQLite reserved keywords. + reservedSQLKeywords = map[string]bool{ + "abort": true, "action": true, "add": true, "after": true, "all": true, + "alter": true, "always": true, "analyze": true, "and": true, "as": true, + "asc": true, "attach": true, "autoincrement": true, "before": true, + "begin": true, "between": true, "by": true, "cascade": true, "case": true, + "cast": true, "check": true, "collate": true, "column": true, "commit": true, + "conflict": true, "constraint": true, "create": true, "cross": true, + "current": true, "current_date": true, "current_time": true, + "current_timestamp": true, "database": true, "default": true, + "deferrable": true, "deferred": true, "delete": true, "desc": true, + "detach": true, "distinct": true, "do": true, "drop": true, "each": true, + "else": true, "end": true, "escape": true, "except": true, "exclude": true, + "exclusive": true, "exists": true, "explain": true, "fail": true, + "filter": true, "first": true, "following": true, "for": true, + "foreign": true, "from": true, "full": true, "glob": true, "group": true, + "groups": true, "having": true, "if": true, "ignore": true, "immediate": true, + "in": true, "index": true, "indexed": true, "initially": true, "inner": true, + "insert": true, "instead": true, "intersect": true, "into": true, "is": true, + "isnull": true, "join": true, "key": true, "last": true, "left": true, + "like": true, "limit": true, "match": true, "materialized": true, + "natural": true, "no": true, "not": true, "nothing": true, "notnull": true, + "null": true, "nulls": true, "of": true, "offset": true, "on": true, + "or": true, "order": true, "others": true, "outer": true, "over": true, + "partition": true, "plan": true, "pragma": true, "preceding": true, + "primary": true, "query": true, "raise": true, "range": true, + "recursive": true, "references": true, "regexp": true, "reindex": true, + "release": true, "rename": true, "replace": true, "restrict": true, + "returning": true, "right": true, "rollback": true, "row": true, + "rows": true, "savepoint": true, "select": true, "set": true, "table": true, + "temp": true, "temporary": true, "then": true, "ties": true, "to": true, + "transaction": true, "trigger": true, "unbounded": true, "union": true, + "unique": true, "update": true, "using": true, "vacuum": true, "values": true, + "view": true, "virtual": true, "when": true, "where": true, "window": true, + "with": true, "without": true, + } +) + +// validateFieldName validates that a field name follows SQLite naming conventions. +func validateFieldName(name string) error { + if len(name) == 0 { + return errors.New("field name cannot be empty") + } + + // SQLite has no hard limit on identifier length but we use a reasonable limit + if !fieldNameRegexp.MatchString(name) { + return fmt.Errorf("field name %q must start with a letter or underscore and contain only alphanumeric characters and underscores", name) + } + + if reservedSQLKeywords[strings.ToLower(name)] { + return fmt.Errorf("field name %q is a reserved SQL keyword and cannot be used without quoting", name) + } + + return nil +} diff --git a/duckdb/provider.go b/duckdb/provider.go new file mode 100644 index 0000000..e1b73b9 --- /dev/null +++ b/duckdb/provider.go @@ -0,0 +1,251 @@ +// Package duckdb provides DuckDB type provider for CEL type system integration. +package duckdb + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/spandigital/cel2sql/v3/schema" +) + +// Sentinel errors for the duckdb package. +var ( + // ErrInvalidSchema indicates a problem with the provided schema or database introspection. + ErrInvalidSchema = errors.New("invalid schema") +) + +// FieldSchema is an alias for schema.FieldSchema. +type FieldSchema = schema.FieldSchema + +// Schema is an alias for schema.Schema. +type Schema = schema.Schema + +// NewSchema creates a new Schema. This is an alias for schema.NewSchema. +var NewSchema = schema.NewSchema + +// TypeProvider interface for DuckDB type providers. +type TypeProvider interface { + types.Provider + LoadTableSchema(ctx context.Context, tableName string) error + GetSchemas() map[string]Schema + Close() +} + +type typeProvider struct { + schemas map[string]Schema + db *sql.DB +} + +// NewTypeProvider creates a new DuckDB type provider with pre-defined schemas. +func NewTypeProvider(schemas map[string]Schema) TypeProvider { + return &typeProvider{schemas: schemas} +} + +// NewTypeProviderWithConnection creates a new DuckDB type provider that can introspect database schemas. +// The caller owns the *sql.DB and is responsible for closing it. +// This works with any DuckDB driver that implements database/sql (e.g., github.com/marcboeker/go-duckdb). +func NewTypeProviderWithConnection(_ context.Context, db *sql.DB) (TypeProvider, error) { + if db == nil { + return nil, fmt.Errorf("%w: db connection must not be nil", ErrInvalidSchema) + } + + return &typeProvider{ + schemas: make(map[string]Schema), + db: db, + }, nil +} + +// LoadTableSchema loads schema information for a table from the database. +func (tp *typeProvider) LoadTableSchema(ctx context.Context, tableName string) error { + if tp.db == nil { + return fmt.Errorf("%w: no database connection available", ErrInvalidSchema) + } + + query := ` + SELECT column_name, data_type, is_nullable + FROM information_schema.columns + WHERE table_name = ? + ORDER BY ordinal_position + ` + + rows, err := tp.db.QueryContext(ctx, query, tableName) + if err != nil { + return fmt.Errorf("%w: failed to query table schema", ErrInvalidSchema) + } + defer func() { _ = rows.Close() }() + + var fields []FieldSchema + for rows.Next() { + var columnName, dataType, isNullable string + + if err := rows.Scan(&columnName, &dataType, &isNullable); err != nil { + return fmt.Errorf("%w: failed to scan row", ErrInvalidSchema) + } + + field := duckdbColumnToFieldSchema(columnName, dataType) + fields = append(fields, field) + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("%w: error iterating rows", ErrInvalidSchema) + } + + if len(fields) == 0 { + return fmt.Errorf("%w: table %q has no columns or does not exist", ErrInvalidSchema, tableName) + } + + tp.schemas[tableName] = NewSchema(fields) + return nil +} + +// duckdbColumnToFieldSchema converts DuckDB column metadata to a FieldSchema. +func duckdbColumnToFieldSchema(columnName, dataType string) FieldSchema { + // DuckDB array types appear as "INTEGER[]", "VARCHAR[]", etc. + isArray, elementType, dimensions := detectDuckDBArray(dataType) + isJSON := strings.EqualFold(dataType, "json") + + if isArray { + return FieldSchema{ + Name: columnName, + Type: strings.ToLower(elementType), + Repeated: true, + Dimensions: dimensions, + ElementType: strings.ToLower(elementType), + } + } + + return FieldSchema{ + Name: columnName, + Type: normalizeDuckDBType(dataType), + IsJSON: isJSON, + } +} + +// detectDuckDBArray detects if a DuckDB data type is an array and returns element type and dimensions. +func detectDuckDBArray(dataType string) (isArray bool, elementType string, dimensions int) { + // Count trailing [] pairs + remaining := dataType + dims := 0 + for strings.HasSuffix(remaining, "[]") { + dims++ + remaining = strings.TrimSuffix(remaining, "[]") + } + + if dims > 0 { + return true, remaining, dims + } + return false, "", 0 +} + +// normalizeDuckDBType normalizes a DuckDB type name to lowercase. +func normalizeDuckDBType(dataType string) string { + return strings.ToLower(dataType) +} + +// Close is a no-op since we don't own the *sql.DB. +func (tp *typeProvider) Close() { + // No-op: caller owns the *sql.DB connection +} + +// GetSchemas returns the schemas known to this type provider. +func (tp *typeProvider) GetSchemas() map[string]Schema { + return tp.schemas +} + +// EnumValue implements types.Provider. +func (tp *typeProvider) EnumValue(_ string) ref.Val { + return types.NewErr("unknown enum value") +} + +// FindIdent implements types.Provider. +func (tp *typeProvider) FindIdent(_ string) (ref.Val, bool) { + return nil, false +} + +// FindStructType implements types.Provider. +func (tp *typeProvider) FindStructType(structType string) (*types.Type, bool) { + if _, ok := tp.schemas[structType]; ok { + return types.NewObjectType(structType), true + } + return nil, false +} + +// FindStructFieldNames implements types.Provider. +func (tp *typeProvider) FindStructFieldNames(structType string) ([]string, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + fields := s.Fields() + names := make([]string, len(fields)) + for i, f := range fields { + names[i] = f.Name + } + return names, true +} + +// FindStructFieldType implements types.Provider. +func (tp *typeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + field, found := s.FindField(fieldName) + if !found { + return nil, false + } + + exprType := duckdbTypeToCELExprType(field) + celType, err := types.ExprTypeToType(exprType) + if err != nil { + return nil, false + } + + return &types.FieldType{ + Type: celType, + }, true +} + +// NewValue implements types.Provider. +func (tp *typeProvider) NewValue(_ string, _ map[string]ref.Val) ref.Val { + return types.NewErr("unknown type in schema") +} + +// duckdbTypeToCELExprType converts a DuckDB field schema to a CEL expression type. +func duckdbTypeToCELExprType(field *schema.FieldSchema) *exprpb.Type { + baseType := duckdbBaseTypeToCEL(field.Type) + if field.Repeated { + return decls.NewListType(baseType) + } + return baseType +} + +// duckdbBaseTypeToCEL converts a DuckDB type name to a CEL expression type. +func duckdbBaseTypeToCEL(typeName string) *exprpb.Type { + switch typeName { + case "varchar", "text", "char", "bpchar", "name": + return decls.String + case "bigint", "integer", "int", "int4", "int8", "smallint", "int2", "tinyint", "hugeint": + return decls.Int + case "double", "float", "real", "float4", "float8", "numeric", "decimal": + return decls.Double + case "boolean", "bool": + return decls.Bool + case "blob", "bytea": + return decls.Bytes + case "json": + return decls.Dyn + case "timestamp", "timestamptz", "timestamp with time zone", "timestamp without time zone": + return decls.Timestamp + default: + return decls.Dyn + } +} diff --git a/duckdb/provider_test.go b/duckdb/provider_test.go new file mode 100644 index 0000000..98e47f0 --- /dev/null +++ b/duckdb/provider_test.go @@ -0,0 +1,157 @@ +package duckdb_test + +import ( + "context" + "testing" + + "github.com/google/cel-go/common/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/spandigital/cel2sql/v3/duckdb" +) + +func TestNewTypeProvider(t *testing.T) { + schemas := map[string]duckdb.Schema{ + "users": duckdb.NewSchema([]duckdb.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "varchar"}, + }), + } + + provider := duckdb.NewTypeProvider(schemas) + require.NotNil(t, provider) + + got := provider.GetSchemas() + assert.Len(t, got, 1) + assert.Contains(t, got, "users") +} + +func TestNewTypeProviderWithConnection_NilDB(t *testing.T) { + _, err := duckdb.NewTypeProviderWithConnection(context.Background(), nil) + require.Error(t, err) + assert.ErrorIs(t, err, duckdb.ErrInvalidSchema) +} + +func TestLoadTableSchema_NoDB(t *testing.T) { + provider := duckdb.NewTypeProvider(nil) + err := provider.LoadTableSchema(context.Background(), "test") + require.Error(t, err) + assert.ErrorIs(t, err, duckdb.ErrInvalidSchema) +} + +func TestTypeProvider_FindStructType(t *testing.T) { + schemas := map[string]duckdb.Schema{ + "users": duckdb.NewSchema([]duckdb.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "varchar"}, + }), + } + provider := duckdb.NewTypeProvider(schemas) + + typ, found := provider.FindStructType("users") + assert.True(t, found) + assert.NotNil(t, typ) + + _, found = provider.FindStructType("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldNames(t *testing.T) { + schemas := map[string]duckdb.Schema{ + "users": duckdb.NewSchema([]duckdb.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "varchar"}, + {Name: "email", Type: "text"}, + }), + } + provider := duckdb.NewTypeProvider(schemas) + + names, found := provider.FindStructFieldNames("users") + assert.True(t, found) + assert.ElementsMatch(t, []string{"id", "name", "email"}, names) + + _, found = provider.FindStructFieldNames("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldType(t *testing.T) { + schemas := map[string]duckdb.Schema{ + "test_table": duckdb.NewSchema([]duckdb.FieldSchema{ + {Name: "str_field", Type: "varchar"}, + {Name: "text_field", Type: "text"}, + {Name: "int_field", Type: "integer"}, + {Name: "bigint_field", Type: "bigint"}, + {Name: "smallint_field", Type: "smallint"}, + {Name: "tinyint_field", Type: "tinyint"}, + {Name: "hugeint_field", Type: "hugeint"}, + {Name: "double_field", Type: "double"}, + {Name: "float_field", Type: "float"}, + {Name: "bool_field", Type: "boolean"}, + {Name: "blob_field", Type: "blob"}, + {Name: "json_field", Type: "json"}, + {Name: "ts_field", Type: "timestamp"}, + {Name: "array_field", Type: "integer", Repeated: true}, + }), + } + provider := duckdb.NewTypeProvider(schemas) + + tests := []struct { + fieldName string + wantType *types.Type + wantFound bool + }{ + {"str_field", types.StringType, true}, + {"text_field", types.StringType, true}, + {"int_field", types.IntType, true}, + {"bigint_field", types.IntType, true}, + {"smallint_field", types.IntType, true}, + {"tinyint_field", types.IntType, true}, + {"hugeint_field", types.IntType, true}, + {"double_field", types.DoubleType, true}, + {"float_field", types.DoubleType, true}, + {"bool_field", types.BoolType, true}, + {"blob_field", types.BytesType, true}, + {"json_field", types.DynType, true}, + {"ts_field", types.TimestampType, true}, + {"array_field", types.NewListType(types.IntType), true}, + {"nonexistent", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("test_table", tt.fieldName) + assert.Equal(t, tt.wantFound, found) + if found { + assert.Equal(t, tt.wantType, got.Type) + } + }) + } +} + +func TestTypeProvider_Close(_ *testing.T) { + provider := duckdb.NewTypeProvider(nil) + // Close should not panic + provider.Close() +} + +func TestTypeProvider_ArrayDetection(t *testing.T) { + // Test that arrays defined manually with Repeated=true work correctly + schemas := map[string]duckdb.Schema{ + "test_table": duckdb.NewSchema([]duckdb.FieldSchema{ + {Name: "tags", Type: "varchar", Repeated: true, Dimensions: 1}, + {Name: "matrix", Type: "integer", Repeated: true, Dimensions: 2}, + }), + } + provider := duckdb.NewTypeProvider(schemas) + + // tags should be list of strings + got, found := provider.FindStructFieldType("test_table", "tags") + assert.True(t, found) + assert.Equal(t, types.NewListType(types.StringType), got.Type) + + // matrix should be list of integers (CEL sees all array dims as list) + got, found = provider.FindStructFieldType("test_table", "matrix") + assert.True(t, found) + assert.Equal(t, types.NewListType(types.IntType), got.Type) +} diff --git a/errors.go b/errors.go index 92115f3..4f6b585 100644 --- a/errors.go +++ b/errors.go @@ -58,6 +58,9 @@ var ( // ErrInvalidByteArrayLength indicates byte array exceeds maximum length ErrInvalidByteArrayLength = errors.New("byte array exceeds maximum length") + + // ErrUnsupportedDialectFeature indicates a feature not supported by the selected dialect + ErrUnsupportedDialectFeature = errors.New("unsupported dialect feature") ) // ConversionError represents an error that occurred during CEL to SQL conversion. diff --git a/examples/index_analysis/main.go b/examples/index_analysis/main.go index 73f744b..8c4be94 100644 --- a/examples/index_analysis/main.go +++ b/examples/index_analysis/main.go @@ -7,6 +7,12 @@ import ( "github.com/google/cel-go/cel" "github.com/spandigital/cel2sql/v3" + "github.com/spandigital/cel2sql/v3/dialect" + dialectbq "github.com/spandigital/cel2sql/v3/dialect/bigquery" + dialectduckdb "github.com/spandigital/cel2sql/v3/dialect/duckdb" + dialectmysql "github.com/spandigital/cel2sql/v3/dialect/mysql" + dialectpg "github.com/spandigital/cel2sql/v3/dialect/postgres" + dialectsqlite "github.com/spandigital/cel2sql/v3/dialect/sqlite" "github.com/spandigital/cel2sql/v3/pg" ) @@ -77,59 +83,114 @@ func main() { }, } - // Analyze each query and display recommendations + // Analyze each query and display recommendations (PostgreSQL default) + fmt.Println("\n--- PostgreSQL (default) ---") for i, ex := range examples { - fmt.Printf("%d. %s\n", i+1, ex.name) - fmt.Printf(" Description: %s\n", ex.description) - fmt.Printf(" CEL Expression: %s\n\n", ex.expression) - - // Compile the CEL expression - ast, issues := env.Compile(ex.expression) - if issues != nil && issues.Err() != nil { - log.Printf(" ERROR: Failed to compile: %v\n\n", issues.Err()) - continue - } + analyzeExample(env, provider, i, ex.name, ex.description, ex.expression) + } + + // Multi-dialect examples + fmt.Println("\n===================================") + fmt.Println("Multi-Dialect Index Recommendations") + fmt.Println("===================================") + + // Use a simple comparison query to show dialect differences + comparisonExpr := `users.age > 21 && users.metadata.verified == true` + + dialectExamples := []struct { + name string + dialect dialect.Dialect + }{ + {"PostgreSQL", dialectpg.New()}, + {"MySQL", dialectmysql.New()}, + {"SQLite", dialectsqlite.New()}, + {"DuckDB", dialectduckdb.New()}, + {"BigQuery", dialectbq.New()}, + } + + ast, issues := env.Compile(comparisonExpr) + if issues != nil && issues.Err() != nil { + log.Fatalf("Failed to compile expression: %v", issues.Err()) + } - // Analyze the query - sql, recommendations, err := cel2sql.AnalyzeQuery(ast, - cel2sql.WithSchemas(provider.GetSchemas())) + for _, de := range dialectExamples { + fmt.Printf("\n--- %s ---\n", de.name) + fmt.Printf(" CEL Expression: %s\n\n", comparisonExpr) + + _, recommendations, err := cel2sql.AnalyzeQuery(ast, + cel2sql.WithSchemas(provider.GetSchemas()), + cel2sql.WithDialect(de.dialect)) if err != nil { log.Printf(" ERROR: Failed to analyze: %v\n\n", err) continue } - // Display the generated SQL - fmt.Printf(" Generated SQL:\n %s\n\n", sql) - - // Display index recommendations if len(recommendations) == 0 { - fmt.Printf(" No index recommendations (query uses constants or simple conditions)\n\n") + fmt.Printf(" No index recommendations\n") } else { - fmt.Printf(" Index Recommendations (%d):\n", len(recommendations)) for j, rec := range recommendations { fmt.Printf(" [%d] Column: %s\n", j+1, rec.Column) fmt.Printf(" Type: %s\n", rec.IndexType) fmt.Printf(" Reason: %s\n", rec.Reason) - fmt.Printf(" SQL: %s\n", rec.Expression) + fmt.Printf(" DDL: %s\n", rec.Expression) fmt.Println() } } - - fmt.Println(" " + string(make([]byte, 60))) - fmt.Println() } // Summary fmt.Println("\nSummary") fmt.Println("=======") - fmt.Println("Index recommendations help optimize query performance by:") - fmt.Println(" • B-tree indexes: Fast equality and range queries on scalar columns") - fmt.Println(" • GIN indexes: Efficient JSON path access and array operations") - fmt.Println(" • GIN with pg_trgm: Fast regex pattern matching on text columns") + fmt.Println("Index recommendations are dialect-aware:") + fmt.Println(" PostgreSQL: B-tree, GIN, GIN with pg_trgm") + fmt.Println(" MySQL: B-tree, FULLTEXT, functional JSON indexes") + fmt.Println(" SQLite: B-tree (limited index types)") + fmt.Println(" DuckDB: ART (Adaptive Radix Tree)") + fmt.Println(" BigQuery: Clustering keys, Search indexes") + fmt.Println() + fmt.Println("Use WithDialect() to get dialect-specific recommendations:") + fmt.Println(" sql, recs, err := cel2sql.AnalyzeQuery(ast,") + fmt.Println(" cel2sql.WithDialect(mysql.New()),") + fmt.Println(" cel2sql.WithSchemas(schemas))") +} + +func analyzeExample(env *cel.Env, provider pg.TypeProvider, idx int, name, description, expression string) { + fmt.Printf("%d. %s\n", idx+1, name) + fmt.Printf(" Description: %s\n", description) + fmt.Printf(" CEL Expression: %s\n\n", expression) + + // Compile the CEL expression + ast, issues := env.Compile(expression) + if issues != nil && issues.Err() != nil { + log.Printf(" ERROR: Failed to compile: %v\n\n", issues.Err()) + return + } + + // Analyze the query + sql, recommendations, err := cel2sql.AnalyzeQuery(ast, + cel2sql.WithSchemas(provider.GetSchemas())) + if err != nil { + log.Printf(" ERROR: Failed to analyze: %v\n\n", err) + return + } + + // Display the generated SQL + fmt.Printf(" Generated SQL:\n %s\n\n", sql) + + // Display index recommendations + if len(recommendations) == 0 { + fmt.Printf(" No index recommendations (query uses constants or simple conditions)\n\n") + } else { + fmt.Printf(" Index Recommendations (%d):\n", len(recommendations)) + for j, rec := range recommendations { + fmt.Printf(" [%d] Column: %s\n", j+1, rec.Column) + fmt.Printf(" Type: %s\n", rec.IndexType) + fmt.Printf(" Reason: %s\n", rec.Reason) + fmt.Printf(" SQL: %s\n", rec.Expression) + fmt.Println() + } + } + + fmt.Println(" " + string(make([]byte, 60))) fmt.Println() - fmt.Println("To apply recommendations:") - fmt.Println(" 1. Review each recommendation and its reason") - fmt.Println(" 2. Adjust table_name to your actual table name") - fmt.Println(" 3. Execute the CREATE INDEX statements on your database") - fmt.Println(" 4. Monitor query performance improvements") } diff --git a/go.mod b/go.mod index b383cc0..58118e2 100644 --- a/go.mod +++ b/go.mod @@ -3,21 +3,35 @@ module github.com/spandigital/cel2sql/v3 go 1.25.7 require ( + cloud.google.com/go/bigquery v1.73.1 + github.com/go-sql-driver/mysql v1.9.3 github.com/google/cel-go v0.27.0 github.com/jackc/pgx/v5 v5.8.0 github.com/lib/pq v1.11.2 github.com/stretchr/testify v1.11.1 github.com/testcontainers/testcontainers-go v0.40.0 + github.com/testcontainers/testcontainers-go/modules/gcloud v0.40.0 + github.com/testcontainers/testcontainers-go/modules/mysql v0.40.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 - google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 + google.golang.org/api v0.268.0 + google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 + google.golang.org/grpc v1.79.1 + modernc.org/sqlite v1.46.1 ) require ( cel.dev/expr v0.25.1 // indirect + cloud.google.com/go v0.123.0 // indirect + cloud.google.com/go/auth v0.18.1 // indirect + cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect + cloud.google.com/go/compute/metadata v0.9.0 // indirect + cloud.google.com/go/iam v1.5.3 // indirect dario.cat/mergo v1.0.2 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/antlr4-go/antlr/v4 v4.13.1 // indirect + github.com/apache/arrow/go/v15 v15.0.2 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/containerd/errdefs v1.0.0 // indirect @@ -30,19 +44,27 @@ require ( github.com/docker/docker v28.5.2+incompatible // indirect github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/ebitengine/purego v0.8.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/google/flatbuffers v23.5.26+incompatible // indirect + github.com/google/s2a-go v0.1.9 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.11 // indirect + github.com/googleapis/gax-go/v2 v2.17.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/cpuid/v2 v2.2.5 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.10 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/go-archive v0.1.0 // indirect github.com/moby/patternmatcher v0.6.0 // indirect @@ -51,28 +73,43 @@ require ( github.com/moby/sys/userns v0.1.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/morikuni/aec v1.0.0 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/pierrec/lz4/v4 v4.1.18 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/shirou/gopsutil/v4 v4.25.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect + github.com/zeebo/xxh3 v1.0.2 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 // indirect go.opentelemetry.io/otel v1.40.0 // indirect go.opentelemetry.io/otel/metric v1.40.0 // indirect go.opentelemetry.io/otel/trace v1.40.0 // indirect - golang.org/x/crypto v0.46.0 // indirect - golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 // indirect + golang.org/x/crypto v0.47.0 // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect + golang.org/x/mod v0.31.0 // indirect + golang.org/x/net v0.49.0 // indirect + golang.org/x/oauth2 v0.35.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect - golang.org/x/text v0.32.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect - google.golang.org/grpc v1.73.0 // indirect - google.golang.org/protobuf v1.36.10 // indirect + golang.org/x/telemetry v0.0.0-20251203150158-8fff8a5912fc // indirect + golang.org/x/text v0.33.0 // indirect + golang.org/x/time v0.14.0 // indirect + golang.org/x/tools v0.40.0 // indirect + golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect + google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20 // indirect + google.golang.org/protobuf v1.36.11 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/libc v1.67.6 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect ) diff --git a/go.sum b/go.sum index 095dcf6..1d392a9 100644 --- a/go.sum +++ b/go.sum @@ -1,19 +1,51 @@ cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4= cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4= +cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= +cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= +cloud.google.com/go/auth v0.18.1 h1:IwTEx92GFUo2pJ6Qea0EU3zYvKnTAeRCODxfA/G5UWs= +cloud.google.com/go/auth v0.18.1/go.mod h1:GfTYoS9G3CWpRA3Va9doKN9mjPGRS+v41jmZAhBzbrA= +cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= +cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= +cloud.google.com/go/bigquery v1.73.1 h1:v//GZwdhtmCbZ87rOnxz7pectOGFS1GNRvrGTvLzka4= +cloud.google.com/go/bigquery v1.73.1/go.mod h1:KSLx1mKP/yGiA8U+ohSrqZM1WknUnjZAxHAQZ51/b1k= +cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +cloud.google.com/go/datacatalog v1.26.1 h1:bCRKA8uSQN8wGW3Tw0gwko4E9a64GRmbW1nCblhgC2k= +cloud.google.com/go/datacatalog v1.26.1/go.mod h1:2Qcq8vsHNxMDgjgadRFmFG47Y+uuIVsyEGUrlrKEdrg= +cloud.google.com/go/iam v1.5.3 h1:+vMINPiDF2ognBJ97ABAYYwRgsaqxPbQDlMnbHMjolc= +cloud.google.com/go/iam v1.5.3/go.mod h1:MR3v9oLkZCTlaqljW6Eb2d3HGDGK5/bDv93jhfISFvU= +cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8= +cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk= +cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE= +cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI= +cloud.google.com/go/storage v1.59.0 h1:9p3yDzEN9Vet4JnbN90FECIw6n4FCXcKBK1scxtQnw8= +cloud.google.com/go/storage v1.59.0/go.mod h1:cMWbtM+anpC74gn6qjLh+exqYcfmB9Hqe5z6adx+CLI= dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk= github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 h1:sBEjpZlNHzK1voKq9695PJSX2o5NEXl7/OL3coiIY0c= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.54.0 h1:lhhYARPUu3LmHysQ/igznQphfzynnqI3D75oUyw1HXk= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.54.0/go.mod h1:l9rva3ApbBpEJxSNYnwT9N4CDLrWgtq3u8736C5hyJw= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.54.0 h1:s0WlVbf9qpvkh1c/uDAPElam0WrL7fHRIidgZJ7UqZI= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.54.0/go.mod h1:Mf6O40IAyB9zR/1J8nGDDPirZQQPbYJni8Yisy7NTMc= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= +github.com/apache/arrow/go/v15 v15.0.2 h1:60IliRbiyTWCWjERBCkO1W4Qun9svcYoZrSLcyOsMLE= +github.com/apache/arrow/go/v15 v15.0.2/go.mod h1:DGXsR3ajT524njufqf95822i+KTh+yea1jass9YXgjA= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5 h1:6xNmx7iTtyBRev0+D/Tv1FZd4SCg8axKApyNyRsAt/w= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= @@ -37,10 +69,19 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw= github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= +github.com/envoyproxy/go-control-plane v0.14.0 h1:hbG2kr4RuFj222B6+7T83thSPqLjwBIfQawTkC++2HA= +github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g= +github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98= +github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4= +github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -48,15 +89,35 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo= +github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/cel-go v0.27.0 h1:e7ih85+4qVrBuqQWTW4FKSqZYokVuc3HnhH5keboFTo= github.com/google/cel-go v0.27.0/go.mod h1:tTJ11FWqnhw5KKpnWpvW9CJC3Y9GK4EIS0WXnBbebzw= +github.com/google/flatbuffers v23.5.26+incompatible h1:M9dgRyhJemaM4Sw8+66GHBu8ioaQmyPLg1b8VwK5WJg= +github.com/google/flatbuffers v23.5.26+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/martian/v3 v3.3.3 h1:DIhPTQrbPkgs2yJYdXU/eNACCG5DVQjySNRNlflZ9Fc= +github.com/google/martian/v3 v3.3.3/go.mod h1:iEPrYcgCF7jA9OtScMFQyAlZZ4YXTKEtJ1E6RWzmBA0= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= +github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/googleapis/enterprise-certificate-proxy v0.3.11 h1:vAe81Msw+8tKUxi2Dqh/NZMz7475yUvmRIkXr4oN2ao= +github.com/googleapis/enterprise-certificate-proxy v0.3.11/go.mod h1:RFV7MUdlb7AgEq2v7FmMCfeSMCllAzWxFgRdusoGks8= +github.com/googleapis/gax-go/v2 v2.17.0 h1:RksgfBpxqff0EZkDWYuz9q/uWsTVz+kf43LsZ1J6SMc= +github.com/googleapis/gax-go/v2 v2.17.0/go.mod h1:mzaqghpQp4JDh3HvADwrat+6M3MOIDp5YKHhb9PAgDY= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1 h1:X5VWvz21y3gzm9Nw/kaUeku/1+uBhcekkmy4IkffJww= github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.1/go.mod h1:Zanoh4+gvIgluNqcfMVTJueD4wSS5hT7zTt4Mrutd90= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -67,6 +128,8 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= +github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -77,6 +140,8 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 h1:6E+4a0GO5zZEnZ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= github.com/mdelapenya/tlscert v0.2.0/go.mod h1:O4njj3ELLnJjGdkN7M/vIVCpZ+Cf0L6muqOG4tLSl8o= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= @@ -97,22 +162,32 @@ github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/pierrec/lz4/v4 v4.1.18 h1:xaKrnTkyoqfh1YItXl56+6KJNVYWlEEPuAQW9xsplYQ= +github.com/pierrec/lz4/v4 v4.1.18/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c h1:ncq/mPwQF4JjgDlrVEn3C11VoGHZN7m8qihwgMEtzYw= github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs= github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/spiffe/go-spiffe/v2 v2.6.0 h1:l+DolpxNWYgruGQVV0xsfeya3CsC7m8iBzDnMpsbLuo= +github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= @@ -122,6 +197,10 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/testcontainers/testcontainers-go v0.40.0 h1:pSdJYLOVgLE8YdUY2FHQ1Fxu+aMnb6JfVz1mxk7OeMU= github.com/testcontainers/testcontainers-go v0.40.0/go.mod h1:FSXV5KQtX2HAMlm7U3APNyLkkap35zNLxukw9oBi/MY= +github.com/testcontainers/testcontainers-go/modules/gcloud v0.40.0 h1:9Q7AnMCHmLArYtWe0i06hHnmVylJw2FNkJX/Sm0Rpf0= +github.com/testcontainers/testcontainers-go/modules/gcloud v0.40.0/go.mod h1:SyFaMHm4IaOBL8DoNUZ2ov4vlQuU7qBRAcJuUNYw2OA= +github.com/testcontainers/testcontainers-go/modules/mysql v0.40.0 h1:P9Txfy5Jothx2wFdcus0QoSmX/PKSIXZxrTbZPVJswA= +github.com/testcontainers/testcontainers-go/modules/mysql v0.40.0/go.mod h1:oZPHHqJqXG7FD8OB/yWH7gLnDvZUlFHAVJNrGftL+eg= github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 h1:s2bIayFXlbDFexo96y+htn7FzuhpXLYJNnIuglNKqOk= github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0/go.mod h1:h+u/2KoREGTnTl9UwrQ/g+XhasAT8E6dClclAADeXoQ= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= @@ -130,8 +209,16 @@ github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+F github.com/tklauser/numcpus v0.6.1/go.mod h1:1XfjsgE2zo8GVw7POkMbHENHzVg3GzmoZ9fESEdAacY= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= +github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= +github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/detectors/gcp v1.39.0 h1:kWRNZMsfBHZ+uHjiH4y7Etn2FK26LAGkNFw7RHv1DhE= +go.opentelemetry.io/contrib/detectors/gcp v1.39.0/go.mod h1:t/OGqzHBa5v6RHZwrDBJ2OirWc+4q/w2fTbLZwAKjTk= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 h1:YH4g8lQroajqUwWbq/tr2QX1JFmEXaDLgG+ew9bLMWo= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0/go.mod h1:fvPi2qXDqFs8M4B4fmJhE92TyQs9Ydjlg3RvfUp+NbQ= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0 h1:7iP2uCb7sGddAr30RRS6xjKy7AZ2JtTOPA3oolgVSw8= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.65.0/go.mod h1:c7hN3ddxs/z6q9xwvfLPk+UHlWRQyaeR1LdgfL/66l0= go.opentelemetry.io/otel v1.40.0 h1:oA5YeOcpRTXq6NN7frwmwFR0Cn3RhTVZvXsP4duvCms= @@ -152,37 +239,55 @@ go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lI go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 h1:kx6Ds3MlpiUHKj7syVnbp57++8WpuKPcR5yjLBjvLEA= -golang.org/x/exp v0.0.0-20240823005443-9b4947da3948/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= -golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= -golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= +golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= +golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= +golang.org/x/oauth2 v0.35.0 h1:Mv2mzuHuZuY2+bkyWXIHMfhNdJAdwW3FuWeCPYN5GVQ= +golang.org/x/oauth2 v0.35.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= -golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= -golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= -golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/telemetry v0.0.0-20251203150158-8fff8a5912fc h1:bH6xUXay0AIFMElXG2rQ4uiE+7ncwtiOdPfYK1NK2XA= +golang.org/x/telemetry v0.0.0-20251203150158-8fff8a5912fc/go.mod h1:hKdjCMrbv9skySur+Nek8Hd0uJ0GuxJIoIX2payrIdQ= +golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= +golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822 h1:oWVWY3NzT7KJppx2UKhKmzPq4SRe0LdCijVRwvGeikY= -google.golang.org/genproto/googleapis/api v0.0.0-20250603155806-513f23925822/go.mod h1:h3c4v36UTKzUiuaOKQ6gr3S+0hovBtUrXzTG/i3+XEc= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 h1:fc6jSaCT0vBduLYZHYrBBNY4dsWuvgyff9noRNDdBeE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= -google.golang.org/grpc v1.73.0 h1:VIWSmpI2MegBtTuFt5/JWy2oXxtjJ/e89Z70ImfD2ok= -google.golang.org/grpc v1.73.0/go.mod h1:50sbHOUqWoCQGI8V2HQLJM0B+LMlIUjNSZmow7EVBQc= -google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= -google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= +golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/api v0.268.0 h1:hgA3aS4lt9rpF5RCCkX0Q2l7DvHgvlb53y4T4u6iKkA= +google.golang.org/api v0.268.0/go.mod h1:HXMyMH496wz+dAJwD/GkAPLd3ZL33Kh0zEG32eNvy9w= +google.golang.org/genproto v0.0.0-20260128011058-8636f8732409 h1:VQZ/yAbAtjkHgH80teYd2em3xtIkkHd7ZhqfH2N9CsM= +google.golang.org/genproto v0.0.0-20260128011058-8636f8732409/go.mod h1:rxKD3IEILWEu3P44seeNOAwZN4SaoKaQ/2eTg4mM6EM= +google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409 h1:merA0rdPeUV3YIIfHHcH4qBkiQAc1nfCKSI7lB4cV2M= +google.golang.org/genproto/googleapis/api v0.0.0-20260128011058-8636f8732409/go.mod h1:fl8J1IvUjCilwZzQowmw2b7HQB2eAuYBabMXzWurF+I= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20 h1:Jr5R2J6F6qWyzINc+4AM8t5pfUz6beZpHp678GNrMbE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260203192932-546029d2fa20/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/grpc v1.79.1 h1:zGhSi45ODB9/p3VAawt9a+O/MULLl9dpizzNNpq7flY= +google.golang.org/grpc v1.79.1/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= @@ -191,3 +296,31 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= +modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU= +modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/json.go b/json.go index 5867bf8..a1351f3 100644 --- a/json.go +++ b/json.go @@ -131,14 +131,10 @@ func (con *converter) buildJSONPathForArray(expr *exprpb.Expr) error { if operandSelect := operand.GetSelectExpr(); operandSelect != nil { // This is nested access - recursively build the path for the operand if con.hasJSONFieldInChain(operand) { - if err := con.buildJSONPathForArray(operand); err != nil { - return err - } - // Add intermediate JSON path operator (always -> for arrays) - con.str.WriteString("->'") - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("'") - return nil + // Add intermediate JSON path operator (always non-final for arrays) + return con.dialect.WriteJSONFieldAccess(&con.str, func() error { + return con.buildJSONPathForArray(operand) + }, field, false) } } @@ -153,13 +149,9 @@ func (con *converter) buildJSONPathForArray(expr *exprpb.Expr) error { } // For other cases, visit the operand and add JSON operator - if err := con.visit(operand); err != nil { - return err - } - con.str.WriteString("->'") - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("'") - return nil + return con.dialect.WriteJSONFieldAccess(&con.str, func() error { + return con.visit(operand) + }, field, false) } // isJSONObjectFieldAccess determines if this is a JSON object field access in comprehensions @@ -303,61 +295,34 @@ func (con *converter) buildJSONPathInternal(expr *exprpb.Expr, isFinalField bool // If so, we should NOT apply JSON operators to this level if tableName, columnName, ok := con.getTableAndFieldFromSelectChain(operand); ok { // This is table.column where column is JSON/JSONB - // Render as table.column without JSON operators - con.str.WriteString(tableName) - con.str.WriteString(".") - con.str.WriteString(columnName) - // Now add JSON operator for the current field - if isFinalField { - con.str.WriteString("->>'") // Final field: extract as text - } else { - con.str.WriteString("->'") // Intermediate field: keep as JSON - } - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("'") - return nil + // Render as table.column without JSON operators, then add JSON operator for the current field + return con.dialect.WriteJSONFieldAccess(&con.str, func() error { + con.str.WriteString(tableName) + con.str.WriteString(".") + con.str.WriteString(columnName) + return nil + }, field, isFinalField) } // This is deeper nesting like table.jsonfield.subfield.finalfield // We need to determine if the operand is JSON-related if con.shouldUseJSONPath(operandSelect.GetOperand(), operandSelect.GetField()) { - // Recursively build the path for the operand (not final since we have more fields) - if err := con.buildJSONPathInternal(operand, false); err != nil { - return err - } // Add appropriate JSON path operator based on whether this is the final field - if isFinalField { - con.str.WriteString("->>'") // Final field: extract as text - } else { - con.str.WriteString("->'") // Intermediate field: keep as JSON - } - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("'") - return nil + return con.dialect.WriteJSONFieldAccess(&con.str, func() error { + // Recursively build the path for the operand (not final since we have more fields) + return con.buildJSONPathInternal(operand, false) + }, field, isFinalField) } } - // Visit the base operand (like table.jsonfield) - if err := con.visit(operand); err != nil { - return err - } - - // Add the appropriate JSON path operator based on whether this is the final field - operator := "->>" - if !isFinalField { - operator = "->" - } - con.logger.LogAttrs(context.Background(), slog.LevelDebug, "JSON path operator selection", slog.String("field", field), - slog.String("operator", operator), slog.Bool("is_final", isFinalField), ) - con.str.WriteString(operator) - con.str.WriteString("'") - con.str.WriteString(escapeJSONFieldName(field)) - con.str.WriteString("'") - return nil + // Visit the base operand (like table.jsonfield) + return con.dialect.WriteJSONFieldAccess(&con.str, func() error { + return con.visit(operand) + }, field, isFinalField) } diff --git a/mysql/provider.go b/mysql/provider.go new file mode 100644 index 0000000..781f40b --- /dev/null +++ b/mysql/provider.go @@ -0,0 +1,226 @@ +// Package mysql provides MySQL type provider for CEL type system integration. +package mysql + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/spandigital/cel2sql/v3/schema" + "github.com/spandigital/cel2sql/v3/sqltypes" +) + +// Sentinel errors for the mysql package. +var ( + // ErrInvalidSchema indicates a problem with the provided schema or database introspection. + ErrInvalidSchema = errors.New("invalid schema") +) + +// FieldSchema is an alias for schema.FieldSchema. +type FieldSchema = schema.FieldSchema + +// Schema is an alias for schema.Schema. +type Schema = schema.Schema + +// NewSchema creates a new Schema. This is an alias for schema.NewSchema. +var NewSchema = schema.NewSchema + +// TypeProvider interface for MySQL type providers. +type TypeProvider interface { + types.Provider + LoadTableSchema(ctx context.Context, tableName string) error + GetSchemas() map[string]Schema + Close() +} + +type typeProvider struct { + schemas map[string]Schema + db *sql.DB +} + +// NewTypeProvider creates a new MySQL type provider with pre-defined schemas. +func NewTypeProvider(schemas map[string]Schema) TypeProvider { + return &typeProvider{schemas: schemas} +} + +// NewTypeProviderWithConnection creates a new MySQL type provider that can introspect database schemas. +// The caller owns the *sql.DB and is responsible for closing it. +func NewTypeProviderWithConnection(_ context.Context, db *sql.DB) (TypeProvider, error) { + if db == nil { + return nil, fmt.Errorf("%w: db connection must not be nil", ErrInvalidSchema) + } + + return &typeProvider{ + schemas: make(map[string]Schema), + db: db, + }, nil +} + +// LoadTableSchema loads schema information for a table from the database. +func (tp *typeProvider) LoadTableSchema(ctx context.Context, tableName string) error { + if tp.db == nil { + return fmt.Errorf("%w: no database connection available", ErrInvalidSchema) + } + + query := ` + SELECT column_name, data_type, column_type, is_nullable + FROM information_schema.columns + WHERE table_schema = DATABASE() AND table_name = ? + ORDER BY ordinal_position + ` + + rows, err := tp.db.QueryContext(ctx, query, tableName) + if err != nil { + return fmt.Errorf("%w: failed to query table schema", ErrInvalidSchema) + } + defer func() { _ = rows.Close() }() + + var fields []FieldSchema + for rows.Next() { + var columnName, dataType, columnType, isNullable string + + if err := rows.Scan(&columnName, &dataType, &columnType, &isNullable); err != nil { + return fmt.Errorf("%w: failed to scan row", ErrInvalidSchema) + } + + field := mysqlColumnToFieldSchema(columnName, dataType, columnType) + fields = append(fields, field) + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("%w: error iterating rows", ErrInvalidSchema) + } + + if len(fields) == 0 { + return fmt.Errorf("%w: table %q has no columns or does not exist", ErrInvalidSchema, tableName) + } + + tp.schemas[tableName] = NewSchema(fields) + return nil +} + +// mysqlColumnToFieldSchema converts MySQL column metadata to a FieldSchema. +func mysqlColumnToFieldSchema(columnName, dataType, _ string) FieldSchema { + // Normalize data type to lowercase + dataType = strings.ToLower(dataType) + + isJSON := dataType == "json" + + return FieldSchema{ + Name: columnName, + Type: dataType, + IsJSON: isJSON, + } +} + +// Close is a no-op since we don't own the *sql.DB. +func (tp *typeProvider) Close() { + // No-op: caller owns the *sql.DB connection +} + +// GetSchemas returns the schemas known to this type provider. +func (tp *typeProvider) GetSchemas() map[string]Schema { + return tp.schemas +} + +// EnumValue implements types.Provider. +func (tp *typeProvider) EnumValue(_ string) ref.Val { + return types.NewErr("unknown enum value") +} + +// FindIdent implements types.Provider. +func (tp *typeProvider) FindIdent(_ string) (ref.Val, bool) { + return nil, false +} + +// FindStructType implements types.Provider. +func (tp *typeProvider) FindStructType(structType string) (*types.Type, bool) { + if _, ok := tp.schemas[structType]; ok { + return types.NewObjectType(structType), true + } + return nil, false +} + +// FindStructFieldNames implements types.Provider. +func (tp *typeProvider) FindStructFieldNames(structType string) ([]string, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + fields := s.Fields() + names := make([]string, len(fields)) + for i, f := range fields { + names[i] = f.Name + } + return names, true +} + +// FindStructFieldType implements types.Provider. +func (tp *typeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + field, found := s.FindField(fieldName) + if !found { + return nil, false + } + + exprType := mysqlTypeToCELExprType(field) + + celType, err := types.ExprTypeToType(exprType) + if err != nil { + return nil, false + } + + return &types.FieldType{ + Type: celType, + }, true +} + +// NewValue implements types.Provider. +func (tp *typeProvider) NewValue(_ string, _ map[string]ref.Val) ref.Val { + return types.NewErr("unknown type in schema") +} + +// mysqlTypeToCELExprType converts a MySQL field schema to a CEL expression type. +func mysqlTypeToCELExprType(field *schema.FieldSchema) *exprpb.Type { + baseType := mysqlBaseTypeToCEL(field.Type) + if field.Repeated { + return decls.NewListType(baseType) + } + return baseType +} + +// mysqlBaseTypeToCEL converts a MySQL type name to a CEL expression type. +func mysqlBaseTypeToCEL(typeName string) *exprpb.Type { + switch typeName { + case "varchar", "char", "text", "tinytext", "mediumtext", "longtext", "enum", "set": + return decls.String + case "int", "integer", "tinyint", "smallint", "mediumint", "bigint": + return decls.Int + case "float", "double", "decimal", "numeric", "real": + return decls.Double + case "boolean", "bool": + return decls.Bool + case "blob", "binary", "varbinary", "tinyblob", "mediumblob", "longblob": + return decls.Bytes + case "json": + return decls.Dyn + case "datetime", "timestamp": + return decls.Timestamp + case "date": + return sqltypes.Date + case "time": + return sqltypes.Time + default: + return decls.Dyn + } +} diff --git a/mysql/provider_test.go b/mysql/provider_test.go new file mode 100644 index 0000000..5a73099 --- /dev/null +++ b/mysql/provider_test.go @@ -0,0 +1,280 @@ +package mysql_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + + _ "github.com/go-sql-driver/mysql" + "github.com/google/cel-go/common/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + tcmysql "github.com/testcontainers/testcontainers-go/modules/mysql" + + "github.com/spandigital/cel2sql/v3/mysql" +) + +func TestNewTypeProvider(t *testing.T) { + schemas := map[string]mysql.Schema{ + "users": mysql.NewSchema([]mysql.FieldSchema{ + {Name: "id", Type: "int"}, + {Name: "name", Type: "varchar"}, + {Name: "email", Type: "text"}, + }), + } + + provider := mysql.NewTypeProvider(schemas) + require.NotNil(t, provider) + + got := provider.GetSchemas() + assert.Len(t, got, 1) + assert.Contains(t, got, "users") +} + +func TestNewTypeProviderWithConnection_NilDB(t *testing.T) { + _, err := mysql.NewTypeProviderWithConnection(context.Background(), nil) + require.Error(t, err) + assert.ErrorIs(t, err, mysql.ErrInvalidSchema) +} + +func TestLoadTableSchema_NoDB(t *testing.T) { + provider := mysql.NewTypeProvider(nil) + err := provider.LoadTableSchema(context.Background(), "test") + require.Error(t, err) + assert.ErrorIs(t, err, mysql.ErrInvalidSchema) +} + +func TestTypeProvider_FindStructType(t *testing.T) { + schemas := map[string]mysql.Schema{ + "users": mysql.NewSchema([]mysql.FieldSchema{ + {Name: "id", Type: "int"}, + {Name: "name", Type: "varchar"}, + }), + } + provider := mysql.NewTypeProvider(schemas) + + typ, found := provider.FindStructType("users") + assert.True(t, found) + assert.NotNil(t, typ) + + _, found = provider.FindStructType("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldNames(t *testing.T) { + schemas := map[string]mysql.Schema{ + "users": mysql.NewSchema([]mysql.FieldSchema{ + {Name: "id", Type: "int"}, + {Name: "name", Type: "varchar"}, + {Name: "email", Type: "text"}, + }), + } + provider := mysql.NewTypeProvider(schemas) + + names, found := provider.FindStructFieldNames("users") + assert.True(t, found) + assert.ElementsMatch(t, []string{"id", "name", "email"}, names) + + _, found = provider.FindStructFieldNames("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldType(t *testing.T) { + schemas := map[string]mysql.Schema{ + "test_table": mysql.NewSchema([]mysql.FieldSchema{ + {Name: "str_field", Type: "varchar"}, + {Name: "int_field", Type: "int"}, + {Name: "bigint_field", Type: "bigint"}, + {Name: "float_field", Type: "float"}, + {Name: "double_field", Type: "double"}, + {Name: "decimal_field", Type: "decimal"}, + {Name: "bool_field", Type: "boolean"}, + {Name: "blob_field", Type: "blob"}, + {Name: "json_field", Type: "json"}, + {Name: "datetime_field", Type: "datetime"}, + {Name: "timestamp_field", Type: "timestamp"}, + {Name: "text_field", Type: "text"}, + {Name: "enum_field", Type: "enum"}, + }), + } + provider := mysql.NewTypeProvider(schemas) + + tests := []struct { + fieldName string + wantType *types.Type + wantFound bool + }{ + {"str_field", types.StringType, true}, + {"int_field", types.IntType, true}, + {"bigint_field", types.IntType, true}, + {"float_field", types.DoubleType, true}, + {"double_field", types.DoubleType, true}, + {"decimal_field", types.DoubleType, true}, + {"bool_field", types.BoolType, true}, + {"blob_field", types.BytesType, true}, + {"json_field", types.DynType, true}, + {"datetime_field", types.TimestampType, true}, + {"timestamp_field", types.TimestampType, true}, + {"text_field", types.StringType, true}, + {"enum_field", types.StringType, true}, + {"nonexistent", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("test_table", tt.fieldName) + assert.Equal(t, tt.wantFound, found) + if found { + assert.Equal(t, tt.wantType, got.Type) + } + }) + } +} + +func TestTypeProvider_Close(_ *testing.T) { + provider := mysql.NewTypeProvider(nil) + // Close should not panic + provider.Close() +} + +// setupMySQLContainer starts a MySQL 8 container and returns a database connection. +func setupMySQLContainer(ctx context.Context, t *testing.T) (*tcmysql.MySQLContainer, *sql.DB) { + t.Helper() + + container, err := tcmysql.Run(ctx, "mysql:8.0", + tcmysql.WithDatabase("testdb"), + tcmysql.WithUsername("testuser"), + tcmysql.WithPassword("testpass"), + ) + require.NoError(t, err, "Failed to start MySQL container") + + host, err := container.Host(ctx) + require.NoError(t, err) + port, err := container.MappedPort(ctx, "3306") + require.NoError(t, err) + + connStr := fmt.Sprintf("testuser:testpass@tcp(%s:%s)/testdb?parseTime=true", + host, port.Port()) + db, err := sql.Open("mysql", connStr) + require.NoError(t, err, "Failed to connect to MySQL database") + + err = db.Ping() + require.NoError(t, err, "Failed to ping MySQL database") + + return container, db +} + +func TestLoadTableSchema_Integration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, db := setupMySQLContainer(ctx, t) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + if err := testcontainers.TerminateContainer(container); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + // Create test table with various MySQL types + _, err := db.ExecContext(ctx, ` + CREATE TABLE schema_test ( + id INT PRIMARY KEY AUTO_INCREMENT, + name VARCHAR(255) NOT NULL, + description TEXT, + age INT, + score DOUBLE, + price DECIMAL(10,2), + is_active BOOLEAN, + avatar BLOB, + metadata JSON, + created_at DATETIME, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + `) + require.NoError(t, err) + + provider, err := mysql.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "schema_test") + require.NoError(t, err) + + // Verify schema was loaded + schemas := provider.GetSchemas() + assert.Contains(t, schemas, "schema_test") + + // Verify FindStructType + typ, found := provider.FindStructType("schema_test") + assert.True(t, found) + assert.NotNil(t, typ) + + // Verify FindStructFieldNames + names, found := provider.FindStructFieldNames("schema_test") + assert.True(t, found) + assert.Contains(t, names, "id") + assert.Contains(t, names, "name") + assert.Contains(t, names, "metadata") + + // Verify type mappings from loaded schema + tests := []struct { + fieldName string + wantType *types.Type + }{ + {"id", types.IntType}, + {"name", types.StringType}, + {"description", types.StringType}, + {"age", types.IntType}, + {"score", types.DoubleType}, + {"price", types.DoubleType}, + {"is_active", types.IntType}, // MySQL BOOLEAN is TINYINT(1), data_type shows "tinyint" + {"metadata", types.DynType}, + {"created_at", types.TimestampType}, + } + + for _, tt := range tests { + t.Run("type_"+tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("schema_test", tt.fieldName) + assert.True(t, found, "field %q should be found", tt.fieldName) + if found { + assert.Equal(t, tt.wantType, got.Type, "field %q type mismatch", tt.fieldName) + } + }) + } + + // Verify JSON detection + schemaObj := schemas["schema_test"] + metadataField, found := schemaObj.FindField("metadata") + assert.True(t, found) + assert.True(t, metadataField.IsJSON, "metadata should be detected as JSON") +} + +func TestLoadTableSchema_NonexistentTable(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, db := setupMySQLContainer(ctx, t) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + if err := testcontainers.TerminateContainer(container); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + provider, err := mysql.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "nonexistent_table") + require.Error(t, err) + assert.ErrorIs(t, err, mysql.ErrInvalidSchema) +} diff --git a/mysql_integration_test.go b/mysql_integration_test.go new file mode 100644 index 0000000..768dbb9 --- /dev/null +++ b/mysql_integration_test.go @@ -0,0 +1,447 @@ +package cel2sql_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + + _ "github.com/go-sql-driver/mysql" + "github.com/google/cel-go/cel" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + tcmysql "github.com/testcontainers/testcontainers-go/modules/mysql" + + "github.com/spandigital/cel2sql/v3" + mysqlDialect "github.com/spandigital/cel2sql/v3/dialect/mysql" + "github.com/spandigital/cel2sql/v3/pg" +) + +// setupMySQLContainer starts a MySQL 8 container and returns a database connection. +func setupMySQLContainer(ctx context.Context, t *testing.T) (testcontainers.Container, *sql.DB) { + t.Helper() + + container, err := tcmysql.Run(ctx, "mysql:8.0", + tcmysql.WithDatabase("testdb"), + tcmysql.WithUsername("testuser"), + tcmysql.WithPassword("testpass"), + ) + require.NoError(t, err, "Failed to start MySQL container") + + // Get connection string + host, err := container.Host(ctx) + require.NoError(t, err) + port, err := container.MappedPort(ctx, "3306") + require.NoError(t, err) + + connStr := fmt.Sprintf("testuser:testpass@tcp(%s:%s)/testdb?parseTime=true", + host, port.Port()) + db, err := sql.Open("mysql", connStr) + require.NoError(t, err, "Failed to connect to MySQL database") + + err = db.Ping() + require.NoError(t, err, "Failed to ping MySQL database") + + return container, db +} + +// TestMySQLOperatorsIntegration validates operator conversions against a real MySQL database. +func TestMySQLOperatorsIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, db := setupMySQLContainer(ctx, t) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + if err := container.Terminate(ctx); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + // Create test table + _, err := db.Exec(` + CREATE TABLE test_data ( + id INTEGER PRIMARY KEY, + text_val TEXT, + int_val INTEGER, + float_val DOUBLE, + bool_val BOOLEAN, + nullable_text TEXT, + nullable_int INTEGER + ) + `) + require.NoError(t, err) + + // Insert test data + _, err = db.Exec(` + INSERT INTO test_data VALUES + (1, 'hello', 10, 10.5, true, 'present', 100), + (2, 'world', 20, 20.5, false, NULL, NULL), + (3, 'test', 30, 30.5, true, 'here', 200), + (4, 'hello world', 5, 5.5, false, 'value', 50), + (5, 'testing', 15, 15.5, true, 'test', 150) + `) + require.NoError(t, err) + + // Set up CEL environment with simple variables + env, err := cel.NewEnv( + cel.Variable("id", cel.IntType), + cel.Variable("text_val", cel.StringType), + cel.Variable("int_val", cel.IntType), + cel.Variable("float_val", cel.DoubleType), + cel.Variable("bool_val", cel.BoolType), + cel.Variable("nullable_text", cel.StringType), + cel.Variable("nullable_int", cel.IntType), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(mysqlDialect.New()) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + // Comparison operators + { + name: "Equality string", + celExpr: `text_val == "hello"`, + expectedRows: 1, + description: "String equality comparison", + }, + { + name: "Equality integer", + celExpr: `int_val == 20`, + expectedRows: 1, + description: "Integer equality comparison", + }, + { + name: "Equality float", + celExpr: `float_val == 10.5`, + expectedRows: 1, + description: "Float equality comparison", + }, + { + name: "Equality boolean", + celExpr: `bool_val == true`, + expectedRows: 3, + description: "Boolean equality comparison", + }, + { + name: "Not equal", + celExpr: `text_val != "hello"`, + expectedRows: 4, + description: "Not equal comparison", + }, + { + name: "Less than", + celExpr: `int_val < 15`, + expectedRows: 2, // 10, 5 + description: "Less than comparison", + }, + { + name: "Less than or equal", + celExpr: `int_val <= 15`, + expectedRows: 3, // 10, 5, 15 + description: "Less than or equal comparison", + }, + { + name: "Greater than", + celExpr: `int_val > 15`, + expectedRows: 2, // 20, 30 + description: "Greater than comparison", + }, + { + name: "Greater than or equal", + celExpr: `int_val >= 15`, + expectedRows: 3, // 20, 30, 15 + description: "Greater than or equal comparison", + }, + + // Logical operators + { + name: "Logical AND", + celExpr: `int_val > 10 && bool_val == true`, + expectedRows: 2, // rows 3 (30,true) and 5 (15,true) + description: "Logical AND operator", + }, + { + name: "Logical OR", + celExpr: `int_val < 10 || bool_val == false`, + expectedRows: 2, // rows 2 (20,false) and 4 (5,false) + description: "Logical OR operator", + }, + { + name: "Logical NOT", + celExpr: `!bool_val`, + expectedRows: 2, // rows 2 and 4 + description: "Logical NOT operator", + }, + { + name: "Complex logical expression", + celExpr: `(int_val > 10 && bool_val) || int_val < 10`, + expectedRows: 3, // rows 3, 5, 4 + description: "Complex nested logical operators", + }, + + // Arithmetic operators + { + name: "Addition", + celExpr: `int_val + 10 == 20`, + expectedRows: 1, // 10 + 10 = 20 + description: "Addition operator", + }, + { + name: "Subtraction", + celExpr: `int_val - 5 == 15`, + expectedRows: 1, // 20 - 5 = 15 + description: "Subtraction operator", + }, + { + name: "Multiplication", + celExpr: `int_val * 2 == 20`, + expectedRows: 1, // 10 * 2 = 20 + description: "Multiplication operator", + }, + { + name: "Division", + celExpr: `int_val / 2 == 10`, + expectedRows: 1, // 20 / 2 = 10 + description: "Division operator", + }, + { + name: "Modulo", + celExpr: `int_val % 10 == 0`, + expectedRows: 3, // 10, 20, 30 + description: "Modulo operator", + }, + { + name: "Complex arithmetic", + celExpr: `(int_val * 2) + 5 > 30`, + expectedRows: 3, // (20*2)+5=45, (30*2)+5=65, (15*2)+5=35 + description: "Complex arithmetic expression", + }, + + // String operators + { + name: "String concatenation", + celExpr: `text_val + "!" == "hello!"`, + expectedRows: 1, + description: "String concatenation (CONCAT)", + }, + { + name: "String contains", + celExpr: `text_val.contains("world")`, + expectedRows: 2, // "world", "hello world" + description: "String contains function (LOCATE)", + }, + { + name: "String startsWith", + celExpr: `text_val.startsWith("hello")`, + expectedRows: 2, // "hello", "hello world" + description: "String startsWith function (LIKE)", + }, + { + name: "String endsWith", + celExpr: `text_val.endsWith("world")`, + expectedRows: 2, // "world", "hello world" + description: "String endsWith function (LIKE)", + }, + + // Regex (MySQL 8.0+ supports REGEXP) + { + name: "Regex match", + celExpr: `text_val.matches(r"^hello")`, + expectedRows: 2, // "hello", "hello world" + description: "Regex match (REGEXP)", + }, + { + name: "Regex word boundary", + celExpr: `text_val.matches(r"test")`, + expectedRows: 2, // "test", "testing" + description: "Regex simple pattern", + }, + + // Complex combined operators + { + name: "Complex multi-operator expression", + celExpr: `int_val > 10 && bool_val && text_val.contains("test")`, + expectedRows: 2, // rows 3 and 5 + description: "Complex expression with multiple operator types", + }, + { + name: "Nested parenthesized operators", + celExpr: `((int_val + 5) * 2 > 30) && (text_val.contains("test") || bool_val)`, + expectedRows: 2, // rows 3 and 5 + description: "Deeply nested operators with parentheses", + }, + { + name: "Triple negation", + celExpr: `!!!bool_val`, + expectedRows: 2, + description: "Multiple NOT operators", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Compile CEL expression + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + // Convert to SQL with MySQL dialect + sqlCondition, err := cel2sql.Convert(ast, dialectOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // Execute query and count results + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := "SELECT COUNT(*) FROM test_data WHERE " + sqlCondition + t.Logf("Full SQL Query: %s", query) + + var actualRows int + err = db.QueryRow(query).Scan(&actualRows) + require.NoError(t, err, "Generated SQL should execute successfully. %s\nSQL: %s", + tt.description, sqlCondition) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s\nCEL: %s\nSQL: %s", + tt.description, tt.celExpr, sqlCondition) + + t.Logf("OK: %s (expected %d rows, got %d rows)", + tt.description, tt.expectedRows, actualRows) + }) + } +} + +// TestMySQLJSONIntegration validates JSON operations against a real MySQL database. +func TestMySQLJSONIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + ctx := context.Background() + container, db := setupMySQLContainer(ctx, t) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + if err := container.Terminate(ctx); err != nil { + t.Logf("failed to terminate container: %v", err) + } + }() + + // Create test table with JSON column + _, err := db.Exec(` + CREATE TABLE products ( + id INTEGER PRIMARY KEY, + name TEXT, + price DOUBLE, + metadata JSON + ) + `) + require.NoError(t, err) + + // Insert test data with JSON metadata + _, err = db.Exec(` + INSERT INTO products VALUES + (1, 'Widget', 19.99, '{"brand": "Acme", "color": "red", "specs": {"weight": 100}}'), + (2, 'Gadget', 29.99, '{"brand": "Beta", "color": "blue", "specs": {"weight": 200}}'), + (3, 'Doohickey', 39.99, '{"brand": "Acme", "color": "green", "specs": {"weight": 150}}') + `) + require.NoError(t, err) + + // Set up CEL environment with schema for JSON detection + productSchema := pg.NewSchema([]pg.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + {Name: "price", Type: "double precision"}, + {Name: "metadata", Type: "json", IsJSON: true}, + }) + + schemas := map[string]pg.Schema{ + "product": productSchema, + } + + provider := pg.NewTypeProvider(schemas) + + env, err := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("product", cel.ObjectType("product")), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(mysqlDialect.New()) + schemaOpt := cel2sql.WithSchemas(schemas) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + { + name: "JSON field access", + celExpr: `product.metadata.brand == "Acme"`, + expectedRows: 2, + description: "JSON field access with ->>", + }, + { + name: "JSON field access different value", + celExpr: `product.metadata.color == "blue"`, + expectedRows: 1, + description: "JSON field access with different value", + }, + { + name: "JSON with regular field", + celExpr: `product.metadata.brand == "Acme" && product.price > 30.0`, + expectedRows: 1, // Doohickey (Acme, 39.99) + description: "JSON field combined with regular field comparison", + }, + { + name: "JSON field existence", + celExpr: `has(product.metadata.brand)`, + expectedRows: 3, // All rows have 'brand' + description: "JSON field existence check", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + sqlCondition, err := cel2sql.Convert(ast, dialectOpt, schemaOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := "SELECT COUNT(*) FROM products product WHERE " + sqlCondition + t.Logf("Full SQL Query: %s", query) + + var actualRows int + err = db.QueryRow(query).Scan(&actualRows) + require.NoError(t, err, "Generated SQL should execute successfully. %s\nSQL: %s", + tt.description, sqlCondition) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s\nCEL: %s\nSQL: %s", + tt.description, tt.celExpr, sqlCondition) + + t.Logf("OK: %s", tt.description) + }) + } +} diff --git a/sqlite/provider.go b/sqlite/provider.go new file mode 100644 index 0000000..0877074 --- /dev/null +++ b/sqlite/provider.go @@ -0,0 +1,280 @@ +// Package sqlite provides SQLite type provider for CEL type system integration. +package sqlite + +import ( + "context" + "database/sql" + "errors" + "fmt" + "regexp" + "strings" + + "github.com/google/cel-go/checker/decls" + "github.com/google/cel-go/common/types" + "github.com/google/cel-go/common/types/ref" + exprpb "google.golang.org/genproto/googleapis/api/expr/v1alpha1" + + "github.com/spandigital/cel2sql/v3/schema" +) + +// Sentinel errors for the sqlite package. +var ( + // ErrInvalidSchema indicates a problem with the provided schema or database introspection. + ErrInvalidSchema = errors.New("invalid schema") +) + +// validTableName matches safe SQLite table names (letters, digits, underscores). +var validTableName = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + +// FieldSchema is an alias for schema.FieldSchema. +type FieldSchema = schema.FieldSchema + +// Schema is an alias for schema.Schema. +type Schema = schema.Schema + +// NewSchema creates a new Schema. This is an alias for schema.NewSchema. +var NewSchema = schema.NewSchema + +// TypeProvider interface for SQLite type providers. +type TypeProvider interface { + types.Provider + LoadTableSchema(ctx context.Context, tableName string) error + GetSchemas() map[string]Schema + Close() +} + +type typeProvider struct { + schemas map[string]Schema + db *sql.DB +} + +// NewTypeProvider creates a new SQLite type provider with pre-defined schemas. +func NewTypeProvider(schemas map[string]Schema) TypeProvider { + return &typeProvider{schemas: schemas} +} + +// NewTypeProviderWithConnection creates a new SQLite type provider that can introspect database schemas. +// The caller owns the *sql.DB and is responsible for closing it. +func NewTypeProviderWithConnection(_ context.Context, db *sql.DB) (TypeProvider, error) { + if db == nil { + return nil, fmt.Errorf("%w: db connection must not be nil", ErrInvalidSchema) + } + + return &typeProvider{ + schemas: make(map[string]Schema), + db: db, + }, nil +} + +// LoadTableSchema loads schema information for a table from the database using PRAGMA table_info. +func (tp *typeProvider) LoadTableSchema(ctx context.Context, tableName string) error { + if tp.db == nil { + return fmt.Errorf("%w: no database connection available", ErrInvalidSchema) + } + + // Validate table name to prevent SQL injection (PRAGMA doesn't support parameterized queries) + if !validTableName.MatchString(tableName) { + return fmt.Errorf("%w: invalid table name %q", ErrInvalidSchema, tableName) + } + + // PRAGMA doesn't support parameterized queries, but we've validated the table name above + // #nosec G202 - table name is validated against strict regex pattern + query := fmt.Sprintf("PRAGMA table_info(%s)", tableName) + + rows, err := tp.db.QueryContext(ctx, query) + if err != nil { + return fmt.Errorf("%w: failed to query table schema", ErrInvalidSchema) + } + defer func() { _ = rows.Close() }() + + var fields []FieldSchema + for rows.Next() { + var cid int + var name, colType string + var notnull int + var dfltValue *string + var pk int + + if err := rows.Scan(&cid, &name, &colType, ¬null, &dfltValue, &pk); err != nil { + return fmt.Errorf("%w: failed to scan row", ErrInvalidSchema) + } + + field := sqliteColumnToFieldSchema(name, colType) + fields = append(fields, field) + } + + if err := rows.Err(); err != nil { + return fmt.Errorf("%w: error iterating rows", ErrInvalidSchema) + } + + if len(fields) == 0 { + return fmt.Errorf("%w: table %q has no columns or does not exist", ErrInvalidSchema, tableName) + } + + tp.schemas[tableName] = NewSchema(fields) + return nil +} + +// sqliteColumnToFieldSchema converts SQLite column metadata to a FieldSchema. +func sqliteColumnToFieldSchema(name, colType string) FieldSchema { + normalizedType := normalizeSQLiteType(colType) + isJSON := strings.EqualFold(colType, "json") || strings.EqualFold(colType, "jsonb") + + return FieldSchema{ + Name: name, + Type: normalizedType, + IsJSON: isJSON, + } +} + +// Normalized type constants used by normalizeSQLiteType. +const ( + sqliteTypeText = "text" + sqliteTypeInteger = "integer" + sqliteTypeReal = "real" + sqliteTypeBlob = "blob" + sqliteTypeJSON = "json" + sqliteTypeBool = "boolean" + sqliteTypeDatetime = "datetime" +) + +// normalizeSQLiteType converts a SQLite column type declaration to a normalized type name. +// SQLite uses type affinity, so we map common type names to our internal types. +func normalizeSQLiteType(colType string) string { + upper := strings.ToUpper(strings.TrimSpace(colType)) + + // Check for exact matches first + switch upper { + case "TEXT", "VARCHAR", "CHAR", "CLOB": + return sqliteTypeText + case "INTEGER", "INT", "TINYINT", "SMALLINT", "MEDIUMINT", "BIGINT": + return sqliteTypeInteger + case "REAL", "DOUBLE", "FLOAT", "NUMERIC", "DECIMAL": + return sqliteTypeReal + case "BOOLEAN", "BOOL": + return sqliteTypeBool + case "BLOB": + return sqliteTypeBlob + case "JSON", "JSONB": + return sqliteTypeJSON + case "DATETIME", "TIMESTAMP": + return sqliteTypeDatetime + } + + // Check for type names that contain known keywords (e.g., "VARCHAR(255)") + if strings.Contains(upper, "INT") { + return sqliteTypeInteger + } + if strings.Contains(upper, "CHAR") || strings.Contains(upper, "CLOB") || strings.Contains(upper, "TEXT") { + return sqliteTypeText + } + if strings.Contains(upper, "BLOB") { + return sqliteTypeBlob + } + if strings.Contains(upper, "REAL") || strings.Contains(upper, "FLOA") || strings.Contains(upper, "DOUBLE") { + return sqliteTypeReal + } + + // Default to text for unknown types (SQLite's flexible typing) + return sqliteTypeText +} + +// Close is a no-op since we don't own the *sql.DB. +func (tp *typeProvider) Close() { + // No-op: caller owns the *sql.DB connection +} + +// GetSchemas returns the schemas known to this type provider. +func (tp *typeProvider) GetSchemas() map[string]Schema { + return tp.schemas +} + +// EnumValue implements types.Provider. +func (tp *typeProvider) EnumValue(_ string) ref.Val { + return types.NewErr("unknown enum value") +} + +// FindIdent implements types.Provider. +func (tp *typeProvider) FindIdent(_ string) (ref.Val, bool) { + return nil, false +} + +// FindStructType implements types.Provider. +func (tp *typeProvider) FindStructType(structType string) (*types.Type, bool) { + if _, ok := tp.schemas[structType]; ok { + return types.NewObjectType(structType), true + } + return nil, false +} + +// FindStructFieldNames implements types.Provider. +func (tp *typeProvider) FindStructFieldNames(structType string) ([]string, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + fields := s.Fields() + names := make([]string, len(fields)) + for i, f := range fields { + names[i] = f.Name + } + return names, true +} + +// FindStructFieldType implements types.Provider. +func (tp *typeProvider) FindStructFieldType(structType, fieldName string) (*types.FieldType, bool) { + s, ok := tp.schemas[structType] + if !ok { + return nil, false + } + field, found := s.FindField(fieldName) + if !found { + return nil, false + } + + exprType := sqliteTypeToCELExprType(field) + celType, err := types.ExprTypeToType(exprType) + if err != nil { + return nil, false + } + + return &types.FieldType{ + Type: celType, + }, true +} + +// NewValue implements types.Provider. +func (tp *typeProvider) NewValue(_ string, _ map[string]ref.Val) ref.Val { + return types.NewErr("unknown type in schema") +} + +// sqliteTypeToCELExprType converts a SQLite field schema to a CEL expression type. +func sqliteTypeToCELExprType(field *schema.FieldSchema) *exprpb.Type { + baseType := sqliteBaseTypeToCEL(field.Type) + if field.Repeated { + return decls.NewListType(baseType) + } + return baseType +} + +// sqliteBaseTypeToCEL converts a SQLite type name to a CEL expression type. +func sqliteBaseTypeToCEL(typeName string) *exprpb.Type { + switch typeName { + case "text", "varchar", "char", "clob": + return decls.String + case "integer", "int", "tinyint", "smallint", "mediumint", "bigint": + return decls.Int + case "real", "double", "float", "numeric", "decimal": + return decls.Double + case "boolean", "bool": + return decls.Bool + case "blob": + return decls.Bytes + case "json": + return decls.Dyn + case "datetime", "timestamp": + return decls.Timestamp + default: + return decls.Dyn + } +} diff --git a/sqlite/provider_test.go b/sqlite/provider_test.go new file mode 100644 index 0000000..e4e47f8 --- /dev/null +++ b/sqlite/provider_test.go @@ -0,0 +1,348 @@ +package sqlite_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/google/cel-go/common/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + _ "modernc.org/sqlite" + + "github.com/spandigital/cel2sql/v3/sqlite" +) + +func TestNewTypeProvider(t *testing.T) { + schemas := map[string]sqlite.Schema{ + "users": sqlite.NewSchema([]sqlite.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + }), + } + + provider := sqlite.NewTypeProvider(schemas) + require.NotNil(t, provider) + + got := provider.GetSchemas() + assert.Len(t, got, 1) + assert.Contains(t, got, "users") +} + +func TestNewTypeProviderWithConnection_NilDB(t *testing.T) { + _, err := sqlite.NewTypeProviderWithConnection(context.Background(), nil) + require.Error(t, err) + assert.ErrorIs(t, err, sqlite.ErrInvalidSchema) +} + +func TestLoadTableSchema_NoDB(t *testing.T) { + provider := sqlite.NewTypeProvider(nil) + err := provider.LoadTableSchema(context.Background(), "test") + require.Error(t, err) + assert.ErrorIs(t, err, sqlite.ErrInvalidSchema) +} + +func TestLoadTableSchema_InvalidTableName(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { _ = db.Close() }() + + ctx := context.Background() + provider, err := sqlite.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + // SQL injection attempts should be rejected + invalidNames := []string{ + "table; DROP TABLE users", + "table'name", + "table-name", + "table.name", + "123table", + "", + "table name", + } + + for _, name := range invalidNames { + t.Run(name, func(t *testing.T) { + err := provider.LoadTableSchema(ctx, name) + require.Error(t, err) + assert.ErrorIs(t, err, sqlite.ErrInvalidSchema) + }) + } +} + +func TestTypeProvider_FindStructType(t *testing.T) { + schemas := map[string]sqlite.Schema{ + "users": sqlite.NewSchema([]sqlite.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + }), + } + provider := sqlite.NewTypeProvider(schemas) + + typ, found := provider.FindStructType("users") + assert.True(t, found) + assert.NotNil(t, typ) + + _, found = provider.FindStructType("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldNames(t *testing.T) { + schemas := map[string]sqlite.Schema{ + "users": sqlite.NewSchema([]sqlite.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + {Name: "email", Type: "text"}, + }), + } + provider := sqlite.NewTypeProvider(schemas) + + names, found := provider.FindStructFieldNames("users") + assert.True(t, found) + assert.ElementsMatch(t, []string{"id", "name", "email"}, names) + + _, found = provider.FindStructFieldNames("nonexistent") + assert.False(t, found) +} + +func TestTypeProvider_FindStructFieldType(t *testing.T) { + schemas := map[string]sqlite.Schema{ + "test_table": sqlite.NewSchema([]sqlite.FieldSchema{ + {Name: "text_field", Type: "text"}, + {Name: "int_field", Type: "integer"}, + {Name: "real_field", Type: "real"}, + {Name: "bool_field", Type: "boolean"}, + {Name: "blob_field", Type: "blob"}, + {Name: "json_field", Type: "json"}, + {Name: "datetime_field", Type: "datetime"}, + }), + } + provider := sqlite.NewTypeProvider(schemas) + + tests := []struct { + fieldName string + wantType *types.Type + wantFound bool + }{ + {"text_field", types.StringType, true}, + {"int_field", types.IntType, true}, + {"real_field", types.DoubleType, true}, + {"bool_field", types.BoolType, true}, + {"blob_field", types.BytesType, true}, + {"json_field", types.DynType, true}, + {"datetime_field", types.TimestampType, true}, + {"nonexistent", nil, false}, + } + + for _, tt := range tests { + t.Run(tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("test_table", tt.fieldName) + assert.Equal(t, tt.wantFound, found) + if found { + assert.Equal(t, tt.wantType, got.Type) + } + }) + } +} + +func TestTypeProvider_Close(_ *testing.T) { + provider := sqlite.NewTypeProvider(nil) + // Close should not panic + provider.Close() +} + +func TestLoadTableSchema_Integration(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { _ = db.Close() }() + + ctx := context.Background() + + // Create test table with various SQLite types + _, err = db.ExecContext(ctx, ` + CREATE TABLE schema_test ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + description VARCHAR(255), + age INTEGER, + score REAL, + is_active BOOLEAN, + avatar BLOB, + metadata JSON, + created_at DATETIME + ) + `) + require.NoError(t, err) + + provider, err := sqlite.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "schema_test") + require.NoError(t, err) + + // Verify schema was loaded + schemas := provider.GetSchemas() + assert.Contains(t, schemas, "schema_test") + + // Verify FindStructType + typ, found := provider.FindStructType("schema_test") + assert.True(t, found) + assert.NotNil(t, typ) + + // Verify FindStructFieldNames + names, found := provider.FindStructFieldNames("schema_test") + assert.True(t, found) + assert.Contains(t, names, "id") + assert.Contains(t, names, "name") + assert.Contains(t, names, "metadata") + + // Verify type mappings from loaded schema + tests := []struct { + fieldName string + wantType *types.Type + }{ + {"id", types.IntType}, + {"name", types.StringType}, + {"description", types.StringType}, + {"age", types.IntType}, + {"score", types.DoubleType}, + {"is_active", types.BoolType}, + {"avatar", types.BytesType}, + {"metadata", types.DynType}, + {"created_at", types.TimestampType}, + } + + for _, tt := range tests { + t.Run("type_"+tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("schema_test", tt.fieldName) + assert.True(t, found, "field %q should be found", tt.fieldName) + if found { + assert.Equal(t, tt.wantType, got.Type, "field %q type mismatch", tt.fieldName) + } + }) + } + + // Verify JSON detection + schemaObj := schemas["schema_test"] + metadataField, found := schemaObj.FindField("metadata") + assert.True(t, found) + assert.True(t, metadataField.IsJSON, "metadata should be detected as JSON") +} + +func TestLoadTableSchema_NonexistentTable(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { _ = db.Close() }() + + ctx := context.Background() + provider, err := sqlite.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "nonexistent_table") + require.Error(t, err) + assert.ErrorIs(t, err, sqlite.ErrInvalidSchema) +} + +func TestLoadTableSchema_MultipleTablesIntegration(t *testing.T) { + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { _ = db.Close() }() + + ctx := context.Background() + + // Create two tables + _, err = db.ExecContext(ctx, `CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, email TEXT)`) + require.NoError(t, err) + _, err = db.ExecContext(ctx, `CREATE TABLE products (id INTEGER PRIMARY KEY, title TEXT, price REAL)`) + require.NoError(t, err) + + provider, err := sqlite.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "users") + require.NoError(t, err) + err = provider.LoadTableSchema(ctx, "products") + require.NoError(t, err) + + schemas := provider.GetSchemas() + assert.Len(t, schemas, 2) + assert.Contains(t, schemas, "users") + assert.Contains(t, schemas, "products") + + // Verify both schemas are independent + userNames, found := provider.FindStructFieldNames("users") + assert.True(t, found) + assert.ElementsMatch(t, []string{"id", "name", "email"}, userNames) + + productNames, found := provider.FindStructFieldNames("products") + assert.True(t, found) + assert.ElementsMatch(t, []string{"id", "title", "price"}, productNames) +} + +func TestNormalizeSQLiteType(t *testing.T) { + // Test via LoadTableSchema that various SQLite type declarations are normalized correctly + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { _ = db.Close() }() + + ctx := context.Background() + + _, err = db.ExecContext(ctx, ` + CREATE TABLE type_test ( + col_int INTEGER, + col_varchar VARCHAR(255), + col_char CHAR(10), + col_text TEXT, + col_real REAL, + col_float FLOAT, + col_double DOUBLE, + col_numeric NUMERIC, + col_blob BLOB, + col_bool BOOLEAN, + col_datetime DATETIME, + col_timestamp TIMESTAMP, + col_bigint BIGINT, + col_smallint SMALLINT, + col_tinyint TINYINT, + col_json JSON + ) + `) + require.NoError(t, err) + + provider, err := sqlite.NewTypeProviderWithConnection(ctx, db) + require.NoError(t, err) + + err = provider.LoadTableSchema(ctx, "type_test") + require.NoError(t, err) + + tests := []struct { + fieldName string + wantType *types.Type + }{ + {"col_int", types.IntType}, + {"col_varchar", types.StringType}, + {"col_char", types.StringType}, + {"col_text", types.StringType}, + {"col_real", types.DoubleType}, + {"col_float", types.DoubleType}, + {"col_double", types.DoubleType}, + {"col_numeric", types.DoubleType}, + {"col_blob", types.BytesType}, + {"col_bool", types.BoolType}, + {"col_datetime", types.TimestampType}, + {"col_timestamp", types.TimestampType}, + {"col_bigint", types.IntType}, + {"col_smallint", types.IntType}, + {"col_tinyint", types.IntType}, + {"col_json", types.DynType}, + } + + for _, tt := range tests { + t.Run(tt.fieldName, func(t *testing.T) { + got, found := provider.FindStructFieldType("type_test", tt.fieldName) + require.True(t, found, "field %q should be found", tt.fieldName) + assert.Equal(t, tt.wantType, got.Type, "field %q type mismatch", tt.fieldName) + }) + } +} diff --git a/sqlite_integration_test.go b/sqlite_integration_test.go new file mode 100644 index 0000000..f70b5d7 --- /dev/null +++ b/sqlite_integration_test.go @@ -0,0 +1,518 @@ +package cel2sql_test + +import ( + "database/sql" + "testing" + + "github.com/google/cel-go/cel" + "github.com/stretchr/testify/require" + _ "modernc.org/sqlite" + + "github.com/spandigital/cel2sql/v3" + sqliteDialect "github.com/spandigital/cel2sql/v3/dialect/sqlite" + "github.com/spandigital/cel2sql/v3/pg" +) + +// TestSQLiteOperatorsIntegration validates operator conversions against a real SQLite database. +// This uses an in-memory SQLite database (no Docker required). +func TestSQLiteOperatorsIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // Open in-memory SQLite database + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + }() + + // Create test table (SQLite uses INTEGER for booleans, REAL for floats) + _, err = db.Exec(` + CREATE TABLE test_data ( + id INTEGER PRIMARY KEY, + text_val TEXT, + int_val INTEGER, + float_val REAL, + bool_val INTEGER, + nullable_text TEXT, + nullable_int INTEGER + ) + `) + require.NoError(t, err) + + // Insert test data (using 1/0 for boolean values) + _, err = db.Exec(` + INSERT INTO test_data VALUES + (1, 'hello', 10, 10.5, 1, 'present', 100), + (2, 'world', 20, 20.5, 0, NULL, NULL), + (3, 'test', 30, 30.5, 1, 'here', 200), + (4, 'hello world', 5, 5.5, 0, 'value', 50), + (5, 'testing', 15, 15.5, 1, 'test', 150) + `) + require.NoError(t, err) + + // Set up CEL environment with simple variables + env, err := cel.NewEnv( + cel.Variable("id", cel.IntType), + cel.Variable("text_val", cel.StringType), + cel.Variable("int_val", cel.IntType), + cel.Variable("float_val", cel.DoubleType), + cel.Variable("bool_val", cel.BoolType), + cel.Variable("nullable_text", cel.StringType), + cel.Variable("nullable_int", cel.IntType), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(sqliteDialect.New()) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + // Comparison operators + { + name: "Equality string", + celExpr: `text_val == "hello"`, + expectedRows: 1, + description: "String equality comparison", + }, + { + name: "Equality integer", + celExpr: `int_val == 20`, + expectedRows: 1, + description: "Integer equality comparison", + }, + { + name: "Equality float", + celExpr: `float_val == 10.5`, + expectedRows: 1, + description: "Float equality comparison", + }, + { + name: "Equality boolean", + celExpr: `bool_val == true`, + expectedRows: 3, + description: "Boolean equality comparison", + }, + { + name: "Not equal", + celExpr: `text_val != "hello"`, + expectedRows: 4, + description: "Not equal comparison", + }, + { + name: "Less than", + celExpr: `int_val < 15`, + expectedRows: 2, // 10, 5 + description: "Less than comparison", + }, + { + name: "Less than or equal", + celExpr: `int_val <= 15`, + expectedRows: 3, // 10, 5, 15 + description: "Less than or equal comparison", + }, + { + name: "Greater than", + celExpr: `int_val > 15`, + expectedRows: 2, // 20, 30 + description: "Greater than comparison", + }, + { + name: "Greater than or equal", + celExpr: `int_val >= 15`, + expectedRows: 3, // 20, 30, 15 + description: "Greater than or equal comparison", + }, + + // Logical operators + { + name: "Logical AND", + celExpr: `int_val > 10 && bool_val == true`, + expectedRows: 2, // rows 3 (30,true) and 5 (15,true) + description: "Logical AND operator", + }, + { + name: "Logical OR", + celExpr: `int_val < 10 || bool_val == false`, + expectedRows: 2, // rows 2 (20,false) and 4 (5,false) + description: "Logical OR operator", + }, + { + name: "Logical NOT", + celExpr: `!bool_val`, + expectedRows: 2, // rows 2 and 4 (bool_val == false) + description: "Logical NOT operator", + }, + { + name: "Complex logical expression", + celExpr: `(int_val > 10 && bool_val) || int_val < 10`, + expectedRows: 3, // rows 3, 5 (>10 && true), row 4 (<10) + description: "Complex nested logical operators", + }, + + // Arithmetic operators + { + name: "Addition", + celExpr: `int_val + 10 == 20`, + expectedRows: 1, // 10 + 10 = 20 + description: "Addition operator", + }, + { + name: "Subtraction", + celExpr: `int_val - 5 == 15`, + expectedRows: 1, // 20 - 5 = 15 + description: "Subtraction operator", + }, + { + name: "Multiplication", + celExpr: `int_val * 2 == 20`, + expectedRows: 1, // 10 * 2 = 20 + description: "Multiplication operator", + }, + { + name: "Division", + celExpr: `int_val / 2 == 10`, + expectedRows: 1, // 20 / 2 = 10 + description: "Division operator", + }, + { + name: "Modulo", + celExpr: `int_val % 10 == 0`, + expectedRows: 3, // 10, 20, 30 + description: "Modulo operator", + }, + { + name: "Complex arithmetic", + celExpr: `(int_val * 2) + 5 > 30`, + expectedRows: 3, // (20*2)+5=45, (30*2)+5=65, (15*2)+5=35 + description: "Complex arithmetic expression", + }, + + // String operators + { + name: "String concatenation", + celExpr: `text_val + "!" == "hello!"`, + expectedRows: 1, + description: "String concatenation operator (||)", + }, + { + name: "String contains", + celExpr: `text_val.contains("world")`, + expectedRows: 2, // "world", "hello world" + description: "String contains function (INSTR)", + }, + { + name: "String startsWith", + celExpr: `text_val.startsWith("hello")`, + expectedRows: 2, // "hello", "hello world" + description: "String startsWith function (LIKE)", + }, + { + name: "String endsWith", + celExpr: `text_val.endsWith("world")`, + expectedRows: 2, // "world", "hello world" + description: "String endsWith function (LIKE)", + }, + + // Complex combined operators + { + name: "Complex multi-operator expression", + celExpr: `int_val > 10 && bool_val && text_val.contains("test")`, + expectedRows: 2, // rows 3 (30, true, "test") and 5 (15, true, "testing") + description: "Complex expression with multiple operator types", + }, + { + name: "Nested parenthesized operators", + celExpr: `((int_val + 5) * 2 > 30) && (text_val.contains("test") || bool_val)`, + expectedRows: 2, // rows 3 and 5 + description: "Deeply nested operators with parentheses", + }, + { + name: "Triple negation", + celExpr: `!!!bool_val`, + expectedRows: 2, // rows 2 and 4 (bool_val == false) + description: "Multiple NOT operators", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Compile CEL expression + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + // Convert to SQL with SQLite dialect + sqlCondition, err := cel2sql.Convert(ast, dialectOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // Execute query and count results + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := "SELECT COUNT(*) FROM test_data WHERE " + sqlCondition + t.Logf("Full SQL Query: %s", query) + + var actualRows int + err = db.QueryRow(query).Scan(&actualRows) + require.NoError(t, err, "Generated SQL should execute successfully. %s\nSQL: %s", + tt.description, sqlCondition) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s\nCEL: %s\nSQL: %s", + tt.description, tt.celExpr, sqlCondition) + + t.Logf("OK: %s (expected %d rows, got %d rows)", + tt.description, tt.expectedRows, actualRows) + }) + } +} + +// TestSQLiteEdgeCasesIntegration validates edge cases against a real SQLite database. +func TestSQLiteEdgeCasesIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + }() + + // Create test table with edge case values + _, err = db.Exec(` + CREATE TABLE edge_cases ( + id INTEGER PRIMARY KEY, + empty_string TEXT, + zero_int INTEGER, + zero_float REAL, + negative_int INTEGER, + negative_float REAL, + large_int INTEGER + ) + `) + require.NoError(t, err) + + _, err = db.Exec(` + INSERT INTO edge_cases VALUES + (1, '', 0, 0.0, -10, -5.5, 9223372036854775807), + (2, 'value', 1, 1.0, -1, -0.1, 123456789), + (3, 'another', 0, 0.0, 0, 0.0, 0) + `) + require.NoError(t, err) + + env, err := cel.NewEnv( + cel.Variable("id", cel.IntType), + cel.Variable("empty_string", cel.StringType), + cel.Variable("zero_int", cel.IntType), + cel.Variable("zero_float", cel.DoubleType), + cel.Variable("negative_int", cel.IntType), + cel.Variable("negative_float", cel.DoubleType), + cel.Variable("large_int", cel.IntType), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(sqliteDialect.New()) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + { + name: "Empty string equality", + celExpr: `empty_string == ""`, + expectedRows: 1, + description: "Empty string should be handled correctly", + }, + { + name: "Zero integer equality", + celExpr: `zero_int == 0`, + expectedRows: 2, + description: "Zero should be handled correctly", + }, + { + name: "Negative integer comparison", + celExpr: `negative_int < 0`, + expectedRows: 2, // -10 and -1 + description: "Negative numbers should work correctly", + }, + { + name: "Large integer comparison", + celExpr: `large_int > 1000000`, + expectedRows: 2, // 9223372036854775807 and 123456789 + description: "Large integers should be handled correctly", + }, + { + name: "Zero float equality", + celExpr: `zero_float == 0.0`, + expectedRows: 2, + description: "Zero float should be handled correctly", + }, + { + name: "Negative float comparison", + celExpr: `negative_float < 0.0`, + expectedRows: 2, // -5.5 and -0.1 + description: "Negative floats should work correctly", + }, + { + name: "Arithmetic with zero", + celExpr: `zero_int + 10 == 10`, + expectedRows: 2, // 0 + 10 = 10 + description: "Arithmetic with zero should work", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + sqlCondition, err := cel2sql.Convert(ast, dialectOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := "SELECT COUNT(*) FROM edge_cases WHERE " + sqlCondition + + var actualRows int + err = db.QueryRow(query).Scan(&actualRows) + require.NoError(t, err, "Generated SQL should execute successfully. %s", + tt.description) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s", tt.description) + + t.Logf("OK: %s", tt.description) + }) + } +} + +// TestSQLiteJSONIntegration validates JSON operations against a real SQLite database. +func TestSQLiteJSONIntegration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + db, err := sql.Open("sqlite", ":memory:") + require.NoError(t, err) + defer func() { + if closeErr := db.Close(); closeErr != nil { + t.Logf("failed to close db: %v", closeErr) + } + }() + + // Create test table with JSON column (stored as TEXT in SQLite) + _, err = db.Exec(` + CREATE TABLE products ( + id INTEGER PRIMARY KEY, + name TEXT, + price REAL, + metadata TEXT + ) + `) + require.NoError(t, err) + + _, err = db.Exec(` + INSERT INTO products VALUES + (1, 'Widget', 19.99, '{"brand": "Acme", "color": "red", "specs": {"weight": 100}}'), + (2, 'Gadget', 29.99, '{"brand": "Beta", "color": "blue", "specs": {"weight": 200}}'), + (3, 'Doohickey', 39.99, '{"brand": "Acme", "color": "green", "specs": {"weight": 150}}') + `) + require.NoError(t, err) + + // Set up CEL environment with schema for JSON detection + productSchema := pg.NewSchema([]pg.FieldSchema{ + {Name: "id", Type: "integer"}, + {Name: "name", Type: "text"}, + {Name: "price", Type: "double precision"}, + {Name: "metadata", Type: "json", IsJSON: true}, + }) + + schemas := map[string]pg.Schema{ + "product": productSchema, + } + + provider := pg.NewTypeProvider(schemas) + + env, err := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("product", cel.ObjectType("product")), + ) + require.NoError(t, err) + + dialectOpt := cel2sql.WithDialect(sqliteDialect.New()) + schemaOpt := cel2sql.WithSchemas(schemas) + + tests := []struct { + name string + celExpr string + expectedRows int + description string + }{ + { + name: "JSON field access", + celExpr: `product.metadata.brand == "Acme"`, + expectedRows: 2, + description: "JSON field access with json_extract", + }, + { + name: "JSON field access different value", + celExpr: `product.metadata.color == "blue"`, + expectedRows: 1, + description: "JSON field access with different value", + }, + { + name: "JSON nested field access", + celExpr: `product.metadata.brand == "Acme" && product.price > 30.0`, + expectedRows: 1, // Doohickey (Acme, 39.99) + description: "JSON field combined with regular field comparison", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ast, issues := env.Compile(tt.celExpr) + if issues != nil && issues.Err() != nil { + t.Fatalf("CEL compilation failed: %v", issues.Err()) + } + + sqlCondition, err := cel2sql.Convert(ast, dialectOpt, schemaOpt) + require.NoError(t, err, "Conversion should succeed for: %s", tt.description) + + t.Logf("CEL Expression: %s", tt.celExpr) + t.Logf("Generated SQL WHERE clause: %s", sqlCondition) + + // #nosec G202 - This is a test validating SQL generation, not a security risk + query := "SELECT COUNT(*) FROM products product WHERE " + sqlCondition + t.Logf("Full SQL Query: %s", query) + + var actualRows int + err = db.QueryRow(query).Scan(&actualRows) + require.NoError(t, err, "Generated SQL should execute successfully. %s\nSQL: %s", + tt.description, sqlCondition) + + require.Equal(t, tt.expectedRows, actualRows, + "Query should return expected number of rows. %s\nCEL: %s\nSQL: %s", + tt.description, tt.celExpr, sqlCondition) + + t.Logf("OK: %s", tt.description) + }) + } +} diff --git a/testcases/array_tests.go b/testcases/array_tests.go new file mode 100644 index 0000000..3a92d35 --- /dev/null +++ b/testcases/array_tests.go @@ -0,0 +1,69 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// ArrayTests returns test cases for array operations. +func ArrayTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "list_index_literal", + CELExpr: `[1, 2, 3][0] == 1`, + Category: CategoryArray, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "ARRAY[1, 2, 3][1] = 1", + dialect.DuckDB: "[1, 2, 3][1] = 1", + dialect.BigQuery: "[1, 2, 3][OFFSET(0)] = 1", + }, + }, + { + Name: "list_var_index", + CELExpr: `string_list[0] == "a"`, + Category: CategoryArray, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "string_list[1] = 'a'", + dialect.DuckDB: "string_list[1] = 'a'", + dialect.BigQuery: "string_list[OFFSET(0)] = 'a'", + }, + }, + { + Name: "size_list", + CELExpr: `size(string_list)`, + Category: CategoryArray, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "COALESCE(ARRAY_LENGTH(string_list, 1), 0)", + dialect.DuckDB: "COALESCE(array_length(string_list), 0)", + dialect.BigQuery: "ARRAY_LENGTH(string_list)", + }, + }, + { + Name: "size_list_comparison", + CELExpr: `size(string_list) > 0`, + Category: CategoryArray, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "COALESCE(ARRAY_LENGTH(string_list, 1), 0) > 0", + dialect.DuckDB: "COALESCE(array_length(string_list), 0) > 0", + dialect.BigQuery: "ARRAY_LENGTH(string_list) > 0", + }, + }, + { + Name: "array_index_overflow", + CELExpr: `string_list[9223372036854775807]`, + Category: CategoryArray, + WantErr: map[dialect.Name]bool{ + dialect.PostgreSQL: true, + dialect.DuckDB: true, + dialect.BigQuery: true, + }, + }, + { + Name: "array_index_negative", + CELExpr: `string_list[-1]`, + Category: CategoryArray, + WantErr: map[dialect.Name]bool{ + dialect.PostgreSQL: true, + dialect.DuckDB: true, + dialect.BigQuery: true, + }, + }, + } +} diff --git a/testcases/basic_tests.go b/testcases/basic_tests.go new file mode 100644 index 0000000..a04ed70 --- /dev/null +++ b/testcases/basic_tests.go @@ -0,0 +1,141 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// BasicTests returns test cases for basic comparisons and expressions. +func BasicTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "equality_string", + CELExpr: `name == "a"`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name = 'a'", + dialect.MySQL: "name = 'a'", + dialect.SQLite: "name = 'a'", + dialect.DuckDB: "name = 'a'", + dialect.BigQuery: "name = 'a'", + }, + }, + { + Name: "inequality_int", + CELExpr: `age != 20`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "age != 20", + dialect.MySQL: "age != 20", + dialect.SQLite: "age != 20", + dialect.DuckDB: "age != 20", + dialect.BigQuery: "age != 20", + }, + }, + { + Name: "less_than", + CELExpr: `age < 20`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "age < 20", + dialect.MySQL: "age < 20", + dialect.SQLite: "age < 20", + dialect.DuckDB: "age < 20", + dialect.BigQuery: "age < 20", + }, + }, + { + Name: "greater_equal_float", + CELExpr: `height >= 1.6180339887`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "height >= 1.6180339887", + dialect.MySQL: "height >= 1.6180339887", + dialect.SQLite: "height >= 1.6180339887", + dialect.DuckDB: "height >= 1.6180339887", + dialect.BigQuery: "height >= 1.6180339887", + }, + }, + { + Name: "is_null", + CELExpr: `null_var == null`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "null_var IS NULL", + dialect.MySQL: "null_var IS NULL", + dialect.SQLite: "null_var IS NULL", + dialect.DuckDB: "null_var IS NULL", + dialect.BigQuery: "null_var IS NULL", + }, + }, + { + Name: "is_not_true", + CELExpr: `adult != true`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "adult IS NOT TRUE", + dialect.MySQL: "adult IS NOT TRUE", + dialect.SQLite: "adult IS NOT TRUE", + dialect.DuckDB: "adult IS NOT TRUE", + dialect.BigQuery: "adult IS NOT TRUE", + }, + }, + { + Name: "not", + CELExpr: `!adult`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "NOT adult", + dialect.MySQL: "NOT adult", + dialect.SQLite: "NOT adult", + dialect.DuckDB: "NOT adult", + dialect.BigQuery: "NOT adult", + }, + }, + { + Name: "negative_int", + CELExpr: `-1`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "-1", + dialect.MySQL: "-1", + dialect.SQLite: "-1", + dialect.DuckDB: "-1", + dialect.BigQuery: "-1", + }, + }, + { + Name: "ternary", + CELExpr: `name == "a" ? "a" : "b"`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "CASE WHEN name = 'a' THEN 'a' ELSE 'b' END", + dialect.MySQL: "CASE WHEN name = 'a' THEN 'a' ELSE 'b' END", + dialect.SQLite: "CASE WHEN name = 'a' THEN 'a' ELSE 'b' END", + dialect.DuckDB: "CASE WHEN name = 'a' THEN 'a' ELSE 'b' END", + dialect.BigQuery: "CASE WHEN name = 'a' THEN 'a' ELSE 'b' END", + }, + }, + { + Name: "field_select", + CELExpr: `page.title == "test"`, + Category: CategoryFieldAccess, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "page.title = 'test'", + dialect.MySQL: "page.title = 'test'", + dialect.SQLite: "page.title = 'test'", + dialect.DuckDB: "page.title = 'test'", + dialect.BigQuery: "page.title = 'test'", + }, + }, + { + Name: "in_list", + CELExpr: `name in ["a", "b", "c"]`, + Category: CategoryBasic, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name = ANY(ARRAY['a', 'b', 'c'])", + dialect.MySQL: "JSON_CONTAINS(JSON_ARRAY('a', 'b', 'c'), CAST(name AS JSON))", + dialect.SQLite: "name IN (SELECT value FROM json_each(json_array('a', 'b', 'c')))", + dialect.DuckDB: "name = ANY(['a', 'b', 'c'])", + dialect.BigQuery: "name IN UNNEST(['a', 'b', 'c'])", + }, + }, + } +} diff --git a/testcases/cast_tests.go b/testcases/cast_tests.go new file mode 100644 index 0000000..bcf85a5 --- /dev/null +++ b/testcases/cast_tests.go @@ -0,0 +1,81 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// CastTests returns test cases for type casting operations. +func CastTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "cast_bool", + CELExpr: `bool(0) == false`, + Category: CategoryCast, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "CAST(0 AS BOOLEAN) IS FALSE", + dialect.MySQL: "CAST(0 AS UNSIGNED) IS FALSE", + dialect.SQLite: "CAST(0 AS INTEGER) IS FALSE", + dialect.DuckDB: "CAST(0 AS BOOLEAN) IS FALSE", + dialect.BigQuery: "CAST(0 AS BOOL) IS FALSE", + }, + }, + { + Name: "cast_bytes", + CELExpr: `bytes("test")`, + Category: CategoryCast, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "CAST('test' AS BYTEA)", + dialect.MySQL: "CAST('test' AS BINARY)", + dialect.SQLite: "CAST('test' AS BLOB)", + dialect.DuckDB: "CAST('test' AS BLOB)", + dialect.BigQuery: "CAST('test' AS BYTES)", + }, + }, + { + Name: "cast_int", + CELExpr: `int(true) == 1`, + Category: CategoryCast, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "CAST(TRUE AS BIGINT) = 1", + dialect.MySQL: "CAST(TRUE AS SIGNED) = 1", + dialect.SQLite: "CAST(TRUE AS INTEGER) = 1", + dialect.DuckDB: "CAST(TRUE AS BIGINT) = 1", + dialect.BigQuery: "CAST(TRUE AS INT64) = 1", + }, + }, + { + Name: "cast_string", + CELExpr: `string(true) == "true"`, + Category: CategoryCast, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "CAST(TRUE AS TEXT) = 'true'", + dialect.MySQL: "CAST(TRUE AS CHAR) = 'true'", + dialect.SQLite: "CAST(TRUE AS TEXT) = 'true'", + dialect.DuckDB: "CAST(TRUE AS VARCHAR) = 'true'", + dialect.BigQuery: "CAST(TRUE AS STRING) = 'true'", + }, + }, + { + Name: "cast_string_from_timestamp", + CELExpr: `string(created_at)`, + Category: CategoryCast, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "CAST(created_at AS TEXT)", + dialect.MySQL: "CAST(created_at AS CHAR)", + dialect.SQLite: "CAST(created_at AS TEXT)", + dialect.DuckDB: "CAST(created_at AS VARCHAR)", + dialect.BigQuery: "CAST(created_at AS STRING)", + }, + }, + { + Name: "cast_int_epoch", + CELExpr: `int(created_at)`, + Category: CategoryCast, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "EXTRACT(EPOCH FROM created_at)::bigint", + dialect.MySQL: "UNIX_TIMESTAMP(created_at)", + dialect.SQLite: "CAST(strftime('%s', created_at) AS INTEGER)", + dialect.DuckDB: "EXTRACT(EPOCH FROM created_at)::BIGINT", + dialect.BigQuery: "UNIX_SECONDS(created_at)", + }, + }, + } +} diff --git a/testcases/comprehension_tests.go b/testcases/comprehension_tests.go new file mode 100644 index 0000000..9eebd0d --- /dev/null +++ b/testcases/comprehension_tests.go @@ -0,0 +1,64 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// ComprehensionTests returns test cases for CEL comprehension operations. +func ComprehensionTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "all", + CELExpr: `string_list.all(x, x != "bad")`, + Category: CategoryComprehension, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "NOT EXISTS (SELECT 1 FROM UNNEST(string_list) AS x WHERE NOT (x != 'bad'))", + dialect.SQLite: "NOT EXISTS (SELECT 1 FROM json_each(string_list) AS x WHERE NOT (x != 'bad'))", + dialect.DuckDB: "NOT EXISTS (SELECT 1 FROM UNNEST(string_list) AS x WHERE NOT (x != 'bad'))", + dialect.BigQuery: "NOT EXISTS (SELECT 1 FROM UNNEST(string_list) AS x WHERE NOT (x != 'bad'))", + }, + }, + { + Name: "exists", + CELExpr: `string_list.exists(x, x == "good")`, + Category: CategoryComprehension, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "EXISTS (SELECT 1 FROM UNNEST(string_list) AS x WHERE x = 'good')", + dialect.SQLite: "EXISTS (SELECT 1 FROM json_each(string_list) AS x WHERE x = 'good')", + dialect.DuckDB: "EXISTS (SELECT 1 FROM UNNEST(string_list) AS x WHERE x = 'good')", + dialect.BigQuery: "EXISTS (SELECT 1 FROM UNNEST(string_list) AS x WHERE x = 'good')", + }, + }, + { + Name: "exists_one", + CELExpr: `string_list.exists_one(x, x == "unique")`, + Category: CategoryComprehension, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "(SELECT COUNT(*) FROM UNNEST(string_list) AS x WHERE x = 'unique') = 1", + dialect.SQLite: "(SELECT COUNT(*) FROM json_each(string_list) AS x WHERE x = 'unique') = 1", + dialect.DuckDB: "(SELECT COUNT(*) FROM UNNEST(string_list) AS x WHERE x = 'unique') = 1", + dialect.BigQuery: "(SELECT COUNT(*) FROM UNNEST(string_list) AS x WHERE x = 'unique') = 1", + }, + }, + { + Name: "filter", + CELExpr: `string_list.filter(x, x != "bad")`, + Category: CategoryComprehension, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "ARRAY(SELECT x FROM UNNEST(string_list) AS x WHERE x != 'bad')", + dialect.SQLite: "(SELECT json_group_array(x) FROM json_each(string_list) AS x WHERE x != 'bad')", + dialect.DuckDB: "ARRAY(SELECT x FROM UNNEST(string_list) AS x WHERE x != 'bad')", + dialect.BigQuery: "ARRAY(SELECT x FROM UNNEST(string_list) AS x WHERE x != 'bad')", + }, + }, + { + Name: "map_transform", + CELExpr: `string_list.map(x, x + "_suffix")`, + Category: CategoryComprehension, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "ARRAY(SELECT x || '_suffix' FROM UNNEST(string_list) AS x)", + dialect.SQLite: "(SELECT json_group_array(x || '_suffix') FROM json_each(string_list) AS x)", + dialect.DuckDB: "ARRAY(SELECT x || '_suffix' FROM UNNEST(string_list) AS x)", + dialect.BigQuery: "ARRAY(SELECT x || '_suffix' FROM UNNEST(string_list) AS x)", + }, + }, + } +} diff --git a/testcases/fixtures.go b/testcases/fixtures.go new file mode 100644 index 0000000..e33e904 --- /dev/null +++ b/testcases/fixtures.go @@ -0,0 +1,93 @@ +package testcases + +import ( + "github.com/spandigital/cel2sql/v3/pg" + "github.com/spandigital/cel2sql/v3/schema" +) + +// EnvDefault is the default environment setup name (basic types, no schema). +const EnvDefault = "" + +// EnvWithSchema is an environment with a schema-based type provider. +const EnvWithSchema = "schema" + +// EnvWithJSON is an environment with JSON/JSONB schema fields. +const EnvWithJSON = "json_schema" + +// EnvWithTimestamp is an environment for timestamp operations. +const EnvWithTimestamp = "timestamp" + +// NewPersonSchema returns a dialect-agnostic schema for the "person" table, +// suitable for basic, operator, and string tests. +func NewPersonSchema() schema.Schema { + return schema.NewSchema([]schema.FieldSchema{ + {Name: "name", Type: "text"}, + {Name: "age", Type: "integer"}, + {Name: "adult", Type: "boolean"}, + {Name: "height", Type: "double precision"}, + {Name: "email", Type: "text"}, + {Name: "tags", Type: "text", Repeated: true}, + {Name: "scores", Type: "integer", Repeated: true}, + }) +} + +// NewPersonPGSchema returns a PostgreSQL-specific schema for the "person" table. +func NewPersonPGSchema() pg.Schema { + return pg.NewSchema([]pg.FieldSchema{ + {Name: "name", Type: "text"}, + {Name: "age", Type: "integer"}, + {Name: "adult", Type: "boolean"}, + {Name: "height", Type: "double precision"}, + {Name: "email", Type: "text"}, + {Name: "tags", Type: "text", Repeated: true}, + {Name: "scores", Type: "integer", Repeated: true}, + }) +} + +// NewProductSchema returns a dialect-agnostic schema for the "product" table, +// with JSON/JSONB fields for JSON-related tests. +func NewProductSchema() schema.Schema { + return schema.NewSchema([]schema.FieldSchema{ + {Name: "name", Type: "text"}, + {Name: "price", Type: "double precision"}, + {Name: "metadata", Type: "jsonb", IsJSON: true, IsJSONB: true}, + {Name: "attributes", Type: "json", IsJSON: true}, + {Name: "tags", Type: "jsonb", IsJSON: true, IsJSONB: true, Repeated: true, ElementType: "text"}, + }) +} + +// NewProductPGSchema returns a PostgreSQL-specific schema for the "product" table. +func NewProductPGSchema() pg.Schema { + return pg.NewSchema([]pg.FieldSchema{ + {Name: "name", Type: "text"}, + {Name: "price", Type: "double precision"}, + {Name: "metadata", Type: "jsonb", IsJSON: true, IsJSONB: true}, + {Name: "attributes", Type: "json", IsJSON: true}, + {Name: "tags", Type: "jsonb", IsJSON: true, IsJSONB: true, Repeated: true, ElementType: "text"}, + }) +} + +// NewOrderSchema returns a dialect-agnostic schema for the "orders" table, +// with array and timestamp fields. +func NewOrderSchema() schema.Schema { + return schema.NewSchema([]schema.FieldSchema{ + {Name: "order_id", Type: "bigint"}, + {Name: "customer_name", Type: "text"}, + {Name: "total", Type: "double precision"}, + {Name: "items", Type: "text", Repeated: true}, + {Name: "created_at", Type: "timestamp with time zone"}, + {Name: "status", Type: "text"}, + }) +} + +// NewOrderPGSchema returns a PostgreSQL-specific schema for the "orders" table. +func NewOrderPGSchema() pg.Schema { + return pg.NewSchema([]pg.FieldSchema{ + {Name: "order_id", Type: "bigint"}, + {Name: "customer_name", Type: "text"}, + {Name: "total", Type: "double precision"}, + {Name: "items", Type: "text", Repeated: true}, + {Name: "created_at", Type: "timestamp with time zone"}, + {Name: "status", Type: "text"}, + }) +} diff --git a/testcases/json_tests.go b/testcases/json_tests.go new file mode 100644 index 0000000..7495a02 --- /dev/null +++ b/testcases/json_tests.go @@ -0,0 +1,49 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// JSONTests returns test cases for JSON/JSONB field access and operations. +// These tests require the "json_schema" environment setup. +func JSONTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "json_field_access", + CELExpr: `product.metadata.brand == "Acme"`, + Category: CategoryJSON, + EnvSetup: EnvWithJSON, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "product.metadata->>'brand' = 'Acme'", + dialect.MySQL: "product.metadata->>'$.brand' = 'Acme'", + dialect.SQLite: "json_extract(product.metadata, '$.brand') = 'Acme'", + dialect.DuckDB: "product.metadata->>'brand' = 'Acme'", + dialect.BigQuery: "JSON_VALUE(product.metadata, '$.brand') = 'Acme'", + }, + }, + { + Name: "json_nested_access", + CELExpr: `product.metadata.specs.color == "red"`, + Category: CategoryJSON, + EnvSetup: EnvWithJSON, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "product.metadata->'specs'->>'color' = 'red'", + dialect.MySQL: "product.metadata->'$.specs'->>'$.color' = 'red'", + dialect.SQLite: "json_extract(json_extract(product.metadata, '$.specs'), '$.color') = 'red'", + dialect.DuckDB: "product.metadata->'specs'->>'color' = 'red'", + dialect.BigQuery: "JSON_VALUE(JSON_QUERY(product.metadata, '$.specs'), '$.color') = 'red'", + }, + }, + { + Name: "json_has_field", + CELExpr: `has(product.metadata.brand)`, + Category: CategoryJSON, + EnvSetup: EnvWithJSON, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "product.metadata ? 'brand'", + dialect.MySQL: "JSON_CONTAINS_PATH(product.metadata, 'one', '$.brand')", + dialect.SQLite: "json_type(product.metadata, '$.brand') IS NOT NULL", + dialect.DuckDB: "json_exists(product.metadata, '$.brand')", + dialect.BigQuery: "JSON_VALUE(product.metadata, '$.brand') IS NOT NULL", + }, + }, + } +} diff --git a/testcases/operator_tests.go b/testcases/operator_tests.go new file mode 100644 index 0000000..c80b779 --- /dev/null +++ b/testcases/operator_tests.go @@ -0,0 +1,91 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// OperatorTests returns test cases for logical and arithmetic operators. +func OperatorTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "logical_and", + CELExpr: `name == "a" && age > 20`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name = 'a' AND age > 20", + dialect.MySQL: "name = 'a' AND age > 20", + dialect.SQLite: "name = 'a' AND age > 20", + dialect.DuckDB: "name = 'a' AND age > 20", + dialect.BigQuery: "name = 'a' AND age > 20", + }, + }, + { + Name: "logical_or", + CELExpr: `name == "a" || age > 20`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name = 'a' OR age > 20", + dialect.MySQL: "name = 'a' OR age > 20", + dialect.SQLite: "name = 'a' OR age > 20", + dialect.DuckDB: "name = 'a' OR age > 20", + dialect.BigQuery: "name = 'a' OR age > 20", + }, + }, + { + Name: "parenthesized", + CELExpr: `age >= 10 && (name == "a" || name == "b")`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "age >= 10 AND (name = 'a' OR name = 'b')", + dialect.MySQL: "age >= 10 AND (name = 'a' OR name = 'b')", + dialect.SQLite: "age >= 10 AND (name = 'a' OR name = 'b')", + dialect.DuckDB: "age >= 10 AND (name = 'a' OR name = 'b')", + dialect.BigQuery: "age >= 10 AND (name = 'a' OR name = 'b')", + }, + }, + { + Name: "addition", + CELExpr: `1 + 2 == 3`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "1 + 2 = 3", + dialect.MySQL: "1 + 2 = 3", + dialect.SQLite: "1 + 2 = 3", + dialect.DuckDB: "1 + 2 = 3", + dialect.BigQuery: "1 + 2 = 3", + }, + }, + { + Name: "modulo", + CELExpr: `5 % 3 == 2`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "MOD(5, 3) = 2", + dialect.MySQL: "MOD(5, 3) = 2", + dialect.SQLite: "MOD(5, 3) = 2", + dialect.DuckDB: "MOD(5, 3) = 2", + dialect.BigQuery: "MOD(5, 3) = 2", + }, + }, + { + Name: "string_concat", + CELExpr: `"a" + "b" == "ab"`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "'a' || 'b' = 'ab'", + dialect.MySQL: "CONCAT('a', 'b') = 'ab'", + dialect.SQLite: "'a' || 'b' = 'ab'", + dialect.DuckDB: "'a' || 'b' = 'ab'", + dialect.BigQuery: "'a' || 'b' = 'ab'", + }, + }, + { + Name: "list_concat_in", + CELExpr: `1 in [1] + [2, 3]`, + Category: CategoryOperator, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "1 = ANY(ARRAY[1] || ARRAY[2, 3])", + dialect.DuckDB: "1 = ANY([1] || [2, 3])", + dialect.BigQuery: "1 IN UNNEST([1] || [2, 3])", + }, + }, + } +} diff --git a/testcases/parameterized_tests.go b/testcases/parameterized_tests.go new file mode 100644 index 0000000..88441ba --- /dev/null +++ b/testcases/parameterized_tests.go @@ -0,0 +1,111 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// ParameterizedTests returns test cases for parameterized SQL conversion. +func ParameterizedTests() []ParameterizedTestCase { + return []ParameterizedTestCase{ + { + Name: "simple_string_equality", + CELExpr: `name == "John"`, + Category: CategoryParameterized, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name = $1", + dialect.SQLite: "name = ?", + dialect.DuckDB: "name = $1", + dialect.BigQuery: "name = @p1", + }, + WantParams: map[dialect.Name][]any{ + dialect.PostgreSQL: {"John"}, + dialect.SQLite: {"John"}, + dialect.DuckDB: {"John"}, + dialect.BigQuery: {"John"}, + }, + }, + { + Name: "multiple_string_params", + CELExpr: `name == "John" && name != "Jane"`, + Category: CategoryParameterized, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name = $1 AND name != $2", + dialect.SQLite: "name = ? AND name != ?", + dialect.DuckDB: "name = $1 AND name != $2", + dialect.BigQuery: "name = @p1 AND name != @p2", + }, + WantParams: map[dialect.Name][]any{ + dialect.PostgreSQL: {"John", "Jane"}, + dialect.SQLite: {"John", "Jane"}, + dialect.DuckDB: {"John", "Jane"}, + dialect.BigQuery: {"John", "Jane"}, + }, + }, + { + Name: "integer_equality", + CELExpr: `age == 18`, + Category: CategoryParameterized, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "age = $1", + dialect.SQLite: "age = ?", + dialect.DuckDB: "age = $1", + dialect.BigQuery: "age = @p1", + }, + WantParams: map[dialect.Name][]any{ + dialect.PostgreSQL: {int64(18)}, + dialect.SQLite: {int64(18)}, + dialect.DuckDB: {int64(18)}, + dialect.BigQuery: {int64(18)}, + }, + }, + { + Name: "integer_range", + CELExpr: `age > 21 && age < 65`, + Category: CategoryParameterized, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "age > $1 AND age < $2", + dialect.SQLite: "age > ? AND age < ?", + dialect.DuckDB: "age > $1 AND age < $2", + dialect.BigQuery: "age > @p1 AND age < @p2", + }, + WantParams: map[dialect.Name][]any{ + dialect.PostgreSQL: {int64(21), int64(65)}, + dialect.SQLite: {int64(21), int64(65)}, + dialect.DuckDB: {int64(21), int64(65)}, + dialect.BigQuery: {int64(21), int64(65)}, + }, + }, + { + Name: "double_equality", + CELExpr: `salary == 50000.50`, + Category: CategoryParameterized, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "salary = $1", + dialect.SQLite: "salary = ?", + dialect.DuckDB: "salary = $1", + dialect.BigQuery: "salary = @p1", + }, + WantParams: map[dialect.Name][]any{ + dialect.PostgreSQL: {50000.50}, + dialect.SQLite: {50000.50}, + dialect.DuckDB: {50000.50}, + dialect.BigQuery: {50000.50}, + }, + }, + { + Name: "boolean_true_inline", + CELExpr: `active == true`, + Category: CategoryParameterized, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "active IS TRUE", + dialect.SQLite: "active IS TRUE", + dialect.DuckDB: "active IS TRUE", + dialect.BigQuery: "active IS TRUE", + }, + WantParams: map[dialect.Name][]any{ + dialect.PostgreSQL: {}, + dialect.SQLite: {}, + dialect.DuckDB: {}, + dialect.BigQuery: {}, + }, + }, + } +} diff --git a/testcases/regex_tests.go b/testcases/regex_tests.go new file mode 100644 index 0000000..aeded0e --- /dev/null +++ b/testcases/regex_tests.go @@ -0,0 +1,93 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// RegexTests returns test cases for regex pattern matching. +func RegexTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "simple_match", + CELExpr: `name.matches("a+")`, + Category: CategoryRegex, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name ~ 'a+'", + dialect.MySQL: "name REGEXP 'a+'", + dialect.DuckDB: "name ~ 'a+'", + dialect.BigQuery: "REGEXP_CONTAINS(name, 'a+')", + }, + SkipDialect: map[dialect.Name]string{ + dialect.SQLite: "SQLite does not support regex", + }, + }, + { + Name: "function_style", + CELExpr: `matches(name, "^[0-9]+$")`, + Category: CategoryRegex, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name ~ '^[0-9]+$'", + dialect.MySQL: "name REGEXP '^[0-9]+$'", + dialect.DuckDB: "name ~ '^[0-9]+$'", + dialect.BigQuery: "REGEXP_CONTAINS(name, '^[0-9]+$')", + }, + SkipDialect: map[dialect.Name]string{ + dialect.SQLite: "SQLite does not support regex", + }, + }, + { + Name: "word_boundary", + CELExpr: `name.matches("\\btest\\b")`, + Category: CategoryRegex, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name ~ '\\ytest\\y'", + dialect.MySQL: "name REGEXP '\\btest\\b'", + dialect.DuckDB: "name ~ '\\btest\\b'", + dialect.BigQuery: "REGEXP_CONTAINS(name, '\\btest\\b')", + }, + SkipDialect: map[dialect.Name]string{ + dialect.SQLite: "SQLite does not support regex", + }, + }, + { + Name: "digit_class", + CELExpr: `name.matches("\\d{3}-\\d{4}")`, + Category: CategoryRegex, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name ~ '[[:digit:]]{3}-[[:digit:]]{4}'", + dialect.MySQL: "name REGEXP '\\d{3}-\\d{4}'", + dialect.DuckDB: "name ~ '\\d{3}-\\d{4}'", + dialect.BigQuery: "REGEXP_CONTAINS(name, '\\d{3}-\\d{4}')", + }, + SkipDialect: map[dialect.Name]string{ + dialect.SQLite: "SQLite does not support regex", + }, + }, + { + Name: "word_class", + CELExpr: `name.matches("\\w+@\\w+\\.\\w+")`, + Category: CategoryRegex, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name ~ '[[:alnum:]_]+@[[:alnum:]_]+\\.[[:alnum:]_]+'", + dialect.MySQL: "name REGEXP '\\w+@\\w+\\.\\w+'", + dialect.DuckDB: "name ~ '\\w+@\\w+\\.\\w+'", + dialect.BigQuery: "REGEXP_CONTAINS(name, '\\w+@\\w+\\.\\w+')", + }, + SkipDialect: map[dialect.Name]string{ + dialect.SQLite: "SQLite does not support regex", + }, + }, + { + Name: "complex_pattern", + CELExpr: `name.matches(".*pattern.*")`, + Category: CategoryRegex, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name ~ '.*pattern.*'", + dialect.MySQL: "name REGEXP '.*pattern.*'", + dialect.DuckDB: "name ~ '.*pattern.*'", + dialect.BigQuery: "REGEXP_CONTAINS(name, '.*pattern.*')", + }, + SkipDialect: map[dialect.Name]string{ + dialect.SQLite: "SQLite does not support regex", + }, + }, + } +} diff --git a/testcases/string_tests.go b/testcases/string_tests.go new file mode 100644 index 0000000..5e75169 --- /dev/null +++ b/testcases/string_tests.go @@ -0,0 +1,69 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// StringTests returns test cases for string functions. +func StringTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "starts_with", + CELExpr: `name.startsWith("a")`, + Category: CategoryString, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name LIKE 'a%' ESCAPE E'\\\\'", + dialect.MySQL: "name LIKE 'a%' ESCAPE '\\\\'", + dialect.SQLite: "name LIKE 'a%' ESCAPE '\\'", + dialect.DuckDB: "name LIKE 'a%' ESCAPE '\\\\'", + dialect.BigQuery: "name LIKE 'a%' ESCAPE '\\\\'", + }, + }, + { + Name: "ends_with", + CELExpr: `name.endsWith("z")`, + Category: CategoryString, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name LIKE '%z' ESCAPE E'\\\\'", + dialect.MySQL: "name LIKE '%z' ESCAPE '\\\\'", + dialect.SQLite: "name LIKE '%z' ESCAPE '\\'", + dialect.DuckDB: "name LIKE '%z' ESCAPE '\\\\'", + dialect.BigQuery: "name LIKE '%z' ESCAPE '\\\\'", + }, + }, + { + Name: "contains", + CELExpr: `name.contains("abc")`, + Category: CategoryString, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "POSITION('abc' IN name) > 0", + dialect.MySQL: "LOCATE('abc', name) > 0", + dialect.SQLite: "INSTR(name, 'abc') > 0", + dialect.DuckDB: "CONTAINS(name, 'abc')", + dialect.BigQuery: "INSTR(name, 'abc') != 0", + }, + }, + { + Name: "size_string", + CELExpr: `size("test")`, + Category: CategoryString, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "LENGTH('test')", + dialect.MySQL: "LENGTH('test')", + dialect.SQLite: "LENGTH('test')", + dialect.DuckDB: "LENGTH('test')", + dialect.BigQuery: "LENGTH('test')", + }, + }, + { + Name: "starts_with_and_ends_with", + CELExpr: `name.startsWith("a") && name.endsWith("z")`, + Category: CategoryString, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "name LIKE 'a%' ESCAPE E'\\\\' AND name LIKE '%z' ESCAPE E'\\\\'", + dialect.MySQL: "name LIKE 'a%' ESCAPE '\\\\' AND name LIKE '%z' ESCAPE '\\\\'", + dialect.SQLite: "name LIKE 'a%' ESCAPE '\\' AND name LIKE '%z' ESCAPE '\\'", + dialect.DuckDB: "name LIKE 'a%' ESCAPE '\\\\' AND name LIKE '%z' ESCAPE '\\\\'", + dialect.BigQuery: "name LIKE 'a%' ESCAPE '\\\\' AND name LIKE '%z' ESCAPE '\\\\'", + }, + }, + } +} diff --git a/testcases/testcases.go b/testcases/testcases.go new file mode 100644 index 0000000..1de6a41 --- /dev/null +++ b/testcases/testcases.go @@ -0,0 +1,98 @@ +// Package testcases defines shared test case types and helpers for multi-dialect testing. +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// Category classifies a test case for organization and selective running. +type Category string + +// Test case categories. +const ( + CategoryBasic Category = "basic" + CategoryOperator Category = "operator" + CategoryString Category = "string" + CategoryRegex Category = "regex" + CategoryJSON Category = "json" + CategoryArray Category = "array" + CategoryComprehension Category = "comprehension" + CategoryTimestamp Category = "timestamp" + CategoryParameterized Category = "parameterized" + CategoryCast Category = "cast" + CategoryFieldAccess Category = "field_access" +) + +// ConvertTestCase defines a single CEL-to-SQL conversion test case +// with expected output per dialect. +type ConvertTestCase struct { + // Name is the test case name (used for t.Run). + Name string + + // CELExpr is the CEL expression source to compile and convert. + CELExpr string + + // Category classifies the test case. + Category Category + + // EnvSetup identifies which CEL environment setup to use. + // Empty string means "default" (basic types, no schema). + EnvSetup string + + // WantSQL maps dialect name to expected SQL output. + // If a dialect is absent, the test is skipped for that dialect. + WantSQL map[dialect.Name]string + + // WantErr maps dialect name to whether an error is expected. + // If a dialect is absent, no error is expected. + WantErr map[dialect.Name]bool + + // SkipDialect maps dialect name to a skip reason. + // If a dialect is present, the test is skipped with the given message. + SkipDialect map[dialect.Name]string +} + +// ForDialect returns the expected SQL for a given dialect, and whether the +// test case has an expectation for that dialect. +func (tc *ConvertTestCase) ForDialect(d dialect.Name) (sql string, hasExpected bool) { + sql, hasExpected = tc.WantSQL[d] + return +} + +// ShouldError returns whether an error is expected for the given dialect. +func (tc *ConvertTestCase) ShouldError(d dialect.Name) bool { + return tc.WantErr[d] +} + +// ShouldSkip returns the skip reason for a dialect, or empty string if not skipped. +func (tc *ConvertTestCase) ShouldSkip(d dialect.Name) string { + if tc.SkipDialect == nil { + return "" + } + return tc.SkipDialect[d] +} + +// ParameterizedTestCase defines a test case for parameterized SQL conversion. +type ParameterizedTestCase struct { + // Name is the test case name. + Name string + + // CELExpr is the CEL expression source. + CELExpr string + + // Category classifies the test case. + Category Category + + // EnvSetup identifies which CEL environment setup to use. + EnvSetup string + + // WantSQL maps dialect name to expected parameterized SQL output. + WantSQL map[dialect.Name]string + + // WantParams maps dialect name to expected parameter values. + WantParams map[dialect.Name][]any + + // WantErr maps dialect name to whether an error is expected. + WantErr map[dialect.Name]bool + + // SkipDialect maps dialect name to a skip reason. + SkipDialect map[dialect.Name]string +} diff --git a/testcases/timestamp_tests.go b/testcases/timestamp_tests.go new file mode 100644 index 0000000..81fe7d3 --- /dev/null +++ b/testcases/timestamp_tests.go @@ -0,0 +1,120 @@ +package testcases + +import "github.com/spandigital/cel2sql/v3/dialect" + +// TimestampTests returns test cases for timestamp and duration operations. +func TimestampTests() []ConvertTestCase { + return []ConvertTestCase{ + { + Name: "duration_second", + CELExpr: `duration("10s")`, + Category: CategoryTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "INTERVAL 10 SECOND", + dialect.MySQL: "INTERVAL 10 SECOND", + dialect.SQLite: "'+10 seconds'", + dialect.DuckDB: "INTERVAL 10 SECOND", + dialect.BigQuery: "INTERVAL 10 SECOND", + }, + }, + { + Name: "duration_minute", + CELExpr: `duration("1h1m")`, + Category: CategoryTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "INTERVAL 61 MINUTE", + dialect.MySQL: "INTERVAL 61 MINUTE", + dialect.SQLite: "'+61 minutes'", + dialect.DuckDB: "INTERVAL 61 MINUTE", + dialect.BigQuery: "INTERVAL 61 MINUTE", + }, + }, + { + Name: "duration_hour", + CELExpr: `duration("60m")`, + Category: CategoryTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "INTERVAL 1 HOUR", + dialect.MySQL: "INTERVAL 1 HOUR", + dialect.SQLite: "'+1 hours'", + dialect.DuckDB: "INTERVAL 1 HOUR", + dialect.BigQuery: "INTERVAL 1 HOUR", + }, + }, + { + Name: "timestamp_getSeconds", + CELExpr: `created_at.getSeconds()`, + Category: CategoryTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "EXTRACT(SECOND FROM created_at)", + dialect.MySQL: "EXTRACT(SECOND FROM created_at)", + dialect.SQLite: "CAST(strftime('%S', created_at) AS INTEGER)", + dialect.DuckDB: "EXTRACT(SECOND FROM created_at)", + dialect.BigQuery: "EXTRACT(SECOND FROM created_at)", + }, + }, + { + Name: "timestamp_getHours_withTimezone", + CELExpr: `created_at.getHours("Asia/Tokyo")`, + Category: CategoryTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "EXTRACT(HOUR FROM created_at AT TIME ZONE 'Asia/Tokyo')", + dialect.MySQL: "EXTRACT(HOUR FROM created_at AT TIME ZONE 'Asia/Tokyo')", + dialect.SQLite: "CAST(strftime('%H', created_at) AS INTEGER)", + dialect.DuckDB: "EXTRACT(HOUR FROM created_at AT TIME ZONE 'Asia/Tokyo')", + dialect.BigQuery: "EXTRACT(HOUR FROM created_at AT TIME ZONE 'Asia/Tokyo')", + }, + }, + { + Name: "timestamp_sub_duration", + CELExpr: `created_at - duration("60m") <= timestamp("2021-09-01T18:00:00Z")`, + Category: CategoryTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "created_at - INTERVAL 1 HOUR <= CAST('2021-09-01T18:00:00Z' AS TIMESTAMP WITH TIME ZONE)", + dialect.MySQL: "created_at - INTERVAL 1 HOUR <= CAST('2021-09-01T18:00:00Z' AS DATETIME)", + dialect.SQLite: "datetime(created_at, '-'||'+1 hours') <= datetime('2021-09-01T18:00:00Z')", + dialect.DuckDB: "created_at - INTERVAL 1 HOUR <= CAST('2021-09-01T18:00:00Z' AS TIMESTAMPTZ)", + dialect.BigQuery: "TIMESTAMP_SUB(created_at, INTERVAL 1 HOUR) <= CAST('2021-09-01T18:00:00Z' AS TIMESTAMP)", + }, + }, + { + Name: "interval_month", + CELExpr: `interval(1, MONTH)`, + Category: CategoryTimestamp, + EnvSetup: EnvWithTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "INTERVAL 1 MONTH", + dialect.MySQL: "INTERVAL 1 MONTH", + dialect.SQLite: "'+'||1||' months'", + dialect.DuckDB: "INTERVAL 1 MONTH", + dialect.BigQuery: "INTERVAL 1 MONTH", + }, + }, + { + Name: "date_getFullYear", + CELExpr: `birthday.getFullYear()`, + Category: CategoryTimestamp, + EnvSetup: EnvWithTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "EXTRACT(YEAR FROM birthday)", + dialect.MySQL: "EXTRACT(YEAR FROM birthday)", + dialect.SQLite: "CAST(strftime('%Y', birthday) AS INTEGER)", + dialect.DuckDB: "EXTRACT(YEAR FROM birthday)", + dialect.BigQuery: "EXTRACT(YEAR FROM birthday)", + }, + }, + { + Name: "datetime_getMonth", + CELExpr: `scheduled_at.getMonth()`, + Category: CategoryTimestamp, + EnvSetup: EnvWithTimestamp, + WantSQL: map[dialect.Name]string{ + dialect.PostgreSQL: "EXTRACT(MONTH FROM scheduled_at) - 1", + dialect.MySQL: "EXTRACT(MONTH FROM scheduled_at) - 1", + dialect.SQLite: "CAST(strftime('%m', scheduled_at) AS INTEGER) - 1", + dialect.DuckDB: "EXTRACT(MONTH FROM scheduled_at) - 1", + dialect.BigQuery: "EXTRACT(MONTH FROM scheduled_at) - 1", + }, + }, + } +} diff --git a/testdata/bigquery_seed.yaml b/testdata/bigquery_seed.yaml new file mode 100644 index 0000000..4375cbf --- /dev/null +++ b/testdata/bigquery_seed.yaml @@ -0,0 +1,78 @@ +projects: + - id: test-project + datasets: + - id: testdataset + tables: + - id: test_data + columns: + - name: id + type: INT64 + - name: text_val + type: STRING + - name: int_val + type: INT64 + - name: float_val + type: FLOAT64 + - name: bool_val + type: BOOL + - name: nullable_text + type: STRING + - name: nullable_int + type: INT64 + data: + - id: 1 + text_val: "hello" + int_val: 10 + float_val: 10.5 + bool_val: true + nullable_text: "present" + nullable_int: 100 + - id: 2 + text_val: "world" + int_val: 20 + float_val: 20.5 + bool_val: false + - id: 3 + text_val: "test" + int_val: 30 + float_val: 30.5 + bool_val: true + nullable_text: "here" + nullable_int: 200 + - id: 4 + text_val: "hello world" + int_val: 5 + float_val: 5.5 + bool_val: false + nullable_text: "value" + nullable_int: 50 + - id: 5 + text_val: "testing" + int_val: 15 + float_val: 15.5 + bool_val: true + nullable_text: "test" + nullable_int: 150 + - id: products + columns: + - name: id + type: INT64 + - name: name + type: STRING + - name: price + type: FLOAT64 + - name: metadata + type: STRING + data: + - id: 1 + name: "Widget" + price: 19.99 + metadata: '{"brand": "Acme", "color": "red", "specs": {"weight": 100}}' + - id: 2 + name: "Gadget" + price: 29.99 + metadata: '{"brand": "Beta", "color": "blue", "specs": {"weight": 200}}' + - id: 3 + name: "Doohickey" + price: 39.99 + metadata: '{"brand": "Acme", "color": "green", "specs": {"weight": 150}}' diff --git a/testutil/env.go b/testutil/env.go new file mode 100644 index 0000000..1239a86 --- /dev/null +++ b/testutil/env.go @@ -0,0 +1,310 @@ +package testutil + +import ( + "fmt" + + "github.com/google/cel-go/cel" + + "github.com/spandigital/cel2sql/v3" + dialectpkg "github.com/spandigital/cel2sql/v3/dialect" + bigqueryDialect "github.com/spandigital/cel2sql/v3/dialect/bigquery" + duckdbDialect "github.com/spandigital/cel2sql/v3/dialect/duckdb" + mysqlDialect "github.com/spandigital/cel2sql/v3/dialect/mysql" + sqliteDialect "github.com/spandigital/cel2sql/v3/dialect/sqlite" + "github.com/spandigital/cel2sql/v3/pg" + "github.com/spandigital/cel2sql/v3/sqltypes" + "github.com/spandigital/cel2sql/v3/testcases" +) + +// EnvResult holds both the CEL environment and convert options needed for testing. +type EnvResult struct { + Env *cel.Env + Opts []cel2sql.ConvertOption +} + +// NewDefaultEnv creates a basic CEL environment with standard variable types. +func NewDefaultEnv() (*EnvResult, error) { + env, err := cel.NewEnv( + cel.Types( + sqltypes.Date, sqltypes.Time, sqltypes.DateTime, sqltypes.Interval, sqltypes.DatePart, + ), + cel.Variable("name", cel.StringType), + cel.Variable("age", cel.IntType), + cel.Variable("adult", cel.BoolType), + cel.Variable("height", cel.DoubleType), + cel.Variable("string_list", cel.ListType(cel.StringType)), + cel.Variable("string_int_map", cel.MapType(cel.StringType, cel.IntType)), + cel.Variable("null_var", cel.NullType), + cel.Variable("created_at", cel.TimestampType), + cel.Variable("page", cel.MapType(cel.StringType, cel.StringType)), + cel.Variable("salary", cel.DoubleType), + cel.Variable("active", cel.BoolType), + cel.Variable("data", cel.BytesType), + cel.Variable("tags", cel.ListType(cel.StringType)), + cel.Variable("scores", cel.ListType(cel.IntType)), + // Cast functions + cel.Function("bool", cel.Overload("bool_from_int", []*cel.Type{cel.IntType}, cel.BoolType)), + cel.Function("int", cel.Overload("int_from_bool", []*cel.Type{cel.BoolType}, cel.IntType)), + ) + if err != nil { + return nil, err + } + return &EnvResult{Env: env}, nil +} + +// NewTimestampEnv creates a CEL environment with timestamp-related types and functions. +func NewTimestampEnv() (*EnvResult, error) { + env, err := cel.NewEnv( + cel.Types( + sqltypes.Date, sqltypes.Time, sqltypes.DateTime, sqltypes.Interval, sqltypes.DatePart, + ), + cel.Variable("name", cel.StringType), + cel.Variable("age", cel.IntType), + cel.Variable("adult", cel.BoolType), + cel.Variable("height", cel.DoubleType), + cel.Variable("string_list", cel.ListType(cel.StringType)), + cel.Variable("string_int_map", cel.MapType(cel.StringType, cel.IntType)), + cel.Variable("null_var", cel.NullType), + cel.Variable("birthday", cel.ObjectType("DATE")), + cel.Variable("fixed_time", cel.ObjectType("TIME")), + cel.Variable("scheduled_at", cel.ObjectType("DATETIME")), + cel.Variable("created_at", cel.TimestampType), + cel.Variable("page", cel.MapType(cel.StringType, cel.StringType)), + // Date part constants + cel.Variable("YEAR", cel.ObjectType("date_part")), + cel.Variable("MONTH", cel.ObjectType("date_part")), + cel.Variable("DAY", cel.ObjectType("date_part")), + cel.Variable("HOUR", cel.ObjectType("date_part")), + cel.Variable("MINUTE", cel.ObjectType("date_part")), + cel.Variable("SECOND", cel.ObjectType("date_part")), + // SQL functions + cel.Function("date", + cel.Overload("date_string", []*cel.Type{cel.StringType}, cel.ObjectType("DATE")), + cel.Overload("date_int_int_int", []*cel.Type{cel.IntType, cel.IntType, cel.IntType}, cel.ObjectType("DATE"))), + cel.Function("time", cel.Overload("time_string", []*cel.Type{cel.StringType}, cel.ObjectType("TIME"))), + cel.Function("datetime", + cel.Overload("datetime_string", []*cel.Type{cel.StringType}, cel.ObjectType("DATETIME")), + cel.Overload("datetime_date_time", []*cel.Type{cel.ObjectType("DATE"), cel.ObjectType("TIME")}, cel.ObjectType("DATETIME"))), + cel.Function("timestamp", + cel.Overload("timestamp_datetime_string", []*cel.Type{cel.ObjectType("DATETIME"), cel.StringType}, cel.TimestampType)), + cel.Function("interval", cel.Overload("interval_int_datepart", []*cel.Type{cel.IntType, cel.ObjectType("date_part")}, cel.ObjectType("INTERVAL"))), + cel.Function("current_date", cel.Overload("current_date", []*cel.Type{}, cel.ObjectType("DATE"))), + cel.Function("current_datetime", cel.Overload("current_datetime_string", []*cel.Type{cel.StringType}, cel.ObjectType("DATETIME"))), + // Date/Time arithmetic operators + cel.Function("_+_", + cel.Overload("date_add_interval", []*cel.Type{cel.ObjectType("DATE"), cel.ObjectType("INTERVAL")}, cel.ObjectType("DATE")), + cel.Overload("date_add_int", []*cel.Type{cel.ObjectType("DATE"), cel.IntType}, cel.ObjectType("DATE")), + cel.Overload("time_add_interval", []*cel.Type{cel.ObjectType("TIME"), cel.ObjectType("INTERVAL")}, cel.ObjectType("TIME")), + cel.Overload("datetime_add_interval", []*cel.Type{cel.ObjectType("DATETIME"), cel.ObjectType("INTERVAL")}, cel.ObjectType("DATETIME")), + cel.Overload("timestamp_add_interval", []*cel.Type{cel.TimestampType, cel.ObjectType("INTERVAL")}, cel.TimestampType)), + cel.Function("_-_", + cel.Overload("date_sub_interval", []*cel.Type{cel.ObjectType("DATE"), cel.ObjectType("INTERVAL")}, cel.ObjectType("DATE")), + cel.Overload("time_sub_interval", []*cel.Type{cel.ObjectType("TIME"), cel.ObjectType("INTERVAL")}, cel.ObjectType("TIME")), + cel.Overload("datetime_sub_interval", []*cel.Type{cel.ObjectType("DATETIME"), cel.ObjectType("INTERVAL")}, cel.ObjectType("DATETIME")), + cel.Overload("timestamp_sub_interval", []*cel.Type{cel.TimestampType, cel.ObjectType("INTERVAL")}, cel.TimestampType)), + // Date/Time comparison operators + cel.Function("_>_", + cel.Overload("date_gt_date", []*cel.Type{cel.ObjectType("DATE"), cel.ObjectType("DATE")}, cel.BoolType)), + // Date/Time methods + cel.Function("getFullYear", cel.MemberOverload("date_getFullYear", []*cel.Type{cel.ObjectType("DATE")}, cel.IntType)), + cel.Function("getMonth", cel.MemberOverload("datetime_getMonth", []*cel.Type{cel.ObjectType("DATETIME")}, cel.IntType)), + cel.Function("getDayOfMonth", cel.MemberOverload("datetime_getDayOfMonth", []*cel.Type{cel.ObjectType("DATETIME")}, cel.IntType)), + cel.Function("getMinutes", cel.MemberOverload("time_getMinutes", []*cel.Type{cel.ObjectType("TIME")}, cel.IntType)), + // Cast functions + cel.Function("bool", cel.Overload("bool_from_int", []*cel.Type{cel.IntType}, cel.BoolType)), + cel.Function("int", cel.Overload("int_from_bool", []*cel.Type{cel.BoolType}, cel.IntType)), + ) + if err != nil { + return nil, err + } + return &EnvResult{Env: env}, nil +} + +// NewJSONSchemaEnv creates a CEL environment with a JSON-enabled schema type provider. +func NewJSONSchemaEnv() (*EnvResult, error) { + productSchema := testcases.NewProductPGSchema() + schemas := map[string]pg.Schema{ + "product": productSchema, + } + provider := pg.NewTypeProvider(schemas) + env, err := cel.NewEnv( + cel.CustomTypeProvider(provider), + cel.Variable("product", cel.ObjectType("product")), + ) + if err != nil { + return nil, err + } + return &EnvResult{ + Env: env, + Opts: []cel2sql.ConvertOption{cel2sql.WithSchemas(schemas)}, + }, nil +} + +// PostgreSQLEnvFactory returns an environment factory for PostgreSQL tests. +func PostgreSQLEnvFactory() func(envSetup string) (*EnvResult, error) { + return func(envSetup string) (*EnvResult, error) { + switch envSetup { + case testcases.EnvDefault: + return NewDefaultEnv() + case testcases.EnvWithTimestamp: + return NewTimestampEnv() + case testcases.EnvWithJSON: + return NewJSONSchemaEnv() + default: + return nil, fmt.Errorf("unknown environment setup: %s", envSetup) + } + } +} + +// MySQLEnvFactory returns an environment factory for MySQL tests. +// It uses the same CEL environments as PostgreSQL (CEL compilation is dialect-independent) +// but sets the MySQL dialect for SQL generation. +func MySQLEnvFactory() func(envSetup string) (*EnvResult, error) { + return func(envSetup string) (*EnvResult, error) { + switch envSetup { + case testcases.EnvDefault: + result, err := NewDefaultEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(mysqlDialect.New())) + return result, nil + case testcases.EnvWithTimestamp: + result, err := NewTimestampEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(mysqlDialect.New())) + return result, nil + case testcases.EnvWithJSON: + result, err := NewJSONSchemaEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(mysqlDialect.New())) + return result, nil + default: + return nil, fmt.Errorf("unknown environment setup: %s", envSetup) + } + } +} + +// SQLiteEnvFactory returns an environment factory for SQLite tests. +// It uses the same CEL environments as PostgreSQL (CEL compilation is dialect-independent) +// but sets the SQLite dialect for SQL generation. +func SQLiteEnvFactory() func(envSetup string) (*EnvResult, error) { + return func(envSetup string) (*EnvResult, error) { + switch envSetup { + case testcases.EnvDefault: + result, err := NewDefaultEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(sqliteDialect.New())) + return result, nil + case testcases.EnvWithTimestamp: + result, err := NewTimestampEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(sqliteDialect.New())) + return result, nil + case testcases.EnvWithJSON: + result, err := NewJSONSchemaEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(sqliteDialect.New())) + return result, nil + default: + return nil, fmt.Errorf("unknown environment setup: %s", envSetup) + } + } +} + +// DuckDBEnvFactory returns an environment factory for DuckDB tests. +// It uses the same CEL environments as PostgreSQL (CEL compilation is dialect-independent) +// but sets the DuckDB dialect for SQL generation. +func DuckDBEnvFactory() func(envSetup string) (*EnvResult, error) { + return func(envSetup string) (*EnvResult, error) { + switch envSetup { + case testcases.EnvDefault: + result, err := NewDefaultEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(duckdbDialect.New())) + return result, nil + case testcases.EnvWithTimestamp: + result, err := NewTimestampEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(duckdbDialect.New())) + return result, nil + case testcases.EnvWithJSON: + result, err := NewJSONSchemaEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(duckdbDialect.New())) + return result, nil + default: + return nil, fmt.Errorf("unknown environment setup: %s", envSetup) + } + } +} + +// BigQueryEnvFactory returns an environment factory for BigQuery tests. +// It uses the same CEL environments as PostgreSQL (CEL compilation is dialect-independent) +// but sets the BigQuery dialect for SQL generation. +func BigQueryEnvFactory() func(envSetup string) (*EnvResult, error) { + return func(envSetup string) (*EnvResult, error) { + switch envSetup { + case testcases.EnvDefault: + result, err := NewDefaultEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(bigqueryDialect.New())) + return result, nil + case testcases.EnvWithTimestamp: + result, err := NewTimestampEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(bigqueryDialect.New())) + return result, nil + case testcases.EnvWithJSON: + result, err := NewJSONSchemaEnv() + if err != nil { + return nil, err + } + result.Opts = append(result.Opts, cel2sql.WithDialect(bigqueryDialect.New())) + return result, nil + default: + return nil, fmt.Errorf("unknown environment setup: %s", envSetup) + } + } +} + +// DialectEnvFactory returns an environment factory for the given dialect. +// This is a convenience function that maps dialect names to their env factories. +func DialectEnvFactory(d dialectpkg.Name) func(envSetup string) (*EnvResult, error) { + switch d { + case dialectpkg.PostgreSQL: + return PostgreSQLEnvFactory() + case dialectpkg.MySQL: + return MySQLEnvFactory() + case dialectpkg.SQLite: + return SQLiteEnvFactory() + case dialectpkg.DuckDB: + return DuckDBEnvFactory() + case dialectpkg.BigQuery: + return BigQueryEnvFactory() + default: + return func(_ string) (*EnvResult, error) { + return nil, fmt.Errorf("no environment factory for dialect %s", d) + } + } +} diff --git a/testutil/runner.go b/testutil/runner.go new file mode 100644 index 0000000..74fb8a2 --- /dev/null +++ b/testutil/runner.go @@ -0,0 +1,180 @@ +// Package testutil provides multi-dialect test runners and helpers. +package testutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/spandigital/cel2sql/v3" + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/testcases" +) + +// RunConvertTests runs a set of ConvertTestCase entries for a given dialect. +// envFactory returns an EnvResult (CEL env + convert options) for the given EnvSetup key. +// Additional opts are appended after any env-specific options. +func RunConvertTests( + t *testing.T, + dialectName dialect.Name, + cases []testcases.ConvertTestCase, + envFactory func(envSetup string) (*EnvResult, error), + opts ...cel2sql.ConvertOption, +) { + t.Helper() + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + // Check skip + if reason := tc.ShouldSkip(dialectName); reason != "" { + t.Skip(reason) + } + + // Check if we have an expectation for this dialect + wantSQL, hasExpected := tc.ForDialect(dialectName) + wantErr := tc.ShouldError(dialectName) + + if !hasExpected && !wantErr { + t.Skipf("no expected SQL for dialect %s", dialectName) + } + + // Build CEL environment + envResult, err := envFactory(tc.EnvSetup) + require.NoError(t, err, "failed to create CEL environment") + + // Compile CEL expression + ast, issues := envResult.Env.Compile(tc.CELExpr) + if issues != nil && issues.Err() != nil { + if wantErr { + return // expected compile error + } + t.Fatalf("CEL compile failed: %v", issues.Err()) + } + + // Merge options: env-specific first, then caller-provided + allOpts := make([]cel2sql.ConvertOption, 0, len(envResult.Opts)+len(opts)) + allOpts = append(allOpts, envResult.Opts...) + allOpts = append(allOpts, opts...) + + // Convert + got, err := cel2sql.Convert(ast, allOpts...) + if wantErr { + assert.Error(t, err, "expected error for dialect %s", dialectName) + return + } + + if assert.NoError(t, err) { + assert.Equal(t, wantSQL, got, "SQL mismatch for dialect %s", dialectName) + } + }) + } +} + +// RunParameterizedTests runs a set of ParameterizedTestCase entries for a given dialect. +func RunParameterizedTests( + t *testing.T, + dialectName dialect.Name, + cases []testcases.ParameterizedTestCase, + envFactory func(envSetup string) (*EnvResult, error), + opts ...cel2sql.ConvertOption, +) { + t.Helper() + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + // Check skip + if tc.SkipDialect != nil { + if reason, ok := tc.SkipDialect[dialectName]; ok && reason != "" { + t.Skip(reason) + } + } + + wantSQL, hasExpected := tc.WantSQL[dialectName] + wantErr := tc.WantErr[dialectName] + + if !hasExpected && !wantErr { + t.Skipf("no expected SQL for dialect %s", dialectName) + } + + // Build CEL environment + envResult, err := envFactory(tc.EnvSetup) + require.NoError(t, err, "failed to create CEL environment") + + // Compile CEL expression + ast, issues := envResult.Env.Compile(tc.CELExpr) + if issues != nil && issues.Err() != nil { + if wantErr { + return + } + t.Fatalf("CEL compile failed: %v", issues.Err()) + } + + // Merge options + allOpts := make([]cel2sql.ConvertOption, 0, len(envResult.Opts)+len(opts)) + allOpts = append(allOpts, envResult.Opts...) + allOpts = append(allOpts, opts...) + + // Convert + result, err := cel2sql.ConvertParameterized(ast, allOpts...) + if wantErr { + assert.Error(t, err) + return + } + + if assert.NoError(t, err) { + assert.Equal(t, wantSQL, result.SQL, "SQL mismatch for dialect %s", dialectName) + + if wantParams, ok := tc.WantParams[dialectName]; ok && len(wantParams) > 0 { + assert.Equal(t, wantParams, result.Parameters, "params mismatch for dialect %s", dialectName) + } + } + }) + } +} + +// RunAllConvertTests runs all standard test suites for a given dialect. +func RunAllConvertTests( + t *testing.T, + dialectName dialect.Name, + envFactory func(envSetup string) (*EnvResult, error), + opts ...cel2sql.ConvertOption, +) { + t.Run("basic", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.BasicTests(), envFactory, opts...) + }) + t.Run("operators", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.OperatorTests(), envFactory, opts...) + }) + t.Run("strings", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.StringTests(), envFactory, opts...) + }) + t.Run("regex", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.RegexTests(), envFactory, opts...) + }) + t.Run("casts", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.CastTests(), envFactory, opts...) + }) + t.Run("arrays", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.ArrayTests(), envFactory, opts...) + }) + t.Run("timestamps", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.TimestampTests(), envFactory, opts...) + }) + t.Run("comprehensions", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.ComprehensionTests(), envFactory, opts...) + }) + t.Run("json", func(t *testing.T) { + RunConvertTests(t, dialectName, testcases.JSONTests(), envFactory, opts...) + }) +} + +// RunAllParameterizedTests runs all parameterized test suites for a given dialect. +func RunAllParameterizedTests( + t *testing.T, + dialectName dialect.Name, + envFactory func(envSetup string) (*EnvResult, error), + opts ...cel2sql.ConvertOption, +) { + t.Run("parameterized", func(t *testing.T) { + RunParameterizedTests(t, dialectName, testcases.ParameterizedTests(), envFactory, opts...) + }) +} diff --git a/testutil/runner_bigquery_test.go b/testutil/runner_bigquery_test.go new file mode 100644 index 0000000..4381120 --- /dev/null +++ b/testutil/runner_bigquery_test.go @@ -0,0 +1,16 @@ +package testutil_test + +import ( + "testing" + + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/testutil" +) + +func TestBigQuerySharedCases(t *testing.T) { + testutil.RunAllConvertTests(t, dialect.BigQuery, testutil.BigQueryEnvFactory()) +} + +func TestBigQueryParameterizedSharedCases(t *testing.T) { + testutil.RunAllParameterizedTests(t, dialect.BigQuery, testutil.BigQueryEnvFactory()) +} diff --git a/testutil/runner_duckdb_test.go b/testutil/runner_duckdb_test.go new file mode 100644 index 0000000..c509510 --- /dev/null +++ b/testutil/runner_duckdb_test.go @@ -0,0 +1,16 @@ +package testutil_test + +import ( + "testing" + + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/testutil" +) + +func TestDuckDBSharedCases(t *testing.T) { + testutil.RunAllConvertTests(t, dialect.DuckDB, testutil.DuckDBEnvFactory()) +} + +func TestDuckDBParameterizedSharedCases(t *testing.T) { + testutil.RunAllParameterizedTests(t, dialect.DuckDB, testutil.DuckDBEnvFactory()) +} diff --git a/testutil/runner_mysql_test.go b/testutil/runner_mysql_test.go new file mode 100644 index 0000000..5087059 --- /dev/null +++ b/testutil/runner_mysql_test.go @@ -0,0 +1,12 @@ +package testutil_test + +import ( + "testing" + + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/testutil" +) + +func TestMySQLSharedCases(t *testing.T) { + testutil.RunAllConvertTests(t, dialect.MySQL, testutil.MySQLEnvFactory()) +} diff --git a/testutil/runner_pg_test.go b/testutil/runner_pg_test.go new file mode 100644 index 0000000..bfc3e15 --- /dev/null +++ b/testutil/runner_pg_test.go @@ -0,0 +1,16 @@ +package testutil_test + +import ( + "testing" + + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/testutil" +) + +func TestPostgreSQLSharedCases(t *testing.T) { + testutil.RunAllConvertTests(t, dialect.PostgreSQL, testutil.PostgreSQLEnvFactory()) +} + +func TestPostgreSQLParameterizedSharedCases(t *testing.T) { + testutil.RunAllParameterizedTests(t, dialect.PostgreSQL, testutil.PostgreSQLEnvFactory()) +} diff --git a/testutil/runner_sqlite_test.go b/testutil/runner_sqlite_test.go new file mode 100644 index 0000000..0e6cfb7 --- /dev/null +++ b/testutil/runner_sqlite_test.go @@ -0,0 +1,16 @@ +package testutil_test + +import ( + "testing" + + "github.com/spandigital/cel2sql/v3/dialect" + "github.com/spandigital/cel2sql/v3/testutil" +) + +func TestSQLiteSharedCases(t *testing.T) { + testutil.RunAllConvertTests(t, dialect.SQLite, testutil.SQLiteEnvFactory()) +} + +func TestSQLiteParameterizedSharedCases(t *testing.T) { + testutil.RunAllParameterizedTests(t, dialect.SQLite, testutil.SQLiteEnvFactory()) +} diff --git a/timestamps.go b/timestamps.go index d0defc6..13d62e9 100644 --- a/timestamps.go +++ b/timestamps.go @@ -2,7 +2,6 @@ package cel2sql import ( "fmt" - "strconv" "time" "github.com/google/cel-go/common/operators" @@ -55,7 +54,6 @@ func (con *converter) callTimestampOperation(fun string, lhs *exprpb.Expr, rhs * return newConversionError(errMsgInvalidTimestampOp, "timestamp operation requires at least one timestamp operand") } - // PostgreSQL uses simple + and - operators for date arithmetic var sqlOp string switch fun { case operators.Add: @@ -66,16 +64,10 @@ func (con *converter) callTimestampOperation(fun string, lhs *exprpb.Expr, rhs * return newConversionError(errMsgInvalidTimestampOp, "unsupported timestamp operation") } - if err := con.visitMaybeNested(timestamp, timestampParen); err != nil { - return err - } - con.str.WriteString(" ") - con.str.WriteString(sqlOp) - con.str.WriteString(" ") - if err := con.visitMaybeNested(duration, durationParen); err != nil { - return err - } - return nil + return con.dialect.WriteTimestampArithmetic(&con.str, sqlOp, + func() error { return con.visitMaybeNested(timestamp, timestampParen) }, + func() error { return con.visitMaybeNested(duration, durationParen) }, + ) } // callDuration converts CEL duration expressions to PostgreSQL INTERVAL @@ -100,105 +92,96 @@ func (con *converter) callDuration(_ *exprpb.Expr, args []*exprpb.Expr) error { if err != nil { return err } - con.str.WriteString("INTERVAL ") + var value int64 + var unit string switch d { case d.Round(time.Hour): - con.str.WriteString(strconv.FormatFloat(d.Hours(), 'f', 0, 64)) - con.str.WriteString(" HOUR") + value = int64(d.Hours()) + unit = "HOUR" case d.Round(time.Minute): - con.str.WriteString(strconv.FormatFloat(d.Minutes(), 'f', 0, 64)) - con.str.WriteString(" MINUTE") + value = int64(d.Minutes()) + unit = "MINUTE" case d.Round(time.Second): - con.str.WriteString(strconv.FormatFloat(d.Seconds(), 'f', 0, 64)) - con.str.WriteString(" SECOND") + value = int64(d.Seconds()) + unit = "SECOND" case d.Round(time.Millisecond): - con.str.WriteString(strconv.FormatInt(d.Milliseconds(), 10)) - con.str.WriteString(" MILLISECOND") + value = d.Milliseconds() + unit = "MILLISECOND" default: - con.str.WriteString(strconv.FormatInt(d.Truncate(time.Microsecond).Microseconds(), 10)) - con.str.WriteString(" MICROSECOND") + value = d.Truncate(time.Microsecond).Microseconds() + unit = "MICROSECOND" } + con.dialect.WriteDuration(&con.str, value, unit) return nil } -// callInterval creates PostgreSQL INTERVAL expressions +// callInterval creates INTERVAL expressions using the dialect func (con *converter) callInterval(_ *exprpb.Expr, args []*exprpb.Expr) error { - con.str.WriteString("INTERVAL ") - if err := con.visit(args[0]); err != nil { - return err - } - con.str.WriteString(" ") datePart := args[1] - con.str.WriteString(datePart.GetIdentExpr().GetName()) - return nil + unit := datePart.GetIdentExpr().GetName() + return con.dialect.WriteInterval(&con.str, func() error { + return con.visit(args[0]) + }, unit) } // callExtractFromTimestamp handles timestamp field extraction (YEAR, MONTH, DAY, etc.) func (con *converter) callExtractFromTimestamp(function string, target *exprpb.Expr, args []*exprpb.Expr) error { - // For getDayOfWeek, we need to wrap the entire EXTRACT in parentheses for modulo operation - if function == overloads.TimeGetDayOfWeek { - con.str.WriteString("(") - } - con.str.WriteString("EXTRACT(") + var part string switch function { case overloads.TimeGetFullYear: - con.str.WriteString("YEAR") + part = "YEAR" case overloads.TimeGetMonth: - con.str.WriteString("MONTH") + part = "MONTH" case overloads.TimeGetDate: - con.str.WriteString("DAY") + part = "DAY" case overloads.TimeGetHours: - con.str.WriteString("HOUR") + part = "HOUR" case overloads.TimeGetMinutes: - con.str.WriteString("MINUTE") + part = "MINUTE" case overloads.TimeGetSeconds: - con.str.WriteString("SECOND") + part = "SECOND" case overloads.TimeGetMilliseconds: - con.str.WriteString("MILLISECONDS") + part = "MILLISECONDS" case overloads.TimeGetDayOfYear: - con.str.WriteString("DOY") + part = "DOY" case overloads.TimeGetDayOfMonth: - con.str.WriteString("DAY") + part = "DAY" case overloads.TimeGetDayOfWeek: - con.str.WriteString("DOW") + part = "DOW" } - con.str.WriteString(" FROM ") - if err := con.visit(target); err != nil { - return err + + writeExpr := func() error { + return con.visit(target) } + + var writeTZ func() error if isTimestampType(con.getType(target)) && len(args) == 1 { - con.str.WriteString(" AT TIME ZONE ") - if err := con.visit(args[0]); err != nil { - return err + writeTZ = func() error { + return con.visit(args[0]) } } - con.str.WriteString(")") + + if err := con.dialect.WriteExtract(&con.str, part, writeExpr, writeTZ); err != nil { + return err + } + + // Apply CEL-specific adjustments (these are universal, not dialect-specific) switch function { case overloads.TimeGetMonth, overloads.TimeGetDayOfYear, overloads.TimeGetDayOfMonth: con.str.WriteString(" - 1") - case overloads.TimeGetDayOfWeek: - // PostgreSQL DOW: 0=Sunday, 1=Monday, ..., 6=Saturday - // CEL getDayOfWeek: 0=Monday, 1=Tuesday, ..., 6=Sunday (ISO 8601) - // Convert: (DOW + 6) % 7 - con.str.WriteString(" + 6) % 7") } return nil } -// callTimestampFromString converts string literals to PostgreSQL timestamps +// callTimestampFromString converts string literals to timestamps using the dialect func (con *converter) callTimestampFromString(_ *exprpb.Expr, args []*exprpb.Expr) error { if len(args) == 1 { - // For PostgreSQL, we need to cast the string to a timestamp - con.str.WriteString("CAST(") - err := con.visit(args[0]) - if err != nil { - return err - } - con.str.WriteString(" AS TIMESTAMP WITH TIME ZONE)") - return nil + return con.dialect.WriteTimestampCast(&con.str, func() error { + return con.visit(args[0]) + }) } else if len(args) == 2 { // Handle timestamp(datetime, timezone) format - // In PostgreSQL, use: datetime AT TIME ZONE timezone + // For most dialects: datetime AT TIME ZONE timezone err := con.visit(args[0]) if err != nil { return err From 814dacc289853931d2cf66aa3ee562ec3bccfcdb Mon Sep 17 00:00:00 2001 From: Richard Wooding Date: Wed, 25 Feb 2026 10:56:28 +0200 Subject: [PATCH 3/3] fix: use STRPOS instead of INSTR and remove ESCAPE clause for BigQuery dialect The BigQuery emulator errors on INSTR ("invalid position number") and does not support the ESCAPE keyword in LIKE patterns. Switch WriteContains to STRPOS and make WriteLikeEscape a no-op to fix CI integration test failures. Co-Authored-By: Claude Opus 4.6 --- README.md | 2 +- bigquery_integration_test.go | 2 +- dialect/bigquery/dialect.go | 13 +++++++------ testcases/string_tests.go | 8 ++++---- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 4774982..66671d3 100644 --- a/README.md +++ b/README.md @@ -167,7 +167,7 @@ sql, err := cel2sql.Convert(ast, cel2sql.WithDialect(bigquery.New())) | UNNEST | `UNNEST(x)` | `JSON_TABLE(...)` | `json_each(x)` | `UNNEST(x)` | `UNNEST(x)` | | Param placeholder | `$1, $2` | `?, ?` | `?, ?` | `$1, $2` | `@p1, @p2` | | Timestamp cast | `TIMESTAMP WITH TIME ZONE` | `DATETIME` | `datetime()` | `TIMESTAMPTZ` | `TIMESTAMP` | -| Contains | `POSITION()` | `LOCATE()` | `INSTR()` | `CONTAINS()` | `INSTR()` | +| Contains | `POSITION()` | `LOCATE()` | `INSTR()` | `CONTAINS()` | `STRPOS()` | | Index analysis | BTREE, GIN, GIN+trgm | BTREE, FULLTEXT | BTREE | ART | CLUSTERING, SEARCH_INDEX | ### Per-Dialect Type Providers diff --git a/bigquery_integration_test.go b/bigquery_integration_test.go index a56546d..1c87b12 100644 --- a/bigquery_integration_test.go +++ b/bigquery_integration_test.go @@ -261,7 +261,7 @@ func TestBigQueryOperatorsIntegration(t *testing.T) { name: "String contains", celExpr: `text_val.contains("world")`, expectedRows: 2, // "world", "hello world" - description: "String contains function (INSTR)", + description: "String contains function (STRPOS)", }, { name: "String startsWith", diff --git a/dialect/bigquery/dialect.go b/dialect/bigquery/dialect.go index d097e9a..5314237 100644 --- a/dialect/bigquery/dialect.go +++ b/dialect/bigquery/dialect.go @@ -75,9 +75,10 @@ func (d *Dialect) WriteRegexMatch(w *strings.Builder, writeTarget func() error, return nil } -// WriteLikeEscape writes the BigQuery LIKE escape clause. -func (d *Dialect) WriteLikeEscape(w *strings.Builder) { - w.WriteString(" ESCAPE '\\\\'") +// WriteLikeEscape is a no-op for BigQuery. +// BigQuery uses backslash as the default escape character in LIKE patterns +// and does not support the ESCAPE keyword. +func (d *Dialect) WriteLikeEscape(_ *strings.Builder) { } // WriteArrayMembership writes a BigQuery array membership test using IN UNNEST(). @@ -349,9 +350,9 @@ func (d *Dialect) WriteTimestampArithmetic(w *strings.Builder, op string, writeT // --- String Functions --- -// WriteContains writes INSTR(haystack, needle) != 0 for BigQuery. +// WriteContains writes STRPOS(haystack, needle) > 0 for BigQuery. func (d *Dialect) WriteContains(w *strings.Builder, writeHaystack, writeNeedle func() error) error { - w.WriteString("INSTR(") + w.WriteString("STRPOS(") if err := writeHaystack(); err != nil { return err } @@ -359,7 +360,7 @@ func (d *Dialect) WriteContains(w *strings.Builder, writeHaystack, writeNeedle f if err := writeNeedle(); err != nil { return err } - w.WriteString(") != 0") + w.WriteString(") > 0") return nil } diff --git a/testcases/string_tests.go b/testcases/string_tests.go index 5e75169..b7ec52d 100644 --- a/testcases/string_tests.go +++ b/testcases/string_tests.go @@ -14,7 +14,7 @@ func StringTests() []ConvertTestCase { dialect.MySQL: "name LIKE 'a%' ESCAPE '\\\\'", dialect.SQLite: "name LIKE 'a%' ESCAPE '\\'", dialect.DuckDB: "name LIKE 'a%' ESCAPE '\\\\'", - dialect.BigQuery: "name LIKE 'a%' ESCAPE '\\\\'", + dialect.BigQuery: "name LIKE 'a%'", }, }, { @@ -26,7 +26,7 @@ func StringTests() []ConvertTestCase { dialect.MySQL: "name LIKE '%z' ESCAPE '\\\\'", dialect.SQLite: "name LIKE '%z' ESCAPE '\\'", dialect.DuckDB: "name LIKE '%z' ESCAPE '\\\\'", - dialect.BigQuery: "name LIKE '%z' ESCAPE '\\\\'", + dialect.BigQuery: "name LIKE '%z'", }, }, { @@ -38,7 +38,7 @@ func StringTests() []ConvertTestCase { dialect.MySQL: "LOCATE('abc', name) > 0", dialect.SQLite: "INSTR(name, 'abc') > 0", dialect.DuckDB: "CONTAINS(name, 'abc')", - dialect.BigQuery: "INSTR(name, 'abc') != 0", + dialect.BigQuery: "STRPOS(name, 'abc') > 0", }, }, { @@ -62,7 +62,7 @@ func StringTests() []ConvertTestCase { dialect.MySQL: "name LIKE 'a%' ESCAPE '\\\\' AND name LIKE '%z' ESCAPE '\\\\'", dialect.SQLite: "name LIKE 'a%' ESCAPE '\\' AND name LIKE '%z' ESCAPE '\\'", dialect.DuckDB: "name LIKE 'a%' ESCAPE '\\\\' AND name LIKE '%z' ESCAPE '\\\\'", - dialect.BigQuery: "name LIKE 'a%' ESCAPE '\\\\' AND name LIKE '%z' ESCAPE '\\\\'", + dialect.BigQuery: "name LIKE 'a%' AND name LIKE '%z'", }, }, }