Skip to content

Commit

Permalink
fix snapshot tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Guillem committed Oct 13, 2024
1 parent c175b49 commit aa6ac88
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 35 deletions.
21 changes: 3 additions & 18 deletions cdc_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package spanner
import (
"context"
"fmt"
"strconv"
"time"

"cloud.google.com/go/spanner"
Expand Down Expand Up @@ -158,7 +157,7 @@ func (c *cdcIterator) startReader(ctx context.Context, streamID string) {
if mod.Keys.Valid {
if keys, ok := mod.Keys.Value.(map[string]interface{}); ok {
for k, v := range keys {
key[k] = v
key[k] = common.FormatValue(v)
}
}
}
Expand Down Expand Up @@ -213,31 +212,17 @@ func parseValues(values spanner.NullJSON, keys spanner.NullJSON) opencdc.Structu

data := make(opencdc.StructuredData)
for key, val := range valuesMap {
data[key] = formatValue(val)
data[key] = common.FormatValue(val)
}

if keys.Valid {
keysMap, ok := keys.Value.(map[string]any)
if ok {
for key, val := range keysMap {
data[key] = formatValue(val)
data[key] = common.FormatValue(val)
}
}
}

return data
}

func formatValue(val any) any {
switch v := val.(type) {
case string:
i, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return v
}

return i
default:
return v
}
}
20 changes: 20 additions & 0 deletions common/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package common

import (
"context"
"strconv"
"time"

"cloud.google.com/go/spanner"
sdk "github.com/conduitio/conduit-connector-sdk"
Expand All @@ -27,3 +29,21 @@ func NewClient(ctx context.Context, config NewClientConfig) (*spanner.Client, er

return spanner.NewClient(ctx, config.DatabaseName, options...)
}

func FormatValue(val any) any {
switch v := val.(type) {
case *time.Time:
return v.Format(time.RFC3339)
case time.Time:
return v.Format(time.RFC3339)
case string:
i, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return v
}

return i
default:
return v
}
}
47 changes: 33 additions & 14 deletions snapshot_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ func newSnapshotIterator(ctx context.Context, config snapshotIteratorConfig) *sn

for tableName, primaryKey := range config.tableKeys {
iterator.t.Go(func() error {
ctx := iterator.t.Context(ctx)
return iterator.fetchTable(ctx, tableName, primaryKey)
})
}
Expand Down Expand Up @@ -103,10 +104,14 @@ func (s *snapshotIterator) Teardown(ctx context.Context) error {
"cannot teardown snapshot mode, fetchers exited unexpectedly: %w", err)
}

sdk.Logger(ctx).Info().Msg("killed all workers, waiting for acks")

if err := s.acks.Wait(ctx); err != nil {
return fmt.Errorf("failed to wait for snapshot acks: %w", err)
}

sdk.Logger(ctx).Info().Msg("all acks received")

// waiting for the workers to finish will allow us to have an easier time
// debugging goroutine leaks.
_ = s.t.Wait()
Expand All @@ -129,7 +134,11 @@ func (s *snapshotIterator) fetchTable(
return fmt.Errorf("failed to fetch start and end of snapshot: %w", err)
}

sdk.Logger(ctx).Debug().Msgf("fetched start and end of snapshot for table %s", tableName)
sdk.Logger(ctx).Debug().
Str("table", string(tableName)).
Int64("start", start).
Int64("end", end).
Msg("fetched start and end of snapshot")

query := fmt.Sprint(`
SELECT *
Expand Down Expand Up @@ -158,27 +167,37 @@ func (s *snapshotIterator) fetchTable(
return fmt.Errorf("failed to fetch next row: %w", err)
}

data, err := decodeRow(row)
decodedRow, err := decodeRow(row)
if err != nil {
return fmt.Errorf("failed to decode row: %w", err)
}

primaryKeyVal, ok := data[string(primaryKey)]
primaryKeyVal, ok := decodedRow[string(primaryKey)]
if !ok {
return fmt.Errorf("primary key %s not found in row %v", primaryKey, data)
return fmt.Errorf("primary key %s not found in row %v", primaryKey, decodedRow)
}

s.dataC <- fetchData{
payload: data,
data := fetchData{
payload: decodedRow,
table: tableName,
primaryKeyName: primaryKey,
primaryKeyVal: primaryKeyVal,

position: common.TablePosition{
LastRead: start,
LastRead: start + 1,
SnapshotEnd: end,
},
}

sdk.Logger(ctx).Trace().Msgf("sending row %v", decodedRow)

select {
case s.dataC <- data:
case <-s.t.Dead():
return s.t.Err()
case <-ctx.Done():
return ctx.Err()
}
}

return nil
Expand All @@ -190,7 +209,7 @@ func (s *snapshotIterator) fetchStartEnd(
tableName common.TableName,
) (start, end int64, err error) {
// fetch the start
start = s.config.position.Snapshots[tableName].LastRead
start = s.lastPosition.Snapshots[tableName].LastRead

// fetch the end
query := fmt.Sprint(`
Expand Down Expand Up @@ -231,12 +250,12 @@ func (s *snapshotIterator) buildRecord(data fetchData) opencdc.Record {
string(data.primaryKeyName): data.payload[string(data.primaryKeyName)],
}

return sdk.Util.Source.NewRecordSnapshot(
position,
metadata,
key,
data.payload,
)
payload := opencdc.StructuredData{}
for key, val := range data.payload {
payload[key] = common.FormatValue(val)
}

return sdk.Util.Source.NewRecordSnapshot(position, metadata, key, payload)
}

func decodeRow(row *spanner.Row) (opencdc.StructuredData, error) {
Expand Down
6 changes: 5 additions & 1 deletion source_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ func testSource(ctx context.Context, is *is.I) (sdk.Source, func()) {
}
}

func testSourceAtPosition(ctx context.Context, is *is.I, pos opencdc.Position) (sdk.Source, func()) {
func testSourceAtPosition(
ctx context.Context, is *is.I,
pos opencdc.Position,
) (sdk.Source, func()) {
source := NewSource()
is.NoErr(source.Configure(ctx, config.Config{
SourceConfigDatabase: testutils.DatabaseName,
SourceConfigEndpoint: testutils.EmulatorHost,
SourceConfigTables: "Singers",
}))
is.NoErr(source.Open(ctx, pos))

Expand Down
5 changes: 3 additions & 2 deletions test/testutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ func ReadAndAssertSnapshot(
ctx context.Context, is *is.I,
iterator common.Iterator, singer Singer,
) opencdc.Record {
is.Helper()

rec, err := iterator.Read(ctx)
is.NoErr(err)
is.NoErr(iterator.Ack(ctx, rec.Position))
Expand Down Expand Up @@ -309,6 +311,5 @@ func assertMetadata(is *is.I, metadata opencdc.Metadata) {

func assertKey(is *is.I, rec opencdc.Record, singer Singer) {
is.Helper()
singerID := fmt.Sprint(singer.SingerID)
isDataEqual(is, rec.Key, opencdc.StructuredData{"SingerID": singerID})
isDataEqual(is, rec.Key, opencdc.StructuredData{"SingerID": singer.SingerID})
}

0 comments on commit aa6ac88

Please sign in to comment.