diff --git a/ed25519.py b/ed25519.py index 7f9cf6d..b2062ca 100644 --- a/ed25519.py +++ b/ed25519.py @@ -15,6 +15,8 @@ import hashlib import operator import sys +import struct +import functools __version__ = "1.0.dev0" @@ -27,6 +29,7 @@ indexbytes = operator.getitem intlist2bytes = bytes int2byte = operator.methodcaller("to_bytes", 1, "big") + reduce = functools.reduce else: int2byte = chr range = xrange @@ -201,19 +204,23 @@ def bit(h, i): def publickey(sk): h = H(sk) - a = 2 ** (b - 2) + sum(2 ** i * bit(h, i) for i in range(3, b - 2)) + a = decodeint(h, b) + a &= (2**((b-2) - 3) - 1) << 3 # isolate b-2 bits, 3 bits in + a += 2 ** (b - 2) # add algorithm constant A = scalarmult_B(a) return encodepoint(A) def Hint(m): h = H(m) - return sum(2 ** i * bit(h, i) for i in range(2 * b)) + return decodeint(h, 2 * b) def signature(m, sk, pk): h = H(sk) - a = 2 ** (b - 2) + sum(2 ** i * bit(h, i) for i in range(3, b - 2)) + a = decodeint(h, b) + a &= (2**((b-2) - 3) - 1) << 3 + a += 2 ** (b - 2) r = Hint( intlist2bytes([indexbytes(h, j) for j in range(b // 8, b // 4)]) + m ) @@ -229,12 +236,15 @@ def isoncurve(P): (y*y - x*x - z*z - d*t*t) % q == 0) -def decodeint(s): - return sum(2 ** i * bit(s, i) for i in range(0, b)) +def decodeint(s, length): + slice_size = length // 8 + num_longs = slice_size // 8 + return reduce(lambda a, x: a << 64 | x, + struct.unpack('<' + 'Q'*num_longs, s[:slice_size])[::-1], 0) def decodepoint(s): - y = sum(2 ** i * bit(s, i) for i in range(0, b - 1)) + y = decodeint(s, b) & (2**(b-1) - 1) x = xrecover(y) if x & 1 != bit(s, b-1): x = q - x @@ -257,7 +267,7 @@ def checkvalid(s, m, pk): R = decodepoint(s[:b // 8]) A = decodepoint(pk) - S = decodeint(s[b // 8:b // 4]) + S = decodeint(s[b // 8:b // 4], b) h = Hint(encodepoint(R) + pk + m) (x1, y1, z1, t1) = P = scalarmult_B(S)