diff --git a/src/wallet/sqlite.cpp b/src/wallet/sqlite.cpp --- a/src/wallet/sqlite.cpp +++ b/src/wallet/sqlite.cpp @@ -4,6 +4,8 @@ #include +#include +#include #include #include #include @@ -135,8 +137,37 @@ bool SQLiteDatabase::Verify(bilingual_str &error) { assert(m_db); + // Check the application ID matches our network magic + sqlite3_stmt *app_id_stmt{nullptr}; + int ret = sqlite3_prepare_v2(m_db, "PRAGMA application_id", -1, + &app_id_stmt, nullptr); + if (ret != SQLITE_OK) { + sqlite3_finalize(app_id_stmt); + error = strprintf(_("SQLiteDatabase: Failed to prepare the statement " + "to fetch the application id: %s"), + sqlite3_errstr(ret)); + return false; + } + ret = sqlite3_step(app_id_stmt); + if (ret != SQLITE_ROW) { + sqlite3_finalize(app_id_stmt); + error = strprintf( + _("SQLiteDatabase: Failed to fetch the application id: %s"), + sqlite3_errstr(ret)); + return false; + } + uint32_t app_id = static_cast(sqlite3_column_int(app_id_stmt, 0)); + sqlite3_finalize(app_id_stmt); + uint32_t net_magic = ReadBE32(Params().DiskMagic().data()); + if (app_id != net_magic) { + error = strprintf( + _("SQLiteDatabase: Unexpected application id. Expected %u, got %u"), + net_magic, app_id); + return false; + } + sqlite3_stmt *stmt{nullptr}; - int ret = + ret = sqlite3_prepare_v2(m_db, "PRAGMA integrity_check", -1, &stmt, nullptr); if (ret != SQLITE_OK) { sqlite3_finalize(stmt); @@ -278,6 +309,17 @@ strprintf("SQLiteDatabase: Failed to create new database: %s\n", sqlite3_errstr(ret))); } + + // Set the application id + uint32_t app_id = ReadBE32(Params().DiskMagic().data()); + std::string set_app_id = strprintf("PRAGMA application_id = %d", + static_cast(app_id)); + ret = sqlite3_exec(m_db, set_app_id.c_str(), nullptr, nullptr, nullptr); + if (ret != SQLITE_OK) { + throw std::runtime_error(strprintf( + "SQLiteDatabase: Failed to set the application id: %s\n", + sqlite3_errstr(ret))); + } } } @@ -653,9 +695,20 @@ // Magic is at beginning and is 16 bytes long char magic[16]; file.read(magic, 16); + + // Application id is at offset 68 and 4 bytes long + file.seekg(68, std::ios::beg); + char app_id[4]; + file.read(app_id, 4); + file.close(); // Check the magic, see https://sqlite.org/fileformat2.html std::string magic_str(magic, 16); - return magic_str == std::string("SQLite format 3", 16); + if (magic_str != std::string("SQLite format 3", 16)) { + return false; + } + + // Check the application id matches our network magic + return memcmp(Params().DiskMagic().data(), app_id, 4) == 0; }