Skip to content

Commit 751d9b9

Browse files
committed
Improve Copy
1 parent 716ee8a commit 751d9b9

File tree

3 files changed

+78
-13
lines changed

3 files changed

+78
-13
lines changed

common/bufio/copy.go

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ import (
1616
)
1717

1818
func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
19+
return CopyWithIncreateBuffer(destination, source, 0)
20+
}
21+
22+
func CopyWithIncreateBuffer(destination io.Writer, source io.Reader, increaseBufferAfter int64) (n int64, err error) {
1923
if source == nil {
2024
return 0, E.New("nil reader")
2125
} else if destination == nil {
@@ -46,10 +50,10 @@ func Copy(destination io.Writer, source io.Reader) (n int64, err error) {
4650
}
4751
break
4852
}
49-
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
53+
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters, increaseBufferAfter)
5054
}
5155

52-
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
56+
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc, increaseBufferAfter int64) (n int64, err error) {
5357
srcSyscallConn, srcIsSyscall := source.(syscall.Conn)
5458
dstSyscallConn, dstIsSyscall := destination.(syscall.Conn)
5559
if srcIsSyscall && dstIsSyscall {
@@ -59,10 +63,10 @@ func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.R
5963
return
6064
}
6165
}
62-
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters)
66+
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters, increaseBufferAfter)
6367
}
6468

65-
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
69+
func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, increaseBufferAfter int64) (n int64, err error) {
6670
frontHeadroom := N.CalculateFrontHeadroom(destination)
6771
rearHeadroom := N.CalculateRearHeadroom(destination)
6872
readWaiter, isReadWaiter := CreateReadWaiter(source)
@@ -74,13 +78,13 @@ func CopyExtended(originSource io.Reader, destination N.ExtendedWriter, source N
7478
})
7579
if !needCopy || common.LowMemory {
7680
var handled bool
77-
handled, n, err = copyWaitWithPool(originSource, destination, readWaiter, readCounters, writeCounters)
81+
handled, n, err = copyWaitWithPool(originSource, destination, source, readWaiter, readCounters, writeCounters, increaseBufferAfter)
7882
if handled {
7983
return
8084
}
8185
}
8286
}
83-
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters)
87+
return CopyExtendedWithPool(originSource, destination, source, readCounters, writeCounters, increaseBufferAfter)
8488
}
8589

8690
// Deprecated: not used
@@ -121,7 +125,7 @@ func CopyExtendedBuffer(originSource io.Writer, destination N.ExtendedWriter, so
121125
}
122126
}
123127

124-
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc) (n int64, err error) {
128+
func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, increaseBufferAfter int64) (n int64, err error) {
125129
options := N.NewReadWaitOptions(source, destination)
126130
var notFirstTime bool
127131
for {
@@ -153,14 +157,66 @@ func CopyExtendedWithPool(originSource io.Reader, destination N.ExtendedWriter,
153157
counter(int64(dataLen))
154158
}
155159
notFirstTime = true
160+
if increaseBufferAfter > 0 && n >= increaseBufferAfter {
161+
return CopyExtendedChanWithPool(destination, source, readCounters, writeCounters, options, n)
162+
}
156163
}
157164
}
158165

166+
func CopyExtendedChanWithPool(destination N.ExtendedWriter, source N.ExtendedReader, readCounters []N.CountFunc, writeCounters []N.CountFunc, options N.ReadWaitOptions, inputN int64) (n int64, err error) {
167+
n += inputN
168+
sendChan := make(chan *buf.Buffer, 2)
169+
errChan := make(chan error, 1)
170+
go func() {
171+
for {
172+
buffer := options.NewBufferMax()
173+
readErr := source.ReadBuffer(buffer)
174+
if readErr != nil {
175+
buffer.Release()
176+
if errors.Is(readErr, io.EOF) {
177+
errChan <- nil
178+
} else {
179+
errChan <- readErr
180+
}
181+
return
182+
}
183+
dataLen := buffer.Len()
184+
options.PostReturn(buffer)
185+
sendChan <- buffer
186+
n += int64(dataLen)
187+
for _, counter := range readCounters {
188+
counter(int64(dataLen))
189+
}
190+
}
191+
}()
192+
for {
193+
select {
194+
case buffer := <-sendChan:
195+
dataLen := buffer.Len()
196+
err = destination.WriteBuffer(buffer)
197+
if err != nil {
198+
buffer.Leak()
199+
return
200+
}
201+
for _, counter := range writeCounters {
202+
counter(int64(dataLen))
203+
}
204+
case err = <-errChan:
205+
return
206+
}
207+
}
208+
}
209+
210+
const (
211+
DefaultIncreaseUploadBufferAfter = 16 * 1000 // 16KB
212+
DefaultIncreaseDownloadBufferAfter = 1 * 1000 * 1000 // 1MB
213+
)
214+
159215
func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error {
160216
var group task.Group
161217
if _, dstDuplex := common.Cast[N.WriteCloser](destination); dstDuplex {
162218
group.Append("upload", func(ctx context.Context) error {
163-
err := common.Error(Copy(destination, source))
219+
err := common.Error(CopyWithIncreateBuffer(destination, source, DefaultIncreaseUploadBufferAfter))
164220
if err == nil {
165221
N.CloseWrite(destination)
166222
} else {
@@ -171,12 +227,12 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error
171227
} else {
172228
group.Append("upload", func(ctx context.Context) error {
173229
defer common.Close(destination)
174-
return common.Error(Copy(destination, source))
230+
return common.Error(CopyWithIncreateBuffer(destination, source, DefaultIncreaseUploadBufferAfter))
175231
})
176232
}
177233
if _, srcDuplex := common.Cast[N.WriteCloser](source); srcDuplex {
178234
group.Append("download", func(ctx context.Context) error {
179-
err := common.Error(Copy(source, destination))
235+
err := common.Error(CopyWithIncreateBuffer(source, destination, DefaultIncreaseDownloadBufferAfter))
180236
if err == nil {
181237
N.CloseWrite(source)
182238
} else {
@@ -187,7 +243,7 @@ func CopyConn(ctx context.Context, source net.Conn, destination net.Conn) error
187243
} else {
188244
group.Append("download", func(ctx context.Context) error {
189245
defer common.Close(source)
190-
return common.Error(Copy(source, destination))
246+
return common.Error(CopyWithIncreateBuffer(source, destination, DefaultIncreaseDownloadBufferAfter))
191247
})
192248
}
193249
group.Cleanup(func() {

common/bufio/copy_direct.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@ func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.
2323
return
2424
}
2525

26-
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handled bool, n int64, err error) {
26+
func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, source N.ExtendedReader, readWaiter N.ReadWaiter, readCounters []N.CountFunc, writeCounters []N.CountFunc, increaseBufferAfter int64) (handled bool, n int64, err error) {
2727
handled = true
2828
var (
2929
buffer *buf.Buffer
3030
notFirstTime bool
3131
)
3232
for {
33-
buffer, err = source.WaitReadBuffer()
33+
buffer, err = readWaiter.WaitReadBuffer()
3434
if err != nil {
3535
if errors.Is(err, io.EOF) {
3636
err = nil
@@ -55,6 +55,10 @@ func copyWaitWithPool(originSource io.Reader, destination N.ExtendedWriter, sour
5555
counter(int64(dataLen))
5656
}
5757
notFirstTime = true
58+
if increaseBufferAfter > 0 && n >= increaseBufferAfter {
59+
n, err = CopyExtendedChanWithPool(destination, source, readCounters, writeCounters, N.NewReadWaitOptions(source, destination), n)
60+
return
61+
}
5862
}
5963
}
6064

common/network/direct.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ func (o ReadWaitOptions) NewBuffer() *buf.Buffer {
4343
return o.newBuffer(buf.BufferSize, true)
4444
}
4545

46+
func (o ReadWaitOptions) NewBufferMax() *buf.Buffer {
47+
const maxBufferSize = 64<<10 - 1
48+
return o.newBuffer(maxBufferSize, true)
49+
}
50+
4651
func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer {
4752
return o.newBuffer(buf.UDPBufferSize, true)
4853
}

0 commit comments

Comments
 (0)