Browse Source

move easy scalar computations to python

master
Michael Hamburg 9 years ago
parent
commit
7ee81cf84f
6 changed files with 18 additions and 44 deletions
  1. +0
    -7
      src/gen_headers/curve_data.py
  2. +2
    -3
      src/gen_headers/template.py
  3. +10
    -11
      src/per_curve/decaf.tmpl.c
  4. +0
    -17
      src/per_curve/decaf_gen_tables.tmpl.c
  5. +2
    -6
      src/per_field/f_field.tmpl.h
  6. +4
    -0
      src/per_field/f_generic.tmpl.c

+ 0
- 7
src/gen_headers/curve_data.py View File

@@ -74,8 +74,6 @@ def ceil_log2(x):
for field,data in field_data.iteritems(): for field,data in field_data.iteritems():
if "modulus" not in data: if "modulus" not in data:
data["modulus"] = eval(data["gf_desc"].replace("^","**")) data["modulus"] = eval(data["gf_desc"].replace("^","**"))
data["p_mod_8"] = data["modulus"] % 8
if "gf_bits" not in data: if "gf_bits" not in data:
data["gf_bits"] = ceil_log2(data["modulus"]) data["gf_bits"] = ceil_log2(data["modulus"])
@@ -88,10 +86,6 @@ for field,data in field_data.iteritems():
if "x_priv_bits" not in data: if "x_priv_bits" not in data:
data["x_priv_bits"] = ceil_log2(data["modulus"]*0.99) # not per curve at least in 7748 data["x_priv_bits"] = ceil_log2(data["modulus"]*0.99) # not per curve at least in 7748
data["ser_modulus"] = ser(data["modulus"], data["gf_lit_limb_bits"])
if data["modulus"] % 4 == 1: data["sqrt_minus_one"] = ser(msqrt(-1,data["modulus"]), data["gf_lit_limb_bits"])
else: data["sqrt_minus_one"] = "/* NONE */"


for curve,data in curve_data.iteritems(): for curve,data in curve_data.iteritems():
for key in field_data[data["field"]]: for key in field_data[data["field"]]:
@@ -121,7 +115,6 @@ for curve,data in curve_data.iteritems():
data["q"] = (data["modulus"]+1-data["trace"]) // data["cofactor"] data["q"] = (data["modulus"]+1-data["trace"]) // data["cofactor"]
data["bits"] = ceil_log2(data["modulus"]) data["bits"] = ceil_log2(data["modulus"])
data["decaf_base"] = ser(msqrt(data["mont_base"],data["modulus"]),8) data["decaf_base"] = ser(msqrt(data["mont_base"],data["modulus"]),8)
data["scalar_p"] = ser(data["q"],64,"SC_LIMB")
if data["cofactor"] > 4: data["sqrt_one_minus_d"] = ser(msqrt(1-data["d"],data["modulus"]),data["gf_lit_limb_bits"]) if data["cofactor"] > 4: data["sqrt_one_minus_d"] = ser(msqrt(1-data["d"],data["modulus"]),data["gf_lit_limb_bits"])
else: data["sqrt_one_minus_d"] = "/* NONE */" else: data["sqrt_one_minus_d"] = "/* NONE */"


+ 2
- 3
src/gen_headers/template.py View File

@@ -1,4 +1,5 @@
from textwrap import dedent from textwrap import dedent
from curve_data import field_data,curve_data,ser,msqrt


import os import os
import argparse import argparse
@@ -12,8 +13,6 @@ parser.add_argument('--guard', required = False, default = None, help = "header
parser.add_argument('files', metavar='file', type=str, nargs='+', help='a list of files to fill') parser.add_argument('files', metavar='file', type=str, nargs='+', help='a list of files to fill')
args = parser.parse_args() args = parser.parse_args()


from curve_data import field_data,curve_data

per_map = {"field":field_data, "curve":curve_data, "global":{"global":{}} } per_map = {"field":field_data, "curve":curve_data, "global":{"global":{}} }


def redoc(filename,doc,author): def redoc(filename,doc,author):
@@ -51,7 +50,7 @@ def fillin(template,data):
if template[position] == '(': parens += 1 if template[position] == '(': parens += 1
elif template[position] == ')': parens -= 1 elif template[position] == ')': parens -= 1
position += 1 position += 1
ret += str(eval(template[dollars+2:position-1],data))
ret += str(eval(template[dollars+2:position-1],{'ser':ser,'msqrt':msqrt},data))


author = "Mike Hamburg" # FUTURE author = "Mike Hamburg" # FUTURE
for name in args.files: for name in args.files:


+ 10
- 11
src/per_curve/decaf.tmpl.c View File

@@ -10,6 +10,7 @@


#include <decaf.h> #include <decaf.h>


/* Template stuff */
#define API_NS(_id) $(c_ns)_##_id #define API_NS(_id) $(c_ns)_##_id
#define SCALAR_BITS $(C_NS)_SCALAR_BITS #define SCALAR_BITS $(C_NS)_SCALAR_BITS
#define SCALAR_LIMBS $(C_NS)_SCALAR_LIMBS #define SCALAR_LIMBS $(C_NS)_SCALAR_LIMBS
@@ -20,7 +21,11 @@
#define COFACTOR $(cofactor) #define COFACTOR $(cofactor)


static const int EDWARDS_D = $(d); static const int EDWARDS_D = $(d);
static const scalar_t sc_p = {{{ $(scalar_p) }}};
static const scalar_t sc_p = {{{ $(ser(q,64,"SC_LIMB")) }}};
static const scalar_t sc_r2 = {{{ $(ser(((2**128)**((scalar_bits+63)/64))%q,64,"SC_LIMB")) }}};
extern const scalar_t API_NS(point_scalarmul_adjustment); /* TODO: auto template these too. */
extern const scalar_t API_NS(precomputed_scalarmul_adjustment);
static const decaf_word_t MONTGOMERY_FACTOR = (decaf_word_t)0x$("%x" % pow(-q,2**64-1,2**64))ull;


#if COFACTOR==8 #if COFACTOR==8
static const gf SQRT_ONE_MINUS_D = {FIELD_LITERAL( static const gf SQRT_ONE_MINUS_D = {FIELD_LITERAL(
@@ -49,8 +54,6 @@ extern const gf SQRT_MINUS_ONE;
#define WBITS DECAF_WORD_BITS /* NB this may be different from ARCH_WORD_BITS */ #define WBITS DECAF_WORD_BITS /* NB this may be different from ARCH_WORD_BITS */


const scalar_t API_NS(scalar_one) = {{{1}}}, API_NS(scalar_zero) = {{{0}}}; const scalar_t API_NS(scalar_one) = {{{1}}}, API_NS(scalar_zero) = {{{0}}};
extern const scalar_t API_NS(sc_r2);
extern const decaf_word_t API_NS(MONTGOMERY_FACTOR);
extern const point_t API_NS(point_base); extern const point_t API_NS(point_base);


/* Projective Niels coordinates */ /* Projective Niels coordinates */
@@ -220,7 +223,7 @@ sc_montmul (
} }
accum[j] = chain; accum[j] = chain;
mand = accum[0] * API_NS(MONTGOMERY_FACTOR);
mand = accum[0] * MONTGOMERY_FACTOR;
chain = 0; chain = 0;
mier = sc_p->limb; mier = sc_p->limb;
for (j=0; j<SCALAR_LIMBS; j++) { for (j=0; j<SCALAR_LIMBS; j++) {
@@ -243,7 +246,7 @@ void API_NS(scalar_mul) (
const scalar_t b const scalar_t b
) { ) {
sc_montmul(out,a,b); sc_montmul(out,a,b);
sc_montmul(out,out,API_NS(sc_r2));
sc_montmul(out,out,sc_r2);
} }


/* PERF: could implement this */ /* PERF: could implement this */
@@ -263,7 +266,7 @@ decaf_error_t API_NS(scalar_invert) (
const int LAST = (1<<SCALAR_WINDOW_BITS)-1; const int LAST = (1<<SCALAR_WINDOW_BITS)-1;


/* Precompute precmp = [a^1,a^3,...] */ /* Precompute precmp = [a^1,a^3,...] */
sc_montmul(precmp[0],a,API_NS(sc_r2));
sc_montmul(precmp[0],a,sc_r2);
if (LAST > 0) sc_montmul(precmp[LAST],precmp[0],precmp[0]); if (LAST > 0) sc_montmul(precmp[LAST],precmp[0],precmp[0]);


int i; int i;
@@ -734,7 +737,7 @@ void API_NS(scalar_decode_long)(


while (i) { while (i) {
i -= SER_BYTES; i -= SER_BYTES;
sc_montmul(t1,t1,API_NS(sc_r2));
sc_montmul(t1,t1,sc_r2);
ignore_result( API_NS(scalar_decode)(t2, ser+i) ); ignore_result( API_NS(scalar_decode)(t2, ser+i) );
API_NS(scalar_add)(t1, t1, t2); API_NS(scalar_add)(t1, t1, t2);
} }
@@ -868,8 +871,6 @@ sub_pniels_from_pt (
sub_niels_from_pt( p, pn->n, before_double ); sub_niels_from_pt( p, pn->n, before_double );
} }


extern const scalar_t API_NS(point_scalarmul_adjustment);

static INLINE void static INLINE void
constant_time_lookup_xx ( constant_time_lookup_xx (
void *__restrict__ out_, void *__restrict__ out_,
@@ -1477,8 +1478,6 @@ void API_NS(precompute) (
decaf_bzero(doubles,sizeof(doubles)); decaf_bzero(doubles,sizeof(doubles));
} }


extern const scalar_t API_NS(precomputed_scalarmul_adjustment);

static INLINE void static INLINE void
constant_time_lookup_xx_niels ( constant_time_lookup_xx_niels (
niels_s *__restrict__ ni, niels_s *__restrict__ ni,


+ 0
- 17
src/per_curve/decaf_gen_tables.tmpl.c View File

@@ -19,8 +19,6 @@ static const unsigned char base_point_ser_for_pregen[SER_BYTES] = {
const gf API_NS(precomputed_base_as_fe)[1]; const gf API_NS(precomputed_base_as_fe)[1];
const API_NS(scalar_t) API_NS(precomputed_scalarmul_adjustment); const API_NS(scalar_t) API_NS(precomputed_scalarmul_adjustment);
const API_NS(scalar_t) API_NS(point_scalarmul_adjustment); const API_NS(scalar_t) API_NS(point_scalarmul_adjustment);
const API_NS(scalar_t) API_NS(sc_r2) = {{{0}}};
const decaf_word_t API_NS(MONTGOMERY_FACTOR) = 0;


const API_NS(point_t) API_NS(point_base); const API_NS(point_t) API_NS(point_base);
const uint8_t API_NS(x_base_point)[X_PUBLIC_BYTES] = {0}; const uint8_t API_NS(x_base_point)[X_PUBLIC_BYTES] = {0};
@@ -148,23 +146,8 @@ int main(int argc, char **argv) {
API_NS(scalar_sub)(smadj, smadj, API_NS(scalar_one)); API_NS(scalar_sub)(smadj, smadj, API_NS(scalar_one));
scalar_print("API_NS(point_scalarmul_adjustment)", smadj); scalar_print("API_NS(point_scalarmul_adjustment)", smadj);
API_NS(scalar_copy)(smadj,API_NS(scalar_one));
for (i=0; i<sizeof(API_NS(scalar_t))*8*2; i++) {
API_NS(scalar_add)(smadj,smadj,smadj);
}
scalar_print("API_NS(sc_r2)", smadj);
API_NS(scalar_sub)(smadj,API_NS(scalar_zero),API_NS(scalar_one)); /* get p-1 */ API_NS(scalar_sub)(smadj,API_NS(scalar_zero),API_NS(scalar_one)); /* get p-1 */
unsigned long long w = 1, plo = smadj->limb[0]+1;
#if DECAF_WORD_BITS == 32
plo |= ((unsigned long long)smadj->limb[1]) << 32;
#endif
for (i=0; i<6; i++) {
w *= w*plo + 2;
}
printf("const decaf_word_t API_NS(MONTGOMERY_FACTOR) = (decaf_word_t)0x%016llxull;\n\n", w);


/* Generate the Montgomery ladder version of the base point */ /* Generate the Montgomery ladder version of the base point */
gf base1,base2; gf base1,base2;


+ 2
- 6
src/per_field/f_field.tmpl.h View File

@@ -74,14 +74,10 @@ mask_t gf_deserialize (gf x, const uint8_t serial[SER_BYTES]);


#include "f_impl.h" /* Bring in the inline implementations */ #include "f_impl.h" /* Bring in the inline implementations */


static const gf MODULUS = {FIELD_LITERAL(
$(ser_modulus)
)};

#define P_MOD_8 $(p_mod_8)
#define P_MOD_8 $(modulus % 8)
#if P_MOD_8 == 5 #if P_MOD_8 == 5
static const gf SQRT_MINUS_ONE = {FIELD_LITERAL( /* TODO make not static */ static const gf SQRT_MINUS_ONE = {FIELD_LITERAL( /* TODO make not static */
$(sqrt_minus_one)
$(ser(msqrt(-1,modulus),gf_lit_limb_bits) if modulus % 4 == 1 else "/* NOPE */")
)}; )};
#endif #endif




+ 4
- 0
src/per_field/f_generic.tmpl.c View File

@@ -10,6 +10,10 @@


#include "field.h" #include "field.h"


static const gf MODULUS = {FIELD_LITERAL(
$(ser(modulus,gf_lit_limb_bits))
)};

/** Serialize to wire format. */ /** Serialize to wire format. */
void gf_serialize (uint8_t serial[SER_BYTES], const gf x) { void gf_serialize (uint8_t serial[SER_BYTES], const gf x) {
gf red; gf red;


Loading…
Cancel
Save