diff --git a/src/protocol.cpp b/src/protocol.cpp index 0504fe95f..827680958 100644 --- a/src/protocol.cpp +++ b/src/protocol.cpp @@ -1,305 +1,301 @@ // Copyright (c) 2009-2010 Satoshi Nakamoto // Copyright (c) 2009-2016 The Bitcoin Core developers // Distributed under the MIT software license, see the accompanying // file COPYING or http://www.opensource.org/licenses/mit-license.php. #include #include #include #include #include #ifndef WIN32 #include #endif #include static std::atomic g_initial_block_download_completed(false); namespace NetMsgType { const char *VERSION = "version"; const char *VERACK = "verack"; const char *ADDR = "addr"; const char *INV = "inv"; const char *GETDATA = "getdata"; const char *MERKLEBLOCK = "merkleblock"; const char *GETBLOCKS = "getblocks"; const char *GETHEADERS = "getheaders"; const char *TX = "tx"; const char *HEADERS = "headers"; const char *BLOCK = "block"; const char *GETADDR = "getaddr"; const char *MEMPOOL = "mempool"; const char *PING = "ping"; const char *PONG = "pong"; const char *NOTFOUND = "notfound"; const char *FILTERLOAD = "filterload"; const char *FILTERADD = "filteradd"; const char *FILTERCLEAR = "filterclear"; const char *SENDHEADERS = "sendheaders"; const char *FEEFILTER = "feefilter"; const char *SENDCMPCT = "sendcmpct"; const char *CMPCTBLOCK = "cmpctblock"; const char *GETBLOCKTXN = "getblocktxn"; const char *BLOCKTXN = "blocktxn"; const char *GETCFILTERS = "getcfilters"; const char *CFILTER = "cfilter"; const char *GETCFHEADERS = "getcfheaders"; const char *CFHEADERS = "cfheaders"; const char *GETCFCHECKPT = "getcfcheckpt"; const char *CFCHECKPT = "cfcheckpt"; const char *AVAPOLL = "avapoll"; const char *AVARESPONSE = "avaresponse"; bool IsBlockLike(const std::string &strCommand) { return strCommand == NetMsgType::BLOCK || strCommand == NetMsgType::CMPCTBLOCK || strCommand == NetMsgType::BLOCKTXN; } }; // namespace NetMsgType /** * All known message types. Keep this in the same order as the list of messages * above and in protocol.h. */ static const std::string allNetMessageTypes[] = { NetMsgType::VERSION, NetMsgType::VERACK, NetMsgType::ADDR, NetMsgType::INV, NetMsgType::GETDATA, NetMsgType::MERKLEBLOCK, NetMsgType::GETBLOCKS, NetMsgType::GETHEADERS, NetMsgType::TX, NetMsgType::HEADERS, NetMsgType::BLOCK, NetMsgType::GETADDR, NetMsgType::MEMPOOL, NetMsgType::PING, NetMsgType::PONG, NetMsgType::NOTFOUND, NetMsgType::FILTERLOAD, NetMsgType::FILTERADD, NetMsgType::FILTERCLEAR, NetMsgType::SENDHEADERS, NetMsgType::FEEFILTER, NetMsgType::SENDCMPCT, NetMsgType::CMPCTBLOCK, NetMsgType::GETBLOCKTXN, NetMsgType::BLOCKTXN, NetMsgType::GETCFILTERS, NetMsgType::CFILTER, NetMsgType::GETCFHEADERS, NetMsgType::CFHEADERS, NetMsgType::GETCFCHECKPT, NetMsgType::CFCHECKPT, }; static const std::vector allNetMessageTypesVec(allNetMessageTypes, allNetMessageTypes + ARRAYLEN(allNetMessageTypes)); CMessageHeader::CMessageHeader(const MessageMagic &pchMessageStartIn) { memcpy(std::begin(pchMessageStart), std::begin(pchMessageStartIn), MESSAGE_START_SIZE); memset(pchCommand.data(), 0, sizeof(pchCommand)); nMessageSize = -1; memset(pchChecksum, 0, CHECKSUM_SIZE); } CMessageHeader::CMessageHeader(const MessageMagic &pchMessageStartIn, const char *pszCommand, unsigned int nMessageSizeIn) { memcpy(std::begin(pchMessageStart), std::begin(pchMessageStartIn), MESSAGE_START_SIZE); // Copy the command name, zero-padding to COMMAND_SIZE bytes size_t i = 0; for (; i < pchCommand.size() && pszCommand[i] != 0; ++i) { pchCommand[i] = pszCommand[i]; } // Assert that the command name passed in is not longer than COMMAND_SIZE assert(pszCommand[i] == 0); for (; i < pchCommand.size(); ++i) { pchCommand[i] = 0; } nMessageSize = nMessageSizeIn; memset(pchChecksum, 0, CHECKSUM_SIZE); } std::string CMessageHeader::GetCommand() const { // return std::string(pchCommand.begin(), pchCommand.end()); return std::string(pchCommand.data(), pchCommand.data() + strnlen(pchCommand.data(), COMMAND_SIZE)); } static bool CheckHeaderMagicAndCommand(const CMessageHeader &header, const CMessageHeader::MessageMagic &magic) { // Check start string if (memcmp(std::begin(header.pchMessageStart), std::begin(magic), CMessageHeader::MESSAGE_START_SIZE) != 0) { return false; } // Check the command string for errors for (const char *p1 = header.pchCommand.data(); p1 < header.pchCommand.data() + CMessageHeader::COMMAND_SIZE; p1++) { if (*p1 == 0) { // Must be all zeros after the first zero for (; p1 < header.pchCommand.data() + CMessageHeader::COMMAND_SIZE; p1++) { if (*p1 != 0) { return false; } } } else if (*p1 < ' ' || *p1 > 0x7E) { return false; } } return true; } bool CMessageHeader::IsValid(const Config &config) const { // Check start string if (!CheckHeaderMagicAndCommand(*this, config.GetChainParams().NetMagic())) { return false; } // Message size if (IsOversized(config)) { LogPrintf("CMessageHeader::IsValid(): (%s, %u bytes) is oversized\n", GetCommand(), nMessageSize); return false; } return true; } /** * This is a transition method in order to stay compatible with older code that * do not use the config. It assumes message will not get too large. This cannot * be used for any piece of code that will download blocks as blocks may be * bigger than the permitted size. Idealy, code that uses this function should * be migrated toward using the config. */ bool CMessageHeader::IsValidWithoutConfig(const MessageMagic &magic) const { // Check start string if (!CheckHeaderMagicAndCommand(*this, magic)) { return false; } // Message size if (nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH) { LogPrintf( "CMessageHeader::IsValidForSeeder(): (%s, %u bytes) is oversized\n", GetCommand(), nMessageSize); return false; } return true; } bool CMessageHeader::IsOversized(const Config &config) const { // If the message doesn't not contain a block content, check against // MAX_PROTOCOL_MESSAGE_LENGTH. if (nMessageSize > MAX_PROTOCOL_MESSAGE_LENGTH && !NetMsgType::IsBlockLike(GetCommand())) { return true; } // Scale the maximum accepted size with the block size. if (nMessageSize > 2 * config.GetMaxBlockSize()) { return true; } return false; } ServiceFlags GetDesirableServiceFlags(ServiceFlags services) { if ((services & NODE_NETWORK_LIMITED) && g_initial_block_download_completed) { return ServiceFlags(NODE_NETWORK_LIMITED); } return ServiceFlags(NODE_NETWORK); } void SetServiceFlagsIBDCache(bool state) { g_initial_block_download_completed = state; } CAddress::CAddress() : CService() { Init(); } CAddress::CAddress(CService ipIn, ServiceFlags nServicesIn) : CService(ipIn) { Init(); nServices = nServicesIn; } void CAddress::Init() { nServices = NODE_NONE; nTime = 100000000; } std::string CInv::GetCommand() const { std::string cmd; switch (GetKind()) { case MSG_TX: return cmd.append(NetMsgType::TX); case MSG_BLOCK: return cmd.append(NetMsgType::BLOCK); case MSG_FILTERED_BLOCK: return cmd.append(NetMsgType::MERKLEBLOCK); case MSG_CMPCT_BLOCK: return cmd.append(NetMsgType::CMPCTBLOCK); default: throw std::out_of_range( strprintf("CInv::GetCommand(): type=%d unknown type", type)); } } std::string CInv::ToString() const { try { return strprintf("%s %s", GetCommand(), hash.ToString()); } catch (const std::out_of_range &) { return strprintf("0x%08x %s", type, hash.ToString()); } } const std::vector &getAllNetMessageTypes() { return allNetMessageTypesVec; } /** * Convert a service flag (NODE_*) to a human readable string. * It supports unknown service flags which will be returned as "UNKNOWN[...]". * @param[in] bit the service flag is calculated as (1 << bit) */ static std::string serviceFlagToStr(const size_t bit) { const uint64_t service_flag = 1ULL << bit; switch (ServiceFlags(service_flag)) { case NODE_NONE: // impossible abort(); case NODE_NETWORK: return "NETWORK"; case NODE_GETUTXO: return "GETUTXO"; case NODE_BLOOM: return "BLOOM"; case NODE_XTHIN: return "XTHIN"; case NODE_NETWORK_LIMITED: return "NETWORK_LIMITED"; case NODE_AVALANCHE: return "AVALANCHE"; default: std::ostringstream stream; stream.imbue(std::locale::classic()); stream << "UNKNOWN["; - if (bit < 8) { - stream << mask; - } else { - stream << "2^" << bit; - } + stream << "2^" << bit; stream << "]"; return stream.str(); } } std::vector serviceFlagsToStr(const uint64_t flags) { std::vector str_flags; for (size_t i = 0; i < sizeof(flags) * 8; ++i) { if (flags & (1ULL << i)) { str_flags.emplace_back(serviceFlagToStr(i)); } } return str_flags; } diff --git a/test/functional/rpc_net.py b/test/functional/rpc_net.py index af88181d4..3370fd239 100755 --- a/test/functional/rpc_net.py +++ b/test/functional/rpc_net.py @@ -1,201 +1,211 @@ #!/usr/bin/env python3 # Copyright (c) 2017 The Bitcoin Core developers # Distributed under the MIT software license, see the accompanying # file COPYING or http://www.opensource.org/licenses/mit-license.php. """Test RPC calls related to net. Tests correspond to code in rpc/net.cpp. """ from decimal import Decimal from test_framework.test_framework import BitcoinTestFramework from test_framework.util import ( assert_equal, assert_greater_than_or_equal, assert_greater_than, assert_raises_rpc_error, connect_nodes, p2p_port, wait_until, ) from test_framework.mininode import P2PInterface import test_framework.messages from test_framework.messages import ( CAddress, msg_addr, NODE_NETWORK, ) def assert_net_servicesnames(servicesflag, servicenames): """Utility that checks if all flags are correctly decoded in `getpeerinfo` and `getnetworkinfo`. :param servicesflag: The services as an integer. :param servicenames: The list of decoded services names, as strings. """ servicesflag_generated = 0 for servicename in servicenames: servicesflag_generated |= getattr( test_framework.messages, 'NODE_' + servicename) assert servicesflag_generated == servicesflag class NetTest(BitcoinTestFramework): def set_test_params(self): self.setup_clean_chain = True self.num_nodes = 2 self.extra_args = [["-minrelaytxfee=0.00001000"], ["-minrelaytxfee=0.00000500"]] self.supports_cli = False def run_test(self): self.log.info('Connect nodes both way') connect_nodes(self.nodes[0], self.nodes[1]) connect_nodes(self.nodes[1], self.nodes[0]) self._test_connection_count() self._test_getnettotals() self._test_getnetworkinfo() self._test_getaddednodeinfo() self._test_getpeerinfo() + self.test_service_flags() self._test_getnodeaddresses() def _test_connection_count(self): # connect_nodes connects each node to the other assert_equal(self.nodes[0].getconnectioncount(), 2) def _test_getnettotals(self): # getnettotals totalbytesrecv and totalbytessent should be # consistent with getpeerinfo. Since the RPC calls are not atomic, # and messages might have been recvd or sent between RPC calls, call # getnettotals before and after and verify that the returned values # from getpeerinfo are bounded by those values. net_totals_before = self.nodes[0].getnettotals() peer_info = self.nodes[0].getpeerinfo() net_totals_after = self.nodes[0].getnettotals() assert_equal(len(peer_info), 2) peers_recv = sum([peer['bytesrecv'] for peer in peer_info]) peers_sent = sum([peer['bytessent'] for peer in peer_info]) assert_greater_than_or_equal( peers_recv, net_totals_before['totalbytesrecv']) assert_greater_than_or_equal( net_totals_after['totalbytesrecv'], peers_recv) assert_greater_than_or_equal( peers_sent, net_totals_before['totalbytessent']) assert_greater_than_or_equal( net_totals_after['totalbytessent'], peers_sent) # test getnettotals and getpeerinfo by doing a ping # the bytes sent/received should change # note ping and pong are 32 bytes each self.nodes[0].ping() wait_until(lambda: (self.nodes[0].getnettotals()[ 'totalbytessent'] >= net_totals_after['totalbytessent'] + 32 * 2), timeout=1) wait_until(lambda: (self.nodes[0].getnettotals()[ 'totalbytesrecv'] >= net_totals_after['totalbytesrecv'] + 32 * 2), timeout=1) peer_info_after_ping = self.nodes[0].getpeerinfo() for before, after in zip(peer_info, peer_info_after_ping): assert_greater_than_or_equal( after['bytesrecv_per_msg'].get( 'pong', 0), before['bytesrecv_per_msg'].get( 'pong', 0) + 32) assert_greater_than_or_equal( after['bytessent_per_msg'].get( 'ping', 0), before['bytessent_per_msg'].get( 'ping', 0) + 32) def _test_getnetworkinfo(self): assert_equal(self.nodes[0].getnetworkinfo()['networkactive'], True) assert_equal(self.nodes[0].getnetworkinfo()['connections'], 2) self.nodes[0].setnetworkactive(state=False) assert_equal(self.nodes[0].getnetworkinfo()['networkactive'], False) # Wait a bit for all sockets to close wait_until(lambda: self.nodes[0].getnetworkinfo()[ 'connections'] == 0, timeout=3) self.nodes[0].setnetworkactive(state=True) self.log.info('Connect nodes both way') connect_nodes(self.nodes[0], self.nodes[1]) connect_nodes(self.nodes[1], self.nodes[0]) assert_equal(self.nodes[0].getnetworkinfo()['networkactive'], True) assert_equal(self.nodes[0].getnetworkinfo()['connections'], 2) # check the `servicesnames` field network_info = [node.getnetworkinfo() for node in self.nodes] for info in network_info: assert_net_servicesnames(int(info["localservices"], 0x10), info["localservicesnames"]) def _test_getaddednodeinfo(self): assert_equal(self.nodes[0].getaddednodeinfo(), []) # add a node (node2) to node0 ip_port = "127.0.0.1:{}".format(p2p_port(2)) self.nodes[0].addnode(node=ip_port, command='add') # check that the node has indeed been added added_nodes = self.nodes[0].getaddednodeinfo(ip_port) assert_equal(len(added_nodes), 1) assert_equal(added_nodes[0]['addednode'], ip_port) # check that a non-existent node returns an error assert_raises_rpc_error(-24, "Node has not been added", self.nodes[0].getaddednodeinfo, '1.1.1.1') def _test_getpeerinfo(self): peer_info = [x.getpeerinfo() for x in self.nodes] # check both sides of bidirectional connection between nodes # the address bound to on one side will be the source address for the # other node assert_equal(peer_info[0][0]['addrbind'], peer_info[1][0]['addr']) assert_equal(peer_info[1][0]['addrbind'], peer_info[0][0]['addr']) assert_equal(peer_info[0][0]['minfeefilter'], Decimal("0.00000500")) assert_equal(peer_info[1][0]['minfeefilter'], Decimal("0.00001000")) # check the `servicesnames` field for info in peer_info: assert_net_servicesnames(int(info[0]["services"], 0x10), info[0]["servicesnames"]) + def test_service_flags(self): + self.nodes[0].add_p2p_connection( + P2PInterface(), services=( + 1 << 5) | ( + 1 << 63)) + assert_equal(['UNKNOWN[2^5]', 'UNKNOWN[2^63]'], + self.nodes[0].getpeerinfo()[-1]['servicesnames']) + self.nodes[0].disconnect_p2ps() + def _test_getnodeaddresses(self): self.nodes[0].add_p2p_connection(P2PInterface()) # send some addresses to the node via the p2p message addr msg = msg_addr() imported_addrs = [] for i in range(256): a = "123.123.123.{}".format(i) imported_addrs.append(a) addr = CAddress() addr.time = 100000000 addr.nServices = NODE_NETWORK addr.ip = a addr.port = 8333 msg.addrs.append(addr) self.nodes[0].p2p.send_and_ping(msg) # obtain addresses via rpc call and check they were ones sent in before REQUEST_COUNT = 10 node_addresses = self.nodes[0].getnodeaddresses(REQUEST_COUNT) assert_equal(len(node_addresses), REQUEST_COUNT) for a in node_addresses: assert_greater_than(a["time"], 1527811200) # 1st June 2018 assert_equal(a["services"], NODE_NETWORK) assert a["address"] in imported_addrs assert_equal(a["port"], 8333) assert_raises_rpc_error(-8, "Address count out of range", self.nodes[0].getnodeaddresses, -1) # addrman's size cannot be known reliably after insertion, as hash collisions may occur # so only test that requesting a large number of addresses returns less # than that LARGE_REQUEST_COUNT = 10000 node_addresses = self.nodes[0].getnodeaddresses(LARGE_REQUEST_COUNT) assert_greater_than(LARGE_REQUEST_COUNT, len(node_addresses)) if __name__ == '__main__': NetTest().main()