diff --git a/src/net.h b/src/net.h --- a/src/net.h +++ b/src/net.h @@ -648,7 +648,35 @@ uint32_t m_mapped_as; }; +/** + * Transport protocol agnostic message container. + * Ideally it should only contain receive time, payload, + * command and size. + */ class CNetMessage { +public: + // received message data + CDataStream m_recv; + // time (in microseconds) of message receipt. + int64_t m_time = 0; + bool m_valid_netmagic = false; + bool m_valid_header = false; + bool m_valid_checksum = false; + // size of the payload + uint32_t m_message_size = 0; + std::string m_command; + + CNetMessage(const CDataStream &recv_in) : m_recv(std::move(recv_in)) {} + + void SetVersion(int nVersionIn) { m_recv.SetVersion(nVersionIn); } +}; + +/** + * The TransportDeserializer takes care of holding and deserializing the + * network receive buffer. It can deserialize the network buffer into a + * transport protocol agnostic CNetMessage (command & payload) + */ +class TransportDeserializer { private: mutable CHash256 hasher; mutable uint256 data_hash; @@ -670,15 +698,22 @@ // Time (in microseconds) of message receipt. int64_t nTime; - CNetMessage(const CMessageHeader::MessageMagic &pchMessageStartIn, - int nTypeIn, int nVersionIn) + TransportDeserializer(const CMessageHeader::MessageMagic &pchMessageStartIn, + int nTypeIn, int nVersionIn) : hdrbuf(nTypeIn, nVersionIn), hdr(pchMessageStartIn), vRecv(nTypeIn, nVersionIn) { + Reset(); + } + + void Reset() { + vRecv.clear(); + hdrbuf.clear(); hdrbuf.resize(24); in_data = false; nHdrPos = 0; nDataPos = 0; - nTime = 0; + data_hash.SetNull(); + hasher.Reset(); } bool complete() const { @@ -698,6 +733,8 @@ int readHeader(const Config &config, const char *pch, uint32_t nBytes); int readData(const char *pch, uint32_t nBytes); + + CNetMessage GetMessage(const Config &config, int64_t time); }; /** Information about a peer */ @@ -705,6 +742,8 @@ friend class CConnman; public: + std::unique_ptr m_deserializer; + // socket std::atomic nServices{NODE_NONE}; SOCKET hSocket GUARDED_BY(cs_hSocket); diff --git a/src/net.cpp b/src/net.cpp --- a/src/net.cpp +++ b/src/net.cpp @@ -570,13 +570,14 @@ addrLocalUnlocked.IsValid() ? addrLocalUnlocked.ToString() : ""; } -static bool IsOversizedMessage(const Config &config, const CNetMessage &msg) { - if (!msg.in_data) { +static bool IsOversizedMessage(const Config &config, + const TransportDeserializer &deserializer) { + if (!deserializer.in_data) { // Header only, cannot be oversized. return false; } - return msg.hdr.IsOversized(config); + return deserializer.hdr.IsOversized(config); } bool CNode::ReceiveMsgBytes(const Config &config, const char *pch, @@ -587,49 +588,49 @@ nLastRecv = nTimeMicros / 1000000; nRecvBytes += nBytes; while (nBytes > 0) { - // Get current incomplete message, or create a new one. - if (vRecvMsg.empty() || vRecvMsg.back().complete()) { - vRecvMsg.push_back(CNetMessage(config.GetChainParams().NetMagic(), - SER_NETWORK, INIT_PROTO_VERSION)); - } - - CNetMessage &msg = vRecvMsg.back(); - // Absorb network data. int handled; - if (!msg.in_data) { - handled = msg.readHeader(config, pch, nBytes); + if (!m_deserializer->in_data) { + handled = m_deserializer->readHeader(config, pch, nBytes); } else { - handled = msg.readData(pch, nBytes); + handled = m_deserializer->readData(pch, nBytes); } if (handled < 0) { + m_deserializer->Reset(); return false; } - if (IsOversizedMessage(config, msg)) { + if (IsOversizedMessage(config, *m_deserializer)) { LogPrint(BCLog::NET, "Oversized message from peer=%i, disconnecting\n", GetId()); + m_deserializer->Reset(); return false; } pch += handled; nBytes -= handled; - if (msg.complete()) { + if (m_deserializer->complete()) { + // decompose a transport agnostic CNetMessage from the deserializer + CNetMessage msg = m_deserializer->GetMessage(config, nTimeMicros); + // Store received bytes per message command to prevent a memory DOS, // only allow valid commands. - mapMsgCmdSize::iterator i = - mapRecvBytesPerMsgCmd.find(msg.hdr.pchCommand.data()); + mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find( + m_deserializer->hdr.pchCommand.data()); if (i == mapRecvBytesPerMsgCmd.end()) { i = mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER); } assert(i != mapRecvBytesPerMsgCmd.end()); - i->second += msg.hdr.nMessageSize + CMessageHeader::HEADER_SIZE; + i->second += + m_deserializer->hdr.nMessageSize + CMessageHeader::HEADER_SIZE; + + // push the message to the process queue, + vRecvMsg.push_back(std::move(msg)); - msg.nTime = nTimeMicros; complete = true; } } @@ -663,8 +664,8 @@ return nSendVersion; } -int CNetMessage::readHeader(const Config &config, const char *pch, - uint32_t nBytes) { +int TransportDeserializer::readHeader(const Config &config, const char *pch, + uint32_t nBytes) { // copy data to temporary parsing buffer uint32_t nRemaining = 24 - nHdrPos; uint32_t nCopy = std::min(nRemaining, nBytes); @@ -696,7 +697,7 @@ return nCopy; } -int CNetMessage::readData(const char *pch, uint32_t nBytes) { +int TransportDeserializer::readData(const char *pch, uint32_t nBytes) { unsigned int nRemaining = hdr.nMessageSize - nDataPos; unsigned int nCopy = std::min(nRemaining, nBytes); @@ -713,7 +714,7 @@ return nCopy; } -const uint256 &CNetMessage::GetMessageHash() const { +const uint256 &TransportDeserializer::GetMessageHash() const { assert(complete()); if (data_hash.IsNull()) { hasher.Finalize(data_hash.begin()); @@ -721,6 +722,44 @@ return data_hash; } +CNetMessage TransportDeserializer::GetMessage(const Config &config, + int64_t time) { + // decompose a single CNetMessage from the TransportDeserializer + CNetMessage msg(std::move(vRecv)); + + // store state about valid header, netmagic and checksum + msg.m_valid_header = hdr.IsValid(config); + // FIXME Split CheckHeaderMagicAndCommand() into CheckHeaderMagic() and + // CheckCommand() to prevent the net magic check code duplication. + msg.m_valid_netmagic = + (memcmp(std::begin(hdr.pchMessageStart), + std::begin(config.GetChainParams().NetMagic()), + CMessageHeader::MESSAGE_START_SIZE) == 0); + uint256 hash = GetMessageHash(); + + // store command string, payload size + msg.m_command = hdr.GetCommand(); + msg.m_message_size = hdr.nMessageSize; + + msg.m_valid_checksum = (memcmp(hash.begin(), hdr.pchChecksum, + CMessageHeader::CHECKSUM_SIZE) == 0); + if (!msg.m_valid_checksum) { + LogPrint( + BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s\n", + SanitizeString(msg.m_command), msg.m_message_size, + HexStr(hash.begin(), hash.begin() + CMessageHeader::CHECKSUM_SIZE), + HexStr(hdr.pchChecksum, + hdr.pchChecksum + CMessageHeader::CHECKSUM_SIZE)); + } + + // store receive time + msg.m_time = time; + + // reset the network deserializer (prepare for the next message) + Reset(); + return msg; +} + size_t CConnman::SocketSendData(CNode *pnode) const EXCLUSIVE_LOCKS_REQUIRED(pnode->cs_vSend) { size_t nSentSize = 0; @@ -1480,11 +1519,11 @@ size_t nSizeAdded = 0; auto it(pnode->vRecvMsg.begin()); for (; it != pnode->vRecvMsg.end(); ++it) { - if (!it->complete()) { - break; - } + // vRecvMsg contains only completed CNetMessage + // the single possible partially deserialized message + // are held by TransportDeserializer nSizeAdded += - it->vRecv.size() + CMessageHeader::HEADER_SIZE; + it->m_recv.size() + CMessageHeader::HEADER_SIZE; } { LOCK(pnode->cs_vProcessMsg); @@ -2908,6 +2947,10 @@ } else { LogPrint(BCLog::NET, "Added connection peer=%d\n", id); } + + m_deserializer = std::make_unique( + TransportDeserializer(GetConfig().GetChainParams().NetMagic(), + SER_NETWORK, INIT_PROTO_VERSION)); } CNode::~CNode() { diff --git a/src/net_processing.cpp b/src/net_processing.cpp --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -3923,7 +3923,6 @@ bool PeerLogicValidation::ProcessMessages(const Config &config, CNode *pfrom, std::atomic &interruptMsgProc) { - const CChainParams &chainparams = config.GetChainParams(); // // Message format // (4) message start @@ -3971,7 +3970,7 @@ msgs.splice(msgs.begin(), pfrom->vProcessMsg, pfrom->vProcessMsg.begin()); pfrom->nProcessQueueSize -= - msgs.front().vRecv.size() + CMessageHeader::HEADER_SIZE; + msgs.front().m_recv.size() + CMessageHeader::HEADER_SIZE; pfrom->fPauseRecv = pfrom->nProcessQueueSize > connman->GetReceiveFloodSize(); fMoreWork = !pfrom->vProcessMsg.empty(); @@ -3980,13 +3979,11 @@ msg.SetVersion(pfrom->GetRecvVersion()); - // Scan for message start - if (memcmp(std::begin(msg.hdr.pchMessageStart), - std::begin(chainparams.NetMagic()), - CMessageHeader::MESSAGE_START_SIZE) != 0) { + // Check network magic + if (!msg.m_valid_netmagic) { LogPrint(BCLog::NET, "PROCESSMESSAGE: INVALID MESSAGESTART %s peer=%d\n", - SanitizeString(msg.hdr.GetCommand()), pfrom->GetId()); + SanitizeString(msg.m_command), pfrom->GetId()); // Make sure we discourage where that come from for some time. if (m_banman) { @@ -3998,32 +3995,23 @@ return false; } - // Read header - CMessageHeader &hdr = msg.hdr; - if (!hdr.IsValid(config)) { + // Check header + if (!msg.m_valid_header) { LogPrint(BCLog::NET, "PROCESSMESSAGE: ERRORS IN HEADER %s peer=%d\n", - SanitizeString(hdr.GetCommand()), pfrom->GetId()); + SanitizeString(msg.m_command), pfrom->GetId()); return fMoreWork; } - std::string strCommand = hdr.GetCommand(); + const std::string &strCommand = msg.m_command; // Message size - unsigned int nMessageSize = hdr.nMessageSize; + unsigned int nMessageSize = msg.m_message_size; // Checksum - CDataStream &vRecv = msg.vRecv; - const uint256 &hash = msg.GetMessageHash(); - if (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) != - 0) { - LogPrint( - BCLog::NET, - "%s(%s, %u bytes): CHECKSUM ERROR expected %s was %s from " - "peer=%d\n", - __func__, SanitizeString(strCommand), nMessageSize, - HexStr(hash.begin(), hash.begin() + CMessageHeader::CHECKSUM_SIZE), - HexStr(hdr.pchChecksum, - hdr.pchChecksum + CMessageHeader::CHECKSUM_SIZE), - pfrom->GetId()); + CDataStream &vRecv = msg.m_recv; + if (!msg.m_valid_checksum) { + LogPrint(BCLog::NET, "%s(%s, %u bytes): CHECKSUM ERROR peer=%d\n", + __func__, SanitizeString(strCommand), nMessageSize, + pfrom->GetId()); if (m_banman) { m_banman->Discourage(pfrom->addr); } @@ -4034,7 +4022,7 @@ // Process message bool fRet = false; try { - fRet = ProcessMessage(config, pfrom, strCommand, vRecv, msg.nTime, + fRet = ProcessMessage(config, pfrom, strCommand, vRecv, msg.m_time, connman, m_banman, interruptMsgProc); if (interruptMsgProc) { return false; diff --git a/test/functional/p2p_invalid_messages.py b/test/functional/p2p_invalid_messages.py --- a/test/functional/p2p_invalid_messages.py +++ b/test/functional/p2p_invalid_messages.py @@ -164,7 +164,7 @@ def test_checksum(self): conn = self.nodes[0].add_p2p_connection(P2PDataStore()) - with self.nodes[0].assert_debug_log(['ProcessMessages(badmsg, 2 bytes): CHECKSUM ERROR expected 78df0a04 was ffffffff']): + with self.nodes[0].assert_debug_log(['CHECKSUM ERROR (badmsg, 2 bytes), expected 78df0a04 was ffffffff']): msg = conn.build_message(msg_unrecognized(str_data="d")) cut_len = ( # magic