Skip to content

Commit 9c2f13e

Browse files
net: use Sock in InterruptibleRecv() and Socks5()
Use the `Sock` class instead of `SOCKET` for `InterruptibleRecv()` and `Socks5()`. This way the `Socks5()` function can be tested by giving it a mocked instance of a socket. Co-authored-by: practicalswift <[email protected]>
1 parent fa7df9b commit 9c2f13e

File tree

3 files changed

+39
-42
lines changed

3 files changed

+39
-42
lines changed

src/net.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
448448
return nullptr;
449449
}
450450
connected = ConnectThroughProxy(proxy, addrConnect.ToStringIP(), addrConnect.GetPort(),
451-
sock->Get(), nConnectTimeout, proxyConnectionFailed);
451+
*sock, nConnectTimeout, proxyConnectionFailed);
452452
} else {
453453
// no proxy needed (none set for target network)
454454
sock = CreateSock(addrConnect);
@@ -472,8 +472,8 @@ CNode* CConnman::ConnectNode(CAddress addrConnect, const char *pszDest, bool fCo
472472
int port = default_port;
473473
SplitHostPort(std::string(pszDest), port, host);
474474
bool proxyConnectionFailed;
475-
connected = ConnectThroughProxy(proxy, host, port, sock->Get(), nConnectTimeout,
476-
proxyConnectionFailed);
475+
connected =
476+
ConnectThroughProxy(proxy, host, port, *sock, nConnectTimeout, proxyConnectionFailed);
477477
}
478478
if (!connected) {
479479
return nullptr;

src/netbase.cpp

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,7 @@ enum class IntrRecvError {
338338
* @param data The buffer where the read bytes should be stored.
339339
* @param len The number of bytes to read into the specified buffer.
340340
* @param timeout The total timeout in milliseconds for this read.
341-
* @param hSocket The socket (has to be in non-blocking mode) from which to read
342-
* bytes.
341+
* @param sock The socket (has to be in non-blocking mode) from which to read bytes.
343342
*
344343
* @returns An IntrRecvError indicating the resulting status of this read.
345344
* IntrRecvError::OK only if all of the specified number of bytes were
@@ -349,15 +348,15 @@ enum class IntrRecvError {
349348
* Sockets can be made non-blocking with SetSocketNonBlocking(const
350349
* SOCKET&, bool).
351350
*/
352-
static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, const SOCKET& hSocket)
351+
static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, const Sock& sock)
353352
{
354353
int64_t curTime = GetTimeMillis();
355354
int64_t endTime = curTime + timeout;
356355
// Maximum time to wait for I/O readiness. It will take up until this time
357356
// (in millis) to break off in case of an interruption.
358357
const int64_t maxWait = 1000;
359358
while (len > 0 && curTime < endTime) {
360-
ssize_t ret = recv(hSocket, (char*)data, len, 0); // Optimistically try the recv first
359+
ssize_t ret = sock.Recv(data, len, 0); // Optimistically try the recv first
361360
if (ret > 0) {
362361
len -= ret;
363362
data += ret;
@@ -366,25 +365,10 @@ static IntrRecvError InterruptibleRecv(uint8_t* data, size_t len, int timeout, c
366365
} else { // Other error or blocking
367366
int nErr = WSAGetLastError();
368367
if (nErr == WSAEINPROGRESS || nErr == WSAEWOULDBLOCK || nErr == WSAEINVAL) {
369-
if (!IsSelectableSocket(hSocket)) {
370-
return IntrRecvError::NetworkError;
371-
}
372368
// Only wait at most maxWait milliseconds at a time, unless
373369
// we're approaching the end of the specified total timeout
374370
int timeout_ms = std::min(endTime - curTime, maxWait);
375-
#ifdef USE_POLL
376-
struct pollfd pollfd = {};
377-
pollfd.fd = hSocket;
378-
pollfd.events = POLLIN;
379-
int nRet = poll(&pollfd, 1, timeout_ms);
380-
#else
381-
struct timeval tval = MillisToTimeval(timeout_ms);
382-
fd_set fdset;
383-
FD_ZERO(&fdset);
384-
FD_SET(hSocket, &fdset);
385-
int nRet = select(hSocket + 1, &fdset, nullptr, nullptr, &tval);
386-
#endif
387-
if (nRet == SOCKET_ERROR) {
371+
if (!sock.Wait(std::chrono::milliseconds{timeout_ms}, Sock::RECV)) {
388372
return IntrRecvError::NetworkError;
389373
}
390374
} else {
@@ -438,7 +422,7 @@ static std::string Socks5ErrorString(uint8_t err)
438422
* @param port The destination port.
439423
* @param auth The credentials with which to authenticate with the specified
440424
* SOCKS5 proxy.
441-
* @param hSocket The SOCKS5 proxy socket.
425+
* @param sock The SOCKS5 proxy socket.
442426
*
443427
* @returns Whether or not the operation succeeded.
444428
*
@@ -448,7 +432,10 @@ static std::string Socks5ErrorString(uint8_t err)
448432
* @see <a href="https://www.ietf.org/rfc/rfc1928.txt">RFC1928: SOCKS Protocol
449433
* Version 5</a>
450434
*/
451-
static bool Socks5(const std::string& strDest, int port, const ProxyCredentials *auth, const SOCKET& hSocket)
435+
static bool Socks5(const std::string& strDest,
436+
int port,
437+
const ProxyCredentials* auth,
438+
const Sock& sock)
452439
{
453440
IntrRecvError recvr;
454441
LogPrint(BCLog::NET, "SOCKS5 connecting %s\n", strDest);
@@ -466,12 +453,12 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
466453
vSocks5Init.push_back(0x01); // 1 method identifier follows...
467454
vSocks5Init.push_back(SOCKS5Method::NOAUTH);
468455
}
469-
ssize_t ret = send(hSocket, (const char*)vSocks5Init.data(), vSocks5Init.size(), MSG_NOSIGNAL);
456+
ssize_t ret = sock.Send(vSocks5Init.data(), vSocks5Init.size(), MSG_NOSIGNAL);
470457
if (ret != (ssize_t)vSocks5Init.size()) {
471458
return error("Error sending to proxy");
472459
}
473460
uint8_t pchRet1[2];
474-
if ((recvr = InterruptibleRecv(pchRet1, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) {
461+
if ((recvr = InterruptibleRecv(pchRet1, 2, SOCKS5_RECV_TIMEOUT, sock)) != IntrRecvError::OK) {
475462
LogPrintf("Socks5() connect to %s:%d failed: InterruptibleRecv() timeout or other failure\n", strDest, port);
476463
return false;
477464
}
@@ -488,13 +475,13 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
488475
vAuth.insert(vAuth.end(), auth->username.begin(), auth->username.end());
489476
vAuth.push_back(auth->password.size());
490477
vAuth.insert(vAuth.end(), auth->password.begin(), auth->password.end());
491-
ret = send(hSocket, (const char*)vAuth.data(), vAuth.size(), MSG_NOSIGNAL);
478+
ret = sock.Send(vAuth.data(), vAuth.size(), MSG_NOSIGNAL);
492479
if (ret != (ssize_t)vAuth.size()) {
493480
return error("Error sending authentication to proxy");
494481
}
495482
LogPrint(BCLog::PROXY, "SOCKS5 sending proxy authentication %s:%s\n", auth->username, auth->password);
496483
uint8_t pchRetA[2];
497-
if ((recvr = InterruptibleRecv(pchRetA, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) {
484+
if ((recvr = InterruptibleRecv(pchRetA, 2, SOCKS5_RECV_TIMEOUT, sock)) != IntrRecvError::OK) {
498485
return error("Error reading proxy authentication response");
499486
}
500487
if (pchRetA[0] != 0x01 || pchRetA[1] != 0x00) {
@@ -514,12 +501,12 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
514501
vSocks5.insert(vSocks5.end(), strDest.begin(), strDest.end());
515502
vSocks5.push_back((port >> 8) & 0xFF);
516503
vSocks5.push_back((port >> 0) & 0xFF);
517-
ret = send(hSocket, (const char*)vSocks5.data(), vSocks5.size(), MSG_NOSIGNAL);
504+
ret = sock.Send(vSocks5.data(), vSocks5.size(), MSG_NOSIGNAL);
518505
if (ret != (ssize_t)vSocks5.size()) {
519506
return error("Error sending to proxy");
520507
}
521508
uint8_t pchRet2[4];
522-
if ((recvr = InterruptibleRecv(pchRet2, 4, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) {
509+
if ((recvr = InterruptibleRecv(pchRet2, 4, SOCKS5_RECV_TIMEOUT, sock)) != IntrRecvError::OK) {
523510
if (recvr == IntrRecvError::Timeout) {
524511
/* If a timeout happens here, this effectively means we timed out while connecting
525512
* to the remote node. This is very common for Tor, so do not print an
@@ -543,24 +530,24 @@ static bool Socks5(const std::string& strDest, int port, const ProxyCredentials
543530
uint8_t pchRet3[256];
544531
switch (pchRet2[3])
545532
{
546-
case SOCKS5Atyp::IPV4: recvr = InterruptibleRecv(pchRet3, 4, SOCKS5_RECV_TIMEOUT, hSocket); break;
547-
case SOCKS5Atyp::IPV6: recvr = InterruptibleRecv(pchRet3, 16, SOCKS5_RECV_TIMEOUT, hSocket); break;
533+
case SOCKS5Atyp::IPV4: recvr = InterruptibleRecv(pchRet3, 4, SOCKS5_RECV_TIMEOUT, sock); break;
534+
case SOCKS5Atyp::IPV6: recvr = InterruptibleRecv(pchRet3, 16, SOCKS5_RECV_TIMEOUT, sock); break;
548535
case SOCKS5Atyp::DOMAINNAME:
549536
{
550-
recvr = InterruptibleRecv(pchRet3, 1, SOCKS5_RECV_TIMEOUT, hSocket);
537+
recvr = InterruptibleRecv(pchRet3, 1, SOCKS5_RECV_TIMEOUT, sock);
551538
if (recvr != IntrRecvError::OK) {
552539
return error("Error reading from proxy");
553540
}
554541
int nRecv = pchRet3[0];
555-
recvr = InterruptibleRecv(pchRet3, nRecv, SOCKS5_RECV_TIMEOUT, hSocket);
542+
recvr = InterruptibleRecv(pchRet3, nRecv, SOCKS5_RECV_TIMEOUT, sock);
556543
break;
557544
}
558545
default: return error("Error: malformed proxy response");
559546
}
560547
if (recvr != IntrRecvError::OK) {
561548
return error("Error reading from proxy");
562549
}
563-
if ((recvr = InterruptibleRecv(pchRet3, 2, SOCKS5_RECV_TIMEOUT, hSocket)) != IntrRecvError::OK) {
550+
if ((recvr = InterruptibleRecv(pchRet3, 2, SOCKS5_RECV_TIMEOUT, sock)) != IntrRecvError::OK) {
564551
return error("Error reading from proxy");
565552
}
566553
LogPrint(BCLog::NET, "SOCKS5 connected %s\n", strDest);
@@ -903,18 +890,23 @@ bool IsProxy(const CNetAddr &addr) {
903890
* @param proxy The SOCKS5 proxy.
904891
* @param strDest The destination service to which to connect.
905892
* @param port The destination port.
906-
* @param hSocket The socket on which to connect to the SOCKS5 proxy.
893+
* @param sock The socket on which to connect to the SOCKS5 proxy.
907894
* @param nTimeout Wait this many milliseconds for the connection to the SOCKS5
908895
* proxy to be established.
909896
* @param[out] outProxyConnectionFailed Whether or not the connection to the
910897
* SOCKS5 proxy failed.
911898
*
912899
* @returns Whether or not the operation succeeded.
913900
*/
914-
bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocket, int nTimeout, bool& outProxyConnectionFailed)
901+
bool ConnectThroughProxy(const proxyType& proxy,
902+
const std::string& strDest,
903+
int port,
904+
const Sock& sock,
905+
int nTimeout,
906+
bool& outProxyConnectionFailed)
915907
{
916908
// first connect to proxy server
917-
if (!ConnectSocketDirectly(proxy.proxy, hSocket, nTimeout, true)) {
909+
if (!ConnectSocketDirectly(proxy.proxy, sock.Get(), nTimeout, true)) {
918910
outProxyConnectionFailed = true;
919911
return false;
920912
}
@@ -923,11 +915,11 @@ bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int
923915
ProxyCredentials random_auth;
924916
static std::atomic_int counter(0);
925917
random_auth.username = random_auth.password = strprintf("%i", counter++);
926-
if (!Socks5(strDest, (uint16_t)port, &random_auth, hSocket)) {
918+
if (!Socks5(strDest, (uint16_t)port, &random_auth, sock)) {
927919
return false;
928920
}
929921
} else {
930-
if (!Socks5(strDest, (uint16_t)port, 0, hSocket)) {
922+
if (!Socks5(strDest, (uint16_t)port, 0, sock)) {
931923
return false;
932924
}
933925
}

src/netbase.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,12 @@ std::unique_ptr<Sock> CreateSockTCP(const CService& address_family);
172172
extern std::function<std::unique_ptr<Sock>(const CService&)> CreateSock;
173173

174174
bool ConnectSocketDirectly(const CService &addrConnect, const SOCKET& hSocketRet, int nTimeout, bool manual_connection);
175-
bool ConnectThroughProxy(const proxyType &proxy, const std::string& strDest, int port, const SOCKET& hSocketRet, int nTimeout, bool& outProxyConnectionFailed);
175+
bool ConnectThroughProxy(const proxyType& proxy,
176+
const std::string& strDest,
177+
int port,
178+
const Sock& sock,
179+
int nTimeout,
180+
bool& outProxyConnectionFailed);
176181
/** Return readable error string for a network error code */
177182
std::string NetworkErrorString(int err);
178183
/** Close socket and set hSocket to INVALID_SOCKET */

0 commit comments

Comments
 (0)