diff --git a/x/mongo/driver/topology/pool.go b/x/mongo/driver/topology/pool.go index 1dd236dfdc..34c3c604d6 100644 --- a/x/mongo/driver/topology/pool.go +++ b/x/mongo/driver/topology/pool.go @@ -128,6 +128,8 @@ type pool struct { idleConns []*connection // idleConns holds all idle connections. idleConnWait wantConnQueue // idleConnWait holds all wantConn requests for idle connections. connectTimeout time.Duration + + connectionSem chan struct{} } // getState returns the current state of the pool. Callers must not hold the stateMu lock. @@ -226,6 +228,7 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) *pool { conns: make(map[int64]*connection, config.MaxPoolSize), idleConns: make([]*connection, 0, config.MaxPoolSize), connectTimeout: config.ConnectTimeout, + connectionSem: make(chan struct{}, maxConnecting), } // minSize must not exceed maxSize if maxSize is not 0 if pool.maxSize != 0 && pool.minSize > pool.maxSize { @@ -241,11 +244,6 @@ func newPool(config poolConfig, connOpts ...ConnectionOption) *pool { var ctx context.Context ctx, pool.cancelBackgroundCtx = context.WithCancel(context.Background()) - for i := 0; i < int(pool.maxConnecting); i++ { - pool.backgroundDone.Add(1) - go pool.createConnections(ctx, pool.backgroundDone) - } - // If maintainInterval is not positive, don't start the maintain() goroutine. Expect that // negative values are only used in testing; this config value is not user-configurable. if maintainInterval > 0 { @@ -610,7 +608,7 @@ func (p *pool) checkOut(ctx context.Context) (conn *connection, err error) { // If we didn't get an immediately available idle connection, also get in the queue for a new // connection while we're waiting for an idle connection. - p.queueForNewConn(w) + p.queueForNewConn(ctx, w) p.stateMu.RUnlock() // Wait for either the wantConn to be ready or for the Context to time out. @@ -936,6 +934,18 @@ func (p *pool) checkInNoEvent(conn *connection) error { go func() { _ = p.closeConnection(conn) }() + + // Since we are removing the connection, we should try to queue another + // in good faith in case the current idle wait queue is being awaited + // in a checkOut() call. + p.createConnectionsCond.L.Lock() + w := p.newConnWait.popFront() + p.createConnectionsCond.L.Unlock() + + if w != nil { + p.queueForNewConn(context.Background(), w) + } + return nil } @@ -1130,13 +1140,30 @@ func (p *pool) getOrQueueForIdleConn(w *wantConn) bool { return false } -func (p *pool) queueForNewConn(w *wantConn) { +// queueForNewConn enqueues a checkout request and signals the +// connection-creation state machine. It does NOT initiate dialing directly, +// but places the wantConn into the pending queue and wakes a background worker +// using sync.Cond. That worker will then dequeue in FIFO order and perform the +// actual dial under it's own synchronization, preserving order. +func (p *pool) queueForNewConn(ctx context.Context, w *wantConn) { p.createConnectionsCond.L.Lock() defer p.createConnectionsCond.L.Unlock() + // Remove any wantConn entries at the front that are no longer waiting. This + // keeps the queue clean and avoids delivering to canceled requests. p.newConnWait.cleanFront() + + // Enqueu this wantConn for allocation of a new connection. p.newConnWait.pushBack(w) + + // Signale on goroutine waiting in waitForNewConn that pool state changed and + // new wantConn is available. That goroutine will then dequeue under lock. p.createConnectionsCond.Signal() + + // Spawn a background worker to service the queue without blocking callers. We + // do NOT pass "w" here because the worker must re-acquite the queue lock and + // pick the next available wantConn in FIFO order via waitForNewConn. + go p.spawnConnectionIfNeeded(ctx) } func (p *pool) totalConnectionCount() int { @@ -1153,143 +1180,6 @@ func (p *pool) availableConnectionCount() int { return len(p.idleConns) } -// createConnections creates connections for wantConn requests on the newConnWait queue. -func (p *pool) createConnections(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() - - // condition returns true if the createConnections() loop should continue and false if it should - // wait. Note that the condition also listens for Context cancellation, which also causes the - // loop to continue, allowing for a subsequent check to return from createConnections(). - condition := func() bool { - checkOutWaiting := p.newConnWait.len() > 0 - poolHasSpace := p.maxSize == 0 || uint64(len(p.conns)) < p.maxSize - cancelled := ctx.Err() != nil - return (checkOutWaiting && poolHasSpace) || cancelled - } - - // wait waits for there to be an available wantConn and for the pool to have space for a new - // connection. When the condition becomes true, it creates a new connection and returns the - // waiting wantConn and new connection. If the Context is cancelled or there are any - // errors, wait returns with "ok = false". - wait := func() (*wantConn, *connection, bool) { - p.createConnectionsCond.L.Lock() - defer p.createConnectionsCond.L.Unlock() - - for !condition() { - p.createConnectionsCond.Wait() - } - - if ctx.Err() != nil { - return nil, nil, false - } - - p.newConnWait.cleanFront() - w := p.newConnWait.popFront() - if w == nil { - return nil, nil, false - } - - conn := newConnection(p.address, p.connOpts...) - conn.pool = p - conn.driverConnectionID = atomic.AddInt64(&p.nextID, 1) - p.conns[conn.driverConnectionID] = conn - - return w, conn, true - } - - for ctx.Err() == nil { - w, conn, ok := wait() - if !ok { - continue - } - - if mustLogPoolMessage(p) { - keysAndValues := logger.KeyValues{ - logger.KeyDriverConnectionID, conn.driverConnectionID, - } - - logPoolMessage(p, logger.ConnectionCreated, keysAndValues...) - } - - if p.monitor != nil { - p.monitor.Event(&event.PoolEvent{ - Type: event.ConnectionCreated, - Address: p.address.String(), - ConnectionID: conn.driverConnectionID, - }) - } - - start := time.Now() - // Pass the createConnections context to connect to allow pool close to - // cancel connection establishment so shutdown doesn't block indefinitely if - // connectTimeout=0. - // - // Per the specifications, an explicit value of connectTimeout=0 means the - // timeout is "infinite". - - var cancel context.CancelFunc - - connctx := context.Background() - if p.connectTimeout != 0 { - connctx, cancel = context.WithTimeout(ctx, p.connectTimeout) - } - - err := conn.connect(connctx) - - if cancel != nil { - cancel() - } - - if err != nil { - w.tryDeliver(nil, err) - - // If there's an error connecting the new connection, call the handshake error handler - // that implements the SDAM handshake error handling logic. This must be called after - // delivering the connection error to the waiting wantConn. If it's called before, the - // handshake error handler may clear the connection pool, leading to a different error - // message being delivered to the same waiting wantConn in idleConnWait when the wait - // queues are cleared. - if p.handshakeErrFn != nil { - p.handshakeErrFn(err, conn.generation, conn.desc.ServiceID) - } - - _ = p.removeConnection(conn, reason{ - loggerConn: logger.ReasonConnClosedError, - event: event.ReasonError, - }, err) - - _ = p.closeConnection(conn) - - continue - } - - duration := time.Since(start) - if mustLogPoolMessage(p) { - keysAndValues := logger.KeyValues{ - logger.KeyDriverConnectionID, conn.driverConnectionID, - logger.KeyDurationMS, duration.Milliseconds(), - } - - logPoolMessage(p, logger.ConnectionReady, keysAndValues...) - } - - if p.monitor != nil { - p.monitor.Event(&event.PoolEvent{ - Type: event.ConnectionReady, - Address: p.address.String(), - ConnectionID: conn.driverConnectionID, - Duration: duration, - }) - } - - if w.tryDeliver(conn, nil) { - continue - } - - _ = p.checkInNoEvent(conn) - } -} - func (p *pool) maintain(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() @@ -1364,7 +1254,7 @@ func (p *pool) maintain(ctx context.Context, wg *sync.WaitGroup) { for i := 0; i < n; i++ { w := newWantConn() - p.queueForNewConn(w) + p.queueForNewConn(ctx, w) wantConns = append(wantConns, w) // Start a goroutine for each new wantConn, waiting for it to be ready. @@ -1551,3 +1441,158 @@ func (q *wantConnQueue) cleanFront() { q.popFront() } } + +// spawnConnection establishes a new connection and delivers it to a waiting +// request. It handles dialing, handshaking, and error handling. This function +// is intended to be run in its own goroutine. +func (p *pool) spawnConnection(w *wantConn, conn *connection) { + // Release a slot from the connection semaphore when this function returns. + // This ensures that another connection can be spawned. + defer func() { <-p.connectionSem }() + + // Record the start time to calculate the total connection setup duration. + start := time.Now() + + // Create a context for the dial operation. If a connection timeout is + // configured, the context will be set to time out after that duration. + dialCtx := context.Background() + var cancel context.CancelFunc + if p.connectTimeout > 0 { + dialCtx, cancel = context.WithTimeout(dialCtx, p.connectTimeout) + defer cancel() + } + + // Attempt to connect + if err := conn.connect(dialCtx); err != nil { + // If connection fails, deliver the error to the waiting requester. + w.tryDeliver(nil, err) + + // If a handshake error handler is defined, invoke it to handle SDAM state + // changes. This is done after delivering the error to prevent race + // conditions where the pool might be cleared before the error is delivered. + if p.handshakeErrFn != nil { + p.handshakeErrFn(err, conn.generation, conn.desc.ServiceID) + } + + _ = p.removeConnection(conn, reason{ + loggerConn: logger.ReasonConnClosedError, + event: event.ReasonError, + }, err) + + _ = p.closeConnection(conn) + + return + } + + // emit "ConnectionReady" + duration := time.Since(start) + if mustLogPoolMessage(p) { + keysAndValues := logger.KeyValues{ + logger.KeyDriverConnectionID, conn.driverConnectionID, + logger.KeyDurationMS, duration.Milliseconds(), + } + + logPoolMessage(p, logger.ConnectionReady, keysAndValues...) + } + + if p.monitor != nil { + p.monitor.Event(&event.PoolEvent{ + Type: event.ConnectionReady, + Address: p.address.String(), + ConnectionID: conn.driverConnectionID, + Duration: duration, + }) + } + + // deliver the connection or check it back in on spurious wakeup + if !w.tryDeliver(conn, nil) { + _ = p.checkInNoEvent(conn) + } +} + +// hasSpace checks if the pool has space for a new connection. It returns +// "true" if the maximum size is unlimited (0) or if the current number of +// connections is less than the maximum size. +func (p *pool) hasSpace() bool { + return p.maxSize == 0 || uint64(len(p.conns)) < p.maxSize +} + +// checkOutWaiting checks if there are any waiting connections that need to be +// checked out. +func (p *pool) checkOutWaiting() bool { + return p.newConnWait.len() > 0 +} + +// waitForNewConn blocks until there's both work and room in the pool (or the +// context is canceled) then pops exactly one wantconn and creates+registers its +// connection. +func (p *pool) waitForNewConn(ctx context.Context) (*wantConn, *connection, bool) { + p.createConnectionsCond.L.Lock() + defer p.createConnectionsCond.L.Unlock() + + for !(p.checkOutWaiting() && p.hasSpace()) && ctx.Err() == nil { + p.createConnectionsCond.Wait() + } + + if ctx.Err() != nil { + return nil, nil, false + } + + p.newConnWait.cleanFront() + w := p.newConnWait.popFront() + if w == nil { + return nil, nil, false + } + + conn := newConnection(p.address, p.connOpts...) + conn.pool = p + conn.driverConnectionID = atomic.AddInt64(&p.nextID, 1) + p.conns[conn.driverConnectionID] = conn + + return w, conn, true +} + +// spawnConnectionIfNeeded takes on waiting waitConn (if any) and starts its +// connection creation subject to the semaphore limit. +func (p *pool) spawnConnectionIfNeeded(ctx context.Context) { + p.createConnectionsCond.L.Lock() + openSlot := p.hasSpace() + p.createConnectionsCond.L.Unlock() + + if !openSlot { + // If the pool is full, we can't spawn a new connection. This guard prevents + // spawning an unbound number of goroutines. + return + } + + // Block until we're allowed to start another connection. + p.connectionSem <- struct{}{} + + // Wait on pool space & context. + w, conn, ok := p.waitForNewConn(ctx) + if !ok { + <-p.connectionSem // Release slot on failure. + + return + } + + // Emit "ConnectionCreated" + if mustLogPoolMessage(p) { + keysAndValues := logger.KeyValues{ + logger.KeyDriverConnectionID, conn.driverConnectionID, + } + + logPoolMessage(p, logger.ConnectionCreated, keysAndValues...) + } + + if p.monitor != nil { + p.monitor.Event(&event.PoolEvent{ + Type: event.ConnectionCreated, + Address: p.address.String(), + ConnectionID: conn.driverConnectionID, + }) + } + + // Dial the connection and spawn it in the background. + go p.spawnConnection(w, conn) +} diff --git a/x/mongo/driver/topology/pool_test.go b/x/mongo/driver/topology/pool_test.go index 17e803ea49..52b5ef7d01 100644 --- a/x/mongo/driver/topology/pool_test.go +++ b/x/mongo/driver/topology/pool_test.go @@ -11,6 +11,7 @@ import ( "errors" "net" "regexp" + "runtime" "sync" "testing" "time" @@ -1608,3 +1609,50 @@ func TestPool_Error(t *testing.T) { p.close(context.Background()) }) } + +// Test that if the pool is already at MaxPoolSize, a flood of checkOuts with +// a background context spins up unbounded goroutines. +func TestPool_unboundedGoroutines(t *testing.T) { + // Start a server that never response so no connection ever frees up. + addr := bootstrapConnections(t, 1, func(net.Conn) { + <-make(chan struct{}) + }) + + // Create pool with exactly 1 slot and 1 dial slot. + p := newPool(poolConfig{ + Address: address.Address(addr.String()), + MaxPoolSize: 1, + MaxConnecting: 1, + ConnectTimeout: defaultConnectionTimeout, + }) + require.NoError(t, p.ready(), "pool ready error") + + // Drain the only connection so the pool is full. + c, err := p.checkOut(context.Background()) + require.NoError(t, err, "checkOut error") + + defer func() { + _ = p.checkIn(c) + p.close(context.Background()) + }() + + // Snapshot base goroutine count. + before := runtime.NumGoroutine() + + // Flood with N background checkOuts + const N = 100 + for i := 0; i < N; i++ { + go func() { + _, _ = p.checkOut(context.Background()) + }() + } + + // Give them a moment to spin up + time.Sleep(1000 * time.Millisecond) + + after := runtime.NumGoroutine() + delta := after - before - N + + assert.LessOrEqual(t, delta, int(p.maxConnecting)) + +}