Skip to content
57 changes: 49 additions & 8 deletions iter/iter.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package iter

import (
"context"
"runtime"
"sync/atomic"

"github.com/sourcegraph/conc"
"github.com/sourcegraph/conc/pool"
)

// defaultMaxGoroutines returns the default maximum number of
Expand Down Expand Up @@ -57,29 +58,69 @@ func ForEachIdx[T any](input []T, f func(int, *T)) { Iterator[T]{}.ForEachIdx(in
// ForEachIdx is the same as ForEach except it also provides the
// index of the element to the callback.
func (iter Iterator[T]) ForEachIdx(input []T, f func(int, *T)) {
_ = iter.ForEachIdxCtx(context.Background(), input, func(_ context.Context, idx int, input *T) error {
f(idx, input)
return nil
})
}

// ForEachCtx is the same as ForEach except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned.
func ForEachCtx[T any](ctx context.Context, input []T, f func(context.Context, *T) error) error {
return Iterator[T]{}.ForEachCtx(ctx, input, f)
}

// ForEachCtx is the same as ForEach except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned.
func (iter Iterator[T]) ForEachCtx(ctx context.Context, input []T, f func(context.Context, *T) error) error {
return iter.ForEachIdxCtx(ctx, input, func(innerctx context.Context, _ int, input *T) error {
return f(innerctx, input)
})
}

// ForEachIdxCtx is the same as ForEachIdx except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned.
func ForEachIdxCtx[T any](ctx context.Context, input []T, f func(context.Context, int, *T) error) error {
return Iterator[T]{}.ForEachIdxCtx(ctx, input, f)
}

// ForEachIdxCtx is the same as ForEachIdx except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned.
func (iter Iterator[T]) ForEachIdxCtx(ctx context.Context, input []T, f func(context.Context, int, *T) error) error {
if iter.MaxGoroutines == 0 {
// iter is a value receiver and is hence safe to mutate
iter.MaxGoroutines = defaultMaxGoroutines()
}

numInput := len(input)
if iter.MaxGoroutines > numInput {
if iter.MaxGoroutines > numInput && numInput > 0 {
// No more concurrent tasks than the number of input items.
iter.MaxGoroutines = numInput
}

var idx atomic.Int64
// Create the task outside the loop to avoid extra closure allocations.
task := func() {
task := func(innerctx context.Context) error {
i := int(idx.Add(1) - 1)
for ; i < numInput; i = int(idx.Add(1) - 1) {
f(i, &input[i])
for ; i < numInput && innerctx.Err() == nil; i = int(idx.Add(1) - 1) {
if err := f(innerctx, i, &input[i]); err != nil {
return err
}
}
return innerctx.Err() // nil if the context was never cancelled
}

var wg conc.WaitGroup
runner := pool.New().
WithContext(ctx).
WithCancelOnError().
WithFirstError().
WithMaxGoroutines(iter.MaxGoroutines)
for i := 0; i < iter.MaxGoroutines; i++ {
wg.Go(task)
runner.Go(task)
}
wg.Wait()
return runner.Wait()
}
77 changes: 70 additions & 7 deletions iter/iter_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package iter_test

import (
"context"
"errors"
"fmt"
"strconv"
"sync/atomic"
Expand Down Expand Up @@ -72,16 +74,18 @@ func TestIterator(t *testing.T) {
})
}

func TestForEachIdx(t *testing.T) {
func TestForEachIdxCtx(t *testing.T) {
t.Parallel()

bgctx := context.Background()
t.Run("empty", func(t *testing.T) {
t.Parallel()
f := func() {
ints := []int{}
iter.ForEachIdx(ints, func(i int, val *int) {
err := iter.ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error {
panic("this should never be called")
})
require.NoError(t, err)
}
require.NotPanics(t, f)
})
Expand All @@ -90,33 +94,57 @@ func TestForEachIdx(t *testing.T) {
t.Parallel()
f := func() {
ints := []int{1}
iter.ForEachIdx(ints, func(i int, val *int) {
panic("super bad thing happened")
})
_ = iter.ForEachIdxCtx(bgctx, ints,
func(ctx context.Context, i int, val *int) error {
panic("super bad thing happened")
})
}
require.Panics(t, f)
})

t.Run("mutating inputs is fine", func(t *testing.T) {
t.Parallel()
ints := []int{1, 2, 3, 4, 5}
iter.ForEachIdx(ints, func(i int, val *int) {
err := iter.ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error {
*val += 1
return nil
})
require.Equal(t, []int{2, 3, 4, 5, 6}, ints)
require.NoError(t, err)
})

t.Run("huge inputs", func(t *testing.T) {
t.Parallel()
ints := make([]int, 10000)
iter.ForEachIdx(ints, func(i int, val *int) {
err := iter.ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error {
*val = i
return nil
})
expected := make([]int, 10000)
for i := 0; i < 10000; i++ {
expected[i] = i
}
require.Equal(t, expected, ints)
require.NoError(t, err)
})

err1 := errors.New("error1")
err2 := errors.New("error2")

t.Run("first error is propagated", func(t *testing.T) {
t.Parallel()
ints := []int{1, 2, 3, 4, 5}
err := iter.ForEachIdxCtx(bgctx, ints, func(ctx context.Context, i int, val *int) error {
if *val == 3 {
return err1
}
if *val == 4 {
return err2
}
return nil
})
require.ErrorIs(t, err, err1)
require.NotErrorIs(t, err, err2)
})
}

Expand Down Expand Up @@ -168,6 +196,41 @@ func TestForEach(t *testing.T) {
})
}

func TestForEachCtx(t *testing.T) {
t.Parallel()

bgctx := context.Background()
t.Run("mutating inputs is fine", func(t *testing.T) {
t.Parallel()
ints := []int{1, 2, 3, 4, 5}
err := iter.ForEachCtx(bgctx, ints, func(ctx context.Context, val *int) error {
*val += 1
return nil
})
require.Equal(t, []int{2, 3, 4, 5, 6}, ints)
require.NoError(t, err)
})

err1 := errors.New("error1")
err2 := errors.New("error2")

t.Run("first error is propagated", func(t *testing.T) {
t.Parallel()
ints := []int{1, 2, 3, 4, 5}
err := iter.ForEachCtx(bgctx, ints, func(ctx context.Context, val *int) error {
if *val == 3 {
return err1
}
if *val == 4 {
return err2
}
return nil
})
require.ErrorIs(t, err, err1)
require.NotErrorIs(t, err, err2)
})
}

func BenchmarkForEach(b *testing.B) {
for _, count := range []int{0, 1, 8, 100, 1000, 10000, 100000} {
b.Run(strconv.Itoa(count), func(b *testing.B) {
Expand Down
35 changes: 28 additions & 7 deletions iter/map.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package iter

import (
"context"
"errors"
"sync"
)
Expand All @@ -24,9 +25,8 @@ func Map[T, R any](input []T, f func(*T) R) []R {
//
// Map uses up to the configured Mapper's maximum number of goroutines.
func (m Mapper[T, R]) Map(input []T, f func(*T) R) []R {
res := make([]R, len(input))
Iterator[T](m).ForEachIdx(input, func(i int, t *T) {
res[i] = f(t)
res, _ := m.MapCtx(context.Background(), input, func(_ context.Context, t *T) (R, error) {
return f(t), nil
})
return res
}
Expand All @@ -46,18 +46,39 @@ func MapErr[T, R any](input []T, f func(*T) (R, error)) ([]R, error) {
// Map uses up to the configured Mapper's maximum number of goroutines.
func (m Mapper[T, R]) MapErr(input []T, f func(*T) (R, error)) ([]R, error) {
var (
res = make([]R, len(input))
errMux sync.Mutex
errs []error
)
Iterator[T](m).ForEachIdx(input, func(i int, t *T) {
var err error
res[i], err = f(t)
// MapErr handles its own errors by accumulating them as a multierror, ignoring the error from MapCtx which is only the first error
res, _ := m.MapCtx(context.Background(), input, func(ctx context.Context, t *T) (R, error) {
ires, err := f(t)
if err != nil {
errMux.Lock()
errs = append(errs, err)
errMux.Unlock()
}
return ires, nil
})
return res, errors.Join(errs...)
}

// MapCtx is the same as Map except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned.
func MapCtx[T, R any](ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) {
return Mapper[T, R]{}.MapCtx(ctx, input, f)
}

// MapCtx is the same as Map except it also accepts a context
// that it uses to manages the execution of tasks.
// The context is cancelled on task failure and the first error is returned.
func (m Mapper[T, R]) MapCtx(ctx context.Context, input []T, f func(context.Context, *T) (R, error)) ([]R, error) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm interested in this as well 🙂 What do you think about using a builder for the Mapper? I would like a WithCancelOnError as well.

var (
res = make([]R, len(input))
)
return res, Iterator[T](m).ForEachIdxCtx(ctx, input, func(innerctx context.Context, i int, t *T) error {
var err error
res[i], err = f(innerctx, t)
return err
})
}
Loading