diff --git a/src/net.h b/src/net.h --- a/src/net.h +++ b/src/net.h @@ -1018,10 +1018,10 @@ } } - void AddInventoryKnown(const CInv &inv) { + void AddKnownTx(const TxId &txid) { if (m_tx_relay != nullptr) { LOCK(m_tx_relay->cs_tx_inventory); - m_tx_relay->filterInventoryKnown.insert(inv.hash); + m_tx_relay->filterInventoryKnown.insert(txid); } } diff --git a/src/net_processing.cpp b/src/net_processing.cpp --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -2897,18 +2897,18 @@ pfrom.GetId()); } } else { - pfrom.AddInventoryKnown(inv); + const TxId txid(inv.hash); + pfrom.AddKnownTx(txid); if (fBlocksOnly) { LogPrint(BCLog::NET, "transaction (%s) inv sent in violation of " "protocol, disconnecting peer=%d\n", - inv.hash.ToString(), pfrom.GetId()); + txid.ToString(), pfrom.GetId()); pfrom.fDisconnect = true; return true; } else if (!fAlreadyHave && !fImporting && !fReindex && !::ChainstateActive().IsInitialBlockDownload()) { - RequestTx(State(pfrom.GetId()), TxId(inv.hash), - current_time); + RequestTx(State(pfrom.GetId()), txid, current_time); } } } @@ -3175,9 +3175,7 @@ vRecv >> ptx; const CTransaction &tx = *ptx; const TxId &txid = tx.GetId(); - - CInv inv(MSG_TX, txid); - pfrom.AddInventoryKnown(inv); + pfrom.AddKnownTx(txid); LOCK2(cs_main, g_cs_orphans); @@ -3188,7 +3186,7 @@ nodestate->m_tx_download.m_tx_in_flight.erase(txid); EraseTxRequest(txid); - if (!AlreadyHave(inv) && + if (!AlreadyHave(CInv(MSG_TX, txid)) && AcceptToMemoryPool(config, g_mempool, state, ptx, false /* bypass_limits */, Amount::zero() /* nAbsurdFee */)) { @@ -3231,9 +3229,8 @@ for (const CTxIn &txin : tx.vin) { // FIXME: MSG_TX should use a TxHash, not a TxId. const TxId _txid = txin.prevout.GetTxId(); - CInv _inv(MSG_TX, _txid); - pfrom.AddInventoryKnown(_inv); - if (!AlreadyHave(_inv)) { + pfrom.AddKnownTx(_txid); + if (!AlreadyHave(CInv(MSG_TX, _txid))) { RequestTx(State(pfrom.GetId()), _txid, current_time); } } diff --git a/test/functional/test_framework/messages.py b/test/functional/test_framework/messages.py --- a/test/functional/test_framework/messages.py +++ b/test/functional/test_framework/messages.py @@ -266,6 +266,10 @@ return "CInv(type={} hash={:064x})".format( self.typemap[self.type], self.hash) + def __eq__(self, other): + return isinstance( + other, CInv) and self.hash == other.hash and self.type == other.type + class CBlockLocator: __slots__ = ("nVersion", "vHave")