diff --git a/src/net.cpp b/src/net.cpp --- a/src/net.cpp +++ b/src/net.cpp @@ -47,11 +47,16 @@ #include #include #include +#include #include #include #include #include +class NetException : public std::exception {}; +class InvalidNetMagic : public NetException {}; +class ChecksumError : public NetException {}; + /** Maximum number of block-relay-only anchor connections */ static constexpr size_t MAX_BLOCK_RELAY_ONLY_ANCHORS = 2; static_assert(MAX_BLOCK_RELAY_ONLY_ANCHORS <= @@ -659,8 +664,15 @@ if (m_deserializer->Complete()) { // decompose a transport agnostic CNetMessage from the deserializer uint32_t out_err_raw_size{0}; - std::optional result{ - m_deserializer->GetMessage(time, out_err_raw_size)}; + std::optional result{std::nullopt}; + try { + result = m_deserializer->GetMessage(time, out_err_raw_size); + } catch (const ChecksumError &) { + mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER)->second += + out_err_raw_size; + throw; + } + if (!result) { // Message deserialization failed. Drop the message but don't // disconnect the peer. store the size of the corrupt message @@ -721,6 +733,7 @@ "peer=%d\n", hdr.GetCommand(), hdr.nMessageSize, HexStr(hdr.pchMessageStart), m_node_id); + throw InvalidNetMagic{}; return -1; } @@ -794,7 +807,8 @@ hash.begin() + CMessageHeader::CHECKSUM_SIZE)), HexStr(hdr.pchChecksum), m_node_id); out_err_raw_size = msg->m_raw_message_size; - msg = std::nullopt; + Reset(); + throw ChecksumError{}; } else if (!hdr.IsCommandValid()) { LogPrint(BCLog::NET, "HEADER ERROR - COMMAND (%s, %u bytes), peer=%d\n", hdr.GetCommand(), msg->m_message_size, m_node_id); @@ -1842,8 +1856,17 @@ } if (nBytes > 0) { bool notify = false; - if (!pnode->ReceiveMsgBytes( - *config, Span(pchBuf, nBytes), notify)) { + try { + if (!pnode->ReceiveMsgBytes( + *config, Span(pchBuf, nBytes), + notify)) { + pnode->CloseSocketDisconnect(); + } + } catch (const NetException &) { + if (m_banman) { + m_banman->Discourage(pnode->addr); + } + DisconnectNode(pnode->addr); pnode->CloseSocketDisconnect(); } RecordBytesRecv(nBytes);