diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -100,9 +100,13 @@ }); } +static Mutex g_loading_wallet_mutex; static Mutex g_wallet_release_mutex; static std::condition_variable g_wallet_release_cv; -static std::set g_unloading_wallet_set; +static std::set + g_loading_wallet_set GUARDED_BY(g_loading_wallet_mutex); +static std::set + g_unloading_wallet_set GUARDED_BY(g_wallet_release_mutex); // Custom deleter for shared_ptr. static void ReleaseWallet(CWallet *wallet) { @@ -146,11 +150,11 @@ static const size_t OUTPUT_GROUP_MAX_ENTRIES = 10; -std::shared_ptr LoadWallet(const CChainParams &chainParams, - interfaces::Chain &chain, - const WalletLocation &location, - bilingual_str &error, - std::vector &warnings) { +namespace { +std::shared_ptr +LoadWalletInternal(const CChainParams &chainParams, interfaces::Chain &chain, + const WalletLocation &location, bilingual_str &error, + std::vector &warnings) { try { if (!CWallet::Verify(chainParams, chain, location, error, warnings)) { error = Untranslated("Wallet file verification failed.") + @@ -173,6 +177,25 @@ return nullptr; } } +} // namespace + +std::shared_ptr LoadWallet(const CChainParams &chainParams, + interfaces::Chain &chain, + const WalletLocation &location, + bilingual_str &error, + std::vector &warnings) { + auto result = + WITH_LOCK(g_loading_wallet_mutex, + return g_loading_wallet_set.insert(location.GetName())); + if (!result.second) { + error = Untranslated("Wallet already being loading."); + return nullptr; + } + auto wallet = + LoadWalletInternal(chainParams, chain, location, error, warnings); + WITH_LOCK(g_loading_wallet_mutex, g_loading_wallet_set.erase(result.first)); + return wallet; +} std::shared_ptr LoadWallet(const CChainParams &chainParams, interfaces::Chain &chain, diff --git a/test/functional/wallet_multiwallet.py b/test/functional/wallet_multiwallet.py --- a/test/functional/wallet_multiwallet.py +++ b/test/functional/wallet_multiwallet.py @@ -7,19 +7,39 @@ Verify that a bitcoind node can load multiple wallet files """ from decimal import Decimal +from threading import Thread import os import shutil import time +from test_framework.authproxy import JSONRPCException from test_framework.test_framework import BitcoinTestFramework from test_framework.test_node import ErrorMatch from test_framework.util import ( assert_equal, assert_raises_rpc_error, + get_rpc_proxy, ) FEATURE_LATEST = 200300 +got_loading_error = False + + +def test_load_unload(node, name): + global got_loading_error + for i in range(10): + if got_loading_error: + return + try: + node.loadwallet(name) + node.unloadwallet(name) + except JSONRPCException as e: + if e.error['code'] == - \ + 4 and 'Wallet already being loading' in e.error['message']: + got_loading_error = True + return + class MultiWalletTest(BitcoinTestFramework): def set_test_params(self): @@ -257,6 +277,19 @@ w2 = node.get_wallet_rpc(wallet_names[1]) w2.getwalletinfo() + self.log.info("Concurrent wallet loading") + threads = [] + for _ in range(3): + n = node.cli if self.options.usecli else get_rpc_proxy( + node.url, 1, timeout=600, coveragedir=node.coverage_dir) + t = Thread(target=test_load_unload, args=(n, wallet_names[2], )) + t.start() + threads.append(t) + for t in threads: + t.join() + global got_loading_error + assert_equal(got_loading_error, True) + self.log.info("Load remaining wallets") for wallet_name in wallet_names[2:]: loadwallet_name = self.nodes[0].loadwallet(wallet_name)