diff --git a/src/interfaces/wallet.h b/src/interfaces/wallet.h --- a/src/interfaces/wallet.h +++ b/src/interfaces/wallet.h @@ -35,7 +35,6 @@ enum class FeeReason; enum class OutputType; enum class TransactionError; -enum class WalletCreationStatus; enum isminetype : unsigned int; struct CRecipient; struct PartiallySignedTransaction; @@ -311,8 +310,7 @@ //! Create new wallet. virtual std::unique_ptr<Wallet> createWallet(const std::string &name, const SecureString &passphrase, - uint64_t wallet_creation_flags, WalletCreationStatus &status, - bilingual_str &error, + uint64_t wallet_creation_flags, bilingual_str &error, std::vector<bilingual_str> &warnings) = 0; //! Load existing wallet. diff --git a/src/interfaces/wallet.cpp b/src/interfaces/wallet.cpp --- a/src/interfaces/wallet.cpp +++ b/src/interfaces/wallet.cpp @@ -509,21 +509,28 @@ //! WalletClient methods std::unique_ptr<Wallet> createWallet(const std::string &name, const SecureString &passphrase, - uint64_t wallet_creation_flags, - WalletCreationStatus &status, bilingual_str &error, + uint64_t wallet_creation_flags, bilingual_str &error, std::vector<bilingual_str> &warnings) override { std::shared_ptr<CWallet> wallet; - status = CreateWallet( - *m_context.chain, passphrase, wallet_creation_flags, name, - true /* load_on_start */, error, warnings, wallet); - return MakeWallet(std::move(wallet)); + DatabaseOptions options; + DatabaseStatus status; + options.require_create = true; + options.create_flags = wallet_creation_flags; + options.create_passphrase = passphrase; + + return MakeWallet(CreateWallet(*m_context.chain, name, + true /* load_on_start */, options, + status, error, warnings)); } std::unique_ptr<Wallet> loadWallet(const std::string &name, bilingual_str &error, std::vector<bilingual_str> &warnings) override { + DatabaseOptions options; + DatabaseStatus status; + options.require_existing = true; return MakeWallet(LoadWallet(*m_context.chain, name, - true /* load_on_start */, error, - warnings)); + true /* load_on_start */, options, + status, error, warnings)); } std::string getWalletDir() override { return GetWalletDir().string(); } std::vector<std::string> listWalletDir() override { diff --git a/src/qt/walletcontroller.cpp b/src/qt/walletcontroller.cpp --- a/src/qt/walletcontroller.cpp +++ b/src/qt/walletcontroller.cpp @@ -252,13 +252,11 @@ } QTimer::singleShot(500, worker(), [this, name, flags] { - WalletCreationStatus status; std::unique_ptr<interfaces::Wallet> wallet = - node().walletClient().createWallet(name, m_passphrase, flags, - status, m_error_message, - m_warning_message); + node().walletClient().createWallet( + name, m_passphrase, flags, m_error_message, m_warning_message); - if (status == WalletCreationStatus::SUCCESS) { + if (wallet) { m_wallet_model = m_wallet_controller->getOrCreateWallet(std::move(wallet)); } diff --git a/src/wallet/db.h b/src/wallet/db.h --- a/src/wallet/db.h +++ b/src/wallet/db.h @@ -9,6 +9,7 @@ #include <clientversion.h> #include <fs.h> #include <streams.h> +#include <support/allocators/secure.h> #include <atomic> #include <memory> @@ -227,6 +228,8 @@ struct DatabaseOptions { bool require_existing = false; bool require_create = false; + uint64_t create_flags = 0; + SecureString create_passphrase; bool verify = true; }; @@ -237,7 +240,9 @@ FAILED_ALREADY_LOADED, FAILED_ALREADY_EXISTS, FAILED_NOT_FOUND, + FAILED_CREATE, FAILED_VERIFY, + FAILED_ENCRYPT, }; std::unique_ptr<WalletDatabase> MakeDatabase(const fs::path &path, diff --git a/src/wallet/rpcwallet.cpp b/src/wallet/rpcwallet.cpp --- a/src/wallet/rpcwallet.cpp +++ b/src/wallet/rpcwallet.cpp @@ -3144,14 +3144,16 @@ } } + DatabaseOptions options; + DatabaseStatus status; bilingual_str error; std::vector<bilingual_str> warnings; std::optional<bool> load_on_start = request.params[1].isNull() ? std::nullopt : std::optional<bool>(request.params[1].get_bool()); - std::shared_ptr<CWallet> const wallet = - LoadWallet(*context.chain, name, load_on_start, error, warnings); + std::shared_ptr<CWallet> const wallet = LoadWallet( + *context.chain, name, load_on_start, options, status, error, warnings); if (!wallet) { throw JSONRPCError(RPC_WALLET_ERROR, error.original); } @@ -3318,23 +3320,23 @@ Untranslated("Wallet is an experimental descriptor wallet")); } + DatabaseOptions options; + DatabaseStatus status; + options.create_flags = flags; + options.create_passphrase = passphrase; bilingual_str error; - std::shared_ptr<CWallet> wallet; std::optional<bool> load_on_start = request.params[6].isNull() ? std::nullopt : std::make_optional<bool>(request.params[6].get_bool()); - WalletCreationStatus status = CreateWallet( - *context.chain, passphrase, flags, request.params[0].get_str(), - load_on_start, error, warnings, wallet); - switch (status) { - case WalletCreationStatus::CREATION_FAILED: - throw JSONRPCError(RPC_WALLET_ERROR, error.original); - case WalletCreationStatus::ENCRYPTION_FAILED: - throw JSONRPCError(RPC_WALLET_ENCRYPTION_FAILED, error.original); - case WalletCreationStatus::SUCCESS: - break; - // no default case, so the compiler can warn about missing cases + std::shared_ptr<CWallet> wallet = + CreateWallet(*context.chain, request.params[0].get_str(), load_on_start, + options, status, error, warnings); + if (!wallet) { + RPCErrorCode code = status == DatabaseStatus::FAILED_ENCRYPT + ? RPC_WALLET_ENCRYPTION_FAILED + : RPC_WALLET_ERROR; + throw JSONRPCError(code, error.original); } UniValue obj(UniValue::VOBJ); diff --git a/src/wallet/wallet.h b/src/wallet/wallet.h --- a/src/wallet/wallet.h +++ b/src/wallet/wallet.h @@ -62,21 +62,18 @@ std::optional<bool> load_on_start); std::vector<std::shared_ptr<CWallet>> GetWallets(); std::shared_ptr<CWallet> GetWallet(const std::string &name); -std::shared_ptr<CWallet> LoadWallet(interfaces::Chain &chain, - const std::string &name, - std::optional<bool> load_on_start, - bilingual_str &error, - std::vector<bilingual_str> &warnings); +std::shared_ptr<CWallet> +LoadWallet(interfaces::Chain &chain, const std::string &name, + std::optional<bool> load_on_start, const DatabaseOptions &options, + DatabaseStatus &status, bilingual_str &error, + std::vector<bilingual_str> &warnings); +std::shared_ptr<CWallet> +CreateWallet(interfaces::Chain &chain, const std::string &name, + std::optional<bool> load_on_start, const DatabaseOptions &options, + DatabaseStatus &status, bilingual_str &error, + std::vector<bilingual_str> &warnings); std::unique_ptr<interfaces::Handler> HandleLoadWallet(LoadWalletFn load_wallet); -enum class WalletCreationStatus { SUCCESS, CREATION_FAILED, ENCRYPTION_FAILED }; - -WalletCreationStatus -CreateWallet(interfaces::Chain &chain, const SecureString &passphrase, - uint64_t wallet_creation_flags, const std::string &name, - std::optional<bool> load_on_start, bilingual_str &error, - std::vector<bilingual_str> &warnings, - std::shared_ptr<CWallet> &result); //! -paytxfee default constexpr Amount DEFAULT_PAY_TX_FEE = Amount::zero(); //! -fallbackfee default diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -223,8 +223,9 @@ namespace { std::shared_ptr<CWallet> LoadWalletInternal(interfaces::Chain &chain, const std::string &name, - std::optional<bool> load_on_start, bilingual_str &error, - std::vector<bilingual_str> &warnings) { + std::optional<bool> load_on_start, + const DatabaseOptions &options, DatabaseStatus &status, + bilingual_str &error, std::vector<bilingual_str> &warnings) { try { if (!CWallet::Verify(chain, name, error, warnings)) { error = Untranslated("Wallet file verification failed.") + @@ -253,29 +254,31 @@ } } // namespace -std::shared_ptr<CWallet> LoadWallet(interfaces::Chain &chain, - const std::string &name, - std::optional<bool> load_on_start, - bilingual_str &error, - std::vector<bilingual_str> &warnings) { +std::shared_ptr<CWallet> +LoadWallet(interfaces::Chain &chain, const std::string &name, + std::optional<bool> load_on_start, const DatabaseOptions &options, + DatabaseStatus &status, bilingual_str &error, + std::vector<bilingual_str> &warnings) { auto result = WITH_LOCK(g_loading_wallet_mutex, return g_loading_wallet_set.insert(name)); if (!result.second) { error = Untranslated("Wallet already being loading."); return nullptr; } - auto wallet = - LoadWalletInternal(chain, name, load_on_start, error, warnings); + auto wallet = LoadWalletInternal(chain, name, load_on_start, options, + status, error, warnings); WITH_LOCK(g_loading_wallet_mutex, g_loading_wallet_set.erase(result.first)); return wallet; } -WalletCreationStatus -CreateWallet(interfaces::Chain &chain, const SecureString &passphrase, - uint64_t wallet_creation_flags, const std::string &name, - std::optional<bool> load_on_start, bilingual_str &error, - std::vector<bilingual_str> &warnings, - std::shared_ptr<CWallet> &result) { +std::shared_ptr<CWallet> +CreateWallet(interfaces::Chain &chain, const std::string &name, + std::optional<bool> load_on_start, const DatabaseOptions &options, + DatabaseStatus &status, bilingual_str &error, + std::vector<bilingual_str> &warnings) { + uint64_t wallet_creation_flags = options.create_flags; + const SecureString &passphrase = options.create_passphrase; + // Indicate that the wallet is actually supposed to be blank and not just // blank to make it encrypted bool create_blank = (wallet_creation_flags & WALLET_FLAG_BLANK_WALLET); @@ -290,7 +293,8 @@ fs::absolute(name.empty() ? "wallet.dat" : name, GetWalletDir())) .type() != fs::file_not_found) { error = strprintf(Untranslated("Wallet %s already exists."), name); - return WalletCreationStatus::CREATION_FAILED; + status = DatabaseStatus::FAILED_CREATE; + return nullptr; } // Wallet::Verify will check if we're trying to create a wallet with a @@ -298,7 +302,8 @@ if (!CWallet::Verify(chain, name, error, warnings)) { error = Untranslated("Wallet file verification failed.") + Untranslated(" ") + error; - return WalletCreationStatus::CREATION_FAILED; + status = DatabaseStatus::FAILED_VERIFY; + return nullptr; } // Do not allow a passphrase when private keys are disabled @@ -308,7 +313,8 @@ "Passphrase provided but private keys are disabled. A passphrase " "is only used to encrypt private keys, so cannot be used for " "wallets with private keys disabled."); - return WalletCreationStatus::CREATION_FAILED; + status = DatabaseStatus::FAILED_CREATE; + return nullptr; } // Make the wallet @@ -317,7 +323,8 @@ if (!wallet) { error = Untranslated("Wallet creation failed.") + Untranslated(" ") + error; - return WalletCreationStatus::CREATION_FAILED; + status = DatabaseStatus::FAILED_CREATE; + return nullptr; } // Encrypt the wallet @@ -326,14 +333,16 @@ if (!wallet->EncryptWallet(passphrase)) { error = Untranslated("Error: Wallet created but failed to encrypt."); - return WalletCreationStatus::ENCRYPTION_FAILED; + status = DatabaseStatus::FAILED_ENCRYPT; + return nullptr; } if (!create_blank) { // Unlock the wallet if (!wallet->Unlock(passphrase)) { error = Untranslated( "Error: Wallet was encrypted but could not be unlocked"); - return WalletCreationStatus::ENCRYPTION_FAILED; + status = DatabaseStatus::FAILED_ENCRYPT; + return nullptr; } // Set a seed for the wallet @@ -344,9 +353,8 @@ } else { for (auto spk_man : wallet->GetActiveScriptPubKeyMans()) { if (!spk_man->SetupGeneration()) { - error = - Untranslated("Unable to generate initial keys"); - return WalletCreationStatus::CREATION_FAILED; + status = DatabaseStatus::FAILED_CREATE; + return nullptr; } } } @@ -358,12 +366,12 @@ } AddWallet(wallet); wallet->postInitProcess(); - result = wallet; // Write the wallet settings UpdateWalletSetting(chain, name, load_on_start, warnings); - return WalletCreationStatus::SUCCESS; + status = DatabaseStatus::SUCCESS; + return wallet; } /** @defgroup mapWallet