diff --git a/aux/ristretto.sage b/aux/ristretto.sage index ac0c450..9c7d078 100644 --- a/aux/ristretto.sage +++ b/aux/ristretto.sage @@ -5,6 +5,7 @@ class SpecException(Exception): pass def lobit(x): return int(x) & 1 def hibit(x): return lobit(2*x) +def negative(x): return lobit(x) def enc_le(x,n): return bytearray([int(x)>>(8*i) & 0xFF for i in xrange(n)]) def dec_le(x): return sum(b<<(8*i) for i,b in enumerate(x)) def randombytes(n): return bytearray([randint(0,255) for _ in range(n)]) @@ -36,7 +37,7 @@ def xsqrt(x,exn=InvalidEncodingException("Not on curve")): """Return sqrt(x)""" if not is_square(x): raise exn s = sqrt(x) - if lobit(s): s=-s + if negative(s): s=-s return s def isqrt(x,exn=InvalidEncodingException("Not on curve")): @@ -106,7 +107,8 @@ class QuotientEdwardsPoint(object): def torque(self): """Apply cofactor group, except keeping the point even""" if self.cofactor == 8: - return self.__class__(self.y*self.i, self.x*self.i) + if self.a == -1: return self.__class__(self.y*self.i, self.x*self.i) + if self.a == 1: return self.__class__(-self.y, self.x) else: return self.__class__(-self.x, -self.y) @@ -120,14 +122,15 @@ class QuotientEdwardsPoint(object): s = dec_le(bytes) if mustBeProper and s >= cls.F.modulus(): raise InvalidEncodingException("%d out of range!" % s) - if mustBePositive and lobit(s): + s = cls.F(s) + if mustBePositive and negative(s): raise InvalidEncodingException("%d is negative!" % s) - return cls.F(s) + return s @classmethod def gfToBytes(cls,x,mustBePositive=False): """Convert little-endian bytes to field element, sanity check length""" - if lobit(x) and mustBePositive: x = -x + if negative(x) and mustBePositive: x = -x return enc_le(x,cls.encLen) class RistrettoPoint(QuotientEdwardsPoint): @@ -135,12 +138,10 @@ class RistrettoPoint(QuotientEdwardsPoint): def encodeSpec(self): """Unoptimized specification for encoding""" x,y = self - if self.cofactor==8 and (lobit(x*y) or y==0): - (x,y) = (self.i*y,self.i*x) - + if self.cofactor==8 and (negative(x*y) or y==0): (x,y) = self.torque() if y == -1: y = 1 # Avoid divide by 0; doesn't affect impl - if lobit(x): x,y = -x,-y + if negative(x): x,y = -x,-y s = xsqrt(self.a*(y-1)/(y+1),exn=Exception("Unimplemented: point is odd: " + str(self))) return self.gfToBytes(s) @@ -153,7 +154,7 @@ class RistrettoPoint(QuotientEdwardsPoint): x = xsqrt(4*s^2 / (a*d*(1+a*s^2)^2 - (1-a*s^2)^2)) y = (1+a*s^2) / (1-a*s^2) - if cls.cofactor==8 and (lobit(x*y) or y==0): + if cls.cofactor==8 and (negative(x*y) or y==0): raise InvalidEncodingException("x*y has high bit") return cls(x,y) @@ -171,13 +172,14 @@ class RistrettoPoint(QuotientEdwardsPoint): i2 = isr*u2 z_inv = i1*i2*t - if self.cofactor==8 and lobit(t*z_inv): - x,y = y*self.i,x*self.i + if self.cofactor==8 and negative(t*z_inv): + if a==-1: x,y = y*self.i,x*self.i + else: x,y = -y,x # TODO: test den_inv = self.magic * i1 else: den_inv = i2 - if lobit(x*z_inv): y = -y + if negative(x*z_inv): y = -y s = (z-y) * den_inv return self.gfToBytes(s,mustBePositive=True) @@ -199,10 +201,10 @@ class RistrettoPoint(QuotientEdwardsPoint): yden_inv = xden_inv * isr * xden_sqr x = 2*s*xden_inv - if lobit(x): x = -x + if negative(x): x = -x y = ynum * yden_inv - if cls.cofactor==8 and (lobit(x*y) or y==0): + if cls.cofactor==8 and (negative(x*y) or y==0): raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y)) return cls(x,y) @@ -213,7 +215,7 @@ class RistrettoPoint(QuotientEdwardsPoint): a,d = cls.a,cls.d assert s^4 - 2*cls.a*(1-2*d/(d-a))*s^2 + 1 == t^2 x = 2*s*cls.magic / t - if lobit(x): x = -x # TODO: doesn't work without resolving x + if negative(x): x = -x # TODO: doesn't work without resolving x y = (1+a*s^2) / (1-a*s^2) return cls(sgn*x,y) @@ -249,20 +251,26 @@ class RistrettoPoint(QuotientEdwardsPoint): return cls.fromJacobiQuartic(s,t,sgn) -class Decaf1Point(QuotientEdwardsPoint): +class Decaf_1_1_Point(QuotientEdwardsPoint): """Like current decaf but tweaked for simplicity""" def encodeSpec(self): """Unoptimized specification for encoding""" a,d = self.a,self.d x,y = self - if x==0: return(self.gfToBytes(0)) + if x==0 or y==0: return(self.gfToBytes(0)) + + if self.cofactor==8 and negative(x*y*self.isoMagic): + x,y = self.torque() isr2 = isqrt(a*(y^2-1)) / self.magic + + sr = xsqrt(1-a*x^2) + assert sr in [isr2*x*y,-isr2*x*y] + altx = 1/isr2*self.isoMagic - if lobit(altx): s = (1+x*y*isr2)/(a*x) - else: s = (1-x*y*isr2)/(a*x) + if negative(altx): s = (1+x*y*isr2)/(a*x) + else: s = (1-x*y*isr2)/(a*x) - # TODO: cofactor 8 return self.gfToBytes(s,mustBePositive=True) @classmethod @@ -274,11 +282,13 @@ class Decaf1Point(QuotientEdwardsPoint): if s==0: return cls() isr = isqrt(s^4 + 2*(a-2*d)*s^2 + 1) altx = 2*s*isr*cls.isoMagic - if lobit(altx): isr = -isr + if negative(altx): isr = -isr x = 2*s / (1+a*s^2) y = (1-a*s^2) * isr - # TODO: cofactor 8 + if cls.cofactor==8 and (negative(x*y*cls.isoMagic) or y==0): + raise InvalidEncodingException("x*y is invalid: %d, %d" % (x,y)) + return cls(x,y) @optimized_version_of("encodeSpec") @@ -325,7 +335,7 @@ class IsoEd448Point(RistrettoPoint): "00000000000000000000000000000000000000000000000000000000"+ "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff"))) -class TwistedEd448GoldilocksPoint(Decaf1Point): +class TwistedEd448GoldilocksPoint(Decaf_1_1_Point): F = GF(2^448-2^224-1) d = F(-39082) a = F(-1) @@ -341,7 +351,7 @@ class TwistedEd448GoldilocksPoint(Decaf1Point): "00000000000000000000000000000000000000000000000000000000"+ "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff"))) -class Ed448GoldilocksPoint(Decaf1Point): +class Ed448GoldilocksPoint(Decaf_1_1_Point): F = GF(2^448-2^224-1) d = F(-39081) a = F(1) @@ -357,6 +367,24 @@ class Ed448GoldilocksPoint(Decaf1Point): "00000000000000000000000000000000000000000000000000000000"+ "fdffffffffffffffffffffffffffffffffffffffffffffffffffffff"))) +class IsoEd25519Point(Decaf_1_1_Point): + # TODO: twisted iso too! + # TODO: twisted iso might have to IMAGINE_TWIST or whatever + F = GF(2^255-19) + d = F(-121665) + a = F(1) + i = sqrt(F(-1)) + qnr = i + magic = isqrt(a*d-1) + cofactor = 8 + encLen = 32 + isoMagic = Ed25519Point.magic + isoA = Ed25519Point.a + + @classmethod + def base(cls): + return cls.decodeSpec(Ed25519Point.base().encode()) + class TestFailedException(Exception): pass def test(cls,n): @@ -385,6 +413,7 @@ def test(cls,n): if Q1 + Q0 != Q2: raise TestFailedException("Scalarmul doesn't work") Q = Q1 test(Ed25519Point,100) +test(IsoEd25519Point,100) test(IsoEd448Point,100) test(TwistedEd448GoldilocksPoint,100) test(Ed448GoldilocksPoint,100) @@ -398,6 +427,7 @@ def gangtest(classes,n): print c,binascii.hexlify(ret) print gangtest([IsoEd448Point,TwistedEd448GoldilocksPoint,Ed448GoldilocksPoint],100) +gangtest([Ed25519Point,IsoEd25519Point],100)