diff --git a/_typos.toml b/_typos.toml index dfd475b..567e835 100644 --- a/_typos.toml +++ b/_typos.toml @@ -2,3 +2,6 @@ [files] extend-exclude = ["go.mod", "go.sum"] + +[default.extend-words] +typ = "typ" # type \ No newline at end of file diff --git a/connstate/conn.go b/connstate/conn.go new file mode 100644 index 0000000..b7f52ff --- /dev/null +++ b/connstate/conn.go @@ -0,0 +1,103 @@ +// Copyright 2025 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connstate + +import ( + "errors" + "net" + "sync/atomic" + "syscall" + "unsafe" +) + +type ConnState uint32 + +const ( + // StateOK means the connection is normal. + StateOK ConnState = iota + // StateRemoteClosed means the remote side has closed the connection. + StateRemoteClosed + // StateClosed means the connection has been closed by local side. + StateClosed +) + +// ConnStater is the interface to get the ConnState of a connection. +// Must call Close to release it if you're going to close the connection. +type ConnStater interface { + Close() error + State() ConnState +} + +// ListenConnState returns a ConnStater for the given connection. +// It's generally used for availability checks when obtaining connections from a connection pool. +// Conn must be a syscall.Conn. +func ListenConnState(conn net.Conn) (ConnStater, error) { + pollInitOnce.Do(createPoller) + sysConn, ok := conn.(syscall.Conn) + if !ok { + return nil, errors.New("conn is not syscall.Conn") + } + rawConn, err := sysConn.SyscallConn() + if err != nil { + return nil, err + } + var fd *fdOperator + var opAddErr error + err = rawConn.Control(func(fileDescriptor uintptr) { + fd = pollcache.alloc() + fd.fd = int(fileDescriptor) + atomic.StorePointer(&fd.conn, unsafe.Pointer(&connStater{fd: unsafe.Pointer(fd)})) + opAddErr = poll.control(fd, opAdd) + }) + if fd != nil { + if err != nil && opAddErr == nil { + // if rawConn is closed, poller will delete the fd by itself + _ = rawConn.Control(func(_ uintptr) { + _ = poll.control(fd, opDel) + }) + } + if err != nil || opAddErr != nil { + atomic.StorePointer(&fd.conn, nil) + pollcache.freeable(fd) + } + } + if err != nil { + return nil, err + } + if opAddErr != nil { + return nil, opAddErr + } + return (*connStater)(atomic.LoadPointer(&fd.conn)), nil +} + +type connStater struct { + fd unsafe.Pointer // *fdOperator + state uint32 +} + +func (c *connStater) Close() error { + fd := (*fdOperator)(atomic.LoadPointer(&c.fd)) + if fd != nil && atomic.CompareAndSwapPointer(&c.fd, unsafe.Pointer(fd), nil) { + atomic.StoreUint32(&c.state, uint32(StateClosed)) + _ = poll.control(fd, opDel) + atomic.StorePointer(&fd.conn, nil) + pollcache.freeable(fd) + } + return nil +} + +func (c *connStater) State() ConnState { + return ConnState(atomic.LoadUint32(&c.state)) +} diff --git a/connstate/conn_test.go b/connstate/conn_test.go new file mode 100644 index 0000000..8df1db6 --- /dev/null +++ b/connstate/conn_test.go @@ -0,0 +1,349 @@ +// Copyright 2025 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connstate + +import ( + "errors" + "io" + "net" + "runtime" + "sync" + "syscall" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestListenConnState(t *testing.T) { + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + panic(err) + } + go func() { + for { + conn, err := ln.Accept() + assert.Nil(t, err) + go func(conn net.Conn) { + buf := make([]byte, 11) + _, err := conn.Read(buf) + assert.Nil(t, err) + conn.Close() + }(conn) + } + }() + conn, err := net.Dial("tcp", ln.Addr().String()) + assert.Nil(t, err) + stater, err := ListenConnState(conn) + assert.Nil(t, err) + assert.Equal(t, StateOK, stater.State()) + _, err = conn.Write([]byte("hello world")) + assert.Nil(t, err) + buf := make([]byte, 1) + _, err = conn.Read(buf) + assert.Equal(t, io.EOF, err) + time.Sleep(100 * time.Millisecond) + assert.Equal(t, StateRemoteClosed, stater.State()) + assert.Nil(t, stater.Close()) + assert.Nil(t, conn.Close()) + assert.Equal(t, StateClosed, stater.State()) +} + +type mockPoller struct { + controlFunc func(fd *fdOperator, op op) error +} + +func (m *mockPoller) wait() error { + return nil +} + +func (m *mockPoller) control(fd *fdOperator, op op) error { + return m.controlFunc(fd, op) +} + +type mockConn struct { + net.Conn + controlFunc func(f func(fd uintptr)) error +} + +func (c *mockConn) SyscallConn() (syscall.RawConn, error) { + return &mockRawConn{ + controlFunc: c.controlFunc, + }, nil +} + +type mockRawConn struct { + syscall.RawConn + controlFunc func(f func(fd uintptr)) error +} + +func (r *mockRawConn) Control(f func(fd uintptr)) error { + return r.controlFunc(f) +} + +func TestListenConnState_Err(t *testing.T) { + // replace poll + pollInitOnce.Do(createPoller) + oldPoll := poll + defer func() { + poll = oldPoll + }() + // test detach + var expectDetach bool + defer func() { + assert.True(t, expectDetach) + }() + cases := []struct { + name string + connControlFunc func(f func(fd uintptr)) error + pollControlFunc func(fd *fdOperator, op op) error + expectErr error + }{ + { + name: "err conn control", + connControlFunc: func(f func(fd uintptr)) error { + return errors.New("err conn control") + }, + expectErr: errors.New("err conn control"), + }, + { + name: "err poll control", + connControlFunc: func(f func(fd uintptr)) error { + f(1) + return nil + }, + pollControlFunc: func(fd *fdOperator, op op) error { + assert.Equal(t, fd.fd, 1) + return errors.New("err poll control") + }, + expectErr: errors.New("err poll control"), + }, + { + name: "err conn control after poll add", + connControlFunc: func(f func(fd uintptr)) error { + f(1) + return errors.New("err conn control after poll add") + }, + pollControlFunc: func(fd *fdOperator, op op) error { + if op == opDel { + expectDetach = true + } + assert.Equal(t, fd.fd, 1) + return nil + }, + expectErr: errors.New("err conn control after poll add"), + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + poll = &mockPoller{ + controlFunc: c.pollControlFunc, + } + conn := &mockConn{ + controlFunc: c.connControlFunc, + } + _, err := ListenConnState(conn) + assert.Equal(t, c.expectErr, err) + }) + } +} + +func BenchmarkListenConnState(b *testing.B) { + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + panic(err) + } + go func() { + for { + conn, err := ln.Accept() + assert.Nil(b, err) + go func(conn net.Conn) { + buf := make([]byte, 11) + _, err := conn.Read(buf) + assert.Nil(b, err) + conn.Close() + }(conn) + } + }() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn, err := net.Dial("tcp", ln.Addr().String()) + assert.Nil(b, err) + stater, err := ListenConnState(conn) + assert.Nil(b, err) + assert.Equal(b, StateOK, stater.State()) + _, err = conn.Write([]byte("hello world")) + assert.Nil(b, err) + buf := make([]byte, 1) + _, err = conn.Read(buf) + assert.Equal(b, io.EOF, err) + time.Sleep(100 * time.Millisecond) + assert.Equal(b, StateRemoteClosed, stater.State()) + assert.Nil(b, stater.Close()) + assert.Nil(b, conn.Close()) + assert.Equal(b, StateClosed, stater.State()) + } + }) +} + +type statefulConn struct { + net.Conn + stater ConnStater +} + +func (s *statefulConn) Close() error { + s.stater.Close() + return s.Conn.Close() +} + +type mockStater struct { +} + +func (m *mockStater) State() ConnState { + return StateOK +} + +func (m *mockStater) Close() error { + return nil +} + +type connpool struct { + mu sync.Mutex + conns []*statefulConn +} + +func (p *connpool) get(dialFunc func() *statefulConn) *statefulConn { + p.mu.Lock() + if len(p.conns) == 0 { + p.mu.Unlock() + return dialFunc() + } + for i := len(p.conns) - 1; i >= 0; i-- { + conn := p.conns[i] + if conn.stater.State() == StateOK { + p.conns = p.conns[:i] + p.mu.Unlock() + return conn + } else { + conn.Close() + } + } + p.conns = p.conns[:0] + p.mu.Unlock() + return dialFunc() +} + +func (p *connpool) put(conn *statefulConn) { + p.mu.Lock() + defer p.mu.Unlock() + p.conns = append(p.conns, conn) +} + +func (p *connpool) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + for _, conn := range p.conns { + conn.Close() + } + p.conns = p.conns[:0] + return nil +} + +var withListenConnState bool + +// BenchmarkWithConnState is used to verify the impact of adding ConnState logic on performance. +// To compare with syscall.EpollWait(), you could run `go test -bench=BenchmarkWith -benchtime=10s .` +// to test the first time, and replace isyscall.EpollWait() with syscall.EpollWait() to test the second time. +func BenchmarkWithConnState(b *testing.B) { + withListenConnState = true + benchmarkConnState(b) +} + +func BenchmarkWithoutConnState(b *testing.B) { + withListenConnState = false + benchmarkConnState(b) +} + +func benchmarkConnState(b *testing.B) { + // set GOMAXPROCS to 1 to make P resources scarce + runtime.GOMAXPROCS(1) + ln, err := net.Listen("tcp", "localhost:0") + if err != nil { + panic(err) + } + go func() { + for { + conn, err := ln.Accept() + assert.Nil(b, err) + go func(conn net.Conn) { + var count uint64 + for { + buf := make([]byte, 11) + _, err := conn.Read(buf) + if err != nil { + conn.Close() + return + } + _, err = conn.Write(buf) + if err != nil { + conn.Close() + return + } + count++ + if count == 1000 { + conn.Close() + return + } + } + }(conn) + } + }() + cp := &connpool{} + dialFunc := func() *statefulConn { + conn, err := net.Dial("tcp", ln.Addr().String()) + assert.Nil(b, err) + var stater ConnStater + if withListenConnState { + stater, err = ListenConnState(conn) + assert.Nil(b, err) + } else { + stater = &mockStater{} + } + return &statefulConn{ + Conn: conn, + stater: stater, + } + } + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + conn := cp.get(dialFunc) + buf := make([]byte, 11) + _, err := conn.Write(buf) + if err != nil { + conn.Close() + continue + } + _, err = conn.Read(buf) + if err != nil { + conn.Close() + continue + } + cp.put(conn) + } + }) + _ = cp.Close() +} diff --git a/connstate/poll.go b/connstate/poll.go new file mode 100644 index 0000000..f6c4264 --- /dev/null +++ b/connstate/poll.go @@ -0,0 +1,49 @@ +// Copyright 2025 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connstate + +import ( + "fmt" + "sync" +) + +type op int + +const ( + opAdd op = iota + opDel +) + +var ( + pollInitOnce sync.Once + poll poller +) + +type poller interface { + wait() error + control(fd *fdOperator, op op) error +} + +func createPoller() { + var err error + poll, err = openpoll() + if err != nil { + panic(fmt.Sprintf("gopkg.connstate openpoll failed, err: %v", err)) + } + go func() { + err := poll.wait() + fmt.Printf("gopkg.connstate epoll wait exit, err: %v\n", err) + }() +} diff --git a/connstate/poll_bsd.go b/connstate/poll_bsd.go new file mode 100644 index 0000000..f7a6876 --- /dev/null +++ b/connstate/poll_bsd.go @@ -0,0 +1,101 @@ +// Copyright 2025 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build darwin || netbsd || freebsd || openbsd || dragonfly +// +build darwin netbsd freebsd openbsd dragonfly + +package connstate + +import ( + "sync/atomic" + "syscall" + "time" + "unsafe" +) + +type kqueue struct { + fd int +} + +func (p *kqueue) wait() error { + events := make([]syscall.Kevent_t, 1024) + timeout := &syscall.Timespec{Sec: 0, Nsec: 0} + var n int + var err error + for { + // timeout=0 must be set to avoid getting stuck in a blocking syscall, + // which could occupy a P until runtime.sysmon thread handoff it. + n, err = syscall.Kevent(p.fd, nil, events, timeout) + if err != nil && err != syscall.EINTR { + // exit gracefully + if err == syscall.EBADF { + return nil + } + return err + } + if n <= 0 { + time.Sleep(10 * time.Millisecond) // avoid busy loop + continue + } + for i := 0; i < n; i++ { + ev := &events[i] + op := *(**fdOperator)(unsafe.Pointer(&ev.Udata)) + if conn := (*connStater)(atomic.LoadPointer(&op.conn)); conn != nil { + if ev.Flags&(syscall.EV_EOF) != 0 { + atomic.CompareAndSwapUint32(&conn.state, uint32(StateOK), uint32(StateRemoteClosed)) + } + } + } + // we can make sure that there is no op remaining if finished handling all events + pollcache.free() + } +} + +func (p *kqueue) control(fd *fdOperator, op op) error { + evs := make([]syscall.Kevent_t, 1) + evs[0].Ident = uint64(fd.fd) + *(**fdOperator)(unsafe.Pointer(&evs[0].Udata)) = fd + if op == opAdd { + evs[0].Filter = syscall.EVFILT_READ + evs[0].Flags = syscall.EV_ADD | syscall.EV_ENABLE | syscall.EV_CLEAR + // prevent ordinary data from triggering + evs[0].Flags |= syscall.EV_OOBAND + evs[0].Fflags = syscall.NOTE_LOWAT + evs[0].Data = 0x7FFFFFFF + _, err := syscall.Kevent(p.fd, evs, nil, nil) + return err + } else { + evs[0].Filter = syscall.EVFILT_READ + evs[0].Flags = syscall.EV_DELETE + _, err := syscall.Kevent(p.fd, evs, nil, nil) + return err + } +} + +func openpoll() (p poller, err error) { + fd, err := syscall.Kqueue() + if err != nil { + return nil, err + } + _, err = syscall.Kevent(fd, []syscall.Kevent_t{{ + Ident: 0, + Filter: syscall.EVFILT_USER, + Flags: syscall.EV_ADD | syscall.EV_CLEAR, + }}, nil, nil) + if err != nil { + syscall.Close(fd) + return nil, err + } + return &kqueue{fd: fd}, nil +} diff --git a/connstate/poll_cache.go b/connstate/poll_cache.go new file mode 100644 index 0000000..5a68520 --- /dev/null +++ b/connstate/poll_cache.go @@ -0,0 +1,87 @@ +// Copyright 2025 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connstate + +import ( + "sync" + "sync/atomic" + "unsafe" +) + +const pollBlockSize = 4 * 1024 + +type fdOperator struct { + link *fdOperator // in pollcache, protected by pollcache.lock + index int32 + + fd int + conn unsafe.Pointer // *connStater +} + +var pollcache pollCache + +type pollCache struct { + lock sync.Mutex + first *fdOperator + cache []*fdOperator + // freelist store the freeable operator + // to reduce GC pressure, we only store op index here + freelist []int32 + freeack int32 +} + +func (c *pollCache) alloc() *fdOperator { + c.lock.Lock() + if c.first == nil { + const pdSize = unsafe.Sizeof(fdOperator{}) + n := pollBlockSize / pdSize + if n == 0 { + n = 1 + } + index := int32(len(c.cache)) + for i := uintptr(0); i < n; i++ { + pd := &fdOperator{index: index} + c.cache = append(c.cache, pd) + pd.link = c.first + c.first = pd + index++ + } + } + op := c.first + c.first = op.link + c.lock.Unlock() + return op +} + +// freeable mark the operator that could be freed +// only poller could do the real free action +func (c *pollCache) freeable(op *fdOperator) { + c.lock.Lock() + // reset all state + if atomic.CompareAndSwapInt32(&c.freeack, 1, 0) { + for _, idx := range c.freelist { + op := c.cache[idx] + op.link = c.first + c.first = op + } + c.freelist = c.freelist[:0] + } + c.freelist = append(c.freelist, op.index) + c.lock.Unlock() +} + +func (c *pollCache) free() { + atomic.StoreInt32(&c.freeack, 1) +} diff --git a/connstate/poll_cache_test.go b/connstate/poll_cache_test.go new file mode 100644 index 0000000..e397c08 --- /dev/null +++ b/connstate/poll_cache_test.go @@ -0,0 +1,312 @@ +// Copyright 2025 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connstate + +import ( + "sync" + "sync/atomic" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPollCacheAlloc(t *testing.T) { + cache := &pollCache{} + + // Test initial allocation + op1 := cache.alloc() + require.NotNil(t, op1) + assert.GreaterOrEqual(t, op1.index, int32(0)) + assert.Equal(t, int(0), op1.fd) + + // Test multiple allocations + op2 := cache.alloc() + require.NotNil(t, op2) + assert.Equal(t, int(0), op2.fd) + + // Test that allocated operators are different + assert.NotEqual(t, op1, op2) + + // Verify they both have valid indices + assert.GreaterOrEqual(t, op1.index, int32(0)) + assert.GreaterOrEqual(t, op2.index, int32(0)) +} + +func TestPollCacheAllocReuse(t *testing.T) { + cache := &pollCache{} + + // Allocate all operators + var ops []*fdOperator + for i := 0; i < 10; i++ { + op := cache.alloc() + require.NotNil(t, op) + ops = append(ops, op) + } + + // Mark some as freeable + for i := 0; i < 5; i++ { + cache.freeable(ops[i]) + } + + // Set freeack to trigger cleanup + cache.free() + + // Allocate again, should reuse freed operators + reusedOp := cache.alloc() + require.NotNil(t, reusedOp) + + // The reused operator should have a high index (from cache) + assert.GreaterOrEqual(t, reusedOp.index, int32(10)) +} + +func TestPollCacheFreeable(t *testing.T) { + cache := &pollCache{} + + // Allocate operators + op1 := cache.alloc() + op2 := cache.alloc() + + require.NotNil(t, op1) + require.NotNil(t, op2) + + // Mark operators as freeable + cache.freeable(op1) + cache.freeable(op2) + + // Verify they are in freelist + cache.lock.Lock() + assert.Len(t, cache.freelist, 2) + assert.Contains(t, cache.freelist, op1.index) + assert.Contains(t, cache.freelist, op2.index) + cache.lock.Unlock() +} + +func TestPollCacheFree(t *testing.T) { + cache := &pollCache{} + + // Allocate and mark operators as freeable + var ops []*fdOperator + for i := 0; i < 5; i++ { + op := cache.alloc() + require.NotNil(t, op) + ops = append(ops, op) + cache.freeable(op) + } + + // Verify they are in freelist + cache.lock.Lock() + freelistLen := len(cache.freelist) + cache.lock.Unlock() + assert.Equal(t, 5, freelistLen) + + // Set freeack flag + cache.free() + + // Verify freeack is set + assert.Equal(t, int32(1), atomic.LoadInt32(&cache.freeack)) + + // Call freeable again to trigger cleanup + cache.freeable(ops[0]) + + // Verify freelist was cleared (should be 1 for the newly added operator) + cache.lock.Lock() + finalFreelistLen := len(cache.freelist) + cache.lock.Unlock() + assert.Equal(t, 1, finalFreelistLen) // Only the newly added operator +} + +func TestPollCacheConcurrent(t *testing.T) { + cache := &pollCache{} + + const numGoroutines = 10 + const numAllocations = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + // Concurrent allocations + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + + var ops []*fdOperator + for j := 0; j < numAllocations; j++ { + op := cache.alloc() + if op != nil { + ops = append(ops, op) + } + + // Randomly mark some as freeable + if j%3 == 0 && len(ops) > 0 { + freeableOp := ops[0] + ops = ops[1:] + cache.freeable(freeableOp) + } + } + + // Mark remaining as freeable + for _, op := range ops { + cache.freeable(op) + } + }() + } + + wg.Wait() + + // Verify cache is still functional + finalOp := cache.alloc() + require.NotNil(t, finalOp) +} + +func TestFDOperatorFields(t *testing.T) { + op := &fdOperator{ + index: 42, + fd: 123, + } + + assert.Equal(t, int32(42), op.index) + assert.Equal(t, int(123), op.fd) + assert.Nil(t, op.link) + assert.Nil(t, op.conn) +} + +func TestFDOperatorSize(t *testing.T) { + // Test that fdOperator has consistent size + size1 := unsafe.Sizeof(fdOperator{}) + size2 := unsafe.Sizeof(fdOperator{}) + assert.Equal(t, size1, size2) + + // Should have reasonable size (not too large, not too small) + assert.Greater(t, size1, uintptr(16)) // At least contains fields + assert.Less(t, size1, uintptr(256)) // Not excessively large +} + +func TestPollCacheBlockAllocation(t *testing.T) { + cache := &pollCache{} + + // Calculate expected number of operators per block + pdSize := unsafe.Sizeof(fdOperator{}) + expectedPerBlock := pollBlockSize / pdSize + if expectedPerBlock == 0 { + expectedPerBlock = 1 + } + + // Allocate more than one block worth + var ops []*fdOperator + allocations := int(expectedPerBlock) + 10 + + for i := 0; i < allocations; i++ { + op := cache.alloc() + require.NotNil(t, op, "Allocation %d should succeed", i) + ops = append(ops, op) + } + + // Verify all have unique indices + indices := make(map[int32]struct{}) + for _, op := range ops { + _, exists := indices[op.index] + assert.False(t, exists, "Index %d should be unique", op.index) + indices[op.index] = struct{}{} + } +} + +func TestPollCacheFreeAckRace(t *testing.T) { + cache := &pollCache{} + + const numOperations = 1000 + var freeCount int64 + + // Start goroutine that marks operators as freeable + go func() { + for i := 0; i < numOperations; i++ { + op := cache.alloc() + if op != nil { + cache.freeable(op) + atomic.AddInt64(&freeCount, 1) + } + time.Sleep(time.Microsecond) // Small delay to increase race chance + } + }() + + // Start goroutine that calls free() periodically + go func() { + for i := 0; i < numOperations/10; i++ { + cache.free() + time.Sleep(10 * time.Microsecond) + } + }() + + time.Sleep(100 * time.Millisecond) // Let goroutines work + + // Verify no panics or corruption + finalOp := cache.alloc() + require.NotNil(t, finalOp) + + // Verify some operations completed + assert.Greater(t, atomic.LoadInt64(&freeCount), int64(0)) +} + +func BenchmarkPollCacheAlloc(b *testing.B) { + cache := &pollCache{} + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + op := cache.alloc() + if op != nil { + // Simulate some usage + op.fd = 42 + op.index = 1 + } + } + }) +} + +func BenchmarkPollCacheFreeable(b *testing.B) { + cache := &pollCache{} + + // Pre-allocate some operators + ops := make([]*fdOperator, 1000) + for i := range ops { + ops[i] = cache.alloc() + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + cache.freeable(ops[i%len(ops)]) + i++ + } + }) +} + +func BenchmarkPollCacheAllocFreeCycle(b *testing.B) { + cache := &pollCache{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + op := cache.alloc() + if op != nil { + cache.freeable(op) + if i%100 == 0 { + cache.free() // Trigger cleanup occasionally + } + } + } +} diff --git a/connstate/poll_linux.go b/connstate/poll_linux.go new file mode 100644 index 0000000..47e7f13 --- /dev/null +++ b/connstate/poll_linux.go @@ -0,0 +1,79 @@ +// Copyright 2025 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connstate + +import ( + "sync/atomic" + "syscall" + "time" + "unsafe" +) + +const _EPOLLET uint32 = 0x80000000 + +type epoller struct { + epfd int +} + +//go:nocheckptr +func (p *epoller) wait() error { + events := make([]syscall.EpollEvent, 1024) + var n int + var err error + for { + // timeout=0 must be set to avoid getting stuck in a blocking syscall, + // which could occupy a P until runtime.sysmon thread handoff it. + n, err = syscall.EpollWait(p.epfd, events, 0) + if err != nil && err != syscall.EINTR { + return err + } + if n <= 0 { + time.Sleep(10 * time.Millisecond) // avoid busy loop + continue + } + for i := 0; i < n; i++ { + ev := &events[i] + op := *(**fdOperator)(unsafe.Pointer(&ev.Fd)) + if conn := (*connStater)(atomic.LoadPointer(&op.conn)); conn != nil { + if ev.Events&(syscall.EPOLLHUP|syscall.EPOLLRDHUP|syscall.EPOLLERR) != 0 { + atomic.CompareAndSwapUint32(&conn.state, uint32(StateOK), uint32(StateRemoteClosed)) + } + } + } + // we can make sure that there is no op remaining if finished handling all events + pollcache.free() + } +} + +func (p *epoller) control(fd *fdOperator, op op) error { + if op == opAdd { + var ev syscall.EpollEvent + *(**fdOperator)(unsafe.Pointer(&ev.Fd)) = fd + ev.Events = syscall.EPOLLHUP | syscall.EPOLLRDHUP | syscall.EPOLLERR | _EPOLLET + return syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_ADD, fd.fd, &ev) + } else { + var ev syscall.EpollEvent + return syscall.EpollCtl(p.epfd, syscall.EPOLL_CTL_DEL, fd.fd, &ev) + } +} + +func openpoll() (p poller, err error) { + var epfd int + epfd, err = syscall.EpollCreate(1) + if err != nil { + return nil, err + } + return &epoller{epfd: epfd}, nil +} diff --git a/connstate/poll_windows.go b/connstate/poll_windows.go new file mode 100644 index 0000000..ad9404d --- /dev/null +++ b/connstate/poll_windows.go @@ -0,0 +1,35 @@ +// Copyright 2025 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package connstate + +import "errors" + +var ( + errNotSupportedForWindows = errors.New("connstate not supported for windows") +) + +type mockWindowsPoller struct{} + +func (m *mockWindowsPoller) wait() error { + return nil +} + +func (m *mockWindowsPoller) control(fd *fdOperator, op op) error { + return errNotSupportedForWindows +} + +func openpoll() (p poller, err error) { + return &mockWindowsPoller{}, nil +}