Skip to content

Commit da98137

Browse files
committed
Improve Copy
1 parent 716ee8a commit da98137

File tree

3 files changed

+71
-9
lines changed

3 files changed

+71
-9
lines changed

common/bufio/copy.go

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@ import (
1515
"github.com/sagernet/sing/common/task"
1616
)
1717

18+
const DefaultIncreaseBufferAfter = 512 * 1000
19+
1820
func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
21+
return CopyWithIncreateBuffer(destination, source, DefaultIncreaseBufferAfter)
22+
}
23+
24+
func CopyWithIncreateBuffer(destination io.Writer, source io.Reader, increaseBufferAfter int64) (n int64, err error) {
1925
if source == nil {
2026
return 0, E.New("nil reader")
2127
} else if destination == nil {
@@ -46,10 +52,10 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
4652
}
4753
break
4854
}
49-
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
55+
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters, increaseBufferAfter)
5056
}
5157

52-
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
58+
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc, increaseBufferAfter int64) (n int64, err error) {
5359
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
5460
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
5561
if srcIsSyscall && dstIsSyscall {
@@ -59,10 +65,10 @@ func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.R
5965
return
6066
}
6167
}
62-
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
68+
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters, increaseBufferAfter)
6369
}
6470

65-
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
71+
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, increaseBufferAfter int64) (n int64, err error) {
6672
frontHeadroom := N.CalculateFrontHeadroom(destination)
6773
rearHeadroom := N.CalculateRearHeadroom(destination)
6874
readWaiter, isReadWaiter := CreateReadWaiter(source)
@@ -74,13 +80,13 @@ func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N
7480
})
7581
if !needCopy || common.LowMemory {
7682
var handled bool
77-
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
83+
handled, n, err = copyWaitWithPool(originSource, destination, source, readWaiter, readCounters, writeCounters, increaseBufferAfter)
7884
if handled {
7985
return
8086
}
8187
}
8288
}
83-
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
89+
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters, increaseBufferAfter)
8490
}
8591

8692
// Deprecated: not used
@@ -121,7 +127,7 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
121127
}
122128
}
123129

124-
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
130+
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, increaseBufferAfter int64) (n int64, err error) {
125131
options := N.NewReadWaitOptions(source, destination)
126132
var notFirstTime bool
127133
for {
@@ -153,6 +159,53 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
153159
counter(int64(dataLen))
154160
}
155161
notFirstTime = true
162+
if increaseBufferAfter > 0 && n >= increaseBufferAfter {
163+
return CopyExtendedChanWithPool(destination, source, readCounters, writeCounters, options, n)
164+
}
165+
}
166+
}
167+
168+
func CopyExtendedChanWithPool(destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, options N.ReadWaitOptions, inputN int64) (n int64, err error) {
169+
n += inputN
170+
sendChan := make(chan *buf.Buffer, 2)
171+
errChan := make(chan error, 1)
172+
go func() {
173+
for {
174+
buffer := options.NewBufferMax()
175+
readErr := source.ReadBuffer(buffer)
176+
if readErr != nil {
177+
buffer.Release()
178+
if errors.Is(readErr, io.EOF) {
179+
errChan <- nil
180+
} else {
181+
errChan <- readErr
182+
}
183+
return
184+
}
185+
dataLen := buffer.Len()
186+
options.PostReturn(buffer)
187+
sendChan <- buffer
188+
n += int64(dataLen)
189+
for _, counter := range readCounters {
190+
counter(int64(dataLen))
191+
}
192+
}
193+
}()
194+
for {
195+
select {
196+
case buffer := <-sendChan:
197+
dataLen := buffer.Len()
198+
err = destination.WriteBuffer(buffer)
199+
if err != nil {
200+
buffer.Leak()
201+
return
202+
}
203+
for _, counter := range writeCounters {
204+
counter(int64(dataLen))
205+
}
206+
case err = <-errChan:
207+
return
208+
}
156209
}
157210
}
158211

common/bufio/copy_direct.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.
2323
return
2424
}
2525

26-
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
26+
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readWaiter N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, increaseBufferAfter int64) (handled bool, n int64, err error) {
2727
handled = true
2828
var (
2929
buffer *buf.Buffer
3030
notFirstTime bool
3131
)
3232
for {
33-
buffer, err = source.WaitReadBuffer()
33+
buffer, err = readWaiter.WaitReadBuffer()
3434
if err != nil {
3535
if errors.Is(err, io.EOF) {
3636
err = nil
@@ -55,6 +55,10 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour
5555
counter(int64(dataLen))
5656
}
5757
notFirstTime = true
58+
if increaseBufferAfter > 0 && n >= increaseBufferAfter {
59+
n, err = CopyExtendedChanWithPool(destination, source, readCounters, writeCounters, N.NewReadWaitOptions(source, destination), n)
60+
return
61+
}
5862
}
5963
}
6064

common/network/direct.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ func (o ReadWaitOptions) NewBuffer() *buf.Buffer {
4343
return o.newBuffer(buf.BufferSize, true)
4444
}
4545

46+
func (o ReadWaitOptions) NewBufferMax() *buf.Buffer {
47+
const maxBufferSize = 64<<10 - 1
48+
return o.newBuffer(maxBufferSize, true)
49+
}
50+
4651
func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer {
4752
return o.newBuffer(buf.UDPBufferSize, true)
4853
}

0 commit comments

Comments
 (0)