diff --git a/pkg/common/event/ddl_event.go b/pkg/common/event/ddl_event.go index ef8dcd6732..0ef8946228 100644 --- a/pkg/common/event/ddl_event.go +++ b/pkg/common/event/ddl_event.go @@ -22,6 +22,8 @@ import ( "github.com/pingcap/log" "github.com/pingcap/ticdc/pkg/common" "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser" + "github.com/pingcap/tidb/pkg/parser/ast" "go.uber.org/zap" ) @@ -41,14 +43,23 @@ type DDLEvent struct { SchemaID int64 `json:"schema_id"` SchemaName string `json:"schema_name"` TableName string `json:"table_name"` + // the following two fields are just used for RenameTable, // they are the old schema/table name of the table - ExtraSchemaName string `json:"extra_schema_name"` - ExtraTableName string `json:"extra_table_name"` - Query string `json:"query"` - TableInfo *common.TableInfo `json:"-"` - StartTs uint64 `json:"start_ts"` - FinishedTs uint64 `json:"finished_ts"` + ExtraSchemaName string `json:"extra_schema_name"` + ExtraTableName string `json:"extra_table_name"` + + // target related fields carry routed names. + // They are set after the unmarshal, so no need to be serialized. + targetSchemaName string `json:"-"` + targetTableName string `json:"-"` + targetExtraSchemaName string `json:"-"` + targetExtraTableName string `json:"-"` + + Query string `json:"query"` + TableInfo *common.TableInfo `json:"-"` + StartTs uint64 `json:"start_ts"` + FinishedTs uint64 `json:"finished_ts"` // The seq of the event. It is set by event service. Seq uint64 `json:"seq"` // The epoch of the event. It is set by event service. @@ -201,6 +212,34 @@ func (d *DDLEvent) GetExtraTableName() string { return d.ExtraTableName } +func (d *DDLEvent) GetTargetSchemaName() string { + if d.targetSchemaName != "" { + return d.targetSchemaName + } + return d.SchemaName +} + +func (d *DDLEvent) GetTargetTableName() string { + if d.targetTableName != "" { + return d.targetTableName + } + return d.TableName +} + +func (d *DDLEvent) GetTargetExtraSchemaName() string { + if d.targetExtraSchemaName != "" { + return d.targetExtraSchemaName + } + return d.ExtraSchemaName +} + +func (d *DDLEvent) GetTargetExtraTableName() string { + if d.targetExtraTableName != "" { + return d.targetExtraTableName + } + return d.ExtraTableName +} + // GetTableID returns the logic table ID of the event. // it returns 0 when there is no tableinfo func (d *DDLEvent) GetTableID() int64 { @@ -210,6 +249,7 @@ func (d *DDLEvent) GetTableID() int64 { return 0 } +// GetEvents split the multi tables DDL into single table DDLs. func (d *DDLEvent) GetEvents() []*DDLEvent { // Some ddl event may be multi-events, we need to split it into multiple messages. // Such as rename table test.table1 to test.table10, test.table2 to test.table20 @@ -230,18 +270,23 @@ func (d *DDLEvent) GetEvents() []*DDLEvent { } for i, info := range d.MultipleTableInfos { event := &DDLEvent{ - Version: d.Version, - Type: byte(t), - SchemaName: info.GetSchemaName(), - TableName: info.GetTableName(), - TableInfo: info, - Query: queries[i], - StartTs: d.StartTs, - FinishedTs: d.FinishedTs, + Version: d.Version, + Type: byte(t), + SchemaName: info.GetSchemaName(), + TableName: info.GetTableName(), + targetSchemaName: info.GetTargetSchemaName(), + targetTableName: info.GetTargetTableName(), + TableInfo: info, + Query: queries[i], + StartTs: d.StartTs, + FinishedTs: d.FinishedTs, } if model.ActionType(d.Type) == model.ActionRenameTables { event.ExtraSchemaName = d.TableNameChange.DropName[i].SchemaName event.ExtraTableName = d.TableNameChange.DropName[i].TableName + targetExtraSchemaName, targetExtraTableName := extractRenameTargetExtraFromQuery(queries[i]) + event.targetExtraSchemaName = targetExtraSchemaName + event.targetExtraTableName = targetExtraTableName } events = append(events, event) } @@ -251,6 +296,19 @@ func (d *DDLEvent) GetEvents() []*DDLEvent { return []*DDLEvent{d} } +func extractRenameTargetExtraFromQuery(query string) (string, string) { + stmt, err := parser.New().ParseOneStmt(query, "", "") + if err != nil { + log.Panic("parse split rename query failed", zap.String("query", query), zap.Error(err)) + } + renameStmt, ok := stmt.(*ast.RenameTableStmt) + if !ok || len(renameStmt.TableToTables) == 0 { + log.Panic("unexpected split rename query", zap.String("query", query), zap.Any("stmt", stmt)) + } + oldTable := renameStmt.TableToTables[0].OldTable + return oldTable.Schema.O, oldTable.Name.O +} + func (d *DDLEvent) GetSeq() uint64 { return d.Seq } @@ -479,6 +537,70 @@ func (t *DDLEvent) IsPaused() bool { return false } +// NewRoutedDDLEvent builds a routed DDL event from the origin event and final routed fields. +func NewRoutedDDLEvent( + d *DDLEvent, + query string, + targetSchemaName, targetTableName string, + targetExtraSchemaName, targetExtraTableName string, + tableInfo *common.TableInfo, + multipleTableInfos []*common.TableInfo, + blockedTableNames []SchemaTableName, +) *DDLEvent { + if d == nil { + return nil + } + + return &DDLEvent{ + Version: d.Version, + DispatcherID: d.DispatcherID, + Type: d.Type, + SchemaID: d.SchemaID, + SchemaName: d.SchemaName, + TableName: d.TableName, + ExtraSchemaName: d.ExtraSchemaName, + ExtraTableName: d.ExtraTableName, + targetSchemaName: targetSchemaName, + targetTableName: targetTableName, + targetExtraSchemaName: targetExtraSchemaName, + targetExtraTableName: targetExtraTableName, + Query: query, + TableInfo: tableInfo, + StartTs: d.StartTs, + FinishedTs: d.FinishedTs, + Seq: d.Seq, + Epoch: d.Epoch, + // MultipleTableInfos and BlockedTableNames carry table names used by downstream + // execution paths, so the routed versions must be passed in explicitly. + MultipleTableInfos: multipleTableInfos, + BlockedTableNames: blockedTableNames, + // The following fields do not participate in table route name rewriting, + // so the routed event keeps the original values from the source event. + BlockedTables: d.BlockedTables, + NeedDroppedTables: d.NeedDroppedTables, + NeedAddedTables: d.NeedAddedTables, + UpdatedSchemas: d.UpdatedSchemas, + TableNameChange: d.TableNameChange, + TiDBOnly: d.TiDBOnly, + BDRMode: d.BDRMode, + Err: d.Err, + PostTxnFlushed: clonePostTxnFlushed(d.PostTxnFlushed), + eventSize: d.eventSize, + IsBootstrap: d.IsBootstrap, + NotSync: d.NotSync, + } +} + +func clonePostTxnFlushed(postTxnFlushed []func()) []func() { + if postTxnFlushed == nil { + return nil + } + + cloned := make([]func(), len(postTxnFlushed)) + copy(cloned, postTxnFlushed) + return cloned +} + func (t *DDLEvent) Len() int32 { return 1 } diff --git a/pkg/common/event/ddl_event_test.go b/pkg/common/event/ddl_event_test.go index 6cb58852e8..9447acefe6 100644 --- a/pkg/common/event/ddl_event_test.go +++ b/pkg/common/event/ddl_event_test.go @@ -21,6 +21,8 @@ import ( "github.com/pingcap/ticdc/pkg/common" "github.com/pingcap/ticdc/pkg/errors" + "github.com/pingcap/tidb/pkg/meta/model" + "github.com/pingcap/tidb/pkg/parser/ast" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -503,3 +505,201 @@ INSERT INTO test VALUES (1); }) } } + +// TestNewRoutedDDLEvent ensures routed DDL construction preserves the origin event +// while producing an independent routed event for downstream use. +func TestNewRoutedDDLEvent(t *testing.T) { + helper := NewEventTestHelper(t) + defer helper.Close() + + helper.tk.MustExec("use test") + ddlJob := helper.DDL2Job(createTableSQL) + require.NotNil(t, ddlJob) + + // Create original DDL event with all fields populated + originalTableInfo := common.WrapTableInfo(ddlJob.SchemaName, ddlJob.BinlogInfo.TableInfo) + originalTableInfo.InitPrivateFields() + + multipleTableInfo1 := common.WrapTableInfo("schema1", ddlJob.BinlogInfo.TableInfo) + multipleTableInfo1.InitPrivateFields() + multipleTableInfo2 := common.WrapTableInfo("schema2", ddlJob.BinlogInfo.TableInfo) + multipleTableInfo2.InitPrivateFields() + + postFlushFunc1 := func() {} + postFlushFunc2 := func() {} + + original := &DDLEvent{ + Version: DDLEventVersion1, + DispatcherID: common.NewDispatcherID(), + Type: byte(ddlJob.Type), + SchemaID: ddlJob.SchemaID, + SchemaName: ddlJob.SchemaName, + TableName: ddlJob.TableName, + Query: ddlJob.Query, + TableInfo: originalTableInfo, + FinishedTs: ddlJob.BinlogInfo.FinishedTS, + Seq: 1, + Epoch: 2, + MultipleTableInfos: []*common.TableInfo{multipleTableInfo1, multipleTableInfo2}, + PostTxnFlushed: []func(){postFlushFunc1, postFlushFunc2}, + TiDBOnly: true, + BDRMode: "test-mode", + } + + newRoutedTableInfo := originalTableInfo.CloneWithRouting("routed_schema", "test") + routedMultipleTableInfos := []*common.TableInfo{ + multipleTableInfo1.CloneWithRouting("routed_schema1", "table1"), + multipleTableInfo2.CloneWithRouting("routed_schema2", "table2"), + } + + routed := NewRoutedDDLEvent( + original, + "CREATE TABLE routed_schema.test ...", + "routed_schema", + "", + "", + "", + newRoutedTableInfo, + routedMultipleTableInfos, + original.BlockedTableNames, + ) + require.NotNil(t, routed) + + // Verify that the routed event is a separate object. + require.False(t, original == routed, "routed event should be a different object") + + // Verify that non-routing fields are copied as-is. + require.Equal(t, original.Version, routed.Version) + require.Equal(t, original.DispatcherID, routed.DispatcherID) + require.Equal(t, original.Type, routed.Type) + require.Equal(t, original.SchemaID, routed.SchemaID) + require.Equal(t, original.SchemaName, routed.SchemaName) + require.Equal(t, original.TableName, routed.TableName) + require.Equal(t, original.FinishedTs, routed.FinishedTs) + require.Equal(t, original.Seq, routed.Seq) + require.Equal(t, original.Epoch, routed.Epoch) + require.Equal(t, original.TiDBOnly, routed.TiDBOnly) + require.Equal(t, original.BDRMode, routed.BDRMode) + + // Verify that MultipleTableInfos is a new slice so later mutations remain isolated. + require.False(t, &original.MultipleTableInfos[0] == &routed.MultipleTableInfos[0], "MultipleTableInfos should be a new slice") + + // Verify that PostTxnFlushed is an independent copy (not shared) + // This is defensive: currently DDL events arrive with nil PostTxnFlushed, + // but we copy it to prevent races if callbacks are ever added before building the routed event. + require.NotNil(t, routed.PostTxnFlushed) + require.Equal(t, 2, len(routed.PostTxnFlushed), "PostTxnFlushed should have same length as original") + require.Equal(t, 2, len(original.PostTxnFlushed), "Original PostTxnFlushed should remain unchanged") + // Verify independent backing arrays. + require.NotEqual(t, &original.PostTxnFlushed[0], &routed.PostTxnFlushed[0], "PostTxnFlushed should have independent backing arrays") + + // Verify that appending to the routed event doesn't affect the original. + routed.AddPostFlushFunc(func() {}) + require.Equal(t, 3, len(routed.PostTxnFlushed), "Routed event should have appended callback") + require.Equal(t, 2, len(original.PostTxnFlushed), "Original should be unaffected by routed event append") + + // Verify that routed state doesn't affect the original. + require.Equal(t, ddlJob.SchemaName, original.SchemaName, "Original SchemaName should be unchanged") + require.Equal(t, ddlJob.Query, original.Query, "Original Query should be unchanged") + require.True(t, original.TableInfo == originalTableInfo, "Original TableInfo should be unchanged") + require.True(t, original.MultipleTableInfos[0] == multipleTableInfo1, "Original MultipleTableInfos[0] should be unchanged") + require.True(t, original.MultipleTableInfos[1] == multipleTableInfo2, "Original MultipleTableInfos[1] should be unchanged") + + // Verify that the routed event has the routed state. + require.Equal(t, "routed_schema", routed.GetTargetSchemaName()) + require.Equal(t, "CREATE TABLE routed_schema.test ...", routed.Query) + require.True(t, routed.TableInfo == newRoutedTableInfo) + require.Equal(t, "routed_schema", routed.TableInfo.TableName.TargetSchema) + require.Equal(t, original.SchemaName, routed.GetSchemaName()) + require.Equal(t, original.TableName, routed.GetTableName()) + require.True(t, routed.MultipleTableInfos[0] == routedMultipleTableInfos[0]) + require.True(t, routed.MultipleTableInfos[1] == routedMultipleTableInfos[1]) + + // Test nil origin event. + var nilEvent *DDLEvent + routedNil := NewRoutedDDLEvent(nilEvent, "", "", "", "", "", nil, nil, nil) + require.Nil(t, routedNil) +} + +func TestNewRoutedDDLEventPreservesSourceFields(t *testing.T) { + original := &DDLEvent{ + SchemaName: "source_db", + TableName: "new_orders", + ExtraSchemaName: "source_db", + ExtraTableName: "old_orders", + targetSchemaName: "target_db", + targetTableName: "new_orders_routed", + targetExtraSchemaName: "target_db", + targetExtraTableName: "old_orders_routed", + } + + routed := NewRoutedDDLEvent( + original, + original.Query, + "target_db_v2", + "new_orders_routed_v2", + "target_db_v2", + "old_orders_routed_v2", + original.TableInfo, + original.MultipleTableInfos, + original.BlockedTableNames, + ) + + require.Equal(t, "source_db", routed.GetSchemaName()) + require.Equal(t, "new_orders", routed.GetTableName()) + require.Equal(t, "source_db", routed.GetExtraSchemaName()) + require.Equal(t, "old_orders", routed.GetExtraTableName()) + require.Equal(t, "target_db_v2", routed.GetTargetSchemaName()) + require.Equal(t, "new_orders_routed_v2", routed.GetTargetTableName()) + require.Equal(t, "target_db_v2", routed.GetTargetExtraSchemaName()) + require.Equal(t, "old_orders_routed_v2", routed.GetTargetExtraTableName()) +} + +func TestGetEventsForRenameTablesPreservesSourceAndTargetNames(t *testing.T) { + sourceTable1 := common.WrapTableInfo("new_db1", &model.TableInfo{ + ID: 100, + Name: ast.NewCIStr("new_table1"), + UpdateTS: 10, + }) + sourceTable2 := common.WrapTableInfo("new_db2", &model.TableInfo{ + ID: 101, + Name: ast.NewCIStr("new_table2"), + UpdateTS: 11, + }) + + ddl := &DDLEvent{ + Type: byte(model.ActionRenameTables), + Query: "RENAME TABLE `old_target_db1`.`old_target_table1` TO `new_target_db1`.`new_target_table1`; RENAME TABLE `old_target_db2`.`old_target_table2` TO `new_target_db2`.`new_target_table2`", + MultipleTableInfos: []*common.TableInfo{ + sourceTable1.CloneWithRouting("new_target_db1", "new_target_table1"), + sourceTable2.CloneWithRouting("new_target_db2", "new_target_table2"), + }, + TableNameChange: &TableNameChange{ + DropName: []SchemaTableName{ + {SchemaName: "old_db1", TableName: "old_table1"}, + {SchemaName: "old_db2", TableName: "old_table2"}, + }, + }, + } + + events := ddl.GetEvents() + require.Len(t, events, 2) + + require.Equal(t, "new_db1", events[0].SchemaName) + require.Equal(t, "new_table1", events[0].TableName) + require.Equal(t, "new_target_db1", events[0].GetTargetSchemaName()) + require.Equal(t, "new_target_table1", events[0].GetTargetTableName()) + require.Equal(t, "old_db1", events[0].ExtraSchemaName) + require.Equal(t, "old_table1", events[0].ExtraTableName) + require.Equal(t, "old_target_db1", events[0].GetTargetExtraSchemaName()) + require.Equal(t, "old_target_table1", events[0].GetTargetExtraTableName()) + + require.Equal(t, "new_db2", events[1].SchemaName) + require.Equal(t, "new_table2", events[1].TableName) + require.Equal(t, "new_target_db2", events[1].GetTargetSchemaName()) + require.Equal(t, "new_target_table2", events[1].GetTargetTableName()) + require.Equal(t, "old_db2", events[1].ExtraSchemaName) + require.Equal(t, "old_table2", events[1].ExtraTableName) + require.Equal(t, "old_target_db2", events[1].GetTargetExtraSchemaName()) + require.Equal(t, "old_target_table2", events[1].GetTargetExtraTableName()) +} diff --git a/pkg/common/event/dml_event.go b/pkg/common/event/dml_event.go index 8868da40dd..b5d3b7d3a0 100644 --- a/pkg/common/event/dml_event.go +++ b/pkg/common/event/dml_event.go @@ -279,16 +279,32 @@ func (b *BatchDMLEvent) encodeV1() ([]byte, error) { // AssembleRows assembles the Rows from the RawRows. // It also sets the TableInfo and clears the RawRows. func (b *BatchDMLEvent) AssembleRows(tableInfo *common.TableInfo) { + if tableInfo == nil { + log.Panic("DMLEvent: TableInfo is nil") + } + defer func() { b.TableInfo.InitPrivateFields() }() - // rows is already set, no need to assemble again - // When the event is passed from the same node, the Rows is already set. + + // For local events (same node), rows are already set. if b.Rows != nil { - return - } - if tableInfo == nil { - log.Panic("DMLEvent: TableInfo is nil") + if !tableInfo.TableName.IsRouted() { + return + } + if b.TableInfo != nil { + originVersion := b.TableInfo.GetUpdateTS() + routedVersion := tableInfo.GetUpdateTS() + if originVersion != routedVersion { + log.Panic("table version mismatch when set routed table info", + zap.Uint64("originTableVersion", originVersion), + zap.Uint64("routedTableVersion", routedVersion)) + } + } + b.TableInfo = tableInfo + for _, dml := range b.DMLEvents { + dml.TableInfo = tableInfo + } return } @@ -297,10 +313,16 @@ func (b *BatchDMLEvent) AssembleRows(tableInfo *common.TableInfo) { return } - if b.TableInfo != nil && b.TableInfo.GetUpdateTS() != tableInfo.GetUpdateTS() { - log.Panic("DMLEvent: TableInfoVersion mismatch", zap.Uint64("dmlEventTableInfoVersion", b.TableInfo.GetUpdateTS()), zap.Uint64("tableInfoVersion", tableInfo.GetUpdateTS())) - return + if b.TableInfo != nil { + originVersion := b.TableInfo.GetUpdateTS() + routedVersion := tableInfo.GetUpdateTS() + if originVersion != routedVersion { + log.Panic("table version mismatch when decode remote raw rows", + zap.Uint64("originTableVersion", originVersion), + zap.Uint64("routedTableVersion", routedVersion)) + } } + decoder := chunk.NewCodec(tableInfo.GetFieldSlice()) b.Rows, _ = decoder.Decode(b.RawRows) b.TableInfo = tableInfo diff --git a/pkg/common/event/dml_event_test.go b/pkg/common/event/dml_event_test.go index 39f72f566e..9a0ed27e99 100644 --- a/pkg/common/event/dml_event_test.go +++ b/pkg/common/event/dml_event_test.go @@ -150,6 +150,104 @@ func TestBatchDMLEvent(t *testing.T) { require.Contains(t, err.Error(), "unsupported BatchDMLEvent version") } +func TestBatchDMLEventAssembleRowsRebindsRoutedTableInfoForLocalRows(t *testing.T) { + helper := NewEventTestHelper(t) + defer helper.Close() + + helper.tk.MustExec("use test") + helper.DDL2Job(createTableSQL) + + dmlEvent := helper.DML2Event("test", "t", insertDataSQL) + require.NotNil(t, dmlEvent) + + originTableInfo := dmlEvent.TableInfo + routedTableInfo := originTableInfo.CloneWithRouting("target_schema", "target_table") + require.NotNil(t, routedTableInfo) + + batchDMLEvent := &BatchDMLEvent{ + Version: BatchDMLEventVersion1, + DMLEventCount: 1, + DMLEvents: []*DMLEvent{dmlEvent}, + Rows: dmlEvent.Rows, + TableInfo: originTableInfo, + } + + batchDMLEvent.AssembleRows(routedTableInfo) + + require.Same(t, dmlEvent.Rows, batchDMLEvent.Rows) + require.Same(t, routedTableInfo, batchDMLEvent.TableInfo) + require.Same(t, routedTableInfo, batchDMLEvent.DMLEvents[0].TableInfo) + require.Equal(t, "target_schema", batchDMLEvent.TableInfo.GetTargetSchemaName()) + require.Equal(t, "target_table", batchDMLEvent.TableInfo.GetTargetTableName()) + require.Contains(t, batchDMLEvent.TableInfo.GetPreInsertSQL(), common.QuoteSchema("target_schema", "target_table")) +} + +func TestBatchDMLEventAssembleRowsKeepsOriginalTableInfoForLocalRowsWithoutRouting(t *testing.T) { + helper := NewEventTestHelper(t) + defer helper.Close() + + helper.tk.MustExec("use test") + helper.DDL2Job(createTableSQL) + + dmlEvent := helper.DML2Event("test", "t", insertDataSQL) + require.NotNil(t, dmlEvent) + + originTableInfo := dmlEvent.TableInfo + notRoutedTableInfo := originTableInfo.CloneWithRouting("", "") + require.NotNil(t, notRoutedTableInfo) + require.False(t, notRoutedTableInfo.TableName.IsRouted()) + notRoutedTableInfo.UpdateTS++ + + batchDMLEvent := &BatchDMLEvent{ + Version: BatchDMLEventVersion1, + DMLEventCount: 1, + DMLEvents: []*DMLEvent{dmlEvent}, + Rows: dmlEvent.Rows, + TableInfo: originTableInfo, + } + + batchDMLEvent.AssembleRows(notRoutedTableInfo) + + require.Same(t, dmlEvent.Rows, batchDMLEvent.Rows) + require.Same(t, originTableInfo, batchDMLEvent.TableInfo) + require.Same(t, originTableInfo, batchDMLEvent.DMLEvents[0].TableInfo) + require.Equal(t, "test", batchDMLEvent.TableInfo.GetTargetSchemaName()) + require.Equal(t, "t", batchDMLEvent.TableInfo.GetTargetTableName()) +} + +func TestBatchDMLEventAssembleRowsDecodesRemoteRawRows(t *testing.T) { + helper := NewEventTestHelper(t) + defer helper.Close() + + helper.tk.MustExec("use test") + helper.DDL2Job(createTableSQL) + + dmlEvent := helper.DML2Event("test", "t", insertDataSQL) + require.NotNil(t, dmlEvent) + + batchDMLEvent := &BatchDMLEvent{ + Version: BatchDMLEventVersion1, + DMLEventCount: 1, + DMLEvents: []*DMLEvent{dmlEvent}, + Rows: dmlEvent.Rows, + TableInfo: dmlEvent.TableInfo, + } + data, err := batchDMLEvent.Marshal() + require.NoError(t, err) + + reverseEvents := &BatchDMLEvent{} + err = reverseEvents.Unmarshal(data) + require.NoError(t, err) + require.Nil(t, reverseEvents.Rows) + require.NotEmpty(t, reverseEvents.RawRows) + + reverseEvents.AssembleRows(batchDMLEvent.TableInfo) + + require.Nil(t, reverseEvents.RawRows) + require.Same(t, batchDMLEvent.TableInfo, reverseEvents.TableInfo) + require.Equal(t, batchDMLEvent.Rows.ToString(batchDMLEvent.TableInfo.GetFieldSlice()), reverseEvents.Rows.ToString(batchDMLEvent.TableInfo.GetFieldSlice())) +} + func TestEncodeAnddecodeV1(t *testing.T) { helper := NewEventTestHelper(t) defer helper.Close() diff --git a/pkg/common/event/redo.go b/pkg/common/event/redo.go index f94e55bb4e..8ad7690b0f 100644 --- a/pkg/common/event/redo.go +++ b/pkg/common/event/redo.go @@ -17,7 +17,7 @@ import ( "fmt" "github.com/pingcap/log" - commonType "github.com/pingcap/ticdc/pkg/common" + "github.com/pingcap/ticdc/pkg/common" "github.com/pingcap/ticdc/pkg/util" timodel "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/parser/ast" @@ -51,10 +51,10 @@ type RedoDMLEvent struct { // RedoDDLEvent represents DDL event used in redo log persistent type RedoDDLEvent struct { - DDL *DDLEventInRedoLog `msg:"ddl"` - Type byte `msg:"type"` - TableName commonType.TableName `msg:"table-name"` - TableSchemaStore *TableSchemaStore `msg:"table-schema-store"` + DDL *DDLEventInRedoLog `msg:"ddl"` + Type byte `msg:"type"` + TableName common.TableName `msg:"table-name"` + TableSchemaStore *TableSchemaStore `msg:"table-schema-store"` } // DMLEventInRedoLog is used to store DMLEvent in redo log v2 format @@ -64,7 +64,7 @@ type DMLEventInRedoLog struct { // Table contains the table name and table ID. // NOTICE: We store the physical table ID here, not the logical table ID. - Table *commonType.TableName `msg:"table"` + Table *common.TableName `msg:"table"` Columns []*RedoColumn `msg:"columns"` PreColumns []*RedoColumn `msg:"pre-columns"` @@ -103,7 +103,7 @@ type RedoColumn struct { // RedoColumnValue stores Column change type RedoColumnValue struct { // Fields from Column and can't be marshaled directly in Column. - Value interface{} `msg:"column"` + Value any `msg:"column"` // msgp transforms empty byte slice into nil, PTAL msgp#247. ValueIsEmptyBytes bool `msg:"value-is-empty-bytes"` Flag uint64 `msg:"flag"` @@ -114,7 +114,7 @@ type RedoRowEvent struct { StartTs uint64 CommitTs uint64 PhysicalTableID int64 - TableInfo *commonType.TableInfo + TableInfo *common.TableInfo Event RowChange Callback func() } @@ -151,9 +151,9 @@ func (r *RedoRowEvent) ToRedoLog() *RedoLog { Type: RedoLogTypeRow, } if r.TableInfo != nil { - redoLog.RedoRow.Row.Table = &commonType.TableName{ - Schema: r.TableInfo.TableName.Schema, - Table: r.TableInfo.TableName.Table, + redoLog.RedoRow.Row.Table = &common.TableName{ + Schema: r.TableInfo.GetTargetSchemaName(), + Table: r.TableInfo.GetTargetTableName(), TableID: r.PhysicalTableID, IsPartition: r.TableInfo.TableName.IsPartition, } @@ -162,18 +162,18 @@ func (r *RedoRowEvent) ToRedoLog() *RedoLog { columnCount := len(r.TableInfo.GetColumns()) columns := make([]*RedoColumn, 0, columnCount) switch r.Event.RowType { - case commonType.RowTypeInsert: + case common.RowTypeInsert: redoLog.RedoRow.Columns = make([]RedoColumnValue, 0, columnCount) - case commonType.RowTypeDelete: + case common.RowTypeDelete: redoLog.RedoRow.PreColumns = make([]RedoColumnValue, 0, columnCount) - case commonType.RowTypeUpdate: + case common.RowTypeUpdate: redoLog.RedoRow.Columns = make([]RedoColumnValue, 0, columnCount) redoLog.RedoRow.PreColumns = make([]RedoColumnValue, 0, columnCount) default: } for i, column := range r.TableInfo.GetColumns() { - if commonType.IsColCDCVisible(column) { + if common.IsColCDCVisible(column) { columns = append(columns, &RedoColumn{ Name: column.Name.String(), Type: column.GetType(), @@ -182,13 +182,13 @@ func (r *RedoRowEvent) ToRedoLog() *RedoLog { }) isHandleKey := r.TableInfo.IsHandleKey(column.ID) switch r.Event.RowType { - case commonType.RowTypeInsert: + case common.RowTypeInsert: v := parseColumnValue(&r.Event.Row, column, i, isHandleKey) redoLog.RedoRow.Columns = append(redoLog.RedoRow.Columns, v) - case commonType.RowTypeDelete: + case common.RowTypeDelete: v := parseColumnValue(&r.Event.PreRow, column, i, isHandleKey) redoLog.RedoRow.PreColumns = append(redoLog.RedoRow.PreColumns, v) - case commonType.RowTypeUpdate: + case common.RowTypeUpdate: v := parseColumnValue(&r.Event.Row, column, i, isHandleKey) redoLog.RedoRow.Columns = append(redoLog.RedoRow.Columns, v) v = parseColumnValue(&r.Event.PreRow, column, i, isHandleKey) @@ -198,11 +198,11 @@ func (r *RedoRowEvent) ToRedoLog() *RedoLog { } } switch r.Event.RowType { - case commonType.RowTypeInsert: + case common.RowTypeInsert: redoLog.RedoRow.Row.Columns = columns - case commonType.RowTypeDelete: + case common.RowTypeDelete: redoLog.RedoRow.Row.PreColumns = columns - case commonType.RowTypeUpdate: + case common.RowTypeUpdate: redoLog.RedoRow.Row.Columns = columns redoLog.RedoRow.Row.PreColumns = columns } @@ -241,14 +241,19 @@ func (d *DDLEvent) ToRedoLog() *RedoLog { Type: RedoLogTypeDDL, } if d.TableInfo != nil { - redoLog.RedoDDL.TableName = d.TableInfo.TableName + redoLog.RedoDDL.TableName = common.TableName{ + Schema: d.TableInfo.GetTargetSchemaName(), + Table: d.TableInfo.GetTargetTableName(), + TableID: d.TableInfo.TableName.TableID, + IsPartition: d.TableInfo.TableName.IsPartition, + } } return redoLog } // GetCommitTs returns commit timestamp of the log event. -func (r *RedoLog) GetCommitTs() commonType.Ts { +func (r *RedoLog) GetCommitTs() common.Ts { switch r.Type { case RedoLogTypeRow: return r.RedoRow.Row.CommitTs @@ -301,7 +306,7 @@ func (r *RedoDMLEvent) ToDMLEvent() *DMLEvent { colInfo.SetType(col.Type) colInfo.SetCharset(col.Charset) colInfo.SetCollate(col.Collation) - flag := commonType.ColumnFlagType(rawColsValue[idx].Flag) + flag := common.ColumnFlagType(rawColsValue[idx].Flag) // if flag.IsHandleKey() { // } // if flag.IsBinary(){ @@ -350,7 +355,7 @@ func (r *RedoDMLEvent) ToDMLEvent() *DMLEvent { tidbTableInfo.Indices = append(tidbTableInfo.Indices, indexInfo) } event := &DMLEvent{ - TableInfo: commonType.NewTableInfo4Decoder(r.Row.Table.Schema, tidbTableInfo), + TableInfo: common.NewTableInfo4Decoder(r.Row.Table.Schema, tidbTableInfo), CommitTs: r.Row.CommitTs, StartTs: r.Row.StartTs, Length: 1, @@ -364,15 +369,15 @@ func (r *RedoDMLEvent) ToDMLEvent() *DMLEvent { columns := event.TableInfo.GetColumns() if r.IsDelete() { collectAllColumnsValue(r.PreColumns, columns, chk) - event.RowTypes = append(event.RowTypes, commonType.RowTypeDelete) + event.RowTypes = append(event.RowTypes, common.RowTypeDelete) } else if r.IsUpdate() { collectAllColumnsValue(r.PreColumns, columns, chk) collectAllColumnsValue(r.Columns, columns, chk) // FIXME: exclude columns with same value - event.RowTypes = append(event.RowTypes, commonType.RowTypeUpdate, commonType.RowTypeUpdate) + event.RowTypes = append(event.RowTypes, common.RowTypeUpdate, common.RowTypeUpdate) } else if r.IsInsert() { collectAllColumnsValue(r.Columns, columns, chk) - event.RowTypes = append(event.RowTypes, commonType.RowTypeInsert) + event.RowTypes = append(event.RowTypes, common.RowTypeInsert) } else { log.Panic("unknown event type for the DML event") } @@ -383,9 +388,11 @@ func (r *RedoDMLEvent) ToDMLEvent() *DMLEvent { func (r *RedoDDLEvent) ToDDLEvent() *DDLEvent { blockedTables := r.DDL.BlockedTables blockedTableNames := r.DDL.BlockedTableNames + schemaName := r.TableName.GetSchema() + tableName := r.TableName.GetTable() if blockedTables == nil { blockedTables = &InfluencedTables{InfluenceType: InfluenceTypeNormal} - blockedTableNames = []SchemaTableName{{SchemaName: r.TableName.Schema, TableName: r.TableName.Table}} + blockedTableNames = []SchemaTableName{{SchemaName: schemaName, TableName: tableName}} } columns := make([]*timodel.ColumnInfo, 0, len(r.DDL.Columns)) for _, col := range r.DDL.Columns { @@ -403,9 +410,9 @@ func (r *RedoDDLEvent) ToDDLEvent() *DDLEvent { } columns = append(columns, colInfo) } - tableInfo := commonType.WrapTableInfo(r.TableName.Schema, &timodel.TableInfo{ + tableInfo := common.WrapTableInfo(schemaName, &timodel.TableInfo{ ID: r.TableName.TableID, - Name: ast.NewCIStr(r.TableName.Table), + Name: ast.NewCIStr(tableName), Columns: columns, }) tableInfo.TableName.IsPartition = r.TableName.IsPartition @@ -413,8 +420,8 @@ func (r *RedoDDLEvent) ToDDLEvent() *DDLEvent { TableInfo: tableInfo, Query: r.DDL.Query, Type: r.Type, - SchemaName: r.TableName.Schema, - TableName: r.TableName.Table, + SchemaName: schemaName, + TableName: tableName, FinishedTs: r.DDL.CommitTs, StartTs: r.DDL.StartTs, BlockedTables: blockedTables, @@ -431,7 +438,7 @@ func (r *RedoDDLEvent) SetTableSchemaStore(tableSchemaStore *TableSchemaStore) { } func parseColumnValue(row *chunk.Row, colInfo *timodel.ColumnInfo, i int, isHandleKey bool) RedoColumnValue { - v := commonType.ExtractColVal(row, colInfo, i) + v := common.ExtractColVal(row, colInfo, i) switch colInfo.GetType() { case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, mysql.TypeBlob: @@ -451,7 +458,7 @@ func parseColumnValue(row *chunk.Row, colInfo *timodel.ColumnInfo, i int, isHand // For compatibility func convertFlag(colInfo *timodel.ColumnInfo, isHandleKey bool) uint64 { - var flag commonType.ColumnFlagType + var flag common.ColumnFlagType if isHandleKey { flag.SetIsHandleKey() } @@ -480,7 +487,7 @@ func convertFlag(colInfo *timodel.ColumnInfo, isHandleKey bool) uint64 { } // For compatibility -func getIndexColumns(tableInfo *commonType.TableInfo) [][]int { +func getIndexColumns(tableInfo *common.TableInfo) [][]int { indexColumns := make([][]int, 0, len(tableInfo.GetIndexColumns())) rowColumnsOffset := tableInfo.GetRowColumnsOffset() for _, index := range tableInfo.GetIndexColumns() { @@ -499,7 +506,7 @@ func collectAllColumnsValue(data []RedoColumnValue, columns []*timodel.ColumnInf } } -func appendCol2Chunk(idx int, raw interface{}, ft tiTypes.FieldType, chk *chunk.Chunk) { +func appendCol2Chunk(idx int, raw any, ft tiTypes.FieldType, chk *chunk.Chunk) { if raw == nil { chk.AppendNull(idx) return diff --git a/pkg/common/event/redo_gen.go b/pkg/common/event/redo_gen.go index 22e7354b8f..811d65f2d3 100644 --- a/pkg/common/event/redo_gen.go +++ b/pkg/common/event/redo_gen.go @@ -3,7 +3,7 @@ package event import ( - commonType "github.com/pingcap/ticdc/pkg/common" + "github.com/pingcap/ticdc/pkg/common" "github.com/tinylib/msgp/msgp" ) @@ -764,7 +764,7 @@ func (z *DMLEventInRedoLog) DecodeMsg(dc *msgp.Reader) (err error) { z.Table = nil } else { if z.Table == nil { - z.Table = new(commonType.TableName) + z.Table = new(common.TableName) } err = z.Table.DecodeMsg(dc) if err != nil { @@ -1092,7 +1092,7 @@ func (z *DMLEventInRedoLog) UnmarshalMsg(bts []byte) (o []byte, err error) { z.Table = nil } else { if z.Table == nil { - z.Table = new(commonType.TableName) + z.Table = new(common.TableName) } bts, err = z.Table.UnmarshalMsg(bts) if err != nil { diff --git a/pkg/common/event/redo_test.go b/pkg/common/event/redo_test.go index 03c09a4b72..adeac28168 100644 --- a/pkg/common/event/redo_test.go +++ b/pkg/common/event/redo_test.go @@ -17,6 +17,7 @@ import ( "testing" commonType "github.com/pingcap/ticdc/pkg/common" + "github.com/pingcap/tidb/pkg/meta/model" timodel "github.com/pingcap/tidb/pkg/meta/model" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/parser/mysql" @@ -24,6 +25,78 @@ import ( "github.com/stretchr/testify/require" ) +func TestRedoUsesRoutedTableNames(t *testing.T) { + t.Parallel() + + helper := NewEventTestHelper(t) + defer helper.Close() + + helper.Tk().MustExec("use test") + job := helper.DDL2Job(`create table test.t(id int primary key, name varchar(32))`) + require.NotNil(t, job) + + sourceTableInfo := helper.GetTableInfo(job) + routedTableInfo := sourceTableInfo.CloneWithRouting("target_db", "target_table") + + redoDDLEvent := (&DDLEvent{ + Query: "ALTER TABLE `target_db`.`target_table` ADD COLUMN age INT", + Type: byte(model.ActionAddColumn), + SchemaName: "test", + TableName: "t", + TableInfo: routedTableInfo, + FinishedTs: 200, + StartTs: 100, + }).ToRedoLog().RedoDDL + + require.Equal(t, "target_db", redoDDLEvent.TableName.Schema) + require.Equal(t, "target_table", redoDDLEvent.TableName.Table) + require.Empty(t, redoDDLEvent.TableName.TargetSchema) + require.Empty(t, redoDDLEvent.TableName.TargetTable) + + ddlEvent := redoDDLEvent.ToDDLEvent() + require.Equal(t, "target_db", ddlEvent.SchemaName) + require.Equal(t, "target_table", ddlEvent.TableName) + require.Empty(t, ddlEvent.targetSchemaName) + require.Empty(t, ddlEvent.targetTableName) + require.Equal(t, "target_db", ddlEvent.GetTargetSchemaName()) + require.Equal(t, "target_table", ddlEvent.GetTargetTableName()) + require.Equal(t, "target_db", ddlEvent.TableInfo.GetSchemaName()) + require.Equal(t, "target_table", ddlEvent.TableInfo.GetTableName()) + require.Empty(t, ddlEvent.TableInfo.TableName.TargetSchema) + require.Empty(t, ddlEvent.TableInfo.TableName.TargetTable) + require.Equal(t, "target_db", ddlEvent.TableInfo.GetTargetSchemaName()) + require.Equal(t, "target_table", ddlEvent.TableInfo.GetTargetTableName()) + require.Equal(t, []SchemaTableName{{ + SchemaName: "target_db", + TableName: "target_table", + }}, ddlEvent.BlockedTableNames) + + dmlEvent := helper.DML2Event("test", "t", `insert into test.t values (1, 'alice')`) + dmlEvent.TableInfo = routedTableInfo + + row, ok := dmlEvent.GetNextRow() + require.True(t, ok) + + redoRow := (&RedoRowEvent{ + StartTs: dmlEvent.StartTs, + CommitTs: dmlEvent.CommitTs, + PhysicalTableID: dmlEvent.PhysicalTableID, + TableInfo: routedTableInfo, + Event: row, + }).ToRedoLog().RedoRow + + require.Equal(t, "target_db", redoRow.Row.Table.Schema) + require.Equal(t, "target_table", redoRow.Row.Table.Table) + require.Empty(t, redoRow.Row.Table.TargetSchema) + require.Empty(t, redoRow.Row.Table.TargetTable) + + decoded := redoRow.ToDMLEvent() + require.Equal(t, "target_db", decoded.TableInfo.GetSchemaName()) + require.Equal(t, "target_table", decoded.TableInfo.GetTableName()) + require.Equal(t, "target_db", decoded.TableInfo.GetTargetSchemaName()) + require.Equal(t, "target_table", decoded.TableInfo.GetTargetTableName()) +} + func TestRedoDDLEventRoundTripPreservesColumnMetadata(t *testing.T) { originalTableInfo := commonType.NewTableInfo4Decoder("test", &timodel.TableInfo{ ID: 1001, diff --git a/pkg/common/table_info.go b/pkg/common/table_info.go index c7e747ad02..86cbbbec12 100644 --- a/pkg/common/table_info.go +++ b/pkg/common/table_info.go @@ -133,13 +133,56 @@ func (ti *TableInfo) InitPrivateFields() { return } - ti.preSQLs.m[preSQLInsert] = fmt.Sprintf(ti.columnSchema.PreSQLs[preSQLInsert], ti.TableName.QuoteString()) - ti.preSQLs.m[preSQLReplace] = fmt.Sprintf(ti.columnSchema.PreSQLs[preSQLReplace], ti.TableName.QuoteString()) - ti.preSQLs.m[preSQLUpdate] = fmt.Sprintf(ti.columnSchema.PreSQLs[preSQLUpdate], ti.TableName.QuoteString()) + ti.preSQLs.m[preSQLInsert] = fmt.Sprintf(ti.columnSchema.PreSQLs[preSQLInsert], ti.TableName.QuoteTargetString()) + ti.preSQLs.m[preSQLReplace] = fmt.Sprintf(ti.columnSchema.PreSQLs[preSQLReplace], ti.TableName.QuoteTargetString()) + ti.preSQLs.m[preSQLUpdate] = fmt.Sprintf(ti.columnSchema.PreSQLs[preSQLUpdate], ti.TableName.QuoteTargetString()) ti.preSQLs.isInitialized.Store(true) } +// CloneWithRouting creates a shallow copy of TableInfo with routing applied. +// The new TableInfo shares the same columnSchema, View, Sequence pointers +// but has its own TableName (with TargetSchema/TargetTable set) and uninitialized preSQLs. +// This is safe because: +// - columnSchema, View, Sequence are read-only after creation +// - preSQLs will be initialized later via InitPrivateFields() using the new TableName +// - TableName is a value type that gets copied +func (ti *TableInfo) CloneWithRouting(targetSchema, targetTable string) *TableInfo { + if ti == nil { + return nil + } + // Create a new TableInfo with copied basic fields + cloned := &TableInfo{ + TableName: ti.TableName, // Value copy of TableName struct + Charset: ti.Charset, + Collate: ti.Collate, + Comment: ti.Comment, + columnSchema: ti.columnSchema, // Share the pointer (read-only) + HasPKOrNotNullUK: ti.HasPKOrNotNullUK, + View: ti.View, // Share the pointer (read-only) + Sequence: ti.Sequence, // Share the pointer (read-only) + UpdateTS: ti.UpdateTS, + ActiveActiveTable: ti.ActiveActiveTable, + SoftDeleteTable: ti.SoftDeleteTable, + // preSQLs is zero-initialized (uninitialized mutex/atomic, empty strings) + } + // Apply routing to the cloned TableName while keeping Schema/Table as source names. + cloned.TableName.TargetSchema = targetSchema + cloned.TableName.TargetTable = targetTable + + // Increment refcount for the shared columnSchema and set finalizer to decrement + // when the clone is garbage collected. This prevents use-after-free if the + // original TableInfo is GC'd before the clone. + if ti.columnSchema != nil { + GetSharedColumnSchemaStorage().incColumnSchemaCount(ti.columnSchema) + runtime.SetFinalizer(cloned, func(ti *TableInfo) { + GetSharedColumnSchemaStorage().tryReleaseColumnSchema(ti.columnSchema) + }) + } + + return cloned +} + func (ti *TableInfo) Marshal() ([]byte, error) { // otherField | columnSchemaData | columnSchemaDataSize data, err := json.Marshal(ti) @@ -346,16 +389,16 @@ func (ti *TableInfo) MustGetColumnOffsetByID(id int64) int { // GetSchemaName returns the schema name of the table func (ti *TableInfo) GetSchemaName() string { - return ti.TableName.Schema + return ti.TableName.GetSchema() } // GetTableName returns the table name of the table func (ti *TableInfo) GetTableName() string { - return ti.TableName.Table + return ti.TableName.GetTable() } func (ti *TableInfo) GetTableNameCIStr() ast.CIStr { - return ast.NewCIStr(ti.TableName.Table) + return ast.NewCIStr(ti.GetTableName()) } // GetSchemaNamePtr returns the pointer to the schema name of the table @@ -373,6 +416,18 @@ func (ti *TableInfo) IsPartitionTable() bool { return ti.TableName.IsPartition } +// GetTargetSchemaName returns the target schema name for routing. +// If TargetSchema is empty, returns Schema. +func (ti *TableInfo) GetTargetSchemaName() string { + return ti.TableName.GetTargetSchema() +} + +// GetTargetTableName returns the target table name for routing. +// If TargetTable is empty, returns Table. +func (ti *TableInfo) GetTargetTableName() string { + return ti.TableName.GetTargetTable() +} + // IsView checks if TableInfo is a view. func (t *TableInfo) IsView() bool { return t.View != nil diff --git a/pkg/common/table_info_test.go b/pkg/common/table_info_test.go index 55c7e3fb22..5be97d2f37 100644 --- a/pkg/common/table_info_test.go +++ b/pkg/common/table_info_test.go @@ -24,6 +24,78 @@ import ( "github.com/stretchr/testify/require" ) +func TestCloneWithRouting(t *testing.T) { + t.Parallel() + + t.Run("nil TableInfo", func(t *testing.T) { + var ti *TableInfo + cloned := ti.CloneWithRouting("target_schema", "target_table") + require.Nil(t, cloned) + }) + + t.Run("basic cloning with routing", func(t *testing.T) { + original := &TableInfo{ + TableName: TableName{ + Schema: "source_db", + Table: "source_table", + TableID: 123, + }, + Charset: "utf8mb4", + Collate: "utf8mb4_bin", + Comment: "test table", + HasPKOrNotNullUK: true, + UpdateTS: 1000, + } + + cloned := original.CloneWithRouting("target_db", "target_table") + + // Verify cloned has routing applied + require.Equal(t, "source_db", cloned.TableName.Schema) + require.Equal(t, "source_table", cloned.TableName.Table) + require.Equal(t, "target_db", cloned.TableName.TargetSchema) + require.Equal(t, "target_table", cloned.TableName.TargetTable) + require.Equal(t, int64(123), cloned.TableName.TableID) + require.Equal(t, "source_db", cloned.GetSchemaName()) + require.Equal(t, "source_table", cloned.GetTableName()) + require.Equal(t, "target_db", cloned.GetTargetSchemaName()) + require.Equal(t, "target_table", cloned.GetTargetTableName()) + require.Equal(t, "source_db.source_table", cloned.TableName.String()) + require.Equal(t, "`source_db`.`source_table`", cloned.TableName.QuoteString()) + require.Equal(t, "`target_db`.`target_table`", cloned.TableName.QuoteTargetString()) + require.True(t, cloned.TableName.IsRouted()) + require.Same(t, &cloned.TableName.Schema, cloned.GetSchemaNamePtr()) + require.Same(t, &cloned.TableName.Table, cloned.GetTableNamePtr()) + + // Verify other fields are copied + require.Equal(t, "utf8mb4", cloned.Charset) + require.Equal(t, "utf8mb4_bin", cloned.Collate) + require.Equal(t, "test table", cloned.Comment) + require.Equal(t, true, cloned.HasPKOrNotNullUK) + require.Equal(t, uint64(1000), cloned.UpdateTS) + + // Verify original is NOT modified + require.Equal(t, "", original.TableName.TargetSchema) + require.Equal(t, "", original.TableName.TargetTable) + }) + + t.Run("target getters remain available without changing source fields", func(t *testing.T) { + original := &TableInfo{ + TableName: TableName{ + Schema: "source_db", + Table: "source_table", + TableID: 123, + }, + } + + cloned := original.CloneWithRouting("target_db", "target_table") + + require.Equal(t, "source_db", cloned.GetSchemaName()) + require.Equal(t, "source_table", cloned.GetTableName()) + require.Equal(t, "target_db", cloned.GetTargetSchemaName()) + require.Equal(t, "target_table", cloned.GetTargetTableName()) + }) +} + func TestUnmarshalJSONToTableInfoInvalidData(t *testing.T) { t.Parallel() diff --git a/pkg/common/table_name.go b/pkg/common/table_name.go index 646a18c6a7..dbf5121a77 100644 --- a/pkg/common/table_name.go +++ b/pkg/common/table_name.go @@ -19,16 +19,22 @@ import ( //go:generate msgp -// TableName represents name of a table, includes table name and schema name. +// TableName represents name of a table, includes table name and schema name type TableName struct { Schema string `toml:"db-name" msg:"db-name"` Table string `toml:"tbl-name" msg:"tbl-name"` // TableID is the logic table ID TableID int64 `toml:"tbl-id" msg:"tbl-id"` IsPartition bool `toml:"is-partition" msg:"is-partition"` + + // TargetSchema and TargetTable are used as an in-memory routing overlay. + // They are intentionally excluded from msgpack serialization because redo + // persists routed names canonically in Schema/Table. + TargetSchema string `toml:"target-db-name" msg:"-"` + TargetTable string `toml:"target-tbl-name" msg:"-"` } -// String implements fmt.Stringer interface. +// String implements fmt.Stringer interface func (t TableName) String() string { return fmt.Sprintf("%s.%s", t.Schema, t.Table) } @@ -37,3 +43,41 @@ func (t TableName) String() string { func (t TableName) QuoteString() string { return QuoteSchema(t.Schema, t.Table) } + +// GetSchema returns the schema name +func (t *TableName) GetSchema() string { + return t.Schema +} + +// GetTable returns the table name +func (t *TableName) GetTable() string { + return t.Table +} + +// IsRouted returns whether table routing is enabled +func (t *TableName) IsRouted() bool { + return t.TargetSchema != "" || t.TargetTable != "" +} + +// GetTargetSchema returns the target schema name for routing. +// If TargetSchema is empty, returns Schema +func (t *TableName) GetTargetSchema() string { + if t.TargetSchema != "" { + return t.TargetSchema + } + return t.Schema +} + +// GetTargetTable returns the target table name for routing +// If TargetTable is empty, returns Table. +func (t *TableName) GetTargetTable() string { + if t.TargetTable != "" { + return t.TargetTable + } + return t.Table +} + +// QuoteTargetString returns quoted full target table name for routing +func (t TableName) QuoteTargetString() string { + return QuoteSchema(t.GetTargetSchema(), t.GetTargetTable()) +} diff --git a/pkg/common/table_name_test.go b/pkg/common/table_name_test.go new file mode 100644 index 0000000000..54a2f914cb --- /dev/null +++ b/pkg/common/table_name_test.go @@ -0,0 +1,118 @@ +// Copyright 2026 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package common + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTableNameIsRouted(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tableName TableName + expected bool + }{ + { + name: "not routed", + tableName: TableName{Schema: "test", Table: "t"}, + expected: false, + }, + { + name: "target schema only", + tableName: TableName{Schema: "test", Table: "t", TargetSchema: "target"}, + expected: true, + }, + { + name: "target table only", + tableName: TableName{Schema: "test", Table: "t", TargetTable: "target_t"}, + expected: true, + }, + { + name: "target schema and table", + tableName: TableName{Schema: "test", Table: "t", TargetSchema: "target", TargetTable: "target_t"}, + expected: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.tableName.IsRouted(); got != tt.expected { + t.Fatalf("IsRouted() = %v, expected %v", got, tt.expected) + } + }) + } +} + +func TestTableNameTargetAccessors(t *testing.T) { + t.Parallel() + + t.Run("fallback to source names", func(t *testing.T) { + tableName := TableName{ + Schema: "source_db", + Table: "source_table", + } + + require.Equal(t, "source_db", tableName.GetTargetSchema()) + require.Equal(t, "source_table", tableName.GetTargetTable()) + require.Equal(t, "`source_db`.`source_table`", tableName.QuoteTargetString()) + }) + + t.Run("use routed names when present", func(t *testing.T) { + tableName := TableName{ + Schema: "source_db", + Table: "source_table", + TargetSchema: "target_db", + TargetTable: "target_table", + } + + require.Equal(t, "target_db", tableName.GetTargetSchema()) + require.Equal(t, "target_table", tableName.GetTargetTable()) + require.Equal(t, "`target_db`.`target_table`", tableName.QuoteTargetString()) + }) +} + +func TestTableNameMsgpackRoundTripDropsRoutingOverlay(t *testing.T) { + t.Parallel() + + original := TableName{ + Schema: "source_db", + Table: "source_table", + TableID: 42, + IsPartition: true, + TargetSchema: "target_db", + TargetTable: "target_table", + } + + data, err := original.MarshalMsg(nil) + require.NoError(t, err) + + var decoded TableName + rest, err := decoded.UnmarshalMsg(data) + require.NoError(t, err) + require.Empty(t, rest) + require.Equal(t, original.Schema, decoded.Schema) + require.Equal(t, original.Table, decoded.Table) + require.Equal(t, original.TableID, decoded.TableID) + require.Equal(t, original.IsPartition, decoded.IsPartition) + require.Empty(t, decoded.TargetSchema) + require.Empty(t, decoded.TargetTable) + require.Equal(t, original.Schema, decoded.GetTargetSchema()) + require.Equal(t, original.Table, decoded.GetTargetTable()) +}