diff --git a/chronos/transports/stream.nim b/chronos/transports/stream.nim index a467d9637..5a2da9ad5 100644 --- a/chronos/transports/stream.nim +++ b/chronos/transports/stream.nim @@ -297,6 +297,25 @@ proc clean(transp: StreamTransport) {.inline.} = transp.future.complete() GC_unref(transp) +proc bindSocket*(sock: AsyncFD, localAddress: TransportAddress, flags: set[ServerFlags] = {}): bool = + ## Returns ``true`` on success, ``false`` on error. + if ServerFlags.ReuseAddr in flags: + if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEADDR, 1)): + return false + if ServerFlags.ReusePort in flags: + if not(setSockOpt(sock, SOL_SOCKET, SO_REUSEPORT, 1)): + return false + if ServerFlags.TcpNoDelay in flags: + if not(setSockOpt(sock, handles.IPPROTO_TCP, handles.TCP_NODELAY, 1)): + return false + var + localAddr: Sockaddr_storage + localAddrLen: SockLen + localAddress.toSAddr(localAddr, localAddrLen) + if bindSocket(SocketHandle(sock), cast[ptr SockAddr](addr localAddr), localAddrLen) != 0: + return false + return true + when defined(nimdoc): proc pauseAccept(server: StreamServer) {.inline.} = discard proc resumeAccept(server: StreamServer) {.inline.} = discard @@ -702,10 +721,18 @@ elif defined(windows): sizeof(saddr).SockLen) != 0'i32: result = false + proc isDomainSet(sock: AsyncFD): bool = + try: + discard getSockDomain(SocketHandle(sock)) + true + except CatchableError as ex: + false + proc connect*(address: TransportAddress, bufferSize = DefaultStreamBufferSize, child: StreamTransport = nil, - flags: set[TransportFlags] = {}): Future[StreamTransport] = + flags: set[TransportFlags] = {}, + sock: AsyncFD = asyncInvalidSocket): Future[StreamTransport] = ## Open new connection to remote peer with address ``address`` and create ## new transport object ``StreamTransport`` for established connection. ## ``bufferSize`` is size of internal buffer for transport. @@ -717,26 +744,37 @@ elif defined(windows): var saddr: Sockaddr_storage slen: SockLen - sock: AsyncFD + localSock: AsyncFD povl: RefCustomOverlapped - proto: Protocol var raddress = windowsAnyAddressFix(address) - toSAddr(raddress, saddr, slen) - proto = Protocol.IPPROTO_TCP - sock = try: createAsyncSocket(raddress.getDomain(), SockType.SOCK_STREAM, - proto) - except CatchableError as exc: - retFuture.fail(exc) - return retFuture + if sock == asyncInvalidSocket: + localSock = try: createAsyncSocket(raddress.getDomain(), SockType.SOCK_STREAM, + Protocol.IPPROTO_TCP) + except CatchableError as exc: + retFuture.fail(exc) + return retFuture + else: + if not setSocketBlocking(SocketHandle(sock), false): + retFuture.fail(getTransportOsError(osLastError())) + return retFuture + localSock = sock + try: + register(localSock) + except CatchableError as exc: + retFuture.fail(exc) + return retFuture + + if localSock == asyncInvalidSocket: retFuture.fail(getTransportOsError(osLastError())) return retFuture - if not(bindToDomain(sock, raddress.getDomain())): + if not isDomainSet(localSock) and not(bindToDomain(localSock, raddress.getDomain())): let err = wsaGetLastError() - sock.closeSocket() + if sock == asyncInvalidSocket: + localSock.closeSocket() retFuture.fail(getTransportOsError(err)) return retFuture @@ -744,29 +782,32 @@ elif defined(windows): var ovl = cast[RefCustomOverlapped](udata) if not(retFuture.finished()): if ovl.data.errCode == OSErrorCode(-1): - if setsockopt(SocketHandle(sock), cint(SOL_SOCKET), + if setsockopt(SocketHandle(localSock), cint(SOL_SOCKET), cint(SO_UPDATE_CONNECT_CONTEXT), nil, SockLen(0)) != 0'i32: let err = wsaGetLastError() - sock.closeSocket() + if sock == asyncInvalidSocket: + localSock.closeSocket() retFuture.fail(getTransportOsError(err)) else: - let transp = newStreamSocketTransport(sock, bufferSize, child) + let transp = newStreamSocketTransport(localSock, bufferSize, child) # Start tracking transport trackStream(transp) retFuture.complete(transp) else: - sock.closeSocket() + if sock == asyncInvalidSocket: + localSock.closeSocket() retFuture.fail(getTransportOsError(ovl.data.errCode)) GC_unref(ovl) proc cancel(udata: pointer) {.gcsafe.} = - sock.closeSocket() + if sock == asyncInvalidSocket: + localSock.closeSocket() povl = RefCustomOverlapped() GC_ref(povl) povl.data = CompletionData(cb: socketContinuation) - let res = loop.connectEx(SocketHandle(sock), + let res = loop.connectEx(SocketHandle(localSock), cast[ptr SockAddr](addr saddr), DWORD(slen), nil, 0, nil, cast[POVERLAPPED](povl)) @@ -775,7 +816,8 @@ elif defined(windows): let err = osLastError() if int32(err) != ERROR_IO_PENDING: GC_unref(povl) - sock.closeSocket() + if sock == asyncInvalidSocket: + localSock.closeSocket() retFuture.fail(getTransportOsError(err)) retFuture.cancelCallback = cancel @@ -816,7 +858,8 @@ elif defined(windows): trackStream(transp) retFuture.complete(transp) pipeContinuation(nil) - + else: + retFuture.fail(newException(TransportAddressError, "Unsupported address family")) return retFuture proc createAcceptPipe(server: StreamServer) {. @@ -1493,29 +1536,45 @@ else: proc connect*(address: TransportAddress, bufferSize = DefaultStreamBufferSize, child: StreamTransport = nil, - flags: set[TransportFlags] = {}): Future[StreamTransport] = + flags: set[TransportFlags] = {}, + sock: AsyncFD = asyncInvalidSocket + ): Future[StreamTransport] = ## Open new connection to remote peer with address ``address`` and create ## new transport object ``StreamTransport`` for established connection. ## ``bufferSize`` - size of internal buffer for transport. var saddr: Sockaddr_storage slen: SockLen - proto: Protocol + localSock: AsyncFD + var retFuture = newFuture[StreamTransport]("stream.transport.connect") address.toSAddr(saddr, slen) - proto = Protocol.IPPROTO_TCP - if address.family == AddressFamily.Unix: - # `Protocol` enum is missing `0` value, so we making here cast, until - # `Protocol` enum will not support IPPROTO_IP == 0. - proto = cast[Protocol](0) - let sock = try: createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM, + if sock == asyncInvalidSocket: + let proto = + if address.family == AddressFamily.Unix: + # `Protocol` enum is missing `0` value, so we making here cast, until + # `Protocol` enum will not support IPPROTO_IP == 0. + cast[Protocol](0) + else: Protocol.IPPROTO_TCP + localSock = + try: createAsyncSocket(address.getDomain(), SockType.SOCK_STREAM, proto) - except CatchableError as exc: - retFuture.fail(exc) - return retFuture + except CatchableError as exc: + retFuture.fail(exc) + return retFuture + else: + if not setSocketBlocking(SocketHandle(sock), false): + retFuture.fail(getTransportOsError(osLastError())) + return retFuture + localSock = sock + try: + register(localSock) + except CatchableError as exc: + retFuture.fail(exc) + return retFuture - if sock == asyncInvalidSocket: + if localSock == asyncInvalidSocket: let err = osLastError() if int(err) == EMFILE: retFuture.fail(getTransportTooManyError()) @@ -1523,13 +1582,13 @@ else: retFuture.fail(getTransportOsError(err)) return retFuture - if address.family in {AddressFamily.IPv4, AddressFamily.IPv6}: if TransportFlags.TcpNoDelay in flags: - if not(setSockOpt(sock, handles.IPPROTO_TCP, + if not(setSockOpt(localSock, handles.IPPROTO_TCP, handles.TCP_NODELAY, 1)): let err = osLastError() - sock.closeSocket() + if sock == asyncInvalidSocket: + localSock.closeSocket() retFuture.fail(getTransportOsError(err)) return retFuture @@ -1537,7 +1596,7 @@ else: if not(retFuture.finished()): var err = 0 try: - sock.removeWriter() + localSock.removeWriter() except IOSelectorsException as exc: retFuture.fail(exc) return @@ -1545,27 +1604,30 @@ else: retFuture.fail(exc) return - if not(sock.getSocketError(err)): - closeSocket(sock) + if not(localSock.getSocketError(err)): + if sock == asyncInvalidSocket: + closeSocket(localSock) retFuture.fail(getTransportOsError(osLastError())) return if err != 0: - closeSocket(sock) + if sock == asyncInvalidSocket: + closeSocket(localSock) retFuture.fail(getTransportOsError(OSErrorCode(err))) return - let transp = newStreamSocketTransport(sock, bufferSize, child) + let transp = newStreamSocketTransport(localSock, bufferSize, child) # Start tracking transport trackStream(transp) retFuture.complete(transp) proc cancel(udata: pointer) = - closeSocket(sock) + if sock == asyncInvalidSocket: + closeSocket(localSock) while true: - var res = posix.connect(SocketHandle(sock), + var res = posix.connect(SocketHandle(localSock), cast[ptr SockAddr](addr saddr), slen) if res == 0: - let transp = newStreamSocketTransport(sock, bufferSize, child) + let transp = newStreamSocketTransport(localSock, bufferSize, child) # Start tracking transport trackStream(transp) retFuture.complete(transp) @@ -1580,16 +1642,18 @@ else: # http://www.madore.org/~david/computers/connect-intr.html if int(err) == EINPROGRESS or int(err) == EINTR: try: - sock.addWriter(continuation) + localSock.addWriter(continuation) except CatchableError as exc: - closeSocket(sock) + if sock == asyncInvalidSocket: + closeSocket(localSock) retFuture.fail(exc) return retFuture retFuture.cancelCallback = cancel break else: - sock.closeSocket() + if sock == asyncInvalidSocket: + localSock.closeSocket() retFuture.fail(getTransportOsError(err)) break diff --git a/tests/teststream.nim b/tests/teststream.nim index 42ac7b155..042cfab29 100644 --- a/tests/teststream.nim +++ b/tests/teststream.nim @@ -5,7 +5,7 @@ # Licensed under either of # Apache License, version 2.0, (LICENSE-APACHEv2) # MIT license (LICENSE-MIT) -import std/[strutils, os] +import std/[strutils, os, nativesockets] import unittest2 import ../chronos @@ -1260,6 +1260,43 @@ suite "Stream Transport test suite": await allFutures(rtransp.closeWait(), wtransp.closeWait()) return buffer == message + proc testConnectReuseLocalPort(): Future[void] {.async.} = + let dst1 = initTAddress("127.0.0.1:33335") + let dst2 = initTAddress("127.0.0.1:33336") + let dst3 = initTAddress("127.0.0.1:33337") + + proc client(server: StreamServer, transp: StreamTransport) {.async.} = + await transp.closeWait() + + let servers = + [createStreamServer(dst1, client, {ReuseAddr}), + createStreamServer(dst2, client, {ReuseAddr}), + createStreamServer(dst3, client, {ReuseAddr})] + + for server in servers: + server.start() + + let ta = initTAddress("0.0.0.0:35000") + + let sock1 = AsyncFD(createNativeSocket(dst1.getDomain(), SockType.SOCK_STREAM, Protocol.IPPROTO_TCP)) + check: bindSocket(sock1, ta, {ReuseAddr}) + var transp1 = await connect(dst1, sock=sock1) + + let sock2 = AsyncFD(createNativeSocket(dst2.getDomain(), SockType.SOCK_STREAM, Protocol.IPPROTO_TCP)) + check: bindSocket(sock2, ta, {ReuseAddr}) + var transp2 = await connect(dst2, sock=sock2) + + let sock3 = AsyncFD(createNativeSocket(dst3.getDomain(), SockType.SOCK_STREAM, Protocol.IPPROTO_TCP)) + check: false == bindSocket(sock3, ta) + + await transp1.closeWait() + await transp2.closeWait() + sock3.closeSocket() + + for server in servers: + server.stop() + await server.closeWait() + markFD = getCurrentFD() for i in 0..