Skip to content

Commit acaffe9

Browse files
committed
Add Syscall[Read/Write]Creator interface
1 parent 0512216 commit acaffe9

File tree

3 files changed

+103
-23
lines changed

3 files changed

+103
-23
lines changed

common/bufio/copy.go

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"errors"
66
"io"
77
"net"
8-
"syscall"
98

109
"github.com/sagernet/sing/common"
1110
"github.com/sagernet/sing/common/buf"
@@ -58,25 +57,11 @@ func CopyWithIncreateBuffer(destination io.Writer, source io.Reader, increaseBuf
5857
return CopyWithCounters(destination, source, originSource, readCounters, writeCounters, increaseBufferAfter, batchSize)
5958
}
6059

61-
type syscallReader interface {
62-
io.Reader
63-
syscall.Conn
64-
}
65-
66-
type syscallWriter interface {
67-
io.Writer
68-
syscall.Conn
69-
}
70-
7160
func CopyWithCounters(destination io.Writer, source io.Reader, originSource io.Reader, readCounters []N.CountFunc, writeCounters []N.CountFunc, increaseBufferAfter int64, batchSize int) (n int64, err error) {
72-
srcSyscallConn, srcIsSyscall := N.CastReader[syscallReader](source)
73-
dstSyscallConn, dstIsSyscall := N.CastWriter[syscallWriter](destination)
74-
if srcIsSyscall && dstIsSyscall {
75-
var handled bool
76-
handled, n, err = copyDirect(srcSyscallConn, dstSyscallConn, readCounters, writeCounters)
77-
if handled {
78-
return
79-
}
61+
var handled bool
62+
handled, n, err = copyDirect(source, destination, readCounters, writeCounters)
63+
if handled {
64+
return
8065
}
8166
return CopyExtended(originSource, NewExtendedWriter(destination), NewExtendedReader(source), readCounters, writeCounters, increaseBufferAfter, batchSize)
8267
}

common/bufio/copy_direct.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,26 @@ package bufio
33
import (
44
"errors"
55
"io"
6-
"syscall"
76

87
"github.com/sagernet/sing/common/buf"
98
M "github.com/sagernet/sing/common/metadata"
109
N "github.com/sagernet/sing/common/network"
1110
)
1211

13-
func copyDirect(source syscall.Conn, destination syscall.Conn, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
14-
rawSource, err := source.SyscallConn()
12+
func copyDirect(source io.Reader, destination io.Writer, readCounters []N.CountFunc, writeCounters []N.CountFunc) (handed bool, n int64, err error) {
13+
if !N.SyscallAvailableForRead(source) || !N.SyscallAvailableForWrite(destination) {
14+
return
15+
}
16+
sourceConn := N.SyscallConnForRead(source)
17+
destinationConn := N.SyscallConnForWrite(destination)
18+
if sourceConn == nil || destinationConn == nil {
19+
return
20+
}
21+
rawSource, err := sourceConn.SyscallConn()
1522
if err != nil {
1623
return
1724
}
18-
rawDestination, err := destination.SyscallConn()
25+
rawDestination, err := destinationConn.SyscallConn()
1926
if err != nil {
2027
return
2128
}

common/network/direct.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package network
22

33
import (
4+
"io"
5+
"syscall"
6+
7+
"github.com/sagernet/sing/common"
48
"github.com/sagernet/sing/common/buf"
59
M "github.com/sagernet/sing/common/metadata"
610
)
@@ -109,3 +113,87 @@ type VectorisedPacketReadWaiter interface {
109113
type VectorisedPacketReadWaitCreator interface {
110114
CreateVectorisedPacketReadWaiter() (VectorisedPacketReadWaiter, bool)
111115
}
116+
117+
type SyscallReadCreator interface {
118+
SyscallConnForRead() syscall.Conn
119+
}
120+
121+
func SyscallAvailableForRead(reader io.Reader) bool {
122+
if _, ok := reader.(syscall.Conn); ok {
123+
return true
124+
}
125+
if _, ok := reader.(SyscallReadCreator); ok {
126+
return true
127+
}
128+
if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
129+
return false
130+
}
131+
if u, ok := reader.(WithUpstreamReader); ok {
132+
return SyscallAvailableForRead(u.UpstreamReader().(io.Reader))
133+
}
134+
if u, ok := reader.(common.WithUpstream); ok {
135+
return SyscallAvailableForRead(u.Upstream().(io.Reader))
136+
}
137+
return false
138+
}
139+
140+
func SyscallConnForRead(reader io.Reader) syscall.Conn {
141+
if c, ok := reader.(syscall.Conn); ok {
142+
return c
143+
}
144+
if c, ok := reader.(SyscallReadCreator); ok {
145+
return c.SyscallConnForRead()
146+
}
147+
if u, ok := reader.(ReaderWithUpstream); !ok || !u.ReaderReplaceable() {
148+
return nil
149+
}
150+
if u, ok := reader.(WithUpstreamReader); ok {
151+
return SyscallConnForRead(u.UpstreamReader().(io.Reader))
152+
}
153+
if u, ok := reader.(common.WithUpstream); ok {
154+
return SyscallConnForRead(u.Upstream().(io.Reader))
155+
}
156+
return nil
157+
}
158+
159+
type SyscallWriteCreator interface {
160+
SyscallConnForWrite() syscall.Conn
161+
}
162+
163+
func SyscallAvailableForWrite(writer io.Writer) bool {
164+
if _, ok := writer.(syscall.Conn); ok {
165+
return true
166+
}
167+
if _, ok := writer.(SyscallWriteCreator); ok {
168+
return true
169+
}
170+
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
171+
return false
172+
}
173+
if u, ok := writer.(WithUpstreamWriter); ok {
174+
return SyscallAvailableForWrite(u.UpstreamWriter().(io.Writer))
175+
}
176+
if u, ok := writer.(common.WithUpstream); ok {
177+
return SyscallAvailableForWrite(u.Upstream().(io.Writer))
178+
}
179+
return false
180+
}
181+
182+
func SyscallConnForWrite(writer io.Writer) syscall.Conn {
183+
if c, ok := writer.(syscall.Conn); ok {
184+
return c
185+
}
186+
if c, ok := writer.(SyscallWriteCreator); ok {
187+
return c.SyscallConnForWrite()
188+
}
189+
if u, ok := writer.(WriterWithUpstream); !ok || !u.WriterReplaceable() {
190+
return nil
191+
}
192+
if u, ok := writer.(WithUpstreamWriter); ok {
193+
return SyscallConnForWrite(u.UpstreamWriter().(io.Writer))
194+
}
195+
if u, ok := writer.(common.WithUpstream); ok {
196+
return SyscallConnForWrite(u.Upstream().(io.Writer))
197+
}
198+
return nil
199+
}

0 commit comments

Comments
 (0)