diff --git a/test/functional/test_framework/schnorr.py b/test/functional/test_framework/schnorr.py --- a/test/functional/test_framework/schnorr.py +++ b/test/functional/test_framework/schnorr.py @@ -41,15 +41,33 @@ ssl.EC_POINT_free.argtypes = [ctypes.c_void_p] ssl.EC_POINT_mul.restype = ctypes.c_int -ssl.EC_POINT_mul.argtypes = [ctypes.c_void_p, ctypes.c_void_p, - ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p] +ssl.EC_POINT_mul.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p] ssl.EC_POINT_is_at_infinity.restype = ctypes.c_int ssl.EC_POINT_is_at_infinity.argtypes = [ctypes.c_void_p, ctypes.c_void_p] ssl.EC_POINT_point2oct.restype = ctypes.c_size_t -ssl.EC_POINT_point2oct.argtypes = [ctypes.c_void_p, ctypes.c_void_p, - ctypes.c_int, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_void_p] +ssl.EC_POINT_point2oct.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_void_p] + +ssl.EC_POINT_oct2point.restype = ctypes.c_int +ssl.EC_POINT_oct2point.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_void_p] # point encodings for EC_POINT_point2oct POINT_CONVERSION_COMPRESSED = 2 @@ -166,45 +184,109 @@ ctx = CTX.ptr_for_this_thread() - # calculate R point and pubkey point, and get them in + # calculate R point and P point, and get them in # uncompressed/compressed formats respectively. R = ssl.EC_POINT_new(group) assert R - pubkey = ssl.EC_POINT_new(group) - assert pubkey + P = ssl.EC_POINT_new(group) + assert P kbn = ssl.BN_bin2bn(k.to_bytes(32, 'big'), 32, None) assert kbn privbn = ssl.BN_bin2bn(privkeybytes, 32, None) assert privbn assert ssl.EC_POINT_mul(group, R, kbn, None, None, ctx) - assert ssl.EC_POINT_mul(group, pubkey, privbn, None, None, ctx) + assert ssl.EC_POINT_mul(group, P, privbn, None, None, ctx) # buffer for uncompressed R coord Rbuf = ctypes.create_string_buffer(65) assert 65 == ssl.EC_POINT_point2oct( group, R, POINT_CONVERSION_UNCOMPRESSED, Rbuf, 65, ctx) - # buffer for compressed pubkey + # buffer for compressed P pubkeybuf = ctypes.create_string_buffer(33) assert 33 == ssl.EC_POINT_point2oct( - group, pubkey, POINT_CONVERSION_COMPRESSED, pubkeybuf, 33, ctx) + group, P, POINT_CONVERSION_COMPRESSED, pubkeybuf, 33, ctx) ssl.BN_free(kbn) ssl.BN_free(privbn) + ssl.EC_POINT_free(P) ssl.EC_POINT_free(R) - ssl.EC_POINT_free(pubkey) - Ry = int.from_bytes(Rbuf[33:65], 'big') # y coord + # y coord + Ry = int.from_bytes(Rbuf[33:65], 'big') if jacobi(Ry, SECP256K1_FIELDSIZE) == -1: k = SECP256K1_ORDER - k - rbytes = Rbuf[1:33] # x coord big-endian + # x coord big-endian + Rx = Rbuf[1:33] e = int.from_bytes(hashlib.sha256( - rbytes + pubkeybuf + msg32).digest(), 'big') + Rx + pubkeybuf + msg32).digest(), 'big') privkey = int.from_bytes(privkeybytes, 'big') s = (k + e * privkey) % SECP256K1_ORDER - return rbytes + s.to_bytes(32, 'big') + sig = Rx + s.to_bytes(32, 'big') + assert verifypriv(sig, privkeybytes, msg32) + return sig + + +def verifypriv(sig, privkeybytes, msg32): + ctx = CTX.ptr_for_this_thread() + + P = ssl.EC_POINT_new(group) + assert P + privbn = ssl.BN_bin2bn(privkeybytes, 32, None) + assert privbn + assert ssl.EC_POINT_mul(group, P, privbn, None, None, ctx) + ssl.BN_free(privbn) + + # buffer for compressed P + pubkeybuf = ctypes.create_string_buffer(33) + assert 33 == ssl.EC_POINT_point2oct( + group, P, POINT_CONVERSION_COMPRESSED, pubkeybuf, 33, ctx) + + ssl.EC_POINT_free(P) + return verify(sig, pubkeybuf[:], msg32) + + +def verify(sig, pubkeybuf, msg32): + assert len(sig) == 64 + assert len(pubkeybuf) == 33 + assert len(msg32) == 32 + + Rx = sig[:32] + s = int.from_bytes(sig[32:], 'big') + e = int.from_bytes(hashlib.sha256(Rx + pubkeybuf + msg32).digest(), 'big') + nege = SECP256K1_ORDER - e + + ctx = CTX.ptr_for_this_thread() + + R = ssl.EC_POINT_new(group) + assert R + + P = ssl.EC_POINT_new(group) + assert P + assert ssl.EC_POINT_oct2point(group, P, pubkeybuf, 33, ctx) + + sbn = ssl.BN_bin2bn(s.to_bytes(32, 'big'), 32, None) + assert sbn + negebn = ssl.BN_bin2bn(nege.to_bytes(32, 'big'), 32, None) + assert negebn + assert ssl.EC_POINT_mul(group, R, sbn, P, negebn, ctx) + ssl.BN_free(negebn) + ssl.BN_free(sbn) + ssl.EC_POINT_free(P) + + # buffer for uncompressed R coord + Rbuf = ctypes.create_string_buffer(65) + assert 65 == ssl.EC_POINT_point2oct( + group, R, POINT_CONVERSION_UNCOMPRESSED, Rbuf, 65, ctx) + ssl.EC_POINT_free(R) + + if Rbuf[1:33] != Rx: + return False + + Ry = int.from_bytes(Rbuf[33:65], 'big') + return jacobi(Ry, SECP256K1_FIELDSIZE) == 1 def getpubkey(privkeybytes, compressed=True):