diff --git a/include/libp2p/basic/message_read_writer.hpp b/include/libp2p/basic/message_read_writer.hpp index d78061b5e..e5914c968 100644 --- a/include/libp2p/basic/message_read_writer.hpp +++ b/include/libp2p/basic/message_read_writer.hpp @@ -8,6 +8,7 @@ #include +#include #include namespace libp2p::basic { @@ -37,5 +38,18 @@ namespace libp2p::basic { * Quantity of bytes written is passed as an argument in case of success */ virtual void write(BytesIn buffer, Writer::WriteCallbackFunc cb) = 0; + + /** + * Reads a message that is prepended with its length (coroutine version) + * @return awaitable with result containing read bytes or an error + */ + virtual boost::asio::awaitable read() = 0; + + /** + * Writes a message and preprends its length (coroutine version) + * @param buffer - bytes to be written + * @return awaitable with result containing number of bytes written or an error + */ + virtual boost::asio::awaitable> write(BytesIn buffer) = 0; }; } // namespace libp2p::basic diff --git a/include/libp2p/basic/reader.hpp b/include/libp2p/basic/reader.hpp index 661e03f55..7f525f4c0 100644 --- a/include/libp2p/basic/reader.hpp +++ b/include/libp2p/basic/reader.hpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -54,6 +55,13 @@ namespace libp2p::basic { * @param res read result * @param cb callback */ + + virtual boost::asio::awaitable> read( + BytesOut out, size_t bytes) = 0; + + virtual boost::asio::awaitable> readSome( + BytesOut out, size_t bytes) = 0; + virtual void deferReadCallback(outcome::result res, ReadCallbackFunc cb) = 0; }; diff --git a/include/libp2p/basic/writer.hpp b/include/libp2p/basic/writer.hpp index 4d152f682..c0ddbe144 100644 --- a/include/libp2p/basic/writer.hpp +++ b/include/libp2p/basic/writer.hpp @@ -6,6 +6,7 @@ #pragma once +#include #include #include @@ -34,6 +35,9 @@ namespace libp2p::basic { */ virtual void writeSome(BytesIn in, size_t bytes, WriteCallbackFunc cb) = 0; + virtual boost::asio::awaitable writeSome(BytesIn in, + size_t bytes) = 0; + /** * @brief Defers reporting error state to callback to avoid reentrancy * (i.e. callback will not be called before initiator function returns) diff --git a/include/libp2p/connection/capable_connection.hpp b/include/libp2p/connection/capable_connection.hpp index 9a75f789a..02dce8525 100644 --- a/include/libp2p/connection/capable_connection.hpp +++ b/include/libp2p/connection/capable_connection.hpp @@ -58,6 +58,13 @@ namespace libp2p::connection { */ virtual void newStream(StreamHandlerFunc cb) = 0; + /** + * @brief Opens new stream in a coroutine manner + * @return Awaitable result of a new Stream or error + */ + virtual boost::asio::awaitable>> + newStreamCoroutine() = 0; + /** * @brief Set a handler, which is called, when a new stream arrives from the * other side diff --git a/include/libp2p/connection/loopback_stream.hpp b/include/libp2p/connection/loopback_stream.hpp index 948ecc5df..5130a6dc4 100644 --- a/include/libp2p/connection/loopback_stream.hpp +++ b/include/libp2p/connection/loopback_stream.hpp @@ -39,6 +39,11 @@ namespace libp2p::connection { outcome::result remoteMultiaddr() const override; + // Coroutine-based methods + boost::asio::awaitable> read(BytesOut out, size_t bytes) override; + boost::asio::awaitable> readSome(BytesOut out, size_t bytes) override; + boost::asio::awaitable writeSome(BytesIn in, size_t bytes) override; + protected: void read(BytesOut out, size_t bytes, ReadCallbackFunc cb) override; diff --git a/include/libp2p/layer/layer_adaptor.hpp b/include/libp2p/layer/layer_adaptor.hpp index be8c3b751..4307892db 100644 --- a/include/libp2p/layer/layer_adaptor.hpp +++ b/include/libp2p/layer/layer_adaptor.hpp @@ -32,6 +32,14 @@ namespace libp2p::layer { std::shared_ptr conn, LayerConnCallbackFunc cb) const = 0; + /** + * Coroutine version of upgradeInbound + * @param conn - connection to be upgraded + * @return result with upgraded connection or error + */ + virtual boost::asio::awaitable>> upgradeInbound( + std::shared_ptr conn) const = 0; + /** * Make a next-layer connection from the current-layer one, using this * adaptor @@ -42,5 +50,15 @@ namespace libp2p::layer { const multi::Multiaddress &address, std::shared_ptr conn, LayerConnCallbackFunc cb) const = 0; + + /** + * Coroutine version of upgradeOutbound + * @param address - multiaddress of the remote peer + * @param conn - connection to be upgraded + * @return result with upgraded connection or error + */ + virtual boost::asio::awaitable>> upgradeOutbound( + const multi::Multiaddress &address, + std::shared_ptr conn) const = 0; }; } // namespace libp2p::layer diff --git a/include/libp2p/layer/websocket/ssl_connection.hpp b/include/libp2p/layer/websocket/ssl_connection.hpp index 549368485..cb80c1c5f 100644 --- a/include/libp2p/layer/websocket/ssl_connection.hpp +++ b/include/libp2p/layer/websocket/ssl_connection.hpp @@ -40,6 +40,13 @@ namespace libp2p::connection { void writeSome(BytesIn in, size_t bytes, WriteCallbackFunc cb) override; void deferWriteCallback(std::error_code ec, WriteCallbackFunc cb) override; + boost::asio::awaitable> read(BytesOut out, + size_t bytes) override; + boost::asio::awaitable> readSome( + BytesOut out, size_t bytes) override; + boost::asio::awaitable writeSome(BytesIn in, + size_t bytes) override; + private: std::shared_ptr connection_; std::shared_ptr ssl_context_; diff --git a/include/libp2p/layer/websocket/ws_adaptor.hpp b/include/libp2p/layer/websocket/ws_adaptor.hpp index 20f4bf6bc..8327e3098 100644 --- a/include/libp2p/layer/websocket/ws_adaptor.hpp +++ b/include/libp2p/layer/websocket/ws_adaptor.hpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace libp2p::layer { @@ -25,10 +26,17 @@ namespace libp2p::layer { void upgradeInbound(std::shared_ptr conn, LayerConnCallbackFunc cb) const override; + boost::asio::awaitable>> + upgradeInbound(std::shared_ptr conn) const override; + void upgradeOutbound(const multi::Multiaddress &address, std::shared_ptr conn, LayerConnCallbackFunc cb) const override; + boost::asio::awaitable>> + upgradeOutbound(const multi::Multiaddress &address, + std::shared_ptr conn) const override; + private: std::shared_ptr scheduler_; std::shared_ptr io_context_; diff --git a/include/libp2p/layer/websocket/ws_connection.hpp b/include/libp2p/layer/websocket/ws_connection.hpp index 139fc69c5..ba086a114 100644 --- a/include/libp2p/layer/websocket/ws_connection.hpp +++ b/include/libp2p/layer/websocket/ws_connection.hpp @@ -7,6 +7,7 @@ #pragma once #include +#include #include #include @@ -79,6 +80,13 @@ namespace libp2p::connection { void deferWriteCallback(std::error_code ec, WriteCallbackFunc cb) override; + boost::asio::awaitable> read(BytesOut out, + size_t bytes) override; + boost::asio::awaitable> readSome( + BytesOut out, size_t bytes) override; + boost::asio::awaitable writeSome(BytesIn in, + size_t bytes) override; + private: void setTimerPing(); void onPong(BytesIn payload); diff --git a/include/libp2p/layer/websocket/wss_adaptor.hpp b/include/libp2p/layer/websocket/wss_adaptor.hpp index 5e170ee61..749e8891f 100644 --- a/include/libp2p/layer/websocket/wss_adaptor.hpp +++ b/include/libp2p/layer/websocket/wss_adaptor.hpp @@ -7,6 +7,7 @@ #pragma once #include +#include namespace boost::asio { class io_context; @@ -36,10 +37,17 @@ namespace libp2p::layer { void upgradeInbound(std::shared_ptr conn, LayerConnCallbackFunc cb) const override; + boost::asio::awaitable>> + upgradeInbound(std::shared_ptr conn) const override; + void upgradeOutbound(const multi::Multiaddress &address, std::shared_ptr conn, LayerConnCallbackFunc cb) const override; + boost::asio::awaitable>> + upgradeOutbound(const multi::Multiaddress &address, + std::shared_ptr conn) const override; + private: std::shared_ptr io_context_; WssCertificate server_certificate_; diff --git a/include/libp2p/muxer/mplex/mplex_stream.hpp b/include/libp2p/muxer/mplex/mplex_stream.hpp index 074c38948..21bb7f1ca 100644 --- a/include/libp2p/muxer/mplex/mplex_stream.hpp +++ b/include/libp2p/muxer/mplex/mplex_stream.hpp @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -61,6 +62,10 @@ namespace libp2p::connection { void deferWriteCallback(std::error_code ec, WriteCallbackFunc cb) override; + boost::asio::awaitable> read(BytesOut out, size_t bytes) override; + boost::asio::awaitable> readSome(BytesOut out, size_t bytes) override; + boost::asio::awaitable writeSome(BytesIn in, size_t bytes) override; + bool isClosed() const override; void close(VoidResultHandlerFunc cb) override; diff --git a/include/libp2p/muxer/mplex/mplexed_connection.hpp b/include/libp2p/muxer/mplex/mplexed_connection.hpp index ea5ac229f..a9003cf38 100644 --- a/include/libp2p/muxer/mplex/mplexed_connection.hpp +++ b/include/libp2p/muxer/mplex/mplexed_connection.hpp @@ -68,6 +68,10 @@ namespace libp2p::connection { void readSome(BytesOut out, size_t bytes, ReadCallbackFunc cb) override; void writeSome(BytesIn in, size_t bytes, WriteCallbackFunc cb) override; + boost::asio::awaitable> read(BytesOut out, size_t bytes) override; + boost::asio::awaitable> readSome(BytesOut out, size_t bytes) override; + boost::asio::awaitable writeSome(BytesIn in, size_t bytes) override; + void deferReadCallback(outcome::result res, ReadCallbackFunc cb) override; void deferWriteCallback(std::error_code ec, WriteCallbackFunc cb) override; diff --git a/include/libp2p/muxer/muxer_adaptor.hpp b/include/libp2p/muxer/muxer_adaptor.hpp index 2a6f17615..f10212234 100644 --- a/include/libp2p/muxer/muxer_adaptor.hpp +++ b/include/libp2p/muxer/muxer_adaptor.hpp @@ -31,5 +31,15 @@ namespace libp2p::muxer { virtual void muxConnection( std::shared_ptr conn, CapConnCallbackFunc cb) const = 0; + + /** + * Make a muxed connection from the secure one, using this adaptor + * (coroutine version) + * @param conn - connection to be upgraded + * @return awaitable with result containing upgraded connection or error + */ + virtual boost::asio::awaitable< + outcome::result>> + muxConnection(std::shared_ptr conn) const = 0; }; } // namespace libp2p::muxer diff --git a/include/libp2p/muxer/yamux/yamux.hpp b/include/libp2p/muxer/yamux/yamux.hpp index 2010194f9..362e48f9b 100644 --- a/include/libp2p/muxer/yamux/yamux.hpp +++ b/include/libp2p/muxer/yamux/yamux.hpp @@ -32,6 +32,17 @@ namespace libp2p::muxer { void muxConnection(std::shared_ptr conn, CapConnCallbackFunc cb) const override; + /** + * Make a muxed connection from the secure one, using this adaptor + * (coroutine version) + * @param conn - connection to be upgraded + * @return awaitable with result containing upgraded connection or error + */ + boost::asio::awaitable< + outcome::result>> + muxConnection( + std::shared_ptr conn) const override; + private: MuxedConnectionConfig config_; std::shared_ptr scheduler_; diff --git a/include/libp2p/muxer/yamux/yamux_stream.hpp b/include/libp2p/muxer/yamux/yamux_stream.hpp index 7ed91e8ff..a37df373c 100644 --- a/include/libp2p/muxer/yamux/yamux_stream.hpp +++ b/include/libp2p/muxer/yamux/yamux_stream.hpp @@ -85,6 +85,14 @@ namespace libp2p::connection { outcome::result remoteMultiaddr() const override; + // Coroutine-based methods + boost::asio::awaitable> read(BytesOut out, + size_t bytes) override; + boost::asio::awaitable> readSome( + BytesOut out, size_t bytes) override; + boost::asio::awaitable writeSome(BytesIn in, + size_t bytes) override; + /// Increases send window. Called from Connection void increaseSendWindow(size_t delta); diff --git a/include/libp2p/muxer/yamux/yamuxed_connection.hpp b/include/libp2p/muxer/yamux/yamuxed_connection.hpp index 6ac7dd594..744201591 100644 --- a/include/libp2p/muxer/yamux/yamuxed_connection.hpp +++ b/include/libp2p/muxer/yamux/yamuxed_connection.hpp @@ -56,6 +56,9 @@ namespace libp2p::connection { void newStream(StreamHandlerFunc cb) override; + boost::asio::awaitable>> + newStreamCoroutine(); + void onStream(NewStreamHandlerFunc cb) override; outcome::result localPeer() const override; @@ -78,6 +81,14 @@ namespace libp2p::connection { ReadCallbackFunc cb) override; void deferWriteCallback(std::error_code ec, WriteCallbackFunc cb) override; + // Coroutine-based methods + boost::asio::awaitable> read(BytesOut out, + size_t bytes) override; + boost::asio::awaitable> readSome( + BytesOut out, size_t bytes) override; + boost::asio::awaitable writeSome(BytesIn in, + size_t bytes) override; + private: using Streams = std::unordered_map>; diff --git a/include/libp2p/network/impl/listener_manager_impl.hpp b/include/libp2p/network/impl/listener_manager_impl.hpp index 08582ba73..386481c9c 100644 --- a/include/libp2p/network/impl/listener_manager_impl.hpp +++ b/include/libp2p/network/impl/listener_manager_impl.hpp @@ -38,6 +38,9 @@ namespace libp2p::network { outcome::result listen(const multi::Multiaddress &ma) override; + boost::asio::awaitable> listenCoroutine( + const multi::Multiaddress &ma) override; + std::vector getListenAddresses() const override; std::vector getListenAddressesInterfaces() @@ -49,6 +52,10 @@ namespace libp2p::network { outcome::result> rconn) override; + void onConnectionCoro( + outcome::result> rconn) + override; + private: bool started = false; diff --git a/include/libp2p/network/listener_manager.hpp b/include/libp2p/network/listener_manager.hpp index df40e5a0b..5c80cb0d4 100644 --- a/include/libp2p/network/listener_manager.hpp +++ b/include/libp2p/network/listener_manager.hpp @@ -73,6 +73,8 @@ namespace libp2p::network { */ virtual outcome::result listen(const multi::Multiaddress &ma) = 0; + virtual boost::asio::awaitable> listenCoroutine( + const multi::Multiaddress &ma) = 0; /** * @brief Returns an unmodified list of addresses, added by user. */ @@ -98,6 +100,10 @@ namespace libp2p::network { virtual void onConnection( outcome::result> rconn) = 0; + + virtual void onConnectionCoro( + outcome::result> + rconn) = 0; }; } // namespace libp2p::network diff --git a/include/libp2p/protocol_muxer/multiselect.hpp b/include/libp2p/protocol_muxer/multiselect.hpp index 1ccc0b234..77dd68039 100644 --- a/include/libp2p/protocol_muxer/multiselect.hpp +++ b/include/libp2p/protocol_muxer/multiselect.hpp @@ -29,6 +29,13 @@ namespace libp2p::protocol_muxer::multiselect { bool negotiate_multiselect, ProtocolHandlerFunc cb) override; + /// Implements coroutine version of ProtocolMuxer API + boost::asio::awaitable> selectOneOf( + std::span protocols, + std::shared_ptr connection, + bool is_initiator, + bool negotiate_multistream) override; + /// Simple single stream negotiate procedure void simpleStreamNegotiate( const std::shared_ptr &stream, diff --git a/include/libp2p/protocol_muxer/multiselect/multiselect_instance.hpp b/include/libp2p/protocol_muxer/multiselect/multiselect_instance.hpp index 46eab6cbe..e9804ae27 100644 --- a/include/libp2p/protocol_muxer/multiselect/multiselect_instance.hpp +++ b/include/libp2p/protocol_muxer/multiselect/multiselect_instance.hpp @@ -7,6 +7,9 @@ #pragma once #include +#include +#include + #include "parser.hpp" namespace soralog { @@ -30,6 +33,13 @@ namespace libp2p::protocol_muxer::multiselect { bool negotiate_multiselect, Multiselect::ProtocolHandlerFunc cb); + /// Coroutine version of ProtocolMuxer API + boost::asio::awaitable> selectOneOf( + std::span protocols, + std::shared_ptr connection, + bool is_initiator, + bool negotiate_multiselect); + private: using Protocols = boost::container::small_vector; using Packet = std::shared_ptr; @@ -74,6 +84,28 @@ namespace libp2p::protocol_muxer::multiselect { /// Handles "na" reply, client-specific MaybeResult handleNA(); + /// Coroutine versions of send and receive operations + boost::asio::awaitable> sendCoro(Packet packet); + boost::asio::awaitable> receiveCoro(size_t bytes_needed); + boost::asio::awaitable processMessagesCoro(); + + /// Coroutine helper methods for protocol negotiation + boost::asio::awaitable> sendProtocolProposalCoro( + std::shared_ptr connection, + bool multistream_negotiated, + const std::string &protocol); + + boost::asio::awaitable> processProtocolMessageCoro( + std::shared_ptr connection, + bool is_initiator, + bool multistream_negotiated, + bool wait_for_protocol_reply, + size_t current_protocol, + boost::optional &wait_for_reply_sent, + const boost::container::small_vector &local_protocols, + const Message &msg, + boost::optional> &na_response); + /// Owner of this object, needed for reuse of instances Multiselect &owner_; diff --git a/include/libp2p/protocol_muxer/protocol_muxer.hpp b/include/libp2p/protocol_muxer/protocol_muxer.hpp index b1f9226f0..b663d66d0 100644 --- a/include/libp2p/protocol_muxer/protocol_muxer.hpp +++ b/include/libp2p/protocol_muxer/protocol_muxer.hpp @@ -49,6 +49,23 @@ namespace libp2p::protocol_muxer { bool negotiate_multistream, ProtocolHandlerFunc cb) = 0; + /** + * Coroutine version of selectOneOf + * @param protocols - set of protocols, one of which should be chosen during + * the negotiation + * @param connection - connection for which the protocol is being chosen + * @param is_initiator - true, if we initiated the connection and thus + * taking lead in the Multiselect protocol; false otherwise + * @param negotiate_multistream - true, if we need to negotiate multistream + * itself, this happens with fresh raw connections + * @return awaitable with chosen protocol or error + */ + virtual boost::asio::awaitable> selectOneOf( + std::span protocols, + std::shared_ptr connection, + bool is_initiator, + bool negotiate_multistream) = 0; + /** * Simple (Yes/No) negotiation of a single protocol on a fresh outbound * stream diff --git a/include/libp2p/security/noise/handshake_coro.hpp b/include/libp2p/security/noise/handshake_coro.hpp new file mode 100644 index 000000000..139ab7af6 --- /dev/null +++ b/include/libp2p/security/noise/handshake_coro.hpp @@ -0,0 +1,83 @@ +/** + * Copyright Quadrivium LLC + * All Rights Reserved + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace libp2p::security::noise { + + /** + * Coroutine version of the Noise handshake protocol. + */ + class HandshakeCoro : public std::enable_shared_from_this { + public: + HandshakeCoro( + std::shared_ptr crypto_provider, + std::unique_ptr noise_marshaller, + crypto::KeyPair local_key, + std::shared_ptr connection, + bool is_initiator, + boost::optional remote_peer_id, + std::shared_ptr key_marshaller); + + /** + * Performs the handshake process using coroutines + * @return awaitable with secure connection on success or error + */ + boost::asio::awaitable>> connect(); + + private: + const std::string kPayloadPrefix = "noise-libp2p-static-key:"; + + void setCipherStates(std::shared_ptr cs1, + std::shared_ptr cs2); + + outcome::result> generateHandshakePayload( + const DHKey &keypair); + + boost::asio::awaitable> sendHandshakeMessage( + BytesIn payload); + + boost::asio::awaitable>> readHandshakeMessage(); + + outcome::result handleRemoteHandshakePayload(BytesIn payload); + + boost::asio::awaitable>> runHandshake(); + + // constructor params + std::shared_ptr crypto_provider_; + std::unique_ptr noise_marshaller_; + const crypto::KeyPair local_key_; + std::shared_ptr conn_; + bool initiator_; /// false for incoming connections + std::shared_ptr key_marshaller_; + std::shared_ptr read_buffer_; + std::shared_ptr rw_; + + // other params + std::unique_ptr handshake_state_; + std::shared_ptr enc_; + std::shared_ptr dec_; + boost::optional remote_peer_id_; + boost::optional remote_peer_pubkey_; + + log::Logger log_ = log::createLogger("NoiseHandshakeCoro"); + }; + +} // namespace libp2p::security::noise diff --git a/include/libp2p/security/noise/insecure_rw.hpp b/include/libp2p/security/noise/insecure_rw.hpp index d12129f61..6c8c9da44 100644 --- a/include/libp2p/security/noise/insecure_rw.hpp +++ b/include/libp2p/security/noise/insecure_rw.hpp @@ -40,6 +40,12 @@ namespace libp2p::security::noise { /// write the given bytes to the network void write(BytesIn buffer, basic::Writer::WriteCallbackFunc cb) override; + /// read next message from the network (coroutine version) + boost::asio::awaitable read() override; + + /// write the given bytes to the network (coroutine version) + boost::asio::awaitable> write(BytesIn buffer) override; + private: std::shared_ptr connection_; std::shared_ptr buffer_; diff --git a/include/libp2p/security/noise/noise.hpp b/include/libp2p/security/noise/noise.hpp index 32fef6bb5..66b485dd9 100644 --- a/include/libp2p/security/noise/noise.hpp +++ b/include/libp2p/security/noise/noise.hpp @@ -12,6 +12,8 @@ #include #include +#include + namespace libp2p::security { class Noise : public SecurityAdaptor, @@ -34,6 +36,27 @@ namespace libp2p::security { const peer::PeerId &p, SecConnCallbackFunc cb) override; + /** + * Coroutine version of secureInbound + * @param inbound - connection to be secured + * @return awaitable with secured connection or error + */ + boost::asio::awaitable< + outcome::result>> + secureInboundCoro( + std::shared_ptr inbound) override; + + /** + * Coroutine version of secureOutbound + * @param outbound - connection to be secured + * @param p - remote peer id + * @return awaitable with secured connection or error + */ + boost::asio::awaitable< + outcome::result>> + secureOutboundCoro(std::shared_ptr outbound, + const peer::PeerId &p) override; + private: log::Logger log_ = log::createLogger("Noise"); libp2p::crypto::KeyPair local_key_; diff --git a/include/libp2p/security/noise/noise_connection.hpp b/include/libp2p/security/noise/noise_connection.hpp index fb92f26bc..77571e67b 100644 --- a/include/libp2p/security/noise/noise_connection.hpp +++ b/include/libp2p/security/noise/noise_connection.hpp @@ -69,6 +69,14 @@ namespace libp2p::connection { outcome::result remotePublicKey() const override; + // Coroutine-based methods + boost::asio::awaitable> read(BytesOut out, + size_t bytes) override; + boost::asio::awaitable> readSome( + BytesOut out, size_t bytes) override; + boost::asio::awaitable writeSome(BytesIn in, + size_t bytes) override; + private: void readSome(BytesOut out, size_t bytes, diff --git a/include/libp2p/security/security_adaptor.hpp b/include/libp2p/security/security_adaptor.hpp index ca0b94640..2cc176dfb 100644 --- a/include/libp2p/security/security_adaptor.hpp +++ b/include/libp2p/security/security_adaptor.hpp @@ -35,6 +35,15 @@ namespace libp2p::security { std::shared_ptr inbound, SecConnCallbackFunc cb) = 0; + /** + * Coroutine version of secureInbound + * @param inbound - connection to be secured + * @return awaitable with secured connection or error + */ + virtual boost::asio::awaitable< + outcome::result>> + secureInboundCoro(std::shared_ptr inbound) = 0; + /** * @brief Secure the connection, either locally or by communicating with * opposing node via outbound connection (we are initiator). @@ -46,5 +55,10 @@ namespace libp2p::security { std::shared_ptr outbound, const peer::PeerId &p, SecConnCallbackFunc cb) = 0; + + virtual boost::asio::awaitable< + outcome::result>> + secureOutboundCoro(std::shared_ptr outbound, + const peer::PeerId &p) = 0; }; } // namespace libp2p::security diff --git a/include/libp2p/transport/impl/upgrader_impl.hpp b/include/libp2p/transport/impl/upgrader_impl.hpp index a2803f25c..6296b6e15 100644 --- a/include/libp2p/transport/impl/upgrader_impl.hpp +++ b/include/libp2p/transport/impl/upgrader_impl.hpp @@ -49,14 +49,30 @@ namespace libp2p::transport { ProtoAddrVec layers, OnLayerCallbackFunc cb) override; + boost::asio::awaitable> upgradeLayersInboundCoro( + RawSPtr conn, ProtoAddrVec layers) override; + + boost::asio::awaitable> upgradeLayersOutboundCoro( + const multi::Multiaddress &address, + RawSPtr conn, + ProtoAddrVec layers) override; + void upgradeToSecureInbound(LayerSPtr conn, OnSecuredCallbackFunc cb) override; void upgradeToSecureOutbound(LayerSPtr conn, const peer::PeerId &remoteId, OnSecuredCallbackFunc cb) override; + boost::asio::awaitable> upgradeToSecureInboundCoro( + LayerSPtr conn) override; + boost::asio::awaitable> upgradeToSecureOutboundCoro( + LayerSPtr conn, const peer::PeerId &remoteId) override; + void upgradeToMuxed(SecSPtr conn, OnMuxedCallbackFunc cb) override; + boost::asio::awaitable> upgradeToMuxedCoro( + SecSPtr conn) override; + enum class Error { SUCCESS = 0, NO_ADAPTOR_FOUND = 1 }; private: @@ -87,6 +103,23 @@ namespace libp2p::transport { size_t layer_index, OnLayerCallbackFunc cb); + /** + * Coroutine version to upgrade to next layer outbound + */ + boost::asio::awaitable> upgradeToNextLayerOutboundCoro( + const multi::Multiaddress &address, + LayerSPtr conn, + ProtoAddrVec layers, + size_t layer_index); + + /** + * Coroutine version to upgrade to next layer inbound + */ + boost::asio::awaitable> upgradeToNextLayerInboundCoro( + LayerSPtr conn, + ProtoAddrVec layers, + size_t layer_index); + std::shared_ptr protocol_muxer_; std::vector layer_adaptors_; diff --git a/include/libp2p/transport/impl/upgrader_session.hpp b/include/libp2p/transport/impl/upgrader_session.hpp index 72bf913dc..d9efcf35e 100644 --- a/include/libp2p/transport/impl/upgrader_session.hpp +++ b/include/libp2p/transport/impl/upgrader_session.hpp @@ -8,6 +8,7 @@ #include +#include #include #include #include @@ -33,6 +34,25 @@ namespace libp2p::transport { void upgradeOutbound(const multi::Multiaddress &address, const peer::PeerId &remoteId); + /** + * Coroutine version of upgradeInbound + * @return awaitable with capable connection or error + */ + boost::asio::awaitable< + outcome::result>> + upgradeInboundCoro(); + + /** + * Coroutine version of upgradeOutbound + * @param address - multiaddress to connect to + * @param remoteId - remote peer ID + * @return awaitable with capable connection or error + */ + boost::asio::awaitable< + outcome::result>> + upgradeOutboundCoro(const multi::Multiaddress &address, + const peer::PeerId &remoteId); + private: std::shared_ptr upgrader_; ProtoAddrVec layers_; @@ -44,9 +64,38 @@ namespace libp2p::transport { void secureInbound(std::shared_ptr conn); + /** + * Coroutine version of secureInbound + * @param conn - connection to be upgraded + * @return awaitable with secured connection or error + */ + boost::asio::awaitable< + outcome::result>> + secureInboundCoro(std::shared_ptr conn); + + /** + * Coroutine version of secureOutbound + * @param conn - connection to be upgraded + * @param remoteId - remote peer ID + * @return awaitable with secured connection or error + */ + boost::asio::awaitable< + outcome::result>> + secureOutboundCoro(std::shared_ptr conn, + const peer::PeerId &remoteId); + void onSecured( outcome::result> res); + /** + * Coroutine version of onSecured + * @param secure_conn - secure connection to be muxed + * @return awaitable with capable connection or error + */ + boost::asio::awaitable< + outcome::result>> + onSecuredCoro(std::shared_ptr secure_conn); + public: LIBP2P_METRICS_INSTANCE_COUNT_IF_ENABLED( libp2p::transport::UpgraderSession); diff --git a/include/libp2p/transport/quic/connection.hpp b/include/libp2p/transport/quic/connection.hpp index b510c46af..a32c801b3 100644 --- a/include/libp2p/transport/quic/connection.hpp +++ b/include/libp2p/transport/quic/connection.hpp @@ -48,6 +48,14 @@ namespace libp2p::transport { void writeSome(BytesIn in, size_t bytes, WriteCallbackFunc cb) override; void deferWriteCallback(std::error_code ec, WriteCallbackFunc cb) override; + // Coroutine-based methods + boost::asio::awaitable> read(BytesOut out, + size_t bytes) override; + boost::asio::awaitable> readSome( + BytesOut out, size_t bytes) override; + boost::asio::awaitable writeSome(BytesIn in, + size_t bytes) override; + // Closeable bool isClosed() const override; outcome::result close() override; diff --git a/include/libp2p/transport/quic/error.hpp b/include/libp2p/transport/quic/error.hpp index 45be18008..c9fb5e9a6 100644 --- a/include/libp2p/transport/quic/error.hpp +++ b/include/libp2p/transport/quic/error.hpp @@ -13,6 +13,7 @@ namespace libp2p { HANDSHAKE_FAILED, CONN_CLOSED, STREAM_CLOSED, + STREAM_READ_IN_PROGRESS, TOO_MANY_STREAMS, CANT_CREATE_CONNECTION, CANT_OPEN_STREAM, @@ -26,6 +27,8 @@ namespace libp2p { return "CONN_CLOSED"; case E::STREAM_CLOSED: return "STREAM_CLOSED"; + case E::STREAM_READ_IN_PROGRESS: + return "STREAM_READ_IN_PROGRESS"; case E::TOO_MANY_STREAMS: return "TOO_MANY_STREAMS"; case E::CANT_CREATE_CONNECTION: diff --git a/include/libp2p/transport/quic/stream.hpp b/include/libp2p/transport/quic/stream.hpp index 42ccc9b8a..5b60660f4 100644 --- a/include/libp2p/transport/quic/stream.hpp +++ b/include/libp2p/transport/quic/stream.hpp @@ -37,6 +37,11 @@ namespace libp2p::connection { void deferReadCallback(outcome::result res, ReadCallbackFunc cb) override; + // Coroutine-based methods + boost::asio::awaitable> read(BytesOut out, size_t bytes) override; + boost::asio::awaitable> readSome(BytesOut out, size_t bytes) override; + boost::asio::awaitable writeSome(BytesIn in, size_t bytes) override; + // Writer void writeSome(BytesIn in, size_t bytes, WriteCallbackFunc cb) override; void deferWriteCallback(std::error_code ec, WriteCallbackFunc cb) override; diff --git a/include/libp2p/transport/tcp/tcp_connection.hpp b/include/libp2p/transport/tcp/tcp_connection.hpp index 57274d07f..5f5782759 100644 --- a/include/libp2p/transport/tcp/tcp_connection.hpp +++ b/include/libp2p/transport/tcp/tcp_connection.hpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -61,6 +62,14 @@ namespace libp2p::transport { ConnectCallbackFunc cb, std::chrono::milliseconds timeout); + // Coroutine-based connect methods + boost::asio::awaitable connect( + const ResolverResultsType &iterator); + + boost::asio::awaitable connect( + const ResolverResultsType &iterator, + std::chrono::milliseconds timeout); + void read(BytesOut out, size_t bytes, ReadCallbackFunc cb) override; void readSome(BytesOut out, size_t bytes, ReadCallbackFunc cb) override; @@ -72,6 +81,15 @@ namespace libp2p::transport { void deferWriteCallback(std::error_code ec, WriteCallbackFunc cb) override; + // Coroutine-based methods + boost::asio::awaitable> read(BytesOut out, + size_t bytes); + + boost::asio::awaitable> readSome(BytesOut out, + size_t bytes); + + boost::asio::awaitable writeSome(BytesIn in, size_t bytes); + outcome::result remoteMultiaddr() override; outcome::result localMultiaddr() override; diff --git a/include/libp2p/transport/tcp/tcp_listener.hpp b/include/libp2p/transport/tcp/tcp_listener.hpp index b1802faca..3e71c9b34 100644 --- a/include/libp2p/transport/tcp/tcp_listener.hpp +++ b/include/libp2p/transport/tcp/tcp_listener.hpp @@ -31,10 +31,20 @@ namespace libp2p::transport { outcome::result getListenMultiaddr() const override; + boost::asio::io_context &getContext() const override; + bool isClosed() const override; outcome::result close() override; + /** + * Asynchronously accept a new connection + * @return Awaitable result of a new CapableConnection or error + */ + boost::asio::awaitable< + outcome::result>> + asyncAccept() override; + private: boost::asio::io_context &context_; std::shared_ptr upgrader_; diff --git a/include/libp2p/transport/transport_listener.hpp b/include/libp2p/transport/transport_listener.hpp index 02bdc1e09..fadab0e8f 100644 --- a/include/libp2p/transport/transport_listener.hpp +++ b/include/libp2p/transport/transport_listener.hpp @@ -52,5 +52,19 @@ namespace libp2p::transport { * @return collection of those addresses */ virtual outcome::result getListenMultiaddr() const = 0; + + /** + * Get the io_context of this listener + * @return reference to the io_context + */ + virtual boost::asio::io_context &getContext() const = 0; + + /** + * Asynchronously accept a new connection + * @return Awaitable result of a new CapableConnection or error + */ + virtual boost::asio::awaitable< + outcome::result>> + asyncAccept() = 0; }; } // namespace libp2p::transport diff --git a/include/libp2p/transport/upgrader.hpp b/include/libp2p/transport/upgrader.hpp index 14d409957..cb041c55b 100644 --- a/include/libp2p/transport/upgrader.hpp +++ b/include/libp2p/transport/upgrader.hpp @@ -12,6 +12,7 @@ #include #include #include +#include namespace libp2p::transport { @@ -46,6 +47,18 @@ namespace libp2p::transport { ProtoAddrVec layers, OnLayerCallbackFunc cb) = 0; + /** + * Coroutine version of upgradeLayersOutbound + * @param address multiaddress to connect to + * @param conn to be upgraded + * @param layers - vector of layer protocols for upgrade + * @return awaitable with layered connection or error + */ + virtual boost::asio::awaitable> + upgradeLayersOutboundCoro(const multi::Multiaddress &address, + RawSPtr conn, + ProtoAddrVec layers) = 0; + /** * Upgrade inbound connection to each required layers * @param conn to be upgraded @@ -58,6 +71,15 @@ namespace libp2p::transport { ProtoAddrVec layers, OnLayerCallbackFunc cb) = 0; + /** + * Coroutine version of upgradeLayersInbound + * @param conn to be upgraded + * @param layers - vector of layer protocols for upgrade + * @return awaitable with layered connection or error + */ + virtual boost::asio::awaitable> + upgradeLayersInboundCoro(RawSPtr conn, ProtoAddrVec layers) = 0; + /** * Upgrade outbound raw connection to the secure one * @param conn to be upgraded @@ -69,6 +91,16 @@ namespace libp2p::transport { const peer::PeerId &remoteId, OnSecuredCallbackFunc cb) = 0; + /** + * Coroutine version of upgradeToSecureOutbound + * @param conn to be upgraded + * @param remoteId peer id of remote peer + * @return awaitable with secured connection or error + */ + virtual boost::asio::awaitable> + upgradeToSecureOutboundCoro(LayerSPtr conn, + const peer::PeerId &remoteId) = 0; + /** * Upgrade inbound raw connection to the secure one * @param conn to be upgraded @@ -78,6 +110,14 @@ namespace libp2p::transport { virtual void upgradeToSecureInbound(LayerSPtr conn, OnSecuredCallbackFunc cb) = 0; + /** + * Coroutine version of upgradeToSecureInbound + * @param conn to be upgraded + * @return awaitable with secured connection or error + */ + virtual boost::asio::awaitable> + upgradeToSecureInboundCoro(LayerSPtr conn) = 0; + /** * Upgrade a secure connection to the muxed (capable) one * @param conn to be upgraded @@ -85,6 +125,14 @@ namespace libp2p::transport { * error happens */ virtual void upgradeToMuxed(SecSPtr conn, OnMuxedCallbackFunc cb) = 0; + + /** + * Coroutine version of upgradeToMuxed + * @param conn to be upgraded + * @return awaitable with capable connection or error + */ + virtual boost::asio::awaitable> + upgradeToMuxedCoro(SecSPtr conn) = 0; }; } // namespace libp2p::transport diff --git a/src/connection/loopback_stream.cpp b/src/connection/loopback_stream.cpp index 8cc92ff1e..3fd1ab320 100644 --- a/src/connection/loopback_stream.cpp +++ b/src/connection/loopback_stream.cpp @@ -84,6 +84,29 @@ namespace libp2p::connection { readReturnSize(shared_from_this(), out, std::move(cb)); } + boost::asio::awaitable> LoopbackStream::read(BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + if (is_reset_) { + co_return Error::STREAM_RESET_BY_HOST; + } + if (!is_readable_) { + co_return Error::STREAM_NOT_READABLE; + } + if (bytes == 0 || out.empty() || static_cast(out.size()) < bytes) { + co_return Error::STREAM_INVALID_ARGUMENT; + } + // Wait until enough data is available + while (buffer_.size() < bytes) { + co_await boost::asio::post(*io_context_, boost::asio::use_awaitable); + } + auto to_read = std::min(buffer_.size(), bytes); + if (boost::asio::buffer_copy(boost::asio::buffer(out.data(), to_read), buffer_.data(), to_read) != to_read) { + co_return Error::STREAM_INTERNAL_ERROR; + } + buffer_.consume(to_read); + co_return to_read; + } + void LoopbackStream::writeSome(BytesIn in, size_t bytes, libp2p::basic::Writer::WriteCallbackFunc cb) { @@ -116,6 +139,23 @@ namespace libp2p::connection { } } + boost::asio::awaitable LoopbackStream::writeSome(BytesIn in, size_t bytes) { + if (is_reset_) { + co_return Error::STREAM_RESET_BY_HOST; + } + if (!is_writable_) { + co_return Error::STREAM_NOT_WRITABLE; + } + if (bytes == 0 || in.empty() || static_cast(in.size()) < bytes) { + co_return Error::STREAM_INVALID_ARGUMENT; + } + if (boost::asio::buffer_copy(buffer_.prepare(bytes), boost::asio::const_buffer(in.data(), bytes)) != bytes) { + co_return Error::STREAM_INTERNAL_ERROR; + } + buffer_.commit(bytes); + co_return std::error_code{}; + } + void LoopbackStream::readSome(BytesOut out, size_t bytes, libp2p::basic::Reader::ReadCallbackFunc cb) { @@ -172,6 +212,29 @@ namespace libp2p::connection { } } + boost::asio::awaitable> LoopbackStream::readSome(BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + if (is_reset_) { + co_return Error::STREAM_RESET_BY_HOST; + } + if (!is_readable_) { + co_return Error::STREAM_NOT_READABLE; + } + if (bytes == 0 || out.empty() || static_cast(out.size()) < bytes) { + co_return Error::STREAM_INVALID_ARGUMENT; + } + // Wait until any data is available + while (buffer_.size() == 0) { + co_await boost::asio::post(*io_context_, boost::asio::use_awaitable); + } + auto to_read = std::min(buffer_.size(), bytes); + if (boost::asio::buffer_copy(boost::asio::buffer(out.data(), to_read), buffer_.data(), to_read) != to_read) { + co_return Error::STREAM_INTERNAL_ERROR; + } + buffer_.consume(to_read); + co_return to_read; + } + void LoopbackStream::deferReadCallback(outcome::result res, basic::Reader::ReadCallbackFunc cb) { deferCallback( diff --git a/src/layer/websocket/ssl_connection.cpp b/src/layer/websocket/ssl_connection.cpp index 96d16b1f8..0ea11984b 100644 --- a/src/layer/websocket/ssl_connection.cpp +++ b/src/layer/websocket/ssl_connection.cpp @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include #include #include @@ -73,4 +74,74 @@ namespace libp2p::connection { WriteCallbackFunc cb) { connection_->deferWriteCallback(ec, std::move(cb)); } + + boost::asio::awaitable> SslConnection::read( + BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + + if (bytes == 0 || out.empty()) { + co_return outcome::success(0); + } + + size_t total_bytes_read = 0; + + while (total_bytes_read < bytes) { + auto result = co_await readSome(BytesOut{out.data() + total_bytes_read, + out.size() - total_bytes_read}, + bytes - total_bytes_read); + + if (!result) { + co_return result.error(); + } + + size_t bytes_read = result.value(); + if (bytes_read == 0) { + break; // EOF reached + } + + total_bytes_read += bytes_read; + } + + co_return outcome::success(total_bytes_read); + } + + boost::asio::awaitable> SslConnection::readSome( + BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + + if (bytes == 0 || out.empty()) { + co_return outcome::success(0); + } + + try { + auto bytes_read = co_await ssl_.async_read_some( + asioBuffer(out), boost::asio::use_awaitable); + + co_return outcome::success(bytes_read); + } catch (const boost::system::system_error &e) { + co_return outcome::failure(e.code()); + } catch (const std::exception &) { + co_return outcome::failure(AsAsioReadWrite::error()); + } + } + + boost::asio::awaitable SslConnection::writeSome( + BytesIn in, size_t bytes) { + ambigousSize(in, bytes); + + if (bytes == 0 || in.empty()) { + co_return std::error_code{}; + } + + try { + co_await ssl_.async_write_some(asioBuffer(in), + boost::asio::use_awaitable); + + co_return std::error_code{}; + } catch (const boost::system::system_error &e) { + co_return e.code(); + } catch (const std::exception &) { + co_return AsAsioReadWrite::error(); + } + } } // namespace libp2p::connection diff --git a/src/layer/websocket/ws_adaptor.cpp b/src/layer/websocket/ws_adaptor.cpp index 19d7b3192..37f6a8c04 100644 --- a/src/layer/websocket/ws_adaptor.cpp +++ b/src/layer/websocket/ws_adaptor.cpp @@ -8,6 +8,7 @@ #include #include +#include namespace libp2p::layer { @@ -40,6 +41,23 @@ namespace libp2p::layer { }); } + boost::asio::awaitable>> + WsAdaptor::upgradeInbound(std::shared_ptr conn) const { + log_->info("upgrade inbound connection to websocket (coroutine)"); + auto ws = std::make_shared( + config_, io_context_, std::move(conn), scheduler_); + + try { + co_await ws->ws_.async_accept(boost::asio::use_awaitable); + ws->start(); + co_return outcome::success(std::move(ws)); + } catch (const boost::system::system_error &error) { + co_return outcome::failure(error.code()); + } catch (const std::exception &) { + co_return outcome::failure(std::errc::io_error); + } + } + void WsAdaptor::upgradeOutbound( const multi::Multiaddress &address, std::shared_ptr conn, @@ -59,4 +77,26 @@ namespace libp2p::layer { }); } + boost::asio::awaitable>> + WsAdaptor::upgradeOutbound( + const multi::Multiaddress &address, + std::shared_ptr conn) const { + auto host = address.getProtocolsWithValues().begin()->second; + auto ws = std::make_shared( + config_, io_context_, std::move(conn), scheduler_); + + try { + co_await ws->ws_.async_handshake( + host, + "/", + boost::asio::use_awaitable); + ws->start(); + co_return outcome::success(std::move(ws)); + } catch (const boost::system::system_error &error) { + co_return outcome::failure(error.code()); + } catch (const std::exception &) { + co_return outcome::failure(std::errc::io_error); + } + } + } // namespace libp2p::layer diff --git a/src/layer/websocket/ws_connection.cpp b/src/layer/websocket/ws_connection.cpp index aca11b629..0fba5f812 100644 --- a/src/layer/websocket/ws_connection.cpp +++ b/src/layer/websocket/ws_connection.cpp @@ -6,6 +6,7 @@ #include +#include #include #include #include @@ -131,6 +132,89 @@ namespace libp2p::connection { connection_->deferWriteCallback(ec, std::move(cb)); } + boost::asio::awaitable> WsConnection::read( + BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + SL_TRACE(log_, "read {} bytes (coroutine)", bytes); + + if (bytes == 0 || out.empty()) { + co_return outcome::success(0); + } + + size_t total_bytes_read = 0; + + while (total_bytes_read < bytes) { + auto result = co_await readSome(BytesOut{out.data() + total_bytes_read, + out.size() - total_bytes_read}, + bytes - total_bytes_read); + + if (!result) { + co_return result.error(); + } + + size_t bytes_read = result.value(); + if (bytes_read == 0) { + break; // EOF reached + } + + total_bytes_read += bytes_read; + } + + co_return outcome::success(total_bytes_read); + } + + boost::asio::awaitable> WsConnection::readSome( + BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + SL_TRACE(log_, "read some upto {} bytes (coroutine)", bytes); + + if (bytes == 0 || out.empty()) { + co_return outcome::success(0); + } + + try { + size_t n = co_await ws_.async_read_some( + asioBuffer(out), boost::asio::use_awaitable); + + if (n != 0) { + co_return outcome::success(n); + } + + // If we got zero bytes and we're still connected, try again + if (!isClosed()) { + auto result = co_await readSome(out, out.size()); + co_return result; + } + + co_return outcome::failure(boost::system::errc::broken_pipe); + } catch (const boost::system::system_error &error) { + co_return outcome::failure(error.code()); + } catch (const std::exception &) { + co_return outcome::failure(boost::system::errc::io_error); + } + } + + boost::asio::awaitable WsConnection::writeSome( + BytesIn in, size_t bytes) { + ambigousSize(in, bytes); + SL_TRACE(log_, "write some upto {} bytes (coroutine)", bytes); + + if (bytes == 0 || in.empty()) { + co_return std::error_code{}; + } + + try { + co_await ws_.async_write_some( + true, asioBuffer(in), boost::asio::use_awaitable); + co_return std::error_code{}; + } catch (const boost::system::system_error &error) { + co_return error.code(); + } catch (const std::exception &) { + co_return boost::system::errc::make_error_code( + boost::system::errc::io_error); + } + } + void WsConnection::setTimerPing() { // Set pong handler using boost::beast::websocket::frame_type; diff --git a/src/layer/websocket/wss_adaptor.cpp b/src/layer/websocket/wss_adaptor.cpp index b9890f516..828be2350 100644 --- a/src/layer/websocket/wss_adaptor.cpp +++ b/src/layer/websocket/wss_adaptor.cpp @@ -6,6 +6,7 @@ #include +#include #include #include @@ -59,6 +60,29 @@ namespace libp2p::layer { }); } + boost::asio::awaitable>> + WssAdaptor::upgradeInbound(std::shared_ptr conn) const { + if (not server_certificate_.context) { + co_return outcome::failure(std::errc::address_family_not_supported); + } + + auto ssl = std::make_shared( + io_context_, std::move(conn), server_certificate_.context); + + try { + co_await ssl->ssl_.async_handshake( + boost::asio::ssl::stream_base::handshake_type::server, + boost::asio::use_awaitable); + + auto result = co_await ws_adaptor_->upgradeInbound(std::move(ssl)); + co_return result; + } catch (const boost::system::system_error &error) { + co_return outcome::failure(error.code()); + } catch (const std::exception &) { + co_return outcome::failure(std::errc::io_error); + } + } + void WssAdaptor::upgradeOutbound( const multi::Multiaddress &address, std::shared_ptr conn, @@ -75,4 +99,25 @@ namespace libp2p::layer { ws->upgradeOutbound(address, std::move(ssl), std::move(cb)); }); } + + boost::asio::awaitable>> + WssAdaptor::upgradeOutbound( + const multi::Multiaddress &address, + std::shared_ptr conn) const { + auto ssl = std::make_shared( + io_context_, std::move(conn), client_context_); + + try { + co_await ssl->ssl_.async_handshake( + boost::asio::ssl::stream_base::handshake_type::client, + boost::asio::use_awaitable); + + auto result = co_await ws_adaptor_->upgradeOutbound(address, std::move(ssl)); + co_return result; + } catch (const boost::system::system_error &error) { + co_return outcome::failure(error.code()); + } catch (const std::exception &) { + co_return outcome::failure(std::errc::io_error); + } + } } // namespace libp2p::layer diff --git a/src/muxer/mplex/mplex_stream.cpp b/src/muxer/mplex/mplex_stream.cpp index 041d767db..3cf1f637d 100644 --- a/src/muxer/mplex/mplex_stream.cpp +++ b/src/muxer/mplex/mplex_stream.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include @@ -39,6 +40,32 @@ namespace libp2p::connection { readReturnSize(shared_from_this(), out, std::move(cb)); } + boost::asio::awaitable> MplexStream::read(BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + if (is_reset_) { + co_return Error::STREAM_RESET_BY_PEER; + } + if (!is_readable_) { + co_return Error::STREAM_NOT_READABLE; + } + if (bytes == 0 || out.empty() || static_cast(out.size()) < bytes) { + co_return Error::STREAM_INVALID_ARGUMENT; + } + while (read_buffer_.size() < bytes) { + co_await boost::asio::post(boost::asio::use_awaitable); + if (is_reset_ || !is_readable_) { + co_return Error::STREAM_RESET_BY_PEER; + } + } + auto size = std::min(read_buffer_.size(), bytes); + if (boost::asio::buffer_copy(boost::asio::buffer(out.data(), size), read_buffer_.data(), size) != size) { + co_return Error::STREAM_INTERNAL_ERROR; + } + read_buffer_.consume(size); + receive_window_size_ += size; + co_return size; + } + void MplexStream::readDone(outcome::result res) { auto cb{std::move(reading_->cb)}; reading_.reset(); @@ -88,6 +115,32 @@ namespace libp2p::connection { } } + boost::asio::awaitable> MplexStream::readSome(BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + if (is_reset_) { + co_return Error::STREAM_RESET_BY_PEER; + } + if (!is_readable_) { + co_return Error::STREAM_NOT_READABLE; + } + if (bytes == 0 || out.empty() || static_cast(out.size()) < bytes) { + co_return Error::STREAM_INVALID_ARGUMENT; + } + while (read_buffer_.size() == 0) { + co_await boost::asio::post(boost::asio::use_awaitable); + if (is_reset_ || !is_readable_) { + co_return Error::STREAM_RESET_BY_PEER; + } + } + auto size = std::min(read_buffer_.size(), bytes); + if (boost::asio::buffer_copy(boost::asio::buffer(out.data(), size), read_buffer_.data(), size) != size) { + co_return Error::STREAM_INTERNAL_ERROR; + } + read_buffer_.consume(size); + receive_window_size_ += size; + co_return size; + } + void MplexStream::writeSome(BytesIn in, size_t bytes, WriteCallbackFunc cb) { ambigousSize(in, bytes); // TODO(107): Reentrancy @@ -136,6 +189,51 @@ namespace libp2p::connection { }); } + boost::asio::awaitable MplexStream::writeSome(BytesIn in, size_t bytes) { + ambigousSize(in, bytes); + if (is_reset_) { + co_return Error::STREAM_RESET_BY_PEER; + } + if (!is_writable_) { + co_return Error::STREAM_NOT_WRITABLE; + } + if (bytes == 0 || in.empty() || static_cast(in.size()) < bytes) { + co_return Error::STREAM_INVALID_ARGUMENT; + } + if (is_writing_) { + // Wait until not writing + while (is_writing_) { + co_await boost::asio::post(boost::asio::use_awaitable); + } + } + if (connection_.expired()) { + co_return Error::STREAM_RESET_BY_HOST; + } + is_writing_ = true; + std::error_code result; + bool done = false; + connection_.lock()->streamWrite( + stream_id_, + in, + bytes, + [self = shared_from_this(), &result, &done](auto &&write_res) mutable { + if (!write_res) { + self->log_->error("write for stream {} failed: {}", + self->stream_id_.toString(), + write_res.error()); + result = write_res.error(); + } else { + result = std::error_code{}; + } + self->is_writing_ = false; + done = true; + }); + while (!done) { + co_await boost::asio::post(boost::asio::use_awaitable); + } + co_return result; + } + void MplexStream::deferReadCallback(outcome::result res, ReadCallbackFunc cb) { if (connection_.expired()) { diff --git a/src/muxer/mplex/mplexed_connection.cpp b/src/muxer/mplex/mplexed_connection.cpp index 339b38620..0f2df05ad 100644 --- a/src/muxer/mplex/mplexed_connection.cpp +++ b/src/muxer/mplex/mplexed_connection.cpp @@ -142,6 +142,18 @@ namespace libp2p::connection { connection_->writeSome(in, bytes, std::move(cb)); } + boost::asio::awaitable> MplexedConnection::read(BytesOut out, size_t bytes) { + co_return co_await connection_->read(out, bytes); + } + + boost::asio::awaitable> MplexedConnection::readSome(BytesOut out, size_t bytes) { + co_return co_await connection_->readSome(out, bytes); + } + + boost::asio::awaitable MplexedConnection::writeSome(BytesIn in, size_t bytes) { + co_return co_await connection_->writeSome(in, bytes); + } + void MplexedConnection::deferReadCallback(outcome::result res, ReadCallbackFunc cb) { connection_->deferReadCallback(res, std::move(cb)); diff --git a/src/muxer/yamux/yamux.cpp b/src/muxer/yamux/yamux.cpp index 96ee3d85f..7d406767e 100644 --- a/src/muxer/yamux/yamux.cpp +++ b/src/muxer/yamux/yamux.cpp @@ -46,4 +46,28 @@ namespace libp2p::muxer { cb(std::make_shared( std::move(conn), scheduler_, close_cb_, config_)); } + + boost::asio::awaitable< + outcome::result>> + Yamux::muxConnection( + std::shared_ptr conn) const { + if (conn == nullptr || conn->isClosed()) { + log::createLogger("Yamux")->error("dead connection passed to muxer"); + co_return std::errc::not_connected; + } + + if (auto res = conn->remotePeer(); res.has_error()) { + log::createLogger("Yamux")->error( + "inactive connection passed to muxer: {}", res.error()); + co_return res.error(); + } + + auto yamuxed_conn = std::make_shared( + std::move(conn), scheduler_, close_cb_, config_); + + // Start the yamuxed connection? + // yamuxed_conn->start(); + + co_return yamuxed_conn; + } } // namespace libp2p::muxer diff --git a/src/muxer/yamux/yamux_stream.cpp b/src/muxer/yamux/yamux_stream.cpp index 3dbe594ea..6238183e9 100644 --- a/src/muxer/yamux/yamux_stream.cpp +++ b/src/muxer/yamux/yamux_stream.cpp @@ -13,6 +13,7 @@ #include #define TRACE_ENABLED 0 +#include #include namespace libp2p::connection { @@ -166,6 +167,142 @@ namespace libp2p::connection { return connection_->remoteMultiaddr(); } + boost::asio::awaitable> YamuxStream::read( + BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + + if (out.data() == nullptr || out.size() == 0 || bytes == 0) { + co_return Error::STREAM_INVALID_ARGUMENT; + } + + // If the stream is closed, return the error + if (close_reason_) { + co_return *close_reason_; + } + + // If the stream is not readable, return an error + if (!is_readable_) { + co_return Error::STREAM_NOT_READABLE; + } + + // Read the exact number of bytes requested + size_t total_bytes_read = 0; + while (total_bytes_read < bytes) { + auto remaining = bytes - total_bytes_read; + auto result = co_await readSome(out.subspan(total_bytes_read), remaining); + + if (!result) { + co_return result.error(); + } + + size_t bytes_read = result.value(); + if (bytes_read == 0) { + // End of stream reached + break; + } + + total_bytes_read += bytes_read; + } + + co_return total_bytes_read; + } + + boost::asio::awaitable> YamuxStream::readSome( + BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + + if (out.data() == nullptr || out.size() == 0 || bytes == 0) { + co_return Error::STREAM_INVALID_ARGUMENT; + } + + // If the stream is closed, return the error + if (close_reason_) { + co_return *close_reason_; + } + + // If the stream is not readable, return an error + if (!is_readable_) { + co_return Error::STREAM_NOT_READABLE; + } + + // If something is still in read buffer, the client can consume these bytes + auto bytes_available_now = internal_read_buffer_.size(); + if (bytes_available_now > 0) { + out = out.first(static_cast(std::min(bytes, bytes_available_now))); + size_t consumed = internal_read_buffer_.consume(out); + + if (is_readable_) { + feedback_.ackReceivedBytes(stream_id_, consumed); + } + + co_return consumed; + } + + // No data available, need to set up an async read + struct ReadContext { + std::optional> result; + bool done = false; + }; + + auto ctx = std::make_shared(); + + doRead(out, bytes, [ctx](auto result) { + ctx->result = result; + ctx->done = true; + }); + + // Wait for the read operation to complete + while (!ctx->done) { + co_await boost::asio::post(boost::asio::use_awaitable); + } + + co_return *ctx->result; + } + + boost::asio::awaitable YamuxStream::writeSome( + BytesIn in, size_t bytes) { + ambigousSize(in, bytes); + + if (bytes == 0 || in.empty() || static_cast(in.size()) < bytes) { + co_return Error::STREAM_INVALID_ARGUMENT; + } + + if (!is_writable_) { + co_return Error::STREAM_NOT_WRITABLE; + } + + if (close_reason_) { + co_return *close_reason_; + } + + if (!write_queue_.canEnqueue(bytes)) { + co_return Error::STREAM_WRITE_OVERFLOW; + } + + struct WriteContext { + std::optional result; + bool done = false; + }; + + auto ctx = std::make_shared(); + + doWrite(in.first(bytes), bytes, [ctx](auto result) { + if (result) { + ctx->result = std::error_code{}; + } else { + ctx->result = result.error(); + } + ctx->done = true; + }); + + // Wait for the write operation to complete + while (!ctx->done) { + co_await boost::asio::post(boost::asio::use_awaitable); + } + + co_return *ctx->result; + } + void YamuxStream::increaseSendWindow(size_t delta) { if (delta > 0) { window_size_ += delta; diff --git a/src/muxer/yamux/yamuxed_connection.cpp b/src/muxer/yamux/yamuxed_connection.cpp index 198cbd63f..1808a10d2 100644 --- a/src/muxer/yamux/yamuxed_connection.cpp +++ b/src/muxer/yamux/yamuxed_connection.cpp @@ -196,6 +196,24 @@ namespace libp2p::connection { connection_->deferWriteCallback(ec, std::move(cb)); } + boost::asio::awaitable> YamuxedConnection::read( + BytesOut out, size_t bytes) { + log()->error("YamuxedConnection::read : invalid direct call"); + co_return Error::CONNECTION_DIRECT_IO_FORBIDDEN; + } + + boost::asio::awaitable> YamuxedConnection::readSome( + BytesOut out, size_t bytes) { + log()->error("YamuxedConnection::readSome : invalid direct call"); + co_return Error::CONNECTION_DIRECT_IO_FORBIDDEN; + } + + boost::asio::awaitable YamuxedConnection::writeSome( + BytesIn in, size_t bytes) { + log()->error("YamuxedConnection::writeSome : invalid direct call"); + co_return Error::CONNECTION_DIRECT_IO_FORBIDDEN; + } + void YamuxedConnection::continueReading() { SL_TRACE(log(), "YamuxedConnection::continueReading"); connection_->readSome(*raw_read_buffer_, @@ -779,4 +797,24 @@ namespace libp2p::connection { }, config_.ping_interval); } + + boost::asio::awaitable>> YamuxedConnection::newStreamCoroutine() { + if (!started_) { + co_return Error::CONNECTION_NOT_ACTIVE; + } + + if (streams_.size() >= config_.maximum_streams) { + co_return Error::CONNECTION_TOO_MANY_STREAMS; + } + + auto stream_id = new_stream_id_; + new_stream_id_ += 2; + enqueue(newStreamMsg(stream_id)); + + // Wait for the stream to be acknowledged + co_await boost::asio::this_coro::executor; + + auto stream = createStream(stream_id); + co_return stream; + } } // namespace libp2p::connection diff --git a/src/network/impl/listener_manager_impl.cpp b/src/network/impl/listener_manager_impl.cpp index b22178265..e6dd06792 100644 --- a/src/network/impl/listener_manager_impl.cpp +++ b/src/network/impl/listener_manager_impl.cpp @@ -243,8 +243,59 @@ namespace libp2p::network { this->cmgr_->addConnectionToPeer(id, conn); } + void ListenerManagerImpl::onConnectionCoro( + outcome::result> rconn) { + if (!rconn) { + log()->warn("can not accept valid connection, {}", rconn.error()); + return; // ignore + } + auto &&conn = rconn.value(); + + auto rid = conn->remotePeer(); + if (!rid) { + log()->warn("can not get remote peer id, {}", rid.error()); + return; // ignore + } + auto &&id = rid.value(); + + // IMPLEMENT + } + Router &ListenerManagerImpl::getRouter() { return *router_; } + boost::asio::awaitable> + ListenerManagerImpl::listenCoroutine(const multi::Multiaddress &ma) { + auto tr = this->tmgr_->findBest(ma); + if (tr == nullptr) { + // can not listen on this address + co_return std::errc::address_family_not_supported; + } + + auto it = listeners_.find(ma); + if (it != listeners_.end()) { + // this address is already used + co_return std::errc::address_in_use; + } + + auto listener = tr->createListener( + [](auto &&) { throw std::logic_error("can not listen, placeholder"); }); + + listeners_.insert({ma, std::move(listener)}); + + // process connection in onConnectionCoro in detached coroutine + boost::asio::co_spawn( + listener->getContext(), + [this, listener]() -> boost::asio::awaitable { + while (listener && !listener->isClosed()) { + auto connection = co_await listener->asyncAccept(); + this->onConnectionCoro(std::move(connection)); + } + }, + boost::asio::detached); + + co_return outcome::success(); + } + } // namespace libp2p::network diff --git a/src/protocol_muxer/multiselect.cpp b/src/protocol_muxer/multiselect.cpp index 86778d53e..7cd6f40ce 100644 --- a/src/protocol_muxer/multiselect.cpp +++ b/src/protocol_muxer/multiselect.cpp @@ -33,6 +33,28 @@ namespace libp2p::protocol_muxer::multiselect { std::move(cb)); } + boost::asio::awaitable> + Multiselect::selectOneOf(std::span protocols, + std::shared_ptr connection, + bool is_initiator, + bool negotiate_multistream) { + // Create instance and delegate to its coroutine implementation + auto instance = getInstance(); + + // Get the result from the coroutine implementation + auto result = co_await instance->selectOneOf( + protocols, std::move(connection), is_initiator, negotiate_multistream); + + // Return the instance to the cache regardless of result + active_instances_.erase(instance); + if (cache_.size() < kMaxCacheSize) { + cache_.emplace_back(std::move(instance)); + } + + // Return the result + co_return result; + } + void Multiselect::simpleStreamNegotiate( const std::shared_ptr &stream, const peer::ProtocolName &protocol_id, diff --git a/src/protocol_muxer/multiselect/multiselect_instance.cpp b/src/protocol_muxer/multiselect/multiselect_instance.cpp index abb657bf0..a4cf6925b 100644 --- a/src/protocol_muxer/multiselect/multiselect_instance.cpp +++ b/src/protocol_muxer/multiselect/multiselect_instance.cpp @@ -8,6 +8,7 @@ #include +#include #include #include #include @@ -338,4 +339,295 @@ namespace libp2p::protocol_muxer::multiselect { return MaybeResult(ProtocolMuxer::Error::PROTOCOL_VIOLATION); } + boost::asio::awaitable> + MultiselectInstance::selectOneOf( + std::span protocols, + std::shared_ptr connection, + bool is_initiator, + bool negotiate_multiselect) { + assert(!protocols.empty()); + assert(connection); + + // Use local variables instead of class members for the coroutine + // implementation + bool multistream_negotiated = !negotiate_multiselect; + bool wait_for_protocol_reply = false; + size_t current_protocol = 0; + boost::optional wait_for_reply_sent; + + // Store protocols in a local variable instead of using the member variable + boost::container::small_vector local_protocols( + protocols.begin(), protocols.end()); + + // Only reset the parser since it's used for state tracking + parser_.reset(); + + // Local read buffer - avoid using the shared class member + auto read_buffer = std::make_shared>(); + + // Initial protocol negotiation + if (is_initiator) { + // Send the first protocol proposal + auto result = + co_await sendProtocolProposalCoro(connection, + multistream_negotiated, + local_protocols[current_protocol]); + if (!result) { + co_return result.error(); + } + wait_for_protocol_reply = true; + } else if (negotiate_multiselect) { + // Send opening protocol ID (server side) + auto msg = detail::createMessage(kProtocolId); + if (!msg) { + co_return msg.error(); + } + auto packet = std::make_shared(msg.value()); + try { + auto ec = + co_await connection->writeSome(BytesIn(*packet), packet->size()); + if (ec) { + co_return ec; + } + } catch (const std::exception &e) { + log()->error("Error writing opening protocol ID: {}", e.what()); + co_return std::make_error_code(std::errc::io_error); + } + } + + // Cache for NA response - local to this coroutine + boost::optional> na_response; + + // Main negotiation loop + while (true) { + // Read data from the connection + auto bytes_needed = parser_.bytesNeeded(); + if (bytes_needed > kMaxMessageSize) { + SL_TRACE(log(), + "rejecting incoming traffic, too large message ({})", + bytes_needed); + co_return ProtocolMuxer::Error::PROTOCOL_VIOLATION; + } + + BytesOut span(*read_buffer); + span = span.first(static_cast(bytes_needed)); + + try { + auto read_result = co_await connection->read(span, bytes_needed); + if (!read_result) { + co_return read_result.error(); + } + + auto bytes_read = read_result.value(); + if (bytes_read > read_buffer->size()) { + log()->error("selectOneOfCoro(): invalid state"); + co_return ProtocolMuxer::Error::INTERNAL_ERROR; + } + + BytesIn data_span(*read_buffer); + data_span = data_span.first(bytes_read); + + auto state = parser_.consume(data_span); + if (state == Parser::kOverflow) { + SL_TRACE(log(), "peer error: parser overflow"); + co_return ProtocolMuxer::Error::PROTOCOL_VIOLATION; + } + if (state != Parser::kReady) { + continue; // Need more data + } + + // Process the received messages + for (const auto &msg : parser_.messages()) { + switch (msg.type) { + case Message::kProtocolName: { + // Process protocol proposal/acceptance + auto result = + co_await processProtocolMessageCoro(connection, + is_initiator, + multistream_negotiated, + wait_for_protocol_reply, + current_protocol, + wait_for_reply_sent, + local_protocols, + msg, + na_response); + + // If we got a protocol or an error, return it + if (result) { + co_return result; + } + break; + } + case Message::kRightProtocolVersion: + multistream_negotiated = true; + break; + case Message::kNAMessage: { + // Handle NA + if (is_initiator) { + if (current_protocol < local_protocols.size()) { + SL_DEBUG(log(), + "protocol {} was not accepted by peer", + local_protocols[current_protocol]); + } + + // Try the next protocol + ++current_protocol; + + if (current_protocol < local_protocols.size()) { + auto result = co_await sendProtocolProposalCoro( + connection, + multistream_negotiated, + local_protocols[current_protocol]); + if (!result) { + co_return result.error(); + } + wait_for_protocol_reply = true; + } else { + // No more protocols to propose + SL_DEBUG(log(), + "Failed to negotiate protocols: {}", + fmt::join(local_protocols, ", ")); + co_return ProtocolMuxer::Error::NEGOTIATION_FAILED; + } + } else { + // Server side + SL_DEBUG(log(), "Unexpected NA received by server"); + co_return ProtocolMuxer::Error::PROTOCOL_VIOLATION; + } + break; + } + case Message::kWrongProtocolVersion: + co_return ProtocolMuxer::Error::PROTOCOL_VIOLATION; + break; + default: + co_return ProtocolMuxer::Error::PROTOCOL_VIOLATION; + break; + } + } + + // Reset parser for next messages + parser_.reset(); + } catch (const std::exception &e) { + log()->error("Error reading data: {}", e.what()); + co_return std::make_error_code(std::errc::io_error); + } + } + + // This should not be reached in normal operation + co_return ProtocolMuxer::Error::INTERNAL_ERROR; + } + + boost::asio::awaitable> + MultiselectInstance::sendProtocolProposalCoro( + std::shared_ptr connection, + bool multistream_negotiated, + const std::string &protocol) { + // Create the protocol proposal message based on negotiation state + outcome::result msg_res = + outcome::failure(std::make_error_code(std::errc::invalid_argument)); + if (!multistream_negotiated) { + std::array a({kProtocolId, protocol}); + msg_res = detail::createMessage(a, false); + } else { + msg_res = detail::createMessage(protocol); + } + + if (!msg_res) { + co_return msg_res.error(); + } + + // Send the message + auto packet = std::make_shared(msg_res.value()); + try { + auto ec = + co_await connection->writeSome(BytesIn(*packet), packet->size()); + if (ec) { + co_return ec; + } + } catch (const std::exception &e) { + log()->error("Error writing protocol proposal: {}", e.what()); + co_return std::make_error_code(std::errc::io_error); + } + + co_return outcome::success(); + } + + boost::asio::awaitable> + MultiselectInstance::processProtocolMessageCoro( + std::shared_ptr connection, + bool is_initiator, + bool multistream_negotiated, + bool wait_for_protocol_reply, + size_t current_protocol, + boost::optional &wait_for_reply_sent, + const boost::container::small_vector &local_protocols, + const Message &msg, + boost::optional> &na_response) { + // Handle protocol name message + if (is_initiator) { + // Client side + if (wait_for_protocol_reply) { + if (current_protocol < local_protocols.size() + && local_protocols[current_protocol] == msg.content) { + // Successful client side negotiation + co_return std::string(msg.content); + } + } + co_return ProtocolMuxer::Error::PROTOCOL_VIOLATION; + } + + // Server side + size_t idx = 0; + for (const auto &p : local_protocols) { + if (p == msg.content) { + // Successful server side negotiation + wait_for_reply_sent = idx; + + // Send protocol acceptance + auto accept_msg = detail::createMessage(msg.content); + if (!accept_msg) { + co_return accept_msg.error(); + } + + auto packet = std::make_shared(accept_msg.value()); + try { + auto ec = + co_await connection->writeSome(BytesIn(*packet), packet->size()); + if (ec) { + co_return ec; + } + // Protocol negotiation successful + co_return local_protocols[wait_for_reply_sent.value()]; + } catch (const std::exception &e) { + log()->error("Error writing protocol acceptance: {}", e.what()); + co_return std::make_error_code(std::errc::io_error); + } + } + ++idx; + } + + // Not found, send NA + SL_DEBUG(log(), "unknown protocol {} proposed by client", msg.content); + if (!na_response) { + auto na_msg = detail::createMessage(kNA); + if (!na_msg) { + co_return na_msg.error(); + } + na_response = std::make_shared(na_msg.value()); + } + + try { + auto ec = co_await connection->writeSome(BytesIn(*na_response.value()), + na_response.value()->size()); + if (ec) { + co_return ec; + } + // NA sent successfully, continue with protocol negotiation + co_return outcome::success(); + } catch (const std::exception &e) { + log()->error("Error sending NA response: {}", e.what()); + co_return std::make_error_code(std::errc::io_error); + } + } + } // namespace libp2p::protocol_muxer::multiselect diff --git a/src/security/noise/CMakeLists.txt b/src/security/noise/CMakeLists.txt index cbf9a6e22..04781c51b 100644 --- a/src/security/noise/CMakeLists.txt +++ b/src/security/noise/CMakeLists.txt @@ -10,6 +10,7 @@ libp2p_add_library(p2p_noise noise.cpp noise_connection.cpp handshake.cpp + handshake_coro.cpp crypto/state.cpp crypto/hkdf.cpp crypto/noise_dh.cpp diff --git a/src/security/noise/handshake_coro.cpp b/src/security/noise/handshake_coro.cpp new file mode 100644 index 000000000..5083a0c62 --- /dev/null +++ b/src/security/noise/handshake_coro.cpp @@ -0,0 +1,273 @@ +/** + * Copyright Quadrivium LLC + * All Rights Reserved + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace libp2p::security::noise { + + namespace { + template + void unused(T &&) {} + } // namespace + + HandshakeCoro::HandshakeCoro( + std::shared_ptr crypto_provider, + std::unique_ptr + noise_marshaller, + crypto::KeyPair local_key, + std::shared_ptr connection, + bool is_initiator, + boost::optional remote_peer_id, + std::shared_ptr key_marshaller) + : crypto_provider_{std::move(crypto_provider)}, + noise_marshaller_{std::move(noise_marshaller)}, + local_key_{std::move(local_key)}, + conn_{std::move(connection)}, + initiator_{is_initiator}, + key_marshaller_{std::move(key_marshaller)}, + read_buffer_{std::make_shared(kMaxMsgLen)}, + rw_{std::make_shared(conn_, read_buffer_)}, + handshake_state_{std::make_unique()}, + remote_peer_id_{std::move(remote_peer_id)} { + read_buffer_->resize(kMaxMsgLen); + } + + boost::asio::awaitable< + outcome::result>> + HandshakeCoro::connect() { + return runHandshake(); + } + + void HandshakeCoro::setCipherStates(std::shared_ptr cs1, + std::shared_ptr cs2) { + if (initiator_) { + enc_ = std::move(cs1); + dec_ = std::move(cs2); + } else { + enc_ = std::move(cs2); + dec_ = std::move(cs1); + } + } + + outcome::result> HandshakeCoro::generateHandshakePayload( + const DHKey &keypair) { + const auto &prefix = kPayloadPrefix; + const auto &pubkey = keypair.pub; + std::vector to_sign; + to_sign.reserve(prefix.size() + pubkey.size()); + std::copy(prefix.begin(), prefix.end(), std::back_inserter(to_sign)); + std::copy(pubkey.begin(), pubkey.end(), std::back_inserter(to_sign)); + + OUTCOME_TRY(signed_payload, + crypto_provider_->sign(to_sign, local_key_.privateKey)); + security::noise::HandshakeMessage payload{ + .identity_key = local_key_.publicKey, + .identity_sig = std::move(signed_payload), + .data = {}}; + return noise_marshaller_->marshal(payload); + } + + boost::asio::awaitable> + HandshakeCoro::sendHandshakeMessage(BytesIn payload) { + auto write_result = handshake_state_->writeMessage({}, payload); + if (write_result.has_error()) { + co_return write_result.error(); + } + + auto write_bytes_result = co_await rw_->write(write_result.value().data); + if (write_bytes_result.has_error()) { + co_return write_bytes_result.error(); + } + + if (write_result.value().cs1 && write_result.value().cs2) { + setCipherStates(write_result.value().cs1, write_result.value().cs2); + } + + co_return write_bytes_result.value(); + } + + boost::asio::awaitable>> + HandshakeCoro::readHandshakeMessage() { + auto read_result = co_await rw_->read(); + if (read_result.has_error()) { + co_return read_result.error(); + } + + auto buffer = read_result.value(); + auto read_message_result = handshake_state_->readMessage({}, *buffer); + if (read_message_result.has_error()) { + co_return read_message_result.error(); + } + + if (read_message_result.value().cs1 && read_message_result.value().cs2) { + setCipherStates(read_message_result.value().cs1, + read_message_result.value().cs2); + } + + auto shared_data = + std::make_shared(std::move(read_message_result.value().data)); + co_return shared_data; + } + + outcome::result HandshakeCoro::handleRemoteHandshakePayload( + BytesIn payload) { + OUTCOME_TRY(remote_payload, noise_marshaller_->unmarshal(payload)); + OUTCOME_TRY(remote_id, peer::PeerId::fromPublicKey(remote_payload.second)); + auto &&handy_payload = remote_payload.first; + if (initiator_ && remote_peer_id_ != remote_id) { + SL_DEBUG(log_, + "Remote peer id mismatches already known, expected {}, got {}", + remote_peer_id_->toHex(), + remote_id.toHex()); + return std::errc::bad_address; + } + Bytes to_verify; + to_verify.reserve(kPayloadPrefix.size() + + handy_payload.identity_key.data.size()); + std::copy(kPayloadPrefix.begin(), + kPayloadPrefix.end(), + std::back_inserter(to_verify)); + OUTCOME_TRY(remote_static, handshake_state_->remotePeerStaticPubkey()); + std::copy(remote_static.begin(), + remote_static.end(), + std::back_inserter(to_verify)); + OUTCOME_TRY( + signature_correct, + crypto_provider_->verify( + to_verify, handy_payload.identity_sig, handy_payload.identity_key)); + if (!signature_correct) { + SL_TRACE(log_, "Remote peer's payload signature verification failed"); + return std::errc::owner_dead; + } + remote_peer_id_ = remote_id; + remote_peer_pubkey_ = handy_payload.identity_key; + return outcome::success(); + } + + boost::asio::awaitable< + outcome::result>> + HandshakeCoro::runHandshake() { + auto cipher_suite = defaultCipherSuite(); + + auto keypair_res = cipher_suite->generate(); + if (keypair_res.has_error()) { + co_return keypair_res.error(); + } + auto keypair = keypair_res.value(); + + HandshakeStateConfig config( + defaultCipherSuite(), handshakeXX, initiator_, keypair); + auto init_result = handshake_state_->init(std::move(config)); + if (init_result.has_error()) { + co_return init_result.error(); + } + + auto payload_res = generateHandshakePayload(keypair); + if (payload_res.has_error()) { + co_return payload_res.error(); + } + auto payload = payload_res.value(); + + if (initiator_) { + // Outgoing connection. Stage 0 + SL_TRACE(log_, "outgoing connection. stage 0"); + + auto send_result = co_await sendHandshakeMessage({}); + if (send_result.has_error()) { + co_return send_result.error(); + } + + if (0 == send_result.value()) { + co_return std::errc::bad_message; + } + + // Outgoing connection. Stage 1 + SL_TRACE(log_, "outgoing connection. stage 1"); + + auto read_result = co_await readHandshakeMessage(); + if (read_result.has_error()) { + co_return read_result.error(); + } + + auto handle_result = handleRemoteHandshakePayload(*read_result.value()); + if (handle_result.has_error()) { + co_return handle_result.error(); + } + + // Outgoing connection. Stage 2 + SL_TRACE(log_, "outgoing connection. stage 2"); + + auto send_payload_result = co_await sendHandshakeMessage(payload); + if (send_payload_result.has_error()) { + co_return send_payload_result.error(); + } + + } else { + // Incoming connection. Stage 0 + SL_TRACE(log_, "incoming connection. stage 0"); + + auto read_result = co_await readHandshakeMessage(); + if (read_result.has_error()) { + co_return read_result.error(); + } + + // Incoming connection. Stage 1 + SL_TRACE(log_, "incoming connection. stage 1"); + + auto send_result = co_await sendHandshakeMessage(payload); + if (send_result.has_error()) { + co_return send_result.error(); + } + + // Incoming connection. Stage 2 + SL_TRACE(log_, "incoming connection. stage 2"); + + auto read_payload_result = co_await readHandshakeMessage(); + if (read_payload_result.has_error()) { + co_return read_payload_result.error(); + } + + auto handle_result = + handleRemoteHandshakePayload(*read_payload_result.value()); + if (handle_result.has_error()) { + co_return handle_result.error(); + } + } + + if (!remote_peer_pubkey_) { + log_->error("Remote peer static pubkey remains unknown"); + co_return std::errc::connection_aborted; + } + + auto secured_connection = std::make_shared( + conn_, + local_key_.publicKey, + remote_peer_pubkey_.value(), + key_marshaller_, + enc_, + dec_); + log_->info("Handshake succeeded"); + co_return secured_connection; + } + +} // namespace libp2p::security::noise diff --git a/src/security/noise/insecure_rw.cpp b/src/security/noise/insecure_rw.cpp index 9c99cd80b..f03a56d5f 100644 --- a/src/security/noise/insecure_rw.cpp +++ b/src/security/noise/insecure_rw.cpp @@ -5,6 +5,8 @@ */ #include +#include +#include #include @@ -76,4 +78,78 @@ namespace libp2p::security::noise { }; writeReturnSize(connection_, outbuf_, std::move(write_cb)); } + + boost::asio::awaitable + InsecureReadWriter::read() { + buffer_->resize(kMaxMsgLen); // ensure buffer capacity + + // Read the length prefix + auto read_length_result = co_await connection_->read(*buffer_, kLengthPrefixSize); + if (!read_length_result) { + co_return read_length_result.error(); + } + + if (kLengthPrefixSize != read_length_result.value()) { + co_return std::errc::broken_pipe; + } + + // Extract the frame length from the prefix + uint16_t frame_len = ntohs(common::convert(buffer_->data())); // NOLINT + + // Read the actual data + auto read_data_result = co_await connection_->read(*buffer_, frame_len); + if (!read_data_result) { + co_return read_data_result.error(); + } + + if (frame_len != read_data_result.value()) { + co_return std::errc::broken_pipe; + } + + // Resize buffer to actual data size and return it + buffer_->resize(read_data_result.value()); + co_return buffer_; + } + + boost::asio::awaitable> + InsecureReadWriter::write(BytesIn buffer) { + if (buffer.size() > static_cast(kMaxMsgLen)) { + co_return std::errc::message_size; + } + + // Prepare the output buffer with length prefix + outbuf_.clear(); + outbuf_.reserve(kLengthPrefixSize + buffer.size()); + common::putUint16BE(outbuf_, buffer.size()); + outbuf_.insert(outbuf_.end(), buffer.begin(), buffer.end()); + + // Use a promise to handle the completion + struct ContextState { + std::optional> result; + bool done = false; + }; + + auto state = std::make_shared(); + + writeReturnSize(connection_, outbuf_, [state](auto result) { + state->result = result; + state->done = true; + }); + + // Wait until write operation completes + while (!state->done) { + co_await boost::asio::post(boost::asio::use_awaitable); + } + + if (!state->result->has_value()) { + co_return state->result->error(); + } + + if (outbuf_.size() != state->result->value()) { + co_return std::errc::broken_pipe; + } + + // Return number of actual payload bytes written (excluding length prefix) + co_return state->result->value() - kLengthPrefixSize; + } } // namespace libp2p::security::noise diff --git a/src/security/noise/noise.cpp b/src/security/noise/noise.cpp index 8565bfa64..88f3af2ea 100644 --- a/src/security/noise/noise.cpp +++ b/src/security/noise/noise.cpp @@ -5,6 +5,7 @@ */ #include +#include #include #include @@ -59,4 +60,43 @@ namespace libp2p::security { key_marshaller_); handshake->connect(); } + + boost::asio::awaitable< + outcome::result>> + Noise::secureInboundCoro( + std::shared_ptr inbound) { + log_->info("securing inbound connection (coroutine)"); + auto noise_marshaller = + std::make_unique( + key_marshaller_); + auto handshake = + std::make_shared(crypto_provider_, + std::move(noise_marshaller), + local_key_, + inbound, + false, + boost::none, + key_marshaller_); + co_return co_await handshake->connect(); + } + + boost::asio::awaitable< + outcome::result>> + Noise::secureOutboundCoro( + std::shared_ptr outbound, + const peer::PeerId &p) { + log_->info("securing outbound connection (coroutine)"); + auto noise_marshaller = + std::make_unique( + key_marshaller_); + auto handshake = + std::make_shared(crypto_provider_, + std::move(noise_marshaller), + local_key_, + outbound, + true, + p, + key_marshaller_); + co_return co_await handshake->connect(); + } } // namespace libp2p::security diff --git a/src/security/noise/noise_connection.cpp b/src/security/noise/noise_connection.cpp index e836628e3..39a9049c7 100644 --- a/src/security/noise/noise_connection.cpp +++ b/src/security/noise/noise_connection.cpp @@ -4,6 +4,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include #include #include @@ -188,4 +189,106 @@ namespace libp2p::connection { write_buffers_.erase(iterator); iterator = write_buffers_.end(); } + + boost::asio::awaitable> NoiseConnection::read(BytesOut out, + size_t bytes) { + ambigousSize(out, bytes); + + size_t total_bytes_read = 0; + while (total_bytes_read < bytes) { + auto remaining = bytes - total_bytes_read; + auto result = co_await readSome(out.subspan(total_bytes_read), remaining); + + if (!result) { + co_return result.error(); + } + + size_t bytes_read = result.value(); + if (bytes_read == 0) { + break; // Connection closed by peer + } + + total_bytes_read += bytes_read; + } + + co_return total_bytes_read; + } + + boost::asio::awaitable> NoiseConnection::readSome( + BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + + // If there's data in the frame buffer, use it directly + if (!frame_buffer_->empty()) { + auto n = std::min(bytes, frame_buffer_->size()); + auto begin = frame_buffer_->begin(); + auto end = begin + static_cast(n); + std::copy(begin, end, out.begin()); + frame_buffer_->erase(begin, end); + co_return n; + } + + // No data in buffer, need to read a new frame using coroutine method + auto frame_result = co_await framer_->read(); + if (!frame_result) { + co_return frame_result.error(); + } + + // Decrypt the received data + auto decrypt_result = decoder_cs_->decrypt({}, *frame_result.value(), {}); + if (!decrypt_result) { + co_return decrypt_result.error(); + } + + // Store decrypted data in frame buffer + frame_buffer_->assign(decrypt_result.value().begin(), decrypt_result.value().end()); + + // Now read from the frame buffer + auto n = std::min(bytes, frame_buffer_->size()); + auto begin = frame_buffer_->begin(); + auto end = begin + static_cast(n); + std::copy(begin, end, out.begin()); + frame_buffer_->erase(begin, end); + + co_return n; + } + + boost::asio::awaitable NoiseConnection::writeSome( + BytesIn in, size_t bytes) { + ambigousSize(in, bytes); + + if (0 == bytes) { + co_return std::error_code{}; + } + + // Process only up to kMaxPlainText bytes at a time + auto n = std::min(bytes, security::noise::kMaxPlainText); + + // Encrypt the data + auto encrypt_result = encoder_cs_->encrypt({}, in.subspan(0, n), {}); + if (!encrypt_result) { + co_return encrypt_result.error(); + } + + // Store the encrypted data in a buffer + Bytes encrypted = std::move(encrypt_result.value()); + + // Write the encrypted data using coroutine method + auto write_result = co_await framer_->write(encrypted); + if (!write_result) { + co_return write_result.error(); + } + + // If there's more data to write, recursively call writeSome + if (n < bytes) { + auto remaining_result = co_await writeSome(in.subspan(n), bytes - n); + if (remaining_result) { + co_return std::error_code{}; + } + co_return remaining_result; + } + + co_return std::error_code{}; + } + } // namespace libp2p::connection diff --git a/src/security/tls/tls_connection.cpp b/src/security/tls/tls_connection.cpp index 5d4b54a59..fde9c92bc 100644 --- a/src/security/tls/tls_connection.cpp +++ b/src/security/tls/tls_connection.cpp @@ -183,4 +183,60 @@ namespace libp2p::connection { outcome::result TlsConnection::close() { return original_connection_->close(); } + + boost::asio::awaitable> TlsConnection::read( + BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + SL_TRACE(log(), "co_read {} bytes", bytes); + if (isClosed()) { + co_return make_error_code(std::errc::connection_aborted); + } + boost::system::error_code ec; + size_t bytes_transferred = co_await boost::asio::async_read( + socket_, + asioBuffer(out), + boost::asio::redirect_error(boost::asio::use_awaitable, ec)); + if (ec) { + std::ignore = close(); + co_return ec; + } + co_return bytes_transferred; + } + + boost::asio::awaitable> TlsConnection::readSome( + BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + SL_TRACE(log(), "co_readSome up to {} bytes", bytes); + if (isClosed()) { + co_return make_error_code(std::errc::connection_aborted); + } + boost::system::error_code ec; + size_t bytes_transferred = co_await socket_.async_read_some( + asioBuffer(out), + boost::asio::redirect_error(boost::asio::use_awaitable, ec)); + if (ec) { + std::ignore = close(); + co_return ec; + } + co_return bytes_transferred; + } + + boost::asio::awaitable TlsConnection::writeSome( + BytesIn in, size_t bytes) { + ambigousSize(in, bytes); + SL_TRACE(log(), "co_writeSome up to {} bytes", bytes); + if (isClosed()) { + co_return make_error_code(std::errc::connection_aborted); + } + boost::system::error_code ec; + size_t bytes_transferred = co_await socket_.async_write_some( + asioBuffer(in), + boost::asio::redirect_error(boost::asio::use_awaitable, ec)); + if (ec) { + std::ignore = close(); + co_return ec; + } + co_return ec; + } + } // namespace libp2p::connection diff --git a/src/security/tls/tls_connection.hpp b/src/security/tls/tls_connection.hpp index 41d7e3326..b8e8573bd 100644 --- a/src/security/tls/tls_connection.hpp +++ b/src/security/tls/tls_connection.hpp @@ -98,6 +98,14 @@ namespace libp2p::connection { /// Closes the socket outcome::result close() override; + // Coroutine-based methods + boost::asio::awaitable> read(BytesOut out, + size_t bytes) override; + boost::asio::awaitable> readSome( + BytesOut out, size_t bytes) override; + boost::asio::awaitable writeSome(BytesIn in, + size_t bytes) override; + private: /// Async handshake callback. Performs libp2p-specific verification and /// extraction of remote peer's identity fields diff --git a/src/transport/impl/upgrader_impl.cpp b/src/transport/impl/upgrader_impl.cpp index fc060fa3a..643684bb4 100644 --- a/src/transport/impl/upgrader_impl.cpp +++ b/src/transport/impl/upgrader_impl.cpp @@ -7,6 +7,9 @@ #include #include +#include +#include +#include #include #include @@ -95,6 +98,103 @@ namespace libp2p::transport { address, std::move(conn), std::move(layers), 0, std::move(cb)); } + boost::asio::awaitable> + UpgraderImpl::upgradeLayersInboundCoro(RawSPtr conn, ProtoAddrVec layers) { + auto result = co_await upgradeToNextLayerInboundCoro( + std::move(conn), std::move(layers), 0); + co_return result; + } + + boost::asio::awaitable> + UpgraderImpl::upgradeLayersOutboundCoro(const multi::Multiaddress &address, + RawSPtr conn, + ProtoAddrVec layers) { + auto result = co_await upgradeToNextLayerOutboundCoro( + address, std::move(conn), std::move(layers), 0); + co_return result; + } + + boost::asio::awaitable> + UpgraderImpl::upgradeToNextLayerInboundCoro(LayerSPtr conn, + ProtoAddrVec layers, + size_t layer_index) { + BOOST_ASSERT_MSG(!conn->isInitiator(), + "connection is initiator, and upgrade for inbound is " + "called (should be upgrade for outbound)"); + + if (layer_index >= layers.size()) { + co_return conn; + } + const auto &protocol = layers[layer_index]; + + auto adaptor_it = + std::find_if(layer_adaptors_.begin(), + layer_adaptors_.end(), + [&](const auto &adaptor) { + return adaptor->getProtocol() == protocol.first.code; + }); + + if (adaptor_it == layer_adaptors_.end()) { + co_return outcome::failure( + multi::converters::ConversionError::NOT_IMPLEMENTED); + } + + const auto &adaptor = *adaptor_it; + + auto next_layer_conn_res = co_await adaptor->upgradeInbound(conn); + if (!next_layer_conn_res) { + co_return next_layer_conn_res.error(); + } + + auto next_layer_conn = std::move(next_layer_conn_res.value()); + auto result = co_await upgradeToNextLayerInboundCoro( + std::move(next_layer_conn), std::move(layers), layer_index + 1); + co_return result; + } + + boost::asio::awaitable> + UpgraderImpl::upgradeToNextLayerOutboundCoro( + const multi::Multiaddress &address, + LayerSPtr conn, + ProtoAddrVec layers, + size_t layer_index) { + BOOST_ASSERT_MSG(conn->isInitiator(), + "connection is NOT initiator, and upgrade of outbound is " + "called (should be upgrade of inbound)"); + + if (layer_index >= layers.size()) { + co_return conn; + } + const auto &protocol = layers[layer_index]; + + auto adaptor_it = + std::find_if(layer_adaptors_.begin(), + layer_adaptors_.end(), + [&](const auto &adaptor) { + return adaptor->getProtocol() == protocol.first.code; + }); + + if (adaptor_it == layer_adaptors_.end()) { + co_return outcome::failure( + multi::converters::ConversionError::NOT_IMPLEMENTED); + } + + const auto &adaptor = *adaptor_it; + + auto next_layer_conn_res = co_await adaptor->upgradeOutbound(address, conn); + if (!next_layer_conn_res) { + co_return next_layer_conn_res.error(); + } + + auto &next_layer_conn = next_layer_conn_res.value(); + auto result = + co_await upgradeToNextLayerOutboundCoro(address, + std::move(next_layer_conn), + std::move(layers), + layer_index + 1); + co_return result; + } + void UpgraderImpl::upgradeToNextLayerInbound(LayerSPtr conn, ProtoAddrVec layers, size_t layer_index, @@ -244,6 +344,52 @@ namespace libp2p::transport { }); } + boost::asio::awaitable> + UpgraderImpl::upgradeToSecureInboundCoro(LayerSPtr conn) { + BOOST_ASSERT_MSG(!conn->isInitiator(), + "connection is initiator, and upgrade for inbound is " + "called (should be upgrade for outbound)"); + + auto proto_res = co_await protocol_muxer_->selectOneOf( + security_protocols_, conn, conn->isInitiator(), true); + + if (!proto_res) { + co_return proto_res.error(); + } + + auto adaptor = findAdaptor(security_adaptors_, proto_res.value()); + if (adaptor == nullptr) { + co_return Error::NO_ADAPTOR_FOUND; + } + + auto secure_conn_res = co_await adaptor->secureInboundCoro(std::move(conn)); + co_return secure_conn_res; + } + + boost::asio::awaitable> + UpgraderImpl::upgradeToSecureOutboundCoro(LayerSPtr conn, + const peer::PeerId &remoteId) { + BOOST_ASSERT_MSG(conn->isInitiator(), + "connection is NOT initiator, and upgrade for outbound is " + "called (should be upgrade for inbound)"); + + auto proto_res = co_await protocol_muxer_->selectOneOf( + security_protocols_, conn, conn->isInitiator(), true); + + if (!proto_res) { + co_return proto_res.error(); + } + + auto adaptor = findAdaptor(security_adaptors_, proto_res.value()); + if (adaptor == nullptr) { + co_return Error::NO_ADAPTOR_FOUND; + } + + auto secure_conn_res = + co_await adaptor->secureOutboundCoro(std::move(conn), remoteId); + co_return secure_conn_res; + } + void UpgraderImpl::upgradeToMuxed(SecSPtr conn, OnMuxedCallbackFunc cb) { return protocol_muxer_->selectOneOf( muxer_protocols_, @@ -274,4 +420,28 @@ namespace libp2p::transport { }); }); } + + boost::asio::awaitable> + UpgraderImpl::upgradeToMuxedCoro(SecSPtr conn) { + auto proto_res = co_await protocol_muxer_->selectOneOf( + muxer_protocols_, conn, conn->isInitiator(), true); + + if (!proto_res) { + co_return proto_res.error(); + } + + auto adaptor = findAdaptor(muxer_adaptors_, proto_res.value()); + if (!adaptor) { + co_return Error::NO_ADAPTOR_FOUND; + } + + auto conn_res = co_await adaptor->muxConnection(std::move(conn)); + if (!conn_res) { + co_return conn_res.error(); + } + + auto &&muxed_conn = conn_res.value(); + muxed_conn->start(); + co_return muxed_conn; + } } // namespace libp2p::transport diff --git a/src/transport/impl/upgrader_session.cpp b/src/transport/impl/upgrader_session.cpp index 42e1fa4ba..599ef1682 100644 --- a/src/transport/impl/upgrader_session.cpp +++ b/src/transport/impl/upgrader_session.cpp @@ -5,6 +5,9 @@ */ #include +#include +#include +#include namespace libp2p::transport { @@ -19,40 +22,76 @@ namespace libp2p::transport { handler_(std::move(handler)) {} void UpgraderSession::upgradeInbound() { - if (layers_.empty()) { - return secureInbound(raw_); + auto on_layers_upgraded = + [self{shared_from_this()}]( + outcome::result> + res) { + if (!res) { + return self->handler_(res.as_failure()); + } + auto &conn = res.value(); + self->secureInbound(std::move(conn)); + }; + + upgrader_->upgradeLayersInbound(raw_, layers_, std::move(on_layers_upgraded)); + } + + boost::asio::awaitable>> + UpgraderSession::upgradeInboundCoro() { + // Step 1: Upgrade through layers + auto layer_conn_res = co_await upgrader_->upgradeLayersInboundCoro(raw_, layers_); + if (!layer_conn_res) { + co_return layer_conn_res.error(); } - auto on_layers_upgraded = [self{shared_from_this()}](auto &&res) { - if (res.has_error()) { - return self->handler_(res.as_failure()); - } - auto &conn = res.value(); - self->secureInbound(std::move(conn)); - }; + // Step 2: Upgrade to secure connection + auto secure_conn_res = co_await secureInboundCoro(std::move(layer_conn_res.value())); + if (!secure_conn_res) { + co_return secure_conn_res.error(); + } - upgrader_->upgradeLayersInbound( - raw_, layers_, std::move(on_layers_upgraded)); + // Step 3: Upgrade to muxed (capable) connection + auto capable_conn_res = co_await onSecuredCoro(std::move(secure_conn_res.value())); + co_return capable_conn_res; } void UpgraderSession::upgradeOutbound(const multi::Multiaddress &address, const peer::PeerId &remoteId) { - if (layers_.empty()) { - return secureOutbound(raw_, remoteId); - } - - auto on_layers_upgraded = [self{shared_from_this()}, remoteId](auto &&res) { - if (res.has_error()) { - return self->handler_(res.as_failure()); - } - auto &conn = res.value(); - self->secureOutbound(std::move(conn), remoteId); - }; + auto on_layers_upgraded = + [self{shared_from_this()}, remoteId]( + outcome::result> + res) { + if (!res) { + return self->handler_(res.as_failure()); + } + auto &conn = res.value(); + self->secureOutbound(std::move(conn), remoteId); + }; upgrader_->upgradeLayersOutbound( address, raw_, layers_, std::move(on_layers_upgraded)); } + boost::asio::awaitable>> + UpgraderSession::upgradeOutboundCoro(const multi::Multiaddress &address, + const peer::PeerId &remoteId) { + // Step 1: Upgrade through layers + auto layer_conn_res = co_await upgrader_->upgradeLayersOutboundCoro(address, raw_, layers_); + if (!layer_conn_res) { + co_return layer_conn_res.error(); + } + + // Step 2: Upgrade to secure connection + auto secure_conn_res = co_await secureOutboundCoro(std::move(layer_conn_res.value()), remoteId); + if (!secure_conn_res) { + co_return secure_conn_res.error(); + } + + // Step 3: Upgrade to muxed (capable) connection + auto capable_conn_res = co_await onSecuredCoro(std::move(secure_conn_res.value())); + co_return capable_conn_res; + } + void UpgraderSession::secureInbound( std::shared_ptr conn) { auto on_sec_upgraded = [self{shared_from_this()}](auto &&res) { @@ -82,16 +121,41 @@ namespace libp2p::transport { std::move(conn), remoteId, std::move(on_sec_upgraded)); } + boost::asio::awaitable>> + UpgraderSession::secureInboundCoro( + std::shared_ptr conn) { + auto secure_conn_res = co_await upgrader_->upgradeToSecureInboundCoro(std::move(conn)); + co_return secure_conn_res; + } + + boost::asio::awaitable>> + UpgraderSession::secureOutboundCoro( + std::shared_ptr conn, + const peer::PeerId &remoteId) { + auto secure_conn_res = co_await upgrader_->upgradeToSecureOutboundCoro(std::move(conn), remoteId); + co_return secure_conn_res; + } + void UpgraderSession::onSecured( - outcome::result> rsecure) { - if (!rsecure) { - return handler_(rsecure.error()); + outcome::result> res) { + if (!res) { + return handler_(res.as_failure()); } - upgrader_->upgradeToMuxed( - rsecure.value(), [self{shared_from_this()}](auto &&res) { - self->handler_(std::forward(res)); - }); + auto &secure_conn = res.value(); + + auto on_muxed = [handler{handler_}]( + outcome::result> + conn_res) mutable { handler(conn_res); }; + + upgrader_->upgradeToMuxed(std::move(secure_conn), std::move(on_muxed)); + } + + boost::asio::awaitable>> + UpgraderSession::onSecuredCoro( + std::shared_ptr secure_conn) { + auto capable_conn_res = co_await upgrader_->upgradeToMuxedCoro(std::move(secure_conn)); + co_return capable_conn_res; } } // namespace libp2p::transport diff --git a/src/transport/quic/connection.cpp b/src/transport/quic/connection.cpp index 11ce0e9ab..3c585a3f6 100644 --- a/src/transport/quic/connection.cpp +++ b/src/transport/quic/connection.cpp @@ -36,12 +36,20 @@ namespace libp2p::transport { throw std::logic_error{"QuicConnection::read must not be called"}; } + boost::asio::awaitable> QuicConnection::read(BytesOut, size_t) { + throw std::logic_error{"QuicConnection::read (coroutine) must not be called"}; + } + void QuicConnection::readSome(BytesOut out, size_t bytes, ReadCallbackFunc cb) { throw std::logic_error{"QuicConnection::readSome must not be called"}; } + boost::asio::awaitable> QuicConnection::readSome(BytesOut, size_t) { + throw std::logic_error{"QuicConnection::readSome (coroutine) must not be called"}; + } + void QuicConnection::deferReadCallback(outcome::result res, ReadCallbackFunc cb) { post(*io_context_, [cb{std::move(cb)}, res] { cb(res); }); @@ -53,6 +61,10 @@ namespace libp2p::transport { throw std::logic_error{"QuicConnection::writeSome must not be called"}; } + boost::asio::awaitable QuicConnection::writeSome(BytesIn, size_t) { + throw std::logic_error{"QuicConnection::writeSome (coroutine) must not be called"}; + } + void QuicConnection::deferWriteCallback(std::error_code ec, WriteCallbackFunc cb) { deferReadCallback(ec, std::move(cb)); diff --git a/src/transport/quic/stream.cpp b/src/transport/quic/stream.cpp index 15066a475..e429731fe 100644 --- a/src/transport/quic/stream.cpp +++ b/src/transport/quic/stream.cpp @@ -5,6 +5,9 @@ */ #include +#include +#include +#include #include #include #include @@ -33,6 +36,35 @@ namespace libp2p::connection { readReturnSize(shared_from_this(), out, std::move(cb)); } + boost::asio::awaitable> QuicStream::read(BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + if (not stream_ctx_) { + co_return QuicError::STREAM_CLOSED; + } + if (stream_ctx_->reading) { + co_return QuicError::STREAM_READ_IN_PROGRESS; + } + auto n = lsquic_stream_read(stream_ctx_->ls_stream, out.data(), out.size()); + if (n == -1 && errno == EWOULDBLOCK) { + bool done = false; + outcome::result r = QuicError::STREAM_CLOSED; + stream_ctx_->reading.emplace(transport::lsquic::StreamCtx::Reading{ + out, [&](auto res) { + r = res; + done = true; + }}); + lsquic_stream_wantread(stream_ctx_->ls_stream, 1); + while (!done) { + co_await boost::asio::post(boost::asio::use_awaitable); + } + co_return r; + } + if (n > 0) { + co_return n; + } + co_return QuicError::STREAM_CLOSED; + } + void QuicStream::readSome(BytesOut out, size_t bytes, basic::Reader::ReadCallbackFunc cb) { @@ -56,6 +88,35 @@ namespace libp2p::connection { deferReadCallback(r, std::move(cb)); } + boost::asio::awaitable> QuicStream::readSome(BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + if (not stream_ctx_) { + co_return QuicError::STREAM_CLOSED; + } + if (stream_ctx_->reading) { + co_return QuicError::STREAM_READ_IN_PROGRESS; + } + auto n = lsquic_stream_read(stream_ctx_->ls_stream, out.data(), out.size()); + if (n == -1 && errno == EWOULDBLOCK) { + bool done = false; + outcome::result r = QuicError::STREAM_CLOSED; + stream_ctx_->reading.emplace(transport::lsquic::StreamCtx::Reading{ + out, [&](auto res) { + r = res; + done = true; + }}); + lsquic_stream_wantread(stream_ctx_->ls_stream, 1); + while (!done) { + co_await boost::asio::post(boost::asio::use_awaitable); + } + co_return r; + } + if (n > 0) { + co_return n; + } + co_return QuicError::STREAM_CLOSED; + } + void QuicStream::deferReadCallback(outcome::result res, basic::Reader::ReadCallbackFunc cb) { conn_->deferReadCallback(res, std::move(cb)); @@ -77,58 +138,18 @@ namespace libp2p::connection { deferReadCallback(r, std::move(cb)); } - void QuicStream::deferWriteCallback(std::error_code ec, - WriteCallbackFunc cb) { - conn_->deferWriteCallback(ec, std::move(cb)); - } - - bool QuicStream::isClosedForRead() const { - return false; // deprecated - } - - bool QuicStream::isClosedForWrite() const { - return false; // deprecated - } - - bool QuicStream::isClosed() const { - return false; // deprecated - } - - void QuicStream::close(Stream::VoidResultHandlerFunc cb) { + boost::asio::awaitable QuicStream::writeSome(BytesIn in, size_t bytes) { + ambigousSize(in, bytes); if (not stream_ctx_) { - return cb(outcome::success()); + co_return QuicError::STREAM_CLOSED; } - lsquic_stream_shutdown(stream_ctx_->ls_stream, 1); - cb(outcome::success()); - } - - void QuicStream::reset() { - if (not stream_ctx_) { - return; + auto n = lsquic_stream_write(stream_ctx_->ls_stream, in.data(), in.size()); + if (n > 0 && lsquic_stream_flush(stream_ctx_->ls_stream) == 0) { + stream_ctx_->engine->process(); + co_return std::error_code{}; } - lsquic_stream_close(stream_ctx_->ls_stream); - } - - void QuicStream::adjustWindowSize(uint32_t new_size, - Stream::VoidResultHandlerFunc cb) {} - - outcome::result QuicStream::isInitiator() const { - return initiator_; - } - - outcome::result QuicStream::remotePeerId() const { - return conn_->remotePeer(); - } - - outcome::result QuicStream::localMultiaddr() const { - return conn_->localMultiaddr(); - } - - outcome::result QuicStream::remoteMultiaddr() const { - return conn_->remoteMultiaddr(); + stream_ctx_->engine->process(); + co_return QuicError::STREAM_CLOSED; } +} - void QuicStream::onClose() { - stream_ctx_ = nullptr; - } -} // namespace libp2p::connection diff --git a/src/transport/tcp/tcp_connection.cpp b/src/transport/tcp/tcp_connection.cpp index 9ac7767f8..ea15634f7 100644 --- a/src/transport/tcp/tcp_connection.cpp +++ b/src/transport/tcp/tcp_connection.cpp @@ -172,6 +172,54 @@ namespace libp2p::transport { }); } + boost::asio::awaitable TcpConnection::connect( + const ResolverResultsType &iterator) { + boost::system::error_code ec; + co_await boost::asio::async_connect( + socket_, + iterator, + boost::asio::redirect_error(boost::asio::use_awaitable, ec)); + + if (ec) { + co_return ec; + } + + initiator_ = true; + std::ignore = saveMultiaddresses(); + co_return ec; + } + + boost::asio::awaitable TcpConnection::connect( + const ResolverResultsType &iterator, std::chrono::milliseconds timeout) { + if (timeout == std::chrono::milliseconds::zero()) { + co_return co_await connect(iterator); + } + + boost::system::error_code ec; + using namespace boost::asio::experimental::awaitable_operators; + auto result = co_await ( + boost::asio::async_connect( + socket_, + iterator, + boost::asio::redirect_error(boost::asio::use_awaitable, ec)) + || deadline_timer_.async_wait(boost::asio::use_awaitable)); + + if (ec) { + co_return ec; + } + + // if result is 0, it means connect finished first + // if result is 1, it means timer finished first + if (result.index() == 1) { + socket_.close(); // close the socket if timeout occurred + co_return make_error_code(boost::system::errc::timed_out); + } + + initiator_ = true; + std::ignore = saveMultiaddresses(); + co_return ec; + } + void TcpConnection::read(BytesOut out, size_t bytes, TcpConnection::ReadCallbackFunc cb) { @@ -210,6 +258,78 @@ namespace libp2p::transport { boost::asio::post(context_, [ec, cb{std::move(cb)}] { cb(ec); }); } + boost::asio::awaitable> TcpConnection::read( + BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + TRACE("{} co_read {}", debug_str_, bytes); + + if (isClosed()) { + co_return close_reason_.value_or( + make_error_code(boost::system::errc::operation_canceled)); + } + + boost::system::error_code ec; + size_t bytes_transferred = co_await boost::asio::async_read( + socket_, + asioBuffer(out), + boost::asio::redirect_error(boost::asio::use_awaitable, ec)); + + if (ec) { + close(ec); + co_return ec; + } + ByteCounter::getInstance().incrementBytesRead( + bytes_transferred); // It's important to count actual bytes + // transferred. + co_return bytes_transferred; + } + + boost::asio::awaitable> TcpConnection::readSome( + BytesOut out, size_t bytes) { + ambigousSize(out, bytes); + TRACE("{} co_readSome up to {}", debug_str_, bytes); + + if (isClosed()) { + co_return close_reason_.value_or( + make_error_code(boost::system::errc::operation_canceled)); + } + + boost::system::error_code ec; + size_t bytes_transferred = co_await socket_.async_read_some( + asioBuffer(out), + boost::asio::redirect_error(boost::asio::use_awaitable, ec)); + + if (ec) { + close(ec); + co_return ec; + } + ByteCounter::getInstance().incrementBytesRead(bytes_transferred); + co_return bytes_transferred; + } + + boost::asio::awaitable TcpConnection::writeSome( + BytesIn in, size_t bytes) { + ambigousSize(in, bytes); + TRACE("{} co_writeSome up to {}", debug_str_, bytes); + + if (isClosed()) { + co_return close_reason_.value_or( + make_error_code(boost::system::errc::operation_canceled)); + } + + boost::system::error_code ec; + size_t bytes_transferred = co_await socket_.async_write_some( + asioBuffer(in), + boost::asio::redirect_error(boost::asio::use_awaitable, ec)); + + if (ec) { + close(ec); + co_return ec; + } + ByteCounter::getInstance().incrementBytesWritten(bytes_transferred); + co_return ec; + } + outcome::result TcpConnection::saveMultiaddresses() { boost::system::error_code ec; if (socket_.is_open()) { @@ -242,12 +362,4 @@ namespace libp2p::transport { return outcome::success(); } - uint64_t TcpConnection::getBytesRead() { - return ByteCounter::getInstance().getBytesRead(); - } - - uint64_t TcpConnection::getBytesWritten() { - return ByteCounter::getInstance().getBytesWritten(); - } - } // namespace libp2p::transport diff --git a/src/transport/tcp/tcp_listener.cpp b/src/transport/tcp/tcp_listener.cpp index efb94a640..04905089f 100644 --- a/src/transport/tcp/tcp_listener.cpp +++ b/src/transport/tcp/tcp_listener.cpp @@ -63,6 +63,10 @@ namespace libp2p::transport { return detail::makeAddress(endpoint, layers_); } + boost::asio::io_context &TcpListener::getContext() const { + return context_; + } + bool TcpListener::isClosed() const { return !acceptor_.is_open(); } @@ -103,4 +107,32 @@ namespace libp2p::transport { }); }; + boost::asio::awaitable< + outcome::result>> + TcpListener::asyncAccept() { + try { + boost::asio::ip::tcp::socket socket(context_); + co_await acceptor_.async_accept(socket, boost::asio::use_awaitable); + + auto connection = + std::make_shared(context_, layers_, std::move(socket)); + auto secure_connection = + co_await upgrader_->upgradeToSecureInboundCoro(connection); + + if (!secure_connection) { + co_return secure_connection.error(); + } + + auto capable_connection = co_await upgrader_->upgradeToMuxedCoro( + std::move(secure_connection.value())); + + if (!capable_connection) { + co_return capable_connection.error(); + } + co_return capable_connection.value(); + } catch (const std::exception &e) { + co_return std::errc::io_error; + } + } + } // namespace libp2p::transport