From 00cffa4ccfef8c40bf7eb72204e8c0e8b89905b8 Mon Sep 17 00:00:00 2001 From: Cao Kai Date: Thu, 11 Jun 2026 14:12:58 +0800 Subject: [PATCH 1/6] fix(plan): dedup update from source matches --- pkg/sql/plan/bind_update.go | 127 ++++++++++++++++++ pkg/sql/plan/build_test.go | 24 ++++ .../dml/update/update_pg_style_from.result | 22 +++ .../cases/dml/update/update_pg_style_from.sql | 20 +++ 4 files changed, 193 insertions(+) diff --git a/pkg/sql/plan/bind_update.go b/pkg/sql/plan/bind_update.go index dc5499c38d6dd..eefafa0f71afe 100644 --- a/pkg/sql/plan/bind_update.go +++ b/pkg/sql/plan/bind_update.go @@ -196,6 +196,13 @@ func (builder *QueryBuilder) bindUpdate(stmt *tree.Update, bindCtx *BindContext) selectNode := builder.qry.Nodes[lastNodeID] selectNodeTag := selectNode.BindingTags[0] + if stmt.From != nil && len(stmt.From.Tables) > 0 { + lastNodeID, selectNode, selectNodeTag, err = builder.appendUpdateFromDedupNode( + bindCtx, lastNodeID, selectNode, selectNodeTag, dmlCtx, oldColName2Idx, newColName2Idx) + if err != nil { + return 0, err + } + } for i, alias := range dmlCtx.aliases { if len(dmlCtx.updateCol2Expr[i]) == 0 { @@ -957,3 +964,123 @@ func (builder *QueryBuilder) bindUpdate(stmt *tree.Update, bindCtx *BindContext) return lastNodeID, err } + +func (builder *QueryBuilder) appendUpdateFromDedupNode( + bindCtx *BindContext, + lastNodeID int32, + selectNode *plan.Node, + selectNodeTag int32, + dmlCtx *DMLContext, + oldColName2Idx map[string]int32, + newColName2Idx map[string]int32, +) (int32, *plan.Node, int32, error) { + groupByExprs := make([]*plan.Expr, 0) + aggList := make([]*plan.Expr, 0) + groupPos := make(map[int32]int32) + aggPos := make(map[int32]int32) + groupTag := builder.genNewBindTag() + aggregateTag := builder.genNewBindTag() + + childColExpr := func(pos int32) *plan.Expr { + e := selectNode.ProjectList[pos] + name := "" + if col, ok := e.Expr.(*plan.Expr_Col); ok { + name = col.Col.Name + } + return &plan.Expr{ + Typ: e.Typ, + Expr: &plan.Expr_Col{ + Col: &plan.ColRef{ + RelPos: selectNodeTag, + ColPos: pos, + Name: name, + }, + }, + } + } + + for i, alias := range dmlCtx.aliases { + if len(dmlCtx.updateCol2Expr[i]) == 0 { + continue + } + + for _, col := range dmlCtx.tableDefs[i].Cols { + key := alias + "." + col.Name + oldPos, ok := oldColName2Idx[key] + if !ok { + continue + } + if _, exists := groupPos[oldPos]; !exists { + groupPos[oldPos] = int32(len(groupByExprs)) + groupByExprs = append(groupByExprs, childColExpr(oldPos)) + } + + updatePos, ok := newColName2Idx[key] + if !ok { + continue + } + if _, exists := aggPos[updatePos]; exists { + continue + } + aggExpr, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "any_value", []*plan.Expr{childColExpr(updatePos)}) + if err != nil { + return 0, nil, 0, err + } + aggPos[updatePos] = int32(len(aggList)) + aggList = append(aggList, aggExpr) + } + } + + projectList := make([]*plan.Expr, len(selectNode.ProjectList)) + for pos, e := range selectNode.ProjectList { + colPos := int32(pos) + relPos := groupTag + if aggColPos, ok := aggPos[colPos]; ok { + colPos = aggColPos + relPos = aggregateTag + } else if groupColPos, ok := groupPos[colPos]; ok { + colPos = groupColPos + } else { + aggExpr, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "any_value", []*plan.Expr{childColExpr(int32(pos))}) + if err != nil { + return 0, nil, 0, err + } + colPos = int32(len(aggList)) + relPos = aggregateTag + aggList = append(aggList, aggExpr) + } + name := "" + if col, ok := e.Expr.(*plan.Expr_Col); ok { + name = col.Col.Name + } + projectList[pos] = &plan.Expr{ + Typ: e.Typ, + Expr: &plan.Expr_Col{ + Col: &plan.ColRef{ + RelPos: relPos, + ColPos: colPos, + Name: name, + }, + }, + } + } + + aggNode := &plan.Node{ + NodeType: plan.Node_AGG, + Children: []int32{lastNodeID}, + GroupBy: groupByExprs, + AggList: aggList, + BindingTags: []int32{groupTag, aggregateTag}, + SpillMem: builder.aggSpillMem, + } + lastNodeID = builder.appendNode(aggNode, bindCtx) + + projectNode := &plan.Node{ + NodeType: plan.Node_PROJECT, + Children: []int32{lastNodeID}, + ProjectList: projectList, + BindingTags: []int32{builder.genNewBindTag()}, + } + lastNodeID = builder.appendNode(projectNode, bindCtx) + return lastNodeID, projectNode, projectNode.BindingTags[0], nil +} diff --git a/pkg/sql/plan/build_test.go b/pkg/sql/plan/build_test.go index 1e931369d12bd..32d80a61ef63c 100644 --- a/pkg/sql/plan/build_test.go +++ b/pkg/sql/plan/build_test.go @@ -720,6 +720,30 @@ func TestUpdate(t *testing.T) { runTestShouldError(mock, t, sqls) } +func TestUpdatePgStyleFromDedupsDuplicateSourceMatchesOnNewPath(t *testing.T) { + mock := NewMockOptimizer(true) + + logicPlan, err := runOneStmt(mock, t, + "UPDATE NATION SET N_NAME = NATION2.N_NAME FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY") + if err != nil { + t.Fatalf("build UPDATE FROM plan: %v", err) + } + + tableDef := mock.ctxt.tables["nation"] + for _, node := range logicPlan.GetQuery().Nodes { + if node.NodeType != plan.Node_AGG { + continue + } + if len(node.GroupBy) != len(tableDef.Cols) || len(node.AggList) != 1 { + continue + } + if fn := node.AggList[0].GetF(); fn != nil && fn.Func.ObjName == "any_value" { + return + } + } + t.Fatalf("UPDATE FROM should dedup duplicate source matches with AGG any_value over update columns") +} + func TestUpdateFallbackMultiTargetGeneratedColumnsKeepProjectLayout(t *testing.T) { mock := NewMockOptimizer(true) setMockGeneratedColumn(t, mock, "emp", "ename", "job") diff --git a/test/distributed/cases/dml/update/update_pg_style_from.result b/test/distributed/cases/dml/update/update_pg_style_from.result index c6e2313015890..a64db0c246f41 100644 --- a/test/distributed/cases/dml/update/update_pg_style_from.result +++ b/test/distributed/cases/dml/update/update_pg_style_from.result @@ -189,6 +189,28 @@ SELECT id, p_id, v FROM dup_t ORDER BY id; DROP TABLE dup_t; DROP TABLE dup_p; DROP TABLE dup_s; +DROP TABLE IF EXISTS dup_no_fk_t; +DROP TABLE IF EXISTS dup_no_fk_s; +CREATE TABLE dup_no_fk_t ( +id INT PRIMARY KEY, +v VARCHAR(20) +); +CREATE TABLE dup_no_fk_s ( +t_id INT, +v VARCHAR(20) +); +INSERT INTO dup_no_fk_t VALUES (1, 'orig'), (2, 'orig'); +INSERT INTO dup_no_fk_s VALUES (1, 'first'), (1, 'second'), (2, 'only'); +UPDATE dup_no_fk_t SET v = s.v FROM dup_no_fk_s s WHERE s.t_id = dup_no_fk_t.id; +SELECT COUNT(*) AS row_cnt, COUNT(DISTINCT id) AS id_cnt FROM dup_no_fk_t; +➤ row_cnt[-5,64,0] ¦ id_cnt[-5,64,0] 𝄀 +2 ¦ 2 +SELECT id, COUNT(*) AS per_id_cnt FROM dup_no_fk_t GROUP BY id ORDER BY id; +➤ id[4,32,0] ¦ per_id_cnt[-5,64,0] 𝄀 +1 ¦ 1 𝄀 +2 ¦ 1 +DROP TABLE dup_no_fk_t; +DROP TABLE dup_no_fk_s; DROP TABLE IF EXISTS company; DROP TABLE IF EXISTS vec_join_case; DROP TABLE IF EXISTS region; diff --git a/test/distributed/cases/dml/update/update_pg_style_from.sql b/test/distributed/cases/dml/update/update_pg_style_from.sql index e8bfd94004cde..691f65b754dbd 100644 --- a/test/distributed/cases/dml/update/update_pg_style_from.sql +++ b/test/distributed/cases/dml/update/update_pg_style_from.sql @@ -196,6 +196,26 @@ DROP TABLE dup_t; DROP TABLE dup_p; DROP TABLE dup_s; +-- Duplicate-match on the new UPDATE path without FK constraints must still +-- update each target row once instead of producing duplicate primary keys. +DROP TABLE IF EXISTS dup_no_fk_t; +DROP TABLE IF EXISTS dup_no_fk_s; +CREATE TABLE dup_no_fk_t ( + id INT PRIMARY KEY, + v VARCHAR(20) +); +CREATE TABLE dup_no_fk_s ( + t_id INT, + v VARCHAR(20) +); +INSERT INTO dup_no_fk_t VALUES (1, 'orig'), (2, 'orig'); +INSERT INTO dup_no_fk_s VALUES (1, 'first'), (1, 'second'), (2, 'only'); +UPDATE dup_no_fk_t SET v = s.v FROM dup_no_fk_s s WHERE s.t_id = dup_no_fk_t.id; +SELECT COUNT(*) AS row_cnt, COUNT(DISTINCT id) AS id_cnt FROM dup_no_fk_t; +SELECT id, COUNT(*) AS per_id_cnt FROM dup_no_fk_t GROUP BY id ORDER BY id; +DROP TABLE dup_no_fk_t; +DROP TABLE dup_no_fk_s; + DROP TABLE IF EXISTS company; DROP TABLE IF EXISTS vec_join_case; DROP TABLE IF EXISTS region; From 1315da8dbffe4fc104db665ca2d24c4ea561f4f3 Mon Sep 17 00:00:00 2001 From: Cao Kai Date: Thu, 11 Jun 2026 15:36:24 +0800 Subject: [PATCH 2/6] fix(plan): handle update from dedup review --- pkg/sql/plan/bind_update.go | 15 +++-- pkg/sql/plan/build_test.go | 66 +++++++++++++++++++ pkg/sql/plan/function/list_agg.go | 2 + .../dml/update/update_pg_style_from.result | 3 + .../cases/dml/update/update_pg_style_from.sql | 1 + 5 files changed, 80 insertions(+), 7 deletions(-) diff --git a/pkg/sql/plan/bind_update.go b/pkg/sql/plan/bind_update.go index eefafa0f71afe..6ec89e738b9c8 100644 --- a/pkg/sql/plan/bind_update.go +++ b/pkg/sql/plan/bind_update.go @@ -196,13 +196,6 @@ func (builder *QueryBuilder) bindUpdate(stmt *tree.Update, bindCtx *BindContext) selectNode := builder.qry.Nodes[lastNodeID] selectNodeTag := selectNode.BindingTags[0] - if stmt.From != nil && len(stmt.From.Tables) > 0 { - lastNodeID, selectNode, selectNodeTag, err = builder.appendUpdateFromDedupNode( - bindCtx, lastNodeID, selectNode, selectNodeTag, dmlCtx, oldColName2Idx, newColName2Idx) - if err != nil { - return 0, err - } - } for i, alias := range dmlCtx.aliases { if len(dmlCtx.updateCol2Expr[i]) == 0 { @@ -291,6 +284,14 @@ func (builder *QueryBuilder) bindUpdate(stmt *tree.Update, bindCtx *BindContext) } } + if stmt.From != nil && len(stmt.From.Tables) > 0 { + lastNodeID, selectNode, selectNodeTag, err = builder.appendUpdateFromDedupNode( + bindCtx, lastNodeID, selectNode, selectNodeTag, dmlCtx, oldColName2Idx, newColName2Idx) + if err != nil { + return 0, err + } + } + for i, tableDef := range dmlCtx.tableDefs { if updateAutoIncrCols[i] { lastNodeID = builder.appendNode(&plan.Node{ diff --git a/pkg/sql/plan/build_test.go b/pkg/sql/plan/build_test.go index 32d80a61ef63c..53ead3ab69a11 100644 --- a/pkg/sql/plan/build_test.go +++ b/pkg/sql/plan/build_test.go @@ -744,6 +744,46 @@ func TestUpdatePgStyleFromDedupsDuplicateSourceMatchesOnNewPath(t *testing.T) { t.Fatalf("UPDATE FROM should dedup duplicate source matches with AGG any_value over update columns") } +func TestUpdatePgStyleFromDedupExpandsDefaultBeforeAnyValue(t *testing.T) { + mock := NewMockOptimizer(true) + setMockDefaultExpr(t, mock, "nation", "n_name", "name-default") + + logicPlan, err := runOneStmt(mock, t, + "UPDATE NATION SET N_NAME = DEFAULT FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY") + if err != nil { + t.Fatalf("build UPDATE FROM with DEFAULT: %v", err) + } + + for _, node := range logicPlan.GetQuery().Nodes { + if node.NodeType != plan.Node_AGG { + continue + } + for _, aggExpr := range node.AggList { + if exprContainsDefaultVal(aggExpr) { + t.Fatalf("dedup any_value should not wrap raw DEFAULT marker: %v", aggExpr) + } + if fn := aggExpr.GetF(); fn != nil && fn.Func.ObjName == "any_value" && + exprContainsStringLiteral(aggExpr, "name-default") { + return + } + } + } + t.Fatalf("UPDATE FROM dedup should aggregate the expanded DEFAULT expression") +} + +func TestUpdatePgStyleFromDedupAllowsVectorUpdateColumn(t *testing.T) { + mock := NewMockOptimizer(true) + vecTyp := plan.Type{Id: int32(types.T_array_float32), Width: 4} + setMockColumnType(t, mock, "nation", "n_comment", vecTyp) + setMockColumnType(t, mock, "nation2", "n_comment", vecTyp) + + _, err := runOneStmt(mock, t, + "UPDATE NATION SET N_COMMENT = NATION2.N_COMMENT FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY") + if err != nil { + t.Fatalf("UPDATE FROM should allow vector update columns through any_value dedup: %v", err) + } +} + func TestUpdateFallbackMultiTargetGeneratedColumnsKeepProjectLayout(t *testing.T) { mock := NewMockOptimizer(true) setMockGeneratedColumn(t, mock, "emp", "ename", "job") @@ -884,6 +924,11 @@ func setMockOnUpdateExpr(t *testing.T, mock *MockOptimizer, tableName, colName, } } +func setMockColumnType(t *testing.T, mock *MockOptimizer, tableName, colName string, typ plan.Type) { + col := requireMockColumn(t, mock, tableName, colName) + col.Typ = typ +} + func requireMockColumn(t *testing.T, mock *MockOptimizer, tableName, colName string) *ColDef { tableDef := mock.ctxt.tables[tableName] if tableDef == nil { @@ -1000,6 +1045,27 @@ func exprContainsStringLiteral(expr *plan.Expr, value string) bool { return false } +func exprContainsDefaultVal(expr *plan.Expr) bool { + switch e := expr.Expr.(type) { + case *plan.Expr_Lit: + _, ok := e.Lit.Value.(*plan.Literal_Defaultval) + return ok + case *plan.Expr_F: + for _, arg := range e.F.Args { + if exprContainsDefaultVal(arg) { + return true + } + } + case *plan.Expr_List: + for _, item := range e.List.List { + if exprContainsDefaultVal(item) { + return true + } + } + } + return false +} + func exprContainsColName(expr *plan.Expr, name string) bool { switch e := expr.Expr.(type) { case *plan.Expr_Col: diff --git a/pkg/sql/plan/function/list_agg.go b/pkg/sql/plan/function/list_agg.go index 38bb30c9c6bf9..93effb95ca9cd 100644 --- a/pkg/sql/plan/function/list_agg.go +++ b/pkg/sql/plan/function/list_agg.go @@ -663,6 +663,8 @@ var AnyValueSupportedTypes = []types.T{ types.T_varchar, types.T_char, types.T_blob, types.T_text, types.T_datalink, types.T_uuid, types.T_binary, types.T_varbinary, types.T_json, + types.T_array_float32, types.T_array_float64, + types.T_geometry, types.T_geometry32, types.T_Rowid, } diff --git a/test/distributed/cases/dml/update/update_pg_style_from.result b/test/distributed/cases/dml/update/update_pg_style_from.result index a64db0c246f41..00f8e5af774ca 100644 --- a/test/distributed/cases/dml/update/update_pg_style_from.result +++ b/test/distributed/cases/dml/update/update_pg_style_from.result @@ -209,6 +209,9 @@ SELECT id, COUNT(*) AS per_id_cnt FROM dup_no_fk_t GROUP BY id ORDER BY id; ➤ id[4,32,0] ¦ per_id_cnt[-5,64,0] 𝄀 1 ¦ 1 𝄀 2 ¦ 1 +SELECT v FROM dup_no_fk_t WHERE id = 2; +➤ v[12,-1,0] 𝄀 +only DROP TABLE dup_no_fk_t; DROP TABLE dup_no_fk_s; DROP TABLE IF EXISTS company; diff --git a/test/distributed/cases/dml/update/update_pg_style_from.sql b/test/distributed/cases/dml/update/update_pg_style_from.sql index 691f65b754dbd..f5bdf447362d0 100644 --- a/test/distributed/cases/dml/update/update_pg_style_from.sql +++ b/test/distributed/cases/dml/update/update_pg_style_from.sql @@ -213,6 +213,7 @@ INSERT INTO dup_no_fk_s VALUES (1, 'first'), (1, 'second'), (2, 'only'); UPDATE dup_no_fk_t SET v = s.v FROM dup_no_fk_s s WHERE s.t_id = dup_no_fk_t.id; SELECT COUNT(*) AS row_cnt, COUNT(DISTINCT id) AS id_cnt FROM dup_no_fk_t; SELECT id, COUNT(*) AS per_id_cnt FROM dup_no_fk_t GROUP BY id ORDER BY id; +SELECT v FROM dup_no_fk_t WHERE id = 2; DROP TABLE dup_no_fk_t; DROP TABLE dup_no_fk_s; From 5fcdb0256985b6c8cdc078d012933217a3ff5bbd Mon Sep 17 00:00:00 2001 From: Cao Kai Date: Thu, 11 Jun 2026 16:50:33 +0800 Subject: [PATCH 3/6] fix(plan): handle update from dedup follow-up --- pkg/sql/colexec/aggexec/any2.go | 8 +++-- pkg/sql/plan/bind_update.go | 28 +++++++++++------ pkg/sql/plan/build_test.go | 31 +++++++++++++++++++ pkg/sql/plan/function/list_agg.go | 3 +- .../dml/update/update_pg_style_from.result | 16 ++++++++++ .../cases/dml/update/update_pg_style_from.sql | 17 ++++++++++ 6 files changed, 91 insertions(+), 12 deletions(-) diff --git a/pkg/sql/colexec/aggexec/any2.go b/pkg/sql/colexec/aggexec/any2.go index ff092eb45309e..74d8c209c17e5 100644 --- a/pkg/sql/colexec/aggexec/any2.go +++ b/pkg/sql/colexec/aggexec/any2.go @@ -48,7 +48,9 @@ func (exec *anyExec) BatchFill(offset int, groups []uint64, vectors []*vector.Ve if exec.state[x].vecs[0].IsNull(uint64(y)) { exec.state[x].vecs[0].UnsetNull(uint64(y)) bs := vectors[0].GetRawBytesAt(int(idx)) - exec.state[x].vecs[0].SetRawBytesAt(int(y), bs, exec.mp) + if err := exec.state[x].vecs[0].SetRawBytesAt(int(y), bs, exec.mp); err != nil { + return err + } } } } @@ -74,7 +76,9 @@ func (exec *anyExec) BatchMerge(next AggFuncExec, offset int, groups []uint64) e if exec.state[x1].vecs[0].IsNull(uint64(y1)) { exec.state[x1].vecs[0].UnsetNull(uint64(y1)) bs := other.state[x2].vecs[0].GetRawBytesAt(int(y2)) - exec.state[x1].vecs[0].SetRawBytesAt(int(y1), bs, exec.mp) + if err := exec.state[x1].vecs[0].SetRawBytesAt(int(y1), bs, exec.mp); err != nil { + return err + } } } return nil diff --git a/pkg/sql/plan/bind_update.go b/pkg/sql/plan/bind_update.go index 6ec89e738b9c8..c97429a44870b 100644 --- a/pkg/sql/plan/bind_update.go +++ b/pkg/sql/plan/bind_update.go @@ -269,7 +269,25 @@ func (builder *QueryBuilder) bindUpdate(stmt *tree.Update, bindCtx *BindContext) } } - // Recompute generated columns (both STORED and VIRTUAL are computed on write) + } + + if stmt.From != nil && len(stmt.From.Tables) > 0 { + lastNodeID, selectNode, selectNodeTag, err = builder.appendUpdateFromDedupNode( + bindCtx, lastNodeID, selectNode, selectNodeTag, dmlCtx, oldColName2Idx, newColName2Idx) + if err != nil { + return 0, err + } + } + + for i, alias := range dmlCtx.aliases { + if len(dmlCtx.updateCol2Expr[i]) == 0 { + continue + } + + tableDef := dmlCtx.tableDefs[i] + + // Recompute generated columns after UPDATE FROM dedup so generated + // expressions read the same deduped base values that will be written. for _, col := range tableDef.Cols { if col.GeneratedCol == nil { continue @@ -284,14 +302,6 @@ func (builder *QueryBuilder) bindUpdate(stmt *tree.Update, bindCtx *BindContext) } } - if stmt.From != nil && len(stmt.From.Tables) > 0 { - lastNodeID, selectNode, selectNodeTag, err = builder.appendUpdateFromDedupNode( - bindCtx, lastNodeID, selectNode, selectNodeTag, dmlCtx, oldColName2Idx, newColName2Idx) - if err != nil { - return 0, err - } - } - for i, tableDef := range dmlCtx.tableDefs { if updateAutoIncrCols[i] { lastNodeID = builder.appendNode(&plan.Node{ diff --git a/pkg/sql/plan/build_test.go b/pkg/sql/plan/build_test.go index 53ead3ab69a11..de1a5968934b1 100644 --- a/pkg/sql/plan/build_test.go +++ b/pkg/sql/plan/build_test.go @@ -784,6 +784,37 @@ func TestUpdatePgStyleFromDedupAllowsVectorUpdateColumn(t *testing.T) { } } +func TestUpdatePgStyleFromDedupAllowsDecimal256AndEnumUpdateColumns(t *testing.T) { + tests := []struct { + name string + typ plan.Type + sql string + }{ + { + name: "decimal256", + typ: plan.Type{Id: int32(types.T_decimal256), Width: 65, Scale: 30}, + sql: "UPDATE NATION SET N_COMMENT = 1.23 FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY", + }, + { + name: "enum", + typ: plan.Type{Id: int32(types.T_enum), Enumvalues: "small,medium,large"}, + sql: "UPDATE NATION SET N_COMMENT = 'small' FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := NewMockOptimizer(true) + setMockColumnType(t, mock, "nation", "n_comment", tt.typ) + + _, err := runOneStmt(mock, t, tt.sql) + if err != nil { + t.Fatalf("UPDATE FROM should allow %s update columns through any_value dedup: %v", tt.name, err) + } + }) + } +} + func TestUpdateFallbackMultiTargetGeneratedColumnsKeepProjectLayout(t *testing.T) { mock := NewMockOptimizer(true) setMockGeneratedColumn(t, mock, "emp", "ename", "job") diff --git a/pkg/sql/plan/function/list_agg.go b/pkg/sql/plan/function/list_agg.go index 93effb95ca9cd..e5fdd58cba021 100644 --- a/pkg/sql/plan/function/list_agg.go +++ b/pkg/sql/plan/function/list_agg.go @@ -656,7 +656,7 @@ var AnyValueSupportedTypes = []types.T{ types.T_float32, types.T_float64, types.T_date, types.T_datetime, types.T_timestamp, types.T_time, - types.T_decimal64, types.T_decimal128, + types.T_decimal64, types.T_decimal128, types.T_decimal256, types.T_bit, types.T_year, types.T_bool, types.T_bit, @@ -665,6 +665,7 @@ var AnyValueSupportedTypes = []types.T{ types.T_binary, types.T_varbinary, types.T_json, types.T_array_float32, types.T_array_float64, types.T_geometry, types.T_geometry32, + types.T_enum, types.T_Rowid, } diff --git a/test/distributed/cases/dml/update/update_pg_style_from.result b/test/distributed/cases/dml/update/update_pg_style_from.result index 00f8e5af774ca..c856484ddce5c 100644 --- a/test/distributed/cases/dml/update/update_pg_style_from.result +++ b/test/distributed/cases/dml/update/update_pg_style_from.result @@ -118,6 +118,22 @@ SELECT id, base, gen_col FROM gen_t ORDER BY id; 3 ¦ 30 ¦ 60 DROP TABLE gen_t; DROP TABLE gen_src; +DROP TABLE IF EXISTS gen_dup_t; +DROP TABLE IF EXISTS gen_dup_s; +CREATE TABLE gen_dup_t ( +id INT PRIMARY KEY, +base INT, +gen_col INT AS (ifnull(base, 0)) STORED +); +INSERT INTO gen_dup_t (id, base) VALUES (1, 5); +CREATE TABLE gen_dup_s (t_id INT, new_base INT); +INSERT INTO gen_dup_s VALUES (1, NULL), (1, 7); +UPDATE gen_dup_t SET base = s.new_base FROM gen_dup_s s WHERE s.t_id = gen_dup_t.id; +SELECT id, base, gen_col FROM gen_dup_t ORDER BY id; +➤ id[4,32,0] ¦ base[4,32,0] ¦ gen_col[4,32,0] 𝄀 +1 ¦ 7 ¦ 7 +DROP TABLE gen_dup_t; +DROP TABLE gen_dup_s; DROP TABLE IF EXISTS fk_parent; DROP TABLE IF EXISTS fk_child; DROP TABLE IF EXISTS fk_src; diff --git a/test/distributed/cases/dml/update/update_pg_style_from.sql b/test/distributed/cases/dml/update/update_pg_style_from.sql index f5bdf447362d0..71a82bb183bbe 100644 --- a/test/distributed/cases/dml/update/update_pg_style_from.sql +++ b/test/distributed/cases/dml/update/update_pg_style_from.sql @@ -121,6 +121,23 @@ SELECT id, base, gen_col FROM gen_t ORDER BY id; DROP TABLE gen_t; DROP TABLE gen_src; +-- Duplicate source rows must not make generated columns come from a different +-- source row than their referenced base columns. +DROP TABLE IF EXISTS gen_dup_t; +DROP TABLE IF EXISTS gen_dup_s; +CREATE TABLE gen_dup_t ( + id INT PRIMARY KEY, + base INT, + gen_col INT AS (ifnull(base, 0)) STORED +); +INSERT INTO gen_dup_t (id, base) VALUES (1, 5); +CREATE TABLE gen_dup_s (t_id INT, new_base INT); +INSERT INTO gen_dup_s VALUES (1, NULL), (1, 7); +UPDATE gen_dup_t SET base = s.new_base FROM gen_dup_s s WHERE s.t_id = gen_dup_t.id; +SELECT id, base, gen_col FROM gen_dup_t ORDER BY id; +DROP TABLE gen_dup_t; +DROP TABLE gen_dup_s; + -- FK target with a generated column: the FK forces the fallback planner -- (buildTableUpdate). Generated column protection must still apply there. DROP TABLE IF EXISTS fk_parent; From b570b50f33d193e67128a117f97414d53c304939 Mon Sep 17 00:00:00 2001 From: Cao Kai Date: Thu, 11 Jun 2026 17:56:02 +0800 Subject: [PATCH 4/6] test(plan): cover update from dedup follow-up --- pkg/sql/colexec/aggexec/any2_test.go | 81 ++++++++++++++++++++++++++++ pkg/sql/plan/build_test.go | 71 +++++++++++++++++++++++- 2 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 pkg/sql/colexec/aggexec/any2_test.go diff --git a/pkg/sql/colexec/aggexec/any2_test.go b/pkg/sql/colexec/aggexec/any2_test.go new file mode 100644 index 0000000000000..681c488448152 --- /dev/null +++ b/pkg/sql/colexec/aggexec/any2_test.go @@ -0,0 +1,81 @@ +// Copyright 2026 Matrix Origin +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggexec + +import ( + "testing" + + "github.com/matrixorigin/matrixone/pkg/common/mpool" + "github.com/matrixorigin/matrixone/pkg/container/types" + "github.com/matrixorigin/matrixone/pkg/container/vector" + "github.com/stretchr/testify/require" +) + +func TestAnyValueBatchFillReturnsSetRawBytesAtError(t *testing.T) { + inputMp := mpool.MustNewZero() + input := vector.NewVec(types.T_varchar.ToType()) + require.NoError(t, vector.AppendBytes(input, make([]byte, 4096), false, inputMp)) + defer input.Free(inputMp) + + exec, limitedMp, filler := newLimitedAnyValueExec(t) + defer cleanupLimitedAnyValueExec(exec, limitedMp, filler) + + err := exec.BatchFill(0, []uint64{1}, []*vector.Vector{input}) + require.Error(t, err) +} + +func TestAnyValueBatchMergeReturnsSetRawBytesAtError(t *testing.T) { + inputMp := mpool.MustNewZero() + input := vector.NewVec(types.T_varchar.ToType()) + require.NoError(t, vector.AppendBytes(input, make([]byte, 4096), false, inputMp)) + defer input.Free(inputMp) + + sourceMp := mpool.MustNewZero() + source := makeAnyValueExec(sourceMp, 1, types.T_varchar.ToType()).(*anyExec) + require.NoError(t, source.GroupGrow(1)) + require.NoError(t, source.BatchFill(0, []uint64{1}, []*vector.Vector{input})) + defer source.Free() + + target, limitedMp, filler := newLimitedAnyValueExec(t) + defer cleanupLimitedAnyValueExec(target, limitedMp, filler) + + err := target.BatchMerge(source, 0, []uint64{1}) + require.Error(t, err) +} + +func newLimitedAnyValueExec(t *testing.T) (*anyExec, *mpool.MPool, []byte) { + t.Helper() + + limitedMp, err := mpool.NewMPool("any-value-limited", 1024*1024, mpool.NoFixed) + require.NoError(t, err) + + exec := makeAnyValueExec(limitedMp, 1, types.T_varchar.ToType()).(*anyExec) + require.NoError(t, exec.GroupGrow(1)) + + remaining := 1024*1024 - limitedMp.CurrNB() + require.Greater(t, remaining, int64(4096)) + filler, err := limitedMp.Alloc(int(remaining-1024), true) + require.NoError(t, err) + + return exec, limitedMp, filler +} + +func cleanupLimitedAnyValueExec(exec *anyExec, mp *mpool.MPool, filler []byte) { + if filler != nil { + mp.Free(filler) + } + exec.Free() + mpool.DeleteMPool(mp) +} diff --git a/pkg/sql/plan/build_test.go b/pkg/sql/plan/build_test.go index de1a5968934b1..9c77eb2438a61 100644 --- a/pkg/sql/plan/build_test.go +++ b/pkg/sql/plan/build_test.go @@ -784,6 +784,27 @@ func TestUpdatePgStyleFromDedupAllowsVectorUpdateColumn(t *testing.T) { } } +func TestUpdatePgStyleFromDedupDoesNotAggregateGeneratedColumns(t *testing.T) { + mock := NewMockOptimizer(true) + setMockGeneratedColumn(t, mock, "nation", "n_comment", "n_name") + + logicPlan, err := runOneStmt(mock, t, + "UPDATE NATION SET N_NAME = NATION2.N_NAME FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY") + if err != nil { + t.Fatalf("build UPDATE FROM with generated column: %v", err) + } + + aggNode := requireUpdateFromDedupAggNode(t, logicPlan.GetQuery(), len(mock.ctxt.tables["nation"].Cols)) + if got := countFuncInExprs(aggNode.AggList, "any_value"); got != 1 { + t.Fatalf("dedup should aggregate only explicit update columns, got %d any_value exprs: %v", got, aggNode.AggList) + } + for _, aggExpr := range aggNode.AggList { + if exprContainsColName(aggExpr, "n_comment") { + t.Fatalf("generated column should be recomputed after dedup, not aggregated by dedup: %v", aggExpr) + } + } +} + func TestUpdatePgStyleFromDedupAllowsDecimal256AndEnumUpdateColumns(t *testing.T) { tests := []struct { name string @@ -793,12 +814,12 @@ func TestUpdatePgStyleFromDedupAllowsDecimal256AndEnumUpdateColumns(t *testing.T { name: "decimal256", typ: plan.Type{Id: int32(types.T_decimal256), Width: 65, Scale: 30}, - sql: "UPDATE NATION SET N_COMMENT = 1.23 FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY", + sql: "UPDATE NATION SET N_COMMENT = REGION.R_COMMENT FROM REGION WHERE NATION.N_REGIONKEY = REGION.R_REGIONKEY", }, { name: "enum", typ: plan.Type{Id: int32(types.T_enum), Enumvalues: "small,medium,large"}, - sql: "UPDATE NATION SET N_COMMENT = 'small' FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY", + sql: "UPDATE NATION SET N_COMMENT = CASE WHEN 1 > 0 THEN 'small' ELSE 'medium' END FROM REGION WHERE NATION.N_REGIONKEY = REGION.R_REGIONKEY", }, } @@ -806,6 +827,7 @@ func TestUpdatePgStyleFromDedupAllowsDecimal256AndEnumUpdateColumns(t *testing.T t.Run(tt.name, func(t *testing.T) { mock := NewMockOptimizer(true) setMockColumnType(t, mock, "nation", "n_comment", tt.typ) + setMockColumnType(t, mock, "region", "r_comment", tt.typ) _, err := runOneStmt(mock, t, tt.sql) if err != nil { @@ -1006,6 +1028,21 @@ func requireFallbackSourceProjectNode(t *testing.T, query *Query, projectLen int return nil } +func requireUpdateFromDedupAggNode(t *testing.T, query *Query, groupByLen int) *Node { + for _, node := range query.Nodes { + if node.NodeType != plan.Node_AGG || len(node.GroupBy) != groupByLen { + continue + } + for _, aggExpr := range node.AggList { + if fn := aggExpr.GetF(); fn != nil && fn.Func.ObjName == "any_value" { + return node + } + } + } + t.Fatalf("missing UPDATE FROM dedup agg node with %d group keys", groupByLen) + return nil +} + func requireQueryExpr(t *testing.T, query *Query, accept func(*plan.Expr) bool, message string) *plan.Expr { for _, node := range query.Nodes { if isSinkScanProjectNode(query, node) { @@ -1097,6 +1134,36 @@ func exprContainsDefaultVal(expr *plan.Expr) bool { return false } +func countFuncInExprs(exprs []*plan.Expr, funcName string) int { + count := 0 + for _, expr := range exprs { + count += countFuncInExpr(expr, funcName) + } + return count +} + +func countFuncInExpr(expr *plan.Expr, funcName string) int { + switch e := expr.Expr.(type) { + case *plan.Expr_F: + count := 0 + if e.F.Func.ObjName == funcName { + count++ + } + for _, arg := range e.F.Args { + count += countFuncInExpr(arg, funcName) + } + return count + case *plan.Expr_List: + count := 0 + for _, item := range e.List.List { + count += countFuncInExpr(item, funcName) + } + return count + default: + return 0 + } +} + func exprContainsColName(expr *plan.Expr, name string) bool { switch e := expr.Expr.(type) { case *plan.Expr_Col: From dce983f7aff00a488b97d5ffbbeb8759ca0ad456 Mon Sep 17 00:00:00 2001 From: Cao Kai Date: Fri, 12 Jun 2026 13:27:24 +0800 Subject: [PATCH 5/6] fix(plan): pick whole row for update from dedup --- pkg/sql/plan/bind_update.go | 157 +++++++++----- pkg/sql/plan/build_test.go | 191 +++++++++++------- .../dml/update/update_pg_style_from.result | 31 ++- .../cases/dml/update/update_pg_style_from.sql | 27 ++- 4 files changed, 282 insertions(+), 124 deletions(-) diff --git a/pkg/sql/plan/bind_update.go b/pkg/sql/plan/bind_update.go index c97429a44870b..bf5b8e445cad7 100644 --- a/pkg/sql/plan/bind_update.go +++ b/pkg/sql/plan/bind_update.go @@ -20,6 +20,7 @@ import ( "github.com/matrixorigin/matrixone/pkg/catalog" "github.com/matrixorigin/matrixone/pkg/common/moerr" + "github.com/matrixorigin/matrixone/pkg/container/types" "github.com/matrixorigin/matrixone/pkg/pb/plan" "github.com/matrixorigin/matrixone/pkg/sql/parsers/tree" ) @@ -985,13 +986,8 @@ func (builder *QueryBuilder) appendUpdateFromDedupNode( oldColName2Idx map[string]int32, newColName2Idx map[string]int32, ) (int32, *plan.Node, int32, error) { - groupByExprs := make([]*plan.Expr, 0) - aggList := make([]*plan.Expr, 0) - groupPos := make(map[int32]int32) - aggPos := make(map[int32]int32) - groupTag := builder.genNewBindTag() - aggregateTag := builder.genNewBindTag() - + partitionByExprs := make([]*plan.Expr, 0) + partitionPos := make(map[int32]struct{}) childColExpr := func(pos int32) *plan.Expr { e := selectNode.ProjectList[pos] name := "" @@ -1021,45 +1017,112 @@ func (builder *QueryBuilder) appendUpdateFromDedupNode( if !ok { continue } - if _, exists := groupPos[oldPos]; !exists { - groupPos[oldPos] = int32(len(groupByExprs)) - groupByExprs = append(groupByExprs, childColExpr(oldPos)) - } - - updatePos, ok := newColName2Idx[key] - if !ok { - continue - } - if _, exists := aggPos[updatePos]; exists { - continue + if _, exists := partitionPos[oldPos]; !exists { + partitionPos[oldPos] = struct{}{} + partitionByExprs = append(partitionByExprs, childColExpr(oldPos)) } - aggExpr, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "any_value", []*plan.Expr{childColExpr(updatePos)}) - if err != nil { - return 0, nil, 0, err - } - aggPos[updatePos] = int32(len(aggList)) - aggList = append(aggList, aggExpr) } } + if len(partitionByExprs) == 0 { + return lastNodeID, selectNode, selectNodeTag, nil + } + + windowTag := builder.genNewBindTag() + partitionBy := make([]*plan.OrderBySpec, 0, len(partitionByExprs)) + for _, expr := range partitionByExprs { + partitionBy = append(partitionBy, &plan.OrderBySpec{ + Expr: expr, + Flag: plan.OrderBySpec_INTERNAL, + }) + } + lastNodeID = builder.appendNode(&plan.Node{ + NodeType: plan.Node_PARTITION, + Children: []int32{lastNodeID}, + OrderBy: partitionBy, + BindingTags: []int32{windowTag}, + }, bindCtx) + + rowNumberFunc, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "row_number", nil) + if err != nil { + return 0, nil, 0, err + } + rowNumberExpr := &plan.Expr{ + Typ: rowNumberFunc.Typ, + Expr: &plan.Expr_W{ + W: &plan.WindowSpec{ + WindowFunc: rowNumberFunc, + Name: "row_number", + PartitionBy: partitionByExprs, + Frame: &plan.FrameClause{ + Type: plan.FrameClause_ROWS, + Start: &plan.FrameBound{ + Type: plan.FrameBound_PRECEDING, + UnBounded: true, + }, + End: &plan.FrameBound{ + Type: plan.FrameBound_FOLLOWING, + UnBounded: true, + }, + }, + }, + }, + } + rowNumberIdx := int32(0) + rowNumberProjectPos := int32(len(selectNode.ProjectList)) + lastNodeID = builder.appendNode(&plan.Node{ + NodeType: plan.Node_WINDOW, + Children: []int32{lastNodeID}, + WinSpecList: []*plan.Expr{rowNumberExpr}, + WindowIdx: rowNumberIdx, + BindingTags: []int32{windowTag}, + }, bindCtx) + + windowProjectTag := builder.genNewBindTag() + windowProjectList := make([]*plan.Expr, 0, len(selectNode.ProjectList)+1) + for pos := range selectNode.ProjectList { + windowProjectList = append(windowProjectList, childColExpr(int32(pos))) + } + windowProjectList = append(windowProjectList, &plan.Expr{ + Typ: rowNumberFunc.Typ, + Expr: &plan.Expr_Col{ + Col: &plan.ColRef{ + RelPos: windowTag, + ColPos: rowNumberIdx, + Name: "__mo_update_from_dedup_row_number", + }, + }, + }) + lastNodeID = builder.appendNode(&plan.Node{ + NodeType: plan.Node_PROJECT, + Children: []int32{lastNodeID}, + ProjectList: windowProjectList, + BindingTags: []int32{windowProjectTag}, + }, bindCtx) + + rowNumberCol := &plan.Expr{ + Typ: plan.Type{Id: int32(types.T_int64), NotNullable: true}, + Expr: &plan.Expr_Col{ + Col: &plan.ColRef{ + RelPos: windowProjectTag, + ColPos: rowNumberProjectPos, + Name: "__mo_update_from_dedup_row_number", + }, + }, + } + keepFirstRowExpr, err := BindFuncExprImplByPlanExpr( + builder.GetContext(), "=", []*plan.Expr{rowNumberCol, makePlan2Int64ConstExprWithType(1)}) + if err != nil { + return 0, nil, 0, err + } + lastNodeID = builder.appendNode(&plan.Node{ + NodeType: plan.Node_FILTER, + Children: []int32{lastNodeID}, + FilterList: []*plan.Expr{keepFirstRowExpr}, + }, bindCtx) + projectList := make([]*plan.Expr, len(selectNode.ProjectList)) for pos, e := range selectNode.ProjectList { - colPos := int32(pos) - relPos := groupTag - if aggColPos, ok := aggPos[colPos]; ok { - colPos = aggColPos - relPos = aggregateTag - } else if groupColPos, ok := groupPos[colPos]; ok { - colPos = groupColPos - } else { - aggExpr, err := BindFuncExprImplByPlanExpr(builder.GetContext(), "any_value", []*plan.Expr{childColExpr(int32(pos))}) - if err != nil { - return 0, nil, 0, err - } - colPos = int32(len(aggList)) - relPos = aggregateTag - aggList = append(aggList, aggExpr) - } name := "" if col, ok := e.Expr.(*plan.Expr_Col); ok { name = col.Col.Name @@ -1068,24 +1131,14 @@ func (builder *QueryBuilder) appendUpdateFromDedupNode( Typ: e.Typ, Expr: &plan.Expr_Col{ Col: &plan.ColRef{ - RelPos: relPos, - ColPos: colPos, + RelPos: windowProjectTag, + ColPos: int32(pos), Name: name, }, }, } } - aggNode := &plan.Node{ - NodeType: plan.Node_AGG, - Children: []int32{lastNodeID}, - GroupBy: groupByExprs, - AggList: aggList, - BindingTags: []int32{groupTag, aggregateTag}, - SpillMem: builder.aggSpillMem, - } - lastNodeID = builder.appendNode(aggNode, bindCtx) - projectNode := &plan.Node{ NodeType: plan.Node_PROJECT, Children: []int32{lastNodeID}, diff --git a/pkg/sql/plan/build_test.go b/pkg/sql/plan/build_test.go index 9c77eb2438a61..879ff7f8f838e 100644 --- a/pkg/sql/plan/build_test.go +++ b/pkg/sql/plan/build_test.go @@ -729,22 +729,35 @@ func TestUpdatePgStyleFromDedupsDuplicateSourceMatchesOnNewPath(t *testing.T) { t.Fatalf("build UPDATE FROM plan: %v", err) } + query := logicPlan.GetQuery() tableDef := mock.ctxt.tables["nation"] - for _, node := range logicPlan.GetQuery().Nodes { - if node.NodeType != plan.Node_AGG { - continue - } - if len(node.GroupBy) != len(tableDef.Cols) || len(node.AggList) != 1 { - continue - } - if fn := node.AggList[0].GetF(); fn != nil && fn.Func.ObjName == "any_value" { - return - } + if hasUpdateFromDedupAnyValueAgg(query, len(tableDef.Cols)) { + t.Fatalf("UPDATE FROM dedup should not aggregate update columns with any_value") + } + if !hasUpdateFromDedupWindow(query, len(tableDef.Cols)) { + t.Fatalf("UPDATE FROM should dedup duplicate source matches with row_number window") } - t.Fatalf("UPDATE FROM should dedup duplicate source matches with AGG any_value over update columns") } -func TestUpdatePgStyleFromDedupExpandsDefaultBeforeAnyValue(t *testing.T) { +func TestUpdatePgStyleFromDedupPicksWholeSourceRow(t *testing.T) { + mock := NewMockOptimizer(true) + + logicPlan, err := runOneStmt(mock, t, + "UPDATE NATION SET N_NAME = NATION2.N_NAME, N_COMMENT = NATION2.N_COMMENT FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY") + if err != nil { + t.Fatalf("build UPDATE FROM plan: %v", err) + } + + query := logicPlan.GetQuery() + if hasUpdateFromDedupAnyValueAgg(query, len(mock.ctxt.tables["nation"].Cols)) { + t.Fatalf("UPDATE FROM dedup must pick a whole source row, not aggregate each update column with any_value") + } + if !hasUpdateFromDedupWindow(query, len(mock.ctxt.tables["nation"].Cols)) { + t.Fatalf("UPDATE FROM dedup should use row_number window partitioned by target old columns") + } +} + +func TestUpdatePgStyleFromDedupExpandsDefaultBeforeDedup(t *testing.T) { mock := NewMockOptimizer(true) setMockDefaultExpr(t, mock, "nation", "n_name", "name-default") @@ -754,21 +767,16 @@ func TestUpdatePgStyleFromDedupExpandsDefaultBeforeAnyValue(t *testing.T) { t.Fatalf("build UPDATE FROM with DEFAULT: %v", err) } - for _, node := range logicPlan.GetQuery().Nodes { - if node.NodeType != plan.Node_AGG { - continue - } - for _, aggExpr := range node.AggList { - if exprContainsDefaultVal(aggExpr) { - t.Fatalf("dedup any_value should not wrap raw DEFAULT marker: %v", aggExpr) - } - if fn := aggExpr.GetF(); fn != nil && fn.Func.ObjName == "any_value" && - exprContainsStringLiteral(aggExpr, "name-default") { - return - } - } + query := logicPlan.GetQuery() + if queryContainsDefaultVal(query) { + t.Fatalf("UPDATE FROM dedup should run after DEFAULT expansion") + } + if !queryContainsStringLiteral(query, "name-default") { + t.Fatalf("UPDATE FROM dedup should retain the expanded DEFAULT expression") + } + if hasUpdateFromDedupAnyValueAgg(query, len(mock.ctxt.tables["nation"].Cols)) { + t.Fatalf("UPDATE FROM dedup should not wrap DEFAULT with any_value") } - t.Fatalf("UPDATE FROM dedup should aggregate the expanded DEFAULT expression") } func TestUpdatePgStyleFromDedupAllowsVectorUpdateColumn(t *testing.T) { @@ -780,11 +788,11 @@ func TestUpdatePgStyleFromDedupAllowsVectorUpdateColumn(t *testing.T) { _, err := runOneStmt(mock, t, "UPDATE NATION SET N_COMMENT = NATION2.N_COMMENT FROM NATION2 WHERE NATION.N_REGIONKEY = NATION2.R_REGIONKEY") if err != nil { - t.Fatalf("UPDATE FROM should allow vector update columns through any_value dedup: %v", err) + t.Fatalf("UPDATE FROM should allow vector update columns through row-level dedup: %v", err) } } -func TestUpdatePgStyleFromDedupDoesNotAggregateGeneratedColumns(t *testing.T) { +func TestUpdatePgStyleFromDedupKeepsGeneratedColumnsAfterDedup(t *testing.T) { mock := NewMockOptimizer(true) setMockGeneratedColumn(t, mock, "nation", "n_comment", "n_name") @@ -794,14 +802,12 @@ func TestUpdatePgStyleFromDedupDoesNotAggregateGeneratedColumns(t *testing.T) { t.Fatalf("build UPDATE FROM with generated column: %v", err) } - aggNode := requireUpdateFromDedupAggNode(t, logicPlan.GetQuery(), len(mock.ctxt.tables["nation"].Cols)) - if got := countFuncInExprs(aggNode.AggList, "any_value"); got != 1 { - t.Fatalf("dedup should aggregate only explicit update columns, got %d any_value exprs: %v", got, aggNode.AggList) + query := logicPlan.GetQuery() + if hasUpdateFromDedupAnyValueAgg(query, len(mock.ctxt.tables["nation"].Cols)) { + t.Fatalf("dedup should not aggregate generated or update columns with any_value") } - for _, aggExpr := range aggNode.AggList { - if exprContainsColName(aggExpr, "n_comment") { - t.Fatalf("generated column should be recomputed after dedup, not aggregated by dedup: %v", aggExpr) - } + if !hasUpdateFromDedupWindow(query, len(mock.ctxt.tables["nation"].Cols)) { + t.Fatalf("UPDATE FROM with generated column should still use row-level dedup") } } @@ -831,7 +837,7 @@ func TestUpdatePgStyleFromDedupAllowsDecimal256AndEnumUpdateColumns(t *testing.T _, err := runOneStmt(mock, t, tt.sql) if err != nil { - t.Fatalf("UPDATE FROM should allow %s update columns through any_value dedup: %v", tt.name, err) + t.Fatalf("UPDATE FROM should allow %s update columns through row-level dedup: %v", tt.name, err) } }) } @@ -1028,19 +1034,70 @@ func requireFallbackSourceProjectNode(t *testing.T, query *Query, projectLen int return nil } -func requireUpdateFromDedupAggNode(t *testing.T, query *Query, groupByLen int) *Node { +func hasUpdateFromDedupAnyValueAgg(query *Query, groupByLen int) bool { for _, node := range query.Nodes { if node.NodeType != plan.Node_AGG || len(node.GroupBy) != groupByLen { continue } for _, aggExpr := range node.AggList { if fn := aggExpr.GetF(); fn != nil && fn.Func.ObjName == "any_value" { - return node + return true } } } - t.Fatalf("missing UPDATE FROM dedup agg node with %d group keys", groupByLen) - return nil + return false +} + +func hasUpdateFromDedupWindow(query *Query, partitionByLen int) bool { + for _, node := range query.Nodes { + if node.NodeType != plan.Node_WINDOW { + continue + } + for _, winExpr := range node.WinSpecList { + spec := winExpr.GetW() + if spec == nil || spec.Name != "row_number" || len(spec.PartitionBy) != partitionByLen { + continue + } + return true + } + } + return false +} + +func queryContainsStringLiteral(query *Query, value string) bool { + return queryContainsExpr(query, func(expr *plan.Expr) bool { + return exprContainsStringLiteral(expr, value) + }) +} + +func queryContainsDefaultVal(query *Query) bool { + return queryContainsExpr(query, exprContainsDefaultVal) +} + +func queryContainsExpr(query *Query, accept func(*plan.Expr) bool) bool { + for _, node := range query.Nodes { + exprLists := [][]*plan.Expr{ + node.ProjectList, + node.OnList, + node.FilterList, + node.GroupBy, + node.AggList, + node.WinSpecList, + } + for _, exprList := range exprLists { + for _, expr := range exprList { + if accept(expr) { + return true + } + } + } + for _, order := range node.OrderBy { + if accept(order.Expr) { + return true + } + } + } + return false } func requireQueryExpr(t *testing.T, query *Query, accept func(*plan.Expr) bool, message string) *plan.Expr { @@ -1109,6 +1166,20 @@ func exprContainsStringLiteral(expr *plan.Expr, value string) bool { return true } } + case *plan.Expr_W: + if exprContainsStringLiteral(e.W.WindowFunc, value) { + return true + } + for _, partition := range e.W.PartitionBy { + if exprContainsStringLiteral(partition, value) { + return true + } + } + for _, order := range e.W.OrderBy { + if exprContainsStringLiteral(order.Expr, value) { + return true + } + } } return false } @@ -1130,38 +1201,22 @@ func exprContainsDefaultVal(expr *plan.Expr) bool { return true } } - } - return false -} - -func countFuncInExprs(exprs []*plan.Expr, funcName string) int { - count := 0 - for _, expr := range exprs { - count += countFuncInExpr(expr, funcName) - } - return count -} - -func countFuncInExpr(expr *plan.Expr, funcName string) int { - switch e := expr.Expr.(type) { - case *plan.Expr_F: - count := 0 - if e.F.Func.ObjName == funcName { - count++ + case *plan.Expr_W: + if exprContainsDefaultVal(e.W.WindowFunc) { + return true } - for _, arg := range e.F.Args { - count += countFuncInExpr(arg, funcName) + for _, partition := range e.W.PartitionBy { + if exprContainsDefaultVal(partition) { + return true + } } - return count - case *plan.Expr_List: - count := 0 - for _, item := range e.List.List { - count += countFuncInExpr(item, funcName) + for _, order := range e.W.OrderBy { + if exprContainsDefaultVal(order.Expr) { + return true + } } - return count - default: - return 0 } + return false } func exprContainsColName(expr *plan.Expr, name string) bool { diff --git a/test/distributed/cases/dml/update/update_pg_style_from.result b/test/distributed/cases/dml/update/update_pg_style_from.result index c856484ddce5c..539c875e39978 100644 --- a/test/distributed/cases/dml/update/update_pg_style_from.result +++ b/test/distributed/cases/dml/update/update_pg_style_from.result @@ -129,11 +129,36 @@ INSERT INTO gen_dup_t (id, base) VALUES (1, 5); CREATE TABLE gen_dup_s (t_id INT, new_base INT); INSERT INTO gen_dup_s VALUES (1, NULL), (1, 7); UPDATE gen_dup_t SET base = s.new_base FROM gen_dup_s s WHERE s.t_id = gen_dup_t.id; -SELECT id, base, gen_col FROM gen_dup_t ORDER BY id; -➤ id[4,32,0] ¦ base[4,32,0] ¦ gen_col[4,32,0] 𝄀 -1 ¦ 7 ¦ 7 +SELECT COUNT(*) AS valid_generated_row FROM gen_dup_t +WHERE (base IS NULL AND gen_col = 0) OR (base = 7 AND gen_col = 7); +➤ valid_generated_row[-5,64,0] 𝄀 +1 DROP TABLE gen_dup_t; DROP TABLE gen_dup_s; +DROP TABLE IF EXISTS whole_row_t; +DROP TABLE IF EXISTS whole_row_s; +CREATE TABLE whole_row_t ( +id INT PRIMARY KEY, +a INT, +b VARCHAR(20) +); +CREATE TABLE whole_row_s ( +t_id INT, +new_a INT, +new_b VARCHAR(20) +); +INSERT INTO whole_row_t VALUES (1, 0, 'orig'); +INSERT INTO whole_row_s VALUES (1, NULL, 'from-null-a'), (1, 7, NULL); +UPDATE whole_row_t SET a = s.new_a, b = s.new_b FROM whole_row_s s WHERE s.t_id = whole_row_t.id; +SELECT COUNT(*) AS valid_whole_row FROM whole_row_t +WHERE (a IS NULL AND b = 'from-null-a') OR (a = 7 AND b IS NULL); +➤ valid_whole_row[-5,64,0] 𝄀 +1 +SELECT COUNT(*) AS synthesized_row FROM whole_row_t WHERE a = 7 AND b = 'from-null-a'; +➤ synthesized_row[-5,64,0] 𝄀 +0 +DROP TABLE whole_row_t; +DROP TABLE whole_row_s; DROP TABLE IF EXISTS fk_parent; DROP TABLE IF EXISTS fk_child; DROP TABLE IF EXISTS fk_src; diff --git a/test/distributed/cases/dml/update/update_pg_style_from.sql b/test/distributed/cases/dml/update/update_pg_style_from.sql index 71a82bb183bbe..0a77d1e833dff 100644 --- a/test/distributed/cases/dml/update/update_pg_style_from.sql +++ b/test/distributed/cases/dml/update/update_pg_style_from.sql @@ -134,10 +134,35 @@ INSERT INTO gen_dup_t (id, base) VALUES (1, 5); CREATE TABLE gen_dup_s (t_id INT, new_base INT); INSERT INTO gen_dup_s VALUES (1, NULL), (1, 7); UPDATE gen_dup_t SET base = s.new_base FROM gen_dup_s s WHERE s.t_id = gen_dup_t.id; -SELECT id, base, gen_col FROM gen_dup_t ORDER BY id; +SELECT COUNT(*) AS valid_generated_row FROM gen_dup_t +WHERE (base IS NULL AND gen_col = 0) OR (base = 7 AND gen_col = 7); DROP TABLE gen_dup_t; DROP TABLE gen_dup_s; +-- Duplicate source rows must be deduped as whole rows. Per-column any_value() +-- can synthesize (new_a = 7, new_b = 'from-null-a'), which is not present in +-- the source. +DROP TABLE IF EXISTS whole_row_t; +DROP TABLE IF EXISTS whole_row_s; +CREATE TABLE whole_row_t ( + id INT PRIMARY KEY, + a INT, + b VARCHAR(20) +); +CREATE TABLE whole_row_s ( + t_id INT, + new_a INT, + new_b VARCHAR(20) +); +INSERT INTO whole_row_t VALUES (1, 0, 'orig'); +INSERT INTO whole_row_s VALUES (1, NULL, 'from-null-a'), (1, 7, NULL); +UPDATE whole_row_t SET a = s.new_a, b = s.new_b FROM whole_row_s s WHERE s.t_id = whole_row_t.id; +SELECT COUNT(*) AS valid_whole_row FROM whole_row_t +WHERE (a IS NULL AND b = 'from-null-a') OR (a = 7 AND b IS NULL); +SELECT COUNT(*) AS synthesized_row FROM whole_row_t WHERE a = 7 AND b = 'from-null-a'; +DROP TABLE whole_row_t; +DROP TABLE whole_row_s; + -- FK target with a generated column: the FK forces the fallback planner -- (buildTableUpdate). Generated column protection must still apply there. DROP TABLE IF EXISTS fk_parent; From e95703970e706bb393a3124382ed8569297c262f Mon Sep 17 00:00:00 2001 From: Cao Kai Date: Fri, 12 Jun 2026 21:36:40 +0800 Subject: [PATCH 6/6] fix(plan): dedup fallback update from rows --- pkg/sql/plan/bind_update.go | 51 ++++++++++++------- pkg/sql/plan/build_constraint_util.go | 6 +-- pkg/sql/plan/build_test.go | 18 +++++++ pkg/sql/plan/build_update.go | 30 +++++++++++ .../dml/update/update_pg_style_from.result | 30 +++++++++++ .../cases/dml/update/update_pg_style_from.sql | 34 +++++++++++-- 6 files changed, 143 insertions(+), 26 deletions(-) diff --git a/pkg/sql/plan/bind_update.go b/pkg/sql/plan/bind_update.go index bf5b8e445cad7..eaf1905ce9a42 100644 --- a/pkg/sql/plan/bind_update.go +++ b/pkg/sql/plan/bind_update.go @@ -986,8 +986,37 @@ func (builder *QueryBuilder) appendUpdateFromDedupNode( oldColName2Idx map[string]int32, newColName2Idx map[string]int32, ) (int32, *plan.Node, int32, error) { - partitionByExprs := make([]*plan.Expr, 0) + partitionColPositions := make([]int32, 0) partitionPos := make(map[int32]struct{}) + for i, alias := range dmlCtx.aliases { + if len(dmlCtx.updateCol2Expr[i]) == 0 { + continue + } + + for _, col := range dmlCtx.tableDefs[i].Cols { + key := alias + "." + col.Name + oldPos, ok := oldColName2Idx[key] + if !ok { + continue + } + if _, exists := partitionPos[oldPos]; !exists { + partitionPos[oldPos] = struct{}{} + partitionColPositions = append(partitionColPositions, oldPos) + } + } + } + + return builder.appendRowNumberDedupNode(bindCtx, lastNodeID, selectNode, selectNodeTag, partitionColPositions) +} + +func (builder *QueryBuilder) appendRowNumberDedupNode( + bindCtx *BindContext, + lastNodeID int32, + selectNode *plan.Node, + selectNodeTag int32, + partitionColPositions []int32, +) (int32, *plan.Node, int32, error) { + partitionByExprs := make([]*plan.Expr, 0, len(partitionColPositions)) childColExpr := func(pos int32) *plan.Expr { e := selectNode.ProjectList[pos] name := "" @@ -1005,25 +1034,9 @@ func (builder *QueryBuilder) appendUpdateFromDedupNode( }, } } - - for i, alias := range dmlCtx.aliases { - if len(dmlCtx.updateCol2Expr[i]) == 0 { - continue - } - - for _, col := range dmlCtx.tableDefs[i].Cols { - key := alias + "." + col.Name - oldPos, ok := oldColName2Idx[key] - if !ok { - continue - } - if _, exists := partitionPos[oldPos]; !exists { - partitionPos[oldPos] = struct{}{} - partitionByExprs = append(partitionByExprs, childColExpr(oldPos)) - } - } + for _, pos := range partitionColPositions { + partitionByExprs = append(partitionByExprs, childColExpr(pos)) } - if len(partitionByExprs) == 0 { return lastNodeID, selectNode, selectNodeTag, nil } diff --git a/pkg/sql/plan/build_constraint_util.go b/pkg/sql/plan/build_constraint_util.go index 81796b8df75e8..6d4403a39e1a9 100644 --- a/pkg/sql/plan/build_constraint_util.go +++ b/pkg/sql/plan/build_constraint_util.go @@ -100,10 +100,8 @@ func getUpdateTableInfo(ctx CompilerContext, stmt *tree.Update) (*dmlTableInfo, } // A PostgreSQL-style UPDATE ... FROM may match a single target row from - // multiple source rows. Force the agg-based dedup path (any_value over - // the target's primary key) just like the classic multi-table syntax - // would; without this flag the fallback planner would silently produce - // duplicate-row writes. + // multiple source rows. Mark it for dedup so the planner does not + // produce duplicate-row writes. if stmt.From != nil && len(stmt.From.Tables) > 0 { tblInfo.needAggFilter = true } diff --git a/pkg/sql/plan/build_test.go b/pkg/sql/plan/build_test.go index 879ff7f8f838e..69ec988151708 100644 --- a/pkg/sql/plan/build_test.go +++ b/pkg/sql/plan/build_test.go @@ -757,6 +757,24 @@ func TestUpdatePgStyleFromDedupPicksWholeSourceRow(t *testing.T) { } } +func TestUpdateFallbackPgStyleFromDedupPicksWholeSourceRow(t *testing.T) { + mock := NewMockOptimizer(true) + + logicPlan, err := runOneStmt(mock, t, + "UPDATE emp SET sal = dept.deptno, comm = dept.deptno FROM dept WHERE emp.deptno = dept.deptno") + if err != nil { + t.Fatalf("build fallback UPDATE FROM plan: %v", err) + } + + query := logicPlan.GetQuery() + if hasUpdateFromDedupAnyValueAgg(query, len(mock.ctxt.tables["emp"].Cols)) { + t.Fatalf("fallback UPDATE FROM dedup must pick a whole source row, not aggregate each update column with any_value") + } + if !hasUpdateFromDedupWindow(query, len(mock.ctxt.tables["emp"].Cols)) { + t.Fatalf("fallback UPDATE FROM dedup should use row_number window partitioned by target old columns") + } +} + func TestUpdatePgStyleFromDedupExpandsDefaultBeforeDedup(t *testing.T) { mock := NewMockOptimizer(true) setMockDefaultExpr(t, mock, "nation", "n_name", "name-default") diff --git a/pkg/sql/plan/build_update.go b/pkg/sql/plan/build_update.go index 432a97d3eb772..05ccd456103f0 100644 --- a/pkg/sql/plan/build_update.go +++ b/pkg/sql/plan/build_update.go @@ -47,6 +47,22 @@ func buildTableUpdate(stmt *tree.Update, ctx CompilerContext, isPrepareStmt bool if err != nil { return nil, err } + if stmt.From != nil && len(stmt.From.Tables) > 0 && tblInfo.needAggFilter { + lastNode := builder.qry.Nodes[lastNodeId] + lastNodeId, _, _, err = builder.appendRowNumberDedupNode( + queryBindCtx, + lastNodeId, + lastNode, + lastNode.BindingTags[0], + fallbackUpdateFromDedupPartitionCols(updatePlanCtxs), + ) + if err != nil { + return nil, err + } + for _, updatePlanCtx := range updatePlanCtxs { + updatePlanCtx.needAggFilter = false + } + } sourceStep := builder.appendStep(lastNodeId) query, err := builder.createQuery() @@ -256,6 +272,20 @@ func rewriteGeneratedColumnsForUpdate(builder *QueryBuilder, planCtxs []*dmlPlan return nil } +func fallbackUpdateFromDedupPartitionCols(planCtxs []*dmlPlanCtx) []int32 { + partitionCols := make([]int32, 0) + offset := int32(0) + for _, planCtx := range planCtxs { + if planCtx.updateColLength > 0 { + for i := range planCtx.tableDef.Cols { + partitionCols = append(partitionCols, offset+int32(i)) + } + } + offset += int32(len(planCtx.tableDef.Cols) + planCtx.updateColLength) + } + return partitionCols +} + func selectUpdateTables(builder *QueryBuilder, bindCtx *BindContext, stmt *tree.Update, tableInfo *dmlTableInfo) (int32, []*dmlPlanCtx, error) { // Merge target table list with PostgreSQL-style FROM sources so that the // inner SELECT can resolve column references against both. tableInfo only diff --git a/test/distributed/cases/dml/update/update_pg_style_from.result b/test/distributed/cases/dml/update/update_pg_style_from.result index 539c875e39978..a4b30fc0052f2 100644 --- a/test/distributed/cases/dml/update/update_pg_style_from.result +++ b/test/distributed/cases/dml/update/update_pg_style_from.result @@ -230,6 +230,36 @@ SELECT id, p_id, v FROM dup_t ORDER BY id; DROP TABLE dup_t; DROP TABLE dup_p; DROP TABLE dup_s; +DROP TABLE IF EXISTS fk_whole_row_t; +DROP TABLE IF EXISTS fk_whole_row_p; +DROP TABLE IF EXISTS fk_whole_row_s; +CREATE TABLE fk_whole_row_p (id INT PRIMARY KEY); +INSERT INTO fk_whole_row_p VALUES (1); +CREATE TABLE fk_whole_row_t ( +id INT PRIMARY KEY, +p_id INT, +a INT, +b VARCHAR(20), +FOREIGN KEY (p_id) REFERENCES fk_whole_row_p(id) +); +CREATE TABLE fk_whole_row_s ( +t_id INT, +new_a INT, +new_b VARCHAR(20) +); +INSERT INTO fk_whole_row_t VALUES (1, 1, 0, 'orig'); +INSERT INTO fk_whole_row_s VALUES (1, NULL, 'from-null-a'), (1, 7, NULL); +UPDATE fk_whole_row_t SET a = s.new_a, b = s.new_b FROM fk_whole_row_s s WHERE s.t_id = fk_whole_row_t.id; +SELECT COUNT(*) AS valid_whole_row FROM fk_whole_row_t +WHERE (a IS NULL AND b = 'from-null-a') OR (a = 7 AND b IS NULL); +➤ valid_whole_row[-5,64,0] 𝄀 +1 +SELECT COUNT(*) AS synthesized_row FROM fk_whole_row_t WHERE a = 7 AND b = 'from-null-a'; +➤ synthesized_row[-5,64,0] 𝄀 +0 +DROP TABLE fk_whole_row_t; +DROP TABLE fk_whole_row_p; +DROP TABLE fk_whole_row_s; DROP TABLE IF EXISTS dup_no_fk_t; DROP TABLE IF EXISTS dup_no_fk_s; CREATE TABLE dup_no_fk_t ( diff --git a/test/distributed/cases/dml/update/update_pg_style_from.sql b/test/distributed/cases/dml/update/update_pg_style_from.sql index 0a77d1e833dff..8ca6e9ec93d20 100644 --- a/test/distributed/cases/dml/update/update_pg_style_from.sql +++ b/test/distributed/cases/dml/update/update_pg_style_from.sql @@ -215,9 +215,8 @@ DROP TABLE ob_s; -- Duplicate-match on the fallback path: target row 1 is matched by both -- (10,1,...) and (11,1,...) source rows. Because dup_t has a FK the fallback --- planner (buildTableUpdate) handles this, and needAggFilter must be set so --- the AGG any_value() dedup runs. Without v3's needAggFilter wiring for --- stmt.From, this would silently double-write target row 1. +-- planner (buildTableUpdate) handles this. It must still dedup duplicate +-- matches instead of silently double-writing target row 1. DROP TABLE IF EXISTS dup_t; DROP TABLE IF EXISTS dup_p; DROP TABLE IF EXISTS dup_s; @@ -238,6 +237,35 @@ DROP TABLE dup_t; DROP TABLE dup_p; DROP TABLE dup_s; +-- Fallback UPDATE ... FROM dedup must also pick a whole source row. FK on the +-- target forces buildTableUpdate instead of the new bindUpdate path. +DROP TABLE IF EXISTS fk_whole_row_t; +DROP TABLE IF EXISTS fk_whole_row_p; +DROP TABLE IF EXISTS fk_whole_row_s; +CREATE TABLE fk_whole_row_p (id INT PRIMARY KEY); +INSERT INTO fk_whole_row_p VALUES (1); +CREATE TABLE fk_whole_row_t ( + id INT PRIMARY KEY, + p_id INT, + a INT, + b VARCHAR(20), + FOREIGN KEY (p_id) REFERENCES fk_whole_row_p(id) +); +CREATE TABLE fk_whole_row_s ( + t_id INT, + new_a INT, + new_b VARCHAR(20) +); +INSERT INTO fk_whole_row_t VALUES (1, 1, 0, 'orig'); +INSERT INTO fk_whole_row_s VALUES (1, NULL, 'from-null-a'), (1, 7, NULL); +UPDATE fk_whole_row_t SET a = s.new_a, b = s.new_b FROM fk_whole_row_s s WHERE s.t_id = fk_whole_row_t.id; +SELECT COUNT(*) AS valid_whole_row FROM fk_whole_row_t +WHERE (a IS NULL AND b = 'from-null-a') OR (a = 7 AND b IS NULL); +SELECT COUNT(*) AS synthesized_row FROM fk_whole_row_t WHERE a = 7 AND b = 'from-null-a'; +DROP TABLE fk_whole_row_t; +DROP TABLE fk_whole_row_p; +DROP TABLE fk_whole_row_s; + -- Duplicate-match on the new UPDATE path without FK constraints must still -- update each target row once instead of producing duplicate primary keys. DROP TABLE IF EXISTS dup_no_fk_t;