From a5649221d53e06714c9e05626e33eb91e3d080d2 Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Tue, 31 Mar 2026 10:57:27 +0100 Subject: [PATCH 1/4] Add stream status trace messages for Airbyte protocol v2 Airbyte 2.x requires sources to emit STREAM_STATUS trace messages (STARTED, COMPLETE, INCOMPLETE) for each stream. Without these, every sync fails with: "streams did not receive a terminal stream status message" Changes: - Add TRACE message type and stream status constants to types.go - Add StreamDescriptor, AirbyteStreamStatus, AirbyteTraceMessage types - Replace legacy global State() with per-stream StreamState() that emits state.type=STREAM (required by Airbyte 2.x, which rejects the LEGACY format with IllegalArgumentException) - Add StreamStatus() method to emit STARTED/COMPLETE/INCOMPLETE traces - Update AirbyteLogger interface and test mock accordingly --- cmd/internal/logger.go | 38 +++++++++++++++++++++++++--- cmd/internal/mock_types.go | 9 ++++--- cmd/internal/types.go | 52 +++++++++++++++++++++++++++++++++----- 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/cmd/internal/logger.go b/cmd/internal/logger.go index 0ada8c0..15aa50a 100644 --- a/cmd/internal/logger.go +++ b/cmd/internal/logger.go @@ -14,8 +14,9 @@ type AirbyteLogger interface { ConnectionStatus(status ConnectionStatus) Record(tableNamespace, tableName string, data map[string]interface{}) Flush() - State(syncState SyncState) + StreamState(namespace, streamName string, shardStates ShardStates) Error(error string) + StreamStatus(namespace, streamName, status string) } const MaxBatchSize = 10000 @@ -82,10 +83,19 @@ func (a *airbyteLogger) Flush() { a.records = a.records[:0] } -func (a *airbyteLogger) State(syncState SyncState) { +func (a *airbyteLogger) StreamState(namespace, streamName string, shardStates ShardStates) { if err := a.recordEncoder.Encode(AirbyteMessage{ - Type: STATE, - State: &AirbyteState{syncState}, + Type: STATE, + State: &AirbyteState{ + Type: STATE_TYPE_STREAM, + Stream: &AirbyteStreamState{ + StreamDescriptor: StreamDescriptor{ + Name: streamName, + Namespace: namespace, + }, + StreamState: &shardStates, + }, + }, }); err != nil { a.Error(fmt.Sprintf("state encoding error: %v", err)) } @@ -103,6 +113,26 @@ func (a *airbyteLogger) Error(error string) { } } +func (a *airbyteLogger) StreamStatus(namespace, streamName, status string) { + now := time.Now() + if err := a.recordEncoder.Encode(AirbyteMessage{ + Type: TRACE, + Trace: &AirbyteTraceMessage{ + Type: TRACE_TYPE_STREAM_STATUS, + EmittedAt: float64(now.UnixMilli()), + StreamStatus: &AirbyteStreamStatus{ + StreamDescriptor: StreamDescriptor{ + Name: streamName, + Namespace: namespace, + }, + Status: status, + }, + }, + }); err != nil { + a.Error(fmt.Sprintf("stream status encoding error: %v", err)) + } +} + func (a *airbyteLogger) ConnectionStatus(status ConnectionStatus) { if err := a.recordEncoder.Encode(AirbyteMessage{ Type: CONNECTION_STATUS, diff --git a/cmd/internal/mock_types.go b/cmd/internal/mock_types.go index 742b822..2e7dc04 100644 --- a/cmd/internal/mock_types.go +++ b/cmd/internal/mock_types.go @@ -50,9 +50,8 @@ func (tal *testAirbyteLogger) Record(tableNamespace, tableName string, data map[ func (testAirbyteLogger) Flush() { } -func (testAirbyteLogger) State(syncState SyncState) { - // TODO implement me - panic("implement me") +func (testAirbyteLogger) StreamState(namespace, streamName string, shardStates ShardStates) { + // no-op for tests } func (testAirbyteLogger) Error(error string) { @@ -60,6 +59,10 @@ func (testAirbyteLogger) Error(error string) { panic("implement me") } +func (testAirbyteLogger) StreamStatus(namespace, streamName, status string) { + // no-op for tests +} + type vstreamClientMock struct { vstreamFn func(ctx context.Context, in *vtgate.VStreamRequest, opts ...grpc.CallOption) (vtgateservice.Vitess_VStreamClient, error) vstreamFnInvoked bool diff --git a/cmd/internal/types.go b/cmd/internal/types.go index 19b5f64..a24a099 100644 --- a/cmd/internal/types.go +++ b/cmd/internal/types.go @@ -21,6 +21,17 @@ const ( LOG = "LOG" CONNECTION_STATUS = "CONNECTION_STATUS" CATALOG = "CATALOG" + TRACE = "TRACE" +) + +const ( + TRACE_TYPE_STREAM_STATUS = "STREAM_STATUS" +) + +const ( + STREAM_STATUS_STARTED = "STARTED" + STREAM_STATUS_COMPLETE = "COMPLETE" + STREAM_STATUS_INCOMPLETE = "INCOMPLETE" ) const ( @@ -385,17 +396,44 @@ func mapEnumValue(value sqltypes.Value, values []string) sqltypes.Value { return value } +const ( + STATE_TYPE_STREAM = "STREAM" +) + +type AirbyteStreamState struct { + StreamDescriptor StreamDescriptor `json:"stream_descriptor"` + StreamState *ShardStates `json:"stream_state"` +} + type AirbyteState struct { - Data SyncState `json:"data"` + Type string `json:"type"` + Stream *AirbyteStreamState `json:"stream,omitempty"` +} + +type StreamDescriptor struct { + Name string `json:"name"` + Namespace string `json:"namespace"` +} + +type AirbyteStreamStatus struct { + StreamDescriptor StreamDescriptor `json:"stream_descriptor"` + Status string `json:"status"` +} + +type AirbyteTraceMessage struct { + Type string `json:"type"` + EmittedAt float64 `json:"emitted_at"` + StreamStatus *AirbyteStreamStatus `json:"stream_status,omitempty"` } type AirbyteMessage struct { - Type string `json:"type"` - Log *AirbyteLogMessage `json:"log,omitempty"` - ConnectionStatus *ConnectionStatus `json:"connectionStatus,omitempty"` - Catalog *Catalog `json:"catalog,omitempty"` - Record *AirbyteRecord `json:"record,omitempty"` - State *AirbyteState `json:"state,omitempty"` + Type string `json:"type"` + Log *AirbyteLogMessage `json:"log,omitempty"` + ConnectionStatus *ConnectionStatus `json:"connectionStatus,omitempty"` + Catalog *Catalog `json:"catalog,omitempty"` + Record *AirbyteRecord `json:"record,omitempty"` + State *AirbyteState `json:"state,omitempty"` + Trace *AirbyteTraceMessage `json:"trace,omitempty"` } // A map of starting GTIDs for every keyspace and shard From 03b5915f46d2656b49d6fdeffa5168b78079709b Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Tue, 31 Mar 2026 10:57:39 +0100 Subject: [PATCH 2/4] Emit per-stream status and state in read loop, handle v2 state input Update the read command to be fully compatible with Airbyte 2.x: Read loop changes: - Emit STARTED before reading each stream - Emit COMPLETE after successful read, INCOMPLETE on error - Replace os.Exit(1) with break on per-stream errors so remaining streams still get status messages - Emit per-stream STATE (type=STREAM) after each stream completes instead of one global state blob at the end State parsing changes: - Handle Airbyte v2 per-stream state format on incremental syncs. Airbyte 2.x passes state back as a JSON array of per-stream state objects, not the legacy global SyncState blob. Without this, the second sync always fails because json.Unmarshal fails on the array format, causing os.Exit(1) before any streams are processed. - Fall back to legacy format for backwards compatibility - Default empty namespace to source database name to prevent state key mismatches --- cmd/airbyte-source/read.go | 43 +++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/cmd/airbyte-source/read.go b/cmd/airbyte-source/read.go index 09b8031..c56e04b 100644 --- a/cmd/airbyte-source/read.go +++ b/cmd/airbyte-source/read.go @@ -109,9 +109,13 @@ func ReadCommand(ch *Helper) *cobra.Command { streamState, ok := syncState.Streams[streamStateKey] if !ok { ch.Logger.Error(fmt.Sprintf("Unable to read state for stream %v", streamStateKey)) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) os.Exit(1) } + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_STARTED) + + streamFailed := false for shardName, shardState := range streamState.Shards { var tc *psdbconnectv1alpha1.TableCursor @@ -119,21 +123,27 @@ func ReadCommand(ch *Helper) *cobra.Command { ch.Logger.Log(internal.LOGLEVEL_INFO, fmt.Sprintf("Using serialized cursor for stream %s", streamStateKey)) if err != nil { ch.Logger.Error(fmt.Sprintf("Invalid serialized cursor for stream %v, failed with [%v]", streamStateKey, err)) - os.Exit(1) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) + streamFailed = true + break } sc, err := ch.Database.Read(ctx, cmd.OutOrStdout(), psc, configuredStream, tc) if err != nil { ch.Logger.Error(err.Error()) - os.Exit(1) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) + streamFailed = true + break } if sc != nil { - // if we get any new state, we assign it here. - // otherwise, the older state is round-tripped back to Airbyte. syncState.Streams[streamStateKey].Shards[shardName] = sc } - ch.Logger.State(syncState) + } + + if !streamFailed { + ch.Logger.StreamState(keyspaceOrDatabase, configuredStream.Stream.Name, syncState.Streams[streamStateKey]) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_COMPLETE) } } }, @@ -153,9 +163,26 @@ func readState(state string, psc internal.PlanetScaleSource, streams []internal. Streams: map[string]internal.ShardStates{}, } if state != "" { - err := json.Unmarshal([]byte(state), &syncState) - if err != nil { - return syncState, err + // Try parsing as Airbyte v2 per-stream state array first + var perStreamStates []internal.AirbyteState + if err := json.Unmarshal([]byte(state), &perStreamStates); err == nil && len(perStreamStates) > 0 && perStreamStates[0].Type == internal.STATE_TYPE_STREAM { + logger.Log(internal.LOGLEVEL_INFO, fmt.Sprintf("Parsing Airbyte v2 per-stream state (%d streams)", len(perStreamStates))) + for _, s := range perStreamStates { + if s.Stream != nil && s.Stream.StreamState != nil { + ns := s.Stream.StreamDescriptor.Namespace + if ns == "" { + ns = psc.Database + } + key := ns + ":" + s.Stream.StreamDescriptor.Name + syncState.Streams[key] = *s.Stream.StreamState + } + } + } else { + // Fall back to legacy global state format + err := json.Unmarshal([]byte(state), &syncState) + if err != nil { + return syncState, err + } } } From de00c59bff8021c5d14a202d5bdd0329206742ce Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Tue, 31 Mar 2026 10:57:47 +0100 Subject: [PATCH 3/4] Add tests for Airbyte protocol v2 compliance Logger tests: - StreamState emits correct per-stream format with type=STREAM - Multiple shards included in state output - No legacy "data" field present (would cause LEGACY rejection) - StreamStatus emits TRACE messages with correct status values - JSON round-trip matches exact Airbyte protocol v2 structure Read protocol tests: - Read emits per-stream STATE, not legacy global state - STARTED and COMPLETE emitted for each configured stream - Correct message ordering: STARTED -> STATE -> COMPLETE - Multi-shard state contains all shard cursors - Read errors emit INCOMPLETE and skip state emission --- cmd/airbyte-source/read_protocol_test.go | 373 +++++++++++++++++++++++ cmd/internal/logger_test.go | 194 ++++++++++++ 2 files changed, 567 insertions(+) create mode 100644 cmd/airbyte-source/read_protocol_test.go create mode 100644 cmd/internal/logger_test.go diff --git a/cmd/airbyte-source/read_protocol_test.go b/cmd/airbyte-source/read_protocol_test.go new file mode 100644 index 0000000..3f9efce --- /dev/null +++ b/cmd/airbyte-source/read_protocol_test.go @@ -0,0 +1,373 @@ +package airbyte_source + +import ( + "bytes" + "context" + "encoding/json" + "io" + "os" + "testing" + + "github.com/planetscale/airbyte-source/cmd/internal" + psdbconnect "github.com/planetscale/airbyte-source/proto/psdbconnect/v1alpha1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockDatabase implements internal.PlanetScaleDatabase for read protocol tests. +type mockDatabase struct { + shards []string + readFunc func(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) + readCalls int +} + +func (m *mockDatabase) CanConnect(ctx context.Context, ps internal.PlanetScaleSource) error { + return nil +} + +func (m *mockDatabase) DiscoverSchema(ctx context.Context, ps internal.PlanetScaleSource) (internal.Catalog, error) { + return internal.Catalog{}, nil +} + +func (m *mockDatabase) ListShards(ctx context.Context, ps internal.PlanetScaleSource) ([]string, error) { + return m.shards, nil +} + +func (m *mockDatabase) Read(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) { + m.readCalls++ + if m.readFunc != nil { + return m.readFunc(ctx, w, ps, s, tc) + } + newCursor, _ := internal.TableCursorToSerializedCursor(&psdbconnect.TableCursor{ + Shard: tc.Shard, + Keyspace: tc.Keyspace, + Position: "MySQL56/updated-position", + }) + return newCursor, nil +} + +func (m *mockDatabase) Close() error { + return nil +} + +func newTestConfig() []byte { + return []byte(`{"host":"test.psdb.cloud","database":"testdb","username":"user","password":"pass"}`) +} + +func newTestCatalog(t *testing.T, streams ...string) string { + t.Helper() + catalog := internal.ConfiguredCatalog{} + for _, name := range streams { + catalog.Streams = append(catalog.Streams, internal.ConfiguredStream{ + Stream: internal.Stream{ + Name: name, + Namespace: "testdb", + }, + SyncMode: "full_refresh", + }) + } + b, err := json.Marshal(catalog) + require.NoError(t, err) + return string(b) +} + +func writeTempFile(t *testing.T, content []byte) string { + t.Helper() + f, err := os.CreateTemp(t.TempDir(), "*.json") + require.NoError(t, err) + _, err = f.Write(content) + require.NoError(t, err) + require.NoError(t, f.Close()) + return f.Name() +} + +func parseOutputMessages(t *testing.T, buf *bytes.Buffer) []internal.AirbyteMessage { + t.Helper() + var messages []internal.AirbyteMessage + decoder := json.NewDecoder(buf) + for decoder.More() { + var msg internal.AirbyteMessage + if err := decoder.Decode(&msg); err != nil { + break + } + messages = append(messages, msg) + } + return messages +} + +func setupReadCommand(t *testing.T, db *mockDatabase, catalogJSON string) (*bytes.Buffer, *Helper) { + t.Helper() + b := bytes.NewBufferString("") + h := &Helper{ + Database: db, + FileReader: testFileReader{content: newTestConfig()}, + Logger: internal.NewLogger(b), + } + return b, h +} + +func TestRead_EmitsPerStreamStateNotLegacy(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "users") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + var stateMessages []internal.AirbyteMessage + for _, msg := range messages { + if msg.Type == internal.STATE { + stateMessages = append(stateMessages, msg) + } + } + + require.NotEmpty(t, stateMessages, "should emit at least one STATE message") + + for _, msg := range stateMessages { + assert.Equal(t, internal.STATE_TYPE_STREAM, msg.State.Type, + "state.type must be STREAM, not LEGACY") + require.NotNil(t, msg.State.Stream, + "state.stream must be present") + assert.NotEmpty(t, msg.State.Stream.StreamDescriptor.Name, + "stream_descriptor.name must be set") + assert.NotEmpty(t, msg.State.Stream.StreamDescriptor.Namespace, + "stream_descriptor.namespace must be set") + require.NotNil(t, msg.State.Stream.StreamState, + "stream_state must be present") + } +} + +func TestRead_EmitsStartedAndCompletePerStream(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "orders", "products") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + type streamStatusEntry struct { + Name string + Status string + } + var statuses []streamStatusEntry + for _, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.Type == internal.TRACE_TYPE_STREAM_STATUS && + msg.Trace.StreamStatus != nil { + statuses = append(statuses, streamStatusEntry{ + Name: msg.Trace.StreamStatus.StreamDescriptor.Name, + Status: msg.Trace.StreamStatus.Status, + }) + } + } + + expectedStatuses := []streamStatusEntry{ + {"orders", internal.STREAM_STATUS_STARTED}, + {"orders", internal.STREAM_STATUS_COMPLETE}, + {"products", internal.STREAM_STATUS_STARTED}, + {"products", internal.STREAM_STATUS_COMPLETE}, + } + assert.Equal(t, expectedStatuses, statuses) +} + +func TestRead_StatePerStreamContainsCorrectDescriptor(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "accounts", "sessions") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + statesByStream := map[string]internal.AirbyteMessage{} + for _, msg := range messages { + if msg.Type == internal.STATE { + name := msg.State.Stream.StreamDescriptor.Name + statesByStream[name] = msg + } + } + + assert.Contains(t, statesByStream, "accounts") + assert.Contains(t, statesByStream, "sessions") + assert.Equal(t, "testdb", statesByStream["accounts"].State.Stream.StreamDescriptor.Namespace) + assert.Equal(t, "testdb", statesByStream["sessions"].State.Stream.StreamDescriptor.Namespace) +} + +func TestRead_StateEmittedAfterStartedBeforeComplete(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "events") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + startedIdx := -1 + stateIdx := -1 + completeIdx := -1 + + for i, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.StreamStatus != nil && + msg.Trace.StreamStatus.StreamDescriptor.Name == "events" { + if msg.Trace.StreamStatus.Status == internal.STREAM_STATUS_STARTED { + startedIdx = i + } + if msg.Trace.StreamStatus.Status == internal.STREAM_STATUS_COMPLETE { + completeIdx = i + } + } + if msg.Type == internal.STATE && msg.State != nil && + msg.State.Stream != nil && + msg.State.Stream.StreamDescriptor.Name == "events" { + stateIdx = i + } + } + + require.Greater(t, startedIdx, -1, "STARTED should be emitted") + require.Greater(t, stateIdx, -1, "STATE should be emitted") + require.Greater(t, completeIdx, -1, "COMPLETE should be emitted") + + assert.Less(t, startedIdx, stateIdx, "STARTED should come before STATE") + assert.Less(t, stateIdx, completeIdx, "STATE should come before COMPLETE") +} + +func TestRead_MultiShardStateContainsAllShards(t *testing.T) { + db := &mockDatabase{shards: []string{"-80", "80-"}} + catalogJSON := newTestCatalog(t, "data") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + var stateMsg *internal.AirbyteMessage + for _, msg := range messages { + if msg.Type == internal.STATE { + stateMsg = &msg + } + } + + require.NotNil(t, stateMsg, "should have a STATE message") + require.NotNil(t, stateMsg.State.Stream.StreamState) + assert.Len(t, stateMsg.State.Stream.StreamState.Shards, 2, + "state should contain both shards") + assert.Contains(t, stateMsg.State.Stream.StreamState.Shards, "-80") + assert.Contains(t, stateMsg.State.Stream.StreamState.Shards, "80-") +} + +func TestRead_ReadErrorEmitsIncompleteNotComplete(t *testing.T) { + db := &mockDatabase{ + shards: []string{"-"}, + readFunc: func(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) { + if s.Stream.Name == "bad_table" { + return nil, assert.AnError + } + newCursor, _ := internal.TableCursorToSerializedCursor(&psdbconnect.TableCursor{ + Shard: tc.Shard, + Keyspace: tc.Keyspace, + Position: "MySQL56/pos", + }) + return newCursor, nil + }, + } + + catalog := internal.ConfiguredCatalog{ + Streams: []internal.ConfiguredStream{ + { + Stream: internal.Stream{Name: "good_table", Namespace: "testdb"}, + SyncMode: "full_refresh", + }, + { + Stream: internal.Stream{Name: "bad_table", Namespace: "testdb"}, + SyncMode: "full_refresh", + }, + }, + } + catalogBytes, _ := json.Marshal(catalog) + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, catalogBytes) + + b := bytes.NewBufferString("") + h := &Helper{ + Database: db, + FileReader: testFileReader{content: newTestConfig()}, + Logger: internal.NewLogger(b), + } + + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + streamStatuses := map[string][]string{} + for _, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.StreamStatus != nil { + name := msg.Trace.StreamStatus.StreamDescriptor.Name + streamStatuses[name] = append(streamStatuses[name], msg.Trace.StreamStatus.Status) + } + } + + assert.Equal(t, []string{internal.STREAM_STATUS_STARTED, internal.STREAM_STATUS_COMPLETE}, + streamStatuses["good_table"]) + assert.Equal(t, []string{internal.STREAM_STATUS_STARTED, internal.STREAM_STATUS_INCOMPLETE}, + streamStatuses["bad_table"]) + + // good_table should have a STATE message, bad_table should NOT + hasGoodState := false + hasBadState := false + for _, msg := range messages { + if msg.Type == internal.STATE && msg.State != nil && msg.State.Stream != nil { + if msg.State.Stream.StreamDescriptor.Name == "good_table" { + hasGoodState = true + } + if msg.State.Stream.StreamDescriptor.Name == "bad_table" { + hasBadState = true + } + } + } + assert.True(t, hasGoodState, "good_table should have a STATE message") + assert.False(t, hasBadState, "bad_table should NOT have a STATE message (it failed)") +} diff --git a/cmd/internal/logger_test.go b/cmd/internal/logger_test.go new file mode 100644 index 0000000..751655e --- /dev/null +++ b/cmd/internal/logger_test.go @@ -0,0 +1,194 @@ +package internal + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStreamState_EmitsPerStreamFormat(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + shardStates := ShardStates{ + Shards: map[string]*SerializedCursor{ + "-": {Cursor: "abc123"}, + }, + } + + logger.StreamState("my-database", "users", shardStates) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STATE, msg.Type) + require.NotNil(t, msg.State) + assert.Equal(t, STATE_TYPE_STREAM, msg.State.Type) + require.NotNil(t, msg.State.Stream) + assert.Equal(t, "users", msg.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "my-database", msg.State.Stream.StreamDescriptor.Namespace) + require.NotNil(t, msg.State.Stream.StreamState) + assert.Equal(t, "abc123", msg.State.Stream.StreamState.Shards["-"].Cursor) +} + +func TestStreamState_MultipleShards(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + shardStates := ShardStates{ + Shards: map[string]*SerializedCursor{ + "-80": {Cursor: "cursor1"}, + "80-": {Cursor: "cursor2"}, + }, + } + + logger.StreamState("sharded-db", "orders", shardStates) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STATE_TYPE_STREAM, msg.State.Type) + assert.Equal(t, "orders", msg.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "sharded-db", msg.State.Stream.StreamDescriptor.Namespace) + assert.Len(t, msg.State.Stream.StreamState.Shards, 2) + assert.Equal(t, "cursor1", msg.State.Stream.StreamState.Shards["-80"].Cursor) + assert.Equal(t, "cursor2", msg.State.Stream.StreamState.Shards["80-"].Cursor) +} + +func TestStreamState_NoLegacyDataField(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + shardStates := ShardStates{ + Shards: map[string]*SerializedCursor{ + "-": {Cursor: "abc"}, + }, + } + + logger.StreamState("db", "table1", shardStates) + + // Parse as raw JSON to verify no "data" key exists (which would indicate LEGACY format) + var raw map[string]json.RawMessage + err := json.NewDecoder(b).Decode(&raw) + require.NoError(t, err) + + var stateRaw map[string]json.RawMessage + err = json.Unmarshal(raw["state"], &stateRaw) + require.NoError(t, err) + + _, hasData := stateRaw["data"] + assert.False(t, hasData, "state should not contain 'data' field (LEGACY format)") + + _, hasType := stateRaw["type"] + assert.True(t, hasType, "state must contain 'type' field") + + _, hasStream := stateRaw["stream"] + assert.True(t, hasStream, "state must contain 'stream' field") +} + +func TestStreamStatus_EmitsTraceMessage(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamStatus("my-db", "accounts", STREAM_STATUS_STARTED) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, TRACE, msg.Type) + require.NotNil(t, msg.Trace) + assert.Equal(t, TRACE_TYPE_STREAM_STATUS, msg.Trace.Type) + assert.True(t, msg.Trace.EmittedAt > 0) + require.NotNil(t, msg.Trace.StreamStatus) + assert.Equal(t, "accounts", msg.Trace.StreamStatus.StreamDescriptor.Name) + assert.Equal(t, "my-db", msg.Trace.StreamStatus.StreamDescriptor.Namespace) + assert.Equal(t, STREAM_STATUS_STARTED, msg.Trace.StreamStatus.Status) +} + +func TestStreamStatus_Complete(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamStatus("ns", "tbl", STREAM_STATUS_COMPLETE) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STREAM_STATUS_COMPLETE, msg.Trace.StreamStatus.Status) +} + +func TestStreamStatus_Incomplete(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamStatus("ns", "tbl", STREAM_STATUS_INCOMPLETE) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STREAM_STATUS_INCOMPLETE, msg.Trace.StreamStatus.Status) +} + +func TestStreamState_JSONRoundTrip(t *testing.T) { + // Verify the JSON output can be parsed back into the exact expected Airbyte protocol structure + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamState("anam-lab", "persona", ShardStates{ + Shards: map[string]*SerializedCursor{ + "-": {Cursor: "encoded-cursor-data"}, + }, + }) + + // Parse into a generic structure to verify exact JSON shape + var raw map[string]interface{} + err := json.NewDecoder(b).Decode(&raw) + require.NoError(t, err) + + assert.Equal(t, "STATE", raw["type"]) + + state := raw["state"].(map[string]interface{}) + assert.Equal(t, "STREAM", state["type"]) + + stream := state["stream"].(map[string]interface{}) + descriptor := stream["stream_descriptor"].(map[string]interface{}) + assert.Equal(t, "persona", descriptor["name"]) + assert.Equal(t, "anam-lab", descriptor["namespace"]) + + streamState := stream["stream_state"].(map[string]interface{}) + shards := streamState["shards"].(map[string]interface{}) + shard := shards["-"].(map[string]interface{}) + assert.Equal(t, "encoded-cursor-data", shard["cursor"]) +} + +func TestMultipleStreamStates_EachIndependent(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamState("db", "table1", ShardStates{ + Shards: map[string]*SerializedCursor{"-": {Cursor: "c1"}}, + }) + logger.StreamState("db", "table2", ShardStates{ + Shards: map[string]*SerializedCursor{"-": {Cursor: "c2"}}, + }) + + decoder := json.NewDecoder(b) + + var msg1 AirbyteMessage + require.NoError(t, decoder.Decode(&msg1)) + assert.Equal(t, "table1", msg1.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "c1", msg1.State.Stream.StreamState.Shards["-"].Cursor) + + var msg2 AirbyteMessage + require.NoError(t, decoder.Decode(&msg2)) + assert.Equal(t, "table2", msg2.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "c2", msg2.State.Stream.StreamState.Shards["-"].Cursor) +} From fa86cd45140caaedb5bd5372667d7d8d016c2284 Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Wed, 15 Apr 2026 12:06:53 +0100 Subject: [PATCH 4/4] Fix lost shard progress and silent success on stream errors Address review feedback: 1. Always emit StreamState after the shard loop, even on failure. Previously, state was only emitted when all shards succeeded. If shard A advanced and shard B failed, shard A's cursor was lost and the next retry would re-read already-synced data. 2. Return an error from the read command when any stream fails. The os.Exit(1) calls were replaced with break to allow other streams to emit proper status messages, but the command was silently exiting successfully. Now uses RunE so cobra surfaces the error and exits non-zero. Also converts remaining os.Exit(1) calls to return errors for consistency and testability, and adds a test for multi-shard partial failure checkpointing. --- cmd/airbyte-source/read.go | 44 +++++++++------ cmd/airbyte-source/read_protocol_test.go | 70 +++++++++++++++++++++++- 2 files changed, 93 insertions(+), 21 deletions(-) diff --git a/cmd/airbyte-source/read.go b/cmd/airbyte-source/read.go index c56e04b..d9ec84e 100644 --- a/cmd/airbyte-source/read.go +++ b/cmd/airbyte-source/read.go @@ -22,32 +22,33 @@ func init() { func ReadCommand(ch *Helper) *cobra.Command { readCmd := &cobra.Command{ - Use: "read", - Short: "Converts rows from a PlanetScale database into AirbyteRecordMessages", - Run: func(cmd *cobra.Command, args []string) { + Use: "read", + Short: "Converts rows from a PlanetScale database into AirbyteRecordMessages", + SilenceUsage: true, + RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() ch.Logger = internal.NewLogger(cmd.OutOrStdout()) if readSourceConfigFilePath == "" { fmt.Fprintf(cmd.ErrOrStderr(), "Please pass path to a valid source config file via the [%v] argument", "config") - os.Exit(1) + return fmt.Errorf("missing config file path") } if readSourceCatalogPath == "" { fmt.Fprintf(cmd.OutOrStdout(), "Please pass path to a valid source catalog file via the [%v] argument", "config") - os.Exit(1) + return fmt.Errorf("missing catalog file path") } psc, err := parseSource(ch.FileReader, readSourceConfigFilePath) if err != nil { fmt.Fprintln(cmd.OutOrStdout(), "Please provide path to a valid configuration file") - return + return err } ch.Logger.Log(internal.LOGLEVEL_INFO, "Ensure database") if err := ch.EnsureDB(psc); err != nil { fmt.Fprintln(cmd.OutOrStdout(), "Unable to connect to PlanetScale Database") - return + return err } defer func() { @@ -60,19 +61,19 @@ func ReadCommand(ch *Helper) *cobra.Command { cs, err := checkConnectionStatus(ctx, ch.Database, psc) if err != nil { ch.Logger.ConnectionStatus(cs) - return + return err } ch.Logger.Log(internal.LOGLEVEL_INFO, "Reading catalog") catalog, err := readCatalog(readSourceCatalogPath) if err != nil { ch.Logger.Error(fmt.Sprintf("Unable to read catalog: %+v", err)) - os.Exit(1) + return fmt.Errorf("unable to read catalog: %w", err) } if len(catalog.Streams) == 0 { ch.Logger.Log(internal.LOGLEVEL_ERROR, "Catalog has no streams") - return + return nil } state := "" @@ -81,7 +82,7 @@ func ReadCommand(ch *Helper) *cobra.Command { b, err := os.ReadFile(stateFilePath) if err != nil { ch.Logger.Error(fmt.Sprintf("Unable to read state : %v", err)) - os.Exit(1) + return fmt.Errorf("unable to read state: %w", err) } state = string(b) } @@ -90,16 +91,17 @@ func ReadCommand(ch *Helper) *cobra.Command { shards, err := ch.Database.ListShards(ctx, psc) if err != nil { ch.Logger.Error(fmt.Sprintf("Unable to list shards : %v", err)) - os.Exit(1) + return fmt.Errorf("unable to list shards: %w", err) } ch.Logger.Log(internal.LOGLEVEL_INFO, "Reading state") syncState, err := readState(state, psc, catalog.Streams, shards, ch.Logger) if err != nil { ch.Logger.Error(fmt.Sprintf("Unable to read state : %v", err)) - os.Exit(1) + return fmt.Errorf("unable to read state: %w", err) } + var readErr error for _, configuredStream := range catalog.Streams { keyspaceOrDatabase := configuredStream.Stream.Namespace if keyspaceOrDatabase == "" { @@ -110,7 +112,7 @@ func ReadCommand(ch *Helper) *cobra.Command { if !ok { ch.Logger.Error(fmt.Sprintf("Unable to read state for stream %v", streamStateKey)) ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) - os.Exit(1) + return fmt.Errorf("unable to read state for stream %v", streamStateKey) } ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_STARTED) @@ -123,7 +125,6 @@ func ReadCommand(ch *Helper) *cobra.Command { ch.Logger.Log(internal.LOGLEVEL_INFO, fmt.Sprintf("Using serialized cursor for stream %s", streamStateKey)) if err != nil { ch.Logger.Error(fmt.Sprintf("Invalid serialized cursor for stream %v, failed with [%v]", streamStateKey, err)) - ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) streamFailed = true break } @@ -131,7 +132,6 @@ func ReadCommand(ch *Helper) *cobra.Command { sc, err := ch.Database.Read(ctx, cmd.OutOrStdout(), psc, configuredStream, tc) if err != nil { ch.Logger.Error(err.Error()) - ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) streamFailed = true break } @@ -141,11 +141,19 @@ func ReadCommand(ch *Helper) *cobra.Command { } } - if !streamFailed { - ch.Logger.StreamState(keyspaceOrDatabase, configuredStream.Stream.Name, syncState.Streams[streamStateKey]) + // Always emit state to checkpoint whatever progress was made, + // including partial progress when only some shards succeeded. + ch.Logger.StreamState(keyspaceOrDatabase, configuredStream.Stream.Name, syncState.Streams[streamStateKey]) + + if streamFailed { + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) + readErr = fmt.Errorf("read failed for stream %v", streamStateKey) + } else { ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_COMPLETE) } } + + return readErr }, } readCmd.Flags().StringVar(&readSourceCatalogPath, "catalog", "", "Path to the PlanetScale catalog configuration") diff --git a/cmd/airbyte-source/read_protocol_test.go b/cmd/airbyte-source/read_protocol_test.go index 3f9efce..a0b501d 100644 --- a/cmd/airbyte-source/read_protocol_test.go +++ b/cmd/airbyte-source/read_protocol_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" "os" "testing" @@ -337,7 +338,9 @@ func TestRead_ReadErrorEmitsIncompleteNotComplete(t *testing.T) { cmd.SetOut(b) require.NoError(t, cmd.Flag("config").Value.Set(configFile)) require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) - require.NoError(t, cmd.Execute()) + + err := cmd.Execute() + require.Error(t, err, "command should return an error when a stream fails") messages := parseOutputMessages(t, b) @@ -355,7 +358,8 @@ func TestRead_ReadErrorEmitsIncompleteNotComplete(t *testing.T) { assert.Equal(t, []string{internal.STREAM_STATUS_STARTED, internal.STREAM_STATUS_INCOMPLETE}, streamStatuses["bad_table"]) - // good_table should have a STATE message, bad_table should NOT + // Both streams should have STATE messages: state is always emitted to + // checkpoint whatever progress was made, even on failure. hasGoodState := false hasBadState := false for _, msg := range messages { @@ -369,5 +373,65 @@ func TestRead_ReadErrorEmitsIncompleteNotComplete(t *testing.T) { } } assert.True(t, hasGoodState, "good_table should have a STATE message") - assert.False(t, hasBadState, "bad_table should NOT have a STATE message (it failed)") + assert.True(t, hasBadState, "bad_table should have a STATE message (checkpointing progress)") +} + +func TestRead_MultiShardPartialFailureCheckpointsProgress(t *testing.T) { + db := &mockDatabase{ + shards: []string{"-80", "80-"}, + readFunc: func(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) { + // Fail the "80-" shard to simulate a partial failure. + if tc.Shard == "80-" { + return nil, fmt.Errorf("shard read error") + } + newCursor, _ := internal.TableCursorToSerializedCursor(&psdbconnect.TableCursor{ + Shard: tc.Shard, + Keyspace: tc.Keyspace, + Position: "MySQL56/advanced-position", + }) + return newCursor, nil + }, + } + catalogJSON := newTestCatalog(t, "events") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + + err := cmd.Execute() + require.Error(t, err, "command should fail when a shard errors") + + messages := parseOutputMessages(t, b) + + // A state message must be emitted even on partial failure so that + // progress from successful shards is checkpointed. + var stateMsg *internal.AirbyteMessage + for _, msg := range messages { + if msg.Type == internal.STATE && msg.State != nil { + stateMsg = &msg + } + } + require.NotNil(t, stateMsg, "state should be emitted even on partial failure") + require.NotNil(t, stateMsg.State.Stream) + require.NotNil(t, stateMsg.State.Stream.StreamState) + assert.Len(t, stateMsg.State.Stream.StreamState.Shards, 2, + "state should contain both shards") + + // Stream should be marked INCOMPLETE, not COMPLETE. + var statuses []string + for _, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.StreamStatus != nil && + msg.Trace.StreamStatus.StreamDescriptor.Name == "events" { + statuses = append(statuses, msg.Trace.StreamStatus.Status) + } + } + assert.Contains(t, statuses, internal.STREAM_STATUS_STARTED) + assert.Contains(t, statuses, internal.STREAM_STATUS_INCOMPLETE) + assert.NotContains(t, statuses, internal.STREAM_STATUS_COMPLETE) }