@@ -15,7 +15,13 @@ import (
1515 "github.com/sagernet/sing/common/task"
1616)
1717
18+ const DefaultIncreaseBufferAfter = 512 * 1000
19+
1820func Copy (destination io.Writer , source io.Reader ) (n int64 , err error ) {
21+ return CopyWithIncreateBuffer (destination , source , DefaultIncreaseBufferAfter )
22+ }
23+
24+ func CopyWithIncreateBuffer (destination io.Writer , source io.Reader , increaseBufferAfter int64 ) (n int64 , err error ) {
1925 if source == nil {
2026 return 0 , E .New ("nil reader" )
2127 } else if destination == nil {
@@ -46,10 +52,10 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
4652 }
4753 break
4854 }
49- return CopyWithCounters (destination , source , originSource , readCounters , writeCounters )
55+ return CopyWithCounters (destination , source , originSource , readCounters , writeCounters , increaseBufferAfter )
5056}
5157
52- func CopyWithCounters (destination io.Writer , source io.Reader , originSource io.Reader , readCounters []N.CountFunc , writeCounters []N.CountFunc ) (n int64 , err error ) {
58+ func CopyWithCounters (destination io.Writer , source io.Reader , originSource io.Reader , readCounters []N.CountFunc , writeCounters []N.CountFunc , increaseBufferAfter int64 ) (n int64 , err error ) {
5359 srcSyscallConn , srcIsSyscall := source .(syscall.Conn )
5460 dstSyscallConn , dstIsSyscall := destination .(syscall.Conn )
5561 if srcIsSyscall && dstIsSyscall {
@@ -59,10 +65,10 @@ func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.R
5965 return
6066 }
6167 }
62- return CopyExtended (originSource , NewExtendedWriter (destination ), NewExtendedReader (source ), readCounters , writeCounters )
68+ return CopyExtended (originSource , NewExtendedWriter (destination ), NewExtendedReader (source ), readCounters , writeCounters , increaseBufferAfter )
6369}
6470
65- func CopyExtended (originSource io.Reader , destination N.ExtendedWriter , source N.ExtendedReader , readCounters []N.CountFunc , writeCounters []N.CountFunc ) (n int64 , err error ) {
71+ func CopyExtended (originSource io.Reader , destination N.ExtendedWriter , source N.ExtendedReader , readCounters []N.CountFunc , writeCounters []N.CountFunc , increaseBufferAfter int64 ) (n int64 , err error ) {
6672 frontHeadroom := N .CalculateFrontHeadroom (destination )
6773 rearHeadroom := N .CalculateRearHeadroom (destination )
6874 readWaiter , isReadWaiter := CreateReadWaiter (source )
@@ -74,13 +80,13 @@ func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N
7480 })
7581 if ! needCopy || common .LowMemory {
7682 var handled bool
77- handled , n , err = copyWaitWithPool (originSource , destination , readWaiter , readCounters , writeCounters )
83+ handled , n , err = copyWaitWithPool (originSource , destination , source , readWaiter , readCounters , writeCounters , increaseBufferAfter )
7884 if handled {
7985 return
8086 }
8187 }
8288 }
83- return CopyExtendedWithPool (originSource , destination , source , readCounters , writeCounters )
89+ return CopyExtendedWithPool (originSource , destination , source , readCounters , writeCounters , increaseBufferAfter )
8490}
8591
8692// Deprecated: not used
@@ -121,7 +127,7 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
121127 }
122128}
123129
124- func CopyExtendedWithPool (originSource io.Reader , destination N.ExtendedWriter , source N.ExtendedReader , readCounters []N.CountFunc , writeCounters []N.CountFunc ) (n int64 , err error ) {
130+ func CopyExtendedWithPool (originSource io.Reader , destination N.ExtendedWriter , source N.ExtendedReader , readCounters []N.CountFunc , writeCounters []N.CountFunc , increaseBufferAfter int64 ) (n int64 , err error ) {
125131 options := N .NewReadWaitOptions (source , destination )
126132 var notFirstTime bool
127133 for {
@@ -153,6 +159,53 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
153159 counter (int64 (dataLen ))
154160 }
155161 notFirstTime = true
162+ if increaseBufferAfter > 0 && n >= increaseBufferAfter {
163+ return CopyExtendedChanWithPool (destination , source , readCounters , writeCounters , options , n )
164+ }
165+ }
166+ }
167+
168+ func CopyExtendedChanWithPool (destination N.ExtendedWriter , source N.ExtendedReader , readCounters []N.CountFunc , writeCounters []N.CountFunc , options N.ReadWaitOptions , inputN int64 ) (n int64 , err error ) {
169+ n += inputN
170+ sendChan := make (chan * buf.Buffer , 2 )
171+ errChan := make (chan error , 1 )
172+ go func () {
173+ for {
174+ buffer := options .NewBufferMax ()
175+ readErr := source .ReadBuffer (buffer )
176+ if readErr != nil {
177+ buffer .Release ()
178+ if errors .Is (readErr , io .EOF ) {
179+ errChan <- nil
180+ } else {
181+ errChan <- readErr
182+ }
183+ return
184+ }
185+ dataLen := buffer .Len ()
186+ options .PostReturn (buffer )
187+ sendChan <- buffer
188+ n += int64 (dataLen )
189+ for _ , counter := range readCounters {
190+ counter (int64 (dataLen ))
191+ }
192+ }
193+ }()
194+ for {
195+ select {
196+ case buffer := <- sendChan :
197+ dataLen := buffer .Len ()
198+ err = destination .WriteBuffer (buffer )
199+ if err != nil {
200+ buffer .Leak ()
201+ return
202+ }
203+ for _ , counter := range writeCounters {
204+ counter (int64 (dataLen ))
205+ }
206+ case err = <- errChan :
207+ return
208+ }
156209 }
157210}
158211
0 commit comments