diff --git a/README.md b/README.md index 9a179b3..8f0d38e 100644 --- a/README.md +++ b/README.md @@ -174,10 +174,16 @@ Other options and flags are also available: $ timescaledb-parallel-copy --help Usage of timescaledb-parallel-copy: - -batch-error-output-dir string - directory to store batch errors. Settings this will save a .csv file with the contents of the batch that failed and continue with the rest of the data. + -auto-column-mapping + Automatically map CSV headers to database columns with the same names + -batch-byte-size int + Max number of bytes to send in a batch (default 20971520) -batch-size int - Number of rows per insert (default 5000) + Number of rows per insert. It will be limited by batch-byte-size (default 5000) + -buffer-byte-size int + Number of bytes to buffer, it has to be big enough to hold a full row (default 2097152) + -column-mapping string + Column mapping from CSV to database columns (format: "csv_col1:db_col1,csv_col2:db_col2" or JSON) -columns string Comma-separated columns present in CSV -connection string @@ -222,6 +228,7 @@ Usage of timescaledb-parallel-copy: Number of parallel requests to make (default 1) ``` + ## Purpose PostgreSQL native `COPY` function is transactional and single-threaded, and may not be suitable for ingesting large @@ -237,7 +244,7 @@ less often. This improves memory management and keeps operations on the disk as We welcome contributions to this utility, which like TimescaleDB is released under the Apache2 Open Source License. The same [Contributors Agreement](//github.com/timescale/timescaledb/blob/master/CONTRIBUTING.md) applies; please sign the [Contributor License Agreement](https://cla-assistant.io/timescale/timescaledb-parallel-copy) (CLA) if you're a new contributor. -### Running Tests +## Running Tests Some of the tests require a running Postgres database. Set the `TEST_CONNINFO` environment variable to point at the database you want to run tests against. @@ -250,3 +257,94 @@ For example: $ createdb gotest $ TEST_CONNINFO='dbname=gotest user=myuser' go test -v ./... ``` + +## Advanced usage + +### Column Mapping + +The tool exposes two flags `--column-mapping` and `--auto-column-mapping` that allow to handle csv headers in a smart way. + +`--column-mapping` allows to specify how the columns from your csv map into database columns. It supports two formats: + +**Simple format:** +```bash +# Map CSV columns to database columns with different names +$ timescaledb-parallel-copy --connection $DATABASE_URL --table metrics --file data.csv \ + --column-mapping "timestamp:time,temperature:temp_celsius,humidity:humidity_percent" +``` + +**JSON format:** +```bash +# Same mapping using JSON format +$ timescaledb-parallel-copy --connection $DATABASE_URL --table metrics --file data.csv \ + --column-mapping '{"timestamp":"time","temperature":"temp_celsius","humidity":"humidity_percent"}' +``` + +Example CSV file with headers: +```csv +timestamp,temperature,humidity +2023-01-01 00:00:00,20.5,65.2 +2023-01-01 01:00:00,21.0,64.8 +``` + +This maps the CSV columns to database columns: `timestamp` → `time`, `temperature` → `temp_celsius`, `humidity` → `humidity_percent`. + +`--auto-column-mapping` covers the common case when your csv columns have the same name as your database columns. + +```bash +# Automatically map CSV headers to database columns with identical names +$ timescaledb-parallel-copy --connection $DATABASE_URL --table sensors --file sensors.csv \ + --auto-column-mapping +``` + +Example CSV file with headers matching database columns: +```csv +time,device_id,temperature,humidity +2023-01-01 00:00:00,sensor_001,20.5,65.2 +2023-01-01 01:00:00,sensor_002,21.0,64.8 +``` + +Both flags automatically skip the header row and cannot be used together with `--skip-header` or `--columns`. + +**Flexible Column Mapping:** + +Column mappings can include entries for columns that are not present in the input CSV file. This allows you to use the same mapping configuration across multiple input files with different column sets: + +```bash +# Define a comprehensive mapping that works with multiple CSV formats +$ timescaledb-parallel-copy --connection $DATABASE_URL --table sensors --file partial_data.csv \ + --column-mapping "timestamp:time,temp:temperature,humidity:humidity_percent,pressure:pressure_hpa,location:device_location" +``` + +Example CSV file with only some of the mapped columns: +```csv +timestamp,temp,humidity +2023-01-01 00:00:00,20.5,65.2 +2023-01-01 01:00:00,21.0,64.8 +``` + +In this case, only the `timestamp`, `temp`, and `humidity` columns from the CSV will be processed and mapped to `time`, `temperature`, and `humidity_percent` respectively. The unused mappings for `pressure` and `location` are simply ignored, allowing the same mapping configuration to work with different input files that may have varying column sets. + +You can also map different CSV column names to the same database column, as long as only one of them appears in any given input file: + +```bash +# Map both 'temp' and 'temperature' to the same database column +$ timescaledb-parallel-copy --connection $DATABASE_URL --table sensors --file data.csv \ + --column-mapping "timestamp:time,temp:temperature,temperature:temperature,humidity:humidity_percent" +``` + +This allows importing from different file formats into the same table: + +**File A** (uses 'temp'): +```csv +timestamp,temp,humidity +2023-01-01 00:00:00,20.5,65.2 +``` + +**File B** (uses 'temperature'): +```csv +timestamp,temperature,humidity +2023-01-01 02:00:00,22.1,63.5 +``` + +Both files can use the same mapping configuration and import successfully into the same database table, even though they use different column names for the temperature data. The tool only validates for duplicate database columns among the columns actually present in each specific input file. diff --git a/cmd/timescaledb-parallel-copy/main.go b/cmd/timescaledb-parallel-copy/main.go index 91a4a62..06b9252 100644 --- a/cmd/timescaledb-parallel-copy/main.go +++ b/cmd/timescaledb-parallel-copy/main.go @@ -3,6 +3,7 @@ package main import ( "context" + "encoding/json" "errors" "flag" "fmt" @@ -10,6 +11,7 @@ import ( "log" "os" "runtime" + "strings" "time" "github.com/timescale/timescaledb-parallel-copy/pkg/csvcopy" @@ -33,11 +35,14 @@ var ( quoteCharacter string escapeCharacter string - fromFile string - columns string - skipHeader bool - headerLinesCnt int - skipBatchErrors bool + fromFile string + columns string + columnMapping string + autoColumnMapping bool + skipHeader bool + headerLinesCnt int + skipLines int + skipBatchErrors bool importID string workers int @@ -68,8 +73,12 @@ func init() { flag.StringVar(&escapeCharacter, "escape", "", "The ESCAPE `character` to use during COPY (default '\"')") flag.StringVar(&fromFile, "file", "", "File to read from rather than stdin") flag.StringVar(&columns, "columns", "", "Comma-separated columns present in CSV") + flag.StringVar(&columnMapping, "column-mapping", "", "Column mapping from CSV to database columns (format: \"csv_col1:db_col1,csv_col2:db_col2\" or JSON)") + flag.BoolVar(&autoColumnMapping, "auto-column-mapping", false, "Automatically map CSV headers to database columns with the same names") + flag.BoolVar(&skipHeader, "skip-header", false, "Skip the first line of the input") - flag.IntVar(&headerLinesCnt, "header-line-count", 1, "Number of header lines") + flag.IntVar(&headerLinesCnt, "header-line-count", 1, "(deprecated) Number of header lines") + flag.IntVar(&skipLines, "skip-lines", 0, "Skip the first n lines of the input. it is applied before skip-header") flag.BoolVar(&skipBatchErrors, "skip-batch-errors", false, "if true, the copy will continue even if a batch fails") @@ -103,6 +112,11 @@ func main() { if dbName != "" { log.Fatalf("Error: Deprecated flag -db-name is being used. Update -connection to connect to the given database") } + + if headerLinesCnt != 1 { + log.Fatalf("Error: -header-line-count is deprecated. Use -skip-lines instead") + } + logger := &csvCopierLogger{} opts := []csvcopy.Option{ @@ -127,6 +141,18 @@ func main() { opts = append(opts, csvcopy.WithImportID(importID)) } + if columnMapping != "" { + mapping, err := parseColumnMapping(columnMapping) + if err != nil { + log.Fatalf("Error parsing column mapping: %v", err) + } + opts = append(opts, csvcopy.WithColumnMapping(mapping)) + } + + if autoColumnMapping { + opts = append(opts, csvcopy.WithAutoColumnMapping()) + } + batchErrorHandler := csvcopy.BatchHandlerError() if skipBatchErrors { batchErrorHandler = csvcopy.BatchHandlerNoop() @@ -136,10 +162,12 @@ func main() { } opts = append(opts, csvcopy.WithBatchErrorHandler(batchErrorHandler)) + if skipLines > 0 { + opts = append(opts, csvcopy.WithSkipHeaderCount(skipLines)) + } + if skipHeader { - opts = append(opts, - csvcopy.WithSkipHeaderCount(headerLinesCnt), - ) + opts = append(opts, csvcopy.WithSkipHeader(true)) } copier, err := csvcopy.NewCopier( @@ -190,3 +218,73 @@ func main() { } fmt.Println(res) } + +// parseColumnMapping parses column mapping string into csvcopy.ColumnsMapping +// Supports two formats: +// 1. Simple: "csv_col1:db_col1,csv_col2:db_col2" +// 2. JSON: {"csv_col1":"db_col1","csv_col2":"db_col2"} +func parseColumnMapping(mappingStr string) (csvcopy.ColumnsMapping, error) { + if mappingStr == "" { + return nil, nil + } + + mappingStr = strings.TrimSpace(mappingStr) + + // Check if it's JSON format (starts with '{') + if strings.HasPrefix(mappingStr, "{") { + return parseJSONColumnMapping(mappingStr) + } + + // Parse simple format: "csv_col1:db_col1,csv_col2:db_col2" + return parseSimpleColumnMapping(mappingStr) +} + +// parseJSONColumnMapping parses JSON format column mapping +func parseJSONColumnMapping(jsonStr string) (csvcopy.ColumnsMapping, error) { + var mappingMap map[string]string + if err := json.Unmarshal([]byte(jsonStr), &mappingMap); err != nil { + return nil, fmt.Errorf("invalid JSON format for column mapping: %w", err) + } + + var mapping csvcopy.ColumnsMapping + for csvCol, dbCol := range mappingMap { + mapping = append(mapping, csvcopy.ColumnMapping{ + CSVColumnName: csvCol, + DatabaseColumnName: dbCol, + }) + } + + return mapping, nil +} + +// parseSimpleColumnMapping parses simple format: "csv_col1:db_col1,csv_col2:db_col2" +func parseSimpleColumnMapping(simpleStr string) (csvcopy.ColumnsMapping, error) { + pairs := strings.Split(simpleStr, ",") + var mapping csvcopy.ColumnsMapping + + for i, pair := range pairs { + pair = strings.TrimSpace(pair) + if pair == "" { + continue + } + + parts := strings.Split(pair, ":") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid column mapping format at position %d: '%s', expected 'csv_column:db_column'", i+1, pair) + } + + csvCol := strings.TrimSpace(parts[0]) + dbCol := strings.TrimSpace(parts[1]) + + if csvCol == "" || dbCol == "" { + return nil, fmt.Errorf("empty column name in mapping at position %d: '%s'", i+1, pair) + } + + mapping = append(mapping, csvcopy.ColumnMapping{ + CSVColumnName: csvCol, + DatabaseColumnName: dbCol, + }) + } + + return mapping, nil +} diff --git a/pkg/csvcopy/csvcopy.go b/pkg/csvcopy/csvcopy.go index c1323f4..9ca5e51 100644 --- a/pkg/csvcopy/csvcopy.go +++ b/pkg/csvcopy/csvcopy.go @@ -1,6 +1,7 @@ package csvcopy import ( + "bufio" "context" "errors" "fmt" @@ -12,6 +13,7 @@ import ( "sync/atomic" "time" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" _ "github.com/jackc/pgx/v5/stdlib" "github.com/jmoiron/sqlx" @@ -19,6 +21,15 @@ import ( const TAB_CHAR_STR = "\\t" +type HeaderHandling int + +const ( + HeaderNone HeaderHandling = iota + HeaderSkip + HeaderAutoColumnMapping + HeaderColumnMapping +) + type Result struct { // InsertedRows is the number of rows inserted into the database by this copier instance InsertedRows int64 @@ -55,6 +66,8 @@ type Copier struct { skip int importID string idempotencyWindow time.Duration + columnMapping ColumnsMapping + useFileHeaders HeaderHandling // Rows that are inserted in the database by this copier instance insertedRows int64 @@ -131,7 +144,6 @@ func (c *Copier) Truncate() (err error) { } func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) { - if c.HasImportID() { if err := ensureTransactionTable(ctx, c.connString); err != nil { return Result{}, fmt.Errorf("failed to ensure transaction table, %w", err) @@ -142,6 +154,33 @@ func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) { } } + // Setup reader with buffering for header skipping + bufferSize := 2 * 1024 * 1024 // 2 MB buffer + if c.bufferSize > 0 { + bufferSize = c.bufferSize + } + + counter := &CountReader{Reader: reader} + bufferedReader := bufio.NewReaderSize(counter, bufferSize) + + if c.useFileHeaders == HeaderSkip { + c.skip++ + } + + if c.skip > 0 { + if err := skipLines(bufferedReader, c.skip); err != nil { + return Result{}, fmt.Errorf("failed to skip lines: %w", err) + } + } + + if c.useFileHeaders == HeaderAutoColumnMapping || c.useFileHeaders == HeaderColumnMapping { + // Increment number of skipped lines to account for the header line + c.skip++ + if err := c.calculateColumnsFromHeaders(bufferedReader); err != nil { + return Result{}, fmt.Errorf("failed to calculate columns from headers: %w", err) + } + } + var workerWg sync.WaitGroup batchChan := make(chan Batch, c.workers*2) @@ -178,12 +217,11 @@ func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) { } opts := scanOptions{ - Size: c.batchSize, - Skip: c.skip, - Limit: c.limit, - BufferByteSize: c.bufferSize, - BatchByteSize: c.batchByteSize, - ImportID: c.importID, + Size: c.batchSize, + Skip: c.skip, + Limit: c.limit, + BatchByteSize: c.batchByteSize, + ImportID: c.importID, } if c.quoteCharacter != "" { @@ -199,7 +237,7 @@ func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) { workerWg.Add(1) go func() { defer workerWg.Done() - if err := scan(ctx, reader, batchChan, opts); err != nil { + if err := scan(ctx, counter, bufferedReader, batchChan, opts); err != nil { errCh <- fmt.Errorf("failed reading input: %w", err) cancel() } @@ -238,6 +276,91 @@ func (c *Copier) Copy(ctx context.Context, reader io.Reader) (Result, error) { return result, nil } +func parseCSVHeaders(bufferedReader *bufio.Reader, quoteCharacter, escapeCharacter, splitCharacter string) ([]string, error) { + quote := byte('"') + if quoteCharacter != "" { + quote = quoteCharacter[0] + } + escape := quote + if escapeCharacter != "" { + escape = escapeCharacter[0] + } + + comma := ',' + if splitCharacter != "" { + comma = rune(splitCharacter[0]) + } + + return parseHeaders(bufferedReader, quote, escape, comma) +} + +func (c *Copier) useAutomaticColumnMapping(headers []string) error { + quotedHeaders := make([]string, len(headers)) + for i, header := range headers { + quotedHeaders[i] = pgx.Identifier{header}.Sanitize() + } + c.columns = strings.Join(quotedHeaders, ",") + c.logger.Infof("automatic column mapping: %s", c.columns) + return nil +} + +func validateColumnMapping(columnMapping ColumnsMapping) error { + seenMappingCSVColumns := make(map[string]bool) + for _, mapping := range columnMapping { + if seenMappingCSVColumns[mapping.CSVColumnName] { + return fmt.Errorf("duplicate source column name: %q", mapping.CSVColumnName) + } + seenMappingCSVColumns[mapping.CSVColumnName] = true + } + return nil +} + +func buildColumnsFromMapping(headers []string, columnMapping ColumnsMapping) ([]string, error) { + columns := make([]string, 0, len(headers)) + seenColumns := make(map[string]bool) + + for _, header := range headers { + dbColumn, ok := columnMapping.Get(header) + if !ok { + return nil, fmt.Errorf("column mapping not found for header %s", header) + } + + sanitizedColumn := pgx.Identifier{dbColumn}.Sanitize() + if seenColumns[sanitizedColumn] { + return nil, fmt.Errorf("duplicate database column name: %s", sanitizedColumn) + } + + seenColumns[sanitizedColumn] = true + columns = append(columns, sanitizedColumn) + } + + return columns, nil +} + +func (c *Copier) calculateColumnsFromHeaders(bufferedReader *bufio.Reader) error { + headers, err := parseCSVHeaders(bufferedReader, c.quoteCharacter, c.escapeCharacter, c.splitCharacter) + if err != nil { + return fmt.Errorf("failed to parse headers: %w", err) + } + + if len(c.columnMapping) == 0 { + return c.useAutomaticColumnMapping(headers) + } + + if err := validateColumnMapping(c.columnMapping); err != nil { + return err + } + + columns, err := buildColumnsFromMapping(headers, c.columnMapping) + if err != nil { + return err + } + + c.columns = strings.Join(columns, ",") + c.logger.Infof("Using column mapping: %s", c.columns) + return nil +} + type ErrAtRow struct { Err error // Row is the row reported by PgError @@ -478,3 +601,21 @@ func (c *Copier) GetTotalRows() int64 { func (c *Copier) HasImportID() bool { return c.importID != "" } + +// ColumnsMapping defines mapping from CSV column name to database column name +type ColumnsMapping []ColumnMapping + +func (c ColumnsMapping) Get(header string) (string, bool) { + for _, mapping := range c { + if mapping.CSVColumnName == header { + return mapping.DatabaseColumnName, true + } + } + return "", false +} + +// ColumnMapping defines mapping from CSV column name to database column name +type ColumnMapping struct { + CSVColumnName string // CSV column name from header + DatabaseColumnName string // Database column name for COPY statement +} diff --git a/pkg/csvcopy/csvcopy_test.go b/pkg/csvcopy/csvcopy_test.go index 252b537..f52d700 100644 --- a/pkg/csvcopy/csvcopy_test.go +++ b/pkg/csvcopy/csvcopy_test.go @@ -1,10 +1,12 @@ package csvcopy import ( + "bufio" "context" "encoding/csv" "fmt" "os" + "strings" "testing" "time" @@ -313,6 +315,90 @@ func TestErrorAtRow(t *testing.T) { assert.EqualValues(t, len(batch), errAtRow.BatchLocation.ByteLen) } +func TestErrorAtRowAndSkipLines(t *testing.T) { + ctx := context.Background() + + pgContainer, err := postgres.Run(ctx, + "postgres:15.3-alpine", + postgres.WithDatabase("test-db"), + postgres.WithUsername("postgres"), + postgres.WithPassword("postgres"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(5*time.Second)), + ) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := pgContainer.Terminate(ctx); err != nil { + t.Fatalf("failed to terminate pgContainer: %s", err) + } + }) + + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + require.NoError(t, err) + + conn, err := pgx.Connect(ctx, connStr) + require.NoError(t, err) + defer conn.Close(ctx) + _, err = conn.Exec(ctx, "create table public.metrics (device_id int, label text, value float8)") + require.NoError(t, err) + + // Create a temporary CSV file + tmpfile, err := os.CreateTemp("", "example") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + // Write data to the CSV file + writer := csv.NewWriter(tmpfile) + + data := [][]string{ + {"# This is a comment"}, + {"42", "xasev", "4.2"}, + {"24", "qased", "2.4"}, + {"24", "qased", "2.4"}, + {"24", "qased", "hello"}, + {"24", "qased", "2.4"}, + {"24", "qased", "2.4"}, + } + + for _, record := range data { + if err := writer.Write(record); err != nil { + t.Fatalf("Error writing record to CSV: %v", err) + } + } + + writer.Flush() + + copier, err := NewCopier(connStr, "metrics", WithColumns("device_id,label,value"), WithBatchSize(2), WithSkipHeaderCount(1)) + require.NoError(t, err) + reader, err := os.Open(tmpfile.Name()) + require.NoError(t, err) + r, err := copier.Copy(context.Background(), reader) + assert.Error(t, err) + + require.NotNil(t, r) + assert.EqualValues(t, 2, int(r.InsertedRows)) + assert.EqualValues(t, 4, int(r.TotalRows)) + assert.EqualValues(t, 0, int(r.SkippedRows)) + + errAtRow := &ErrAtRow{} + assert.ErrorAs(t, err, &errAtRow) + assert.EqualValues(t, 4, errAtRow.RowAtLocation()) + + prev := `# This is a comment +42,xasev,4.2 +24,qased,2.4 +` + assert.EqualValues(t, len(prev), errAtRow.BatchLocation.ByteOffset) + batch := `24,qased,2.4 +24,qased,hello +` + assert.EqualValues(t, len(batch), errAtRow.BatchLocation.ByteLen) +} + func TestErrorAtRowWithHeader(t *testing.T) { ctx := context.Background() @@ -396,6 +482,183 @@ func TestErrorAtRowWithHeader(t *testing.T) { assert.EqualValues(t, len(batch), errAtRow.BatchLocation.ByteLen) } +func TestErrorAtRowAutoColumnMappingAndSkipLines(t *testing.T) { + ctx := context.Background() + + pgContainer, err := postgres.Run(ctx, + "postgres:15.3-alpine", + postgres.WithDatabase("test-db"), + postgres.WithUsername("postgres"), + postgres.WithPassword("postgres"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(5*time.Second)), + ) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := pgContainer.Terminate(ctx); err != nil { + t.Fatalf("failed to terminate pgContainer: %s", err) + } + }) + + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + require.NoError(t, err) + + conn, err := pgx.Connect(ctx, connStr) + require.NoError(t, err) + defer conn.Close(ctx) + _, err = conn.Exec(ctx, "create table public.metrics (device_id int, label text, value float8)") + require.NoError(t, err) + + // Create a temporary CSV file + tmpfile, err := os.CreateTemp("", "example") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + // Write data to the CSV file + writer := csv.NewWriter(tmpfile) + + data := [][]string{ + {"# This is a comment"}, + {"# This is another comment"}, + {"# And the following line contain the actual headers"}, + {"device_id", "label", "value"}, + {"42", "xasev", "4.2"}, + {"24", "qased", "2.4"}, + {"24", "qased", "2.4"}, + {"24", "qased", "hello"}, + {"24", "qased", "2.4"}, + {"24", "qased", "2.4"}, + } + + for _, record := range data { + if err := writer.Write(record); err != nil { + t.Fatalf("Error writing record to CSV: %v", err) + } + } + + writer.Flush() + + copier, err := NewCopier(connStr, "metrics", WithAutoColumnMapping(), WithSkipHeaderCount(3), WithBatchSize(2)) + require.NoError(t, err) + reader, err := os.Open(tmpfile.Name()) + require.NoError(t, err) + r, err := copier.Copy(context.Background(), reader) + assert.Error(t, err) + + require.NotNil(t, r) + assert.EqualValues(t, 2, int(r.InsertedRows)) + assert.EqualValues(t, 4, int(r.TotalRows)) + assert.EqualValues(t, 0, int(r.SkippedRows)) + errAtRow := &ErrAtRow{} + assert.ErrorAs(t, err, &errAtRow) + assert.EqualValues(t, 7, errAtRow.RowAtLocation()) // skipped lines are also counted + + prev := `# This is a comment +# This is another comment +# And the following line contain the actual headers +device_id,label,value +42,xasev,4.2 +24,qased,2.4 +` + assert.EqualValues(t, len(prev), errAtRow.BatchLocation.ByteOffset) + batch := `24,qased,2.4 +24,qased,hello +` + assert.EqualValues(t, len(batch), errAtRow.BatchLocation.ByteLen) +} + +func TestErrorAtRowWithColumnMapping(t *testing.T) { + ctx := context.Background() + + pgContainer, err := postgres.Run(ctx, + "postgres:15.3-alpine", + postgres.WithDatabase("test-db"), + postgres.WithUsername("postgres"), + postgres.WithPassword("postgres"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(5*time.Second)), + ) + if err != nil { + t.Fatal(err) + } + + t.Cleanup(func() { + if err := pgContainer.Terminate(ctx); err != nil { + t.Fatalf("failed to terminate pgContainer: %s", err) + } + }) + + connStr, err := pgContainer.ConnectionString(ctx, "sslmode=disable") + require.NoError(t, err) + + conn, err := pgx.Connect(ctx, connStr) + require.NoError(t, err) + defer conn.Close(ctx) + _, err = conn.Exec(ctx, "create table public.metrics (device_id int, label text, value float8)") + require.NoError(t, err) + + // Create a temporary CSV file + tmpfile, err := os.CreateTemp("", "example") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + // Write data to the CSV file + writer := csv.NewWriter(tmpfile) + + data := [][]string{ + {"a", "b", "c"}, + {"42", "xasev", "4.2"}, + {"24", "qased", "2.4"}, + {"24", "qased", "2.4"}, + {"24", "qased", "hello"}, + {"24", "qased", "2.4"}, + {"24", "qased", "2.4"}, + } + + for _, record := range data { + if err := writer.Write(record); err != nil { + t.Fatalf("Error writing record to CSV: %v", err) + } + } + + writer.Flush() + + copier, err := NewCopier(connStr, "metrics", WithColumnMapping([]ColumnMapping{ + {CSVColumnName: "a", DatabaseColumnName: "device_id"}, + {CSVColumnName: "b", DatabaseColumnName: "label"}, + {CSVColumnName: "c", DatabaseColumnName: "value"}, + }), WithBatchSize(2)) + require.NoError(t, err) + reader, err := os.Open(tmpfile.Name()) + require.NoError(t, err) + r, err := copier.Copy(context.Background(), reader) + assert.Error(t, err) + + require.NotNil(t, r) + assert.EqualValues(t, 2, int(r.InsertedRows)) + assert.EqualValues(t, 4, int(r.TotalRows)) + assert.EqualValues(t, 0, int(r.SkippedRows)) + + errAtRow := &ErrAtRow{} + assert.ErrorAs(t, err, &errAtRow) + assert.EqualValues(t, 4, errAtRow.RowAtLocation()) // header line is also counted + + prev := `a,b,c +42,xasev,4.2 +24,qased,2.4 +` + assert.EqualValues(t, len(prev), errAtRow.BatchLocation.ByteOffset) + batch := `24,qased,2.4 +24,qased,hello +` + assert.EqualValues(t, len(batch), errAtRow.BatchLocation.ByteLen) +} + func TestWriteReportProgress(t *testing.T) { ctx := context.Background() @@ -1425,3 +1688,267 @@ func TestTransactionFailureRetry(t *testing.T) { assert.Equal(t, 4, total) }) } + +func TestCalculateColumnsFromHeaders(t *testing.T) { + tests := []struct { + name string + csvHeaders string + columnMapping []ColumnMapping + quoteCharacter string + escapeCharacter string + expectedColumns string + expectedError string + }{ + { + name: "simple mapping", + csvHeaders: "user_id,full_name,email_address", + columnMapping: []ColumnMapping{ + {CSVColumnName: "user_id", DatabaseColumnName: "id"}, + {CSVColumnName: "full_name", DatabaseColumnName: "name"}, + {CSVColumnName: "email_address", DatabaseColumnName: "email"}, + }, + expectedColumns: "\"id\",\"name\",\"email\"", + }, + { + name: "partial mapping", + csvHeaders: "id,name,age,email", + columnMapping: []ColumnMapping{ + {CSVColumnName: "id", DatabaseColumnName: "user_id"}, + {CSVColumnName: "name", DatabaseColumnName: "full_name"}, + {CSVColumnName: "email", DatabaseColumnName: "email_addr"}, + }, + expectedError: "column mapping not found for header age", + }, + { + name: "quoted headers", + csvHeaders: `"user id","full name","email address"`, + columnMapping: []ColumnMapping{ + {CSVColumnName: "user id", DatabaseColumnName: "id"}, + {CSVColumnName: "full name", DatabaseColumnName: "name"}, + {CSVColumnName: "email address", DatabaseColumnName: "email"}, + }, + expectedColumns: "\"id\",\"name\",\"email\"", + }, + { + name: "headers with spaces (no quotes)", + csvHeaders: "user id,full name,email address", + columnMapping: []ColumnMapping{ + {CSVColumnName: "user id", DatabaseColumnName: "id"}, + {CSVColumnName: "full name", DatabaseColumnName: "name"}, + {CSVColumnName: "email address", DatabaseColumnName: "email"}, + }, + expectedColumns: "\"id\",\"name\",\"email\"", + }, + { + name: "empty header", + csvHeaders: "id,,email", + columnMapping: []ColumnMapping{ + {CSVColumnName: "id", DatabaseColumnName: "user_id"}, + {CSVColumnName: "", DatabaseColumnName: "middle_col"}, + {CSVColumnName: "email", DatabaseColumnName: "email_addr"}, + }, + expectedColumns: "\"user_id\",\"middle_col\",\"email_addr\"", + }, + { + name: "single column", + csvHeaders: "id", + columnMapping: []ColumnMapping{ + {CSVColumnName: "id", DatabaseColumnName: "user_id"}, + }, + expectedColumns: "\"user_id\"", + }, + { + name: "complex quoted headers with commas", + csvHeaders: `"user,id","full,name","email,address"`, + columnMapping: []ColumnMapping{ + {CSVColumnName: "user,id", DatabaseColumnName: "id"}, + {CSVColumnName: "full,name", DatabaseColumnName: "name"}, + {CSVColumnName: "email,address", DatabaseColumnName: "email"}, + }, + expectedColumns: "\"id\",\"name\",\"email\"", + }, + { + name: "custom quote character", + csvHeaders: "'user id','full name','email address'", + quoteCharacter: "'", + escapeCharacter: "'", + columnMapping: []ColumnMapping{ + {CSVColumnName: "user id", DatabaseColumnName: "id"}, + {CSVColumnName: "full name", DatabaseColumnName: "name"}, + {CSVColumnName: "email address", DatabaseColumnName: "email"}, + }, + expectedColumns: "\"id\",\"name\",\"email\"", + }, + { + name: "case sensitive mapping", + csvHeaders: "ID,Name,Email", + columnMapping: []ColumnMapping{ + {CSVColumnName: "id", DatabaseColumnName: "user_id"}, + {CSVColumnName: "Name", DatabaseColumnName: "full_name"}, + {CSVColumnName: "Email", DatabaseColumnName: "email_addr"}, + }, + expectedError: "column mapping not found for header ID", + }, + { + name: "order preservation", + csvHeaders: "email,id,name", + columnMapping: []ColumnMapping{ + {CSVColumnName: "id", DatabaseColumnName: "user_id"}, + {CSVColumnName: "name", DatabaseColumnName: "full_name"}, + {CSVColumnName: "email", DatabaseColumnName: "email_addr"}, + }, + expectedColumns: "\"email_addr\",\"user_id\",\"full_name\"", + }, + { + name: "no column mapping - use all headers", + csvHeaders: `"user id","full name","email address"`, + columnMapping: []ColumnMapping{}, // Empty mapping - triggers "No column mapping provided" log + expectedColumns: "\"user id\",\"full name\",\"email address\"", + }, + { + name: "column mapping with more keys than CSV headers", + csvHeaders: "id,name", + columnMapping: []ColumnMapping{ + {CSVColumnName: "id", DatabaseColumnName: "user_id"}, + {CSVColumnName: "name", DatabaseColumnName: "full_name"}, + {CSVColumnName: "email", DatabaseColumnName: "email_addr"}, // Extra mapping key + {CSVColumnName: "age", DatabaseColumnName: "user_age"}, // Another extra mapping key + }, + expectedColumns: "\"user_id\",\"full_name\"", // Only mapped columns from CSV headers + }, + { + name: "duplicate database columns in mapping", + csvHeaders: "first_name,last_name,email", + columnMapping: []ColumnMapping{ + {CSVColumnName: "first_name", DatabaseColumnName: "name"}, + {CSVColumnName: "last_name", DatabaseColumnName: "name"}, // Same database column + {CSVColumnName: "email", DatabaseColumnName: "email_addr"}, + }, + expectedError: "duplicate database column name: \"name\"", + }, + { + name: "duplicate database columns in mapping but doesn't create a conflict", + csvHeaders: "first_name,email", + columnMapping: []ColumnMapping{ + {CSVColumnName: "first_name", DatabaseColumnName: "name"}, + {CSVColumnName: "name", DatabaseColumnName: "name"}, // legacy field mapping exmaple + {CSVColumnName: "email", DatabaseColumnName: "email_addr"}, + }, + expectedColumns: "\"name\",\"email_addr\"", + }, + { + name: "duplicate csv column name in mapping", + csvHeaders: "first_name,email", + columnMapping: []ColumnMapping{ + {CSVColumnName: "first_name", DatabaseColumnName: "name"}, + {CSVColumnName: "first_name", DatabaseColumnName: "first_name"}, // ERROR: it is duplicated + {CSVColumnName: "email", DatabaseColumnName: "email_addr"}, + }, + expectedError: "duplicate source column name: \"first_name\"", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a copier with the test configuration + copier := &Copier{ + skip: 1, + columnMapping: ColumnsMapping(tt.columnMapping), + quoteCharacter: tt.quoteCharacter, + escapeCharacter: tt.escapeCharacter, + logger: &noopLogger{}, + } + + // Create a buffered reader with the test CSV headers + csvData := tt.csvHeaders + "\ndata1,data2,data3\n" + reader := strings.NewReader(csvData) + counter := &CountReader{Reader: reader} + bufferedReader := bufio.NewReaderSize(counter, 1024) + + // Call the function under test + err := copier.calculateColumnsFromHeaders(bufferedReader) + + // Check the results + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedColumns, copier.columns) + } + }) + } +} + +func TestCalculateColumnsFromHeaders_NoMapping(t *testing.T) { + // Test the case where no column mapping is provided + copier := &Copier{ + skip: 1, + columnMapping: ColumnsMapping{}, // Empty mapping + logger: &noopLogger{}, + } + + csvData := "id,name,email\ndata1,data2,data3\n" + reader := strings.NewReader(csvData) + counter := &CountReader{Reader: reader} + bufferedReader := bufio.NewReaderSize(counter, 1024) + + err := copier.calculateColumnsFromHeaders(bufferedReader) + + require.NoError(t, err) + assert.Equal(t, "\"id\",\"name\",\"email\"", copier.columns) +} + +func TestColumnsMapping_Get(t *testing.T) { + mapping := ColumnsMapping{ + {CSVColumnName: "user_id", DatabaseColumnName: "id"}, + {CSVColumnName: "full_name", DatabaseColumnName: "name"}, + {CSVColumnName: "email_address", DatabaseColumnName: "email"}, + } + + tests := []struct { + name string + header string + expectedColumn string + expectedFound bool + }{ + { + name: "existing mapping", + header: "user_id", + expectedColumn: "id", + expectedFound: true, + }, + { + name: "another existing mapping", + header: "email_address", + expectedColumn: "email", + expectedFound: true, + }, + { + name: "non-existing mapping", + header: "age", + expectedColumn: "", + expectedFound: false, + }, + { + name: "empty header", + header: "", + expectedColumn: "", + expectedFound: false, + }, + { + name: "case sensitive", + header: "USER_ID", + expectedColumn: "", + expectedFound: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + column, found := mapping.Get(tt.header) + assert.Equal(t, tt.expectedFound, found) + assert.Equal(t, tt.expectedColumn, column) + }) + } +} diff --git a/pkg/csvcopy/options.go b/pkg/csvcopy/options.go index c2b7c30..00f649d 100644 --- a/pkg/csvcopy/options.go +++ b/pkg/csvcopy/options.go @@ -100,6 +100,9 @@ func WithEscapeCharacter(escapeCharacter string) Option { // WithColumns accepts a list of comma separated values for the csv columns func WithColumns(columns string) Option { return func(c *Copier) error { + if c.useFileHeaders == HeaderAutoColumnMapping || c.useFileHeaders == HeaderColumnMapping { + return errors.New("column mapping is already set. Use only one of: WithColumns, WithColumnMapping, or WithAutoColumnMapping") + } c.columns = columns return nil } @@ -108,12 +111,10 @@ func WithColumns(columns string) Option { // WithSkipHeader is set, skips the first row of the csv file func WithSkipHeader(skipHeader bool) Option { return func(c *Copier) error { - if c.skip != 0 { - return errors.New("skip is already set. Use SkipHeader or SkipHeaderCount") - } - if skipHeader { - c.skip = 1 + if c.useFileHeaders != HeaderNone { + return errors.New("header handling is already configured. Use only one of: WithSkipHeader, WithColumnMapping, or WithAutoColumnMapping") } + c.useFileHeaders = HeaderSkip return nil } } @@ -122,7 +123,7 @@ func WithSkipHeader(skipHeader bool) Option { func WithSkipHeaderCount(headerLineCount int) Option { return func(c *Copier) error { if c.skip != 0 { - return errors.New("skip is already set. Use SkipHeader or SkipHeaderCount") + return errors.New("skip is already set") } if headerLineCount <= 0 { return errors.New("header line count must be greater than zero") @@ -293,3 +294,47 @@ func WithIdempotencyWindow(window time.Duration) Option { return nil } } + +// WithColumnMapping sets the column mapping from CSV header names to database column names +// Each ColumnMapping specifies CSVColumnName and DatabaseColumnName +// This option automatically enables header skipping (sets skip to 1) +func WithColumnMapping(mappings []ColumnMapping) Option { + return func(c *Copier) error { + if mappings == nil { + return errors.New("column mapping cannot be nil") + } + if c.useFileHeaders != HeaderNone { + return errors.New("header handling is already configured. Use only one of: WithSkipHeader, WithColumnMapping, or WithAutoColumnMapping") + } + if c.columns != "" { + return errors.New("columns are already set. Use only one of: WithColumns, WithColumnMapping, or WithAutoColumnMapping") + } + for i, mapping := range mappings { + if mapping.CSVColumnName == "" { + return fmt.Errorf("column mapping at index %d has empty CSVColumnName", i) + } + if mapping.DatabaseColumnName == "" { + return fmt.Errorf("column mapping at index %d has empty DatabaseColumnName", i) + } + } + c.columnMapping = mappings + c.useFileHeaders = HeaderColumnMapping + return nil + } +} + +// WithAutoColumnMapping enables automatic column mapping where CSV header names +// are used as database column names (1:1 mapping) +// This option automatically enables header skipping (sets skip to 1) +func WithAutoColumnMapping() Option { + return func(c *Copier) error { + if c.useFileHeaders != HeaderNone { + return errors.New("header handling is already configured. Use only one of: WithSkipHeader, WithColumnMapping, or WithAutoColumnMapping") + } + if c.columns != "" { + return errors.New("columns are already set. Use only one of: WithColumns, WithColumnMapping, or WithAutoColumnMapping") + } + c.useFileHeaders = HeaderAutoColumnMapping + return nil + } +} diff --git a/pkg/csvcopy/options_test.go b/pkg/csvcopy/options_test.go new file mode 100644 index 0000000..65cb474 --- /dev/null +++ b/pkg/csvcopy/options_test.go @@ -0,0 +1,379 @@ +package csvcopy + +import ( + "strings" + "testing" +) + +func TestOptionsMutualExclusivity(t *testing.T) { + tests := []struct { + name string + options []Option + expectError bool + errorContains string + }{ + // Valid individual configurations + { + name: "WithSkipHeader alone should work", + options: []Option{WithSkipHeader(true)}, + expectError: false, + }, + { + name: "WithSkipHeader false should work", + options: []Option{WithSkipHeader(false)}, + expectError: false, + }, + { + name: "WithColumnMapping alone should work", + options: []Option{WithColumnMapping([]ColumnMapping{ + {CSVColumnName: "csv_col", DatabaseColumnName: "db_col"}, + })}, + expectError: false, + }, + { + name: "WithAutoColumnMapping alone should work", + options: []Option{WithAutoColumnMapping()}, + expectError: false, + }, + { + name: "WithColumns alone should work", + options: []Option{WithColumns("col1,col2,col3")}, + expectError: false, + }, + { + name: "WithSkipHeaderCount alone should work", + options: []Option{WithSkipHeaderCount(2)}, + expectError: false, + }, + + // Mutual exclusivity tests - WithSkipHeader conflicts + { + name: "WithSkipHeader + WithColumnMapping should fail", + options: []Option{ + WithSkipHeader(true), + WithColumnMapping([]ColumnMapping{ + {CSVColumnName: "csv_col", DatabaseColumnName: "db_col"}, + }), + }, + expectError: true, + errorContains: "header handling is already configured", + }, + { + name: "WithSkipHeader + WithAutoColumnMapping should fail", + options: []Option{ + WithSkipHeader(true), + WithAutoColumnMapping(), + }, + expectError: true, + errorContains: "header handling is already configured", + }, + { + name: "WithColumnMapping + WithSkipHeader should fail", + options: []Option{ + WithColumnMapping([]ColumnMapping{ + {CSVColumnName: "csv_col", DatabaseColumnName: "db_col"}, + }), + WithSkipHeader(true), + }, + expectError: true, + errorContains: "header handling is already configured", + }, + { + name: "WithAutoColumnMapping + WithSkipHeader should fail", + options: []Option{ + WithAutoColumnMapping(), + WithSkipHeader(true), + }, + expectError: true, + errorContains: "header handling is already configured", + }, + + // Mutual exclusivity tests - Column mapping conflicts + { + name: "WithColumnMapping + WithAutoColumnMapping should fail", + options: []Option{ + WithColumnMapping([]ColumnMapping{ + {CSVColumnName: "csv_col", DatabaseColumnName: "db_col"}, + }), + WithAutoColumnMapping(), + }, + expectError: true, + errorContains: "header handling is already configured", + }, + { + name: "WithAutoColumnMapping + WithColumnMapping should fail", + options: []Option{ + WithAutoColumnMapping(), + WithColumnMapping([]ColumnMapping{ + {CSVColumnName: "csv_col", DatabaseColumnName: "db_col"}, + }), + }, + expectError: true, + errorContains: "header handling is already configured", + }, + + // Triple conflicts + { + name: "All three options should fail", + options: []Option{ + WithSkipHeader(true), + WithColumnMapping([]ColumnMapping{ + {CSVColumnName: "csv_col", DatabaseColumnName: "db_col"}, + }), + WithAutoColumnMapping(), + }, + expectError: true, + errorContains: "header handling is already configured", + }, + + // WithColumns conflicts with column mapping + { + name: "WithColumns + WithColumnMapping should fail", + options: []Option{ + WithColumns("col1,col2"), + WithColumnMapping([]ColumnMapping{ + {CSVColumnName: "csv_col", DatabaseColumnName: "db_col"}, + }), + }, + expectError: true, + errorContains: "columns are already set", + }, + { + name: "WithColumns + WithAutoColumnMapping should fail", + options: []Option{ + WithColumns("col1,col2"), + WithAutoColumnMapping(), + }, + expectError: true, + errorContains: "columns are already set", + }, + { + name: "WithColumnMapping + WithColumns should fail", + options: []Option{ + WithColumnMapping([]ColumnMapping{ + {CSVColumnName: "csv_col", DatabaseColumnName: "db_col"}, + }), + WithColumns("col1,col2"), + }, + expectError: true, + errorContains: "column mapping is already set", + }, + { + name: "WithAutoColumnMapping + WithColumns should fail", + options: []Option{ + WithAutoColumnMapping(), + WithColumns("col1,col2"), + }, + expectError: true, + errorContains: "column mapping is already set", + }, + + + // Valid combinations that should work + { + name: "WithSkipHeader false + WithColumns should work", + options: []Option{ + WithSkipHeader(false), + WithColumns("col1,col2"), + }, + expectError: false, + }, + { + name: "WithSkipHeaderCount + WithColumns should work", + options: []Option{ + WithSkipHeaderCount(2), + WithColumns("col1,col2"), + }, + expectError: false, + }, + { + name: "WithSkipHeader + WithSkipHeaderCount should work", + options: []Option{ + WithSkipHeader(true), + WithSkipHeaderCount(3), + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewCopier("test-conn", "test-table", tt.options...) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got none") + return + } + if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("Expected error to contain '%s', but got: %s", tt.errorContains, err.Error()) + } + } else { + if err != nil { + t.Errorf("Expected no error but got: %s", err.Error()) + } + } + }) + } +} + +func TestColumnMappingValidation(t *testing.T) { + tests := []struct { + name string + mappings []ColumnMapping + expectError bool + errorContains string + }{ + { + name: "Valid column mapping should work", + mappings: []ColumnMapping{ + {CSVColumnName: "csv_col1", DatabaseColumnName: "db_col1"}, + {CSVColumnName: "csv_col2", DatabaseColumnName: "db_col2"}, + }, + expectError: false, + }, + { + name: "Nil mappings should fail", + mappings: nil, + expectError: true, + errorContains: "column mapping cannot be nil", + }, + { + name: "Empty CSV column name should fail", + mappings: []ColumnMapping{ + {CSVColumnName: "", DatabaseColumnName: "db_col"}, + }, + expectError: true, + errorContains: "empty CSVColumnName", + }, + { + name: "Empty database column name should fail", + mappings: []ColumnMapping{ + {CSVColumnName: "csv_col", DatabaseColumnName: ""}, + }, + expectError: true, + errorContains: "empty DatabaseColumnName", + }, + { + name: "Multiple mappings with one invalid should fail", + mappings: []ColumnMapping{ + {CSVColumnName: "csv_col1", DatabaseColumnName: "db_col1"}, + {CSVColumnName: "", DatabaseColumnName: "db_col2"}, + }, + expectError: true, + errorContains: "empty CSVColumnName", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewCopier("test-conn", "test-table", WithColumnMapping(tt.mappings)) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got none") + return + } + if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("Expected error to contain '%s', but got: %s", tt.errorContains, err.Error()) + } + } else { + if err != nil { + t.Errorf("Expected no error but got: %s", err.Error()) + } + } + }) + } +} + +func TestHeaderHandlingEnumValues(t *testing.T) { + tests := []struct { + name string + option Option + expectedHeader HeaderHandling + }{ + { + name: "WithSkipHeader(true) should set HeaderSkip", + option: WithSkipHeader(true), + expectedHeader: HeaderSkip, + }, + { + name: "WithSkipHeader(false) should keep HeaderNone", + option: WithSkipHeader(false), + expectedHeader: HeaderSkip, + }, + { + name: "WithAutoColumnMapping should set HeaderAutoColumnMapping", + option: WithAutoColumnMapping(), + expectedHeader: HeaderAutoColumnMapping, + }, + { + name: "WithColumnMapping should set HeaderColumnMapping", + option: WithColumnMapping([]ColumnMapping{ + {CSVColumnName: "csv_col", DatabaseColumnName: "db_col"}, + }), + expectedHeader: HeaderColumnMapping, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + copier, err := NewCopier("test-conn", "test-table", tt.option) + if err != nil { + t.Errorf("Unexpected error: %s", err.Error()) + return + } + + if copier.useFileHeaders != tt.expectedHeader { + t.Errorf("Expected useFileHeaders to be %d, but got %d", tt.expectedHeader, copier.useFileHeaders) + } + }) + } +} + +func TestSkipHeaderCountValidation(t *testing.T) { + tests := []struct { + name string + count int + expectError bool + errorContains string + }{ + { + name: "Valid skip count should work", + count: 3, + expectError: false, + }, + { + name: "Zero skip count should fail", + count: 0, + expectError: true, + errorContains: "header line count must be greater than zero", + }, + { + name: "Negative skip count should fail", + count: -1, + expectError: true, + errorContains: "header line count must be greater than zero", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewCopier("test-conn", "test-table", WithSkipHeaderCount(tt.count)) + + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got none") + return + } + if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("Expected error to contain '%s', but got: %s", tt.errorContains, err.Error()) + } + } else { + if err != nil { + t.Errorf("Expected no error but got: %s", err.Error()) + } + } + }) + } +} diff --git a/pkg/csvcopy/scan.go b/pkg/csvcopy/scan.go index bf4c0ea..46a13dd 100644 --- a/pkg/csvcopy/scan.go +++ b/pkg/csvcopy/scan.go @@ -11,11 +11,10 @@ import ( // scanOptions contains all the configurable knobs for Scan. type scanOptions struct { - Size int // maximum number of rows per batch, It may be less than this if ChunkByteSize is reached first - Skip int // how many header lines to skip at the beginning - Limit int64 // total number of rows to scan after the header. - BufferByteSize int // buffer size for the reader. it has to be big enough to hold a full row - BatchByteSize int // Max byte size for a batch. + Size int // maximum number of rows per batch, It may be less than this if ChunkByteSize is reached first + Skip int // how many header lines to skip at the beginning + Limit int64 // total number of rows to scan after the header. + BatchByteSize int // Max byte size for a batch. Quote byte // the QUOTE character; defaults to '"' Escape byte // the ESCAPE character; defaults to QUOTE @@ -100,53 +99,24 @@ func (l Location) HasImportID() bool { return l.ImportID != "" } -// scan reads all lines from an io.Reader, partitions them into net.Buffers with -// opts.Size rows each, and writes each batch to the out channel. If opts.Skip -// is greater than zero, that number of lines will be discarded from the -// beginning of the data. If opts.Limit is greater than zero, then scan will -// stop once it has written that number of rows, across all batches, to the -// channel. +// scan reads all lines from a pre-configured buffered reader, partitions them into net.Buffers with +// opts.Size rows each, and writes each batch to the out channel. If opts.Limit is greater than zero, +// then scan will stop once it has written that number of rows, across all batches, to the channel. // // scan expects the input to be in Postgres CSV format. Since this format allows // rows to be split over multiple lines, the caller may provide opts.Quote and // opts.Escape as the QUOTE and ESCAPE characters used for the CSV input. -func scan(ctx context.Context, r io.Reader, out chan<- Batch, opts scanOptions) error { +// +// The caller is responsible for setting up the CountReader and buffered reader, +// and for skipping any headers before calling this function. +func scan(ctx context.Context, counter *CountReader, reader *bufio.Reader, out chan<- Batch, opts scanOptions) error { var rowsRead int64 - counter := &CountReader{Reader: r} - - bufferSize := 2 * 1024 * 1024 // 2 MB buffer - if opts.BufferByteSize > 0 { - bufferSize = opts.BufferByteSize - } batchSize := 20 * 1024 * 1024 // 20 MB batch size if opts.BatchByteSize > 0 { batchSize = opts.BatchByteSize } - if batchSize < bufferSize { - return fmt.Errorf("batch size (%d) is smaller than buffer size (%d)", batchSize, bufferSize) - } - - reader := bufio.NewReaderSize(counter, bufferSize) - - for skip := opts.Skip; skip > 0; { - // The use of ReadLine() here avoids copying or buffering data that - // we're just going to discard. - _, isPrefix, err := reader.ReadLine() - - if err == io.EOF { - // No data? - return nil - } else if err != nil { - return fmt.Errorf("skipping header: %w", err) - } - if !isPrefix { - // We pulled a full row from the buffer. - skip-- - } - } - quote := byte('"') if opts.Quote != 0 { quote = opts.Quote @@ -198,7 +168,7 @@ func scan(ctx context.Context, r io.Reader, out chan<- Batch, opts scanOptions) switch err { case bufio.ErrBufferFull: // If we hit buffer full, we do not have enough data to read a full row - return fmt.Errorf("reading lines, %w", err) + return fmt.Errorf("reading lines, %w. you should provably increase batch size", err) case io.EOF: // Also fine, but unlike ErrBufferFull we won't have another @@ -380,3 +350,112 @@ func (c *CountReader) Read(b []byte) (int, error) { c.Total += n return n, err } + +// skipLines skips the specified number of lines starting from the very beginning of the file. +func skipLines(reader *bufio.Reader, skip int) error { + for skip > 0 { + // The use of ReadLine() here avoids copying or buffering data that + // we're just going to discard. + _, isPrefix, err := reader.ReadLine() + + if err == io.EOF { + // No data? + return nil + } else if err != nil { + return fmt.Errorf("skipping line: %w", err) + } + if !isPrefix { + // We pulled a full row from the buffer. + skip-- + } + } + return nil +} + +// parseHeaders parses the first header line and skips remaining header lines +func parseHeaders(reader *bufio.Reader, quote, escape byte, comma rune) ([]string, error) { + // Read the first header line + var headerLine []byte + for { + data, isPrefix, err := reader.ReadLine() + if err == io.EOF { + return []string{}, nil + } else if err != nil { + return nil, fmt.Errorf("reading header: %w", err) + } + + headerLine = append(headerLine, data...) + if !isPrefix { + // We have a complete line + break + } + } + + // Parse the CSV header line using PostgreSQL CSV format + // (which differs from standard CSV in escape handling) + headers, err := parsePostgreSQLCSVLine(string(headerLine), comma, quote, escape) + if err != nil { + return nil, fmt.Errorf("parsing header line: %w", err) + } + + return headers, nil +} + +// parsePostgreSQLCSVLine parses a CSV line using PostgreSQL CSV format rules +// This handles quote, escape, and comma characters as PostgreSQL COPY expects +func parsePostgreSQLCSVLine(line string, comma rune, quote, escape byte) ([]string, error) { + var fields []string + var field []byte + var inQuote bool + + for i := 0; i < len(line); i++ { + b := line[i] + + if inQuote { + if b == escape && i+1 < len(line) { + // Handle escape sequences - look ahead to see what's being escaped + next := line[i+1] + if next == quote || next == escape { + // Valid escape sequence, add the escaped character + field = append(field, next) + i++ // Skip the next character as it's been consumed + continue + } + } + + if b == quote { + // End of quoted field + inQuote = false + continue + } + + // Regular character inside quotes + field = append(field, b) + } else { + if b == quote { + // Start of quoted field + inQuote = true + continue + } + + if rune(b) == comma { + // Field separator + fields = append(fields, string(field)) + field = field[:0] + continue + } + + // Regular character outside quotes + field = append(field, b) + } + } + + // Add the last field + fields = append(fields, string(field)) + + if inQuote { + return nil, fmt.Errorf("unterminated quoted field in header line") + } + + return fields, nil +} diff --git a/pkg/csvcopy/scan_test.go b/pkg/csvcopy/scan_test.go index bd0e435..a398797 100644 --- a/pkg/csvcopy/scan_test.go +++ b/pkg/csvcopy/scan_test.go @@ -303,7 +303,7 @@ d" size: 2, batchSize: 1024, bufferSize: 2048, - expectedError: "batch size (1024) is smaller than buffer size (2048)", + expectedError: "you should provably increase batch size", }, { name: "batch size is hit before line limit", @@ -329,6 +329,42 @@ d" 1, }, }, + { + name: "simple quoted headers with skip", + input: []string{ + `"user id","full name","email address"`, + `1,"John Doe","john@example.com"`, + `2,"Jane Smith","jane@example.com"`, + }, + size: 2, + skip: 1, + expected: []string{ + `1,"John Doe","john@example.com" +2,"Jane Smith","jane@example.com"`, + }, + expectedRowCount: []int{ + 2, + }, + }, + { + name: "skip first lines, then parse headers", + input: []string{ + `# This is a comment`, + `# This is another comment`, + `# And the following line contain the actual headers`, + `a,b,c`, + `1,2,3`, + `4,5,6`, + }, + size: 3, + skip: 3, // skip the comments, not the header line + expected: []string{ + "a,b,c\n1,2,3\n4,5,6", + }, + expectedRowCount: []int{ + 3, + }, + }, } for _, c := range cases { @@ -352,16 +388,31 @@ d" all := strings.Join(c.input, "\n") reader := strings.NewReader(all) opts := scanOptions{ - Size: c.size, - Skip: c.skip, - Limit: c.limit, - Quote: byte(c.quote), - Escape: byte(c.escape), - BufferByteSize: c.bufferSize, - BatchByteSize: c.batchSize, + Size: c.size, + Skip: c.skip, + Limit: c.limit, + Quote: byte(c.quote), + Escape: byte(c.escape), + BatchByteSize: c.batchSize, + } + + counter := &CountReader{Reader: reader} + bufferSize := 2 * 1024 * 1024 + if c.bufferSize > 0 { + bufferSize = c.bufferSize + } + bufferedReader := bufio.NewReaderSize(counter, bufferSize) + + // Skip headers if needed + if opts.Skip > 0 { + err := skipLines(bufferedReader, opts.Skip) + if err != nil { + assert.NoError(t, err) + return + } } - err := scan(context.Background(), reader, rowChan, opts) + err := scan(context.Background(), counter, bufferedReader, rowChan, opts) if err != nil { if c.expectedError == "" { assert.NoError(t, err) @@ -411,7 +462,19 @@ d" Skip: c.skip, } - err := scan(context.Background(), reader, rowChan, opts) + counter := &CountReader{Reader: reader} + bufferedReader := bufio.NewReaderSize(counter, 2*1024*1024) + + // Skip headers if needed + if opts.Skip > 0 { + err := skipLines(bufferedReader, opts.Skip) + if !errors.Is(err, expected) { + t.Errorf("Scan() returned unexpected error: %v", err) + t.Logf("want: %v", expected) + } + } + + err := scan(context.Background(), counter, bufferedReader, rowChan, opts) if !errors.Is(err, expected) { t.Errorf("Scan() returned unexpected error: %v", err) t.Logf("want: %v", expected) @@ -520,7 +583,19 @@ func BenchmarkScan(b *testing.B) { for i := 0; i < b.N; i++ { reader.Reset(data) // rewind to the beginning - err := scan(context.Background(), reader, rowChan, opts) + counter := &CountReader{Reader: reader} + bufferedReader := bufio.NewReaderSize(counter, 2*1024*1024) + + // Skip headers if needed + if opts.Skip > 0 { + err := skipLines(bufferedReader, opts.Skip) + if err != nil { + b.Errorf("Failed to skip headers: %v", err) + return + } + } + + err := scan(context.Background(), counter, bufferedReader, rowChan, opts) if err != nil { b.Errorf("Scan() returned unexpected error: %v", err) }