diff --git a/internal/generator/output/general/test/unit_test.go b/internal/generator/output/general/test/unit_test.go index 4e7a91f..c5a3bc2 100644 --- a/internal/generator/output/general/test/unit_test.go +++ b/internal/generator/output/general/test/unit_test.go @@ -14,6 +14,9 @@ import ( "github.com/tarantool/sdvg/internal/generator/common" "github.com/tarantool/sdvg/internal/generator/models" outputGeneral "github.com/tarantool/sdvg/internal/generator/output/general" + "github.com/tarantool/sdvg/internal/generator/output/general/writer" + outputCsv "github.com/tarantool/sdvg/internal/generator/output/general/writer/csv" + outputParquet "github.com/tarantool/sdvg/internal/generator/output/general/writer/parquet" "github.com/tarantool/sdvg/internal/generator/usecase" useCaseGeneral "github.com/tarantool/sdvg/internal/generator/usecase/general" ) @@ -264,6 +267,57 @@ cause: dir for model is not empty } } +// TestWriterInitTeardown tests if Teardown works properly right after Init. +func TestWriterInitTeardown(t *testing.T) { + tmpDir := t.TempDir() + + testCases := []struct { + name string + writer writer.Writer + }{ + { + "csv", + outputCsv.NewWriter( + context.TODO(), + nil, + &models.CSVConfig{ + FloatPrecision: 1, + DatetimeFormat: "2006-01-02T15:04:05Z07:00", + Delimiter: ",", + WithoutHeaders: false, + }, + nil, + tmpDir, + false, + make(chan<- uint64), + ), + }, + { + "parquet", + outputParquet.NewWriter( + &models.Model{Columns: make([]*models.Column, 0)}, + &models.ParquetConfig{ + CompressionCodec: "UNCOMPRESSED", + FloatPrecision: 2, + DateTimeFormat: models.ParquetDateTimeMillisFormat, + }, + nil, + outputParquet.NewFileSystem(), + tmpDir, + false, + make(chan<- uint64), + ), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + require.NoError(t, tc.writer.Init()) + require.NoError(t, tc.writer.Teardown()) + }) + } +} + //nolint:lll func generate(t *testing.T, cfg *models.GenerationConfig, uc usecase.UseCase, continueGeneration, forceGeneration bool) error { t.Helper() diff --git a/internal/generator/output/general/writer/csv/csv.go b/internal/generator/output/general/writer/csv/csv.go index 066e300..fc10477 100644 --- a/internal/generator/output/general/writer/csv/csv.go +++ b/internal/generator/output/general/writer/csv/csv.go @@ -41,6 +41,8 @@ type Writer struct { fileDescriptor *os.File csvWriter *stdCSV.Writer flushTicker *time.Ticker + flushWg *sync.WaitGroup + flushStopChan chan struct{} totalWrittenRows uint64 bufferedRows uint64 @@ -51,7 +53,6 @@ type Writer struct { writerWg *sync.WaitGroup writerMutex *sync.Mutex started bool - stopChan chan struct{} } // NewWriter function creates Writer object. @@ -72,13 +73,14 @@ func NewWriter( outputPath: outputPath, continueGeneration: continueGeneration, flushTicker: time.NewTicker(flushInterval), + flushWg: &sync.WaitGroup{}, + flushStopChan: make(chan struct{}), writtenRowsChan: writtenRowsChan, writerChan: make(chan *models.DataRow), errorsChan: make(chan error, 1), writerWg: &sync.WaitGroup{}, writerMutex: &sync.Mutex{}, started: false, - stopChan: make(chan struct{}), } } @@ -104,6 +106,7 @@ func (w *Writer) Init() error { w.started = true w.writerWg.Add(1) + w.flushWg.Add(1) go w.writer() go w.flusher() @@ -142,9 +145,11 @@ func (w *Writer) writer() { } func (w *Writer) flusher() { + defer w.flushWg.Done() + for { select { - case <-w.stopChan: + case <-w.flushStopChan: return case <-w.flushTicker.C: if w.csvWriter != nil { @@ -427,14 +432,19 @@ func (w *Writer) Teardown() error { w.writerWg.Wait() w.flushTicker.Stop() - w.stopChan <- struct{}{} + close(w.flushStopChan) + w.flushWg.Wait() - if err := w.flush(); err != nil { - return err + if w.csvWriter != nil { + if err := w.flush(); err != nil { + return err + } } - if err := w.fileDescriptor.Close(); err != nil { - return errors.New(err.Error()) + if w.fileDescriptor != nil { + if err := w.fileDescriptor.Close(); err != nil { + return errors.New(err.Error()) + } } select { diff --git a/internal/generator/output/general/writer/parquet/parquet.go b/internal/generator/output/general/writer/parquet/parquet.go index da0f362..f1e2e8a 100644 --- a/internal/generator/output/general/writer/parquet/parquet.go +++ b/internal/generator/output/general/writer/parquet/parquet.go @@ -28,6 +28,9 @@ import ( const ( flushInterval = 5 * time.Second + //nolint:godox + // TODO: find optimal value, or calculate it to flush on disk 512Mb data. + recordBuilderReserve = 5000 ) var ( @@ -68,7 +71,10 @@ type Writer struct { parquetWriter *pqarrow.FileWriter writerProperties *parquet.WriterProperties recordBuilder *array.RecordBuilder - flushTicker *time.Ticker + + flushTicker *time.Ticker + flushWg *sync.WaitGroup + flushStopChan chan struct{} totalWrittenRows uint64 bufferedRows uint64 @@ -77,7 +83,6 @@ type Writer struct { errorChan chan error writerMutex *sync.Mutex started bool - stopCh chan struct{} } type FileSystem interface { @@ -106,11 +111,12 @@ func NewWriter( continueGeneration: continueGeneration, fs: fs, flushTicker: time.NewTicker(flushInterval), + flushWg: &sync.WaitGroup{}, + flushStopChan: make(chan struct{}), writtenRowsChan: writtenRowsChan, errorChan: make(chan error), writerMutex: &sync.Mutex{}, started: false, - stopCh: make(chan struct{}), } } @@ -273,10 +279,9 @@ func (w *Writer) Init() error { w.parquetModelSchema = modelSchema w.writerProperties = parquet.NewWriterProperties(writerProperties...) + w.recordBuilder = array.NewRecordBuilder(memory.DefaultAllocator, w.parquetModelSchema) - //nolint:mnd,godox - // TODO: find optimal value, or calculate it to flush on disk 512Mb data - w.recordBuilder.Reserve(5000) + w.recordBuilder.Reserve(recordBuilderReserve) if err = os.MkdirAll(w.outputPath, os.ModePerm); err != nil { return errors.New(err.Error()) @@ -301,7 +306,7 @@ func (w *Writer) Init() error { func (w *Writer) flusher() { for { select { - case <-w.stopCh: + case <-w.flushStopChan: return case <-w.flushTicker.C: //nolint:godox @@ -661,14 +666,24 @@ func (w *Writer) WriteRow(row *models.DataRow) error { // Teardown function waits recording finish and stops parquet writer and closes opened file descriptor. func (w *Writer) Teardown() error { w.flushTicker.Stop() - w.stopCh <- struct{}{} + w.flushStopChan <- struct{}{} + w.flushWg.Wait() - if err := w.flush(); err != nil { - return errors.New(err.Error()) + w.writerMutex.Lock() + if w.recordBuilder != nil && w.parquetWriter != nil { + w.writerMutex.Unlock() + + if err := w.flush(); err != nil { + return errors.New(err.Error()) + } } - if err := w.parquetWriter.Close(); err != nil { - return errors.New(err.Error()) + w.writerMutex.TryLock() + + if w.parquetWriter != nil { + if err := w.parquetWriter.Close(); err != nil { + return errors.New(err.Error()) + } } select {