diff --git a/src/net.h b/src/net.h --- a/src/net.h +++ b/src/net.h @@ -333,9 +333,6 @@ CDataStream m_recv; //! time of message receipt std::chrono::microseconds m_time{0}; - bool m_valid_netmagic = false; - bool m_valid_header = false; - bool m_valid_checksum = false; //! size of the payload uint32_t m_message_size{0}; //! used wire size of the message (including header/checksum) @@ -361,13 +358,16 @@ /** read and deserialize data, advances msg_bytes data pointer */ virtual int Read(const Config &config, Span &msg_bytes) = 0; // decomposes a message from the context - virtual CNetMessage GetMessage(const Config &config, - std::chrono::microseconds time) = 0; + virtual std::optional + GetMessage(std::chrono::microseconds time, uint32_t &out_err) = 0; virtual ~TransportDeserializer() {} }; class V1TransportDeserializer final : public TransportDeserializer { private: + const CChainParams &m_chain_params; + // Only for logging + const NodeId m_node_id; mutable CHash256 hasher; mutable uint256 data_hash; @@ -398,11 +398,10 @@ } public: - V1TransportDeserializer( - const CMessageHeader::MessageMagic &pchMessageStartIn, int nTypeIn, - int nVersionIn) - : hdrbuf(nTypeIn, nVersionIn), hdr(pchMessageStartIn), - vRecv(nTypeIn, nVersionIn) { + V1TransportDeserializer(const CChainParams &chain_params, + const NodeId node_id, int nTypeIn, int nVersionIn) + : m_chain_params(chain_params), m_node_id(node_id), + hdrbuf(nTypeIn, nVersionIn), vRecv(nTypeIn, nVersionIn) { Reset(); } @@ -428,8 +427,8 @@ return ret; } - CNetMessage GetMessage(const Config &config, - std::chrono::microseconds time) override; + std::optional GetMessage(std::chrono::microseconds time, + uint32_t &out_err) override; }; /** diff --git a/src/net.cpp b/src/net.cpp --- a/src/net.cpp +++ b/src/net.cpp @@ -652,26 +652,36 @@ // Absorb network data. int handled = m_deserializer->Read(config, msg_bytes); if (handled < 0) { + // Serious header problem, disconnect from the peer. return false; } if (m_deserializer->Complete()) { // decompose a transport agnostic CNetMessage from the deserializer - CNetMessage msg = m_deserializer->GetMessage(config, time); + uint32_t out_err_raw_size{0}; + std::optional result{ + m_deserializer->GetMessage(time, out_err_raw_size)}; + if (!result) { + // Message deserialization failed. Drop the message but don't + // disconnect the peer. store the size of the corrupt message + mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += + out_err_raw_size; + continue; + } // Store received bytes per message command to prevent a memory DOS, // only allow valid commands. mapMsgCmdSize::iterator i = - mapRecvBytesPerMsgCmd.find(msg.m_command); + mapRecvBytesPerMsgCmd.find(result->m_command); if (i == mapRecvBytesPerMsgCmd.end()) { i = mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER); } assert(i != mapRecvBytesPerMsgCmd.end()); - i->second += msg.m_raw_message_size; + i->second += result->m_raw_message_size; // push the message to the process queue, - vRecvMsg.push_back(std::move(msg)); + vRecvMsg.push_back(std::move(*result)); complete = true; } @@ -698,12 +708,26 @@ try { hdrbuf >> hdr; } catch (const std::exception &) { + LogPrint(BCLog::NET, "HEADER ERROR - UNABLE TO DESERIALIZE, peer=%d\n", + m_node_id); + return -1; + } + + // Check start string, network magic + if (memcmp(hdr.pchMessageStart.begin(), m_chain_params.NetMagic().begin(), + CMessageHeader::MESSAGE_START_SIZE) != 0) { + LogPrint(BCLog::NET, + "HEADER ERROR - MESSAGESTART (%s, %u bytes), received %s, " + "peer=%d\n", + hdr.GetCommand(), hdr.nMessageSize, + HexStr(hdr.pchMessageStart), m_node_id); return -1; } // Reject oversized messages if (hdr.IsOversized(config)) { - LogPrint(BCLog::NET, "Oversized header detected\n"); + LogPrint(BCLog::NET, "HEADER ERROR - SIZE (%s, %u bytes), peer=%d\n", + hdr.GetCommand(), hdr.nMessageSize, m_node_id); return -1; } @@ -738,47 +762,47 @@ return data_hash; } -CNetMessage -V1TransportDeserializer::GetMessage(const Config &config, - const std::chrono::microseconds time) { +std::optional +V1TransportDeserializer::GetMessage(const std::chrono::microseconds time, + uint32_t &out_err_raw_size) { // decompose a single CNetMessage from the TransportDeserializer - CNetMessage msg(std::move(vRecv)); + std::optional msg(std::move(vRecv)); - // store state about valid header, netmagic and checksum - msg.m_valid_header = hdr.IsValid(config); // FIXME Split CheckHeaderMagicAndCommand() into CheckHeaderMagic() and // CheckCommand() to prevent the net magic check code duplication. - msg.m_valid_netmagic = - (memcmp(std::begin(hdr.pchMessageStart), - std::begin(config.GetChainParams().NetMagic()), - CMessageHeader::MESSAGE_START_SIZE) == 0); - uint256 hash = GetMessageHash(); - // store command string, payload size - msg.m_command = hdr.GetCommand(); - msg.m_message_size = hdr.nMessageSize; - msg.m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE; + // store command string, time, and sizes + msg->m_command = hdr.GetCommand(); + msg->m_time = time; + msg->m_message_size = hdr.nMessageSize; + msg->m_raw_message_size = hdr.nMessageSize + CMessageHeader::HEADER_SIZE; + + uint256 hash = GetMessageHash(); // We just received a message off the wire, harvest entropy from the time // (and the message checksum) RandAddEvent(ReadLE32(hash.begin())); - msg.m_valid_checksum = (memcmp(hash.begin(), hdr.pchChecksum, - CMessageHeader::CHECKSUM_SIZE) == 0); - - if (!msg.m_valid_checksum) { + // Check checksum and header command string + if (memcmp(hash.begin(), hdr.pchChecksum, CMessageHeader::CHECKSUM_SIZE) != + 0) { LogPrint( - BCLog::NET, "CHECKSUM ERROR (%s, %u bytes), expected %s was %s\n", - SanitizeString(msg.m_command), msg.m_message_size, + BCLog::NET, + "CHECKSUM ERROR (%s, %u bytes), expected %s was %s, peer=%d\n", + SanitizeString(msg->m_command), msg->m_message_size, HexStr(Span(hash.begin(), hash.begin() + CMessageHeader::CHECKSUM_SIZE)), - HexStr(hdr.pchChecksum)); + HexStr(hdr.pchChecksum), m_node_id); + out_err_raw_size = msg->m_raw_message_size; + msg = std::nullopt; + } else if (!hdr.IsCommandValid()) { + LogPrint(BCLog::NET, "HEADER ERROR - COMMAND (%s, %u bytes), peer=%d\n", + hdr.GetCommand(), msg->m_message_size, m_node_id); + out_err_raw_size = msg->m_raw_message_size; + msg = std::nullopt; } - // store receive time - msg.m_time = time; - - // reset the network deserializer (prepare for the next message) + // Always reset the network deserializer (prepare for the next message) Reset(); return msg; } @@ -3473,9 +3497,9 @@ LogPrint(BCLog::NET, "Added connection peer=%d\n", id); } - m_deserializer = std::make_unique( - V1TransportDeserializer(GetConfig().GetChainParams().NetMagic(), - SER_NETWORK, INIT_PROTO_VERSION)); + m_deserializer = + std::make_unique(V1TransportDeserializer( + Params(), GetId(), SER_NETWORK, INIT_PROTO_VERSION)); m_serializer = std::make_unique(V1TransportSerializer()); } diff --git a/src/net_processing.cpp b/src/net_processing.cpp --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -5846,14 +5846,6 @@ bool PeerManagerImpl::ProcessMessages(const Config &config, CNode *pfrom, std::atomic &interruptMsgProc) { - // - // Message format - // (4) message start - // (12) command - // (4) size - // (4) checksum - // (x) data - // bool fMoreWork = false; PeerRef peer = GetPeerRef(pfrom->GetId()); @@ -5926,49 +5918,13 @@ } msg.SetVersion(pfrom->GetCommonVersion()); - - // Check network magic - if (!msg.m_valid_netmagic) { - LogPrint(BCLog::NET, - "PROCESSMESSAGE: INVALID MESSAGESTART %s peer=%d\n", - SanitizeString(msg.m_command), pfrom->GetId()); - - // Make sure we discourage where that come from for some time. - if (m_banman) { - m_banman->Discourage(pfrom->addr); - } - m_connman.DisconnectNode(pfrom->addr); - - pfrom->fDisconnect = true; - return false; - } - - // Check header - if (!msg.m_valid_header) { - LogPrint(BCLog::NET, "PROCESSMESSAGE: ERRORS IN HEADER %s peer=%d\n", - SanitizeString(msg.m_command), pfrom->GetId()); - return fMoreWork; - } const std::string &msg_type = msg.m_command; // Message size unsigned int nMessageSize = msg.m_message_size; - // Checksum - CDataStream &vRecv = msg.m_recv; - if (!msg.m_valid_checksum) { - LogPrint(BCLog::NET, "%s(%s, %u bytes): CHECKSUM ERROR peer=%d\n", - __func__, SanitizeString(msg_type), nMessageSize, - pfrom->GetId()); - if (m_banman) { - m_banman->Discourage(pfrom->addr); - } - m_connman.DisconnectNode(pfrom->addr); - return fMoreWork; - } - try { - ProcessMessage(config, *pfrom, msg_type, vRecv, msg.m_time, + ProcessMessage(config, *pfrom, msg_type, msg.m_recv, msg.m_time, interruptMsgProc); if (interruptMsgProc) { return false; diff --git a/src/protocol.h b/src/protocol.h --- a/src/protocol.h +++ b/src/protocol.h @@ -48,7 +48,7 @@ MESSAGE_START_SIZE + COMMAND_SIZE + MESSAGE_SIZE_SIZE + CHECKSUM_SIZE; typedef std::array MessageMagic; - explicit CMessageHeader(const MessageMagic &pchMessageStartIn); + explicit CMessageHeader(); /** * Construct a P2P message header from message-start characters, a command @@ -60,7 +60,7 @@ const char *pszCommand, unsigned int nMessageSizeIn); std::string GetCommand() const; - bool IsValid(const Config &config) const; + bool IsCommandValid() const; bool IsValidWithoutConfig(const MessageMagic &magic) const; bool IsOversized(const Config &config) const; diff --git a/src/protocol.cpp b/src/protocol.cpp --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -84,9 +84,8 @@ allNetMessageTypesVec(std::begin(allNetMessageTypes), std::end(allNetMessageTypes)); -CMessageHeader::CMessageHeader(const MessageMagic &pchMessageStartIn) { - memcpy(std::begin(pchMessageStart), std::begin(pchMessageStartIn), - MESSAGE_START_SIZE); +CMessageHeader::CMessageHeader() { + memset(std::begin(pchMessageStart), 0, MESSAGE_START_SIZE); memset(pchCommand.data(), 0, sizeof(pchCommand)); nMessageSize = -1; memset(pchChecksum, 0, CHECKSUM_SIZE); @@ -119,15 +118,7 @@ strnlen(pchCommand.data(), COMMAND_SIZE)); } -static bool -CheckHeaderMagicAndCommand(const CMessageHeader &header, - const CMessageHeader::MessageMagic &magic) { - // Check start string - if (memcmp(std::begin(header.pchMessageStart), std::begin(magic), - CMessageHeader::MESSAGE_START_SIZE) != 0) { - return false; - } - +static bool CheckHeaderMagicAndCommand(const CMessageHeader &header) { // Check the command string for errors for (const char *p1 = header.pchCommand.data(); p1 < header.pchCommand.data() + CMessageHeader::COMMAND_SIZE; p1++) { @@ -147,17 +138,9 @@ return true; } -bool CMessageHeader::IsValid(const Config &config) const { +bool CMessageHeader::IsCommandValid() const { // Check start string - if (!CheckHeaderMagicAndCommand(*this, - config.GetChainParams().NetMagic())) { - return false; - } - - // Message size - if (IsOversized(config)) { - LogPrintf("CMessageHeader::IsValid(): (%s, %u bytes) is oversized\n", - GetCommand(), nMessageSize); + if (!CheckHeaderMagicAndCommand(*this)) { return false; } @@ -173,15 +156,7 @@ */ bool CMessageHeader::IsValidWithoutConfig(const MessageMagic &magic) const { // Check start string - if (!CheckHeaderMagicAndCommand(*this, magic)) { - return false; - } - - // Message size - if (nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH) { - LogPrintf( - "CMessageHeader::IsValidForSeeder(): (%s, %u bytes) is oversized\n", - GetCommand(), nMessageSize); + if (!CheckHeaderMagicAndCommand(*this)) { return false; } diff --git a/src/seeder/bitcoin.cpp b/src/seeder/bitcoin.cpp --- a/src/seeder/bitcoin.cpp +++ b/src/seeder/bitcoin.cpp @@ -123,7 +123,7 @@ CDataStream::iterator pstart = std::search( vRecv.begin(), vRecv.end(), BEGIN(netMagic), END(netMagic)); uint32_t nHeaderSize = - GetSerializeSize(CMessageHeader(netMagic), vRecv.GetVersion()); + GetSerializeSize(CMessageHeader(), vRecv.GetVersion()); if (vRecv.end() - pstart < nHeaderSize) { if (vRecv.size() > nHeaderSize) { vRecv.erase(vRecv.begin(), vRecv.end() - nHeaderSize); @@ -133,7 +133,7 @@ vRecv.erase(vRecv.begin(), pstart); std::vector vHeaderSave(vRecv.begin(), vRecv.begin() + nHeaderSize); - CMessageHeader hdr(netMagic); + CMessageHeader hdr; vRecv >> hdr; if (!hdr.IsValidWithoutConfig(netMagic)) { // tfm::format(std::cout, "%s: BAD (invalid header)\n", diff --git a/src/seeder/test/p2p_messaging_tests.cpp b/src/seeder/test/p2p_messaging_tests.cpp --- a/src/seeder/test/p2p_messaging_tests.cpp +++ b/src/seeder/test/p2p_messaging_tests.cpp @@ -99,7 +99,7 @@ // Seeder should respond with an ADDR message const CMessageHeader::MessageMagic netMagic = Params().NetMagic(); - CMessageHeader header(netMagic); + CMessageHeader header; CDataStream sendBuffer = testNode->getSendBuffer(); sendBuffer >> header; BOOST_CHECK(header.IsValidWithoutConfig(netMagic)); diff --git a/src/test/fuzz/deserialize.cpp b/src/test/fuzz/deserialize.cpp --- a/src/test/fuzz/deserialize.cpp +++ b/src/test/fuzz/deserialize.cpp @@ -180,11 +180,9 @@ DeserializeFromFuzzingInput(buffer, s); AssertEqualAfterSerializeDeserialize(s); #elif MESSAGEHEADER_DESERIALIZE - const CMessageHeader::MessageMagic pchMessageStart = { - {0x00, 0x00, 0x00, 0x00}}; - CMessageHeader mh(pchMessageStart); + CMessageHeader mh; DeserializeFromFuzzingInput(buffer, mh); - (void)mh.IsValidWithoutConfig(pchMessageStart); + (void)mh.IsCommandValid(); #elif ADDRESS_DESERIALIZE CAddress a; DeserializeFromFuzzingInput(buffer, a); diff --git a/src/test/fuzz/p2p_transport_deserializer.cpp b/src/test/fuzz/p2p_transport_deserializer.cpp --- a/src/test/fuzz/p2p_transport_deserializer.cpp +++ b/src/test/fuzz/p2p_transport_deserializer.cpp @@ -20,7 +20,8 @@ void test_one_input(const std::vector &buffer) { const Config &config = GetConfig(); - V1TransportDeserializer deserializer{config.GetChainParams().NetMagic(), + // Construct deserializer, with a dummy NodeId + V1TransportDeserializer deserializer{Params(), static_cast(0), SER_NETWORK, INIT_PROTO_VERSION}; Span msg_bytes{buffer}; while (msg_bytes.size() > 0) { @@ -31,17 +32,16 @@ if (deserializer.Complete()) { const std::chrono::microseconds m_time{ std::numeric_limits::max()}; - const CNetMessage msg = deserializer.GetMessage(config, m_time); - assert(msg.m_command.size() <= CMessageHeader::COMMAND_SIZE); - assert(msg.m_raw_message_size <= buffer.size()); - assert(msg.m_raw_message_size == - CMessageHeader::HEADER_SIZE + msg.m_message_size); - assert(msg.m_time == m_time); - if (msg.m_valid_header) { - assert(msg.m_valid_netmagic); - } - if (!msg.m_valid_netmagic) { - assert(!msg.m_valid_header); + uint32_t out_err_raw_size{0}; + std::optional result{ + deserializer.GetMessage(m_time, out_err_raw_size)}; + if (result) { + assert(result->m_command.size() <= + CMessageHeader::COMMAND_SIZE); + assert(result->m_raw_message_size <= buffer.size()); + assert(result->m_raw_message_size == + CMessageHeader::HEADER_SIZE + result->m_message_size); + assert(result->m_time == m_time); } } } diff --git a/test/functional/p2p_invalid_messages.py b/test/functional/p2p_invalid_messages.py --- a/test/functional/p2p_invalid_messages.py +++ b/test/functional/p2p_invalid_messages.py @@ -93,7 +93,7 @@ def test_magic_bytes(self): self.log.info("Test message with invalid magic bytes disconnects peer") conn = self.nodes[0].add_p2p_connection(P2PDataStore()) - with self.nodes[0].assert_debug_log(['PROCESSMESSAGE: INVALID MESSAGESTART badmsg']): + with self.nodes[0].assert_debug_log(['HEADER ERROR - MESSAGESTART (badmsg, 2 bytes), received ffffffff']): msg = conn.build_message(msg_unrecognized(str_data="d")) # modify magic bytes msg = b'\xff' * 4 + msg[4:] @@ -117,7 +117,7 @@ def test_size(self): self.log.info("Test message with oversized payload disconnects peer") conn = self.nodes[0].add_p2p_connection(P2PDataStore()) - with self.nodes[0].assert_debug_log(['']): + with self.nodes[0].assert_debug_log(['HEADER ERROR - SIZE (badmsg, 2097153 bytes)']): msg = msg_unrecognized(str_data="d" * (VALID_DATA_LIMIT + 1)) msg = conn.build_message(msg) conn.send_raw_message(msg) @@ -127,9 +127,8 @@ def test_msgtype(self): self.log.info("Test message with invalid message type logs an error") conn = self.nodes[0].add_p2p_connection(P2PDataStore()) - with self.nodes[0].assert_debug_log(['PROCESSMESSAGE: ERRORS IN HEADER']): + with self.nodes[0].assert_debug_log(['HEADER ERROR - COMMAND']): msg = msg_unrecognized(str_data="d") - msg.msgtype = b'\xff' * 12 msg = conn.build_message(msg) # Modify msgtype msg = msg[:7] + b'\x00' + msg[7 + 1:]