diff --git a/src/blockfilter.h b/src/blockfilter.h --- a/src/blockfilter.h +++ b/src/blockfilter.h @@ -142,7 +142,7 @@ uint256 ComputeHeader(const uint256 &prev_header) const; template void Serialize(Stream &s) const { - s << m_block_hash << static_cast(m_filter_type) + s << static_cast(m_filter_type) << m_block_hash << m_filter.GetEncoded(); } @@ -150,7 +150,7 @@ std::vector encoded_filter; uint8_t filter_type; - s >> m_block_hash >> filter_type >> encoded_filter; + s >> filter_type >> m_block_hash >> encoded_filter; m_filter_type = static_cast(filter_type); diff --git a/src/net_processing.cpp b/src/net_processing.cpp --- a/src/net_processing.cpp +++ b/src/net_processing.cpp @@ -165,6 +165,11 @@ * Maximum feefilter broadcast delay after significant change. */ static constexpr unsigned int MAX_FEEFILTER_CHANGE_DELAY = 5 * 60; +/** + * Maximum number of compact filters that may be requested with one + * getcfilters. See BIP 157. + */ +static constexpr uint32_t MAX_GETCFILTERS_SIZE = 1000; /** * Maximum number of cf hashes that may be requested with one getcfheaders. See * BIP 157. @@ -2258,7 +2263,7 @@ * @return True if the request can be serviced. */ static bool PrepareBlockFilterRequest( - CNode *pfrom, const CChainParams &chain_params, BlockFilterType filter_type, + CNode &pfrom, const CChainParams &chain_params, BlockFilterType filter_type, uint32_t start_height, const BlockHash &stop_hash, uint32_t max_height_diff, const CBlockIndex *&stop_index, BlockFilterIndex *&filter_index) { const bool supported_filter_type = @@ -2267,8 +2272,8 @@ if (!supported_filter_type) { LogPrint(BCLog::NET, "peer %d requested unsupported block filter type: %d\n", - pfrom->GetId(), static_cast(filter_type)); - pfrom->fDisconnect = true; + pfrom.GetId(), static_cast(filter_type)); + pfrom.fDisconnect = true; return false; } @@ -2281,8 +2286,8 @@ if (!stop_index || !BlockRequestAllowed(stop_index, chain_params.GetConsensus())) { LogPrint(BCLog::NET, "peer %d requested invalid block hash: %s\n", - pfrom->GetId(), stop_hash.ToString()); - pfrom->fDisconnect = true; + pfrom.GetId(), stop_hash.ToString()); + pfrom.fDisconnect = true; return false; } } @@ -2294,16 +2299,16 @@ "peer %d sent invalid getcfilters/getcfheaders with " /* Continued */ "start height %d and stop height %d\n", - pfrom->GetId(), start_height, stop_height); - pfrom->fDisconnect = true; + pfrom.GetId(), start_height, stop_height); + pfrom.fDisconnect = true; return false; } if (stop_height - start_height >= max_height_diff) { LogPrint(BCLog::NET, "peer %d requested too many cfilters/cfheaders: %d / %d\n", - pfrom->GetId(), stop_height - start_height + 1, + pfrom.GetId(), stop_height - start_height + 1, max_height_diff); - pfrom->fDisconnect = true; + pfrom.fDisconnect = true; return false; } @@ -2317,6 +2322,54 @@ return true; } +/** + * Handle a cfilters request. + * + * May disconnect from the peer in the case of a bad request. + * + * @param[in] pfrom The peer that we received the request from + * @param[in] vRecv The raw message received + * @param[in] chain_params Chain parameters + * @param[in] connman Pointer to the connection manager + */ +static void ProcessGetCFilters(CNode &pfrom, CDataStream &vRecv, + const CChainParams &chain_params, + CConnman &connman) { + uint8_t filter_type_ser; + uint32_t start_height; + BlockHash stop_hash; + + vRecv >> filter_type_ser >> start_height >> stop_hash; + + const BlockFilterType filter_type = + static_cast(filter_type_ser); + + const CBlockIndex *stop_index; + BlockFilterIndex *filter_index; + if (!PrepareBlockFilterRequest( + pfrom, chain_params, filter_type, start_height, stop_hash, + MAX_GETCFILTERS_SIZE, stop_index, filter_index)) { + return; + } + + std::vector filters; + + if (!filter_index->LookupFilterRange(start_height, stop_index, filters)) { + LogPrint(BCLog::NET, + "Failed to find block filter in index: filter_type=%s, " + "start_height=%d, stop_hash=%s\n", + BlockFilterTypeName(filter_type), start_height, + stop_hash.ToString()); + return; + } + + for (const auto &filter : filters) { + CSerializedNetMsg msg = CNetMsgMaker(pfrom.GetSendVersion()) + .Make(NetMsgType::CFILTER, filter); + connman.PushMessage(&pfrom, std::move(msg)); + } +} + /** * Handle a cfheaders request. * @@ -2327,9 +2380,9 @@ * @param[in] chain_params Chain parameters * @param[in] connman Pointer to the connection manager */ -static void ProcessGetCFHeaders(CNode *pfrom, CDataStream &vRecv, +static void ProcessGetCFHeaders(CNode &pfrom, CDataStream &vRecv, const CChainParams &chain_params, - CConnman *connman) { + CConnman &connman) { uint8_t filter_type_ser; uint32_t start_height; BlockHash stop_hash; @@ -2373,10 +2426,10 @@ } CSerializedNetMsg msg = - CNetMsgMaker(pfrom->GetSendVersion()) + CNetMsgMaker(pfrom.GetSendVersion()) .Make(NetMsgType::CFHEADERS, filter_type_ser, stop_index->GetBlockHash(), prev_header, filter_hashes); - connman->PushMessage(pfrom, std::move(msg)); + connman.PushMessage(&pfrom, std::move(msg)); } /** @@ -2389,9 +2442,9 @@ * @param[in] chain_params Chain parameters * @param[in] connman Pointer to the connection manager */ -static void ProcessGetCFCheckPt(CNode *pfrom, CDataStream &vRecv, +static void ProcessGetCFCheckPt(CNode &pfrom, CDataStream &vRecv, const CChainParams &chain_params, - CConnman *connman) { + CConnman &connman) { uint8_t filter_type_ser; BlockHash stop_hash; @@ -2427,10 +2480,10 @@ } } - CSerializedNetMsg msg = CNetMsgMaker(pfrom->GetSendVersion()) + CSerializedNetMsg msg = CNetMsgMaker(pfrom.GetSendVersion()) .Make(NetMsgType::CFCHECKPT, filter_type_ser, stop_index->GetBlockHash(), headers); - connman->PushMessage(pfrom, std::move(msg)); + connman.PushMessage(&pfrom, std::move(msg)); } bool ProcessMessage(const Config &config, CNode *pfrom, @@ -4055,13 +4108,18 @@ return true; } + if (msg_type == NetMsgType::GETCFILTERS) { + ProcessGetCFilters(*pfrom, vRecv, chainparams, *connman); + return true; + } + if (msg_type == NetMsgType::GETCFHEADERS) { - ProcessGetCFHeaders(pfrom, vRecv, chainparams, connman); + ProcessGetCFHeaders(*pfrom, vRecv, chainparams, *connman); return true; } if (msg_type == NetMsgType::GETCFCHECKPT) { - ProcessGetCFCheckPt(pfrom, vRecv, chainparams, connman); + ProcessGetCFCheckPt(*pfrom, vRecv, chainparams, *connman); return true; } diff --git a/src/protocol.h b/src/protocol.h --- a/src/protocol.h +++ b/src/protocol.h @@ -251,6 +251,17 @@ * @since protocol version 70014 as described by BIP 152 */ extern const char *BLOCKTXN; +/** + * getcfilters requests compact filters for a range of blocks. + * Only available with service bit NODE_COMPACT_FILTERS as described by + * BIP 157 & 158. + */ +extern const char *GETCFILTERS; +/** + * cfilter is a response to a getcfilters request containing a single compact + * filter. + */ +extern const char *CFILTER; /** * getcfheaders requests a compact filter header and the filter hashes for a * range of blocks, which can then be used to reconstruct the filter headers diff --git a/src/protocol.cpp b/src/protocol.cpp --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -43,6 +43,8 @@ const char *CMPCTBLOCK = "cmpctblock"; const char *GETBLOCKTXN = "getblocktxn"; const char *BLOCKTXN = "blocktxn"; +const char *GETCFILTERS = "getcfilters"; +const char *CFILTER = "cfilter"; const char *GETCFHEADERS = "getcfheaders"; const char *CFHEADERS = "cfheaders"; const char *GETCFCHECKPT = "getcfcheckpt"; @@ -62,16 +64,17 @@ * above and in protocol.h. */ static const std::string allNetMessageTypes[] = { - NetMsgType::VERSION, NetMsgType::VERACK, NetMsgType::ADDR, - NetMsgType::INV, NetMsgType::GETDATA, NetMsgType::MERKLEBLOCK, - NetMsgType::GETBLOCKS, NetMsgType::GETHEADERS, NetMsgType::TX, - NetMsgType::HEADERS, NetMsgType::BLOCK, NetMsgType::GETADDR, - NetMsgType::MEMPOOL, NetMsgType::PING, NetMsgType::PONG, - NetMsgType::NOTFOUND, NetMsgType::FILTERLOAD, NetMsgType::FILTERADD, - NetMsgType::FILTERCLEAR, NetMsgType::SENDHEADERS, NetMsgType::FEEFILTER, - NetMsgType::SENDCMPCT, NetMsgType::CMPCTBLOCK, NetMsgType::GETBLOCKTXN, - NetMsgType::BLOCKTXN, NetMsgType::GETCFHEADERS, NetMsgType::CFHEADERS, - NetMsgType::GETCFCHECKPT, NetMsgType::CFCHECKPT, + NetMsgType::VERSION, NetMsgType::VERACK, NetMsgType::ADDR, + NetMsgType::INV, NetMsgType::GETDATA, NetMsgType::MERKLEBLOCK, + NetMsgType::GETBLOCKS, NetMsgType::GETHEADERS, NetMsgType::TX, + NetMsgType::HEADERS, NetMsgType::BLOCK, NetMsgType::GETADDR, + NetMsgType::MEMPOOL, NetMsgType::PING, NetMsgType::PONG, + NetMsgType::NOTFOUND, NetMsgType::FILTERLOAD, NetMsgType::FILTERADD, + NetMsgType::FILTERCLEAR, NetMsgType::SENDHEADERS, NetMsgType::FEEFILTER, + NetMsgType::SENDCMPCT, NetMsgType::CMPCTBLOCK, NetMsgType::GETBLOCKTXN, + NetMsgType::BLOCKTXN, NetMsgType::GETCFILTERS, NetMsgType::CFILTER, + NetMsgType::GETCFHEADERS, NetMsgType::CFHEADERS, NetMsgType::GETCFCHECKPT, + NetMsgType::CFCHECKPT, }; static const std::vector allNetMessageTypesVec(allNetMessageTypes, diff --git a/test/functional/p2p_blockfilters.py b/test/functional/p2p_blockfilters.py --- a/test/functional/p2p_blockfilters.py +++ b/test/functional/p2p_blockfilters.py @@ -13,6 +13,7 @@ hash256, msg_getcfcheckpt, msg_getcfheaders, + msg_getcfilters, ser_uint256, uint256_from_str, ) @@ -26,6 +27,22 @@ ) +class CFiltersClient(P2PInterface): + def __init__(self): + super().__init__() + # Store the cfilters received. + self.cfilters = [] + + def pop_cfilters(self): + cfilters = self.cfilters + self.cfilters = [] + return cfilters + + def on_cfilter(self, message): + """Store cfilters received in a list.""" + self.cfilters.append(message) + + class CompactFiltersTest(BitcoinTestFramework): def set_test_params(self): self.setup_clean_chain = True @@ -38,8 +55,8 @@ def run_test(self): # Node 0 supports COMPACT_FILTERS, node 1 does not. - node0 = self.nodes[0].add_p2p_connection(P2PInterface()) - node1 = self.nodes[1].add_p2p_connection(P2PInterface()) + node0 = self.nodes[0].add_p2p_connection(CFiltersClient()) + node1 = self.nodes[1].add_p2p_connection(CFiltersClient()) # Nodes 0 & 1 share the same first 999 blocks in the chain. self.nodes[0].generate(999) @@ -116,7 +133,8 @@ ) node0.send_and_ping(request) response = node0.last_message['cfheaders'] - assert_equal(len(response.hashes), 1000) + main_cfhashes = response.hashes + assert_equal(len(main_cfhashes), 1000) assert_equal( compute_last_header(response.prev_header, response.hashes), int(main_cfcheckpt, 16) @@ -130,12 +148,51 @@ ) node0.send_and_ping(request) response = node0.last_message['cfheaders'] - assert_equal(len(response.hashes), 1000) + stale_cfhashes = response.hashes + assert_equal(len(stale_cfhashes), 1000) assert_equal( compute_last_header(response.prev_header, response.hashes), int(stale_cfcheckpt, 16) ) + self.log.info("Check that peers can fetch cfilters.") + stop_hash = self.nodes[0].getblockhash(10) + request = msg_getcfilters( + filter_type=FILTER_TYPE_BASIC, + start_height=1, + stop_hash=int(stop_hash, 16) + ) + node0.send_message(request) + node0.sync_with_ping() + response = node0.pop_cfilters() + assert_equal(len(response), 10) + + self.log.info("Check that cfilter responses are correct.") + for cfilter, cfhash, height in zip( + response, main_cfhashes, range(1, 11)): + block_hash = self.nodes[0].getblockhash(height) + assert_equal(cfilter.filter_type, FILTER_TYPE_BASIC) + assert_equal(cfilter.block_hash, int(block_hash, 16)) + computed_cfhash = uint256_from_str(hash256(cfilter.filter_data)) + assert_equal(computed_cfhash, cfhash) + + self.log.info("Check that peers can fetch cfilters for stale blocks.") + request = msg_getcfilters( + filter_type=FILTER_TYPE_BASIC, + start_height=1000, + stop_hash=int(stale_block_hash, 16) + ) + node0.send_message(request) + node0.sync_with_ping() + response = node0.pop_cfilters() + assert_equal(len(response), 1) + + cfilter = response[0] + assert_equal(cfilter.filter_type, FILTER_TYPE_BASIC) + assert_equal(cfilter.block_hash, int(stale_block_hash, 16)) + computed_cfhash = uint256_from_str(hash256(cfilter.filter_data)) + assert_equal(computed_cfhash, stale_cfhashes[999]) + self.log.info( "Requests to node 1 without NODE_COMPACT_FILTERS results in disconnection.") requests = [ @@ -148,6 +205,11 @@ start_height=1000, stop_hash=int(main_block_hash, 16) ), + msg_getcfilters( + filter_type=FILTER_TYPE_BASIC, + start_height=1000, + stop_hash=int(main_block_hash, 16) + ), ] for request in requests: node1 = self.nodes[1].add_p2p_connection(P2PInterface()) @@ -156,6 +218,12 @@ self.log.info("Check that invalid requests result in disconnection.") requests = [ + # Requesting too many filters results in disconnection. + msg_getcfilters( + filter_type=FILTER_TYPE_BASIC, + start_height=0, + stop_hash=int(main_block_hash, 16) + ), # Requesting too many filter headers results in disconnection. msg_getcfheaders( filter_type=FILTER_TYPE_BASIC, diff --git a/test/functional/test_framework/messages.py b/test/functional/test_framework/messages.py --- a/test/functional/test_framework/messages.py +++ b/test/functional/test_framework/messages.py @@ -1471,6 +1471,58 @@ repr(self.block_transactions)) +class msg_getcfilters: + __slots__ = ("filter_type", "start_height", "stop_hash") + msgtype = b"getcfilters" + + def __init__(self, filter_type, start_height, stop_hash): + self.filter_type = filter_type + self.start_height = start_height + self.stop_hash = stop_hash + + def deserialize(self, f): + self.filter_type = struct.unpack("