66 "errors"
77 "io"
88 "net/netip"
9+ "os"
910 "syscall"
1011
1112 "github.com/sagernet/sing/common/buf"
@@ -25,10 +26,11 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
2526 bufferSize = buf .BufferSize
2627 }
2728 var (
28- buffer * buf.Buffer
29- readBuffer * buf.Buffer
29+ buffer * buf.Buffer
30+ readBuffer * buf.Buffer
31+ notFirstTime bool
3032 )
31- newBuffer := func () * buf.Buffer {
33+ source . InitializeReadWaiter ( func () * buf.Buffer {
3234 if buffer != nil {
3335 buffer .Release ()
3436 }
@@ -37,10 +39,10 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
3739 readBuffer = buf .With (readBufferRaw [:len (readBufferRaw )- rearHeadroom ])
3840 readBuffer .Resize (frontHeadroom , 0 )
3941 return readBuffer
40- }
41- var notFirstTime bool
42+ })
43+ defer source . InitializeReadWaiter ( nil )
4244 for {
43- err = source .WaitReadBuffer (newBuffer )
45+ err = source .WaitReadBuffer ()
4446 if err != nil {
4547 buffer .Release ()
4648 if errors .Is (err , io .EOF ) {
@@ -55,10 +57,8 @@ func copyWaitWithPool(originDestination io.Writer, destination N.ExtendedWriter,
5557 dataLen := readBuffer .Len ()
5658 buffer .Resize (readBuffer .Start (), dataLen )
5759 err = destination .WriteBuffer (buffer )
60+ buffer .Release ()
5861 if err != nil {
59- if buffer != nil {
60- buffer .Release ()
61- }
6262 return
6363 }
6464 n += int64 (dataLen )
@@ -83,10 +83,12 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
8383 bufferSize = buf .UDPBufferSize
8484 }
8585 var (
86- buffer * buf.Buffer
87- readBuffer * buf.Buffer
86+ buffer * buf.Buffer
87+ readBuffer * buf.Buffer
88+ destination M.Socksaddr
89+ notFirstTime bool
8890 )
89- newBuffer := func () * buf.Buffer {
91+ source . InitializeReadWaiter ( func () * buf.Buffer {
9092 if buffer != nil {
9193 buffer .Release ()
9294 }
@@ -95,11 +97,10 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
9597 readBuffer = buf .With (readBufferRaw [:len (readBufferRaw )- rearHeadroom ])
9698 readBuffer .Resize (frontHeadroom , 0 )
9799 return readBuffer
98- }
99- var destination M.Socksaddr
100- var notFirstTime bool
100+ })
101+ defer source .InitializeReadWaiter (nil )
101102 for {
102- destination , err = source .WaitReadPacket (newBuffer )
103+ destination , err = source .WaitReadPacket ()
103104 if err != nil {
104105 buffer .Release ()
105106 if ! notFirstTime {
@@ -113,9 +114,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
113114 if err != nil {
114115 buffer .Release ()
115116 return
116- } else {
117- buffer = nil
118117 }
118+ buffer = nil
119119 n += int64 (dataLen )
120120 for _ , counter := range readCounters {
121121 counter (int64 (dataLen ))
@@ -127,6 +127,8 @@ func copyPacketWaitWithPool(destinationConn N.PacketWriter, source N.PacketReadW
127127 }
128128}
129129
130+ var _ N.ReadWaiter = (* syscallReadWaiter )(nil )
131+
130132type syscallReadWaiter struct {
131133 rawConn syscall.RawConn
132134 readErr error
@@ -143,8 +145,11 @@ func createSyscallReadWaiter(reader any) (*syscallReadWaiter, bool) {
143145 return nil , false
144146}
145147
146- func (w * syscallReadWaiter ) WaitReadBuffer (newBuffer func () * buf.Buffer ) error {
147- if w .readFunc == nil {
148+ func (w * syscallReadWaiter ) InitializeReadWaiter (newBuffer func () * buf.Buffer ) {
149+ w .readErr = nil
150+ if newBuffer == nil {
151+ w .readFunc = nil
152+ } else {
148153 w .readFunc = func (fd uintptr ) (done bool ) {
149154 buffer := newBuffer ()
150155 var readN int
@@ -164,16 +169,27 @@ func (w *syscallReadWaiter) WaitReadBuffer(newBuffer func() *buf.Buffer) error {
164169 return true
165170 }
166171 }
172+ }
173+
174+ func (w * syscallReadWaiter ) WaitReadBuffer () error {
175+ if w .readFunc == nil {
176+ return os .ErrInvalid
177+ }
167178 err := w .rawConn .Read (w .readFunc )
168179 if err != nil {
169180 return err
170181 }
171182 if w .readErr != nil {
183+ if w .readErr == io .EOF {
184+ return io .EOF
185+ }
172186 return E .Cause (w .readErr , "raw read" )
173187 }
174188 return nil
175189}
176190
191+ var _ N.PacketReadWaiter = (* syscallPacketReadWaiter )(nil )
192+
177193type syscallPacketReadWaiter struct {
178194 rawConn syscall.RawConn
179195 readErr error
@@ -191,8 +207,12 @@ func createSyscallPacketReadWaiter(reader any) (*syscallPacketReadWaiter, bool)
191207 return nil , false
192208}
193209
194- func (w * syscallPacketReadWaiter ) WaitReadPacket (newBuffer func () * buf.Buffer ) (destination M.Socksaddr , err error ) {
195- if w .readFunc == nil {
210+ func (w * syscallPacketReadWaiter ) InitializeReadWaiter (newBuffer func () * buf.Buffer ) {
211+ w .readErr = nil
212+ w .readFrom = M.Socksaddr {}
213+ if newBuffer == nil {
214+ w .readFunc = nil
215+ } else {
196216 w .readFunc = func (fd uintptr ) (done bool ) {
197217 buffer := newBuffer ()
198218 var readN int
@@ -221,6 +241,12 @@ func (w *syscallPacketReadWaiter) WaitReadPacket(newBuffer func() *buf.Buffer) (
221241 return true
222242 }
223243 }
244+ }
245+
246+ func (w * syscallPacketReadWaiter ) WaitReadPacket () (destination M.Socksaddr , err error ) {
247+ if w .readFunc == nil {
248+ return M.Socksaddr {}, os .ErrInvalid
249+ }
224250 err = w .rawConn .Read (w .readFunc )
225251 if err != nil {
226252 return
0 commit comments