Skip to content

Commit 1376f85

Browse files
committed
cover corner case
1 parent d8b8edc commit 1376f85

File tree

3 files changed

+100
-2
lines changed

3 files changed

+100
-2
lines changed

hitless/handoff_worker.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,11 @@ func (hwm *handoffWorkerManager) shutdownWorkers(ctx context.Context) error {
285285
hwm.shutdownOnce.Do(func() {
286286
close(hwm.shutdown)
287287
// workers will exit when they finish their current request
288+
289+
// Shutdown circuit breaker manager cleanup goroutine
290+
if hwm.circuitBreakerManager != nil {
291+
hwm.circuitBreakerManager.Shutdown()
292+
}
288293
})
289294

290295
// Wait for workers to complete

internal/pool/conn.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,9 @@ func (cn *Conn) ClearRelaxedTimeout() {
244244
// Atomically decrement counter and check if we should clear
245245
newCount := cn.relaxedCounter.Add(-1)
246246
if newCount <= 0 {
247-
// Use compare-and-swap to ensure only one goroutine clears
248-
if cn.relaxedCounter.CompareAndSwap(newCount, 0) {
247+
// Use atomic load to get current value for CAS to avoid stale value race
248+
current := cn.relaxedCounter.Load()
249+
if current <= 0 && cn.relaxedCounter.CompareAndSwap(current, 0) {
249250
cn.clearRelaxedTimeout()
250251
}
251252
}
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package pool
2+
3+
import (
4+
"net"
5+
"sync"
6+
"testing"
7+
"time"
8+
)
9+
10+
// TestConcurrentRelaxedTimeoutClearing tests the race condition fix in ClearRelaxedTimeout
11+
func TestConcurrentRelaxedTimeoutClearing(t *testing.T) {
12+
// Create a dummy connection for testing
13+
netConn := &net.TCPConn{}
14+
cn := NewConn(netConn)
15+
defer cn.Close()
16+
17+
// Set relaxed timeout multiple times to increase counter
18+
cn.SetRelaxedTimeout(time.Second, time.Second)
19+
cn.SetRelaxedTimeout(time.Second, time.Second)
20+
cn.SetRelaxedTimeout(time.Second, time.Second)
21+
22+
// Verify counter is 3
23+
if count := cn.relaxedCounter.Load(); count != 3 {
24+
t.Errorf("Expected relaxed counter to be 3, got %d", count)
25+
}
26+
27+
// Clear timeouts concurrently to test race condition fix
28+
var wg sync.WaitGroup
29+
for i := 0; i < 10; i++ {
30+
wg.Add(1)
31+
go func() {
32+
defer wg.Done()
33+
cn.ClearRelaxedTimeout()
34+
}()
35+
}
36+
wg.Wait()
37+
38+
// Verify counter is 0 and timeouts are cleared
39+
if count := cn.relaxedCounter.Load(); count != 0 {
40+
t.Errorf("Expected relaxed counter to be 0 after clearing, got %d", count)
41+
}
42+
if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 {
43+
t.Errorf("Expected relaxed read timeout to be 0, got %d", timeout)
44+
}
45+
if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 {
46+
t.Errorf("Expected relaxed write timeout to be 0, got %d", timeout)
47+
}
48+
}
49+
50+
// TestRelaxedTimeoutCounterRaceCondition tests the specific race condition scenario
51+
func TestRelaxedTimeoutCounterRaceCondition(t *testing.T) {
52+
netConn := &net.TCPConn{}
53+
cn := NewConn(netConn)
54+
defer cn.Close()
55+
56+
// Set relaxed timeout once
57+
cn.SetRelaxedTimeout(time.Second, time.Second)
58+
59+
// Verify counter is 1
60+
if count := cn.relaxedCounter.Load(); count != 1 {
61+
t.Errorf("Expected relaxed counter to be 1, got %d", count)
62+
}
63+
64+
// Test concurrent clearing with race condition scenario
65+
var wg sync.WaitGroup
66+
67+
// Multiple goroutines try to clear simultaneously
68+
for i := 0; i < 5; i++ {
69+
wg.Add(1)
70+
go func() {
71+
defer wg.Done()
72+
cn.ClearRelaxedTimeout()
73+
}()
74+
}
75+
wg.Wait()
76+
77+
// Verify final state is consistent
78+
if count := cn.relaxedCounter.Load(); count != 0 {
79+
t.Errorf("Expected relaxed counter to be 0 after concurrent clearing, got %d", count)
80+
}
81+
82+
// Verify timeouts are actually cleared
83+
if timeout := cn.relaxedReadTimeoutNs.Load(); timeout != 0 {
84+
t.Errorf("Expected relaxed read timeout to be cleared, got %d", timeout)
85+
}
86+
if timeout := cn.relaxedWriteTimeoutNs.Load(); timeout != 0 {
87+
t.Errorf("Expected relaxed write timeout to be cleared, got %d", timeout)
88+
}
89+
if deadline := cn.relaxedDeadlineNs.Load(); deadline != 0 {
90+
t.Errorf("Expected relaxed deadline to be cleared, got %d", deadline)
91+
}
92+
}

0 commit comments

Comments
 (0)