diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -948,7 +948,8 @@ EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); std::vector GroupOutputs(const std::vector &outputs, - bool single_coin) const; + bool single_coin, + const size_t max_ancestors) const; bool IsLockedCoin(const COutPoint &outpoint) const EXCLUSIVE_LOCKS_REQUIRED(cs_wallet); diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -2640,6 +2640,12 @@ } } + size_t max_ancestors{0}; + size_t max_descendants{0}; + chain().getPackageLimits(max_ancestors, max_descendants); + bool fRejectLongChains = gArgs.GetBoolArg( + "-walletrejectlongchains", DEFAULT_WALLET_REJECT_LONG_CHAINS); + // form groups from remaining coins; note that preset coins will not // automatically have their associated (same address) coins included if (coin_control.m_avoid_partial_spends && @@ -2650,14 +2656,9 @@ // outputs before processing Shuffle(vCoins.begin(), vCoins.end(), FastRandomContext()); } - std::vector groups = - GroupOutputs(vCoins, !coin_control.m_avoid_partial_spends); - size_t max_ancestors{0}; - size_t max_descendants{0}; - chain().getPackageLimits(max_ancestors, max_descendants); - bool fRejectLongChains = gArgs.GetBoolArg( - "-walletrejectlongchains", DEFAULT_WALLET_REJECT_LONG_CHAINS); + std::vector groups = GroupOutputs( + vCoins, !coin_control.m_avoid_partial_spends, max_ancestors); bool res = nTargetValue <= nValueFromPresetInputs || @@ -4534,13 +4535,15 @@ } std::vector -CWallet::GroupOutputs(const std::vector &outputs, - bool single_coin) const { +CWallet::GroupOutputs(const std::vector &outputs, bool single_coin, + const size_t max_ancestors) const { std::vector groups; std::map gmap; - CTxDestination dst; + std::set full_groups; + for (const auto &output : outputs) { if (output.fSpendable) { + CTxDestination dst; CInputCoin input_coin = output.GetInputCoin(); size_t ancestors, descendants; @@ -4549,16 +4552,28 @@ if (!single_coin && ExtractDestination(output.tx->tx->vout[output.i].scriptPubKey, dst)) { - // Limit output groups to no more than 10 entries, to protect - // against inadvertently creating a too-large transaction - // when using -avoidpartialspends - if (gmap[dst].m_outputs.size() >= OUTPUT_GROUP_MAX_ENTRIES) { - groups.push_back(gmap[dst]); - gmap.erase(dst); + auto it = gmap.find(dst); + if (it != gmap.end()) { + // Limit output groups to no more than + // OUTPUT_GROUP_MAX_ENTRIES number of entries, to protect + // against inadvertently creating a too-large transaction + // when using -avoidpartialspends to prevent breaking + // consensus or surprising users with a very high amount of + // fees. + if (it->second.m_outputs.size() >= + OUTPUT_GROUP_MAX_ENTRIES) { + groups.push_back(it->second); + it->second = OutputGroup{}; + full_groups.insert(dst); + } + it->second.Insert(input_coin, output.nDepth, + output.tx->IsFromMe(ISMINE_ALL), + ancestors, descendants); + } else { + gmap[dst].Insert(input_coin, output.nDepth, + output.tx->IsFromMe(ISMINE_ALL), ancestors, + descendants); } - gmap[dst].Insert(input_coin, output.nDepth, - output.tx->IsFromMe(ISMINE_ALL), ancestors, - descendants); } else { groups.emplace_back(input_coin, output.nDepth, output.tx->IsFromMe(ISMINE_ALL), ancestors, @@ -4567,8 +4582,14 @@ } } if (!single_coin) { - for (const auto &it : gmap) { - groups.push_back(it.second); + for (auto &it : gmap) { + auto &group = it.second; + if (full_groups.count(it.first) > 0) { + // Make this unattractive as we want coin selection to avoid it + // if possible + group.m_ancestors = max_ancestors - 1; + } + groups.push_back(group); } } return groups; diff --git a/test/functional/wallet_avoidreuse.py b/test/functional/wallet_avoidreuse.py --- a/test/functional/wallet_avoidreuse.py +++ b/test/functional/wallet_avoidreuse.py @@ -95,11 +95,15 @@ self.sync_all() self.test_change_remains_change(self.nodes[1]) reset_balance(self.nodes[1], self.nodes[0].getnewaddress()) - self.test_fund_send_fund_senddirty() + self.test_sending_from_reused_address_without_avoid_reuse() reset_balance(self.nodes[1], self.nodes[0].getnewaddress()) - self.test_fund_send_fund_send() + self.test_sending_from_reused_address_fails() reset_balance(self.nodes[1], self.nodes[0].getnewaddress()) self.test_getbalances_used() + reset_balance(self.nodes[1], self.nodes[0].getnewaddress()) + self.test_full_destination_group_is_preferred() + reset_balance(self.nodes[1], self.nodes[0].getnewaddress()) + self.test_all_destination_groups_are_used() def test_persistence(self): '''Test that wallet files persist the avoid_reuse flag.''' @@ -184,13 +188,14 @@ for logical_tx in node.listtransactions(): assert logical_tx.get('address') != changeaddr - def test_fund_send_fund_senddirty(self): + def test_sending_from_reused_address_without_avoid_reuse(self): ''' - Test the same as test_fund_send_fund_send, except send the 10 BCH with - the avoid_reuse flag set to false. This means the 10 BCH send should succeed, - where it fails in test_fund_send_fund_send. + Test the same as test_sending_from_reused_address_fails, except send the 10 BCH with + the avoid_reuse flag set to false. This means the 10 BTC send should succeed, + where it fails in test_sending_from_reused_address_fails. ''' - self.log.info("Test fund send fund send dirty") + self.log.info( + "Test sending from reused address with avoid_reuse=false") fundaddr = self.nodes[1].getnewaddress() retaddr = self.nodes[0].getnewaddress() @@ -257,7 +262,7 @@ assert_approx(self.nodes[1].getbalance(), 5, 0.001) assert_approx(self.nodes[1].getbalance(avoid_reuse=False), 5, 0.001) - def test_fund_send_fund_send(self): + def test_sending_from_reused_address_fails(self): ''' Test the simple case where [1] generates a new address A, then [0] sends 10 BCH to A. @@ -266,7 +271,7 @@ [1] tries to spend 10 BCH (fails; dirty). [1] tries to spend 4 BCH (succeeds; change address sufficient) ''' - self.log.info("Test fund send fund send") + self.log.info("Test sending from reused address fails") fundaddr = self.nodes[1].getnewaddress(label="", address_type="legacy") retaddr = self.nodes[0].getnewaddress() @@ -381,6 +386,67 @@ reused_sum=1) assert_balances(self.nodes[1], mine={"used": 1, "trusted": 5}) + def test_full_destination_group_is_preferred(self): + ''' + Test the case where [1] only has 11 outputs of 1 BCH in the same reused + address and tries to send a small payment of 0.5 BCH. The wallet + should use 10 outputs from the reused address as inputs and not a + single 1 BCH input, in order to join several outputs from the reused + address. + ''' + self.log.info( + "Test that full destination groups are preferred in coin selection") + + # Node under test should be empty + assert_equal(self.nodes[1].getbalance(avoid_reuse=False), 0) + + new_addr = self.nodes[1].getnewaddress() + ret_addr = self.nodes[0].getnewaddress() + + # Send 11 outputs of 1 BCH to the same, reused address in the wallet + for _ in range(11): + self.nodes[0].sendtoaddress(new_addr, 1) + + self.nodes[0].generate(1) + self.sync_all() + + # Sending a transaction that is smaller than each one of the + # available outputs + txid = self.nodes[1].sendtoaddress(address=ret_addr, amount=0.5) + inputs = self.nodes[1].getrawtransaction(txid, 1)["vin"] + + # The transaction should use 10 inputs exactly + assert_equal(len(inputs), 10) + + def test_all_destination_groups_are_used(self): + ''' + Test the case where [1] only has 22 outputs of 1 BCH in the same reused + address and tries to send a payment of 20.5 BCH. The wallet + should use all 22 outputs from the reused address as inputs. + ''' + self.log.info("Test that all destination groups are used") + + # Node under test should be empty + assert_equal(self.nodes[1].getbalance(avoid_reuse=False), 0) + + new_addr = self.nodes[1].getnewaddress() + ret_addr = self.nodes[0].getnewaddress() + + # Send 22 outputs of 1 BCH to the same, reused address in the wallet + for _ in range(22): + self.nodes[0].sendtoaddress(new_addr, 1) + + self.nodes[0].generate(1) + self.sync_all() + + # Sending a transaction that needs to use the full groups + # of 10 inputs but also the incomplete group of 2 inputs. + txid = self.nodes[1].sendtoaddress(address=ret_addr, amount=20.5) + inputs = self.nodes[1].getrawtransaction(txid, 1)["vin"] + + # The transaction should use 22 inputs exactly + assert_equal(len(inputs), 22) + if __name__ == '__main__': AvoidReuseTest().main()