diff --git a/Makefile b/Makefile index 75d1ecb..4383bf3 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,7 @@ test/lint/fix: test/unit: go test -race ./... + go test ./internal/generator/cli/confirm -run '' ./internal/generator/cli/confirm/confirm_race_off_test.go test/cover: module=./... test/cover: diff --git a/doc/en/usage.md b/doc/en/usage.md index 3605b33..b27cf40 100644 --- a/doc/en/usage.md +++ b/doc/en/usage.md @@ -177,6 +177,8 @@ Structure `output.params` for format `csv`: - `datetime_format`: Date-time format. Default is `2006-01-02T15:04:05Z07:00`. - `without_headers`: Flag indicating if CSV headers should be excluded from data files. - `delimiter`: Single-character CSV delimiter. Default is `,`. +- `partition_files_limit`: Limit on the number of partition files, upon reaching which a prompt will appear asking whether to continue. + Ignored if the `--force` flag is specified. Default is `1000`. Structure `output.params` for format `parquet`: @@ -184,6 +186,8 @@ Structure `output.params` for format `parquet`: Default is `UNCOMPRESSED`. - `float_precision`: Floating-point number precision. Default is `2`. - `datetime_format`: Date-time format. Supported values: `millis`, `micros`. Default is `millis`. +- `partition_files_limit`: Limit on the number of partition files, upon reaching which a prompt will appear asking whether to continue. + Ignored if the `--force` flag is specified. Default is `1000`. Structure `output.params` for format `http`: @@ -458,7 +462,7 @@ sdvg generate ./models.yml ### Ignoring conflicts If you want to automatically remove conflicting files from the output directory -and continue generation without additional prompts, use the `-F` or `--force` flag: +and continue generation without additional prompts, use the `-f` or `--force` flag: ```shell sdvg generate --force ./models.yml @@ -469,7 +473,7 @@ sdvg generate --force ./models.yml To continue generation from the last recorded row: ```shell -sdvg generate --continue-generation ./models.yml +sdvg generate --continue ./models.yml ``` > **Important**: To correctly continue generation, you must not change the generation configuration diff --git a/doc/ru/usage.md b/doc/ru/usage.md index b88f2ec..90b3d9b 100644 --- a/doc/ru/usage.md +++ b/doc/ru/usage.md @@ -183,6 +183,8 @@ open_ai: - `datetime_format`: Формат даты и времени. По умолчанию `2006-01-02T15:04:05Z07:00`. - `without_headers`: Флаг, указывающий, исключать ли CSV заголовок из файлов с данными. - `delimiter`: Односимвольный CSV разделитель. По умолчанию `,`. +- `partition_files_limit`: Ограничение количества файлов партиций, при достижении которого всплывет вопрос о продолжении. + Игнорируется при указании флага `--force`. По умолчанию `1000` Структура `output.params` для формата `parquet`: @@ -190,6 +192,8 @@ open_ai: По умолчанию `UNCOMPRESSED`. - `float_precision`: Точность чисел с плавающей запятой. По умолчанию `2`. - `datetime_format`: Формат даты и времени. Поддерживаемые значения: `millis`, `micros`. По умолчанию `millis`. +- `partition_files_limit`: Ограничение количества файлов партиций, при достижении которого всплывет вопрос о продолжении. + Игнорируется при указании флага `--force`. По умолчанию `1000` Структура `output.params` для формата `http`: @@ -464,7 +468,7 @@ sdvg generate ./models.yml ### Игнорирование конфликтов Если вы хотите автоматически удалить конфликтующие файлы в выходной директории -и продолжить генерацию без дополнительных сообщений, используйте флаг `-F` или `--force`: +и продолжить генерацию без дополнительных сообщений, используйте флаг `-f` или `--force`: ```shell sdvg generate --force ./models.yml @@ -475,7 +479,7 @@ sdvg generate --force ./models.yml Для продолжения генерации с последней записанной строки: ```shell -sdvg generate --continue-generation ./models.yml +sdvg generate --continue ./models.yml ``` > **Важно**: для корректного продолжения генерации нельзя менять конфигурацию генерации и уже сгенерированные данные. diff --git a/internal/generator/cli/commands/consts.go b/internal/generator/cli/commands/consts.go index 35271bb..9aeb192 100644 --- a/internal/generator/cli/commands/consts.go +++ b/internal/generator/cli/commands/consts.go @@ -6,15 +6,15 @@ const ( ConfigPathDefaultValue = "" ConfigPathUsage = "Location of config file" - ContinueGenerationFlag = "continue-generation" - ContinueGenerationShortFlag = "C" + ContinueGenerationFlag = "continue" + ContinueGenerationShortFlag = "c" ContinueGenerationDefaultValue = false ContinueGenerationUsage = "Continue generation from the last recorded row" ForceGenerationFlag = "force" - ForceGenerationShortFlag = "F" + ForceGenerationShortFlag = "f" ForceGenerationFlagDefaultValue = false - ForceGenerationUsage = "Force generation even if output file conflicts found" + ForceGenerationUsage = "Force generation even if output file conflicts found and partition files limit reached" //nolint:lll TTYFlag = "tty" TTYShortFlag = "t" diff --git a/internal/generator/cli/commands/generate/generate.go b/internal/generator/cli/commands/generate/generate.go index caa2837..fa94b28 100644 --- a/internal/generator/cli/commands/generate/generate.go +++ b/internal/generator/cli/commands/generate/generate.go @@ -12,6 +12,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/tarantool/sdvg/internal/generator/cli/commands" + "github.com/tarantool/sdvg/internal/generator/cli/confirm" "github.com/tarantool/sdvg/internal/generator/cli/options" "github.com/tarantool/sdvg/internal/generator/cli/progress" "github.com/tarantool/sdvg/internal/generator/cli/progress/bar" @@ -124,7 +125,9 @@ func runGenerate(ctx context.Context, opts *generateOptions) error { return err } - out := general.NewOutput(generationCfg, opts.continueGeneration, opts.forceGeneration) + progressTrackerManager, confirm := initProgressTrackerManager(ctx, opts.renderer, opts.useTTY, opts.forceGeneration) + + out := general.NewOutput(generationCfg, opts.continueGeneration, opts.forceGeneration, confirm) taskID, err := opts.useCase.CreateTask( ctx, usecase.TaskConfig{ @@ -143,12 +146,11 @@ func runGenerate(ctx context.Context, opts *generateOptions) error { ) startProgressTracking( - ctx, + progressTrackerManager, opts.useCase, taskID, &finished, &wg, - opts.useTTY, ) err = opts.useCase.WaitResult(taskID) @@ -173,26 +175,51 @@ func runGenerate(ctx context.Context, opts *generateOptions) error { return nil } +// initProgressTrackerManager inits progress bar manager (progress.Tracker) +// and builds confirm.Confirm func based on useTTY and forceGeneration. +func initProgressTrackerManager( + ctx context.Context, + renderer render.Renderer, + useTTY bool, + forceGeneration bool, +) (progress.Tracker, confirm.Confirm) { + var ( + progressTrackerManager progress.Tracker + confirmFunc confirm.Confirm + ) + + if useTTY { + progressTrackerManager = bar.NewProgressBarManager(ctx) + + confirmFunc = confirm.BuildConfirmTTY(renderer, progressTrackerManager) + } else { + isUpdatePaused := &atomic.Bool{} + + progressTrackerManager = log.NewProgressLogManager(ctx, isUpdatePaused) + + confirmFunc = confirm.BuildConfirmNoTTY(renderer, progressTrackerManager, isUpdatePaused) + } + + if forceGeneration { + confirmFunc = func(_ context.Context, _ string) (bool, error) { + return true, nil + } + } + + return progressTrackerManager, confirmFunc +} + // startProgressTracking runs function to track progress of task // by getting progress from usecase object and displaying it. func startProgressTracking( - ctx context.Context, + progressTrackerManager progress.Tracker, uc usecase.UseCase, taskID string, finished *atomic.Bool, wg *sync.WaitGroup, - useTTY bool, ) { const delay = 500 * time.Millisecond - var progressTrackerManager progress.Tracker - - if useTTY { - progressTrackerManager = bar.NewProgressBarManager(ctx) - } else { - progressTrackerManager = log.NewProgressLogManager(ctx) - } - wg.Add(1) go func() { diff --git a/internal/generator/cli/commands/generate/generate_test.go b/internal/generator/cli/commands/generate/generate_test.go index 03f5798..7596225 100644 --- a/internal/generator/cli/commands/generate/generate_test.go +++ b/internal/generator/cli/commands/generate/generate_test.go @@ -256,7 +256,7 @@ func TestNewGenerateCommand(t *testing.T) { cliOpts.SetOut(streams.NewOut(os.Stdout)) cmd := NewGenerateCommand(cliOpts) - cmd.SetArgs([]string{"-F"}) + cmd.SetArgs([]string{"-f"}) err = cmd.Execute() diff --git a/internal/generator/cli/commands/serve/handlers.go b/internal/generator/cli/commands/serve/handlers.go index 7a36fd8..0d709f0 100644 --- a/internal/generator/cli/commands/serve/handlers.go +++ b/internal/generator/cli/commands/serve/handlers.go @@ -58,7 +58,7 @@ func handleGenerate(opts handlerOptions, c echo.Context) error { generationConfig.OutputConfig.Dir = models.DefaultOutputDir - out := general.NewOutput(&generationConfig, false, true) + out := general.NewOutput(&generationConfig, false, true, nil) taskID, err := opts.useCase.CreateTask( c.Request().Context(), usecase.TaskConfig{ diff --git a/internal/generator/cli/confirm/confirm.go b/internal/generator/cli/confirm/confirm.go new file mode 100644 index 0000000..adfc66e --- /dev/null +++ b/internal/generator/cli/confirm/confirm.go @@ -0,0 +1,119 @@ +package confirm + +import ( + "context" + "fmt" + "io" + "strings" + "sync/atomic" + + "github.com/manifoldco/promptui" + "github.com/pkg/errors" + "github.com/tarantool/sdvg/internal/generator/cli/render" + "github.com/tarantool/sdvg/internal/generator/cli/utils" +) + +var ErrPromptFailed = errors.New("prompt failed") + +// Confirm asks user a yes/no question. Returns true for “yes”. +type Confirm func(ctx context.Context, question string) (bool, error) + +//nolint:gocritic +func BuildConfirmTTY(in io.Reader, out io.Writer) func(ctx context.Context, question string) (bool, error) { + return func(ctx context.Context, question string) (bool, error) { + fmt.Fprintln(out) + + cancelableIn := newCancelableReader(in) + defer cancelableIn.Close() + + prompt := promptui.Prompt{ + Label: question + " [y/N]: ", + Default: "y", + Stdin: cancelableIn, + Stdout: utils.DummyReadWriteCloser{Writer: out}, + } + validate := func(s string) error { + if len(s) == 1 && strings.Contains("YyNn", s) || prompt.Default != "" && len(s) == 0 { + return nil + } + + return errors.New("invalid input") + } + prompt.Validate = validate + + var ( + input string + err error + promptFinished = make(chan struct{}) + ) + + go func() { + input, err = prompt.Run() // goroutine will block here until user input + + promptFinished <- struct{}{} + }() + + select { + case <-ctx.Done(): + return false, ctx.Err() + case <-promptFinished: + } + + if err != nil { + return false, errors.Wrap(ErrPromptFailed, err.Error()) + } + + return strings.Contains("Yy", input), nil + } +} + +func BuildConfirmNoTTY( + in render.Renderer, + out io.Writer, + isUpdatePaused *atomic.Bool, +) func(ctx context.Context, question string) (bool, error) { + return func(ctx context.Context, question string) (bool, error) { + // here we pause ProgressLogManager to stop sending progress messages + isUpdatePaused.Store(true) + defer isUpdatePaused.Store(false) + + for { + fmt.Fprintf(out, "%s [y/N]: ", question) + + var ( + input string + err error + inputReadFinished = make(chan struct{}) + ) + + go func() { + input, err = in.ReadLine() // goroutine will block here until user input + + inputReadFinished <- struct{}{} + }() + + select { + case <-ctx.Done(): + return false, ctx.Err() + case <-inputReadFinished: + } + + if err != nil { + return false, err + } + + if !in.IsTerminal() { + fmt.Fprintln(out, input) + } + + switch strings.ToLower(strings.TrimSpace(input)) { + case "y", "yes": + return true, nil + case "", "n", "no": + return false, nil + default: + fmt.Fprintln(out, "Please enter y or n") + } + } + } +} diff --git a/internal/generator/cli/confirm/confirm_racy_test.go b/internal/generator/cli/confirm/confirm_racy_test.go new file mode 100644 index 0000000..668d785 --- /dev/null +++ b/internal/generator/cli/confirm/confirm_racy_test.go @@ -0,0 +1,88 @@ +//go:build !race + +package confirm + +import ( + "bytes" + "context" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestConfirmTTY(t *testing.T) { + testCases := []struct { + name string + question string + input string + expected bool + expectedErr error + }{ + { + name: "Y", + question: "question", + input: "Y", + expected: true, + }, + { + name: "y", + question: "question", + input: "y", + expected: true, + }, + { + name: "yes", + question: "question", + input: "yes", + expectedErr: ErrPromptFailed, + }, + { + name: "N", + question: "question", + input: "N", + expected: false, + }, + { + name: "n", + question: "question", + input: "n", + expected: false, + }, + { + name: "no", + question: "question", + input: "no", + expectedErr: ErrPromptFailed, + }, + { + name: "Context canceled", + expectedErr: context.Canceled, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + input := bytes.Buffer{} + output := bytes.Buffer{} + + confirm := BuildConfirmTTY(&input, &output) + + ctx := context.Background() + + if errors.Is(tc.expectedErr, context.Canceled) { + var cancel context.CancelFunc + + ctx, cancel = context.WithCancel(ctx) + cancel() + } + + input.WriteString(tc.input + "\n") + + res, err := confirm(ctx, tc.question) + require.ErrorIs(t, err, tc.expectedErr, "expected: %v, got: %v", tc.expectedErr, err) + + require.Equal(t, tc.expected, res) + }) + } +} diff --git a/internal/generator/cli/confirm/confirm_test.go b/internal/generator/cli/confirm/confirm_test.go new file mode 100644 index 0000000..2f0e0ac --- /dev/null +++ b/internal/generator/cli/confirm/confirm_test.go @@ -0,0 +1,166 @@ +package confirm + +import ( + "bytes" + "context" + "sync/atomic" + "testing" + "time" + + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + rendererMock "github.com/tarantool/sdvg/internal/generator/cli/render/mock" +) + +var errMockTest = errors.New("mock test error") + +func TestConfirmNoTTY(t *testing.T) { + testCases := []struct { + name string + question string + ch chan time.Time + expected bool + expectedErr error + mockFunc func(r *rendererMock.Renderer) + }{ + { + name: "Y", + question: "question", + expected: true, + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("Y"+"\n", nil) + + r. + On("IsTerminal"). + Return(true) + }, + }, + { + name: "y", + question: "question", + expected: true, + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("y"+"\n", nil) + + r. + On("IsTerminal"). + Return(true) + }, + }, + { + name: "yes", + question: "question", + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("yes"+"\n", errMockTest) + }, + expectedErr: errMockTest, + }, + { + name: "N", + question: "question", + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("N"+"\n", nil) + + r. + On("IsTerminal"). + Return(true) + }, + expected: false, + }, + { + name: "n", + question: "question", + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("n"+"\n", nil) + + r. + On("IsTerminal"). + Return(true) + }, + expected: false, + }, + { + name: "no", + question: "question", + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("no"+"\n", errMockTest) + }, + expectedErr: errMockTest, + }, + { + name: "Context canceled", + mockFunc: func(r *rendererMock.Renderer) { + r.On("ReadLine"). + Return("", nil).Maybe() + }, + expectedErr: context.Canceled, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := rendererMock.NewRenderer(t) + tc.mockFunc(r) + + output := bytes.Buffer{} + isUpdatePaused := atomic.Bool{} + + confirm := BuildConfirmNoTTY(r, &output, &isUpdatePaused) + + ctx := context.Background() + + if errors.Is(tc.expectedErr, context.Canceled) { + var cancel context.CancelFunc + + ctx, cancel = context.WithCancel(ctx) + cancel() + } + + res, err := confirm(ctx, tc.question) + require.ErrorIs(t, err, tc.expectedErr, "expected: %v, got: %v", tc.expectedErr, err) + + require.Equal(t, tc.expected, res) + }) + } +} + +func TestConfirmNoTTY_IsUpdatePaused(t *testing.T) { + output := bytes.Buffer{} + + isUpdatePaused := atomic.Bool{} + + r := rendererMock.NewRenderer(t) + + confirm := BuildConfirmNoTTY(r, &output, &isUpdatePaused) + + mockFunc := func(r *rendererMock.Renderer, ch chan time.Time) { + r.On("ReadLine").WaitUntil(ch). + Return("Y"+"\n", nil) + + r. + On("IsTerminal"). + Return(true) + } + + ch := make(chan time.Time) + + mockFunc(r, ch) + + //nolint:errcheck + go confirm(context.Background(), "") + + start := time.Now() + ch <- start + + for isUpdatePaused.Load() { + if time.Since(start) > 2*time.Second { + t.Fatal("isUpdatePaused has not been called") + } + } +} diff --git a/internal/generator/cli/confirm/reader.go b/internal/generator/cli/confirm/reader.go new file mode 100644 index 0000000..8e5b67a --- /dev/null +++ b/internal/generator/cli/confirm/reader.go @@ -0,0 +1,38 @@ +package confirm + +import "io" + +// cancelableReader wraps an io.Reader and can be closed to make future reads fail. +type cancelableReader struct { + r io.Reader + closed chan struct{} +} + +// newCancelableReader creates a ReadCloser from an io.Reader. +// Closing it will make subsequent Read() calls return io.EOF. +func newCancelableReader(r io.Reader) io.ReadCloser { + return &cancelableReader{ + r: r, + closed: make(chan struct{}), + } +} + +func (c *cancelableReader) Read(p []byte) (int, error) { + select { + case <-c.closed: + return 0, io.EOF + default: + return c.r.Read(p) //nolint:wrapcheck + } +} + +func (c *cancelableReader) Close() error { + select { + case <-c.closed: + // already closed + default: + close(c.closed) + } + + return nil +} diff --git a/internal/generator/cli/progress/bar/bar.go b/internal/generator/cli/progress/bar/bar.go index baf1fb3..e6590aa 100644 --- a/internal/generator/cli/progress/bar/bar.go +++ b/internal/generator/cli/progress/bar/bar.go @@ -77,3 +77,8 @@ func (p *ProgressBarManager) UpdateProgress(name string, progress usecase.Progre func (p *ProgressBarManager) Wait() { p.progressManager.Wait() } + +// Write writes to stdout. +func (p *ProgressBarManager) Write(b []byte) (int, error) { + return p.progressManager.Write(b) //nolint:wrapcheck +} diff --git a/internal/generator/cli/progress/interfaces.go b/internal/generator/cli/progress/interfaces.go index 07e3027..57d8329 100644 --- a/internal/generator/cli/progress/interfaces.go +++ b/internal/generator/cli/progress/interfaces.go @@ -10,4 +10,6 @@ type Tracker interface { UpdateProgress(name string, progress usecase.Progress) // Wait function should wait for all tracked tasks to complete. Wait() + // Write function should write to stdout. + Write(b []byte) (int, error) } diff --git a/internal/generator/cli/progress/log/log.go b/internal/generator/cli/progress/log/log.go index 2662a0c..c9d5ed2 100644 --- a/internal/generator/cli/progress/log/log.go +++ b/internal/generator/cli/progress/log/log.go @@ -5,7 +5,9 @@ import ( "fmt" "log/slog" "math" + "os" "sync" + "sync/atomic" "time" "github.com/tarantool/sdvg/internal/generator/cli/progress" @@ -42,13 +44,16 @@ type ProgressLogManager struct { ctx context.Context //nolint:containedctx tasks map[string]*task wg sync.WaitGroup + + isUpdatePaused *atomic.Bool } -// NewProgressLogManager creates NewProgressLogManager object. -func NewProgressLogManager(ctx context.Context) progress.Tracker { +// NewProgressLogManager creates NewProgressLogManager object. isUpdatePaused is used to pause UpdateProgress. +func NewProgressLogManager(ctx context.Context, isUpdatePaused *atomic.Bool) progress.Tracker { return &ProgressLogManager{ - ctx: ctx, - tasks: make(map[string]*task), + ctx: ctx, + tasks: make(map[string]*task), + isUpdatePaused: isUpdatePaused, } } @@ -78,6 +83,12 @@ func (p *ProgressLogManager) UpdateProgress(name string, progress usecase.Progre return } + for p.isUpdatePaused.Load() { + if t.isDone() { + return + } + } + p.updateIntervals(t, progress.Done) t.current = progress.Done @@ -138,3 +149,8 @@ func (p *ProgressLogManager) eta(t *task) string { return fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) } + +// Write writes to default stdout. +func (p *ProgressLogManager) Write(b []byte) (int, error) { + return os.Stdout.Write(b) //nolint:wrapcheck +} diff --git a/internal/generator/cli/render/interfaces.go b/internal/generator/cli/render/interfaces.go index cc479b6..6b0c3a6 100644 --- a/internal/generator/cli/render/interfaces.go +++ b/internal/generator/cli/render/interfaces.go @@ -1,9 +1,13 @@ package render -import "context" +import ( + "context" +) // Renderer interface implementation should render interactive menu. // +// after regenerating mock, do not forget to add call to fn() argument in WithSpinner +// //go:generate go run github.com/vektra/mockery/v2@v2.51.1 --name=Renderer --output=mock --outpkg=mock type Renderer interface { // Logo should display application logo. @@ -16,4 +20,10 @@ type Renderer interface { TextMenu(ctx context.Context, title string) (string, error) // WithSpinner should display spinner. WithSpinner(title string, fn func()) + // IsTerminal should return true if renderer is connected to a terminal. + IsTerminal() bool + // ReadLine should read input from input stream. + ReadLine() (string, error) + // Read should read from input stream. + Read(p []byte) (int, error) } diff --git a/internal/generator/cli/render/mock/renderer.go b/internal/generator/cli/render/mock/renderer.go index 3580615..1114eee 100644 --- a/internal/generator/cli/render/mock/renderer.go +++ b/internal/generator/cli/render/mock/renderer.go @@ -41,11 +41,85 @@ func (_m *Renderer) InputMenu(ctx context.Context, title string, validateFunc fu return r0, r1 } +// IsTerminal provides a mock function with no fields +func (_m *Renderer) IsTerminal() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for IsTerminal") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + // Logo provides a mock function with no fields func (_m *Renderer) Logo() { _m.Called() } +// Read provides a mock function with given fields: p +func (_m *Renderer) Read(p []byte) (int, error) { + ret := _m.Called(p) + + if len(ret) == 0 { + panic("no return value specified for Read") + } + + var r0 int + var r1 error + if rf, ok := ret.Get(0).(func([]byte) (int, error)); ok { + return rf(p) + } + if rf, ok := ret.Get(0).(func([]byte) int); ok { + r0 = rf(p) + } else { + r0 = ret.Get(0).(int) + } + + if rf, ok := ret.Get(1).(func([]byte) error); ok { + r1 = rf(p) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ReadLine provides a mock function with no fields +func (_m *Renderer) ReadLine() (string, error) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for ReadLine") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func() (string, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() string); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // SelectionMenu provides a mock function with given fields: ctx, title, items func (_m *Renderer) SelectionMenu(ctx context.Context, title string, items []string) (string, error) { ret := _m.Called(ctx, title, items) diff --git a/internal/generator/cli/render/prompt/prompt.go b/internal/generator/cli/render/prompt/prompt.go index 67832af..91206d3 100644 --- a/internal/generator/cli/render/prompt/prompt.go +++ b/internal/generator/cli/render/prompt/prompt.go @@ -90,7 +90,7 @@ func (r *Renderer) SelectionMenu(ctx context.Context, title string, items []stri for { _, _ = fmt.Fprint(r.out, "Write a number: ") - input, err := r.readLine() + input, err := r.ReadLine() if err != nil { resultChan <- result{err: err} @@ -147,7 +147,7 @@ func (r *Renderer) InputMenu(ctx context.Context, title string, validateFunc fun for { _, _ = fmt.Fprintf(r.out, "%s: ", title) - input, err := r.readLine() + input, err := r.ReadLine() if err != nil { resultChan <- result{err: err} @@ -252,6 +252,28 @@ func (r *Renderer) WithSpinner(title string, fn func()) { fn() } +// ReadLine reads input from stdin. +func (r *Renderer) ReadLine() (string, error) { + if r.scanner.Scan() { + return strings.TrimSpace(r.scanner.Text()), nil + } + + if err := r.scanner.Err(); err != nil { + return "", errors.New(err.Error()) + } + + return "", errors.New(io.EOF.Error()) +} + +// IsTerminal returns true if this stream is connected to a terminal. +func (r *Renderer) IsTerminal() bool { + return r.in.IsTerminal() +} + +func (r *Renderer) Read(p []byte) (int, error) { + return r.in.Read(p) +} + // selectionPrompt returns prompt for selection items. func (r *Renderer) selectionPrompt(title string, items []string) promptui.Select { templates := &promptui.SelectTemplates{ @@ -371,19 +393,6 @@ func (r *Renderer) readFile(filePath string) (string, error) { return strings.TrimSpace(sb.String()), nil } -// readInput reads input from stdin. -func (r *Renderer) readLine() (string, error) { - if r.scanner.Scan() { - return strings.TrimSpace(r.scanner.Text()), nil - } - - if err := r.scanner.Err(); err != nil { - return "", errors.New(err.Error()) - } - - return "", errors.New(io.EOF.Error()) -} - func (r *Renderer) readMultiline() (string, error) { var sb strings.Builder diff --git a/internal/generator/cli/render/prompt/prompt_test.go b/internal/generator/cli/render/prompt/prompt_test.go index 1e4da74..6253d5c 100644 --- a/internal/generator/cli/render/prompt/prompt_test.go +++ b/internal/generator/cli/render/prompt/prompt_test.go @@ -427,7 +427,7 @@ func readLinesTestFunc(t *testing.T, tc readLinesTestCase, mode int) { switch mode { case SingleLine: - actual, err = renderer.readLine() + actual, err = renderer.ReadLine() case MultiLine: actual, err = renderer.readMultiline() } diff --git a/internal/generator/cli/streams/in.go b/internal/generator/cli/streams/in.go index eff6e5a..138d6c4 100644 --- a/internal/generator/cli/streams/in.go +++ b/internal/generator/cli/streams/in.go @@ -1,18 +1,12 @@ -//nolint:dupl package streams import ( "io" "github.com/moby/term" + "github.com/tarantool/sdvg/internal/generator/cli/utils" ) -type nopReadCloser struct { - io.Reader -} - -func (nopReadCloser) Close() error { return nil } - // In is an input stream to read user input. It implements [io.ReadCloser]. type In struct { isTerminal bool @@ -26,7 +20,7 @@ func NewIn(in io.Reader) *In { if readCloser, ok := in.(io.ReadCloser); ok { i.in = readCloser } else { - i.in = nopReadCloser{in} + i.in = utils.DummyReadWriteCloser{Reader: in} } _, i.isTerminal = term.GetFdInfo(in) diff --git a/internal/generator/cli/streams/out.go b/internal/generator/cli/streams/out.go index 58ce498..11cddb8 100644 --- a/internal/generator/cli/streams/out.go +++ b/internal/generator/cli/streams/out.go @@ -1,18 +1,12 @@ -//nolint:dupl package streams import ( "io" "github.com/moby/term" + "github.com/tarantool/sdvg/internal/generator/cli/utils" ) -type nopWriteCloser struct { - io.Writer -} - -func (nopWriteCloser) Close() error { return nil } - // Out is an output stream to write normal program output. It implements an [io.WriteCloser]. type Out struct { isTerminal bool @@ -26,7 +20,7 @@ func NewOut(out io.Writer) *Out { if writeCloser, ok := out.(io.WriteCloser); ok { o.out = writeCloser } else { - o.out = nopWriteCloser{out} + o.out = utils.DummyReadWriteCloser{Writer: out} } _, o.isTerminal = term.GetFdInfo(out) diff --git a/internal/generator/cli/utils/utils.go b/internal/generator/cli/utils/utils.go index a57d26a..aeefb42 100644 --- a/internal/generator/cli/utils/utils.go +++ b/internal/generator/cli/utils/utils.go @@ -1,6 +1,7 @@ package utils import ( + "io" "path/filepath" "slices" "strings" @@ -98,3 +99,12 @@ func ChooseCommand(cmd *cobra.Command, args []string, renderer render.Renderer) return nil } + +type DummyReadWriteCloser struct { + io.Reader + io.Writer +} + +func (rwc DummyReadWriteCloser) Close() error { + return nil +} diff --git a/internal/generator/models/generator_output.go b/internal/generator/models/generator_output.go index 9f7e729..9d16712 100644 --- a/internal/generator/models/generator_output.go +++ b/internal/generator/models/generator_output.go @@ -18,6 +18,8 @@ const ( tcsTimeoutHeader = "x-tcs-timeout_ms" ParquetDateTimeMillisFormat = "millis" ParquetDateTimeMicrosFormat = "micros" + + PartitionFilesLimitDefault = 1000 ) // DataRow type is used to represent any data row that was generated. @@ -167,10 +169,11 @@ var _ Field = (*CSVConfig)(nil) // CSVConfig type used to describe output config for CSV implementation. type CSVConfig struct { - FloatPrecision int `json:"float_precision" yaml:"float_precision"` - DatetimeFormat string `json:"datetime_format" yaml:"datetime_format"` - Delimiter string `backup:"true" json:"delimiter" yaml:"delimiter"` - WithoutHeaders bool `backup:"true" json:"without_headers" yaml:"without_headers"` + FloatPrecision int `json:"float_precision" yaml:"float_precision"` + DatetimeFormat string `json:"datetime_format" yaml:"datetime_format"` + Delimiter string `backup:"true" json:"delimiter" yaml:"delimiter"` + WithoutHeaders bool `backup:"true" json:"without_headers" yaml:"without_headers"` + PartitionFilesLimit *int `json:"partition_files_limit" yaml:"partition_files_limit"` } func (c *CSVConfig) Parse() error { return nil } @@ -187,6 +190,11 @@ func (c *CSVConfig) FillDefaults() { if c.Delimiter == "" { c.Delimiter = "," } + + if c.PartitionFilesLimit == nil { + c.PartitionFilesLimit = new(int) + *c.PartitionFilesLimit = 1000 + } } func (c *CSVConfig) Validate() []error { @@ -200,6 +208,10 @@ func (c *CSVConfig) Validate() []error { errs = append(errs, errors.Errorf("the delimiter must consist of one character, got %v", c.Delimiter)) } + if c.PartitionFilesLimit != nil && *c.PartitionFilesLimit <= 0 { + errs = append(errs, errors.Errorf("partition files limit should be greater than 0, got: %v", *c.PartitionFilesLimit)) + } + return errs } @@ -295,9 +307,10 @@ var _ Field = (*ParquetConfig)(nil) // ParquetConfig type used to describe output config for parquet implementation. type ParquetConfig struct { - CompressionCodec string `backup:"true" json:"compression_codec" yaml:"compression_codec"` - FloatPrecision int `json:"float_precision" yaml:"float_precision"` - DateTimeFormat string `json:"datetime_format" yaml:"datetime_format"` + CompressionCodec string `backup:"true" json:"compression_codec" yaml:"compression_codec"` + FloatPrecision int `json:"float_precision" yaml:"float_precision"` + DateTimeFormat string `json:"datetime_format" yaml:"datetime_format"` + PartitionFilesLimit *int `json:"partition_files_limit" yaml:"partition_files_limit"` } //nolint:lll @@ -318,6 +331,11 @@ func (c *ParquetConfig) FillDefaults() { if c.DateTimeFormat == "" { c.DateTimeFormat = ParquetDateTimeMillisFormat } + + if c.PartitionFilesLimit == nil { + c.PartitionFilesLimit = new(int) + *c.PartitionFilesLimit = 1000 + } } func (c *ParquetConfig) Validate() []error { @@ -337,5 +355,9 @@ func (c *ParquetConfig) Validate() []error { c.DateTimeFormat, parquetSupportedDateTimeFormats)) } + if c.PartitionFilesLimit != nil && *c.PartitionFilesLimit <= 0 { + errs = append(errs, errors.Errorf("partition files limit should be greater than 0, got: %v", *c.PartitionFilesLimit)) + } + return errs } diff --git a/internal/generator/models/models_test.go b/internal/generator/models/models_test.go index ffb908a..66b35fd 100644 --- a/internal/generator/models/models_test.go +++ b/internal/generator/models/models_test.go @@ -225,9 +225,10 @@ func TestGeneratorConfigYAMLParse(t *testing.T) { OutputConfig: &OutputConfig{ Type: "csv", CSVParams: &CSVConfig{ - FloatPrecision: 2, - DatetimeFormat: "2006-01-02T15:04:05Z07:00", - Delimiter: ",", + FloatPrecision: 2, + DatetimeFormat: "2006-01-02T15:04:05Z07:00", + Delimiter: ",", + PartitionFilesLimit: ptr(1000), }, }, }, @@ -624,15 +625,16 @@ models: Dir: "test_output", CheckpointInterval: time.Second, CSVParams: &CSVConfig{ - FloatPrecision: 2, - DatetimeFormat: "2006-01-02T15:04:05Z07:00", - Delimiter: ",", + FloatPrecision: 2, + DatetimeFormat: "2006-01-02T15:04:05Z07:00", + Delimiter: ",", + PartitionFilesLimit: ptr(1000), }, }, }, }, { - name: "CsvFullConfig", + name: "csv full config", content: ` random_seed: 1 output: @@ -641,6 +643,7 @@ output: datetime_format: "2006-01-02" float_precision: 1 delimiter: ";" + partition_files_limit: 10 models: test: rows_count: 1 @@ -656,9 +659,10 @@ models: Dir: DefaultOutputDir, CheckpointInterval: 5 * time.Second, CSVParams: &CSVConfig{ - FloatPrecision: 1, - DatetimeFormat: "2006-01-02", - Delimiter: ";", + FloatPrecision: 1, + DatetimeFormat: "2006-01-02", + Delimiter: ";", + PartitionFilesLimit: ptr(10), }, }, }, @@ -849,6 +853,7 @@ output: datetime_format: micros float_precision: 3 compression_codec: GZIP + partition_files_limit: 1 models: test: rows_count: 1 @@ -864,9 +869,10 @@ models: Dir: DefaultOutputDir, CheckpointInterval: 5 * time.Second, ParquetParams: &ParquetConfig{ - FloatPrecision: 3, - DateTimeFormat: ParquetDateTimeMicrosFormat, - CompressionCodec: "GZIP", + FloatPrecision: 3, + DateTimeFormat: ParquetDateTimeMicrosFormat, + CompressionCodec: "GZIP", + PartitionFilesLimit: ptr(1), }, }, }, @@ -892,9 +898,10 @@ models: Dir: DefaultOutputDir, CheckpointInterval: 5 * time.Second, ParquetParams: &ParquetConfig{ - FloatPrecision: 2, - DateTimeFormat: ParquetDateTimeMillisFormat, - CompressionCodec: "UNCOMPRESSED", + FloatPrecision: 2, + DateTimeFormat: ParquetDateTimeMillisFormat, + CompressionCodec: "UNCOMPRESSED", + PartitionFilesLimit: ptr(1000), }, }, }, @@ -956,9 +963,10 @@ models: Dir: DefaultOutputDir, CheckpointInterval: 5 * time.Second, ParquetParams: &ParquetConfig{ - FloatPrecision: 2, - DateTimeFormat: ParquetDateTimeMillisFormat, - CompressionCodec: "UNCOMPRESSED", + FloatPrecision: 2, + DateTimeFormat: ParquetDateTimeMillisFormat, + CompressionCodec: "UNCOMPRESSED", + PartitionFilesLimit: ptr(1000), }, }, }, @@ -1062,9 +1070,10 @@ models: Dir: DefaultOutputDir, CheckpointInterval: 5 * time.Second, ParquetParams: &ParquetConfig{ - FloatPrecision: 2, - DateTimeFormat: ParquetDateTimeMillisFormat, - CompressionCodec: "UNCOMPRESSED", + FloatPrecision: 2, + DateTimeFormat: ParquetDateTimeMillisFormat, + CompressionCodec: "UNCOMPRESSED", + PartitionFilesLimit: ptr(1000), }, }, }, @@ -1107,6 +1116,7 @@ output: compression_codec: non-existent-codec float_precision: -1 datetime_format: non-existent-datetime-format + partition_files_limit: 0 checkpoint_interval: -1s models_to_ignore: - non-existent-column @@ -1158,7 +1168,8 @@ output config: parquet params: - unknown compression codec non-existent-codec, supported [UNCOMPRESSED SNAPPY GZIP LZ4 LZ4RAW LZO ZSTD BROTLI] - float precision should be grater than 0, got -1 -- unknown datetime format non-existent-datetime-format, supported [millis micros]`, +- unknown datetime format non-existent-datetime-format, supported [millis micros] +- partition files limit should be greater than 0, got: 0`, ), }, } diff --git a/internal/generator/output/general/model_writer.go b/internal/generator/output/general/model_writer.go index 5c3f746..1bfc436 100644 --- a/internal/generator/output/general/model_writer.go +++ b/internal/generator/output/general/model_writer.go @@ -13,6 +13,7 @@ import ( "time" "github.com/pkg/errors" + "github.com/tarantool/sdvg/internal/generator/cli/confirm" "github.com/tarantool/sdvg/internal/generator/common" "github.com/tarantool/sdvg/internal/generator/models" "github.com/tarantool/sdvg/internal/generator/output" @@ -26,6 +27,8 @@ import ( const buffer = 100 +var ErrPartitionFilesLimitExceeded = errors.New("partition files limit exceeded") + // ModelWriter type implements the general logic of writing data. type ModelWriter struct { model *models.Model @@ -48,6 +51,10 @@ type ModelWriter struct { writtenRowsWg *sync.WaitGroup writtenRowsChan chan uint64 stopChan chan struct{} + + partitionFilesCount int + partitionFilesLimit *int + confirm confirm.Confirm } // NewModelWriter creates ModelWriter object. @@ -55,7 +62,16 @@ func newModelWriter( model *models.Model, config *models.OutputConfig, continueGeneration bool, -) (*ModelWriter, error) { + confirm confirm.Confirm) (*ModelWriter, error) { + var partitionFilesLimit *int + + switch config.Type { + case "csv": + partitionFilesLimit = config.CSVParams.PartitionFilesLimit + case "parquet": + partitionFilesLimit = config.ParquetParams.PartitionFilesLimit + } + orderedColumnNames := make([]string, 0, len(model.Columns)) for _, column := range model.Columns { orderedColumnNames = append(orderedColumnNames, column.Name) @@ -108,6 +124,9 @@ func newModelWriter( writtenRowsWg: &sync.WaitGroup{}, writtenRowsChan: make(chan uint64, buffer), stopChan: make(chan struct{}), + partitionFilesCount: 0, + partitionFilesLimit: partitionFilesLimit, + confirm: confirm, } modelWriter.checkpointFilePath = modelWriter.getCheckpointFilePath() @@ -164,6 +183,7 @@ func (w *ModelWriter) updateCheckpoint() error { } // WriteRows function determines the partitioning key and sends the data to the appropriate writer. +// Note that this func should not be called concurrently from multiple goroutines because of confirm func call. func (w *ModelWriter) WriteRows(ctx context.Context, rows []*models.DataRow) error { for _, row := range rows { partitionPath := w.getPartitionPath(row) @@ -173,6 +193,13 @@ func (w *ModelWriter) WriteRows(ctx context.Context, rows []*models.DataRow) err w.writersMutex.RUnlock() if !ok { + w.partitionFilesCount++ + + err := w.shouldContinue(ctx) + if err != nil { + return err + } + newDataWriter, err := w.newWriter(ctx, partitionPath) if err != nil { return err @@ -232,6 +259,22 @@ func (w *ModelWriter) getPartitionPath(row *models.DataRow) string { return sb.String() } +// shouldContinue returns error if user don't want to continue generation. +func (w *ModelWriter) shouldContinue(ctx context.Context) error { + if w.confirm != nil && w.partitionFilesLimit != nil && w.partitionFilesCount == *w.partitionFilesLimit+1 { + shouldContinue, err := w.confirm(ctx, "Number of partitions files reached limit. Continue?") + if err != nil { + return err + } + + if !shouldContinue { + return errors.Wrapf(ErrPartitionFilesLimitExceeded, ": %v", w.partitionFilesCount) + } + } + + return nil +} + // newWriter function creates writer.Writer object based on output type from models.OutputConfig. func (w *ModelWriter) newWriter(ctx context.Context, outPath string) (writer.Writer, error) { var dataWriter writer.Writer diff --git a/internal/generator/output/general/model_writer_test.go b/internal/generator/output/general/model_writer_test.go index e938c96..ed75dc6 100644 --- a/internal/generator/output/general/model_writer_test.go +++ b/internal/generator/output/general/model_writer_test.go @@ -220,7 +220,7 @@ func TestPartitionPaths(t *testing.T) { }, } - writer, err := newModelWriter(tCase.model, devnullConfig, false) + writer, err := newModelWriter(tCase.model, devnullConfig, false, nil) require.NoError(t, err) err = writer.WriteRows(context.Background(), tCase.data) diff --git a/internal/generator/output/general/output.go b/internal/generator/output/general/output.go index b9a7fbe..2f2fc89 100644 --- a/internal/generator/output/general/output.go +++ b/internal/generator/output/general/output.go @@ -9,6 +9,7 @@ import ( "slices" "github.com/pkg/errors" + "github.com/tarantool/sdvg/internal/generator/cli/confirm" "github.com/tarantool/sdvg/internal/generator/models" "github.com/tarantool/sdvg/internal/generator/output" ) @@ -20,13 +21,20 @@ var _ output.Output = (*Output)(nil) type Output struct { config *models.OutputConfig models map[string]*models.Model + writersByModelName map[string]*ModelWriter + continueGeneration bool forceGeneration bool - writersByModelName map[string]*ModelWriter + confirm confirm.Confirm } // NewOutput function creates Output object. -func NewOutput(cfg *models.GenerationConfig, continueGeneration, forceGeneration bool) output.Output { +func NewOutput( + cfg *models.GenerationConfig, + continueGeneration, + forceGeneration bool, + confirm confirm.Confirm, +) output.Output { filteredModels := make(map[string]*models.Model) for modelName, model := range cfg.Models { @@ -41,6 +49,7 @@ func NewOutput(cfg *models.GenerationConfig, continueGeneration, forceGeneration continueGeneration: continueGeneration, forceGeneration: forceGeneration, writersByModelName: make(map[string]*ModelWriter), + confirm: confirm, } } @@ -56,7 +65,7 @@ func (o *Output) Setup() error { writersByModelName := make(map[string]*ModelWriter) for modelName, model := range o.models { - modelWriter, err := newModelWriter(model, o.config, o.continueGeneration) + modelWriter, err := newModelWriter(model, o.config, o.continueGeneration, o.confirm) if err != nil { return err } diff --git a/internal/generator/output/general/test/bench_test.go b/internal/generator/output/general/test/bench_test.go index 450d8c2..68dc9ab 100644 --- a/internal/generator/output/general/test/bench_test.go +++ b/internal/generator/output/general/test/bench_test.go @@ -301,7 +301,7 @@ func runModelsBenches( copyCfg := *genCfg SetOutputParams(©Cfg, uint64(b.N)) - out := general.NewOutput(©Cfg, false, true) + out := general.NewOutput(©Cfg, false, true, nil) require.NoError(b, out.Setup()) b.ResetTimer() diff --git a/internal/generator/output/general/test/unit_test.go b/internal/generator/output/general/test/unit_test.go index 4e7a91f..8ac7016 100644 --- a/internal/generator/output/general/test/unit_test.go +++ b/internal/generator/output/general/test/unit_test.go @@ -7,10 +7,12 @@ import ( "math" "os" "path/filepath" + "strings" "testing" "github.com/pkg/errors" "github.com/stretchr/testify/require" + "github.com/tarantool/sdvg/internal/generator/cli/confirm" "github.com/tarantool/sdvg/internal/generator/common" "github.com/tarantool/sdvg/internal/generator/models" outputGeneral "github.com/tarantool/sdvg/internal/generator/output/general" @@ -63,6 +65,18 @@ models: range_percentage: 0.5 - type_params: to: 5 +` + oneModelConfigWithPartition = ` +models: + model1: + rows_count: 10 + columns: + - name: id + type: integer + distinct_percentage: 1 + partition_columns: + - name: id + write_to_output: true ` ) @@ -94,7 +108,7 @@ func TestContinueGeneration(t *testing.T) { // Generate expected data - require.NoError(t, generate(t, cfg, uc, false, true)) + require.NoError(t, generate(t, cfg, uc, false, true, nil)) expectedFilesData := make(map[string][][]string) @@ -117,7 +131,7 @@ func TestContinueGeneration(t *testing.T) { model.GenerateTo = model.RowsCount / 2 } - require.NoError(t, generate(t, cfg, uc, false, true)) + require.NoError(t, generate(t, cfg, uc, false, true, nil)) for _, model := range cfg.Models { filesCount := int(math.Ceil(float64(model.GenerateTo-model.GenerateFrom) / float64(model.RowsPerFile))) @@ -151,7 +165,7 @@ func TestContinueGeneration(t *testing.T) { require.NoError(t, cfg.ParseFromFile(configPath)) cfg.OutputConfig.Dir = outputDir - require.NoError(t, generate(t, cfg, uc, true, true)) + require.NoError(t, generate(t, cfg, uc, true, true, nil)) for _, model := range cfg.Models { filesCount := math.Ceil(float64(rowsCountByModel[model.Name]) / float64(model.RowsPerFile)) @@ -238,10 +252,10 @@ cause: dir for model is not empty // Generate data in empty output dir - require.NoError(t, generate(t, cfg, uc, false, false)) + require.NoError(t, generate(t, cfg, uc, false, false, nil)) // Try to init new output with conflicts - out := outputGeneral.NewOutput(cfg, false, tc.forceGeneration) + out := outputGeneral.NewOutput(cfg, false, tc.forceGeneration, nil) err := out.Setup() if tc.err != nil { @@ -264,11 +278,98 @@ cause: dir for model is not empty } } +var ( + errMockTest = errors.New("mock test error") + partitionsFileLimit = 2 +) + +func TestConfirmationAsk(t *testing.T) { + testCases := []struct { + name string + shouldContinue bool + wantErr bool + err error + confirm confirm.Confirm + }{ + { + name: "Continue", + shouldContinue: true, + confirm: func(ctx context.Context, question string) (bool, error) { + return true, nil + }, + }, + { + name: "Stop", + shouldContinue: false, + err: outputGeneral.ErrPartitionFilesLimitExceeded, + confirm: func(ctx context.Context, question string) (bool, error) { + return false, nil + }, + }, + { + name: "Error", + shouldContinue: false, + wantErr: true, + err: errMockTest, + confirm: func(ctx context.Context, question string) (bool, error) { + return false, errMockTest + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Write models config + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, configFileName) + require.NoError(t, os.WriteFile(configPath, []byte(oneModelConfigWithPartition), configFilePerm)) + + uc := useCaseGeneral.NewUseCase(useCaseGeneral.UseCaseConfig{}) + require.NoError(t, uc.Setup()) + + // Parse config + + cfg := &models.GenerationConfig{} + + require.NoError(t, cfg.ParseFromFile(configPath)) + + *cfg.OutputConfig.CSVParams.PartitionFilesLimit = partitionsFileLimit + + // Generate data in empty output dir + + err := generate(t, cfg, uc, false, true, tc.confirm) + + // check generated partitions files amount + fileNames, walkErr := common.WalkWithFilter(models.DefaultOutputDir, func(entry os.DirEntry) bool { + return entry.IsDir() && strings.HasPrefix(entry.Name(), "id=") + }) + + require.NoError(t, walkErr, "failed to walk tmpdir: %v", tmpDir) + + if tc.wantErr { + require.Error(t, err) + } else { + if tc.shouldContinue { + require.Len(t, fileNames, 10, "there should be rows_amount dirs") + require.NoError(t, err) + } else { + require.True(t, errors.Is(err, tc.err), "expected error: %v, got: %v", tc.err, err) + require.Len(t, fileNames, partitionsFileLimit, "there should be partitionsFileLimit dirs") + } + } + + // cleanup + + require.NoError(t, os.RemoveAll(models.DefaultOutputDir)) + }) + } +} + //nolint:lll -func generate(t *testing.T, cfg *models.GenerationConfig, uc usecase.UseCase, continueGeneration, forceGeneration bool) error { +func generate(t *testing.T, cfg *models.GenerationConfig, uc usecase.UseCase, continueGeneration, forceGeneration bool, confirm confirm.Confirm) error { t.Helper() - out := outputGeneral.NewOutput(cfg, continueGeneration, forceGeneration) + out := outputGeneral.NewOutput(cfg, continueGeneration, forceGeneration, confirm) taskID, err := uc.CreateTask(context.Background(), usecase.TaskConfig{ GenerationConfig: cfg, @@ -279,8 +380,15 @@ func generate(t *testing.T, cfg *models.GenerationConfig, uc usecase.UseCase, co return err } - require.NoError(t, uc.WaitResult(taskID)) - require.NoError(t, uc.Teardown()) + err = uc.WaitResult(taskID) + if err != nil { + return err + } + + err = uc.Teardown() + if err != nil { + return err + } return nil } diff --git a/internal/generator/usecase/general/backup/backup_test.go b/internal/generator/usecase/general/backup/backup_test.go index e29b433..3168e1f 100644 --- a/internal/generator/usecase/general/backup/backup_test.go +++ b/internal/generator/usecase/general/backup/backup_test.go @@ -35,6 +35,7 @@ func TestHandleBackup(t *testing.T) { }, false, false, + nil, ) type testCase struct { @@ -189,7 +190,7 @@ func TestHandleCheckpoint(t *testing.T) { "model4": 954, } - out := general.NewOutput(cfg, false, false) + out := general.NewOutput(cfg, false, false, nil) require.NoError(t, out.Setup()) for modelName, generateFrom := range checkpoints { diff --git a/internal/generator/usecase/general/task.go b/internal/generator/usecase/general/task.go index 2bf435e..29c0d55 100644 --- a/internal/generator/usecase/general/task.go +++ b/internal/generator/usecase/general/task.go @@ -253,6 +253,7 @@ func (t *Task) skipRows() { } // generateAndSaveBatch function generate batch of values for selected column and send it to output. +// The next batch is written only after the previous one has completed saving. func (t *Task) generateAndSaveBatch( ctx context.Context, outputSync *common.WorkerSyncer, modelName string, generators []*generator.BatchGenerator, count uint64,