Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package pgkit

import (
"fmt"
"maps"
"reflect"
"slices"

Expand Down Expand Up @@ -31,12 +32,8 @@ func (s *StatementBuilder) InsertRecord(record interface{}, optTableName ...stri
return InsertBuilder{InsertBuilder: insert.Into(tableName).Columns(cols...).Values(vals...)}
}

// InsertRecords builds a multi-row INSERT from a slice of records.
//
// Every record must produce the same non-empty Map column set; a drifted
// shape (e.g. mixed nil and non-nil empty slices under ,omitzero) or an
// all-default record returns a build-time error rather than emitting
// malformed multi-row SQL.
// InsertRecords builds a multi-row INSERT from a slice of records, unioning
// columns across rows and emitting DEFAULT for any slot a given row skipped.
func (s StatementBuilder) InsertRecords(recordsSlice interface{}, optTableName ...string) InsertBuilder {
insert := sq.InsertBuilder(s.StatementBuilderType)

Expand All @@ -53,7 +50,8 @@ func (s StatementBuilder) InsertRecords(recordsSlice interface{}, optTableName .
tableName = optTableName[0]
}

var baseCols []string
rows := make([]map[string]any, 0, v.Len())
colSet := map[string]struct{}{}
for i := 0; i < v.Len(); i++ {
record := v.Index(i).Interface()

Expand All @@ -67,25 +65,39 @@ func (s StatementBuilder) InsertRecords(recordsSlice interface{}, optTableName .
if err != nil {
return InsertBuilder{InsertBuilder: insert, err: wrapErr(err)}
}
if len(cols) == 0 {
return InsertBuilder{InsertBuilder: insert, err: wrapErr(fmt.Errorf("Map returned no columns for record %d (%T); for an all-default INSERT use SQL.InsertDefaults (single-row only)", i, record))}
byCol := make(map[string]any, len(cols))
for j, c := range cols {
byCol[c] = vals[j]
colSet[c] = struct{}{}
}
rows = append(rows, byCol)
}

// slices.Sorted matches Map's lexical column order, so generated SQL
// lines up with what callers see when they call Map(record) directly.
allCols := slices.Sorted(maps.Keys(colSet))
if len(allCols) == 0 {
hint := `SQL.InsertDefaults("<table>")`
if tableName != "" {
hint = fmt.Sprintf("SQL.InsertDefaults(%q)", tableName)
}
return InsertBuilder{InsertBuilder: insert, err: wrapErr(fmt.Errorf("Map returned no columns across any of the %d records; for all-default rows use %s in a loop", v.Len(), hint))}
}

if i == 0 {
baseCols = cols
insert = insert.Columns(cols...).Values(vals...)
} else {
if !slices.Equal(cols, baseCols) {
return InsertBuilder{
InsertBuilder: insert,
err: wrapErr(fmt.Errorf("record %d columns %v differ from record 0 columns %v", i, cols, baseCols)),
}
insert = insert.Into(tableName).Columns(allCols...)
for _, row := range rows {
padded := make([]any, len(allCols))
for i, c := range allCols {
if v, ok := row[c]; ok {
padded[i] = v
} else {
padded[i] = sqlDefault
}
insert = insert.Values(vals...)
}
insert = insert.Values(padded...)
}

return InsertBuilder{InsertBuilder: insert.Into(tableName)}
return InsertBuilder{InsertBuilder: insert}
}

// InsertDefaults builds INSERT INTO <table> DEFAULT VALUES; table must be non-empty.
Expand Down
135 changes: 105 additions & 30 deletions builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,7 @@ import (
"github.com/goware/pgkit/v2"
)

func TestInsertRecords_ColumnDriftRejected(t *testing.T) {
// ,omitzero produces different column shapes for nil vs non-nil empty
// slices; squirrel would otherwise stitch the mismatched widths into
// malformed multi-row SQL and surface only at exec time.
type Item struct {
ID int `db:"id"`
Tags []string `db:"tags,omitzero"`
}

sb := &pgkit.StatementBuilder{StatementBuilderType: sq.StatementBuilder.PlaceholderFormat(sq.Dollar)}
records := []Item{
{ID: 1, Tags: nil},
{ID: 2, Tags: []string{}},
}
b := sb.InsertRecords(records, "items")
require.Error(t, b.Err())
assert.Contains(t, b.Err().Error(), "differ from record 0")
}

func TestInsertRecords_UniformShape(t *testing.T) {
// Sanity: batches with consistent column shape across rows still build.
type Item struct {
ID int `db:"id"`
Tags []string `db:"tags,omitzero"`
Expand All @@ -43,6 +23,10 @@ func TestInsertRecords_UniformShape(t *testing.T) {
}
b := sb.InsertRecords(records, "items")
require.NoError(t, b.Err())
sql, args, err := b.ToSql()
require.NoError(t, err)
assert.Equal(t, `INSERT INTO items (id,tags) VALUES ($1,$2),($3,$4)`, sql)
assert.Equal(t, []any{1, []string{"a"}, 2, []string{"b"}}, args)
}

func TestInsertDefaults_PlainSQL(t *testing.T) {
Expand Down Expand Up @@ -126,23 +110,58 @@ func TestInsertRecord_AllDefaultsErrorHintsAtInsertDefaults(t *testing.T) {
assert.Contains(t, b.Err().Error(), `SQL.InsertDefaults("items")`)
}

func TestInsertRecords_EmptyColumnsRejected(t *testing.T) {
// Multi-row INSERT ... DEFAULT VALUES is not valid PG; the batch path
// rejects all-default records and points at the single-row InsertRecord.
func TestInsertRecords_MixedShape_UnionsAndDefaults(t *testing.T) {
// Heterogeneous batch: each row contributes a different column subset.
// The union becomes the INSERT column list; missing slots become DEFAULT.
type Item struct {
ID int `db:"id"`
Name string `db:"name,omitzero"`
Tags []string `db:"tags,omitzero"`
}

sb := &pgkit.StatementBuilder{StatementBuilderType: sq.StatementBuilder.PlaceholderFormat(sq.Dollar)}
records := []Item{{}, {}}
records := []Item{
{ID: 1, Name: "first"}, // cols = [id, name]
{ID: 2, Tags: []string{"foo"}}, // cols = [id, tags]
{ID: 3, Name: "third", Tags: []string{}}, // cols = [id, name, tags] (omitzero keeps non-nil empty)
}
b := sb.InsertRecords(records, "items")
require.Error(t, b.Err())
assert.Contains(t, b.Err().Error(), "no columns")
require.NoError(t, b.Err())
sql, args, err := b.ToSql()
require.NoError(t, err)
assert.Equal(t,
`INSERT INTO items (id,name,tags) VALUES ($1,$2,DEFAULT),($3,DEFAULT,$4),($5,$6,$7)`,
sql,
)
assert.Equal(t, []any{1, "first", 2, []string{"foo"}, 3, "third", []string{}}, args)
}

func TestInsertRecords_OmitZeroMixedSlices(t *testing.T) {
// #50 used to reject this with a drift error. The union-with-DEFAULT
// approach makes it valid: ,omitzero distinguishes nil (skipped → DEFAULT)
// from non-nil empty (included with empty literal).
type Item struct {
ID int `db:"id"`
Tags []string `db:"tags,omitzero"`
}

sb := &pgkit.StatementBuilder{StatementBuilderType: sq.StatementBuilder.PlaceholderFormat(sq.Dollar)}
records := []Item{
{ID: 1, Tags: nil}, // tags skipped → will become DEFAULT
{ID: 2, Tags: []string{}}, // tags included → empty array literal
}
b := sb.InsertRecords(records, "items")
require.NoError(t, b.Err())
sql, args, err := b.ToSql()
require.NoError(t, err)
assert.Equal(t, `INSERT INTO items (id,tags) VALUES ($1,DEFAULT),($2,$3)`, sql)
assert.Equal(t, []any{1, 2, []string{}}, args)
}

func TestInsertRecords_OmitEmptyMapDriftRejected(t *testing.T) {
// Latent footgun ,omitzero exposes: legacy ,omitempty on a map already
// produced shape drift (nil map skipped, non-nil empty map kept via the
// DeepEqual fallback). The validation catches this case for free.
func TestInsertRecords_OmitEmptyMixedMaps(t *testing.T) {
// Legacy footgun resolved: ,omitempty on a map has always treated nil
// and non-nil empty differently (DeepEqual sees them as distinct).
// Now the union path handles it instead of rejecting.
type Item struct {
ID int `db:"id"`
Tags map[string]string `db:"tags,omitempty"`
Expand All @@ -154,5 +173,61 @@ func TestInsertRecords_OmitEmptyMapDriftRejected(t *testing.T) {
{ID: 2, Tags: map[string]string{}},
}
b := sb.InsertRecords(records, "items")
require.NoError(t, b.Err())
sql, _, err := b.ToSql()
require.NoError(t, err)
assert.Equal(t, `INSERT INTO items (id,tags) VALUES ($1,DEFAULT),($2,$3)`, sql)
}

func TestInsertRecords_EmptyRowMixedWithNonEmpty(t *testing.T) {
// A row with all-skipped fields can still appear in a batch: another row
// contributes the column union, the empty row pads to all DEFAULT.
type Item struct {
Name string `db:"name,omitzero"`
Tags []string `db:"tags,omitzero"`
}

sb := &pgkit.StatementBuilder{StatementBuilderType: sq.StatementBuilder.PlaceholderFormat(sq.Dollar)}
records := []Item{
{}, // empty → all DEFAULT
{Name: "second", Tags: []string{"a"}}, // contributes the union
}
b := sb.InsertRecords(records, "items")
require.NoError(t, b.Err())
sql, args, err := b.ToSql()
require.NoError(t, err)
assert.Equal(t, `INSERT INTO items (name,tags) VALUES (DEFAULT,DEFAULT),($1,$2)`, sql)
assert.Equal(t, []any{"second", []string{"a"}}, args)
}

func TestInsertRecords_AllRowsEmpty_Rejected(t *testing.T) {
// Whole-batch empty union: no row contributed any column. Genuinely
// out of InsertRecords' scope — caller wants InsertDefaults per row.
type Item struct {
Name string `db:"name,omitzero"`
Tags []string `db:"tags,omitzero"`
}

sb := &pgkit.StatementBuilder{StatementBuilderType: sq.StatementBuilder.PlaceholderFormat(sq.Dollar)}
records := []Item{{}, {}}
b := sb.InsertRecords(records, "items")
require.Error(t, b.Err())
assert.Contains(t, b.Err().Error(), "no columns")
assert.Contains(t, b.Err().Error(), `SQL.InsertDefaults("items")`)
}

func TestInsertRecords_MapRecords(t *testing.T) {
// Map accepts records as map[string]any (mapper.go's reflect.Map case).
// Heterogeneous map batches should union just like struct batches do.
sb := &pgkit.StatementBuilder{StatementBuilderType: sq.StatementBuilder.PlaceholderFormat(sq.Dollar)}
records := []map[string]any{
{"id": 1, "name": "first"},
{"id": 2, "tags": "foo"},
}
b := sb.InsertRecords(records, "items")
require.NoError(t, b.Err())
sql, _, err := b.ToSql()
require.NoError(t, err)
// Column order is lexical (Map sorts deterministically).
assert.Equal(t, `INSERT INTO items (id,name,tags) VALUES ($1,$2,DEFAULT),($3,DEFAULT,$4)`, sql)
}
48 changes: 48 additions & 0 deletions tests/pgkit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1184,3 +1184,51 @@ func TestInsertDefaultsRoundTrip(t *testing.T) {
defer rows.Close()
require.True(t, rows.Next(), "row should exist")
}

type MixedShape struct {
ID int64 `db:"id,omitzero"`
Name string `db:"name"`
Tags []string `db:"tags,omitzero"`
Note *string `db:"note,omitempty"`
Created time.Time `db:"created_at,omitempty"`
}

func TestInsertRecordsMixedShapeRoundTrip(t *testing.T) {
// Three rows, each contributing a different column subset. Each row
// should land with caller values where supplied and DB defaults
// (or NULL on nullable cols) where the row opted out.
truncateTable(t, "mixed_shape")

note := "third row's note"
records := []*MixedShape{
{Name: "first"}, // tags → DEFAULT '{}', note → NULL
{Name: "second", Tags: []string{"a", "b"}}, // note → NULL
{Name: "third", Note: &note}, // tags → DEFAULT '{}'
}

_, err := DB.Query.Exec(context.Background(), DB.SQL.InsertRecords(records, "mixed_shape"))
require.NoError(t, err)

var out []*MixedShape
err = DB.Query.GetAll(
context.Background(),
DB.SQL.Select("*").From("mixed_shape").OrderBy("id"),
&out,
)
require.NoError(t, err)
require.Len(t, out, 3)

assert.Equal(t, "first", out[0].Name)
assert.Empty(t, out[0].Tags)
assert.Nil(t, out[0].Note)
assert.False(t, out[0].Created.IsZero(), "created_at populated by DB default")

assert.Equal(t, "second", out[1].Name)
assert.Equal(t, []string{"a", "b"}, out[1].Tags)
assert.Nil(t, out[1].Note)

assert.Equal(t, "third", out[2].Name)
assert.Empty(t, out[2].Tags)
require.NotNil(t, out[2].Note)
assert.Equal(t, "third row's note", *out[2].Note)
}
10 changes: 10 additions & 0 deletions tests/testdata/pgkit_test_db.sql
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,13 @@ CREATE TABLE default_only (
id BIGSERIAL PRIMARY KEY,
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL
);

-- mixed_shape proves heterogeneous batches survive end-to-end: nullable +
-- defaulted columns let different rows opt out of different fields.
CREATE TABLE mixed_shape (
id BIGSERIAL PRIMARY KEY,
name TEXT NOT NULL,
tags TEXT[] NOT NULL DEFAULT '{}',
note TEXT NULL,
created_at TIMESTAMP WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP NOT NULL
);
Loading