diff --git a/src/net.h b/src/net.h --- a/src/net.h +++ b/src/net.h @@ -679,29 +679,47 @@ * transport protocol agnostic CNetMessage (command & payload) */ class TransportDeserializer { +public: + // prepare for next message + virtual void Reset() = 0; + // returns true if the current deserialization is complete + virtual bool Complete() const = 0; + // checks if the potential message in deserialization is oversized + virtual bool OversizedMessageDetected(const Config &config) const = 0; + // set the serialization context version + virtual void SetVersion(int version) = 0; + // read and deserialize data + virtual int Read(const Config &config, const char *data, + uint32_t bytes) = 0; + // decomposes a message from the context + virtual CNetMessage GetMessage(const Config &config, int64_t time) = 0; + virtual ~TransportDeserializer() {} +}; + +class V1TransportDeserializer : public TransportDeserializer { private: mutable CHash256 hasher; mutable uint256 data_hash; -public: // Parsing header (false) or data (true) bool in_data; - // Partially received header. CDataStream hdrbuf; // Complete header. CMessageHeader hdr; - uint32_t nHdrPos; - // Received message data. CDataStream vRecv; + uint32_t nHdrPos; uint32_t nDataPos; - // Time (in microseconds) of message receipt. - int64_t nTime; + const uint256 &GetMessageHash() const; + int readHeader(const Config &config, const char *pch, uint32_t nBytes); + int readData(const char *pch, uint32_t nBytes); - TransportDeserializer(const CMessageHeader::MessageMagic &pchMessageStartIn, - int nTypeIn, int nVersionIn) +public: + V1TransportDeserializer( + const CMessageHeader::MessageMagic &pchMessageStartIn, int nTypeIn, + int nVersionIn) : hdrbuf(nTypeIn, nVersionIn), hdr(pchMessageStartIn), vRecv(nTypeIn, nVersionIn) { Reset(); @@ -718,7 +736,7 @@ hasher.Reset(); } - bool complete() const { + bool Complete() const { if (!in_data) { return false; } @@ -726,15 +744,17 @@ return (hdr.nMessageSize == nDataPos); } - const uint256 &GetMessageHash() const; - void SetVersion(int nVersionIn) { hdrbuf.SetVersion(nVersionIn); vRecv.SetVersion(nVersionIn); } - - int readHeader(const Config &config, const char *pch, uint32_t nBytes); - int readData(const char *pch, uint32_t nBytes); + bool OversizedMessageDetected(const Config &config) const { + return (in_data && hdr.IsOversized(config)); + } + int Read(const Config &config, const char *pch, uint32_t nBytes) { + return in_data ? readData(pch, nBytes) + : readHeader(config, pch, nBytes); + } CNetMessage GetMessage(const Config &config, int64_t time); }; diff --git a/src/net.cpp b/src/net.cpp --- a/src/net.cpp +++ b/src/net.cpp @@ -570,16 +570,6 @@ addrLocalUnlocked.IsValid() ? addrLocalUnlocked.ToString() : ""; } -static bool IsOversizedMessage(const Config &config, - const TransportDeserializer &deserializer) { - if (!deserializer.in_data) { - // Header only, cannot be oversized. - return false; - } - - return deserializer.hdr.IsOversized(config); -} - bool CNode::ReceiveMsgBytes(const Config &config, const char *pch, uint32_t nBytes, bool &complete) { complete = false; @@ -589,19 +579,14 @@ nRecvBytes += nBytes; while (nBytes > 0) { // Absorb network data. - int handled; - if (!m_deserializer->in_data) { - handled = m_deserializer->readHeader(config, pch, nBytes); - } else { - handled = m_deserializer->readData(pch, nBytes); - } + int handled = m_deserializer->Read(config, pch, nBytes); if (handled < 0) { m_deserializer->Reset(); return false; } - if (IsOversizedMessage(config, *m_deserializer)) { + if (m_deserializer->OversizedMessageDetected(config)) { LogPrint(BCLog::NET, "Oversized message from peer=%i, disconnecting\n", GetId()); @@ -612,14 +597,14 @@ pch += handled; nBytes -= handled; - if (m_deserializer->complete()) { + if (m_deserializer->Complete()) { // decompose a transport agnostic CNetMessage from the deserializer CNetMessage msg = m_deserializer->GetMessage(config, nTimeMicros); // Store received bytes per message command to prevent a memory DOS, // only allow valid commands. - mapMsgCmdSize::iterator i = mapRecvBytesPerMsgCmd.find( - m_deserializer->hdr.pchCommand.data()); + mapMsgCmdSize::iterator i = + mapRecvBytesPerMsgCmd.find(msg.m_command); if (i == mapRecvBytesPerMsgCmd.end()) { i = mapRecvBytesPerMsgCmd.find(NET_MESSAGE_COMMAND_OTHER); } @@ -663,8 +648,8 @@ return nSendVersion; } -int TransportDeserializer::readHeader(const Config &config, const char *pch, - uint32_t nBytes) { +int V1TransportDeserializer::readHeader(const Config &config, const char *pch, + uint32_t nBytes) { // copy data to temporary parsing buffer uint32_t nRemaining = 24 - nHdrPos; uint32_t nCopy = std::min(nRemaining, nBytes); @@ -696,7 +681,7 @@ return nCopy; } -int TransportDeserializer::readData(const char *pch, uint32_t nBytes) { +int V1TransportDeserializer::readData(const char *pch, uint32_t nBytes) { unsigned int nRemaining = hdr.nMessageSize - nDataPos; unsigned int nCopy = std::min(nRemaining, nBytes); @@ -713,16 +698,16 @@ return nCopy; } -const uint256 &TransportDeserializer::GetMessageHash() const { - assert(complete()); +const uint256 &V1TransportDeserializer::GetMessageHash() const { + assert(Complete()); if (data_hash.IsNull()) { hasher.Finalize(data_hash.begin()); } return data_hash; } -CNetMessage TransportDeserializer::GetMessage(const Config &config, - int64_t time) { +CNetMessage V1TransportDeserializer::GetMessage(const Config &config, + int64_t time) { // decompose a single CNetMessage from the TransportDeserializer CNetMessage msg(std::move(vRecv)); @@ -2947,9 +2932,9 @@ LogPrint(BCLog::NET, "Added connection peer=%d\n", id); } - m_deserializer = std::make_unique( - TransportDeserializer(GetConfig().GetChainParams().NetMagic(), - SER_NETWORK, INIT_PROTO_VERSION)); + m_deserializer = std::make_unique( + V1TransportDeserializer(GetConfig().GetChainParams().NetMagic(), + SER_NETWORK, INIT_PROTO_VERSION)); } CNode::~CNode() {