diff --git a/src/wallet/db.h b/src/wallet/db.h --- a/src/wallet/db.h +++ b/src/wallet/db.h @@ -31,6 +31,8 @@ bool operator==(const WalletDatabaseFileId &rhs) const; }; +class BerkeleyDatabase; + class BerkeleyEnvironment { private: bool fDbEnvInit; @@ -43,7 +45,7 @@ public: std::unique_ptr dbenv; std::map mapFileUseCount; - std::map mapDb; + std::map> m_databases; std::unordered_map m_fileids; std::condition_variable_any m_db_in_use; @@ -54,6 +56,9 @@ void MakeMock(); bool IsMock() const { return fMockDb; } bool IsInitialized() const { return fDbEnvInit; } + bool IsDatabaseLoaded(const std::string &db_filename) const { + return m_databases.find(db_filename) != m_databases.end(); + } fs::path Directory() const { return strPath; } /** @@ -97,6 +102,9 @@ } }; +/** Return whether a wallet database is currently loaded. */ +bool IsWalletLoaded(const fs::path &wallet_path); + /** Get BerkeleyEnvironment and database filename given a wallet path. */ BerkeleyEnvironment *GetWalletEnv(const fs::path &wallet_path, std::string &database_filename); @@ -119,6 +127,8 @@ : nUpdateCounter(0), nLastSeen(0), nLastFlushed(0), nLastWalletUpdate(0) { env = GetWalletEnv(wallet_path, strFile); + auto inserted = env->m_databases.emplace(strFile, std::ref(*this)); + assert(inserted.second); if (mock) { env->Close(); env->Reset(); @@ -126,6 +136,13 @@ } } + ~BerkeleyDatabase() { + if (env) { + size_t erased = env->m_databases.erase(strFile); + assert(erased == 1); + } + } + /** Return object for accessing database at specified path. */ static std::unique_ptr Create(const fs::path &path) { return std::make_unique(path); @@ -171,6 +188,12 @@ unsigned int nLastFlushed; int64_t nLastWalletUpdate; + /** + * Database pointer. This is initialized lazily and reset during flushes, + * so it can be null. + */ + std::unique_ptr m_db; + private: /** BerkeleyDB specific */ BerkeleyEnvironment *env; diff --git a/src/wallet/db.cpp b/src/wallet/db.cpp --- a/src/wallet/db.cpp +++ b/src/wallet/db.cpp @@ -67,9 +67,9 @@ return memcmp(value, &rhs.value, sizeof(value)) == 0; } -BerkeleyEnvironment *GetWalletEnv(const fs::path &wallet_path, - std::string &database_filename) { - fs::path env_directory; +static void SplitWalletPath(const fs::path &wallet_path, + fs::path &env_directory, + std::string &database_filename) { if (fs::is_regular_file(wallet_path)) { // Special case for backwards compatibility: if wallet path points to an // existing file, treat it as the path to a BDB data file in a parent @@ -82,6 +82,26 @@ env_directory = wallet_path; database_filename = "wallet.dat"; } +} + +bool IsWalletLoaded(const fs::path &wallet_path) { + fs::path env_directory; + std::string database_filename; + SplitWalletPath(wallet_path, env_directory, database_filename); + + LOCK(cs_db); + auto env = g_dbenvs.find(env_directory.string()); + if (env == g_dbenvs.end()) { + return false; + } + + return env->second.IsDatabaseLoaded(database_filename); +} + +BerkeleyEnvironment *GetWalletEnv(const fs::path &wallet_path, + std::string &database_filename) { + fs::path env_directory; + SplitWalletPath(wallet_path, env_directory, database_filename); LOCK(cs_db); // Note: An ununsed temporary BerkeleyEnvironment object may be created // inside the emplace function if the key already exists. This is a little @@ -105,13 +125,13 @@ fDbEnvInit = false; - for (auto &db : mapDb) { + for (auto &db : m_databases) { auto count = mapFileUseCount.find(db.first); assert(count == mapFileUseCount.end() || count->second == 0); - if (db.second) { - db.second->close(0); - delete db.second; - db.second = nullptr; + BerkeleyDatabase &database = db.second.get(); + if (database.m_db) { + database.m_db->close(0); + database.m_db.reset(); } } @@ -507,7 +527,7 @@ "BerkeleyBatch: Failed to open database environment."); } - pdb = env->mapDb[strFilename]; + pdb = database.m_db.get(); if (pdb == nullptr) { int ret; std::unique_ptr pdb_temp = @@ -560,7 +580,7 @@ } pdb = pdb_temp.release(); - env->mapDb[strFilename] = pdb; + database.m_db.reset(pdb); if (fCreate && !Exists(std::string("version"))) { bool fTmp = fReadOnly; @@ -618,12 +638,13 @@ void BerkeleyEnvironment::CloseDb(const std::string &strFile) { LOCK(cs_db); - if (mapDb[strFile] != nullptr) { + auto it = m_databases.find(strFile); + assert(it != m_databases.end()); + BerkeleyDatabase &database = it->second.get(); + if (database.m_db) { // Close the database handle - Db *pdb = mapDb[strFile]; - pdb->close(0); - delete pdb; - mapDb[strFile] = nullptr; + database.m_db->close(0); + database.m_db.reset(); } } @@ -641,7 +662,7 @@ }); std::vector filenames; - for (auto it : mapDb) { + for (auto it : m_databases) { filenames.push_back(it.first); } // Close the individual Db's diff --git a/src/wallet/wallet.cpp b/src/wallet/wallet.cpp --- a/src/wallet/wallet.cpp +++ b/src/wallet/wallet.cpp @@ -4325,13 +4325,11 @@ } // Make sure that the wallet path doesn't clash with an existing wallet path - for (auto wallet : GetWallets()) { - if (wallet->GetLocation().GetPath() == wallet_path) { - error_string = strprintf("Error loading wallet %s. Duplicate " - "-wallet filename specified.", - location.GetName()); - return false; - } + if (IsWalletLoaded(wallet_path)) { + error_string = strprintf( + "Error loading wallet %s. Duplicate -wallet filename specified.", + location.GetName()); + return false; } try { diff --git a/test/functional/wallet_multiwallet.py b/test/functional/wallet_multiwallet.py --- a/test/functional/wallet_multiwallet.py +++ b/test/functional/wallet_multiwallet.py @@ -221,6 +221,10 @@ assert_raises_rpc_error(-4, 'Wallet file verification failed: Error loading wallet w1. Duplicate -wallet filename specified.', self.nodes[0].loadwallet, wallet_names[0]) + # Fail to load duplicate wallets by different ways (directory and filepath) + assert_raises_rpc_error(-4, "Wallet file verification failed: Error loading wallet wallet.dat. Duplicate -wallet filename specified.", + self.nodes[0].loadwallet, 'wallet.dat') + # Fail to load if one wallet is a copy of another assert_raises_rpc_error(-1, "BerkeleyBatch: Can't open database w8_copy (duplicates fileid", self.nodes[0].loadwallet, 'w8_copy')