diff --git a/src/rpc/command.h b/src/rpc/command.h --- a/src/rpc/command.h +++ b/src/rpc/command.h @@ -5,6 +5,8 @@ #ifndef BITCOIN_RPC_COMMAND_H #define BITCOIN_RPC_COMMAND_H +#include "jsonrpcrequest.h" + #include #include @@ -12,9 +14,11 @@ #include /** - * Base class for all RPC commands. + * Base class for all RPC commands. RPCCommandBase should only be inherited + * from directly if access to the entire request context is necessary. For + * more typical cases, see the RPCCommand class below. */ -class RPCCommand : public boost::noncopyable { +class RPCCommandBase : public boost::noncopyable { private: const std::string name; @@ -28,12 +32,30 @@ // messages as well) public: - RPCCommand(std::string nameIn) : name(nameIn) {} - virtual ~RPCCommand() {} + RPCCommandBase(std::string nameIn) : name(nameIn) {} + virtual ~RPCCommandBase() {} - virtual UniValue Execute(const UniValue &args) const = 0; + /** + * It is recommended to override Execute(JSONRPCRequest) only if the entire + * request context is required. Otherwise, use RPCCommand instead. + */ + virtual UniValue Execute(const JSONRPCRequest &request) const = 0; std::string GetName() const { return name; }; }; +/** + * By default, use RPCCommand as the parent class for new RPC command classes. + */ +class RPCCommand : public RPCCommandBase { +public: + RPCCommand(std::string nameIn) : RPCCommandBase(nameIn) {} + + virtual UniValue Execute(const UniValue &args) const = 0; + + UniValue Execute(const JSONRPCRequest &request) const final { + return Execute(request.params); + } +}; + #endif // BITCOIN_RPC_COMMAND_H diff --git a/src/rpc/server.h b/src/rpc/server.h --- a/src/rpc/server.h +++ b/src/rpc/server.h @@ -49,7 +49,7 @@ UniValue::VType type; }; -typedef std::map> RPCCommandMap; +typedef std::map> RPCCommandMap; /** * Class for registering and managing all RPC calls. @@ -71,7 +71,7 @@ /** * Register an RPC command. */ - void RegisterCommand(std::unique_ptr command); + void RegisterCommand(std::unique_ptr command); }; /** diff --git a/src/rpc/server.cpp b/src/rpc/server.cpp --- a/src/rpc/server.cpp +++ b/src/rpc/server.cpp @@ -55,7 +55,7 @@ auto commandsReadView = commands.getReadView(); auto iter = commandsReadView->find(commandName); if (iter != commandsReadView.end()) { - return iter->second.get()->Execute(request.params); + return iter->second.get()->Execute(request); } } @@ -66,7 +66,7 @@ return tableRPC.execute(config, request); } -void RPCServer::RegisterCommand(std::unique_ptr command) { +void RPCServer::RegisterCommand(std::unique_ptr command) { if (command != nullptr) { commands.getWriteView()->insert( std::make_pair(command->GetName(), std::move(command))); diff --git a/src/test/rpc_server_tests.cpp b/src/test/rpc_server_tests.cpp --- a/src/test/rpc_server_tests.cpp +++ b/src/test/rpc_server_tests.cpp @@ -53,4 +53,36 @@ UniValue, isRpcMethodNotFound); } +class RequestContextRPCCommand : public RPCCommandBase { +public: + RequestContextRPCCommand(std::string nameIn) : RPCCommandBase(nameIn) {} + + // Sanity check that Execute(JSONRPCRequest) is called correctly from + // RPCServer + UniValue Execute(const JSONRPCRequest &request) const override { + const UniValue args = request.params; + BOOST_CHECK_EQUAL(request.strMethod, "testcommand"); + BOOST_CHECK_EQUAL(args["arg1"].get_str(), "value1"); + return UniValue("testing"); + } +}; + +BOOST_AUTO_TEST_CASE(rpc_server_execute_command_from_request_context) { + DummyConfig config; + RPCServer rpcServer; + const std::string commandName = "testcommand"; + rpcServer.RegisterCommand( + MakeUnique(commandName)); + + UniValue args(UniValue::VOBJ); + args.pushKV("arg1", "value1"); + + // Registered commands execute and return values correctly + JSONRPCRequest request; + request.strMethod = commandName; + request.params = args; + UniValue output = rpcServer.ExecuteCommand(config, request); + BOOST_CHECK_EQUAL(output.get_str(), "testing"); +} + BOOST_AUTO_TEST_SUITE_END()