Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 110 additions & 46 deletions chronos/transports/stream.nim
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,25 @@ proc clean(transp: StreamTransport) {.inline.} =
transp.future.complete()
GC_unref(transp)

proc bindSocket*(sock: AsyncFD, localAddress: TransportAddress, flags: set[ServerFlags] = {}): bool =
## Returns ``true`` on success, ``false`` on error.
if ServerFlags.ReuseAddr in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)):
return false
if ServerFlags.ReusePort in flags:
if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)):
return false
if ServerFlags.TcpNoDelay in flags:
if not(setSockOpt(sock, handles.IPPROTO_TCP, handles.TCP_NODELAY, 1)):
return false
var
localAddr: Sockaddr_storage
localAddrLen: SockLen
localAddress.toSAddr(localAddr, localAddrLen)
if bindSocket(SocketHandle(sock), cast[ptr SockAddr](addr localAddr), localAddrLen) != 0:
return false
return true

when defined(nimdoc):
proc pauseAccept(server: StreamServer) {.inline.} = discard
proc resumeAccept(server: StreamServer) {.inline.} = discard
Expand Down Expand Up @@ -702,10 +721,18 @@ elif defined(windows):
sizeof(saddr).SockLen) != 0'i32:
result = false

proc isDomainSet(sock: AsyncFD): bool =
try:
discard getSockDomain(SocketHandle(sock))
true
except CatchableError as ex:
false

proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil,
flags: set[TransportFlags] = {}): Future[StreamTransport] =
flags: set[TransportFlags] = {},
sock: AsyncFD = asyncInvalidSocket): Future[StreamTransport] =
## Open new connection to remote peer with address ``address`` and create
## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` is size of internal buffer for transport.
Expand All @@ -717,56 +744,70 @@ elif defined(windows):
var
saddr: Sockaddr_storage
slen: SockLen
sock: AsyncFD
localSock: AsyncFD
povl: RefCustomOverlapped
proto: Protocol

var raddress = windowsAnyAddressFix(address)

toSAddr(raddress, saddr, slen)
proto = Protocol.IPPROTO_TCP
sock = try: createAsyncSocket(raddress.getDomain(), SockType.SOCK_STREAM,
proto)
except CatchableError as exc:
retFuture.fail(exc)
return retFuture

if sock == asyncInvalidSocket:
localSock = try: createAsyncSocket(raddress.getDomain(), SockType.SOCK_STREAM,
Protocol.IPPROTO_TCP)
except CatchableError as exc:
retFuture.fail(exc)
return retFuture
else:
if not setSocketBlocking(SocketHandle(sock), false):
retFuture.fail(getTransportOsError(osLastError()))
return retFuture
localSock = sock
try:
register(localSock)
except CatchableError as exc:
retFuture.fail(exc)
return retFuture

if localSock == asyncInvalidSocket:
retFuture.fail(getTransportOsError(osLastError()))
return retFuture

if not(bindToDomain(sock, raddress.getDomain())):
if not isDomainSet(localSock) and not(bindToDomain(localSock, raddress.getDomain())):
let err = wsaGetLastError()
sock.closeSocket()
if sock == asyncInvalidSocket:
localSock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture

proc socketContinuation(udata: pointer) {.gcsafe.} =
var ovl = cast[RefCustomOverlapped](udata)
if not(retFuture.finished()):
if ovl.data.errCode == OSErrorCode(-1):
if setsockopt(SocketHandle(sock), cint(SOL_SOCKET),
if setsockopt(SocketHandle(localSock), cint(SOL_SOCKET),
cint(SO_UPDATE_CONNECT_CONTEXT), nil,
SockLen(0)) != 0'i32:
let err = wsaGetLastError()
sock.closeSocket()
if sock == asyncInvalidSocket:
localSock.closeSocket()
retFuture.fail(getTransportOsError(err))
else:
let transp = newStreamSocketTransport(sock, bufferSize, child)
let transp = newStreamSocketTransport(localSock, bufferSize, child)
# Start tracking transport
trackStream(transp)
retFuture.complete(transp)
else:
sock.closeSocket()
if sock == asyncInvalidSocket:
localSock.closeSocket()
retFuture.fail(getTransportOsError(ovl.data.errCode))
GC_unref(ovl)

proc cancel(udata: pointer) {.gcsafe.} =
sock.closeSocket()
if sock == asyncInvalidSocket:
localSock.closeSocket()

povl = RefCustomOverlapped()
GC_ref(povl)
povl.data = CompletionData(cb: socketContinuation)
let res = loop.connectEx(SocketHandle(sock),
let res = loop.connectEx(SocketHandle(localSock),
cast[ptr SockAddr](addr saddr),
DWORD(slen), nil, 0, nil,
cast[POVERLAPPED](povl))
Expand All @@ -775,7 +816,8 @@ elif defined(windows):
let err = osLastError()
if int32(err) != ERROR_IO_PENDING:
GC_unref(povl)
sock.closeSocket()
if sock == asyncInvalidSocket:
localSock.closeSocket()
retFuture.fail(getTransportOsError(err))

retFuture.cancelCallback = cancel
Expand Down Expand Up @@ -816,7 +858,8 @@ elif defined(windows):
trackStream(transp)
retFuture.complete(transp)
pipeContinuation(nil)

else:
retFuture.fail(newException(TransportAddressError, "Unsupported address family"))
return retFuture

proc createAcceptPipe(server: StreamServer) {.
Expand Down Expand Up @@ -1493,79 +1536,98 @@ else:
proc connect*(address: TransportAddress,
bufferSize = DefaultStreamBufferSize,
child: StreamTransport = nil,
flags: set[TransportFlags] = {}): Future[StreamTransport] =
flags: set[TransportFlags] = {},
sock: AsyncFD = asyncInvalidSocket
): Future[StreamTransport] =
## Open new connection to remote peer with address ``address`` and create
## new transport object ``StreamTransport`` for established connection.
## ``bufferSize`` - size of internal buffer for transport.
var
saddr: Sockaddr_storage
slen: SockLen
proto: Protocol
localSock: AsyncFD

var retFuture = newFuture[StreamTransport]("stream.transport.connect")
address.toSAddr(saddr, slen)
proto = Protocol.IPPROTO_TCP
if address.family == AddressFamily.Unix:
# `Protocol` enum is missing `0` value, so we making here cast, until
# `Protocol` enum will not support IPPROTO_IP == 0.
proto = cast[Protocol](0)

let sock = try: createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM,
if sock == asyncInvalidSocket:
let proto =
if address.family == AddressFamily.Unix:
# `Protocol` enum is missing `0` value, so we making here cast, until
# `Protocol` enum will not support IPPROTO_IP == 0.
cast[Protocol](0)
else: Protocol.IPPROTO_TCP
localSock =
try: createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM,
proto)
except CatchableError as exc:
retFuture.fail(exc)
return retFuture
except CatchableError as exc:
retFuture.fail(exc)
return retFuture
else:
if not setSocketBlocking(SocketHandle(sock), false):
retFuture.fail(getTransportOsError(osLastError()))
return retFuture
localSock = sock
try:
register(localSock)
except CatchableError as exc:
retFuture.fail(exc)
return retFuture

if sock == asyncInvalidSocket:
if localSock == asyncInvalidSocket:
let err = osLastError()
if int(err) == EMFILE:
retFuture.fail(getTransportTooManyError())
else:
retFuture.fail(getTransportOsError(err))
return retFuture


if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}:
if TransportFlags.TcpNoDelay in flags:
if not(setSockOpt(sock, handles.IPPROTO_TCP,
if not(setSockOpt(localSock, handles.IPPROTO_TCP,
handles.TCP_NODELAY, 1)):
let err = osLastError()
sock.closeSocket()
if sock == asyncInvalidSocket:
localSock.closeSocket()
retFuture.fail(getTransportOsError(err))
return retFuture

proc continuation(udata: pointer) =
if not(retFuture.finished()):
var err = 0
try:
sock.removeWriter()
localSock.removeWriter()
except IOSelectorsException as exc:
retFuture.fail(exc)
return
except ValueError as exc:
retFuture.fail(exc)
return

if not(sock.getSocketError(err)):
closeSocket(sock)
if not(localSock.getSocketError(err)):
if sock == asyncInvalidSocket:
closeSocket(localSock)
retFuture.fail(getTransportOsError(osLastError()))
return
if err != 0:
closeSocket(sock)
if sock == asyncInvalidSocket:
closeSocket(localSock)
retFuture.fail(getTransportOsError(OSErrorCode(err)))
return
let transp = newStreamSocketTransport(sock, bufferSize, child)
let transp = newStreamSocketTransport(localSock, bufferSize, child)
# Start tracking transport
trackStream(transp)
retFuture.complete(transp)

proc cancel(udata: pointer) =
closeSocket(sock)
if sock == asyncInvalidSocket:
closeSocket(localSock)

while true:
var res = posix.connect(SocketHandle(sock),
var res = posix.connect(SocketHandle(localSock),
cast[ptr SockAddr](addr saddr), slen)
if res == 0:
let transp = newStreamSocketTransport(sock, bufferSize, child)
let transp = newStreamSocketTransport(localSock, bufferSize, child)
# Start tracking transport
trackStream(transp)
retFuture.complete(transp)
Expand All @@ -1580,16 +1642,18 @@ else:
# http://www.madore.org/~david/computers/connect-intr.html
if int(err) == EINPROGRESS or int(err) == EINTR:
try:
sock.addWriter(continuation)
localSock.addWriter(continuation)
except CatchableError as exc:
closeSocket(sock)
if sock == asyncInvalidSocket:
closeSocket(localSock)
retFuture.fail(exc)
return retFuture

retFuture.cancelCallback = cancel
break
else:
sock.closeSocket()
if sock == asyncInvalidSocket:
localSock.closeSocket()

retFuture.fail(getTransportOsError(err))
break
Expand Down
50 changes: 49 additions & 1 deletion tests/teststream.nim
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# Licensed under either of
# Apache License, version 2.0, (LICENSE-APACHEv2)
# MIT license (LICENSE-MIT)
import std/[strutils, os]
import std/[strutils, os, nativesockets]
import unittest2
import ../chronos

Expand Down Expand Up @@ -1260,6 +1260,43 @@ suite "Stream Transport test suite":
await allFutures(rtransp.closeWait(), wtransp.closeWait())
return buffer == message

proc testConnectReuseLocalPort(): Future[void] {.async.} =
let dst1 = initTAddress("127.0.0.1:33335")
let dst2 = initTAddress("127.0.0.1:33336")
let dst3 = initTAddress("127.0.0.1:33337")

proc client(server: StreamServer, transp: StreamTransport) {.async.} =
await transp.closeWait()

let servers =
[createStreamServer(dst1, client, {ReuseAddr}),
createStreamServer(dst2, client, {ReuseAddr}),
createStreamServer(dst3, client, {ReuseAddr})]

for server in servers:
server.start()

let ta = initTAddress("0.0.0.0:35000")

let sock1 = AsyncFD(createNativeSocket(dst1.getDomain(), SockType.SOCK_STREAM, Protocol.IPPROTO_TCP))
check: bindSocket(sock1, ta, {ReuseAddr})
var transp1 = await connect(dst1, sock=sock1)

let sock2 = AsyncFD(createNativeSocket(dst2.getDomain(), SockType.SOCK_STREAM, Protocol.IPPROTO_TCP))
check: bindSocket(sock2, ta, {ReuseAddr})
var transp2 = await connect(dst2, sock=sock2)

let sock3 = AsyncFD(createNativeSocket(dst3.getDomain(), SockType.SOCK_STREAM, Protocol.IPPROTO_TCP))
check: false == bindSocket(sock3, ta)

await transp1.closeWait()
await transp2.closeWait()
sock3.closeSocket()

for server in servers:
server.stop()
await server.closeWait()

markFD = getCurrentFD()

for i in 0..<len(addresses):
Expand Down Expand Up @@ -1347,6 +1384,8 @@ suite "Stream Transport test suite":
check waitFor(testReadOnClose(addresses[i])) == true
test "[PIPE] readExactly()/write() test":
check waitFor(testPipe()) == true
test "Connect reusing same port in local address":
waitFor testConnectReuseLocalPort()
test "Servers leak test":
check getTracker("stream.server").isLeaked() == false
test "Transports leak test":
Expand All @@ -1358,3 +1397,12 @@ suite "Stream Transport test suite":
skip()
else:
check getCurrentFD() == markFD

test "Leaks test":
proc getTrackerLeaks(tracker: string): bool =
let tracker = getTracker(tracker)
if isNil(tracker): false else: tracker.isLeaked()

check:
getTrackerLeaks("stream.server") == false
getTrackerLeaks("stream.transport") == false