diff --git a/src/psbt.h b/src/psbt.h --- a/src/psbt.h +++ b/src/psbt.h @@ -323,6 +323,8 @@ */ NODISCARD bool Merge(const PartiallySignedTransaction &psbt); bool IsSane() const; + bool AddInput(const CTxIn &txin, PSBTInput &psbtin); + bool AddOutput(const CTxOut &txout, const PSBTOutput &psbtout); PartiallySignedTransaction() {} PartiallySignedTransaction(const PartiallySignedTransaction &psbt_in) : tx(psbt_in.tx), inputs(psbt_in.inputs), outputs(psbt_in.outputs), diff --git a/src/psbt.cpp b/src/psbt.cpp --- a/src/psbt.cpp +++ b/src/psbt.cpp @@ -41,6 +41,25 @@ return true; } +bool PartiallySignedTransaction::AddInput(const CTxIn &txin, + PSBTInput &psbtin) { + if (std::find(tx->vin.begin(), tx->vin.end(), txin) != tx->vin.end()) { + return false; + } + tx->vin.push_back(txin); + psbtin.partial_sigs.clear(); + psbtin.final_script_sig.clear(); + inputs.push_back(psbtin); + return true; +} + +bool PartiallySignedTransaction::AddOutput(const CTxOut &txout, + const PSBTOutput &psbtout) { + tx->vout.push_back(txout); + outputs.push_back(psbtout); + return true; +} + bool PSBTInput::IsNull() const { return utxo.IsNull() && partial_sigs.empty() && unknown.empty() && hd_keypaths.empty() && redeem_script.empty(); diff --git a/src/rpc/client.cpp b/src/rpc/client.cpp --- a/src/rpc/client.cpp +++ b/src/rpc/client.cpp @@ -102,6 +102,7 @@ {"createpsbt", 1, "outputs"}, {"createpsbt", 2, "locktime"}, {"combinepsbt", 0, "txs"}, + {"joinpsbts", 0, "txs"}, {"finalizepsbt", 1, "extract"}, {"converttopsbt", 1, "permitsigdata"}, {"gettxout", 1, "n"}, diff --git a/src/rpc/rawtransaction.cpp b/src/rpc/rawtransaction.cpp --- a/src/rpc/rawtransaction.cpp +++ b/src/rpc/rawtransaction.cpp @@ -1768,6 +1768,87 @@ return EncodeBase64((uint8_t *)ssTx.data(), ssTx.size()); } +UniValue joinpsbts(const Config &config, const JSONRPCRequest &request) { + if (request.fHelp || request.params.size() != 1) { + throw std::runtime_error(RPCHelpMan{ + "joinpsbts", + "\nJoins multiple distinct PSBTs with different inputs and outputs " + "into one PSBT with inputs and outputs from all of the PSBTs\n" + "No input in any of the PSBTs can be in more than one of the " + "PSBTs.\n", + {{"txs", + RPCArg::Type::ARR, + /* opt */ false, + /* default_val */ "", + "A json array of base64 strings of partially signed transactions", + {{"psbt", RPCArg::Type::STR, /* opt */ false, + /* default_val */ "", "A base64 string of a PSBT"}}}}, + RPCResult{" \"psbt\" (string) The base64-encoded " + "partially signed transaction\n"}, + RPCExamples{HelpExampleCli("joinpsbts", "\"psbt\"")}} + .ToString()); + } + + RPCTypeCheck(request.params, {UniValue::VARR}, true); + + // Unserialize the transactions + std::vector psbtxs; + UniValue txs = request.params[0].get_array(); + + if (txs.size() <= 1) { + throw JSONRPCError(RPC_INVALID_PARAMETER, + "At least two PSBTs are required to join PSBTs."); + } + + int32_t best_version = 1; + uint32_t best_locktime = 0xffffffff; + for (size_t i = 0; i < txs.size(); ++i) { + PartiallySignedTransaction psbtx; + std::string error; + if (!DecodeBase64PSBT(psbtx, txs[i].get_str(), error)) { + throw JSONRPCError(RPC_DESERIALIZATION_ERROR, + strprintf("TX decode failed %s", error)); + } + psbtxs.push_back(psbtx); + // Choose the highest version number + if (psbtx.tx->nVersion > best_version) { + best_version = psbtx.tx->nVersion; + } + // Choose the lowest lock time + if (psbtx.tx->nLockTime < best_locktime) { + best_locktime = psbtx.tx->nLockTime; + } + } + + // Create a blank psbt where everything will be added + PartiallySignedTransaction merged_psbt; + merged_psbt.tx = CMutableTransaction(); + merged_psbt.tx->nVersion = best_version; + merged_psbt.tx->nLockTime = best_locktime; + + // Merge + for (auto &psbt : psbtxs) { + for (size_t i = 0; i < psbt.tx->vin.size(); ++i) { + if (!merged_psbt.AddInput(psbt.tx->vin[i], psbt.inputs[i])) { + throw JSONRPCError( + RPC_INVALID_PARAMETER, + strprintf( + "Input %s:%d exists in multiple PSBTs", + psbt.tx->vin[i].prevout.GetTxId().ToString().c_str(), + psbt.tx->vin[i].prevout.GetN())); + } + } + for (size_t i = 0; i < psbt.tx->vout.size(); ++i) { + merged_psbt.AddOutput(psbt.tx->vout[i], psbt.outputs[i]); + } + merged_psbt.unknown.insert(psbt.unknown.begin(), psbt.unknown.end()); + } + + CDataStream ssTx(SER_NETWORK, PROTOCOL_VERSION); + ssTx << merged_psbt; + return EncodeBase64((uint8_t *)ssTx.data(), ssTx.size()); +} + // clang-format off static const CRPCCommand commands[] = { // category name actor (function) argNames @@ -1786,6 +1867,7 @@ { "rawtransactions", "createpsbt", createpsbt, {"inputs","outputs","locktime"} }, { "rawtransactions", "converttopsbt", converttopsbt, {"hexstring","permitsigdata"} }, { "rawtransactions", "utxoupdatepsbt", utxoupdatepsbt, {"psbt"} }, + { "rawtransactions", "joinpsbts", joinpsbts, {"txs"} }, { "blockchain", "gettxoutproof", gettxoutproof, {"txids", "blockhash"} }, { "blockchain", "verifytxoutproof", verifytxoutproof, {"proof"} }, }; diff --git a/test/functional/rpc_psbt.py b/test/functional/rpc_psbt.py --- a/test/functional/rpc_psbt.py +++ b/test/functional/rpc_psbt.py @@ -8,6 +8,7 @@ import json import os +from decimal import Decimal from test_framework.test_framework import BitcoinTestFramework from test_framework.util import ( assert_equal, @@ -249,6 +250,34 @@ updated = self.nodes[1].utxoupdatepsbt(psbt) decoded = self.nodes[1].decodepsbt(updated) + # Two PSBTs with a common input should not be joinable + psbt1 = self.nodes[1].createpsbt([{"txid": txid1, "vout": vout1}], { + self.nodes[0].getnewaddress(): Decimal('10.999')}) + assert_raises_rpc_error(-8, "exists in multiple PSBTs", + self.nodes[1].joinpsbts, [psbt1, updated]) + + # Join two distinct PSBTs + addr4 = self.nodes[1].getnewaddress("") + txid4 = self.nodes[0].sendtoaddress(addr4, 5) + vout4 = find_output(self.nodes[0], txid4, 5) + self.nodes[0].generate(6) + self.sync_all() + psbt2 = self.nodes[1].createpsbt([{"txid": txid4, "vout": vout4}], { + self.nodes[0].getnewaddress(): Decimal('4.999')}) + psbt2 = self.nodes[1].walletprocesspsbt(psbt2)['psbt'] + psbt2_decoded = self.nodes[0].decodepsbt(psbt2) + assert "final_scriptSig" in psbt2_decoded['inputs'][0] + joined = self.nodes[0].joinpsbts([psbt, psbt2]) + joined_decoded = self.nodes[0].decodepsbt(joined) + assert len(joined_decoded['inputs']) == 4 and len( + joined_decoded['outputs']) == 2 and "final_scriptSig" not in joined_decoded['inputs'][3] + + # Fail when trying to join less than two PSBTs + assert_raises_rpc_error(-8, + "At least two PSBTs are required to join PSBTs.", self.nodes[1].joinpsbts, []) + assert_raises_rpc_error(-8, + "At least two PSBTs are required to join PSBTs.", self.nodes[1].joinpsbts, [psbt2]) + if __name__ == '__main__': PSBTTest().main()