From 10f915ef6faf675000baada3b647298e9e28eb74 Mon Sep 17 00:00:00 2001 From: hehaifeng <240657481@qq.com> Date: Wed, 27 May 2026 01:14:02 +0800 Subject: [PATCH 1/4] feat: introduce LoadBalancer interface for extensible MultiPool load balancing - Add LoadBalancer interface with Pick/Fallback methods - Add PoolMetrics interface exposing read-only pool stats - Extract load-balancing logic into lbs.go with built-in strategies: RoundRobinLB, LeastTasksLB, LeastWaitingLB - Add NewMultiPoolWithLB, NewMultiPoolWithFuncAndLB, NewMultiPoolWithFuncGenericAndLB for custom LB injection - Keep NewMultiPool/NewMultiPoolWithFunc/NewMultiPoolWithFuncGeneric unchanged for full backward compatibility - Add tests and benchmarks covering all three strategies and task types --- ants_benchmark_test.go | 156 ++++++++++++++++++++++++++++ ants_test.go | 209 ++++++++++++++++++++++++++++++++++++++ lbs.go | 133 ++++++++++++++++++++++++ multipool.go | 66 +++++------- multipool_func.go | 55 +++++----- multipool_func_generic.go | 56 +++++----- 6 files changed, 574 insertions(+), 101 deletions(-) create mode 100644 lbs.go diff --git a/ants_benchmark_test.go b/ants_benchmark_test.go index 38e25dc..569e851 100644 --- a/ants_benchmark_test.go +++ b/ants_benchmark_test.go @@ -23,6 +23,7 @@ package ants_test import ( + "math/rand" "runtime" "sync" "sync/atomic" @@ -226,3 +227,158 @@ func BenchmarkParallelAntsMultiPoolThroughput(b *testing.B) { } }) } + +// cpuTask simulates a CPU-intensive task. +func cpuTask() { + n := 0 + for i := 0; i < 1000; i++ { + n += i * i + } + _ = n +} + +// ioTask simulates an IO-intensive task with a short sleep. +func ioTask() { + time.Sleep(time.Millisecond) +} + +// mixedTask simulates uneven task durations to stress load-balancing decisions. +func mixedTask() { + if time.Now().UnixNano()%5 == 0 { + time.Sleep(10 * time.Millisecond) + } else { + time.Sleep(time.Millisecond) + } +} + +func benchmarkMultiPoolLBS(b *testing.B, lb ants.LoadBalancer, task func()) { + p, _ := ants.NewMultiPoolWithLB(10, PoolCap/10, lb, ants.WithExpiryDuration(DefaultExpiredTime)) + defer p.ReleaseTimeout(DefaultExpiredTime) //nolint:errcheck + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = p.Submit(task) + } + }) +} + +// CPU-intensive task benchmarks across LBS strategies. + +func BenchmarkMultiPool_RoundRobin_CPUThroughput(b *testing.B) { + benchmarkMultiPoolLBS(b, ants.NewRoundRobinLB(), cpuTask) +} + +func BenchmarkMultiPool_LeastTasks_CPUThroughput(b *testing.B) { + benchmarkMultiPoolLBS(b, ants.NewLeastTasksLB(), cpuTask) +} + +func BenchmarkMultiPool_LeastWaiting_CPUThroughput(b *testing.B) { + benchmarkMultiPoolLBS(b, ants.NewLeastWaitingLB(), cpuTask) +} + +// IO-intensive task benchmarks across LBS strategies. + +func BenchmarkMultiPool_RoundRobin_IOThroughput(b *testing.B) { + benchmarkMultiPoolLBS(b, ants.NewRoundRobinLB(), ioTask) +} + +func BenchmarkMultiPool_LeastTasks_IOThroughput(b *testing.B) { + benchmarkMultiPoolLBS(b, ants.NewLeastTasksLB(), ioTask) +} + +func BenchmarkMultiPool_LeastWaiting_IOThroughput(b *testing.B) { + benchmarkMultiPoolLBS(b, ants.NewLeastWaitingLB(), ioTask) +} + +// Mixed (uneven duration) task benchmarks across LBS strategies. + +func BenchmarkMultiPool_RoundRobin_MixedThroughput(b *testing.B) { + benchmarkMultiPoolLBS(b, ants.NewRoundRobinLB(), mixedTask) +} + +func BenchmarkMultiPool_LeastTasks_MixedThroughput(b *testing.B) { + benchmarkMultiPoolLBS(b, ants.NewLeastTasksLB(), mixedTask) +} + +func BenchmarkMultiPool_LeastWaiting_MixedThroughput(b *testing.B) { + benchmarkMultiPoolLBS(b, ants.NewLeastWaitingLB(), mixedTask) +} + +// randomLB is a custom LoadBalancer that picks a pool at random, +// demonstrating how users can plug in their own strategy via NewMultiPoolWithLB. +type randomLB struct{} + +func newRandomLB() *randomLB { + return &randomLB{} +} + +func (r *randomLB) Pick(pools []ants.PoolMetrics) int { + return rand.Intn(len(pools)) +} + +func (r *randomLB) Fallback(pools []ants.PoolMetrics) int { + return -1 +} + +// Custom random LB benchmarks across task types. + +func BenchmarkMultiPool_Random_CPUThroughput(b *testing.B) { + benchmarkMultiPoolLBS(b, newRandomLB(), cpuTask) +} + +func BenchmarkMultiPool_Random_IOThroughput(b *testing.B) { + benchmarkMultiPoolLBS(b, newRandomLB(), ioTask) +} + +func BenchmarkMultiPool_Random_MixedThroughput(b *testing.B) { + benchmarkMultiPoolLBS(b, newRandomLB(), mixedTask) +} + +func benchmarkMultiPoolWithFuncLBSThroughput(b *testing.B, lb ants.LoadBalancer) { + p, _ := ants.NewMultiPoolWithFuncAndLB(10, PoolCap/10, demoPoolFunc, lb, ants.WithExpiryDuration(DefaultExpiredTime)) + defer p.ReleaseTimeout(DefaultExpiredTime) //nolint:errcheck + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = p.Invoke(BenchParam) + } + }) +} + +func BenchmarkMultiPoolWithFunc_RoundRobin_Throughput(b *testing.B) { + benchmarkMultiPoolWithFuncLBSThroughput(b, ants.NewRoundRobinLB()) +} + +func BenchmarkMultiPoolWithFunc_LeastTasks_Throughput(b *testing.B) { + benchmarkMultiPoolWithFuncLBSThroughput(b, ants.NewLeastTasksLB()) +} + +func BenchmarkMultiPoolWithFunc_LeastWaiting_Throughput(b *testing.B) { + benchmarkMultiPoolWithFuncLBSThroughput(b, ants.NewLeastWaitingLB()) +} + +func benchmarkMultiPoolWithFuncGenericLBSThroughput(b *testing.B, lb ants.LoadBalancer) { + p, _ := ants.NewMultiPoolWithFuncGenericAndLB(10, PoolCap/10, demoPoolFuncInt, lb, ants.WithExpiryDuration(DefaultExpiredTime)) + defer p.ReleaseTimeout(DefaultExpiredTime) //nolint:errcheck + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = p.Invoke(BenchParam) + } + }) +} + +func BenchmarkMultiPoolWithFuncGeneric_RoundRobin_Throughput(b *testing.B) { + benchmarkMultiPoolWithFuncGenericLBSThroughput(b, ants.NewRoundRobinLB()) +} + +func BenchmarkMultiPoolWithFuncGeneric_LeastTasks_Throughput(b *testing.B) { + benchmarkMultiPoolWithFuncGenericLBSThroughput(b, ants.NewLeastTasksLB()) +} + +func BenchmarkMultiPoolWithFuncGeneric_LeastWaiting_Throughput(b *testing.B) { + benchmarkMultiPoolWithFuncGenericLBSThroughput(b, ants.NewLeastWaitingLB()) +} diff --git a/ants_test.go b/ants_test.go index a03f71d..8ca8a41 100644 --- a/ants_test.go +++ b/ants_test.go @@ -1815,3 +1815,212 @@ func TestRebootNewPoolWithPreAllocCalc(t *testing.T) { wg.Wait() require.EqualValues(t, 499500, sum, "The result should be 499500") } +func TestMultiPoolWithLB_RoundRobin(t *testing.T) { + _, err := ants.NewMultiPoolWithLB(-1, 5, ants.NewRoundRobinLB()) + require.ErrorIs(t, err, ants.ErrInvalidMultiPoolSize) + _, err = ants.NewMultiPoolWithLB(10, 5, nil) + require.ErrorIs(t, err, ants.ErrInvalidLoadBalancingStrategy) + _, err = ants.NewMultiPoolWithLB(10, 5, ants.NewRoundRobinLB(), ants.WithExpiryDuration(-1)) + require.ErrorIs(t, err, ants.ErrInvalidPoolExpiry) + + mp, err := ants.NewMultiPoolWithLB(10, 5, ants.NewRoundRobinLB()) + require.NoError(t, err) + testFn := func() { + for i := 0; i < 50; i++ { + err = mp.Submit(longRunningFunc) + require.NoError(t, err) + } + require.EqualValues(t, 50, mp.Running()) + require.EqualValues(t, 50, mp.Cap()) + require.EqualValues(t, 0, mp.Free()) + require.False(t, mp.IsClosed()) + atomic.StoreInt32(&stopLongRunningFunc, 1) + require.NoError(t, mp.ReleaseTimeout(3*time.Second)) + require.ErrorIs(t, mp.Submit(nil), ants.ErrPoolClosed) + require.True(t, mp.IsClosed()) + atomic.StoreInt32(&stopLongRunningFunc, 0) + } + testFn() + mp.Reboot() + testFn() +} + +func TestMultiPoolWithLB_LeastTasks(t *testing.T) { + mp, err := ants.NewMultiPoolWithLB(10, 5, ants.NewLeastTasksLB(), ants.WithNonblocking(true)) + require.NoError(t, err) + testFn := func() { + for i := 0; i < 50; i++ { + _ = mp.Submit(longRunningFunc) + } + require.False(t, mp.IsClosed()) + atomic.StoreInt32(&stopLongRunningFunc, 1) + require.NoError(t, mp.ReleaseTimeout(3*time.Second)) + require.ErrorIs(t, mp.Submit(nil), ants.ErrPoolClosed) + require.True(t, mp.IsClosed()) + atomic.StoreInt32(&stopLongRunningFunc, 0) + } + testFn() + mp.Reboot() + testFn() +} + +func TestMultiPoolWithLB_LeastWaiting(t *testing.T) { + mp, err := ants.NewMultiPoolWithLB(10, 5, ants.NewLeastWaitingLB(), ants.WithNonblocking(true)) + require.NoError(t, err) + testFn := func() { + for i := 0; i < 50; i++ { + _ = mp.Submit(longRunningFunc) + } + require.False(t, mp.IsClosed()) + atomic.StoreInt32(&stopLongRunningFunc, 1) + require.NoError(t, mp.ReleaseTimeout(3*time.Second)) + require.ErrorIs(t, mp.Submit(nil), ants.ErrPoolClosed) + require.True(t, mp.IsClosed()) + atomic.StoreInt32(&stopLongRunningFunc, 0) + } + testFn() + mp.Reboot() + testFn() +} + +func TestMultiPoolWithFuncAndLB_RoundRobin(t *testing.T) { + _, err := ants.NewMultiPoolWithFuncAndLB(-1, 5, longRunningPoolFunc, ants.NewRoundRobinLB()) + require.ErrorIs(t, err, ants.ErrInvalidMultiPoolSize) + _, err = ants.NewMultiPoolWithFuncAndLB(10, 5, longRunningPoolFunc, nil) + require.ErrorIs(t, err, ants.ErrInvalidLoadBalancingStrategy) + _, err = ants.NewMultiPoolWithFuncAndLB(10, 5, longRunningPoolFunc, ants.NewRoundRobinLB(), ants.WithExpiryDuration(-1)) + require.ErrorIs(t, err, ants.ErrInvalidPoolExpiry) + + ch := make(chan struct{}) + mp, err := ants.NewMultiPoolWithFuncAndLB(10, 5, longRunningPoolFunc, ants.NewRoundRobinLB()) + require.NoError(t, err) + testFn := func() { + for i := 0; i < 50; i++ { + err = mp.Invoke(ch) + require.NoError(t, err) + } + require.EqualValues(t, 50, mp.Running()) + require.EqualValues(t, 50, mp.Cap()) + require.EqualValues(t, 0, mp.Free()) + require.False(t, mp.IsClosed()) + close(ch) + require.NoError(t, mp.ReleaseTimeout(3*time.Second)) + require.ErrorIs(t, mp.Invoke(nil), ants.ErrPoolClosed) + require.True(t, mp.IsClosed()) + ch = make(chan struct{}) + } + testFn() + mp.Reboot() + testFn() +} + +func TestMultiPoolWithFuncAndLB_LeastTasks(t *testing.T) { + ch := make(chan struct{}) + mp, err := ants.NewMultiPoolWithFuncAndLB(10, 5, longRunningPoolFunc, ants.NewLeastTasksLB(), ants.WithNonblocking(true)) + require.NoError(t, err) + testFn := func() { + for i := 0; i < 50; i++ { + _ = mp.Invoke(ch) + } + require.False(t, mp.IsClosed()) + close(ch) + require.NoError(t, mp.ReleaseTimeout(3*time.Second)) + require.ErrorIs(t, mp.Invoke(nil), ants.ErrPoolClosed) + require.True(t, mp.IsClosed()) + ch = make(chan struct{}) + } + testFn() + mp.Reboot() + testFn() +} + +func TestMultiPoolWithFuncAndLB_LeastWaiting(t *testing.T) { + ch := make(chan struct{}) + mp, err := ants.NewMultiPoolWithFuncAndLB(10, 5, longRunningPoolFunc, ants.NewLeastWaitingLB(), ants.WithNonblocking(true)) + require.NoError(t, err) + testFn := func() { + for i := 0; i < 50; i++ { + _ = mp.Invoke(ch) + } + require.False(t, mp.IsClosed()) + close(ch) + require.NoError(t, mp.ReleaseTimeout(3*time.Second)) + require.ErrorIs(t, mp.Invoke(nil), ants.ErrPoolClosed) + require.True(t, mp.IsClosed()) + ch = make(chan struct{}) + } + testFn() + mp.Reboot() + testFn() +} + +func TestMultiPoolWithFuncGenericAndLB_RoundRobin(t *testing.T) { + _, err := ants.NewMultiPoolWithFuncGenericAndLB(-1, 5, longRunningPoolFuncCh, ants.NewRoundRobinLB()) + require.ErrorIs(t, err, ants.ErrInvalidMultiPoolSize) + _, err = ants.NewMultiPoolWithFuncGenericAndLB(10, 5, longRunningPoolFuncCh, nil) + require.ErrorIs(t, err, ants.ErrInvalidLoadBalancingStrategy) + _, err = ants.NewMultiPoolWithFuncGenericAndLB(10, 5, longRunningPoolFuncCh, ants.NewRoundRobinLB(), ants.WithExpiryDuration(-1)) + require.ErrorIs(t, err, ants.ErrInvalidPoolExpiry) + + ch := make(chan struct{}) + mp, err := ants.NewMultiPoolWithFuncGenericAndLB(10, 5, longRunningPoolFuncCh, ants.NewRoundRobinLB()) + require.NoError(t, err) + testFn := func() { + for i := 0; i < 50; i++ { + err = mp.Invoke(ch) + require.NoError(t, err) + } + require.EqualValues(t, 50, mp.Running()) + require.EqualValues(t, 50, mp.Cap()) + require.EqualValues(t, 0, mp.Free()) + require.False(t, mp.IsClosed()) + close(ch) + require.NoError(t, mp.ReleaseTimeout(3*time.Second)) + require.ErrorIs(t, mp.Invoke(nil), ants.ErrPoolClosed) + require.True(t, mp.IsClosed()) + ch = make(chan struct{}) + } + testFn() + mp.Reboot() + testFn() +} + +func TestMultiPoolWithFuncGenericAndLB_LeastTasks(t *testing.T) { + ch := make(chan struct{}) + mp, err := ants.NewMultiPoolWithFuncGenericAndLB(10, 5, longRunningPoolFuncCh, ants.NewLeastTasksLB(), ants.WithNonblocking(true)) + require.NoError(t, err) + testFn := func() { + for i := 0; i < 50; i++ { + _ = mp.Invoke(ch) + } + require.False(t, mp.IsClosed()) + close(ch) + require.NoError(t, mp.ReleaseTimeout(3*time.Second)) + require.ErrorIs(t, mp.Invoke(nil), ants.ErrPoolClosed) + require.True(t, mp.IsClosed()) + ch = make(chan struct{}) + } + testFn() + mp.Reboot() + testFn() +} + +func TestMultiPoolWithFuncGenericAndLB_LeastWaiting(t *testing.T) { + ch := make(chan struct{}) + mp, err := ants.NewMultiPoolWithFuncGenericAndLB(10, 5, longRunningPoolFuncCh, ants.NewLeastWaitingLB(), ants.WithNonblocking(true)) + require.NoError(t, err) + testFn := func() { + for i := 0; i < 50; i++ { + _ = mp.Invoke(ch) + } + require.False(t, mp.IsClosed()) + close(ch) + require.NoError(t, mp.ReleaseTimeout(3*time.Second)) + require.ErrorIs(t, mp.Invoke(nil), ants.ErrPoolClosed) + require.True(t, mp.IsClosed()) + ch = make(chan struct{}) + } + testFn() + mp.Reboot() + testFn() +} diff --git a/lbs.go b/lbs.go new file mode 100644 index 0000000..0c94f3e --- /dev/null +++ b/lbs.go @@ -0,0 +1,133 @@ +// MIT License + +// Copyright (c) 2023 Andy Pan + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +package ants + +import ( + "math" + "sync/atomic" +) + +// LoadBalancingStrategy represents the type of load-balancing algorithm. +type LoadBalancingStrategy int + +const ( + // RoundRobin distributes task to a list of pools in rotation. + RoundRobin LoadBalancingStrategy = 1 << (iota + 1) + + // LeastTasks always selects the pool with the least number of pending tasks. + LeastTasks +) + +// PoolMetrics exposes the read-only stats a LoadBalancer needs to make a pick decision. +type PoolMetrics interface { + Running() int + Waiting() int + Free() int + Cap() int +} + +// LoadBalancer picks a pool index from a slice of PoolMetrics. +type LoadBalancer interface { + Pick(pools []PoolMetrics) int + // Fallback is called when the pool chosen by Pick is overloaded. + // Return -1 to indicate no fallback is supported. + Fallback(pools []PoolMetrics) int +} + +// roundRobinLB distributes tasks across pools in rotation. +type roundRobinLB struct { + index uint32 +} + +// NewRoundRobinLB returns a RoundRobin load balancer. +func NewRoundRobinLB() LoadBalancer { + return &roundRobinLB{index: math.MaxUint32} +} + +func (r *roundRobinLB) Pick(pools []PoolMetrics) int { + return int(atomic.AddUint32(&r.index, 1) % uint32(len(pools))) +} + +func (r *roundRobinLB) Fallback(pools []PoolMetrics) int { + return leastTasksPick(pools) +} + +func leastTasksPick(pools []PoolMetrics) int { + idx, least := 0, math.MaxInt32 + for i, p := range pools { + if n := p.Running(); n < least { + least = n + idx = i + } + } + return idx +} + +// leastTasksLB picks the pool with the fewest running tasks. +type leastTasksLB struct{} + +// NewLeastTasksLB returns a LeastTasks load balancer. +func NewLeastTasksLB() LoadBalancer { + return &leastTasksLB{} +} + +func (l *leastTasksLB) Pick(pools []PoolMetrics) int { + return leastTasksPick(pools) +} + +func (l *leastTasksLB) Fallback(pools []PoolMetrics) int { + return -1 +} + +// leastWaiting picks the pool with the fewest waiting tasks. +type leastWaiting struct{} + +// NewLeastWaitingLB returns a LeastWaiting load balancer. +func NewLeastWaitingLB() LoadBalancer { + return &leastWaiting{} +} + +func (l *leastWaiting) Pick(pools []PoolMetrics) int { + idx, least := 0, math.MaxInt32 + for i, p := range pools { + if n := p.Waiting(); n < least { + least = n + idx = i + } + } + return idx +} + +func (l *leastWaiting) Fallback(pools []PoolMetrics) int { + return -1 +} + +func newBuiltinLB(lbs LoadBalancingStrategy) (LoadBalancer, error) { + switch lbs { + case RoundRobin: + return NewRoundRobinLB(), nil + case LeastTasks: + return NewLeastTasksLB(), nil + } + return nil, ErrInvalidLoadBalancingStrategy +} diff --git a/multipool.go b/multipool.go index 6116f0b..2517a82 100644 --- a/multipool.go +++ b/multipool.go @@ -26,7 +26,6 @@ import ( "context" "errors" "fmt" - "math" "strings" "sync/atomic" "time" @@ -34,17 +33,6 @@ import ( "golang.org/x/sync/errgroup" ) -// LoadBalancingStrategy represents the type of load-balancing algorithm. -type LoadBalancingStrategy int - -const ( - // RoundRobin distributes task to a list of pools in rotation. - RoundRobin LoadBalancingStrategy = 1 << (iota + 1) - - // LeastTasks always selects the pool with the least number of pending tasks. - LeastTasks -) - type contextReleaser interface { ReleaseContext(ctx context.Context) error } @@ -88,10 +76,10 @@ func releasePools(ctx context.Context, pools []contextReleaser) error { // MultiPool is a good fit for the scenario where you have a large number of // tasks to submit, and you don't want the single pool to be the bottleneck. type MultiPool struct { - pools []*Pool - index uint32 - state int32 - lbs LoadBalancingStrategy + pools []*Pool + metrics []PoolMetrics + lb LoadBalancer + state int32 } // NewMultiPool instantiates a MultiPool with a size of the pool list and a size @@ -100,11 +88,24 @@ func NewMultiPool(size, sizePerPool int, lbs LoadBalancingStrategy, options ...O if size <= 0 { return nil, ErrInvalidMultiPoolSize } + lb, err := newBuiltinLB(lbs) + if err != nil { + return nil, err + } + return NewMultiPoolWithLB(size, sizePerPool, lb, options...) +} - if lbs != RoundRobin && lbs != LeastTasks { +// NewMultiPoolWithLB instantiates a MultiPool with a given LoadBalancer. +func NewMultiPoolWithLB(size, sizePerPool int, lb LoadBalancer, options ...Option) (*MultiPool, error) { + if size <= 0 { + return nil, ErrInvalidMultiPoolSize + } + if lb == nil { return nil, ErrInvalidLoadBalancingStrategy } + pools := make([]*Pool, size) + metrics := make([]PoolMetrics, size) for i := 0; i < size; i++ { pool, err := NewPool(sizePerPool, options...) if err != nil { @@ -115,25 +116,10 @@ func NewMultiPool(size, sizePerPool int, lbs LoadBalancingStrategy, options ...O return nil, err } pools[i] = pool + metrics[i] = pool } - return &MultiPool{pools: pools, index: math.MaxUint32, lbs: lbs}, nil -} -func (mp *MultiPool) next(lbs LoadBalancingStrategy) (idx int) { - switch lbs { - case RoundRobin: - return int(atomic.AddUint32(&mp.index, 1) % uint32(len(mp.pools))) - case LeastTasks: - leastTasks := math.MaxInt32 - for i, pool := range mp.pools { - if n := pool.Running(); n < leastTasks { - leastTasks = n - idx = i - } - } - return - } - return -1 + return &MultiPool{pools: pools, metrics: metrics, lb: lb}, nil } // Submit submits a task to a pool selected by the load-balancing strategy. @@ -141,11 +127,12 @@ func (mp *MultiPool) Submit(task func()) (err error) { if mp.IsClosed() { return ErrPoolClosed } - if err = mp.pools[mp.next(mp.lbs)].Submit(task); err == nil { - return - } - if err == ErrPoolOverload && mp.lbs == RoundRobin { - return mp.pools[mp.next(LeastTasks)].Submit(task) + + idx := mp.lb.Pick(mp.metrics) + if err = mp.pools[idx].Submit(task); err == ErrPoolOverload { + if fb := mp.lb.Fallback(mp.metrics); fb >= 0 { + return mp.pools[fb].Submit(task) + } } return } @@ -246,7 +233,6 @@ func (mp *MultiPool) ReleaseContext(ctx context.Context) error { // Reboot reboots a released multi-pool. func (mp *MultiPool) Reboot() { if atomic.CompareAndSwapInt32(&mp.state, CLOSED, OPENED) { - atomic.StoreUint32(&mp.index, 0) for _, pool := range mp.pools { pool.Reboot() } diff --git a/multipool_func.go b/multipool_func.go index e5c0e80..aa555bc 100644 --- a/multipool_func.go +++ b/multipool_func.go @@ -24,7 +24,6 @@ package ants import ( "context" - "math" "sync/atomic" "time" ) @@ -35,10 +34,10 @@ import ( // MultiPoolWithFunc is a good fit for the scenario where you have a large number of // tasks to submit, and you don't want the single pool to be the bottleneck. type MultiPoolWithFunc struct { - pools []*PoolWithFunc - index uint32 - state int32 - lbs LoadBalancingStrategy + pools []*PoolWithFunc + metrics []PoolMetrics + lb LoadBalancer + state int32 } // NewMultiPoolWithFunc instantiates a MultiPoolWithFunc with a size of the pool list and a size @@ -47,11 +46,24 @@ func NewMultiPoolWithFunc(size, sizePerPool int, fn func(any), lbs LoadBalancing if size <= 0 { return nil, ErrInvalidMultiPoolSize } + lb, err := newBuiltinLB(lbs) + if err != nil { + return nil, err + } + return NewMultiPoolWithFuncAndLB(size, sizePerPool, fn, lb, options...) +} - if lbs != RoundRobin && lbs != LeastTasks { +// NewMultiPoolWithFuncAndLB instantiates a MultiPoolWithFunc with a given LoadBalancer. +func NewMultiPoolWithFuncAndLB(size, sizePerPool int, fn func(any), lb LoadBalancer, options ...Option) (*MultiPoolWithFunc, error) { + if size <= 0 { + return nil, ErrInvalidMultiPoolSize + } + if lb == nil { return nil, ErrInvalidLoadBalancingStrategy } + pools := make([]*PoolWithFunc, size) + metrics := make([]PoolMetrics, size) for i := 0; i < size; i++ { pool, err := NewPoolWithFunc(sizePerPool, fn, options...) if err != nil { @@ -62,25 +74,10 @@ func NewMultiPoolWithFunc(size, sizePerPool int, fn func(any), lbs LoadBalancing return nil, err } pools[i] = pool + metrics[i] = pool } - return &MultiPoolWithFunc{pools: pools, index: math.MaxUint32, lbs: lbs}, nil -} -func (mp *MultiPoolWithFunc) next(lbs LoadBalancingStrategy) (idx int) { - switch lbs { - case RoundRobin: - return int(atomic.AddUint32(&mp.index, 1) % uint32(len(mp.pools))) - case LeastTasks: - leastTasks := math.MaxInt32 - for i, pool := range mp.pools { - if n := pool.Running(); n < leastTasks { - leastTasks = n - idx = i - } - } - return - } - return -1 + return &MultiPoolWithFunc{pools: pools, metrics: metrics, lb: lb}, nil } // Invoke submits a task to a pool selected by the load-balancing strategy. @@ -88,12 +85,11 @@ func (mp *MultiPoolWithFunc) Invoke(args any) (err error) { if mp.IsClosed() { return ErrPoolClosed } - - if err = mp.pools[mp.next(mp.lbs)].Invoke(args); err == nil { - return - } - if err == ErrPoolOverload && mp.lbs == RoundRobin { - return mp.pools[mp.next(LeastTasks)].Invoke(args) + idx := mp.lb.Pick(mp.metrics) + if err = mp.pools[idx].Invoke(args); err == ErrPoolOverload { + if fb := mp.lb.Fallback(mp.metrics); fb >= 0 { + return mp.pools[fb].Invoke(args) + } } return } @@ -194,7 +190,6 @@ func (mp *MultiPoolWithFunc) ReleaseContext(ctx context.Context) error { // Reboot reboots a released multi-pool. func (mp *MultiPoolWithFunc) Reboot() { if atomic.CompareAndSwapInt32(&mp.state, CLOSED, OPENED) { - atomic.StoreUint32(&mp.index, 0) for _, pool := range mp.pools { pool.Reboot() } diff --git a/multipool_func_generic.go b/multipool_func_generic.go index 2b247ff..88699f0 100644 --- a/multipool_func_generic.go +++ b/multipool_func_generic.go @@ -24,17 +24,16 @@ package ants import ( "context" - "math" "sync/atomic" "time" ) // MultiPoolWithFuncGeneric is the generic version of MultiPoolWithFunc. type MultiPoolWithFuncGeneric[T any] struct { - pools []*PoolWithFuncGeneric[T] - index uint32 - state int32 - lbs LoadBalancingStrategy + pools []*PoolWithFuncGeneric[T] + metrics []PoolMetrics + lb LoadBalancer + state int32 } // NewMultiPoolWithFuncGeneric instantiates a MultiPoolWithFunc with a size of the pool list and a size @@ -43,11 +42,24 @@ func NewMultiPoolWithFuncGeneric[T any](size, sizePerPool int, fn func(T), lbs L if size <= 0 { return nil, ErrInvalidMultiPoolSize } + lb, err := newBuiltinLB(lbs) + if err != nil { + return nil, err + } + return NewMultiPoolWithFuncGenericAndLB(size, sizePerPool, fn, lb, options...) +} - if lbs != RoundRobin && lbs != LeastTasks { +// NewMultiPoolWithFuncGenericAndLB instantiates a MultiPoolWithFuncGeneric with a given LoadBalancer. +func NewMultiPoolWithFuncGenericAndLB[T any](size, sizePerPool int, fn func(T), lb LoadBalancer, options ...Option) (*MultiPoolWithFuncGeneric[T], error) { + if size <= 0 { + return nil, ErrInvalidMultiPoolSize + } + if lb == nil { return nil, ErrInvalidLoadBalancingStrategy } + pools := make([]*PoolWithFuncGeneric[T], size) + metrics := make([]PoolMetrics, size) for i := 0; i < size; i++ { pool, err := NewPoolWithFuncGeneric(sizePerPool, fn, options...) if err != nil { @@ -58,25 +70,9 @@ func NewMultiPoolWithFuncGeneric[T any](size, sizePerPool int, fn func(T), lbs L return nil, err } pools[i] = pool + metrics[i] = pool } - return &MultiPoolWithFuncGeneric[T]{pools: pools, index: math.MaxUint32, lbs: lbs}, nil -} - -func (mp *MultiPoolWithFuncGeneric[T]) next(lbs LoadBalancingStrategy) (idx int) { - switch lbs { - case RoundRobin: - return int(atomic.AddUint32(&mp.index, 1) % uint32(len(mp.pools))) - case LeastTasks: - leastTasks := math.MaxInt32 - for i, pool := range mp.pools { - if n := pool.Running(); n < leastTasks { - leastTasks = n - idx = i - } - } - return - } - return -1 + return &MultiPoolWithFuncGeneric[T]{pools: pools, metrics: metrics, lb: lb}, nil } // Invoke submits a task to a pool selected by the load-balancing strategy. @@ -84,12 +80,11 @@ func (mp *MultiPoolWithFuncGeneric[T]) Invoke(args T) (err error) { if mp.IsClosed() { return ErrPoolClosed } - - if err = mp.pools[mp.next(mp.lbs)].Invoke(args); err == nil { - return - } - if err == ErrPoolOverload && mp.lbs == RoundRobin { - return mp.pools[mp.next(LeastTasks)].Invoke(args) + idx := mp.lb.Pick(mp.metrics) + if err = mp.pools[idx].Invoke(args); err == ErrPoolOverload { + if fb := mp.lb.Fallback(mp.metrics); fb >= 0 { + return mp.pools[fb].Invoke(args) + } } return } @@ -190,7 +185,6 @@ func (mp *MultiPoolWithFuncGeneric[T]) ReleaseContext(ctx context.Context) error // Reboot reboots a released multi-pool. func (mp *MultiPoolWithFuncGeneric[T]) Reboot() { if atomic.CompareAndSwapInt32(&mp.state, CLOSED, OPENED) { - atomic.StoreUint32(&mp.index, 0) for _, pool := range mp.pools { pool.Reboot() } From cca4801c37a5aa3336909873bb838c271fd1ad73 Mon Sep 17 00:00:00 2001 From: hehaifeng <240657481@qq.com> Date: Wed, 27 May 2026 01:36:25 +0800 Subject: [PATCH 2/4] feat: add bounds check on Pick/Fallback to guard against invalid indices from custom LoadBalancer implementations --- lbs.go | 5 +++++ multipool.go | 5 ++++- multipool_func.go | 5 ++++- multipool_func_generic.go | 5 ++++- 4 files changed, 17 insertions(+), 3 deletions(-) diff --git a/lbs.go b/lbs.go index 0c94f3e..b26d4e4 100644 --- a/lbs.go +++ b/lbs.go @@ -122,6 +122,11 @@ func (l *leastWaiting) Fallback(pools []PoolMetrics) int { return -1 } +// validIdx checks that idx returned by a LoadBalancer is within bounds. +func validIdx(idx, n int) bool { + return idx >= 0 && idx < n +} + func newBuiltinLB(lbs LoadBalancingStrategy) (LoadBalancer, error) { switch lbs { case RoundRobin: diff --git a/multipool.go b/multipool.go index 2517a82..787e19b 100644 --- a/multipool.go +++ b/multipool.go @@ -129,8 +129,11 @@ func (mp *MultiPool) Submit(task func()) (err error) { } idx := mp.lb.Pick(mp.metrics) + if !validIdx(idx, len(mp.pools)) { + return ErrInvalidPoolIndex + } if err = mp.pools[idx].Submit(task); err == ErrPoolOverload { - if fb := mp.lb.Fallback(mp.metrics); fb >= 0 { + if fb := mp.lb.Fallback(mp.metrics); validIdx(fb, len(mp.pools)) { return mp.pools[fb].Submit(task) } } diff --git a/multipool_func.go b/multipool_func.go index aa555bc..a142719 100644 --- a/multipool_func.go +++ b/multipool_func.go @@ -86,8 +86,11 @@ func (mp *MultiPoolWithFunc) Invoke(args any) (err error) { return ErrPoolClosed } idx := mp.lb.Pick(mp.metrics) + if !validIdx(idx, len(mp.pools)) { + return ErrInvalidPoolIndex + } if err = mp.pools[idx].Invoke(args); err == ErrPoolOverload { - if fb := mp.lb.Fallback(mp.metrics); fb >= 0 { + if fb := mp.lb.Fallback(mp.metrics); validIdx(fb, len(mp.pools)) { return mp.pools[fb].Invoke(args) } } diff --git a/multipool_func_generic.go b/multipool_func_generic.go index 88699f0..3f1e8c3 100644 --- a/multipool_func_generic.go +++ b/multipool_func_generic.go @@ -81,8 +81,11 @@ func (mp *MultiPoolWithFuncGeneric[T]) Invoke(args T) (err error) { return ErrPoolClosed } idx := mp.lb.Pick(mp.metrics) + if !validIdx(idx, len(mp.pools)) { + return ErrInvalidPoolIndex + } if err = mp.pools[idx].Invoke(args); err == ErrPoolOverload { - if fb := mp.lb.Fallback(mp.metrics); fb >= 0 { + if fb := mp.lb.Fallback(mp.metrics); validIdx(fb, len(mp.pools)) { return mp.pools[fb].Invoke(args) } } From bf9dc05fc900718fd9a9c42e969b4a9e84020a1e Mon Sep 17 00:00:00 2001 From: hehaifeng <240657481@qq.com> Date: Mon, 1 Jun 2026 17:13:51 +0800 Subject: [PATCH 3/4] fix: fix lint, copyright and rename Fallback to FallBack --- ants_benchmark_test.go | 2 +- lbs.go | 12 ++++++------ multipool.go | 2 +- multipool_func.go | 2 +- multipool_func_generic.go | 2 +- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/ants_benchmark_test.go b/ants_benchmark_test.go index 569e851..4daa7b6 100644 --- a/ants_benchmark_test.go +++ b/ants_benchmark_test.go @@ -317,7 +317,7 @@ func (r *randomLB) Pick(pools []ants.PoolMetrics) int { return rand.Intn(len(pools)) } -func (r *randomLB) Fallback(pools []ants.PoolMetrics) int { +func (r *randomLB) FallBack(_ []ants.PoolMetrics) int { return -1 } diff --git a/lbs.go b/lbs.go index b26d4e4..44d8c97 100644 --- a/lbs.go +++ b/lbs.go @@ -1,6 +1,6 @@ // MIT License -// Copyright (c) 2023 Andy Pan +// Copyright (c) 2026. Ants Authors. All rights reserved. // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal @@ -49,9 +49,9 @@ type PoolMetrics interface { // LoadBalancer picks a pool index from a slice of PoolMetrics. type LoadBalancer interface { Pick(pools []PoolMetrics) int - // Fallback is called when the pool chosen by Pick is overloaded. + // FallBack is called when the pool chosen by Pick is overloaded. // Return -1 to indicate no fallback is supported. - Fallback(pools []PoolMetrics) int + FallBack(pools []PoolMetrics) int } // roundRobinLB distributes tasks across pools in rotation. @@ -68,7 +68,7 @@ func (r *roundRobinLB) Pick(pools []PoolMetrics) int { return int(atomic.AddUint32(&r.index, 1) % uint32(len(pools))) } -func (r *roundRobinLB) Fallback(pools []PoolMetrics) int { +func (r *roundRobinLB) FallBack(pools []PoolMetrics) int { return leastTasksPick(pools) } @@ -95,7 +95,7 @@ func (l *leastTasksLB) Pick(pools []PoolMetrics) int { return leastTasksPick(pools) } -func (l *leastTasksLB) Fallback(pools []PoolMetrics) int { +func (l *leastTasksLB) FallBack(_ []PoolMetrics) int { return -1 } @@ -118,7 +118,7 @@ func (l *leastWaiting) Pick(pools []PoolMetrics) int { return idx } -func (l *leastWaiting) Fallback(pools []PoolMetrics) int { +func (l *leastWaiting) FallBack(_ []PoolMetrics) int { return -1 } diff --git a/multipool.go b/multipool.go index 787e19b..a60257e 100644 --- a/multipool.go +++ b/multipool.go @@ -133,7 +133,7 @@ func (mp *MultiPool) Submit(task func()) (err error) { return ErrInvalidPoolIndex } if err = mp.pools[idx].Submit(task); err == ErrPoolOverload { - if fb := mp.lb.Fallback(mp.metrics); validIdx(fb, len(mp.pools)) { + if fb := mp.lb.FallBack(mp.metrics); validIdx(fb, len(mp.pools)) { return mp.pools[fb].Submit(task) } } diff --git a/multipool_func.go b/multipool_func.go index a142719..546550d 100644 --- a/multipool_func.go +++ b/multipool_func.go @@ -90,7 +90,7 @@ func (mp *MultiPoolWithFunc) Invoke(args any) (err error) { return ErrInvalidPoolIndex } if err = mp.pools[idx].Invoke(args); err == ErrPoolOverload { - if fb := mp.lb.Fallback(mp.metrics); validIdx(fb, len(mp.pools)) { + if fb := mp.lb.FallBack(mp.metrics); validIdx(fb, len(mp.pools)) { return mp.pools[fb].Invoke(args) } } diff --git a/multipool_func_generic.go b/multipool_func_generic.go index 3f1e8c3..4527f82 100644 --- a/multipool_func_generic.go +++ b/multipool_func_generic.go @@ -85,7 +85,7 @@ func (mp *MultiPoolWithFuncGeneric[T]) Invoke(args T) (err error) { return ErrInvalidPoolIndex } if err = mp.pools[idx].Invoke(args); err == ErrPoolOverload { - if fb := mp.lb.Fallback(mp.metrics); validIdx(fb, len(mp.pools)) { + if fb := mp.lb.FallBack(mp.metrics); validIdx(fb, len(mp.pools)) { return mp.pools[fb].Invoke(args) } } From b6608bf46d981eb9fa52f3887088e580ce270eea Mon Sep 17 00:00:00 2001 From: hehaifeng <240657481@qq.com> Date: Mon, 1 Jun 2026 17:23:31 +0800 Subject: [PATCH 4/4] fix: format all files with gofumpt --- ants_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/ants_test.go b/ants_test.go index 8ca8a41..fa46d06 100644 --- a/ants_test.go +++ b/ants_test.go @@ -1815,6 +1815,7 @@ func TestRebootNewPoolWithPreAllocCalc(t *testing.T) { wg.Wait() require.EqualValues(t, 499500, sum, "The result should be 499500") } + func TestMultiPoolWithLB_RoundRobin(t *testing.T) { _, err := ants.NewMultiPoolWithLB(-1, 5, ants.NewRoundRobinLB()) require.ErrorIs(t, err, ants.ErrInvalidMultiPoolSize)