diff --git a/pkg/sql/colexec/aggexec/any2.go b/pkg/sql/colexec/aggexec/any2.go index ff092eb45309e..74d8c209c17e5 100644 --- a/pkg/sql/colexec/aggexec/any2.go +++ b/pkg/sql/colexec/aggexec/any2.go @@ -48,7 +48,9 @@ func (exec *anyExec) BatchFill(offset int, groups []uint64, vectors []*vector.Ve if exec.state[x].vecs[0].IsNull(uint64(y)) { exec.state[x].vecs[0].UnsetNull(uint64(y)) bs := vectors[0].GetRawBytesAt(int(idx)) - exec.state[x].vecs[0].SetRawBytesAt(int(y), bs, exec.mp) + if err := exec.state[x].vecs[0].SetRawBytesAt(int(y), bs, exec.mp); err != nil { + return err + } } } } @@ -74,7 +76,9 @@ func (exec *anyExec) BatchMerge(next AggFuncExec, offset int, groups []uint64) e if exec.state[x1].vecs[0].IsNull(uint64(y1)) { exec.state[x1].vecs[0].UnsetNull(uint64(y1)) bs := other.state[x2].vecs[0].GetRawBytesAt(int(y2)) - exec.state[x1].vecs[0].SetRawBytesAt(int(y1), bs, exec.mp) + if err := exec.state[x1].vecs[0].SetRawBytesAt(int(y1), bs, exec.mp); err != nil { + return err + } } } return nil diff --git a/pkg/sql/colexec/aggexec/any2_test.go b/pkg/sql/colexec/aggexec/any2_test.go new file mode 100644 index 0000000000000..681c488448152 --- /dev/null +++ b/pkg/sql/colexec/aggexec/any2_test.go @@ -0,0 +1,81 @@ +// Copyright 2026 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggexec + +import ( + "testing" + + "github.com/matrixorigin/matrixone/pkg/common/mpool" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/matrixorigin/matrixone/pkg/container/vector" + "github.com/stretchr/testify/require" +) + +func TestAnyValueBatchFillReturnsSetRawBytesAtError(t *testing.T) { + inputMp := mpool.MustNewZero() + input := vector.NewVec(types.T_varchar.ToType()) + require.NoError(t, vector.AppendBytes(input, make([]byte, 4096), false, inputMp)) + defer input.Free(inputMp) + + exec, limitedMp, filler := newLimitedAnyValueExec(t) + defer cleanupLimitedAnyValueExec(exec, limitedMp, filler) + + err := exec.BatchFill(0, []uint64{1}, []*vector.Vector{input}) + require.Error(t, err) +} + +func TestAnyValueBatchMergeReturnsSetRawBytesAtError(t *testing.T) { + inputMp := mpool.MustNewZero() + input := vector.NewVec(types.T_varchar.ToType()) + require.NoError(t, vector.AppendBytes(input, make([]byte, 4096), false, inputMp)) + defer input.Free(inputMp) + + sourceMp := mpool.MustNewZero() + source := makeAnyValueExec(sourceMp, 1, types.T_varchar.ToType()).(*anyExec) + require.NoError(t, source.GroupGrow(1)) + require.NoError(t, source.BatchFill(0, []uint64{1}, []*vector.Vector{input})) + defer source.Free() + + target, limitedMp, filler := newLimitedAnyValueExec(t) + defer cleanupLimitedAnyValueExec(target, limitedMp, filler) + + err := target.BatchMerge(source, 0, []uint64{1}) + require.Error(t, err) +} + +func newLimitedAnyValueExec(t *testing.T) (*anyExec, *mpool.MPool, []byte) { + t.Helper() + + limitedMp, err := mpool.NewMPool("any-value-limited", 1024*1024, mpool.NoFixed) + require.NoError(t, err) + + exec := makeAnyValueExec(limitedMp, 1, types.T_varchar.ToType()).(*anyExec) + require.NoError(t, exec.GroupGrow(1)) + + remaining := 1024*1024 - limitedMp.CurrNB() + require.Greater(t, remaining, int64(4096)) + filler, err := limitedMp.Alloc(int(remaining-1024), true) + require.NoError(t, err) + + return exec, limitedMp, filler +} + +func cleanupLimitedAnyValueExec(exec *anyExec, mp *mpool.MPool, filler []byte) { + if filler != nil { + mp.Free(filler) + } + exec.Free() + mpool.DeleteMPool(mp) +} diff --git a/pkg/sql/plan/bind_update.go b/pkg/sql/plan/bind_update.go index dc5499c38d6dd..bf5b8e445cad7 100644 --- a/pkg/sql/plan/bind_update.go +++ b/pkg/sql/plan/bind_update.go @@ -20,6 +20,7 @@ import ( "github.com/matrixorigin/matrixone/pkg/catalog" "github.com/matrixorigin/matrixone/pkg/common/moerr" + "github.com/matrixorigin/matrixone/pkg/container/types" "github.com/matrixorigin/matrixone/pkg/pb/plan" "github.com/matrixorigin/matrixone/pkg/sql/parsers/tree" ) @@ -269,7 +270,25 @@ func (builder *QueryBuilder) bindUpdate(stmt *tree.Update, bindCtx *BindContext) } } - // Recompute generated columns (both STORED and VIRTUAL are computed on write) + } + + if stmt.From != nil && len(stmt.From.Tables) > 0 { + lastNodeID, selectNode, selectNodeTag, err = builder.appendUpdateFromDedupNode( + bindCtx, lastNodeID, selectNode, selectNodeTag, dmlCtx, oldColName2Idx, newColName2Idx) + if err != nil { + return 0, err + } + } + + for i, alias := range dmlCtx.aliases { + if len(dmlCtx.updateCol2Expr[i]) == 0 { + continue + } + + tableDef := dmlCtx.tableDefs[i] + + // Recompute generated columns after UPDATE FROM dedup so generated + // expressions read the same deduped base values that will be written. for _, col := range tableDef.Cols { if col.GeneratedCol == nil { continue @@ -957,3 +976,175 @@ func (builder *QueryBuilder) bindUpdate(stmt *tree.Update, bindCtx *BindContext) return lastNodeID, err } + +func (builder *QueryBuilder) appendUpdateFromDedupNode( + bindCtx *BindContext, + lastNodeID int32, + selectNode *plan.Node, + selectNodeTag int32, + dmlCtx *DMLContext, + oldColName2Idx map[string]int32, + newColName2Idx map[string]int32, +) (int32, *plan.Node, int32, error) { + partitionByExprs := make([]*plan.Expr, 0) + partitionPos := make(map[int32]struct{}) + childColExpr := func(pos int32) *plan.Expr { + e := selectNode.ProjectList[pos] + name := "" + if col, ok := e.Expr.(*plan.Expr_Col); ok { + name = col.Col.Name + } + return &plan.Expr{ + Typ: e.Typ, + Expr: &plan.Expr_Col{ + Col: &plan.ColRef{ + RelPos: selectNodeTag, + ColPos: pos, + Name: name, + }, + }, + } + } + + for i, alias := range dmlCtx.aliases { + if len(dmlCtx.updateCol2Expr[i]) == 0 { + continue + } + + for _, col := range dmlCtx.tableDefs[i].Cols { + key := alias + "." + col.Name + oldPos, ok := oldColName2Idx[key] + if !ok { + continue + } + if _, exists := partitionPos[oldPos]; !exists { + partitionPos[oldPos] = struct{}{} + partitionByExprs = append(partitionByExprs, childColExpr(oldPos)) + } + } + } + + if len(partitionByExprs) == 0 { + return lastNodeID, selectNode, selectNodeTag, nil + } + + windowTag := builder.genNewBindTag() + partitionBy := make([]*plan.OrderBySpec, 0, len(partitionByExprs)) + for _, expr := range partitionByExprs { + partitionBy = append(partitionBy, &plan.OrderBySpec{ + Expr: expr, + Flag: plan.OrderBySpec_INTERNAL, + }) + } + lastNodeID = builder.appendNode(&plan.Node{ + NodeType: plan.Node_PARTITION, + Children: []int32{lastNodeID}, + OrderBy: partitionBy, + BindingTags: []int32{windowTag}, + }, bindCtx) + + rowNumberFunc, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "row_number", nil) + if err != nil { + return 0, nil, 0, err + } + rowNumberExpr := &plan.Expr{ + Typ: rowNumberFunc.Typ, + Expr: &plan.Expr_W{ + W: &plan.WindowSpec{ + WindowFunc: rowNumberFunc, + Name: "row_number", + PartitionBy: partitionByExprs, + Frame: &plan.FrameClause{ + Type: plan.FrameClause_ROWS, + Start: &plan.FrameBound{ + Type: plan.FrameBound_PRECEDING, + UnBounded: true, + }, + End: &plan.FrameBound{ + Type: plan.FrameBound_FOLLOWING, + UnBounded: true, + }, + }, + }, + }, + } + rowNumberIdx := int32(0) + rowNumberProjectPos := int32(len(selectNode.ProjectList)) + lastNodeID = builder.appendNode(&plan.Node{ + NodeType: plan.Node_WINDOW, + Children: []int32{lastNodeID}, + WinSpecList: []*plan.Expr{rowNumberExpr}, + WindowIdx: rowNumberIdx, + BindingTags: []int32{windowTag}, + }, bindCtx) + + windowProjectTag := builder.genNewBindTag() + windowProjectList := make([]*plan.Expr, 0, len(selectNode.ProjectList)+1) + for pos := range selectNode.ProjectList { + windowProjectList = append(windowProjectList, childColExpr(int32(pos))) + } + windowProjectList = append(windowProjectList, &plan.Expr{ + Typ: rowNumberFunc.Typ, + Expr: &plan.Expr_Col{ + Col: &plan.ColRef{ + RelPos: windowTag, + ColPos: rowNumberIdx, + Name: "__mo_update_from_dedup_row_number", + }, + }, + }) + lastNodeID = builder.appendNode(&plan.Node{ + NodeType: plan.Node_PROJECT, + Children: []int32{lastNodeID}, + ProjectList: windowProjectList, + BindingTags: []int32{windowProjectTag}, + }, bindCtx) + + rowNumberCol := &plan.Expr{ + Typ: plan.Type{Id: int32(types.T_int64), NotNullable: true}, + Expr: &plan.Expr_Col{ + Col: &plan.ColRef{ + RelPos: windowProjectTag, + ColPos: rowNumberProjectPos, + Name: "__mo_update_from_dedup_row_number", + }, + }, + } + keepFirstRowExpr, err := BindFuncExprImplByPlanExpr( + builder.GetContext(), "=", []*plan.Expr{rowNumberCol, makePlan2Int64ConstExprWithType(1)}) + if err != nil { + return 0, nil, 0, err + } + lastNodeID = builder.appendNode(&plan.Node{ + NodeType: plan.Node_FILTER, + Children: []int32{lastNodeID}, + FilterList: []*plan.Expr{keepFirstRowExpr}, + }, bindCtx) + + projectList := make([]*plan.Expr, len(selectNode.ProjectList)) + for pos, e := range selectNode.ProjectList { + name := "" + if col, ok := e.Expr.(*plan.Expr_Col); ok { + name = col.Col.Name + } + projectList[pos] = &plan.Expr{ + Typ: e.Typ, + Expr: &plan.Expr_Col{ + Col: &plan.ColRef{ + RelPos: windowProjectTag, + ColPos: int32(pos), + Name: name, + }, + }, + } + } + + projectNode := &plan.Node{ + NodeType: plan.Node_PROJECT, + Children: []int32{lastNodeID}, + ProjectList: projectList, + BindingTags: []int32{builder.genNewBindTag()}, + } + lastNodeID = builder.appendNode(projectNode, bindCtx) + return lastNodeID, projectNode, projectNode.BindingTags[0], nil +} diff --git a/pkg/sql/plan/build_test.go b/pkg/sql/plan/build_test.go index 1e931369d12bd..879ff7f8f838e 100644 --- a/pkg/sql/plan/build_test.go +++ b/pkg/sql/plan/build_test.go @@ -720,6 +720,129 @@ func TestUpdate(t *testing.T) { runTestShouldError(mock, t, sqls) } +func TestUpdatePgStyleFromDedupsDuplicateSourceMatchesOnNewPath(t *testing.T) { + mock := NewMockOptimizer(true) + + logicPlan, err := runOneStmt(mock, t, + "UPDATE NATION SET N_NAME = NATION2.N_NAME FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY") + if err != nil { + t.Fatalf("build UPDATE FROM plan: %v", err) + } + + query := logicPlan.GetQuery() + tableDef := mock.ctxt.tables["nation"] + if hasUpdateFromDedupAnyValueAgg(query, len(tableDef.Cols)) { + t.Fatalf("UPDATE FROM dedup should not aggregate update columns with any_value") + } + if !hasUpdateFromDedupWindow(query, len(tableDef.Cols)) { + t.Fatalf("UPDATE FROM should dedup duplicate source matches with row_number window") + } +} + +func TestUpdatePgStyleFromDedupPicksWholeSourceRow(t *testing.T) { + mock := NewMockOptimizer(true) + + logicPlan, err := runOneStmt(mock, t, + "UPDATE NATION SET N_NAME = NATION2.N_NAME, N_COMMENT = NATION2.N_COMMENT FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY") + if err != nil { + t.Fatalf("build UPDATE FROM plan: %v", err) + } + + query := logicPlan.GetQuery() + if hasUpdateFromDedupAnyValueAgg(query, len(mock.ctxt.tables["nation"].Cols)) { + t.Fatalf("UPDATE FROM dedup must pick a whole source row, not aggregate each update column with any_value") + } + if !hasUpdateFromDedupWindow(query, len(mock.ctxt.tables["nation"].Cols)) { + t.Fatalf("UPDATE FROM dedup should use row_number window partitioned by target old columns") + } +} + +func TestUpdatePgStyleFromDedupExpandsDefaultBeforeDedup(t *testing.T) { + mock := NewMockOptimizer(true) + setMockDefaultExpr(t, mock, "nation", "n_name", "name-default") + + logicPlan, err := runOneStmt(mock, t, + "UPDATE NATION SET N_NAME = DEFAULT FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY") + if err != nil { + t.Fatalf("build UPDATE FROM with DEFAULT: %v", err) + } + + query := logicPlan.GetQuery() + if queryContainsDefaultVal(query) { + t.Fatalf("UPDATE FROM dedup should run after DEFAULT expansion") + } + if !queryContainsStringLiteral(query, "name-default") { + t.Fatalf("UPDATE FROM dedup should retain the expanded DEFAULT expression") + } + if hasUpdateFromDedupAnyValueAgg(query, len(mock.ctxt.tables["nation"].Cols)) { + t.Fatalf("UPDATE FROM dedup should not wrap DEFAULT with any_value") + } +} + +func TestUpdatePgStyleFromDedupAllowsVectorUpdateColumn(t *testing.T) { + mock := NewMockOptimizer(true) + vecTyp := plan.Type{Id: int32(types.T_array_float32), Width: 4} + setMockColumnType(t, mock, "nation", "n_comment", vecTyp) + setMockColumnType(t, mock, "nation2", "n_comment", vecTyp) + + _, err := runOneStmt(mock, t, + "UPDATE NATION SET N_COMMENT = NATION2.N_COMMENT FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY") + if err != nil { + t.Fatalf("UPDATE FROM should allow vector update columns through row-level dedup: %v", err) + } +} + +func TestUpdatePgStyleFromDedupKeepsGeneratedColumnsAfterDedup(t *testing.T) { + mock := NewMockOptimizer(true) + setMockGeneratedColumn(t, mock, "nation", "n_comment", "n_name") + + logicPlan, err := runOneStmt(mock, t, + "UPDATE NATION SET N_NAME = NATION2.N_NAME FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY") + if err != nil { + t.Fatalf("build UPDATE FROM with generated column: %v", err) + } + + query := logicPlan.GetQuery() + if hasUpdateFromDedupAnyValueAgg(query, len(mock.ctxt.tables["nation"].Cols)) { + t.Fatalf("dedup should not aggregate generated or update columns with any_value") + } + if !hasUpdateFromDedupWindow(query, len(mock.ctxt.tables["nation"].Cols)) { + t.Fatalf("UPDATE FROM with generated column should still use row-level dedup") + } +} + +func TestUpdatePgStyleFromDedupAllowsDecimal256AndEnumUpdateColumns(t *testing.T) { + tests := []struct { + name string + typ plan.Type + sql string + }{ + { + name: "decimal256", + typ: plan.Type{Id: int32(types.T_decimal256), Width: 65, Scale: 30}, + sql: "UPDATE NATION SET N_COMMENT = REGION.R_COMMENT FROM REGION WHERE NATION.N_REGIONKEY = REGION.R_REGIONKEY", + }, + { + name: "enum", + typ: plan.Type{Id: int32(types.T_enum), Enumvalues: "small,medium,large"}, + sql: "UPDATE NATION SET N_COMMENT = CASE WHEN 1 > 0 THEN 'small' ELSE 'medium' END FROM REGION WHERE NATION.N_REGIONKEY = REGION.R_REGIONKEY", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := NewMockOptimizer(true) + setMockColumnType(t, mock, "nation", "n_comment", tt.typ) + setMockColumnType(t, mock, "region", "r_comment", tt.typ) + + _, err := runOneStmt(mock, t, tt.sql) + if err != nil { + t.Fatalf("UPDATE FROM should allow %s update columns through row-level dedup: %v", tt.name, err) + } + }) + } +} + func TestUpdateFallbackMultiTargetGeneratedColumnsKeepProjectLayout(t *testing.T) { mock := NewMockOptimizer(true) setMockGeneratedColumn(t, mock, "emp", "ename", "job") @@ -860,6 +983,11 @@ func setMockOnUpdateExpr(t *testing.T, mock *MockOptimizer, tableName, colName, } } +func setMockColumnType(t *testing.T, mock *MockOptimizer, tableName, colName string, typ plan.Type) { + col := requireMockColumn(t, mock, tableName, colName) + col.Typ = typ +} + func requireMockColumn(t *testing.T, mock *MockOptimizer, tableName, colName string) *ColDef { tableDef := mock.ctxt.tables[tableName] if tableDef == nil { @@ -906,6 +1034,72 @@ func requireFallbackSourceProjectNode(t *testing.T, query *Query, projectLen int return nil } +func hasUpdateFromDedupAnyValueAgg(query *Query, groupByLen int) bool { + for _, node := range query.Nodes { + if node.NodeType != plan.Node_AGG || len(node.GroupBy) != groupByLen { + continue + } + for _, aggExpr := range node.AggList { + if fn := aggExpr.GetF(); fn != nil && fn.Func.ObjName == "any_value" { + return true + } + } + } + return false +} + +func hasUpdateFromDedupWindow(query *Query, partitionByLen int) bool { + for _, node := range query.Nodes { + if node.NodeType != plan.Node_WINDOW { + continue + } + for _, winExpr := range node.WinSpecList { + spec := winExpr.GetW() + if spec == nil || spec.Name != "row_number" || len(spec.PartitionBy) != partitionByLen { + continue + } + return true + } + } + return false +} + +func queryContainsStringLiteral(query *Query, value string) bool { + return queryContainsExpr(query, func(expr *plan.Expr) bool { + return exprContainsStringLiteral(expr, value) + }) +} + +func queryContainsDefaultVal(query *Query) bool { + return queryContainsExpr(query, exprContainsDefaultVal) +} + +func queryContainsExpr(query *Query, accept func(*plan.Expr) bool) bool { + for _, node := range query.Nodes { + exprLists := [][]*plan.Expr{ + node.ProjectList, + node.OnList, + node.FilterList, + node.GroupBy, + node.AggList, + node.WinSpecList, + } + for _, exprList := range exprLists { + for _, expr := range exprList { + if accept(expr) { + return true + } + } + } + for _, order := range node.OrderBy { + if accept(order.Expr) { + return true + } + } + } + return false +} + func requireQueryExpr(t *testing.T, query *Query, accept func(*plan.Expr) bool, message string) *plan.Expr { for _, node := range query.Nodes { if isSinkScanProjectNode(query, node) { @@ -972,6 +1166,55 @@ func exprContainsStringLiteral(expr *plan.Expr, value string) bool { return true } } + case *plan.Expr_W: + if exprContainsStringLiteral(e.W.WindowFunc, value) { + return true + } + for _, partition := range e.W.PartitionBy { + if exprContainsStringLiteral(partition, value) { + return true + } + } + for _, order := range e.W.OrderBy { + if exprContainsStringLiteral(order.Expr, value) { + return true + } + } + } + return false +} + +func exprContainsDefaultVal(expr *plan.Expr) bool { + switch e := expr.Expr.(type) { + case *plan.Expr_Lit: + _, ok := e.Lit.Value.(*plan.Literal_Defaultval) + return ok + case *plan.Expr_F: + for _, arg := range e.F.Args { + if exprContainsDefaultVal(arg) { + return true + } + } + case *plan.Expr_List: + for _, item := range e.List.List { + if exprContainsDefaultVal(item) { + return true + } + } + case *plan.Expr_W: + if exprContainsDefaultVal(e.W.WindowFunc) { + return true + } + for _, partition := range e.W.PartitionBy { + if exprContainsDefaultVal(partition) { + return true + } + } + for _, order := range e.W.OrderBy { + if exprContainsDefaultVal(order.Expr) { + return true + } + } } return false } diff --git a/pkg/sql/plan/function/list_agg.go b/pkg/sql/plan/function/list_agg.go index 38bb30c9c6bf9..e5fdd58cba021 100644 --- a/pkg/sql/plan/function/list_agg.go +++ b/pkg/sql/plan/function/list_agg.go @@ -656,13 +656,16 @@ var AnyValueSupportedTypes = []types.T{ types.T_float32, types.T_float64, types.T_date, types.T_datetime, types.T_timestamp, types.T_time, - types.T_decimal64, types.T_decimal128, + types.T_decimal64, types.T_decimal128, types.T_decimal256, types.T_bit, types.T_year, types.T_bool, types.T_bit, types.T_varchar, types.T_char, types.T_blob, types.T_text, types.T_datalink, types.T_uuid, types.T_binary, types.T_varbinary, types.T_json, + types.T_array_float32, types.T_array_float64, + types.T_geometry, types.T_geometry32, + types.T_enum, types.T_Rowid, } diff --git a/test/distributed/cases/dml/update/update_pg_style_from.result b/test/distributed/cases/dml/update/update_pg_style_from.result index c6e2313015890..539c875e39978 100644 --- a/test/distributed/cases/dml/update/update_pg_style_from.result +++ b/test/distributed/cases/dml/update/update_pg_style_from.result @@ -118,6 +118,47 @@ SELECT id, base, gen_col FROM gen_t ORDER BY id; 3 ¦ 30 ¦ 60 DROP TABLE gen_t; DROP TABLE gen_src; +DROP TABLE IF EXISTS gen_dup_t; +DROP TABLE IF EXISTS gen_dup_s; +CREATE TABLE gen_dup_t ( +id INT PRIMARY KEY, +base INT, +gen_col INT AS (ifnull(base, 0)) STORED +); +INSERT INTO gen_dup_t (id, base) VALUES (1, 5); +CREATE TABLE gen_dup_s (t_id INT, new_base INT); +INSERT INTO gen_dup_s VALUES (1, NULL), (1, 7); +UPDATE gen_dup_t SET base = s.new_base FROM gen_dup_s s WHERE s.t_id = gen_dup_t.id; +SELECT COUNT(*) AS valid_generated_row FROM gen_dup_t +WHERE (base IS NULL AND gen_col = 0) OR (base = 7 AND gen_col = 7); +➤ valid_generated_row[-5,64,0] 𝄀 +1 +DROP TABLE gen_dup_t; +DROP TABLE gen_dup_s; +DROP TABLE IF EXISTS whole_row_t; +DROP TABLE IF EXISTS whole_row_s; +CREATE TABLE whole_row_t ( +id INT PRIMARY KEY, +a INT, +b VARCHAR(20) +); +CREATE TABLE whole_row_s ( +t_id INT, +new_a INT, +new_b VARCHAR(20) +); +INSERT INTO whole_row_t VALUES (1, 0, 'orig'); +INSERT INTO whole_row_s VALUES (1, NULL, 'from-null-a'), (1, 7, NULL); +UPDATE whole_row_t SET a = s.new_a, b = s.new_b FROM whole_row_s s WHERE s.t_id = whole_row_t.id; +SELECT COUNT(*) AS valid_whole_row FROM whole_row_t +WHERE (a IS NULL AND b = 'from-null-a') OR (a = 7 AND b IS NULL); +➤ valid_whole_row[-5,64,0] 𝄀 +1 +SELECT COUNT(*) AS synthesized_row FROM whole_row_t WHERE a = 7 AND b = 'from-null-a'; +➤ synthesized_row[-5,64,0] 𝄀 +0 +DROP TABLE whole_row_t; +DROP TABLE whole_row_s; DROP TABLE IF EXISTS fk_parent; DROP TABLE IF EXISTS fk_child; DROP TABLE IF EXISTS fk_src; @@ -189,6 +230,31 @@ SELECT id, p_id, v FROM dup_t ORDER BY id; DROP TABLE dup_t; DROP TABLE dup_p; DROP TABLE dup_s; +DROP TABLE IF EXISTS dup_no_fk_t; +DROP TABLE IF EXISTS dup_no_fk_s; +CREATE TABLE dup_no_fk_t ( +id INT PRIMARY KEY, +v VARCHAR(20) +); +CREATE TABLE dup_no_fk_s ( +t_id INT, +v VARCHAR(20) +); +INSERT INTO dup_no_fk_t VALUES (1, 'orig'), (2, 'orig'); +INSERT INTO dup_no_fk_s VALUES (1, 'first'), (1, 'second'), (2, 'only'); +UPDATE dup_no_fk_t SET v = s.v FROM dup_no_fk_s s WHERE s.t_id = dup_no_fk_t.id; +SELECT COUNT(*) AS row_cnt, COUNT(DISTINCT id) AS id_cnt FROM dup_no_fk_t; +➤ row_cnt[-5,64,0] ¦ id_cnt[-5,64,0] 𝄀 +2 ¦ 2 +SELECT id, COUNT(*) AS per_id_cnt FROM dup_no_fk_t GROUP BY id ORDER BY id; +➤ id[4,32,0] ¦ per_id_cnt[-5,64,0] 𝄀 +1 ¦ 1 𝄀 +2 ¦ 1 +SELECT v FROM dup_no_fk_t WHERE id = 2; +➤ v[12,-1,0] 𝄀 +only +DROP TABLE dup_no_fk_t; +DROP TABLE dup_no_fk_s; DROP TABLE IF EXISTS company; DROP TABLE IF EXISTS vec_join_case; DROP TABLE IF EXISTS region; diff --git a/test/distributed/cases/dml/update/update_pg_style_from.sql b/test/distributed/cases/dml/update/update_pg_style_from.sql index e8bfd94004cde..0a77d1e833dff 100644 --- a/test/distributed/cases/dml/update/update_pg_style_from.sql +++ b/test/distributed/cases/dml/update/update_pg_style_from.sql @@ -121,6 +121,48 @@ SELECT id, base, gen_col FROM gen_t ORDER BY id; DROP TABLE gen_t; DROP TABLE gen_src; +-- Duplicate source rows must not make generated columns come from a different +-- source row than their referenced base columns. +DROP TABLE IF EXISTS gen_dup_t; +DROP TABLE IF EXISTS gen_dup_s; +CREATE TABLE gen_dup_t ( + id INT PRIMARY KEY, + base INT, + gen_col INT AS (ifnull(base, 0)) STORED +); +INSERT INTO gen_dup_t (id, base) VALUES (1, 5); +CREATE TABLE gen_dup_s (t_id INT, new_base INT); +INSERT INTO gen_dup_s VALUES (1, NULL), (1, 7); +UPDATE gen_dup_t SET base = s.new_base FROM gen_dup_s s WHERE s.t_id = gen_dup_t.id; +SELECT COUNT(*) AS valid_generated_row FROM gen_dup_t +WHERE (base IS NULL AND gen_col = 0) OR (base = 7 AND gen_col = 7); +DROP TABLE gen_dup_t; +DROP TABLE gen_dup_s; + +-- Duplicate source rows must be deduped as whole rows. Per-column any_value() +-- can synthesize (new_a = 7, new_b = 'from-null-a'), which is not present in +-- the source. +DROP TABLE IF EXISTS whole_row_t; +DROP TABLE IF EXISTS whole_row_s; +CREATE TABLE whole_row_t ( + id INT PRIMARY KEY, + a INT, + b VARCHAR(20) +); +CREATE TABLE whole_row_s ( + t_id INT, + new_a INT, + new_b VARCHAR(20) +); +INSERT INTO whole_row_t VALUES (1, 0, 'orig'); +INSERT INTO whole_row_s VALUES (1, NULL, 'from-null-a'), (1, 7, NULL); +UPDATE whole_row_t SET a = s.new_a, b = s.new_b FROM whole_row_s s WHERE s.t_id = whole_row_t.id; +SELECT COUNT(*) AS valid_whole_row FROM whole_row_t +WHERE (a IS NULL AND b = 'from-null-a') OR (a = 7 AND b IS NULL); +SELECT COUNT(*) AS synthesized_row FROM whole_row_t WHERE a = 7 AND b = 'from-null-a'; +DROP TABLE whole_row_t; +DROP TABLE whole_row_s; + -- FK target with a generated column: the FK forces the fallback planner -- (buildTableUpdate). Generated column protection must still apply there. DROP TABLE IF EXISTS fk_parent; @@ -196,6 +238,27 @@ DROP TABLE dup_t; DROP TABLE dup_p; DROP TABLE dup_s; +-- Duplicate-match on the new UPDATE path without FK constraints must still +-- update each target row once instead of producing duplicate primary keys. +DROP TABLE IF EXISTS dup_no_fk_t; +DROP TABLE IF EXISTS dup_no_fk_s; +CREATE TABLE dup_no_fk_t ( + id INT PRIMARY KEY, + v VARCHAR(20) +); +CREATE TABLE dup_no_fk_s ( + t_id INT, + v VARCHAR(20) +); +INSERT INTO dup_no_fk_t VALUES (1, 'orig'), (2, 'orig'); +INSERT INTO dup_no_fk_s VALUES (1, 'first'), (1, 'second'), (2, 'only'); +UPDATE dup_no_fk_t SET v = s.v FROM dup_no_fk_s s WHERE s.t_id = dup_no_fk_t.id; +SELECT COUNT(*) AS row_cnt, COUNT(DISTINCT id) AS id_cnt FROM dup_no_fk_t; +SELECT id, COUNT(*) AS per_id_cnt FROM dup_no_fk_t GROUP BY id ORDER BY id; +SELECT v FROM dup_no_fk_t WHERE id = 2; +DROP TABLE dup_no_fk_t; +DROP TABLE dup_no_fk_s; + DROP TABLE IF EXISTS company; DROP TABLE IF EXISTS vec_join_case; DROP TABLE IF EXISTS region;