diff --git a/README.md b/README.md index 195760e..38ce9c4 100644 --- a/README.md +++ b/README.md @@ -218,6 +218,8 @@ Usage of timescaledb-parallel-copy: Number of rows to insert overall; 0 means to insert all -log-batches Whether to time individual batches. + -on-conflict-do-nothing + Skip duplicate rows on unique constraint violations -quote character The QUOTE character to use during COPY (default '"') -reporting-period duration @@ -362,3 +364,20 @@ timestamp,temperature,humidity ``` Both files can use the same mapping configuration and import successfully into the same database table, even though they use different column names for the temperature data. The tool only validates for duplicate database columns among the columns actually present in each specific input file. + +### Conflict Resolution + +Use `--on-conflict-do-nothing` to automatically skip duplicate rows when unique constraint violations occur: + +```bash +# Skip duplicate rows and continue importing +$ timescaledb-parallel-copy --connection $DATABASE_URL --table metrics --file data.csv \ + --on-conflict-do-nothing +``` + +This uses PostgreSQL's `ON CONFLICT DO NOTHING` clause to ignore rows that would violate unique constraints, allowing the import to continue with just the non-duplicate data. + +Note that this statement is not allowed within a `COPY FROM`. The tool will fallback to moving your data into a temporal table and running `INSERT INTO ... SELECT * FROM ... ON CONFLICT DO NOTHING`. + +This flag is intended to detect real duplicates and not incremental changes to rows. This means it is safe to use this setting is you expect your data to have duplicate rows, but it is not ok to use this as an ingestion pipeline where you expect updates for the same unique constraint. + diff --git a/cmd/timescaledb-parallel-copy/main.go b/cmd/timescaledb-parallel-copy/main.go index 06b9252..2e514da 100644 --- a/cmd/timescaledb-parallel-copy/main.go +++ b/cmd/timescaledb-parallel-copy/main.go @@ -15,6 +15,7 @@ import ( "time" "github.com/timescale/timescaledb-parallel-copy/pkg/csvcopy" + "github.com/timescale/timescaledb-parallel-copy/pkg/errorhandlers" ) const ( @@ -55,6 +56,8 @@ var ( verbose bool showVersion bool + onConflictDoNothing bool + dbName string ) @@ -92,6 +95,8 @@ func init() { flag.DurationVar(&reportingPeriod, "reporting-period", 0*time.Second, "Period to report insert stats; if 0s, intermediate results will not be reported") flag.BoolVar(&verbose, "verbose", false, "Print more information about copying statistics") + flag.BoolVar(&onConflictDoNothing, "on-conflict-do-nothing", false, "Skip duplicate rows on unique constraint violations") + flag.BoolVar(&showVersion, "version", false, "Show the version of this tool") flag.Parse() @@ -117,6 +122,7 @@ func main() { log.Fatalf("Error: -header-line-count is deprecated. Use -skip-lines instead") } + logger := &csvCopierLogger{} opts := []csvcopy.Option{ @@ -157,6 +163,13 @@ func main() { if skipBatchErrors { batchErrorHandler = csvcopy.BatchHandlerNoop() } + + if onConflictDoNothing { + batchErrorHandler = errorhandlers.BatchConflictHandler( + errorhandlers.WithConflictHandlerNext(batchErrorHandler), + ) + } + if verbose || skipBatchErrors { batchErrorHandler = csvcopy.BatchHandlerLog(logger, batchErrorHandler) } diff --git a/go.mod b/go.mod index 0b8717a..3456322 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.22.0 toolchain go1.23.4 require ( - github.com/google/uuid v1.6.0 github.com/jackc/pgx/v5 v5.7.2 github.com/jmoiron/sqlx v1.4.0 github.com/stretchr/testify v1.10.0 @@ -34,6 +33,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect diff --git a/pkg/buffer/buffer.go b/pkg/buffer/buffer.go new file mode 100644 index 0000000..1cef050 --- /dev/null +++ b/pkg/buffer/buffer.go @@ -0,0 +1,171 @@ +// Package seekablebuffers provides a seekable wrapper around net.Buffers +// that enables retry functionality for database copy operations. +package buffer + +import ( + "fmt" + "io" +) + +// Buffers contains zero or more runs of bytes to write. +// +// On certain machines, for certain types of connections, this is +// optimized into an OS-specific batch write operation (such as +// "writev"). +type Seekable struct { + buf [][]byte + position int64 +} + +var ( + _ io.WriterTo = (*Seekable)(nil) + _ io.Reader = (*Seekable)(nil) + _ io.Writer = (*Seekable)(nil) + _ io.Seeker = (*Seekable)(nil) +) + +func NewSeekable(buf [][]byte) *Seekable { + return &Seekable{ + buf: buf, + position: 0, + } +} + +func (v *Seekable) HasData() bool { + return v.position < v.TotalSize() +} + +func (v *Seekable) TotalSize() int64 { + var size int64 + for _, b := range v.buf { + size += int64(len(b)) + } + return size +} + +// WriteTo writes contents of the buffers to w. +// +// WriteTo implements [io.WriterTo] for [Buffers]. +// +// WriteTo modifies the slice v as well as v[i] for 0 <= i < len(v), +// but does not modify v[i][j] for any i, j. +func (v *Seekable) WriteTo(w io.Writer) (n int64, err error) { + if v.position >= v.TotalSize() { + return 0, nil + } + + currentPos := v.position + + for _, buf := range v.buf { + bufLen := int64(len(buf)) + if currentPos >= bufLen { + currentPos -= bufLen + continue + } + + startInBuf := currentPos + bytesToWrite := buf[startInBuf:] + + nb, err := w.Write(bytesToWrite) + n += int64(nb) + if err != nil { + v.position += n + return n, err + } + currentPos = 0 + } + + v.position += n + return n, nil +} + +// Read from the buffers. +// +// Read implements [io.Reader] for [Buffers]. +// +// Read modifies the slice v as well as v[i] for 0 <= i < len(v), +// but does not modify v[i][j] for any i, j. +func (v *Seekable) Read(p []byte) (n int, err error) { + if v.position >= v.TotalSize() { + return 0, io.EOF + } + + remaining := len(p) + currentPos := v.position + + for i, buf := range v.buf { + if remaining == 0 { + break + } + + bufLen := int64(len(buf)) + if currentPos >= bufLen { + currentPos -= bufLen + continue + } + + startInBuf := currentPos + bytesToRead := bufLen - startInBuf + if int64(remaining) < bytesToRead { + bytesToRead = int64(remaining) + } + + copied := copy(p[n:], buf[startInBuf:startInBuf+bytesToRead]) + n += copied + remaining -= copied + currentPos = 0 + + if i == len(v.buf)-1 && n < len(p) { + err = io.EOF + } + } + + v.position += int64(n) + return n, err +} + +// Write appends data to the buffer +func (v *Seekable) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + + // Create a copy of the input data to avoid issues with caller reusing the slice + data := make([]byte, len(p)) + copy(data, p) + + v.buf = append(v.buf, data) + return len(p), nil +} + +// WriteString appends a string to the buffer +func (v *Seekable) WriteString(s string) (n int, err error) { + return v.Write([]byte(s)) +} + +// Seek sets the position for next Read or Write operation +func (v *Seekable) Seek(offset int64, whence int) (int64, error) { + totalSize := v.TotalSize() + + var newPos int64 + switch whence { + case io.SeekStart: + newPos = offset + case io.SeekCurrent: + newPos = v.position + offset + case io.SeekEnd: + newPos = totalSize + offset + default: + return v.position, fmt.Errorf("invalid whence value: %d", whence) + } + + if newPos < 0 { + return v.position, fmt.Errorf("seek position cannot be negative: %d", newPos) + } + if newPos > totalSize { + return v.position, fmt.Errorf("seek position beyond buffer size: %d > %d", newPos, totalSize) + } + + v.position = newPos + return v.position, nil +} diff --git a/pkg/buffer/buffer_test.go b/pkg/buffer/buffer_test.go new file mode 100644 index 0000000..4073446 --- /dev/null +++ b/pkg/buffer/buffer_test.go @@ -0,0 +1,433 @@ +package buffer + +import ( + "bytes" + "io" + "strings" + "testing" +) + +func TestSeekAndRead(t *testing.T) { + data := [][]byte{ + []byte("hello"), + []byte(" "), + []byte("world"), + } + + sb := NewSeekable(data) + + // Test initial read + buf := make([]byte, 11) + n, err := sb.Read(buf) + if n != 11 || (err != nil && err != io.EOF) { + t.Errorf("Read() = (%d, %v), expected 11 bytes", n, err) + } + if string(buf) != "hello world" { + t.Errorf("Read data = %q, want %q", string(buf), "hello world") + } + + // Test seek to beginning + pos, err := sb.Seek(0, io.SeekStart) + if pos != 0 || err != nil { + t.Errorf("Seek(0, SeekStart) = (%d, %v), want (0, nil)", pos, err) + } + + // Test read after seek to beginning + buf = make([]byte, 11) + n, err = sb.Read(buf) + if n != 11 || (err != nil && err != io.EOF) { + t.Errorf("Read after seek = (%d, %v), expected 11 bytes", n, err) + } + if string(buf) != "hello world" { + t.Errorf("Read after seek = %q, want %q", string(buf), "hello world") + } +} + +func TestSeekPositions(t *testing.T) { + data := [][]byte{ + []byte("abc"), + []byte("def"), + []byte("ghi"), + } + + sb := NewSeekable(data) + + tests := []struct { + name string + offset int64 + whence int + expected int64 + readData string + }{ + {"SeekStart 0", 0, io.SeekStart, 0, "abcdefghi"}, + {"SeekStart 3", 3, io.SeekStart, 3, "defghi"}, + {"SeekStart 6", 6, io.SeekStart, 6, "ghi"}, + {"SeekStart 9", 9, io.SeekStart, 9, ""}, // At end + {"SeekCurrent -3", -3, io.SeekCurrent, 6, "ghi"}, + {"SeekCurrent -6", -6, io.SeekCurrent, 0, "abcdefghi"}, + {"SeekEnd 0", 0, io.SeekEnd, 9, ""}, + {"SeekEnd -3", -3, io.SeekEnd, 6, "ghi"}, + {"SeekEnd -9", -9, io.SeekEnd, 0, "abcdefghi"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset position for each test case to ensure predictable state + if strings.Contains(tt.name, "SeekCurrent") { + // For SeekCurrent tests, set up the expected starting position + if tt.name == "SeekCurrent -3" { + _, _ = sb.Seek(9, io.SeekStart) // Start from end + } else if tt.name == "SeekCurrent -6" { + _, _ = sb.Seek(6, io.SeekStart) // Start from position 6 + } + } + + pos, err := sb.Seek(tt.offset, tt.whence) + if pos != tt.expected || err != nil { + t.Errorf("Seek(%d, %d) = (%d, %v), want (%d, nil)", + tt.offset, tt.whence, pos, err, tt.expected) + } + + // Read remaining data + buf := make([]byte, 20) + n, err := sb.Read(buf) + if err != nil && err != io.EOF { + t.Errorf("Read after seek failed: %v", err) + } + readData := string(buf[:n]) + if readData != tt.readData { + t.Errorf("Read data = %q, want %q", readData, tt.readData) + } + }) + } +} + +func TestSeekErrors(t *testing.T) { + data := [][]byte{[]byte("test")} + sb := NewSeekable(data) + + tests := []struct { + name string + offset int64 + whence int + }{ + {"negative position", -1, io.SeekStart}, + {"beyond end", 10, io.SeekStart}, + {"invalid whence", 0, 99}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := sb.Seek(tt.offset, tt.whence) + if err == nil { + t.Errorf("Seek(%d, %d) should return error", tt.offset, tt.whence) + } + }) + } +} + +func TestWriteTo(t *testing.T) { + data := [][]byte{ + []byte("hello"), + []byte(" "), + []byte("world"), + } + + sb := NewSeekable(data) + + // Test WriteTo from beginning + var buf bytes.Buffer + n, err := sb.WriteTo(&buf) + if n != 11 || err != nil { + t.Errorf("WriteTo() = (%d, %v), want (11, nil)", n, err) + } + if buf.String() != "hello world" { + t.Errorf("WriteTo data = %q, want %q", buf.String(), "hello world") + } + + // Test WriteTo after seek + _, _ = sb.Seek(6, io.SeekStart) // Start from "world" + buf.Reset() + n, err = sb.WriteTo(&buf) + if n != 5 || err != nil { + t.Errorf("WriteTo after seek = (%d, %v), want (5, nil)", n, err) + } + if buf.String() != "world" { + t.Errorf("WriteTo after seek data = %q, want %q", buf.String(), "world") + } +} + +func TestPartialReads(t *testing.T) { + data := [][]byte{ + []byte("abcde"), + []byte("fghij"), + []byte("klmno"), + } + + sb := NewSeekable(data) + + // Test reading in small chunks + result := "" + buf := make([]byte, 3) + + for { + n, err := sb.Read(buf) + if n > 0 { + result += string(buf[:n]) + } + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Read error: %v", err) + } + } + + expected := "abcdefghijklmno" + if result != expected { + t.Errorf("Partial reads result = %q, want %q", result, expected) + } +} + +func TestRetryScenario(t *testing.T) { + // Simulate CSV data that might need retry + csvData := [][]byte{ + []byte("id,name,value\n"), + []byte("1,John,100\n"), + []byte("2,Jane,200\n"), + []byte("3,Bob,300\n"), + } + + sb := NewSeekable(csvData) + + // Simulate first attempt that reads some data + attempt1 := &strings.Builder{} + buf := make([]byte, 10) // Small buffer to simulate partial read + n, err := sb.Read(buf) + if err != nil && err != io.EOF { + t.Fatalf("First read failed: %v", err) + } + attempt1.Write(buf[:n]) + + // Get current position for verification + currentPos, _ := sb.Seek(0, io.SeekCurrent) + if currentPos != 10 { + t.Errorf("Position after first read = %d, want 10", currentPos) + } + + // Simulate retry after reset + pos, err := sb.Seek(0, io.SeekStart) + if pos != 0 || err != nil { + t.Errorf("Reset seek failed: (%d, %v)", pos, err) + } + + attempt2 := &strings.Builder{} + _, err = io.Copy(attempt2, sb) + if err != nil { + t.Fatalf("Retry failed: %v", err) + } + + expected := "id,name,value\n1,John,100\n2,Jane,200\n3,Bob,300\n" + if attempt2.String() != expected { + t.Errorf("Retry result = %q, want %q", attempt2.String(), expected) + } + + // Verify partial read in first attempt + if len(attempt1.String()) != 10 { + t.Errorf("First attempt read %d bytes, want 10", len(attempt1.String())) + } +} + +func TestSeekBoundaries(t *testing.T) { + data := [][]byte{ + []byte("abc"), + []byte("def"), + []byte("ghi"), + } + + sb := NewSeekable(data) + + // Test seeking to exact buffer boundaries + positions := []int64{0, 3, 6, 9} // Start, end of first, end of second, end + + for _, pos := range positions { + actualPos, err := sb.Seek(pos, io.SeekStart) + if actualPos != pos || err != nil { + t.Errorf("Seek to boundary %d = (%d, %v), want (%d, nil)", + pos, actualPos, err, pos) + } + + // Verify we can read from this position + buf := make([]byte, 1) + n, err := sb.Read(buf) + + if pos == 9 { // At end + if n != 0 || err != io.EOF { + t.Errorf("Read at end pos %d = (%d, %v), want (0, EOF)", pos, n, err) + } + } else { + if n != 1 || (err != nil && err != io.EOF) { + t.Errorf("Read at pos %d = (%d, %v), want (1, nil or EOF)", pos, n, err) + } + } + } +} + +func TestWrite(t *testing.T) { + // Test basic write functionality + sb := NewSeekable([][]byte{}) + + n, err := sb.Write([]byte("hello")) + if n != 5 || err != nil { + t.Errorf("Write('hello') = (%d, %v), want (5, nil)", n, err) + } + + n, err = sb.Write([]byte(" world")) + if n != 6 || err != nil { + t.Errorf("Write(' world') = (%d, %v), want (6, nil)", n, err) + } + + // Test reading back the written data + _, err = sb.Seek(0, io.SeekStart) + if err != nil { + t.Errorf("Seek to start failed: %v", err) + } + buf := make([]byte, 20) + n, err = sb.Read(buf) + if n != 11 || (err != nil && err != io.EOF) { + t.Errorf("Read after write = (%d, %v), want (11, nil or EOF)", n, err) + } + if string(buf[:n]) != "hello world" { + t.Errorf("Read data = %q, want %q", string(buf[:n]), "hello world") + } +} + +func TestWriteString(t *testing.T) { + sb := NewSeekable([][]byte{}) + + n, err := sb.WriteString("test") + if n != 4 || err != nil { + t.Errorf("WriteString('test') = (%d, %v), want (4, nil)", n, err) + } + + n, err = sb.WriteString(" string") + if n != 7 || err != nil { + t.Errorf("WriteString(' string') = (%d, %v), want (7, nil)", n, err) + } + + // Test reading back + _, err = sb.Seek(0, io.SeekStart) + if err != nil { + t.Errorf("Seek to start failed: %v", err) + } + buf := make([]byte, 20) + n, err = sb.Read(buf) + if n != 11 || (err != nil && err != io.EOF) { + t.Errorf("Read after WriteString = (%d, %v), want (11, nil or EOF)", n, err) + } + if string(buf[:n]) != "test string" { + t.Errorf("Read data = %q, want %q", string(buf[:n]), "test string") + } +} + +func TestWriteAndSeek(t *testing.T) { + // Test writing data, seeking, and reading from different positions + sb := NewSeekable([][]byte{}) + + _, _ = sb.Write([]byte("abc")) + _, _ = sb.Write([]byte("def")) + _, _ = sb.Write([]byte("ghi")) + + // Test reading from beginning + _, err := sb.Seek(0, io.SeekStart) + if err != nil { + t.Errorf("Seek to start failed: %v", err) + } + buf := make([]byte, 3) + n, err := sb.Read(buf) + if n != 3 || err != nil { + t.Errorf("Read from start = (%d, %v), want (3, nil)", n, err) + } + if string(buf) != "abc" { + t.Errorf("Read data = %q, want %q", string(buf), "abc") + } + + // Test reading from middle + _, err = sb.Seek(3, io.SeekStart) + if err != nil { + t.Errorf("Seek to start failed: %v", err) + } + n, err = sb.Read(buf) + if n != 3 || err != nil { + t.Errorf("Read from middle = (%d, %v), want (3, nil)", n, err) + } + if string(buf) != "def" { + t.Errorf("Read data = %q, want %q", string(buf), "def") + } + + // Test reading from end + _, err = sb.Seek(6, io.SeekStart) + if err != nil { + t.Errorf("Seek to start failed: %v", err) + } + n, err = sb.Read(buf) + if n != 3 || (err != nil && err != io.EOF) { + t.Errorf("Read from end = (%d, %v), want (3, nil or EOF)", n, err) + } + if string(buf) != "ghi" { + t.Errorf("Read data = %q, want %q", string(buf), "ghi") + } +} + +func TestWriteEmpty(t *testing.T) { + sb := NewSeekable([][]byte{}) + + // Test writing empty slice + n, err := sb.Write([]byte{}) + if n != 0 || err != nil { + t.Errorf("Write([]) = (%d, %v), want (0, nil)", n, err) + } + + // Test writing nil slice + n, err = sb.Write(nil) + if n != 0 || err != nil { + t.Errorf("Write(nil) = (%d, %v), want (0, nil)", n, err) + } + + // Buffer should still be empty + buf := make([]byte, 10) + n, err = sb.Read(buf) + if n != 0 || err != io.EOF { + t.Errorf("Read from empty buffer = (%d, %v), want (0, EOF)", n, err) + } +} + +func TestMixedWriteAndInitialData(t *testing.T) { + // Test starting with initial data and then writing more + initialData := [][]byte{ + []byte("initial"), + []byte(" data"), + } + + sb := NewSeekable(initialData) + + // Write additional data + _, _ = sb.Write([]byte(" plus")) + _, _ = sb.Write([]byte(" more")) + + // Read everything + _, err := sb.Seek(0, io.SeekStart) + if err != nil { + t.Errorf("Seek to start failed: %v", err) + } + buf := make([]byte, 30) + n, err := sb.Read(buf) + if err != nil && err != io.EOF { + t.Errorf("Read error: %v", err) + } + + expected := "initial data plus more" + if string(buf[:n]) != expected { + t.Errorf("Read data = %q, want %q", string(buf[:n]), expected) + } +} diff --git a/pkg/csvcopy/batch_error.go b/pkg/csvcopy/batch_error.go index b69818f..138f90c 100644 --- a/pkg/csvcopy/batch_error.go +++ b/pkg/csvcopy/batch_error.go @@ -1,23 +1,35 @@ package csvcopy +import ( + "context" + + "github.com/jmoiron/sqlx" +) + // BatchHandlerLog prints a log line that reports the error in the given batch +// If next is nil, it uses BatchHandlerNoop func BatchHandlerLog(log Logger, next BatchErrorHandler) BatchErrorHandler { - return BatchErrorHandler(func(batch Batch, reason error) *BatchError { - log.Infof("Batch %d, starting at byte %d with len %d, has error: %s", batch.Location.StartRow, batch.Location.ByteOffset, batch.Location.ByteLen, reason.Error()) + return BatchErrorHandler(func(ctx context.Context, c *Copier, db *sqlx.Conn, batch Batch, reason error) HandleBatchErrorResult { + c.LogInfo(ctx, "BatchHandlerLog: Batch %d, starting at byte %d with len %d, has error: %s", batch.Location.StartRow, batch.Location.ByteOffset, batch.Location.ByteLen, reason.Error()) if next != nil { - return next(batch, reason) + return next(ctx, c, db, batch, reason) } - return NewErrContinue(reason) + return BatchHandlerNoop()(ctx, c, db, batch, reason) }) } // BatchHandlerNoop no operation +// Marks all rows as skipped func BatchHandlerNoop() BatchErrorHandler { - return BatchErrorHandler(func(_ Batch, reason error) *BatchError { return NewErrContinue(reason) }) + return BatchErrorHandler(func(_ context.Context, _ *Copier, _ *sqlx.Conn, _ Batch, reason error) HandleBatchErrorResult { + return NewErrContinue(reason) + }) } // BatchHandlerError fails the process func BatchHandlerError() BatchErrorHandler { - return BatchErrorHandler(func(_ Batch, reason error) *BatchError { return NewErrStop(reason) }) + return BatchErrorHandler(func(_ context.Context, _ *Copier, _ *sqlx.Conn, _ Batch, reason error) HandleBatchErrorResult { + return NewErrStop(reason) + }) } diff --git a/pkg/csvcopy/csvcopy.go b/pkg/csvcopy/csvcopy.go index 3fbe762..3e0bcee 100644 --- a/pkg/csvcopy/csvcopy.go +++ b/pkg/csvcopy/csvcopy.go @@ -3,6 +3,7 @@ package csvcopy import ( "bufio" "context" + "database/sql" "errors" "fmt" "io" @@ -19,6 +20,24 @@ import ( "github.com/jmoiron/sqlx" ) +// contextKey is used for context values to avoid collisions +type contextKey string + +const workerIDKey contextKey = "workerID" + +// WithWorkerID adds a worker ID to the context +func WithWorkerID(ctx context.Context, workerID int) context.Context { + return context.WithValue(ctx, workerIDKey, workerID) +} + +// GetWorkerIDFromContext extracts the worker ID from context, returns -1 if not found +func GetWorkerIDFromContext(ctx context.Context) int { + if workerID, ok := ctx.Value(workerIDKey).(int); ok { + return workerID + } + return -1 +} + const TAB_CHAR_STR = "\\t" type HeaderHandling int @@ -49,7 +68,7 @@ type Copier struct { copyOptions string schemaName string - logger Logger + Logger Logger splitCharacter string quoteCharacter string escapeCharacter string @@ -80,6 +99,26 @@ type Copier struct { failHandler BatchErrorHandler } +// LogInfo logs a message with worker ID extracted from context if available +func (c *Copier) LogInfo(ctx context.Context, msg string, args ...interface{}) { + if !c.verbose { + return + } + if workerID := GetWorkerIDFromContext(ctx); workerID >= 0 { + c.Logger.Infof("[WORKER-%d] "+msg, append([]interface{}{workerID}, args...)...) + } else { + c.Logger.Infof(msg, args...) + } +} + +func (c *Copier) LogError(ctx context.Context, msg string, args ...interface{}) { + if workerID := GetWorkerIDFromContext(ctx); workerID >= 0 { + c.Logger.Infof("[WORKER-%d] "+msg, append([]interface{}{workerID}, args...)...) + } else { + c.Logger.Infof(msg, args...) + } +} + func NewCopier( connString string, tableName string, @@ -91,7 +130,7 @@ func NewCopier( // Defaults schemaName: "public", - logger: &noopLogger{}, + Logger: &noopLogger{}, copyOptions: "CSV", splitCharacter: ",", quoteCharacter: "", @@ -119,11 +158,11 @@ func NewCopier( } if copier.skip > 0 && copier.verbose { - copier.logger.Infof("Skipping the first %d lines of the input.", copier.skip) + copier.Logger.Infof("Skipping the first %d lines of the input.", copier.skip) } if copier.reportingFunction == nil { - copier.reportingFunction = DefaultReportFunc(copier.logger) + copier.reportingFunction = DefaultReportFunc(copier.Logger) } return copier, nil @@ -137,7 +176,7 @@ func (c *Copier) Truncate() (err error) { defer func() { err = dbx.Close() }() - _, err = dbx.Exec(fmt.Sprintf("TRUNCATE %s", c.getFullTableName())) + _, err = dbx.Exec(fmt.Sprintf("TRUNCATE %s", c.GetFullTableName())) if err != nil { return fmt.Errorf("failed to truncate table: %w", err) } @@ -150,7 +189,7 @@ func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) { if err := ensureTransactionTable(ctx, c.connString); err != nil { return Result{}, fmt.Errorf("failed to ensure transaction table, %w", err) } - c.logger.Infof("Cleaning old transactions older than %s", c.idempotencyWindow) + c.Logger.Infof("Cleaning old transactions older than %s", c.idempotencyWindow) if err := cleanOldTransactions(ctx, c.connString, c.idempotencyWindow); err != nil { return Result{}, fmt.Errorf("failed to clean old transactions, %w", err) } @@ -200,12 +239,16 @@ func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) { workerWg.Add(1) go func(i int) { defer workerWg.Done() - err := c.processBatches(ctx, batchChan) + // Add worker ID to context for all operations in this worker + workerCtx := WithWorkerID(ctx, i) + c.LogInfo(workerCtx, "start worker") + err := c.processBatches(workerCtx, batchChan, i) if err != nil { + c.LogError(workerCtx, "worker error: %v", err) errCh <- err cancel() } - c.logger.Infof("stop worker %d", i) + c.LogInfo(workerCtx, "stop worker") }(i) } @@ -215,7 +258,7 @@ func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) { defer cancelSupportCtx() // Reporting thread if c.reportingPeriod > (0 * time.Second) { - c.logger.Infof("There will be reports every %s", c.reportingPeriod.String()) + c.Logger.Infof("There will be reports every %s", c.reportingPeriod.String()) supportWg.Add(1) go func() { defer supportWg.Done() @@ -244,12 +287,12 @@ func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) { workerWg.Add(1) go func() { defer workerWg.Done() - if err := scan(ctx, counter, bufferedReader, batchChan, opts); err != nil { + if err := scan(ctx, c.LogInfo, counter, bufferedReader, batchChan, opts); err != nil { errCh <- fmt.Errorf("failed reading input: %w", err) cancel() } close(batchChan) - c.logger.Infof("stop scan") + c.LogInfo(ctx, "stop scan") }() workerWg.Wait() @@ -307,7 +350,7 @@ func (c *Copier) useAutomaticColumnMapping(headers []string) error { quotedHeaders[i] = pgx.Identifier{header}.Sanitize() } c.columns = strings.Join(quotedHeaders, ",") - c.logger.Infof("automatic column mapping: %s", c.columns) + c.LogInfo(context.TODO(), "automatic column mapping: %s", c.columns) return nil } @@ -364,7 +407,7 @@ func (c *Copier) calculateColumnsFromHeaders(bufferedReader *bufio.Reader) error } c.columns = strings.Join(columns, ",") - c.logger.Infof("Using column mapping: %s", c.columns) + c.LogInfo(context.TODO(), "Using column mapping: %s", c.columns) return nil } @@ -413,15 +456,11 @@ func (e ErrAtRow) Unwrap() error { return e.Err } -// processBatches reads batches from channel c and copies them to the target -// server while tracking stats on the write. -func (c *Copier) processBatches(ctx context.Context, ch chan Batch) (err error) { - dbx, err := connect(c.connString) - if err != nil { - return err - } - defer dbx.Close() +func (c *Copier) CopyCmd() string { + return c.CopyCmdWithContext(context.Background()) +} +func (c *Copier) CopyCmdWithContext(ctx context.Context) string { delimStr := "'" + c.splitCharacter + "'" if c.splitCharacter == TAB_CHAR_STR { delimStr = "E" + delimStr @@ -437,13 +476,32 @@ func (c *Copier) processBatches(ctx context.Context, ch chan Batch) (err error) quotes, strings.ReplaceAll(c.escapeCharacter, "'", "''")) } - var copyCmd string + var baseCmd string if c.columns != "" { - copyCmd = fmt.Sprintf("COPY %s(%s) FROM STDIN WITH DELIMITER %s %s %s", c.getFullTableName(), c.columns, delimStr, quotes, c.copyOptions) + baseCmd = fmt.Sprintf("COPY %s(%s) FROM STDIN WITH DELIMITER %s %s %s", c.GetFullTableName(), c.columns, delimStr, quotes, c.copyOptions) } else { - copyCmd = fmt.Sprintf("COPY %s FROM STDIN WITH DELIMITER %s %s %s", c.getFullTableName(), delimStr, quotes, c.copyOptions) + baseCmd = fmt.Sprintf("COPY %s FROM STDIN WITH DELIMITER %s %s %s", c.GetFullTableName(), delimStr, quotes, c.copyOptions) + } + + // Add worker ID comment if available in context + if workerID := GetWorkerIDFromContext(ctx); workerID >= 0 { + baseCmd = fmt.Sprintf("/* Worker-%d */ %s", workerID, baseCmd) + } + + return baseCmd +} + +// processBatches reads batches from channel c and copies them to the target +// server while tracking stats on the write. +func (c *Copier) processBatches(ctx context.Context, ch chan Batch, workerID int) (err error) { + dbx, err := connect(c.connString) + if err != nil { + return err } - c.logger.Infof("Copy command: %s", copyCmd) + defer dbx.Close() + + copyCmd := c.CopyCmdWithContext(ctx) + c.LogInfo(ctx, "Copy command: %s", copyCmd) for { if ctx.Err() != nil { @@ -458,31 +516,40 @@ func (c *Copier) processBatches(ctx context.Context, ch chan Batch) (err error) } atomic.AddInt64(&c.totalRows, int64(batch.Location.RowCount)) + if c.logBatches { + c.LogInfo(ctx, "Processing: starting at row %d: rows count %d, byte len %d", + batch.Location.StartRow, batch.Location.RowCount, batch.Location.ByteLen) + } + start := time.Now() rows, err := copyFromBatch(ctx, dbx, batch, copyCmd) if err != nil { - handleErr := c.handleCopyError(ctx, dbx, batch, err) + handleResult, handleErr := c.handleCopyError(ctx, dbx, batch, err) if handleErr != nil { + c.LogError(ctx, "Error handler failed for batch %d: %v", batch.Location.StartRow, handleErr) return handleErr } + atomic.AddInt64(&c.skippedRows, handleResult.SkippedRows) + rows = handleResult.InsertedRows } atomic.AddInt64(&c.insertedRows, rows) - if err, ok := err.(*ErrBatchAlreadyProcessed); ok { - if err.State.State == "completed" { - atomic.AddInt64(&c.skippedRows, int64(batch.Location.RowCount)) - } - } - if c.logBatches { took := time.Since(start) - fmt.Printf("[BATCH] starting at row %d, took %v, row count %d, byte len %d, row rate %f/sec\n", batch.Location.StartRow, took, batch.Location.RowCount, batch.Location.ByteLen, float64(batch.Location.RowCount)/float64(took.Seconds())) + c.LogInfo(ctx, "Processing: starting at row %d, took %v, row count %d, byte len %d, row rate %f/sec", batch.Location.StartRow, took, batch.Location.RowCount, batch.Location.ByteLen, float64(batch.Location.RowCount)/float64(took.Seconds())) } } } } -func (c *Copier) handleCopyError(ctx context.Context, db *sqlx.DB, batch Batch, copyErr error) error { +type HandleCopyErrorResult struct { + // Rows actually inserted + InsertedRows int64 + // Rows found but skipped due to a known reason + SkippedRows int64 +} + +func (c *Copier) handleCopyError(ctx context.Context, db *sqlx.DB, batch Batch, copyErr error) (HandleCopyErrorResult, error) { errAt := &ErrAtRow{ Err: copyErr, BatchLocation: batch.Location, @@ -494,46 +561,97 @@ func (c *Copier) handleCopyError(ctx context.Context, db *sqlx.DB, batch Batch, } if err, ok := copyErr.(*ErrBatchAlreadyProcessed); ok { - c.logger.Infof("skip batch %s already processed with state %s", batch.Location, err.State.State) - return nil + c.LogInfo(ctx, "skip batch %s already processed with state %s", batch.Location, err.State.State) + if err.State.State == "completed" { + return HandleCopyErrorResult{ + InsertedRows: 0, + SkippedRows: int64(batch.Location.RowCount), + }, nil + } + return HandleCopyErrorResult{ + InsertedRows: 0, + SkippedRows: 0, + }, nil + } - var failHandlerError *BatchError + connx, err := db.Connx(ctx) + if err != nil { + return HandleCopyErrorResult{}, fmt.Errorf("failed to connect to database") + } + defer connx.Close() + + if !batch.Location.HasImportID() { + if c.failHandler == nil { + return HandleCopyErrorResult{ + InsertedRows: 0, + SkippedRows: 0, + }, errAt + } + + failHandlerError := c.failHandler(ctx, c, connx, batch, errAt) + if !failHandlerError.Continue { + return HandleCopyErrorResult{ + InsertedRows: failHandlerError.InsertedRows, + SkippedRows: failHandlerError.SkippedRows, + }, failHandlerError + } + return HandleCopyErrorResult{ + InsertedRows: failHandlerError.InsertedRows, + SkippedRows: failHandlerError.SkippedRows, + }, nil + } + + // If we have an import ID, we need to start a transaction before the error handling to ensure both can run in the same transaction + tx, err := connx.BeginTxx(ctx, &sql.TxOptions{}) + if err != nil { + return HandleCopyErrorResult{}, fmt.Errorf("failed to start transaction, %w", err) + } + defer func() { _ = tx.Rollback() }() + + var failHandlerError HandleBatchErrorResult // If failHandler is defined, attempt to handle the error if c.failHandler != nil { - failHandlerError = c.failHandler(batch, errAt) - if failHandlerError == nil { - // If fail handler error does not return an error, - // make it so it recovers the previous error and continues execution - failHandlerError = NewErrContinue(errAt) - } + failHandlerError = c.failHandler(ctx, c, connx, batch, errAt) } else { failHandlerError = NewErrStop(errAt) } - c.logger.Infof("handling error %#v", failHandlerError) + c.LogInfo(ctx, "handling error for batch %s: %#v", batch.Location, failHandlerError) - if batch.Location.HasImportID() && !isTemporaryError(failHandlerError) { - connx, err := db.Connx(ctx) + tr := newTransactionAt(batch.Location) + + // If the fail handler is marked as handled, the transaction will be marked as completed. Independently if it still contains an error + if failHandlerError.Handled { + err = tr.setCompleted(ctx, tx) if err != nil { - return fmt.Errorf("failed to connect to database") + return HandleCopyErrorResult{}, fmt.Errorf("failed to set state to completed for batch %s, %w", batch.Location, err) } - defer connx.Close() - - tr := newTransactionAt(batch.Location) - err = tr.setFailed(ctx, connx, failHandlerError.Error()) + } else if !isTemporaryError(failHandlerError.Err) { + err = tr.setFailed(ctx, tx, failHandlerError.Error()) if err != nil { if !isDuplicateKeyError(err) { - return fmt.Errorf("failed to set state to failed, %w", err) + return HandleCopyErrorResult{}, fmt.Errorf("failed to set state to failed for batch %s, %w", batch.Location, err) } } } + err = tx.Commit() + if err != nil { + return HandleCopyErrorResult{}, fmt.Errorf("failed to commit transaction for batch %s, %w", batch.Location, err) + } + if !failHandlerError.Continue { - return failHandlerError + return HandleCopyErrorResult{ + InsertedRows: failHandlerError.InsertedRows, + SkippedRows: failHandlerError.SkippedRows, + }, failHandlerError } - return nil + return HandleCopyErrorResult{ + InsertedRows: failHandlerError.InsertedRows, + SkippedRows: failHandlerError.SkippedRows, + }, nil } @@ -589,10 +707,18 @@ func (c *Copier) report(ctx context.Context) { } } -func (c *Copier) getFullTableName() string { +func (c *Copier) GetFullTableName() string { return fmt.Sprintf(`"%s"."%s"`, c.schemaName, c.tableName) } +func (c *Copier) GetTableName() string { + return c.tableName +} + +func (c *Copier) GetSchemaName() string { + return c.schemaName +} + func (c *Copier) GetInsertedRows() int64 { return atomic.LoadInt64(&c.insertedRows) } diff --git a/pkg/csvcopy/csvcopy_test.go b/pkg/csvcopy/csvcopy_test.go index f52d700..e8a517a 100644 --- a/pkg/csvcopy/csvcopy_test.go +++ b/pkg/csvcopy/csvcopy_test.go @@ -17,6 +17,7 @@ import ( "github.com/testcontainers/testcontainers-go" "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" + "github.com/timescale/timescaledb-parallel-copy/pkg/buffer" ) func TestWriteDataToCSV(t *testing.T) { @@ -909,7 +910,7 @@ type MockErrorHandler struct { stop bool } -func (fs *MockErrorHandler) HandleError(batch Batch, reason error) *BatchError { +func (fs *MockErrorHandler) HandleError(ctx context.Context, c *Copier, db *sqlx.Conn, batch Batch, reason error) HandleBatchErrorResult { if fs.Errors == nil { fs.Errors = map[int]error{} } @@ -975,7 +976,7 @@ func TestFailedBatchHandlerFailure(t *testing.T) { writer.Flush() - copier, err := NewCopier(connStr, "metrics", WithColumns("device_id,label,value"), WithBatchSize(2), WithBatchErrorHandler(func(batch Batch, err error) *BatchError { + copier, err := NewCopier(connStr, "metrics", WithColumns("device_id,label,value"), WithBatchSize(2), WithBatchErrorHandler(func(_ context.Context, _ *Copier, _ *sqlx.Conn, _ Batch, err error) HandleBatchErrorResult { return NewErrStop(fmt.Errorf("couldn't handle error %w", err)) })) require.NoError(t, err) @@ -1074,14 +1075,14 @@ func TestTransactionState(t *testing.T) { assert.Equal(t, "test-file-id", row.ImportID) assert.Equal(t, int64(0), row.StartRow) assert.Equal(t, 2, row.RowCount) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) batch2, row, err := batch1.Next(ctx, connx) require.NoError(t, err) assert.Equal(t, "test-file-id", row.ImportID) assert.Equal(t, int64(2), row.StartRow) assert.Equal(t, 2, row.RowCount) - assert.Equal(t, transactionRowStateFailed, row.State) + assert.Equal(t, TransactionRowStateFailed, row.State) assert.NotEmpty(t, row.FailureReason) batch3, row, err := batch2.Next(ctx, connx) @@ -1089,7 +1090,7 @@ func TestTransactionState(t *testing.T) { assert.Equal(t, "test-file-id", row.ImportID) assert.Equal(t, int64(4), row.StartRow) assert.Equal(t, 2, row.RowCount) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) batch4, row, err := batch3.Next(ctx, connx) require.NoError(t, err) @@ -1178,15 +1179,15 @@ func TestTransactionIdempotency(t *testing.T) { batch1, row, err := LoadTransaction(ctx, connx, "test-file-id") require.NoError(t, err) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) batch2, row, err := batch1.Next(ctx, connx) require.NoError(t, err) - assert.Equal(t, transactionRowStateFailed, row.State) + assert.Equal(t, TransactionRowStateFailed, row.State) _, row, err = batch2.Next(ctx, connx) require.NoError(t, err) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) _, err = tmpfile.Seek(0, 0) require.NoError(t, err) @@ -1211,16 +1212,16 @@ func TestTransactionIdempotency(t *testing.T) { batch1, row, err = LoadTransaction(ctx, connx, "test-file-id") require.NoError(t, err) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) batch2, row, err = batch1.Next(ctx, connx) require.NoError(t, err) - assert.Equal(t, transactionRowStateFailed, row.State) + assert.Equal(t, TransactionRowStateFailed, row.State) _, row, err = batch2.Next(ctx, connx) require.NoError(t, err) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) var total int err = connx.QueryRowxContext(ctx, "SELECT COUNT(*) FROM public.metrics").Scan(&total) @@ -1483,15 +1484,15 @@ func TestTransactionFailureRetry(t *testing.T) { batch1, row, err := LoadTransaction(ctx, connx, "test-file-id") require.NoError(t, err) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) batch2, row, err := batch1.Next(ctx, connx) require.NoError(t, err) - assert.Equal(t, transactionRowStateFailed, row.State) + assert.Equal(t, TransactionRowStateFailed, row.State) _, row, err = batch2.Next(ctx, connx) require.NoError(t, err) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) reader, err = os.Open(goodFile.Name()) require.NoError(t, err) @@ -1513,16 +1514,16 @@ func TestTransactionFailureRetry(t *testing.T) { batch1, row, err = LoadTransaction(ctx, connx, "test-file-id") require.NoError(t, err) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) batch2, row, err = batch1.Next(ctx, connx) require.NoError(t, err) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) _, row, err = batch2.Next(ctx, connx) require.NoError(t, err) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) var total int err = connx.QueryRowxContext(ctx, "SELECT COUNT(*) FROM public.metrics").Scan(&total) @@ -1639,16 +1640,16 @@ func TestTransactionFailureRetry(t *testing.T) { batch1, row, err := LoadTransaction(ctx, connx, "test-file-id") require.NoError(t, err) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) batch2, row, err := batch1.Next(ctx, connx) require.NoError(t, err) - assert.Equal(t, transactionRowStateFailed, row.State) + assert.Equal(t, TransactionRowStateFailed, row.State) assert.Contains(t, *row.FailureReason, "forced-failure") _, row, err = batch2.Next(ctx, connx) require.NoError(t, err) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) reader, err = os.Open(retryFile.Name()) require.NoError(t, err) @@ -1670,17 +1671,17 @@ func TestTransactionFailureRetry(t *testing.T) { batch1, row, err = LoadTransaction(ctx, connx, "test-file-id") require.NoError(t, err) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) batch2, row, err = batch1.Next(ctx, connx) require.NoError(t, err) - assert.Equal(t, transactionRowStateFailed, row.State) + assert.Equal(t, TransactionRowStateFailed, row.State) assert.Contains(t, *row.FailureReason, "still fails") _, row, err = batch2.Next(ctx, connx) require.NoError(t, err) - assert.Equal(t, transactionRowStateCompleted, row.State) + assert.Equal(t, TransactionRowStateCompleted, row.State) var total int err = connx.QueryRowxContext(ctx, "SELECT COUNT(*) FROM public.metrics").Scan(&total) @@ -1856,7 +1857,7 @@ func TestCalculateColumnsFromHeaders(t *testing.T) { columnMapping: ColumnsMapping(tt.columnMapping), quoteCharacter: tt.quoteCharacter, escapeCharacter: tt.escapeCharacter, - logger: &noopLogger{}, + Logger: &noopLogger{}, } // Create a buffered reader with the test CSV headers @@ -1885,7 +1886,7 @@ func TestCalculateColumnsFromHeaders_NoMapping(t *testing.T) { copier := &Copier{ skip: 1, columnMapping: ColumnsMapping{}, // Empty mapping - logger: &noopLogger{}, + Logger: &noopLogger{}, } csvData := "id,name,email\ndata1,data2,data3\n" @@ -1952,3 +1953,120 @@ func TestColumnsMapping_Get(t *testing.T) { }) } } + +// This test serves as an example of how the CopyFromBatch implementation is atomic. +func TestAtomicityAssurance(t *testing.T) { + ctx := context.Background() + + pgContainer, err := postgres.Run(ctx, + "postgres:15.3-alpine", + postgres.WithDatabase("test-db"), + postgres.WithUsername("postgres"), + postgres.WithPassword("postgres"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(5*time.Second)), + ) + require.NoError(t, err) + + t.Cleanup(func() { + err := pgContainer.Terminate(ctx) + require.NoError(t, err) + }) + + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + require.NoError(t, err) + + db, err := sqlx.ConnectContext(ctx, "pgx/v5", connStr) + require.NoError(t, err) + defer db.Close() + + // Setup test table + _, err = db.ExecContext(ctx, "CREATE TABLE public.test_metrics (device_id int, label text, value float8)") + require.NoError(t, err) + + // Ensure transaction table exists + err = ensureTransactionTable(ctx, connStr) + require.NoError(t, err) + + // Create test batch data + csvLines := [][]byte{ + []byte("42,test,4.2\n"), + []byte("24,data,2.4\n"), + } + seekableData := buffer.NewSeekable(csvLines) + + batch := Batch{ + Data: seekableData, + Location: Location{ + ImportID: "test-atomicity-assurance", + StartRow: 0, + RowCount: 2, + ByteOffset: 0, + ByteLen: len("42,test,4.2\n24,data,2.4\n"), + }, + } + + // Test the current implementation (should be atomic) + testCurrentImplementation := func() error { + connx, err := db.Connx(ctx) + if err != nil { + return fmt.Errorf("acquiring DBx connection for COPY: %w", err) + } + defer connx.Close() + + // Use BeginTxx as the current code does + tx, err := connx.BeginTxx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to start transaction: %w", err) + } + + defer func() { + _ = tx.Rollback() + }() + + // Set transaction control row first + tr := newTransactionAt(batch.Location) + err = tr.setCompleted(ctx, tx) + if err != nil { + return fmt.Errorf("failed to insert control row, %w", err) + } + + // Perform COPY operation (this should run within the transaction) + copyCmd := "COPY public.test_metrics(device_id,label,value) FROM STDIN WITH DELIMITER ',' CSV" + _, err = CopyFromLines(ctx, connx.Conn, batch.Data, copyCmd) + if err != nil { + return fmt.Errorf("failed to copy from lines %w", err) + } + + // Simulate failure - this should cause rollback of EVERYTHING if atomic + return fmt.Errorf("simulated failure - should rollback both COPY data and control row") + } + + // Run the test + err = testCurrentImplementation() + require.Error(t, err) + require.Contains(t, err.Error(), "simulated failure") + + // Check what actually happened + var targetRowCount int + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM public.test_metrics").Scan(&targetRowCount) + require.NoError(t, err) + + var controlRowCount int + err = db.QueryRowContext(ctx, "SELECT COUNT(*) FROM timescaledb_parallel_copy.transactions WHERE import_id = 'test-atomicity-assurance'").Scan(&controlRowCount) + require.NoError(t, err) + + // This test REQUIRES atomicity - it will FAIL if the implementation is broken + if targetRowCount != 0 || controlRowCount != 0 { + t.Errorf("ATOMICITY VIOLATION: Expected both counts to be 0, got targetRows=%d, controlRows=%d", + targetRowCount, controlRowCount) + t.Errorf("This means the COPY operation or control row was not properly rolled back") + t.FailNow() + } + + // If we get here, atomicity is working correctly + assert.Equal(t, 0, targetRowCount, "COPY data must be rolled back") + assert.Equal(t, 0, controlRowCount, "Control row must be rolled back") + t.Logf("SUCCESS: Current implementation maintains atomicity - both operations rolled back correctly") +} diff --git a/pkg/csvcopy/db.go b/pkg/csvcopy/db.go index 30a9367..8768578 100644 --- a/pkg/csvcopy/db.go +++ b/pkg/csvcopy/db.go @@ -29,10 +29,10 @@ func connect(connStr string) (*sqlx.DB, error) { return db, nil } -// copyFromLines bulk-loads data using the given copyCmd. lines must provide a +// CopyFromLines bulk-loads data using the given copyCmd. lines must provide a // set of complete lines of CSV data, including the end-of-line delimiters. // Returns the number of rows inserted. -func copyFromLines(ctx context.Context, conn *sql.Conn, lines io.Reader, copyCmd string) (int64, error) { +func CopyFromLines(ctx context.Context, conn *sql.Conn, lines io.Reader, copyCmd string) (int64, error) { var rowCount int64 // pgx requires us to use the low-level API for a raw COPY FROM operation. err := conn.Raw(func(driverConn interface{}) error { @@ -63,13 +63,15 @@ func copyFromBatch(ctx context.Context, db *sqlx.DB, batch Batch, copyCmd string defer connx.Close() if !batch.Location.HasImportID() { - rowCount, err := copyFromLines(ctx, connx.Conn, &batch.data, copyCmd) + rowCount, err := CopyFromLines(ctx, connx.Conn, batch.Data, copyCmd) if err != nil { return rowCount, fmt.Errorf("failed to copy from lines %w", err) } return rowCount, nil } + // This puts the connx in transaction mode, so when we use connx.Conn, it will be in the same transaction. + // Refer to TestAtomicityAssurance for a working example that proves atomicity. tx, err := connx.BeginTxx(ctx, &sql.TxOptions{}) if err != nil { return 0, fmt.Errorf("failed to start transaction: %w", err) @@ -101,7 +103,7 @@ func copyFromBatch(ctx context.Context, db *sqlx.DB, batch Batch, copyCmd string return 0, fmt.Errorf("failed to insert control row, %w", err) } - rowCount, err := copyFromLines(ctx, connx.Conn, &batch.data, copyCmd) + rowCount, err := CopyFromLines(ctx, connx.Conn, batch.Data, copyCmd) if err != nil { return rowCount, fmt.Errorf("failed to copy from lines %w", err) } diff --git a/pkg/csvcopy/options.go b/pkg/csvcopy/options.go index acf4f7c..91fa908 100644 --- a/pkg/csvcopy/options.go +++ b/pkg/csvcopy/options.go @@ -1,10 +1,13 @@ package csvcopy import ( + "context" "errors" "fmt" "strings" "time" + + "github.com/jmoiron/sqlx" ) type Option func(c *Copier) error @@ -20,7 +23,7 @@ func (l *noopLogger) Infof(msg string, args ...interface{}) {} // WithLogger sets the logger where the application will print debug messages func WithLogger(logger Logger) Option { return func(c *Copier) error { - c.logger = logger + c.Logger = logger return nil } } @@ -224,30 +227,46 @@ func WithSchemaName(schema string) Option { } } -func NewErrContinue(err error) *BatchError { - return &BatchError{ - Continue: true, - Err: err, +func NewErrContinue(err error) HandleBatchErrorResult { + return HandleBatchErrorResult{ + Continue: true, + Err: err, + Handled: false, + InsertedRows: 0, + SkippedRows: 0, } } -func NewErrStop(err error) *BatchError { - return &BatchError{ - Continue: false, - Err: err, +func NewErrStop(err error) HandleBatchErrorResult { + return HandleBatchErrorResult{ + Continue: false, + Err: err, + Handled: false, + InsertedRows: 0, + SkippedRows: 0, } } -type BatchError struct { +type HandleBatchErrorResult struct { + // Continue if true, The code will continue processing new batches. Otherwise, it will stop. Continue bool - Err error + // Handled if true, It means the error was correctly handled and the resulting batch will be marked as completed + Handled bool + // Rows number of rows successfully processed by the error handler. + // This ensures metrics stay up to date when error handlers can process the rows. + InsertedRows int64 + // Rows found but skipped due to a known reason + SkippedRows int64 + // Err is the error that was returned by the error handler. It may just return the original error to act as a middleware or a new error to indicate failure reason. + // The transaction will fail with this error if it is not a temporary error. + Err error } -func (err BatchError) Error() string { +func (err HandleBatchErrorResult) Error() string { return fmt.Sprintf("continue: %t, %s", err.Continue, err.Err) } -func (err BatchError) Unwrap() error { +func (err HandleBatchErrorResult) Unwrap() error { return err.Err } @@ -257,7 +276,19 @@ func (err BatchError) Unwrap() error { // If the error is not handled properly, returning an error will stop the workers // If ErrContinue is returned, the batch will be marked as failed but continue processing // if ErrStop is returned, the processing will stop -type BatchErrorHandler func(batch Batch, err error) *BatchError +type BatchErrorHandler func( + ctx context.Context, + // c is the copier that is being used to handle the error + c *Copier, + // db is the database connection running a transaction that is being used to handle the error + // It connects to the target database + // The error handler is not the owner of the transaction, so it must not commit or rollback it. + db *sqlx.Conn, + // batch is the batch that has the error + batch Batch, + // err is the error that was returned the copy operation. + err error, +) HandleBatchErrorResult // WithBatchErrorHandler specifies which fail handler implementation to use func WithBatchErrorHandler(handler BatchErrorHandler) Option { diff --git a/pkg/csvcopy/options_test.go b/pkg/csvcopy/options_test.go index 648aa9b..a51effc 100644 --- a/pkg/csvcopy/options_test.go +++ b/pkg/csvcopy/options_test.go @@ -170,7 +170,6 @@ func TestOptionsMutualExclusivity(t *testing.T) { errorContains: "column mapping is already set", }, - // Valid combinations that should work { name: "WithSkipHeader false + WithColumns should work", diff --git a/pkg/csvcopy/scan.go b/pkg/csvcopy/scan.go index 46a13dd..0dbdc25 100644 --- a/pkg/csvcopy/scan.go +++ b/pkg/csvcopy/scan.go @@ -6,7 +6,8 @@ import ( "fmt" "io" "log" - "net" + + "github.com/timescale/timescaledb-parallel-copy/pkg/buffer" ) // scanOptions contains all the configurable knobs for Scan. @@ -27,13 +28,13 @@ type scanOptions struct { // Batch represents an operation to copy data into the DB type Batch struct { - data net.Buffers + Data *buffer.Seekable Location Location } -func newBatch(data net.Buffers, location Location) Batch { +func newBatch(data *buffer.Seekable, location Location) Batch { b := Batch{ - data: data, + Data: data, Location: location, } return b @@ -41,7 +42,9 @@ func newBatch(data net.Buffers, location Location) Batch { // newBatchFromReader used for testing purposes func newBatchFromReader(r io.Reader) Batch { - b := Batch{} + b := Batch{ + Data: buffer.NewSeekable([][]byte{}), + } buf := make([]byte, 32*1024) for { @@ -55,7 +58,7 @@ func newBatchFromReader(r io.Reader) Batch { b.Location.ByteLen += n // Process the data read from the buffer - b.data = append(b.data, buf[:n]) + _, _ = b.Data.Write(buf[:n]) // Write cannot fail, just exists to meet Writer interface } return b @@ -109,7 +112,8 @@ func (l Location) HasImportID() bool { // // The caller is responsible for setting up the CountReader and buffered reader, // and for skipping any headers before calling this function. -func scan(ctx context.Context, counter *CountReader, reader *bufio.Reader, out chan<- Batch, opts scanOptions) error { +func scan(ctx context.Context, logger func(ctx context.Context, msg string, args ...interface{}), + counter *CountReader, reader *bufio.Reader, out chan<- Batch, opts scanOptions) error { var rowsRead int64 batchSize := 20 * 1024 * 1024 // 20 MB batch size @@ -134,9 +138,9 @@ func scan(ctx context.Context, counter *CountReader, reader *bufio.Reader, out c // (which would have bad memory usage and performance characteristics for // larger CSV datasets, and be wasted anyway as soon as the underlying // Postgres connection divides the data into smaller CopyData chunks), keep - // the slices as-is and store them in net.Buffers, which is a convenient - // io.Reader abstraction wrapped over a [][]byte. - bufs := make(net.Buffers, 0) + // the slices as-is and store them in our Seekable buffer, which is a convenient + // io.Reader abstraction wrapped over a [][]byte with additional seek functionality. + bufs := buffer.NewSeekable([][]byte{}) var bufferedRows int // finishedRow is true if the current row has been fully read and counted @@ -153,7 +157,7 @@ func scan(ctx context.Context, counter *CountReader, reader *bufio.Reader, out c case <-ctx.Done(): return ctx.Err() } - bufs = make(net.Buffers, 0) + bufs = buffer.NewSeekable([][]byte{}) bufferedRows = 0 byteStart = byteEnd return nil @@ -186,7 +190,7 @@ func scan(ctx context.Context, counter *CountReader, reader *bufio.Reader, out c byteEnd := counter.Total - reader.Buffered() // Chunk will be bigger than ChunkByteSize if we append the current line. Let's send the data we have int he buffer if byteEnd-byteStart > batchSize { - log.Printf("reached max batch size, sending %d rows", bufferedRows) + logger(ctx, "reached max batch size, sending %d rows", bufferedRows) err := send(byteEndBeforeLine) if err != nil { return err @@ -194,16 +198,13 @@ func scan(ctx context.Context, counter *CountReader, reader *bufio.Reader, out c } finishedRow = false - // ReadSlice doesn't make a copy of the data; to avoid an overwrite - // on the next call, we need to make one now. - buf := make([]byte, len(data)) - copy(buf, data) - bufs = append(bufs, buf) + + _, _ = bufs.Write(data) // Write cannot fail, just exists to meet Writer interface // Figure out whether we're still inside a quoted value, in which // case the row hasn't ended yet even if we're at the end of a line. // TODO: This may no be a feasible scenario given that we require a full row to be in the buffer. - scanner.Scan(buf) + scanner.Scan(data) if eol && !scanner.NeedsMore() { finishedRow = true bufferedRows++ @@ -222,7 +223,7 @@ func scan(ctx context.Context, counter *CountReader, reader *bufio.Reader, out c if err == io.EOF { // if we have data in the buffer and we are not at the end of a row, we need to count the last row // this can happen if the last row is not terminated by a newline - if len(bufs) > 0 && !finishedRow { + if bufs.TotalSize() > 0 && !finishedRow { bufferedRows++ rowsRead++ } @@ -232,7 +233,7 @@ func scan(ctx context.Context, counter *CountReader, reader *bufio.Reader, out c } } // Finished reading input, make sure last batch goes out. - if len(bufs) > 0 { + if bufs.HasData() { byteEnd := counter.Total - reader.Buffered() select { case out <- newBatch( diff --git a/pkg/csvcopy/scan_test.go b/pkg/csvcopy/scan_test.go index a398797..7a309d1 100644 --- a/pkg/csvcopy/scan_test.go +++ b/pkg/csvcopy/scan_test.go @@ -2,7 +2,6 @@ package csvcopy import ( "bufio" - "bytes" "context" "errors" "fmt" @@ -378,7 +377,10 @@ d" i := 0 for buf := range rowChan { assert.EqualValues(t, c.expectedRowCount[i], buf.Location.RowCount, "on batch %d", i) - actual = append(actual, string(bytes.Join(buf.data, nil))) + // Read all data from the Seekable buffer + _, _ = buf.Data.Seek(0, io.SeekStart) + data, _ := io.ReadAll(buf.Data) + actual = append(actual, string(data)) i++ } @@ -412,7 +414,7 @@ d" } } - err := scan(context.Background(), counter, bufferedReader, rowChan, opts) + err := scan(context.Background(), noopLoggerFunc, counter, bufferedReader, rowChan, opts) if err != nil { if c.expectedError == "" { assert.NoError(t, err) @@ -474,7 +476,7 @@ d" } } - err := scan(context.Background(), counter, bufferedReader, rowChan, opts) + err := scan(context.Background(), noopLoggerFunc, counter, bufferedReader, rowChan, opts) if !errors.Is(err, expected) { t.Errorf("Scan() returned unexpected error: %v", err) t.Logf("want: %v", expected) @@ -595,7 +597,7 @@ func BenchmarkScan(b *testing.B) { } } - err := scan(context.Background(), counter, bufferedReader, rowChan, opts) + err := scan(context.Background(), noopLoggerFunc, counter, bufferedReader, rowChan, opts) if err != nil { b.Errorf("Scan() returned unexpected error: %v", err) } @@ -617,3 +619,5 @@ func RandString(n int) string { } return string(b) } + +var noopLoggerFunc = func(ctx context.Context, msg string, args ...interface{}) {} diff --git a/pkg/csvcopy/transaction.go b/pkg/csvcopy/transaction.go index 138ba42..6d0ebfd 100644 --- a/pkg/csvcopy/transaction.go +++ b/pkg/csvcopy/transaction.go @@ -14,11 +14,11 @@ type Transaction struct { loc Location } -type transactionRowState string +type TransactionRowState string const ( - transactionRowStateCompleted transactionRowState = "completed" - transactionRowStateFailed transactionRowState = "failed" + TransactionRowStateCompleted TransactionRowState = "completed" + TransactionRowStateFailed TransactionRowState = "failed" ) type TransactionRow struct { @@ -28,7 +28,7 @@ type TransactionRow struct { ByteOffset int ByteLen int CreatedAt time.Time - State transactionRowState + State TransactionRowState FailureReason *string } diff --git a/pkg/errorhandlers/conflict_handler.go b/pkg/errorhandlers/conflict_handler.go new file mode 100644 index 0000000..40100a3 --- /dev/null +++ b/pkg/errorhandlers/conflict_handler.go @@ -0,0 +1,125 @@ +package errorhandlers + +import ( + "context" + "crypto/rand" + "encoding/hex" + "errors" + "fmt" + "io" + "strings" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jmoiron/sqlx" + "github.com/timescale/timescaledb-parallel-copy/pkg/csvcopy" +) + +// generateRandomTableSuffix creates a random suffix for temporary table names +func generateRandomTableSuffix() string { + bytes := make([]byte, 6) // 6 bytes = 12 hex characters + _, _ = rand.Read(bytes) + return hex.EncodeToString(bytes) +} + +// ConflictHandlerConfig holds configuration for BatchConflictHandler +type ConflictHandlerConfig struct { + Next csvcopy.BatchErrorHandler +} + +// ConflictHandlerOption allows configuring the conflict handler +type ConflictHandlerOption func(*ConflictHandlerConfig) + + +// WithConflictHandlerNext sets the next batch error handler +func WithConflictHandlerNext(next csvcopy.BatchErrorHandler) ConflictHandlerOption { + return func(config *ConflictHandlerConfig) { + config.Next = next + } +} + +// BatchConflictHandler handles unique constraint violations during batch processing +// by creating temporal tables and using ON CONFLICT DO NOTHING to skip duplicates. +// This allows CSV imports to continue processing even when duplicate rows are encountered. +// +// The handler works by: +// 1. Detecting PostgreSQL unique constraint violations (error code 23505) +// 2. Creating a temporal table with the same structure as the destination +// 3. Copying the batch data to the temporal table +// 4. Using INSERT ... ON CONFLICT DO NOTHING to transfer only non-duplicate rows +// 5. Cleaning up the temporal table (automatic with PostgreSQL) +// +// If next is provided, non-unique-constraint errors are forwarded to it. +func BatchConflictHandler(options ...ConflictHandlerOption) csvcopy.BatchErrorHandler { + config := &ConflictHandlerConfig{} + for _, option := range options { + option(config) + } + const UniqueViolationError = "23505" + + return csvcopy.BatchErrorHandler(func(ctx context.Context, c *csvcopy.Copier, db *sqlx.Conn, batch csvcopy.Batch, reason error) csvcopy.HandleBatchErrorResult { + c.LogInfo(ctx, "BatchConflictHandler called: batch %d, byte offset %d, len %d", batch.Location.StartRow, batch.Location.ByteOffset, batch.Location.ByteLen) + + pgerr := &pgconn.PgError{} + if !errors.As(reason, &pgerr) { + c.LogInfo(ctx, "BatchConflictHandler: error is not PostgreSQL error. Type: %T, Error: %v", reason, reason) + if config.Next != nil { + return config.Next(ctx, c, db, batch, reason) + } + return csvcopy.NewErrStop(reason) + } + + if pgerr.Code != UniqueViolationError { + c.LogInfo(ctx, "BatchConflictHandler: not a unique constraint violation (code %s != %s). Forwarding to next handler.", pgerr.Code, UniqueViolationError) + if config.Next != nil { + return config.Next(ctx, c, db, batch, reason) + } + return csvcopy.NewErrStop(reason) + } + + c.LogInfo(ctx, "BatchConflictHandler: Batch %d, has conflict: %s", batch.Location.StartRow, reason.Error()) + _, err := batch.Data.Seek(0, io.SeekStart) + if err != nil { + return csvcopy.NewErrStop(fmt.Errorf("failed to seek to start of batch data, %w", err)) + } + + // Create a temporal table with random name (automatically cleaned up by PostgreSQL) + randomSuffix := generateRandomTableSuffix() + temporalTableName := fmt.Sprintf("tmp_batch_%s", randomSuffix) + + c.LogInfo(ctx, "BatchConflictHandler: Creating temporal table %s", temporalTableName) + _, err = db.ExecContext(ctx, fmt.Sprintf("/* Worker-%d */ CREATE TEMPORARY TABLE %s (LIKE %s INCLUDING DEFAULTS)", csvcopy.GetWorkerIDFromContext(ctx), temporalTableName, c.GetFullTableName())) + if err != nil { + return csvcopy.NewErrStop(fmt.Errorf("failed to create temporal table %s, %w", temporalTableName, err)) + } + + // Create copy command for temporal table + tempCopyCmd := strings.Replace(c.CopyCmdWithContext(ctx), c.GetFullTableName(), temporalTableName, 1) + rows, err := csvcopy.CopyFromLines(ctx, db.Conn, batch.Data, tempCopyCmd) + if err != nil { + return csvcopy.NewErrStop(fmt.Errorf("failed to copy from lines %w", err)) + } + + c.LogInfo(ctx, "BatchConflictHandler: Copied %d rows to temporal table %s", rows, temporalTableName) + + // Insert data using ON CONFLICT DO NOTHING to skip duplicates + insertSQL := fmt.Sprintf("/* Worker-%d */ INSERT INTO %s SELECT * FROM %s ON CONFLICT DO NOTHING", csvcopy.GetWorkerIDFromContext(ctx), c.GetFullTableName(), temporalTableName) + result, err := db.ExecContext(ctx, insertSQL) + if err != nil { + return csvcopy.NewErrStop(fmt.Errorf("failed to insert from temporal table %s to %s: %w", temporalTableName, c.GetFullTableName(), err)) + } + insertedRows, _ := result.RowsAffected() + + c.LogInfo(ctx, "BatchConflictHandler: Processed %d rows from temporal table %s to %s", insertedRows, temporalTableName, c.GetFullTableName()) + + // No need to drop temporal table - PostgreSQL automatically cleans it up + + return csvcopy.HandleBatchErrorResult{ + Continue: true, + InsertedRows: insertedRows, + SkippedRows: rows - insertedRows, + Handled: true, + } + }) +} + + diff --git a/pkg/errorhandlers/conflict_handler_test.go b/pkg/errorhandlers/conflict_handler_test.go new file mode 100644 index 0000000..88cb5fb --- /dev/null +++ b/pkg/errorhandlers/conflict_handler_test.go @@ -0,0 +1,267 @@ +package errorhandlers + +import ( + "context" + "encoding/csv" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/postgres" + "github.com/testcontainers/testcontainers-go/wait" + "github.com/timescale/timescaledb-parallel-copy/pkg/csvcopy" +) + +func TestBatchConflictHandler_WithUniqueConstraint(t *testing.T) { + ctx := context.Background() + + pgContainer, err := postgres.Run(ctx, + "postgres:15.3-alpine", + postgres.WithDatabase("test-db"), + postgres.WithUsername("postgres"), + postgres.WithPassword("postgres"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(5*time.Second)), + ) + require.NoError(t, err) + + t.Cleanup(func() { + err := pgContainer.Terminate(ctx) + require.NoError(t, err) + }) + + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + require.NoError(t, err) + + conn, err := pgx.Connect(ctx, connStr) + require.NoError(t, err) + defer conn.Close(ctx) + + // Create table with unique constraint + _, err = conn.Exec(ctx, ` + CREATE TABLE public.test_metrics ( + device_id int, + label text, + value float8, + UNIQUE(device_id, label) + ) + `) + require.NoError(t, err) + + // Create temporary CSV file with duplicate data + tmpfile, err := os.CreateTemp("", "batch_conflict_test") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + writer := csv.NewWriter(tmpfile) + data := [][]string{ + // Batch 1 - will succeed + {"1", "temp", "25.5"}, + {"2", "humidity", "60.0"}, + // Batch 2 - contains conflict + {"1", "temp", "26.0"}, // Duplicate! Should be skipped + {"3", "pressure", "1013.25"}, + // Batch 3 - contains another conflict + {"2", "humidity", "65.0"}, // Another duplicate! Should be skipped + {"4", "temp", "24.8"}, + } + + for _, record := range data { + err := writer.Write(record) + require.NoError(t, err) + } + writer.Flush() + + // Test with BatchConflictHandler - should handle conflicts gracefully + copier, err := csvcopy.NewCopier(connStr, "test_metrics", + csvcopy.WithColumns("device_id,label,value"), + csvcopy.WithBatchSize(2), + csvcopy.WithBatchErrorHandler(BatchConflictHandler(WithConflictHandlerNext(csvcopy.BatchHandlerNoop()))), + csvcopy.WithImportID("test-conflict-handling"), + ) + require.NoError(t, err) + + reader, err := os.Open(tmpfile.Name()) + require.NoError(t, err) + defer reader.Close() + + result, err := copier.Copy(context.Background(), reader) + require.NoError(t, err, "Copy should succeed with conflict handler") + + // Verify results + assert.EqualValues(t, 6, result.TotalRows, "Should process all 6 rows") + assert.EqualValues(t, 4, result.InsertedRows, "Should insert 4 unique rows") + assert.EqualValues(t, 2, result.SkippedRows, "No rows should be marked as skipped at the copier level") + + // Verify actual data in database + var actualCount int + err = conn.QueryRow(ctx, "SELECT COUNT(*) FROM public.test_metrics").Scan(&actualCount) + require.NoError(t, err) + assert.Equal(t, 4, actualCount, "Should have exactly 4 unique rows in database") + + // Verify specific rows exist (first occurrence of each unique combination) + rows, err := conn.Query(ctx, "SELECT device_id, label, value FROM public.test_metrics ORDER BY device_id, label") + require.NoError(t, err) + defer rows.Close() + + expectedRows := []struct { + deviceID int + label string + value float64 + }{ + {1, "temp", 25.5}, // First occurrence + {2, "humidity", 60.0}, // First occurrence + {3, "pressure", 1013.25}, + {4, "temp", 24.8}, + } + + i := 0 + for rows.Next() { + require.Less(t, i, len(expectedRows), "More rows than expected") + + var deviceID int + var label string + var value float64 + + err = rows.Scan(&deviceID, &label, &value) + require.NoError(t, err) + + expected := expectedRows[i] + assert.Equal(t, expected.deviceID, deviceID, "Device ID mismatch at row %d", i) + assert.Equal(t, expected.label, label, "Label mismatch at row %d", i) + assert.InDelta(t, expected.value, value, 0.01, "Value mismatch at row %d", i) + i++ + } + assert.Equal(t, len(expectedRows), i, "Should have exactly %d rows", len(expectedRows)) +} + +func TestBatchConflictHandler_WithoutBatchConflictHandler(t *testing.T) { + ctx := context.Background() + + pgContainer, err := postgres.Run(ctx, + "postgres:15.3-alpine", + postgres.WithDatabase("test-db"), + postgres.WithUsername("postgres"), + postgres.WithPassword("postgres"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(5*time.Second)), + ) + require.NoError(t, err) + + t.Cleanup(func() { + err := pgContainer.Terminate(ctx) + require.NoError(t, err) + }) + + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + require.NoError(t, err) + + conn, err := pgx.Connect(ctx, connStr) + require.NoError(t, err) + defer conn.Close(ctx) + + // Create table with unique constraint + _, err = conn.Exec(ctx, ` + CREATE TABLE public.test_metrics ( + device_id int, + label text, + value float8, + UNIQUE(device_id, label) + ) + `) + require.NoError(t, err) + + // Create temporary CSV file with duplicate data (same as previous test) + tmpfile, err := os.CreateTemp("", "batch_no_batch_conflict_handler_test") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + writer := csv.NewWriter(tmpfile) + data := [][]string{ + // Batch 1 - will succeed + {"1", "temp", "25.5"}, + {"2", "humidity", "60.0"}, + // Batch 2 - contains conflict, should cause failure without handler + {"1", "temp", "26.0"}, // Duplicate! Should cause error + {"3", "pressure", "1013.25"}, + // Batch 3 - won't be reached due to failure + {"2", "humidity", "65.0"}, + {"4", "temp", "24.8"}, + } + + for _, record := range data { + err := writer.Write(record) + require.NoError(t, err) + } + writer.Flush() + + // Test without BatchConflictHandler - should fail on unique constraint violation + copier, err := csvcopy.NewCopier(connStr, "test_metrics", + csvcopy.WithColumns("device_id,label,value"), + csvcopy.WithBatchSize(2), + csvcopy.WithImportID("test-no-conflict-handling"), + ) + require.NoError(t, err) + + reader, err := os.Open(tmpfile.Name()) + require.NoError(t, err) + defer reader.Close() + + result, err := copier.Copy(context.Background(), reader) + require.Error(t, err, "Copy should fail without conflict handler") + + // Verify error is related to unique constraint violation + assert.Contains(t, err.Error(), "duplicate key value violates unique constraint", + "Error should mention unique constraint violation") + + // Verify partial results - first batch should have succeeded + require.NotNil(t, result) + assert.EqualValues(t, 2, result.InsertedRows, "Should have inserted first batch (2 rows)") + assert.EqualValues(t, 4, result.TotalRows, "Should have processed up to the failed batch") + + // Verify actual data in database - only first batch should be there + var actualCount int + err = conn.QueryRow(ctx, "SELECT COUNT(*) FROM public.test_metrics").Scan(&actualCount) + require.NoError(t, err) + assert.Equal(t, 2, actualCount, "Should have only 2 rows from first successful batch") + + // Verify the specific rows that were inserted before failure + rows, err := conn.Query(ctx, "SELECT device_id, label, value FROM public.test_metrics ORDER BY device_id") + require.NoError(t, err) + defer rows.Close() + + expectedRows := []struct { + deviceID int + label string + value float64 + }{ + {1, "temp", 25.5}, + {2, "humidity", 60.0}, + } + + i := 0 + for rows.Next() { + require.Less(t, i, len(expectedRows), "More rows than expected") + + var deviceID int + var label string + var value float64 + + err = rows.Scan(&deviceID, &label, &value) + require.NoError(t, err) + + expected := expectedRows[i] + assert.Equal(t, expected.deviceID, deviceID, "Device ID mismatch at row %d", i) + assert.Equal(t, expected.label, label, "Label mismatch at row %d", i) + assert.InDelta(t, expected.value, value, 0.01, "Value mismatch at row %d", i) + i++ + } + assert.Equal(t, len(expectedRows), i, "Should have exactly %d rows", len(expectedRows)) +} +