Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 87 additions & 5 deletions chronos/transports/stream.nim
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,12 @@ type
# get stuck on transport `close()`.
# Please use this flag only if you are making both client and server in
# the same thread.
TcpNoDelay
TcpNoDelay # deprecated: Use SocketFlags.TcpNoDelay

SocketFlags* {.pure.} = enum
TcpNoDelay,
ReuseAddr,
ReusePort


StreamTransportTracker* = ref object of TrackerBase
Expand Down Expand Up @@ -699,7 +704,9 @@ when defined(windows):
proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil,
flags: set[TransportFlags] = {}): Future[StreamTransport] =
localAddress = TransportAddress(),
flags: set[SocketFlags] = {},
): Future[StreamTransport] =
## Open new connection to remote peer with address ``address`` and create
## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` is size of internal buffer for transport.
Expand All @@ -724,7 +731,35 @@ when defined(windows):
retFuture.fail(getTransportOsError(osLastError()))
return retFuture

if not(bindToDomain(sock, raddress.getDomain())):
if SocketFlags.ReuseAddr in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
if SocketFlags.ReusePort in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture

if localAddress != TransportAddress():
if localAddress.family != address.family:
sock.closeSocket()
retFuture.fail(newException(TransportOsError,
"connect local address domain is not equal to target address domain"))
return retFuture
var
localAddr: Sockaddr_storage
localAddrLen: SockLen
localAddress.toSAddr(localAddr, localAddrLen)
if bindSocket(SocketHandle(sock),
cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
sock.closeSocket()
retFuture.fail(getTransportOsError(osLastError()))
return retFuture
elif not(bindToDomain(sock, raddress.getDomain())):
let err = wsaGetLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
Expand Down Expand Up @@ -1478,7 +1513,9 @@ else:
proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil,
flags: set[TransportFlags] = {}): Future[StreamTransport] =
localAddress = TransportAddress(),
flags: set[SocketFlags] = {},
): Future[StreamTransport] =
## Open new connection to remote peer with address ``address`` and create
## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` - size of internal buffer for transport.
Expand All @@ -1505,12 +1542,40 @@ else:
return retFuture

if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
if TransportFlags.TcpNoDelay in flags:
if SocketFlags.TcpNoDelay in flags:
if not(setSockOpt(sock, osdefs.IPPROTO_TCP, osdefs.TCP_NODELAY, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
if SocketFlags.ReuseAddr in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture
if SocketFlags.ReusePort in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)):
let err = osLastError()
sock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture

if localAddress != TransportAddress():
if localAddress.family != address.family:
sock.closeSocket()
retFuture.fail(newException(TransportOsError,
"connect local address domain is not equal to target address domain"))
return retFuture
var
localAddr: Sockaddr_storage
localAddrLen: SockLen
localAddress.toSAddr(localAddr, localAddrLen)
if bindSocket(SocketHandle(sock),
cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
sock.closeSocket()
retFuture.fail(getTransportOsError(osLastError()))
return retFuture

proc continuation(udata: pointer) =
if not(retFuture.finished()):
Expand Down Expand Up @@ -1758,6 +1823,16 @@ proc join*(server: StreamServer): Future[void] =
retFuture.complete()
return retFuture

proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil,
flags: set[TransportFlags],
localAddress = TransportAddress()): Future[StreamTransport] =
# Retro compatibility with TransportFlags
var mappedFlags: set[SocketFlags]
if TcpNoDelay in flags: mappedFlags.incl(SocketFlags.TcpNoDelay)
address.connect(bufferSize, child, localAddress, mappedFlags)

proc close*(server: StreamServer) =
## Release ``server`` resources.
##
Expand Down Expand Up @@ -1846,6 +1921,13 @@ proc createStreamServer*(host: TransportAddress,
if sock == asyncInvalidSocket:
discard closeFd(SocketHandle(serverSocket))
raiseTransportOsError(err)
if ServerFlags.ReusePort in flags:
if not(setSockOpt(serverSocket, osdefs.SOL_SOCKET,
osdefs.SO_REUSEPORT, 1)):
let err = osLastError()
if sock == asyncInvalidSocket:
discard closeFd(SocketHandle(serverSocket))
raiseTransportOsError(err)
# TCP flags are not useful for Unix domain sockets.
if ServerFlags.TcpNoDelay in flags:
if not(setSockOpt(serverSocket, osdefs.IPPROTO_TCP,
Expand Down
2 changes: 1 addition & 1 deletion tests/testasyncstream.nim
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ suite "TLSStream test suite":
key = TLSPrivateKey.init(pemkey)
cert = TLSCertificate.init(pemcert)

var server = createStreamServer(address, serveClient, {ReuseAddr})
var server = createStreamServer(address, serveClient, {ServerFlags.ReuseAddr})
server.start()
var conn = await connect(address)
var creader = newAsyncStreamReader(conn)
Expand Down
43 changes: 43 additions & 0 deletions tests/teststream.nim
Original file line number Diff line number Diff line change
Expand Up @@ -1259,6 +1259,47 @@ suite "Stream Transport test suite":
await allFutures(rtransp.closeWait(), wtransp.closeWait())
return buffer == message

proc testConnectBindLocalAddress() {.async.} =
let dst1 = initTAddress("127.0.0.1:33335")
let dst2 = initTAddress("127.0.0.1:33336")
let dst3 = initTAddress("127.0.0.1:33337")

proc client(server: StreamServer, transp: StreamTransport) {.async.} =
await transp.closeWait()

# We use ReuseAddr here only to be able to reuse the same IP/Port when there's a TIME_WAIT socket. It's useful when
# running the test multiple times or if a test ran previously used the same port.
let servers =
[createStreamServer(dst1, client, {ReuseAddr}),
createStreamServer(dst2, client, {ReuseAddr}),
createStreamServer(dst3, client, {ReusePort})]

for server in servers:
server.start()

let ta = initTAddress("0.0.0.0:35000")

# It works cause there's no active listening socket bound to ta and we are using ReuseAddr
var transp1 = await connect(dst1, localAddress = ta, flags={SocketFlags.ReuseAddr})
var transp2 = await connect(dst2, localAddress = ta, flags={SocketFlags.ReuseAddr})

# It works cause even thought there's an active listening socket bound to dst3, we are using ReusePort
var transp3 = await connect(dst2, localAddress = dst3, flags={SocketFlags.ReusePort})

expect(TransportOsError):
var transp2 = await connect(dst3, localAddress = ta)

expect(TransportOsError):
var transp3 = await connect(dst3, localAddress = initTAddress(":::35000"))

await transp1.closeWait()
await transp2.closeWait()
await transp3.closeWait()

for server in servers:
server.stop()
await server.closeWait()

markFD = getCurrentFD()

for i in 0..<len(addresses):
Expand Down Expand Up @@ -1346,6 +1387,8 @@ suite "Stream Transport test suite":
check waitFor(testReadOnClose(addresses[i])) == true
test "[PIPE] readExactly()/write() test":
check waitFor(testPipe()) == true
test "[IP] bind connect to local address":
waitFor(testConnectBindLocalAddress())
test "Servers leak test":
check getTracker("stream.server").isLeaked() == false
test "Transports leak test":
Expand Down