diff --git a/src/addrdb.cpp b/src/addrdb.cpp --- a/src/addrdb.cpp +++ b/src/addrdb.cpp @@ -15,22 +15,31 @@ #include #include -CBanDB::CBanDB(const CChainParams &chainParamsIn) : chainParams(chainParamsIn) { - pathBanlist = GetDataDir() / "banlist.dat"; +namespace { + +template +bool SerializeDB(const CChainParams &chainParams, Stream &stream, + const Data &data) { + // Write and commit header, data + try { + CHashWriter hasher(SER_DISK, CLIENT_VERSION); + stream << FLATDATA(chainParams.DiskMagic()) << data; + hasher << FLATDATA(chainParams.DiskMagic()) << data; + stream << hasher.GetHash(); + } catch (const std::exception &e) { + return error("%s: Serialize or I/O error - %s", __func__, e.what()); + } + + return true; } -bool CBanDB::Write(const banmap_t &banSet) { +template +bool SerializeFileDB(const CChainParams &chainParams, const std::string &prefix, + const fs::path &path, const Data &data) { // Generate random temporary filename unsigned short randv = 0; GetRandBytes((uint8_t *)&randv, sizeof(randv)); - std::string tmpfn = strprintf("banlist.dat.%04x", randv); - - // serialize banlist, checksum data up to that point, then append csum - CDataStream ssBanlist(SER_DISK, CLIENT_VERSION); - ssBanlist << FLATDATA(chainParams.DiskMagic()); - ssBanlist << banSet; - uint256 hash = Hash(ssBanlist.begin(), ssBanlist.end()); - ssBanlist << hash; + std::string tmpfn = strprintf("%s.%04x", prefix, randv); // open temp output file, and associate with CAutoFile fs::path pathTmp = GetDataDir() / tmpfn; @@ -39,68 +48,43 @@ if (fileout.IsNull()) return error("%s: Failed to open file %s", __func__, pathTmp.string()); - // Write and commit header, data - try { - fileout << ssBanlist; - } catch (const std::exception &e) { - return error("%s: Serialize or I/O error - %s", __func__, e.what()); - } + // Serialize + if (!SerializeDB(chainParams, fileout, data)) return false; FileCommit(fileout.Get()); fileout.fclose(); - // replace existing banlist.dat, if any, with new banlist.dat.XXXX - if (!RenameOver(pathTmp, pathBanlist)) + // replace existing file, if any, with new file + if (!RenameOver(pathTmp, path)) return error("%s: Rename-into-place failed", __func__); return true; } -bool CBanDB::Read(banmap_t &banSet) { - // open input file, and associate with CAutoFile - FILE *file = fsbridge::fopen(pathBanlist, "rb"); - CAutoFile filein(file, SER_DISK, CLIENT_VERSION); - if (filein.IsNull()) - return error("%s: Failed to open file %s", __func__, - pathBanlist.string()); - - // use file size to size memory buffer - uint64_t fileSize = fs::file_size(pathBanlist); - uint64_t dataSize = 0; - // Don't try to resize to a negative number if file is small - if (fileSize >= sizeof(uint256)) dataSize = fileSize - sizeof(uint256); - std::vector vchData; - vchData.resize(dataSize); - uint256 hashIn; - - // read data and checksum from file - try { - filein.read((char *)&vchData[0], dataSize); - filein >> hashIn; - } catch (const std::exception &e) { - return error("%s: Deserialize or I/O error - %s", __func__, e.what()); - } - filein.fclose(); - - CDataStream ssBanlist(vchData, SER_DISK, CLIENT_VERSION); - - // verify stored checksum matches input data - uint256 hashTmp = Hash(ssBanlist.begin(), ssBanlist.end()); - if (hashIn != hashTmp) - return error("%s: Checksum mismatch, data corrupted", __func__); - - uint8_t pchMsgTmp[4]; +template +bool DeserializeDB(const CChainParams &chainParams, Stream &stream, Data &data, + bool fCheckSum = true) { try { + CHashVerifier verifier(&stream); // de-serialize file header (network specific magic number) and .. - ssBanlist >> FLATDATA(pchMsgTmp); - + unsigned char pchMsgTmp[4]; + verifier >> FLATDATA(pchMsgTmp); // ... verify the network matches ours if (memcmp(pchMsgTmp, std::begin(chainParams.DiskMagic()), sizeof(pchMsgTmp))) { return error("%s: Invalid network magic number", __func__); } - // de-serialize ban data - ssBanlist >> banSet; + // de-serialize data + verifier >> data; + + // verify checksum + if (fCheckSum) { + uint256 hashTmp; + stream >> hashTmp; + if (hashTmp != verifier.GetHash()) { + return error("%s: Checksum mismatch, data corrupted", __func__); + } + } } catch (const std::exception &e) { return error("%s: Deserialize or I/O error - %s", __func__, e.what()); } @@ -108,101 +92,50 @@ return true; } -CAddrDB::CAddrDB(const CChainParams &chainParamsIn) - : chainParams(chainParamsIn) { - pathAddr = GetDataDir() / "peers.dat"; +template +bool DeserializeFileDB(const CChainParams &chainParams, const fs::path &path, + Data &data) { + // open input file, and associate with CAutoFile + FILE *file = fsbridge::fopen(path, "rb"); + CAutoFile filein(file, SER_DISK, CLIENT_VERSION); + if (filein.IsNull()) + return error("%s: Failed to open file %s", __func__, path.string()); + + return DeserializeDB(chainParams, filein, data); } -bool CAddrDB::Write(const CAddrMan &addr) { - // Generate random temporary filename - unsigned short randv = 0; - GetRandBytes((uint8_t *)&randv, sizeof(randv)); - std::string tmpfn = strprintf("peers.dat.%04x", randv); +} - // serialize addresses, checksum data up to that point, then append csum - CDataStream ssPeers(SER_DISK, CLIENT_VERSION); - ssPeers << FLATDATA(chainParams.DiskMagic()); - ssPeers << addr; - uint256 hash = Hash(ssPeers.begin(), ssPeers.end()); - ssPeers << hash; +CBanDB::CBanDB(const CChainParams &chainParamsIn) : chainParams(chainParamsIn) { + pathBanlist = GetDataDir() / "banlist.dat"; +} - // open temp output file, and associate with CAutoFile - fs::path pathTmp = GetDataDir() / tmpfn; - FILE *file = fsbridge::fopen(pathTmp, "wb"); - CAutoFile fileout(file, SER_DISK, CLIENT_VERSION); - if (fileout.IsNull()) - return error("%s: Failed to open file %s", __func__, pathTmp.string()); +bool CBanDB::Write(const banmap_t &banSet) { + return SerializeFileDB(chainParams, "banlist", pathBanlist, banSet); +} - // Write and commit header, data - try { - fileout << ssPeers; - } catch (const std::exception &e) { - return error("%s: Serialize or I/O error - %s", __func__, e.what()); - } - FileCommit(fileout.Get()); - fileout.fclose(); +bool CBanDB::Read(banmap_t &banSet) { + return DeserializeFileDB(chainParams, pathBanlist, banSet); +} - // replace existing peers.dat, if any, with new peers.dat.XXXX - if (!RenameOver(pathTmp, pathAddr)) - return error("%s: Rename-into-place failed", __func__); +CAddrDB::CAddrDB(const CChainParams &chainParamsIn) + : chainParams(chainParamsIn) { + pathAddr = GetDataDir() / "peers.dat"; +} - return true; +bool CAddrDB::Write(const CAddrMan &addr) { + return SerializeFileDB(chainParams, "peers", pathAddr, addr); } bool CAddrDB::Read(CAddrMan &addr) { - // open input file, and associate with CAutoFile - FILE *file = fsbridge::fopen(pathAddr, "rb"); - CAutoFile filein(file, SER_DISK, CLIENT_VERSION); - if (filein.IsNull()) - return error("%s: Failed to open file %s", __func__, pathAddr.string()); - - // use file size to size memory buffer - uint64_t fileSize = fs::file_size(pathAddr); - uint64_t dataSize = 0; - // Don't try to resize to a negative number if file is small - if (fileSize >= sizeof(uint256)) dataSize = fileSize - sizeof(uint256); - std::vector vchData; - vchData.resize(dataSize); - uint256 hashIn; - - // read data and checksum from file - try { - filein.read((char *)&vchData[0], dataSize); - filein >> hashIn; - } catch (const std::exception &e) { - return error("%s: Deserialize or I/O error - %s", __func__, e.what()); - } - filein.fclose(); - - CDataStream ssPeers(vchData, SER_DISK, CLIENT_VERSION); - - // verify stored checksum matches input data - uint256 hashTmp = Hash(ssPeers.begin(), ssPeers.end()); - if (hashIn != hashTmp) - return error("%s: Checksum mismatch, data corrupted", __func__); - - return Read(addr, ssPeers); + return DeserializeFileDB(chainParams, pathAddr, addr); } bool CAddrDB::Read(CAddrMan &addr, CDataStream &ssPeers) { - uint8_t pchMsgTmp[4]; - try { - // de-serialize file header (network specific magic number) and .. - ssPeers >> FLATDATA(pchMsgTmp); - - // ... verify the network matches ours - if (memcmp(pchMsgTmp, std::begin(chainParams.DiskMagic()), - sizeof(pchMsgTmp))) { - return error("%s: Invalid network magic number", __func__); - } - - // de-serialize address data into one CAddrMan object - ssPeers >> addr; - } catch (const std::exception &e) { - // de-serialization has failed, ensure addrman is left in a clean state + bool ret = DeserializeDB(chainParams, ssPeers, addr, false); + if (!ret) { + // Ensure addrman is left in a clean state addr.Clear(); - return error("%s: Deserialize or I/O error - %s", __func__, e.what()); } - - return true; + return ret; }