From 55766b76c97e2fc46e1a6cbb861223a814d7d2b7 Mon Sep 17 00:00:00 2001 From: Michael Hamburg Date: Fri, 29 Jan 2016 18:44:04 -0800 Subject: [PATCH] homogenize invsqrt code --- Makefile | 4 +- src/gen_headers/curve_data.py | 1 - src/p25519/f_arithmetic.c | 49 +++++++++----- src/p448/f_arithmetic.c | 9 ++- src/per_curve/decaf.tmpl.c | 93 +++++++++++---------------- src/per_curve/decaf_gen_tables.tmpl.c | 6 +- src/per_field/f_field.tmpl.h | 14 ++-- src/per_field/f_generic.tmpl.c | 20 ++++-- 8 files changed, 107 insertions(+), 89 deletions(-) diff --git a/Makefile b/Makefile index 00e23fd..9271b29 100644 --- a/Makefile +++ b/Makefile @@ -156,8 +156,8 @@ $$(BUILD_OBJ)/$(1)/%.o: src/$(1)/%.c $$(HEADERS_OF_$(1)) $$(BUILD_OBJ)/$(1)/%.o: src/$(1)/$$(ARCH_FOR_$(1))/%.c $$(HEADERS_OF_$(1)) $$(CC) $$(CFLAGS) -I src/$(1) -I src/$(1)/$$(ARCH_FOR_$(1)) -I $(BUILD_H)/$(1) \ - -I $(BUILD_H)/$(1)/$$(ARCH_FOR_$(1)) -I src/include/$$(ARCH_FOR_$(1)) \ - -c -o $$@ $$< + -I $(BUILD_H)/$(1)/$$(ARCH_FOR_$(1)) -I src/include/$$(ARCH_FOR_$(1)) \ + -c -o $$@ $$< endef ################################################################ diff --git a/src/gen_headers/curve_data.py b/src/gen_headers/curve_data.py index 67b8175..73a7bc4 100644 --- a/src/gen_headers/curve_data.py +++ b/src/gen_headers/curve_data.py @@ -84,7 +84,6 @@ def ceil_log2(x): out += 1 return out -# TODO: reduce this because we can now have expressions. for field,data in field_data.iteritems(): if "modulus" not in data: data["modulus"] = eval(data["gf_desc"].replace("^","**")) diff --git a/src/p25519/f_arithmetic.c b/src/p25519/f_arithmetic.c index 8fe653f..e4a1962 100644 --- a/src/p25519/f_arithmetic.c +++ b/src/p25519/f_arithmetic.c @@ -12,21 +12,40 @@ #include "constant_time.h" /* Guarantee: a^2 x = 0 if x = 0; else a^2 x = 1 or SQRT_MINUS_ONE; */ -void gf_isr (gf a, const gf x) { - gf st[3], tmp1, tmp2; - const struct { unsigned char sh, idx; } ops[] = { - {1,2},{1,2},{3,1},{6,0},{1,2},{12,1},{25,1},{25,1},{50,0},{125,0},{2,2},{1,2} - }; - st[0][0] = st[1][0] = st[2][0] = x[0]; - unsigned int i; - for (i=0; ilimb[0]&1); -} - #if COFACTOR==8 /** Return high bit of x = low bit of 2x mod p */ -static mask_t lobit(const gf x) { +static mask_t gf_lobit(const gf x) { gf y; gf_copy(y,x); gf_strong_reduce(y); @@ -424,20 +407,20 @@ deisogenize ( gf_sub(b, p->z, p->y); gf_mul(c, b, a); gf_mulw_sgn(b, c, -EDWARDS_D); /* (a-d)(Z+Y)(Z-Y) */ - mask_t ok = gf_isqrt_chk ( a, b, DECAF_TRUE); /* r in the paper */ - (void)ok; assert(ok); + mask_t ok = gf_isr (a,b); /* r in the paper */ + (void)ok; assert(ok | gf_eq(b,ZERO)); gf_mulw_sgn (b, a, -EDWARDS_D); /* u in the paper */ gf_mul(c,a,d); /* r(aZX-dYT) */ gf_mul(a,b,p->z); /* uZ */ gf_add(a,a,a); /* 2uZ */ - cond_neg(c, toggle_hibit_t_over_s ^ ~hibit(a)); /* u <- -u if negative. */ - cond_neg(a, toggle_hibit_t_over_s ^ ~hibit(a)); /* t/s <-? -t/s */ + cond_neg(c, toggle_hibit_t_over_s ^ ~gf_hibit(a)); /* u <- -u if negative. */ + cond_neg(a, toggle_hibit_t_over_s ^ ~gf_hibit(a)); /* t/s <-? -t/s */ gf_add(d,c,p->y); gf_mul(s,b,d); - cond_neg(s, toggle_hibit_s ^ hibit(s)); + cond_neg(s, toggle_hibit_s ^ gf_hibit(s)); #else /* More complicated because of rotation */ /* MAGIC This code is wrong for certain non-Curve25519 curves; @@ -468,8 +451,8 @@ deisogenize ( gf_mul ( a, p->z, t ); /* "tz" = T*Z */ gf_sqr ( b, a ); gf_mul ( d, b, c ); /* (TZ)^2 * (Z^2-aX^2) */ - mask_t ok = gf_isqrt_chk ( b, d, DECAF_TRUE ); - (void)ok; assert(ok); + mask_t ok = gf_isr(b, d); + (void)ok; assert(ok | gf_eq(d,ZERO)); gf_mul ( d, b, a ); /* "osx" = 1 / sqrt(z^2-ax^2) */ gf_mul ( a, b, c ); gf_mul ( b, a, d ); /* 1/tz */ @@ -479,7 +462,7 @@ deisogenize ( gf e; gf_sqr(e, p->z); gf_mul(a, e, b); /* z^2 / tz = z/t = 1/xy */ - rotate = hibit(a) ^ toggle_rotation; + rotate = gf_hibit(a) ^ toggle_rotation; /* Curve25519: cond select between zx * 1/tz or sqrt(1-d); y=-x */ gf_mul ( a, b, c ); cond_sel ( a, a, SQRT_ONE_MINUS_D, rotate ); @@ -492,13 +475,13 @@ deisogenize ( gf_mul ( c, a, d ); // new "osx" gf_mul ( a, c, p->z ); gf_add ( a, a, a ); // 2 * "osx" * Z - mask_t tg1 = rotate ^ toggle_hibit_t_over_s ^~ hibit(a); + mask_t tg1 = rotate ^ toggle_hibit_t_over_s ^~ gf_hibit(a); cond_neg ( c, tg1 ); cond_neg ( a, rotate ^ tg1 ); gf_mul ( d, b, p->z ); gf_add ( d, d, c ); gf_mul ( b, d, x ); /* here "x" = y unless rotate */ - cond_neg ( b, toggle_hibit_s ^ hibit(b) ); + cond_neg ( b, toggle_hibit_s ^ gf_hibit(b) ); #endif } @@ -506,7 +489,7 @@ deisogenize ( void API_NS(point_encode)( unsigned char ser[SER_BYTES], const point_t p ) { gf s, mtos; deisogenize(s,mtos,p,0,0,0); - gf_serialize ( ser, s ); + gf_serialize(ser,s,0); } decaf_error_t API_NS(point_decode) ( @@ -515,10 +498,9 @@ decaf_error_t API_NS(point_decode) ( decaf_bool_t allow_identity ) { gf s, a, b, c, d, e, f; - mask_t succ = gf_deserialize(s, ser); + mask_t succ = gf_deserialize(s, ser, 0); mask_t zero = gf_eq(s, ZERO); succ &= bool_to_mask(allow_identity) | ~zero; - succ &= ~hibit(s); gf_sqr ( a, s ); #if IMAGINE_TWIST gf_sub ( f, ONE, a ); /* f = 1-as^2 = 1-s^2*/ @@ -533,11 +515,11 @@ decaf_error_t API_NS(point_decode) ( gf_sqr ( e, d ); gf_mul ( b, c, e ); - succ &= gf_isqrt_chk ( e, b, DECAF_TRUE ); /* e = 1/(t s (1-as^2)) */ + succ &= gf_isr(e,b) | gf_eq(b,ZERO); /* e = 1/(t s (1-as^2)) */ gf_mul ( b, e, d ); /* 1/t */ gf_mul ( d, e, c ); /* d = t / (s(1-as^2)) */ gf_mul ( e, d, f ); /* t/s */ - mask_t negtos = hibit(e); + mask_t negtos = gf_hibit(e); cond_neg(b, negtos); cond_neg(d, negtos); @@ -549,7 +531,7 @@ decaf_error_t API_NS(point_decode) ( #if COFACTOR == 8 gf_mul ( a, p->z, d); /* t(1+s^2) / s(1-s^2) = 2/xy */ - succ &= ~lobit(a); /* = ~hibit(a/2), since hibit(x) = lobit(2x) */ + succ &= ~gf_lobit(a); /* = ~gf_hibit(a/2), since gf_hibit(x) = gf_lobit(2x) */ #endif gf_mul ( a, f, b ); /* y = (1-s^2) / t */ @@ -685,7 +667,7 @@ void API_NS(point_negate) ( static INLINE void scalar_decode_short ( scalar_t s, - const unsigned char ser[SER_BYTES], + const unsigned char *ser, unsigned int nbytes ) { unsigned int i,j,k=0; @@ -700,10 +682,10 @@ scalar_decode_short ( decaf_error_t API_NS(scalar_decode)( scalar_t s, - const unsigned char ser[SER_BYTES] + const unsigned char ser[SCALAR_SER_BYTES] ) { unsigned int i; - scalar_decode_short(s, ser, SER_BYTES); + scalar_decode_short(s, ser, SCALAR_SER_BYTES); decaf_dsword_t accum = 0; for (i=0; ilimb[i] - sc_p->limb[i]) >> WBITS; @@ -738,8 +720,8 @@ void API_NS(scalar_decode_long)( size_t i; scalar_t t1, t2; - i = ser_len - (ser_len%SER_BYTES); - if (i==ser_len) i -= SER_BYTES; + i = ser_len - (ser_len%SCALAR_SER_BYTES); + if (i==ser_len) i -= SCALAR_SER_BYTES; scalar_decode_short(t1, &ser[i], ser_len-i); @@ -752,7 +734,7 @@ void API_NS(scalar_decode_long)( } while (i) { - i -= SER_BYTES; + i -= SCALAR_SER_BYTES; sc_montmul(t1,t1,sc_r2); ignore_result( API_NS(scalar_decode)(t2, ser+i) ); API_NS(scalar_add)(t1, t1, t2); @@ -764,7 +746,7 @@ void API_NS(scalar_decode_long)( } void API_NS(scalar_encode)( - unsigned char ser[SER_BYTES], + unsigned char ser[SCALAR_SER_BYTES], const scalar_t s ) { unsigned int i,j,k=0; @@ -1188,7 +1170,7 @@ void API_NS(point_from_hash_nonuniform) ( const unsigned char ser[SER_BYTES] ) { gf r0,r,a,b,c,N,e; - gf_deserialize(r0,ser); + ignore_result(gf_deserialize(r0,ser,0)); gf_strong_reduce(r0); gf_sqr(a,r0); #if P_MOD_8 == 5 @@ -1213,13 +1195,13 @@ void API_NS(point_from_hash_nonuniform) ( /* e = +-sqrt(1/ND) or +-r0 * sqrt(qnr/ND) */ gf_mul(a,c,N); - mask_t square = gf_isqrt_chk(b,a,DECAF_FALSE); + mask_t square = gf_isr(b,a); cond_sel(c,r0,ONE,square); /* r? = square ? 1 : r0 */ gf_mul(e,b,c); /* s@a = +-|N.e| */ gf_mul(a,N,e); - cond_neg(a,hibit(a)^square); /* NB this is - what is listen in the paper */ + cond_neg(a,gf_hibit(a)^square); /* NB this is - what is listen in the paper */ /* t@b = -+ cN(r-1)((a-2d)e)^2 - 1 */ gf_mulw_sgn(c,e,1-2*EDWARDS_D); /* (a-2d)e */ @@ -1257,7 +1239,7 @@ API_NS(invert_elligator_nonuniform) ( mask_t hint = hint_; mask_t sgn_s = -(hint & 1), sgn_t_over_s = -(hint>>1 & 1), - sgn_r0 = -(hint>>2 & 1), + sgn_r0 = -(hint>>2 & 1), /* FIXME: but it's SER_BYTES ... */ sgn_ed_T = -(hint>>3 & 1); gf a, b, c, d; deisogenize(a,c,p,sgn_s,sgn_t_over_s,sgn_ed_T); @@ -1285,16 +1267,16 @@ API_NS(invert_elligator_nonuniform) ( #else gf_sub(d,ZERO,b); #endif - mask_t succ = gf_isqrt_chk(c,d,DECAF_TRUE); + mask_t succ = gf_isr(c,d)|gf_eq(d,ZERO); gf_mul(b,a,c); - cond_neg(b, sgn_r0^hibit(b)); + cond_neg(b, sgn_r0^gf_hibit(b)); succ &= ~(gf_eq(b,ZERO) & sgn_r0); #if COFACTOR == 8 succ &= ~(is_identity & sgn_ed_T); /* NB: there are no preimages of rotated identity. */ #endif - gf_serialize(recovered_hash, b); + gf_serialize(recovered_hash,b,1); /* FIXME: ,0 */ /* TODO: deal with overflow flag */ return decaf_succeed_if(mask_to_bool(succ)); } @@ -1365,7 +1347,8 @@ void API_NS(point_debugging_pscale) ( const uint8_t factor[SER_BYTES] ) { gf gfac,tmp; - ignore_result(gf_deserialize(gfac,factor)); + /* NB this means you'll never pscale by negative numbers for p521 */ + ignore_result(gf_deserialize(gfac,factor,0)); cond_sel(gfac,gfac,ONE,gf_eq(gfac,ZERO)); gf_mul(tmp,p->x,gfac); gf_copy(q->x,tmp); @@ -1593,7 +1576,7 @@ decaf_error_t API_NS(x_direct_scalarmul) ( const uint8_t scalar[X_PRIVATE_BYTES] ) { gf x1, x2, z2, x3, z3, t1, t2; - ignore_result(gf_deserialize(x1,base)); + ignore_result(gf_deserialize(x1,base,1)); gf_copy(x2,ONE); gf_copy(z2,ZERO); gf_copy(x3,x1); @@ -1644,7 +1627,7 @@ decaf_error_t API_NS(x_direct_scalarmul) ( cond_swap(z2,z3,swap); gf_invert(z2,z2); gf_mul(x1,x2,z2); - gf_serialize(out,x1); + gf_serialize(out,x1,1); mask_t nz = ~gf_eq(x1,ZERO); decaf_bzero(x1,sizeof(x1)); @@ -1706,7 +1689,7 @@ void API_NS(x_base_scalarmul) ( #if IMAGINE_TWIST gf_sub(p->y,ZERO,p->y); #endif - gf_serialize(out,p->y); + gf_serialize(out,p->y,1); decaf_bzero(scalar2,sizeof(scalar2)); API_NS(scalar_destroy)(the_scalar); diff --git a/src/per_curve/decaf_gen_tables.tmpl.c b/src/per_curve/decaf_gen_tables.tmpl.c index 44afb7f..46a031c 100644 --- a/src/per_curve/decaf_gen_tables.tmpl.c +++ b/src/per_curve/decaf_gen_tables.tmpl.c @@ -26,12 +26,12 @@ void API_NS(precompute_wnafs) ( const API_NS(point_t) base ); static void field_print(const gf f) { /* UNIFY */ - unsigned char ser[SER_BYTES]; - gf_serialize(ser,f); + unsigned char ser[X_SER_BYTES]; + gf_serialize(ser,f,1); int b=0, i, comma=0; unsigned long long limb = 0; printf("{FIELD_LITERAL("); - for (i=0; i= GF_LIT_LIMB_BITS || i == SER_BYTES-1) { diff --git a/src/per_field/f_field.tmpl.h b/src/per_field/f_field.tmpl.h index d210e46..d03c5ad 100644 --- a/src/per_field/f_field.tmpl.h +++ b/src/per_field/f_field.tmpl.h @@ -8,7 +8,8 @@ #define __DECAF_$(gf_shortname)_GF_DEFINED__ 1 #define NLIMBS ($(gf_impl_bits/8)/sizeof(word_t)) -#define SER_BYTES $(((gf_bits-1)//8 + 1)) /* MAGIC: depends on if high bit known to be clear (eg p521) */ +#define X_SER_BYTES $(((gf_bits-1)/8 + 1)) +#define SER_BYTES $(((gf_bits-2)/8 + 1)) typedef struct gf_$(gf_shortname)_s { word_t limb[NLIMBS]; } __attribute__((aligned(32))) gf_$(gf_shortname)_s, gf_$(gf_shortname)_t[1]; @@ -21,6 +22,7 @@ typedef struct gf_$(gf_shortname)_s { #define gf gf_$(gf_shortname)_t #define gf_s gf_$(gf_shortname)_s #define gf_eq gf_$(gf_shortname)_eq +#define gf_hibit gf_$(gf_shortname)_hibit #define gf_copy gf_$(gf_shortname)_copy #define gf_add gf_$(gf_shortname)_add #define gf_sub gf_$(gf_shortname)_sub @@ -37,7 +39,7 @@ typedef struct gf_$(gf_shortname)_s { #define gf_deserialize gf_$(gf_shortname)_deserialize /* RFC 7748 support */ -#define X_PUBLIC_BYTES $((gf_bits-1)/8 + 1) +#define X_PUBLIC_BYTES X_SER_BYTES #define X_PRIVATE_BYTES X_PUBLIC_BYTES #define X_PRIVATE_BITS $(gf_bits) @@ -62,10 +64,12 @@ void gf_sub (gf out, const gf a, const gf b); void gf_mul (gf_s *__restrict__ out, const gf a, const gf b); void gf_mulw (gf_s *__restrict__ out, const gf a, uint32_t b); void gf_sqr (gf_s *__restrict__ out, const gf a); -void gf_serialize (uint8_t *serial, const gf x); -void gf_isr(gf a, const gf x); /** a^2 x = 1, QNR, or 0 if x=0 */ +mask_t gf_isr(gf a, const gf x); /** a^2 x = 1, QNR, or 0 if x=0. Return true if successful */ mask_t gf_eq (const gf x, const gf y); -mask_t gf_deserialize (gf x, const uint8_t serial[SER_BYTES]); +mask_t gf_hibit (const gf x); + +void gf_serialize (uint8_t *serial, const gf x,int with_highbit); +mask_t gf_deserialize (gf x, const uint8_t serial[SER_BYTES],int with_highbit); #ifdef __cplusplus diff --git a/src/per_field/f_generic.tmpl.c b/src/per_field/f_generic.tmpl.c index 0a1f742..c7ebc00 100644 --- a/src/per_field/f_generic.tmpl.c +++ b/src/per_field/f_generic.tmpl.c @@ -15,14 +15,15 @@ static const gf MODULUS = {FIELD_LITERAL( )}; /** Serialize to wire format. */ -void gf_serialize (uint8_t serial[SER_BYTES], const gf x) { +void gf_serialize (uint8_t serial[SER_BYTES], const gf x, int with_hibit) { gf red; gf_copy(red, x); gf_strong_reduce(red); + if (!with_hibit) { assert(gf_hibit(red) == 0); } unsigned int j=0, fill=0; dword_t buffer = 0; - UNROLL for (unsigned int i=0; ilimb[LIMBPERM(j)]) << fill; fill += LIMB_PLACE_VALUE(LIMBPERM(j)); @@ -34,13 +35,21 @@ void gf_serialize (uint8_t serial[SER_BYTES], const gf x) { } } +/** Return high bit of x = low bit of 2x mod p */ +mask_t gf_hibit(const gf x) { + gf y; + gf_add(y,x,x); + gf_strong_reduce(y); + return -(y->limb[0]&1); +} + /** Deserialize from wire format; return -1 on success and 0 on failure. */ -mask_t gf_deserialize (gf x, const uint8_t serial[SER_BYTES]) { +mask_t gf_deserialize (gf x, const uint8_t serial[SER_BYTES], int with_hibit) { unsigned int j=0, fill=0; dword_t buffer = 0; dsword_t scarry = 0; UNROLL for (unsigned int i=0; i>= LIMB_PLACE_VALUE(LIMBPERM(i)); scarry = (scarry + x->limb[LIMBPERM(i)] - MODULUS->limb[LIMBPERM(i)]) >> (8*sizeof(word_t)); } - return word_is_zero(buffer) & ~word_is_zero(scarry); + mask_t succ = with_hibit ? -1 : ~gf_hibit(x); + return succ & word_is_zero(buffer) & ~word_is_zero(scarry); } /** Reduce to canonical form. */