You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

363 lines
12 KiB

  1. /**
  2. * @file test_decaf.cxx
  3. * @author Mike Hamburg
  4. *
  5. * @copyright
  6. * Copyright (c) 2015 Cryptography Research, Inc. \n
  7. * Released under the MIT License. See LICENSE.txt for license information.
  8. *
  9. * @brief C++ tests, because that's easier.
  10. */
  11. #include <decaf.hxx>
  12. #include <decaf/shake.hxx>
  13. #include <decaf/crypto.h>
  14. #include <stdio.h>
  15. static bool passing = true;
  16. static const long NTESTS = 10000;
  17. class Test {
  18. public:
  19. bool passing_now;
  20. Test(const char *test) {
  21. passing_now = true;
  22. printf("%s...", test);
  23. if (strlen(test) < 27) printf("%*s",int(27-strlen(test)),"");
  24. fflush(stdout);
  25. }
  26. ~Test() {
  27. if (std::uncaught_exception()) {
  28. fail();
  29. printf(" due to uncaught exception.\n");
  30. }
  31. if (passing_now) printf("[PASS]\n");
  32. }
  33. void fail() {
  34. if (!passing_now) return;
  35. passing_now = passing = false;
  36. printf("[FAIL]\n");
  37. }
  38. };
  39. template<typename Group> struct Tests {
  40. typedef typename Group::Scalar Scalar;
  41. typedef typename Group::Point Point;
  42. typedef typename Group::Precomputed Precomputed;
  43. static void print(const char *name, const Scalar &x) {
  44. unsigned char buffer[Scalar::SER_BYTES];
  45. x.encode(buffer);
  46. printf(" %s = 0x", name);
  47. for (int i=sizeof(buffer)-1; i>=0; i--) {
  48. printf("%02x", buffer[i]);
  49. }
  50. printf("\n");
  51. }
  52. static void hexprint(const char *name, const decaf::SecureBuffer &buffer) {
  53. printf(" %s = 0x", name);
  54. for (int i=buffer.size()-1; i>=0; i--) {
  55. printf("%02x", buffer[i]);
  56. }
  57. printf("\n");
  58. }
  59. static void print(const char *name, const Point &x) {
  60. unsigned char buffer[Point::SER_BYTES];
  61. x.encode(buffer);
  62. printf(" %s = 0x", name);
  63. for (int i=sizeof(buffer)-1; i>=0; i--) {
  64. printf("%02x", buffer[i]);
  65. }
  66. printf("\n");
  67. }
  68. static bool arith_check(
  69. Test &test,
  70. const Scalar &x,
  71. const Scalar &y,
  72. const Scalar &z,
  73. const Scalar &l,
  74. const Scalar &r,
  75. const char *name
  76. ) {
  77. if (l == r) return true;
  78. test.fail();
  79. printf(" %s", name);
  80. print("x", x);
  81. print("y", y);
  82. print("z", z);
  83. print("lhs", l);
  84. print("rhs", r);
  85. return false;
  86. }
  87. static bool point_check(
  88. Test &test,
  89. const Point &p,
  90. const Point &q,
  91. const Point &R,
  92. const Scalar &x,
  93. const Scalar &y,
  94. const Point &l,
  95. const Point &r,
  96. const char *name
  97. ) {
  98. bool good = l==r;
  99. if (!p.validate()) { good = false; printf(" p invalid\n"); }
  100. if (!q.validate()) { good = false; printf(" q invalid\n"); }
  101. if (!r.validate()) { good = false; printf(" r invalid\n"); }
  102. if (!l.validate()) { good = false; printf(" l invalid\n"); }
  103. if (good) return true;
  104. test.fail();
  105. printf(" %s", name);
  106. print("x", x);
  107. print("y", y);
  108. print("p", p);
  109. print("q", q);
  110. print("r", R);
  111. print("lhs", r);
  112. print("rhs", l);
  113. return false;
  114. }
  115. static void test_arithmetic() {
  116. decaf::SpongeRng rng(decaf::Block("test_arithmetic"));
  117. Test test("Arithmetic");
  118. Scalar x(0),y(0),z(0);
  119. arith_check(test,x,y,z,INT_MAX,(decaf_word_t)INT_MAX,"cast from max");
  120. arith_check(test,x,y,z,INT_MIN,-Scalar(1+(decaf_word_t)INT_MAX),"cast from min");
  121. for (int i=0; i<NTESTS*10 && test.passing_now; i++) {
  122. /* TODO: pathological cases */
  123. size_t sob = DECAF_255_SCALAR_BYTES + 8 - (i%16);
  124. Scalar x(rng.read(sob));
  125. Scalar y(rng.read(sob));
  126. Scalar z(rng.read(sob));
  127. arith_check(test,x,y,z,x+y,y+x,"commute add");
  128. arith_check(test,x,y,z,x,x+0,"ident add");
  129. arith_check(test,x,y,z,x,x-0,"ident sub");
  130. arith_check(test,x,y,z,x+(y+z),(x+y)+z,"assoc add");
  131. arith_check(test,x,y,z,x*(y+z),x*y + x*z,"distributive mul/add");
  132. arith_check(test,x,y,z,x*(y-z),x*y - x*z,"distributive mul/add");
  133. arith_check(test,x,y,z,x*(y*z),(x*y)*z,"assoc mul");
  134. arith_check(test,x,y,z,x*y,y*x,"commute mul");
  135. arith_check(test,x,y,z,x,x*1,"ident mul");
  136. arith_check(test,x,y,z,0,x*0,"mul by 0");
  137. arith_check(test,x,y,z,-x,x*-1,"mul by -1");
  138. arith_check(test,x,y,z,x+x,x*2,"mul by 2");
  139. if (i%20) continue;
  140. if (y!=0) arith_check(test,x,y,z,x*y/y,x,"invert");
  141. arith_check(test,x,y,z,x/0,0,"invert0");
  142. }
  143. }
  144. static void test_elligator() {
  145. decaf::SpongeRng rng(decaf::Block("test_elligator"));
  146. Test test("Elligator");
  147. const int NHINTS = 1<<4;
  148. decaf::SecureBuffer *alts[NHINTS];
  149. bool successes[NHINTS];
  150. decaf::SecureBuffer *alts2[NHINTS];
  151. bool successes2[NHINTS];
  152. for (int i=0; i<NTESTS/10 && (test.passing_now || i < 100); i++) {
  153. size_t len = (i % (2*Point::HASH_BYTES + 3)); // FIXME: 0
  154. decaf::SecureBuffer b1(len);
  155. if (i!=Point::HASH_BYTES) rng.read(b1); /* special test case */
  156. if (i==1) b1[0] = 1; /* special case test */
  157. if (len >= Point::HASH_BYTES) b1[Point::HASH_BYTES-1] &= 0x7F; // FIXME MAGIC
  158. Point s = Point::from_hash(b1), ss=s;
  159. for (int j=0; j<(i&3); j++) ss = ss.debugging_torque();
  160. ss = ss.debugging_pscale(rng);
  161. bool good = false;
  162. for (int j=0; j<NHINTS; j++) {
  163. alts[j] = new decaf::SecureBuffer(len);
  164. alts2[j] = new decaf::SecureBuffer(len);
  165. if (len > Point::HASH_BYTES)
  166. memcpy(&(*alts[j])[Point::HASH_BYTES], &b1[Point::HASH_BYTES], len-Point::HASH_BYTES);
  167. if (len > Point::HASH_BYTES)
  168. memcpy(&(*alts2[j])[Point::HASH_BYTES], &b1[Point::HASH_BYTES], len-Point::HASH_BYTES);
  169. successes[j] = s.invert_elligator(*alts[j], j);
  170. successes2[j] = ss.invert_elligator(*alts2[j],j);
  171. if (successes[j] != successes2[j]
  172. || (successes[j] && successes2[j] && *alts[j] != *alts2[j])
  173. ) {
  174. test.fail();
  175. printf(" Unscalable Elligator inversion: i=%d, hint=%d, s=%d,%d\n",i,j,
  176. -int(successes[j]),-int(successes2[j]));
  177. hexprint("x",b1);
  178. hexprint("X",*alts[j]);
  179. hexprint("X",*alts2[j]);
  180. }
  181. if (successes[j]) {
  182. good = good || (b1 == *alts[j]);
  183. for (int k=0; k<j; k++) {
  184. if (successes[k] && *alts[j] == *alts[k]) {
  185. test.fail();
  186. printf(" Duplicate Elligator inversion: i=%d, hints=%d, %d\n",i,j,k);
  187. hexprint("x",b1);
  188. hexprint("X",*alts[j]);
  189. }
  190. }
  191. if (s != Point::from_hash(*alts[j])) {
  192. test.fail();
  193. printf(" Fail Elligator inversion round-trip: i=%d, hint=%d %s\n",i,j,
  194. (s==-Point::from_hash(*alts[j])) ? "[output was -input]": "");
  195. hexprint("x",b1);
  196. hexprint("X",*alts[j]);
  197. }
  198. /*
  199. if (i == Point::HASH_BYTES) {
  200. printf("Identity, hint = %d\n", j);
  201. hexprint("einv(0)",*alts[j]);
  202. }
  203. */
  204. }
  205. }
  206. if (!good) {
  207. test.fail();
  208. printf(" %s Elligator inversion: i=%d\n",good ? "Passed" : "Failed", i);
  209. hexprint("B", b1);
  210. for (int j=0; j<NHINTS; j++) {
  211. printf(" %d: %s%s", j, successes[j] ? "succ" : "fail\n", (successes[j] && *alts[j] == b1) ? " [x]" : "");
  212. if (successes[j]) {
  213. hexprint("b", *alts[j]);
  214. }
  215. }
  216. printf("\n");
  217. }
  218. for (int j=0; j<NHINTS; j++) {
  219. delete alts[j];
  220. alts[j] = NULL;
  221. delete alts2[j];
  222. alts2[j] = NULL;
  223. }
  224. Point t(rng);
  225. point_check(test,t,t,t,0,0,t,Point::from_hash(t.steg_encode(rng)),"steg round-trip");
  226. }
  227. }
  228. static void test_ec() {
  229. decaf::SpongeRng rng(decaf::Block("test_ec"));
  230. Test test("EC");
  231. Point id = Point::identity(), base = Point::base();
  232. point_check(test,id,id,id,0,0,Point::from_hash(""),id,"fh0");
  233. //point_check(test,id,id,id,0,0,Point::from_hash("\x01"),id,"fh1"); FIXME
  234. for (int i=0; i<NTESTS && test.passing_now; i++) {
  235. /* TODO: pathological cases */
  236. Scalar x(rng);
  237. Scalar y(rng);
  238. Point p(rng);
  239. Point q(rng);
  240. decaf::SecureBuffer buffer(2*Point::HASH_BYTES);
  241. rng.read(buffer);
  242. Point r = Point::from_hash(buffer);
  243. point_check(test,p,q,r,0,0,p,Point((decaf::SecureBuffer)p),"round-trip");
  244. Point pp = p.debugging_torque().debugging_pscale(rng);
  245. if (decaf::SecureBuffer(pp) != decaf::SecureBuffer(p)) {
  246. test.fail();
  247. printf("Fail torque seq test\n");
  248. }
  249. point_check(test,p,q,r,0,0,p,pp,"torque eq");
  250. point_check(test,p,q,r,0,0,p+q,q+p,"commute add");
  251. point_check(test,p,q,r,0,0,(p-q)+q,p,"correct sub");
  252. point_check(test,p,q,r,0,0,p+(q+r),(p+q)+r,"assoc add");
  253. point_check(test,p,q,r,0,0,p.times_two(),p+p,"dbl add");
  254. if (i%10) continue;
  255. point_check(test,p,q,r,x,0,x*(p+q),x*p+x*q,"distr mul");
  256. point_check(test,p,q,r,x,y,(x*y)*p,x*(y*p),"assoc mul");
  257. point_check(test,p,q,r,x,y,x*p+y*q,Point::double_scalarmul(x,p,y,q),"ds mul");
  258. point_check(test,base,q,r,x,y,x*base+y*q,q.non_secret_combo_with_base(y,x),"ds vt mul");
  259. point_check(test,p,q,r,x,0,Precomputed(p)*x,p*x,"precomp mul");
  260. point_check(test,p,q,r,0,0,r,
  261. Point::from_hash(buffer.slice(0,Point::HASH_BYTES))
  262. + Point::from_hash(buffer.slice(Point::HASH_BYTES,Point::HASH_BYTES)),
  263. "unih = hash+add"
  264. );
  265. point_check(test,p,q,r,x,0,Point(x.direct_scalarmul(decaf::SecureBuffer(p))),x*p,"direct mul");
  266. }
  267. }
  268. }; // template<decaf::GroupId GROUP>
  269. static void test_decaf() {
  270. Test test("Sample crypto");
  271. decaf::SpongeRng rng(decaf::Block("test_decaf"));
  272. decaf_255_symmetric_key_t proto1,proto2;
  273. decaf_255_private_key_t s1,s2;
  274. decaf_255_public_key_t p1,p2;
  275. decaf_255_signature_t sig;
  276. unsigned char shared1[1234],shared2[1234];
  277. const char *message = "Hello, world!";
  278. for (int i=0; i<NTESTS && test.passing_now; i++) {
  279. rng.read(decaf::TmpBuffer(proto1,sizeof(proto1)));
  280. rng.read(decaf::TmpBuffer(proto2,sizeof(proto2)));
  281. decaf_255_derive_private_key(s1,proto1);
  282. decaf_255_private_to_public(p1,s1);
  283. decaf_255_derive_private_key(s2,proto2);
  284. decaf_255_private_to_public(p2,s2);
  285. if (!decaf_255_shared_secret (shared1,sizeof(shared1),s1,p2)) {
  286. test.fail(); printf("Fail ss12\n");
  287. }
  288. if (!decaf_255_shared_secret (shared2,sizeof(shared2),s2,p1)) {
  289. test.fail(); printf("Fail ss21\n");
  290. }
  291. if (memcmp(shared1,shared2,sizeof(shared1))) {
  292. test.fail(); printf("Fail ss21 == ss12\n");
  293. }
  294. decaf_255_sign (sig,s1,(const unsigned char *)message,strlen(message));
  295. if (!decaf_255_verify (sig,p1,(const unsigned char *)message,strlen(message))) {
  296. test.fail(); printf("Fail sig ver\n");
  297. }
  298. }
  299. }
  300. int main(int argc, char **argv) {
  301. (void) argc; (void) argv;
  302. Tests<decaf::Ed255>::test_arithmetic();
  303. Tests<decaf::Ed255>::test_elligator();
  304. Tests<decaf::Ed255>::test_ec();
  305. test_decaf();
  306. if (passing) printf("Passed all tests.\n");
  307. return passing ? 0 : 1;
  308. }