diff --git a/src/httprpc.cpp b/src/httprpc.cpp --- a/src/httprpc.cpp +++ b/src/httprpc.cpp @@ -198,10 +198,9 @@ // Send reply strReply = JSONRPCReply(result, NullUniValue, jreq.id); - - // array of requests } else if (valRequest.isArray()) { - strReply = JSONRPCExecBatch(config, valRequest.get_array()); + // array of requests + strReply = JSONRPCExecBatch(config, jreq, valRequest.get_array()); } else { throw JSONRPCError(RPC_PARSE_ERROR, "Top-level object parse error"); } diff --git a/src/rpc/server.h b/src/rpc/server.h --- a/src/rpc/server.h +++ b/src/rpc/server.h @@ -227,7 +227,8 @@ bool StartRPC(); void InterruptRPC(); void StopRPC(); -std::string JSONRPCExecBatch(Config &config, const UniValue &vReq); +std::string JSONRPCExecBatch(Config &config, const JSONRPCRequest &req, + const UniValue &vReq); void RPCNotifyBlockChange(bool ibd, const CBlockIndex *); // Retrieves any serialization flags requested in command line argument diff --git a/src/rpc/server.cpp b/src/rpc/server.cpp --- a/src/rpc/server.cpp +++ b/src/rpc/server.cpp @@ -393,10 +393,10 @@ "Params must be an array or object"); } -static UniValue JSONRPCExecOne(Config &config, const UniValue &req) { +static UniValue JSONRPCExecOne(Config &config, JSONRPCRequest jreq, + const UniValue &req) { UniValue rpc_result(UniValue::VOBJ); - JSONRPCRequest jreq; try { jreq.parse(req); @@ -412,10 +412,11 @@ return rpc_result; } -std::string JSONRPCExecBatch(Config &config, const UniValue &vReq) { +std::string JSONRPCExecBatch(Config &config, const JSONRPCRequest &jreq, + const UniValue &vReq) { UniValue ret(UniValue::VARR); - for (unsigned int reqIdx = 0; reqIdx < vReq.size(); reqIdx++) { - ret.push_back(JSONRPCExecOne(config, vReq[reqIdx])); + for (size_t i = 0; i < vReq.size(); i++) { + ret.push_back(JSONRPCExecOne(config, jreq, vReq[i])); } return ret.write() + "\n"; diff --git a/test/functional/multiwallet.py b/test/functional/multiwallet.py --- a/test/functional/multiwallet.py +++ b/test/functional/multiwallet.py @@ -90,6 +90,11 @@ assert_equal(w3.getbalance(), 2) assert_equal(w4.getbalance(), 3) + batch = w1.batch([w1.getblockchaininfo.get_request(), + w1.getwalletinfo.get_request()]) + assert_equal(batch[0]["result"]["chain"], "regtest") + assert_equal(batch[1]["result"]["walletname"], "w1") + if __name__ == '__main__': MultiWalletTest().main() diff --git a/test/functional/test_framework/authproxy.py b/test/functional/test_framework/authproxy.py --- a/test/functional/test_framework/authproxy.py +++ b/test/functional/test_framework/authproxy.py @@ -126,7 +126,7 @@ self.__conn.request(method, path, postdata, headers) return self._get_response() - def __call__(self, *args, **argsn): + def get_request(self, *args, **argsn): AuthServiceProxy.__id_count += 1 log.debug("-%s-> %s %s" % (AuthServiceProxy.__id_count, self._service_name, @@ -134,10 +134,14 @@ if args and argsn: raise ValueError( 'Cannot handle both named and positional arguments') - postdata = json.dumps({'version': '1.1', - 'method': self._service_name, - 'params': args or argsn, - 'id': AuthServiceProxy.__id_count}, default=EncodeDecimal, ensure_ascii=self.ensure_ascii) + return {'version': '1.1', + 'method': self._service_name, + 'params': args or argsn, + 'id': AuthServiceProxy.__id_count} + + def __call__(self, *args, **argsn): + postdata = json.dumps(self.get_request( + *args, **argsn), default=EncodeDecimal, ensure_ascii=self.ensure_ascii) response = self._request( 'POST', self.__url.path, postdata.encode('utf-8')) if response['error'] is not None: @@ -151,7 +155,7 @@ def batch(self, rpc_call_list): postdata = json.dumps( list(rpc_call_list), default=EncodeDecimal, ensure_ascii=self.ensure_ascii) - log.debug("--> " + postdata) + log.debug("--> "+postdata) return self._request('POST', self.__url.path, postdata.encode('utf-8')) def _get_response(self): diff --git a/test/functional/test_framework/coverage.py b/test/functional/test_framework/coverage.py --- a/test/functional/test_framework/coverage.py +++ b/test/functional/test_framework/coverage.py @@ -35,10 +35,11 @@ self.auth_service_proxy_instance = auth_service_proxy_instance self.coverage_logfile = coverage_logfile - def __getattr__(self, *args, **kwargs): - return_val = self.auth_service_proxy_instance.__getattr__( - *args, **kwargs) - + def __getattr__(self, name): + return_val = getattr(self.auth_service_proxy_instance, name) + if not isinstance(return_val, type(self.auth_service_proxy_instance)): + # If proxy getattr returned an unwrapped value, do the same here. + return return_val return AuthServiceProxyWrapper(return_val, self.coverage_logfile) def __call__(self, *args, **kwargs): @@ -48,20 +49,23 @@ """ return_val = self.auth_service_proxy_instance.__call__(*args, **kwargs) + self._log_call() + return return_val + + def _log_call(self): rpc_method = self.auth_service_proxy_instance._service_name if self.coverage_logfile: with open(self.coverage_logfile, 'a+', encoding='utf8') as f: f.write("%s\n" % rpc_method) - return return_val - - @property - def url(self): - return self.auth_service_proxy_instance.url - def __truediv__(self, relative_uri): - return AuthServiceProxyWrapper(self.auth_service_proxy_instance / relative_uri) + return AuthServiceProxyWrapper(self.auth_service_proxy_instance / relative_uri, + self.coverage_logfile) + + def get_request(self, *args, **kwargs): + self._log_call() + return self.auth_service_proxy_instance.get_request(*args, **kwargs) def get_filename(dirname, n_node):