You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

597 lines
19 KiB

  1. /**
  2. * @file test_decaf.cxx
  3. * @author Mike Hamburg
  4. *
  5. * @copyright
  6. * Copyright (c) 2015 Cryptography Research, Inc. \n
  7. * Released under the MIT License. See LICENSE.txt for license information.
  8. *
  9. * @brief C++ tests, because that's easier.
  10. */
  11. #include <decaf.hxx>
  12. #include <decaf/spongerng.hxx>
  13. #include <decaf/crypto.h>
  14. #include <decaf/crypto.hxx>
  15. #include <decaf/eddsa.hxx>
  16. #include <stdio.h>
  17. using namespace decaf;
  18. using namespace decaf::TOY;
  19. static bool passing = true;
  20. static const long NTESTS = 10000;
  21. class Test {
  22. public:
  23. bool passing_now;
  24. Test(const char *test) {
  25. passing_now = true;
  26. printf("%s...", test);
  27. if (strlen(test) < 27) printf("%*s",int(27-strlen(test)),"");
  28. fflush(stdout);
  29. }
  30. ~Test() {
  31. if (std::uncaught_exception()) {
  32. fail();
  33. printf(" due to uncaught exception.\n");
  34. }
  35. if (passing_now) printf("[PASS]\n");
  36. }
  37. void fail() {
  38. if (!passing_now) return;
  39. passing_now = passing = false;
  40. printf("[FAIL]\n");
  41. }
  42. };
  43. static uint64_t leint(const SecureBuffer &xx) {
  44. uint64_t out = 0;
  45. for (unsigned int i=0; i<xx.size() && i<sizeof(out); i++) {
  46. out |= uint64_t(xx[i]) << (8*i);
  47. }
  48. return out;
  49. }
  50. template<typename Group> struct Tests {
  51. typedef typename Group::Scalar Scalar;
  52. typedef typename Group::Point Point;
  53. typedef typename Group::DhLadder DhLadder;
  54. typedef typename Group::Precomputed Precomputed;
  55. static void print(const char *name, const Scalar &x) {
  56. unsigned char buffer[Scalar::SER_BYTES];
  57. x.serialize_into(buffer);
  58. printf(" %s = 0x", name);
  59. for (int i=sizeof(buffer)-1; i>=0; i--) {
  60. printf("%02x", buffer[i]);
  61. }
  62. printf("\n");
  63. }
  64. static void hexprint(const char *name, const SecureBuffer &buffer) {
  65. printf(" %s = 0x", name);
  66. for (int i=buffer.size()-1; i>=0; i--) {
  67. printf("%02x", buffer[i]);
  68. }
  69. printf("\n");
  70. }
  71. static void print(const char *name, const Point &x) {
  72. unsigned char buffer[Point::SER_BYTES];
  73. x.serialize_into(buffer);
  74. printf(" %s = 0x", name);
  75. for (int i=Point::SER_BYTES-1; i>=0; i--) {
  76. printf("%02x", buffer[i]);
  77. }
  78. printf("\n");
  79. }
  80. static bool arith_check(
  81. Test &test,
  82. const Scalar &x,
  83. const Scalar &y,
  84. const Scalar &z,
  85. const Scalar &l,
  86. const Scalar &r,
  87. const char *name
  88. ) {
  89. if (l == r) return true;
  90. test.fail();
  91. printf(" %s", name);
  92. print("x", x);
  93. print("y", y);
  94. print("z", z);
  95. print("lhs", l);
  96. print("rhs", r);
  97. return false;
  98. }
  99. static bool point_check(
  100. Test &test,
  101. const Point &p,
  102. const Point &q,
  103. const Point &R,
  104. const Scalar &x,
  105. const Scalar &y,
  106. const Point &l,
  107. const Point &r,
  108. const char *name
  109. ) {
  110. bool good = l==r;
  111. if (!p.validate()) { good = false; printf(" p invalid\n"); }
  112. if (!q.validate()) { good = false; printf(" q invalid\n"); }
  113. if (!r.validate()) { good = false; printf(" r invalid\n"); }
  114. if (!l.validate()) { good = false; printf(" l invalid\n"); }
  115. if (good) return true;
  116. test.fail();
  117. printf(" %s", name);
  118. print("x", x);
  119. print("y", y);
  120. print("p", p);
  121. print("q", q);
  122. print("r", R);
  123. print("lhs", r);
  124. print("rhs", l);
  125. return false;
  126. }
  127. static void test_arithmetic() {
  128. SpongeRng rng(Block("test_arithmetic"),SpongeRng::DETERMINISTIC);
  129. Test test("Arithmetic");
  130. Scalar x(0),y(0),z(0);
  131. arith_check(test,x,y,z,INT_MAX,(decaf_word_t)INT_MAX,"cast from max");
  132. arith_check(test,x,y,z,INT_MIN,-Scalar(1+(decaf_word_t)INT_MAX),"cast from min");
  133. for (int i=0; i<NTESTS*10 && test.passing_now; i++) {
  134. size_t sob = i % (2*Group::Scalar::SER_BYTES);
  135. SecureBuffer xx = rng.read(sob), yy = rng.read(sob), zz = rng.read(sob);
  136. Scalar x(xx);
  137. Scalar y(yy);
  138. Scalar z(zz);
  139. arith_check(test,x,y,z,x+y,y+x,"commute add");
  140. arith_check(test,x,y,z,x,x+0,"ident add");
  141. arith_check(test,x,y,z,x,x-0,"ident sub");
  142. arith_check(test,x,y,z,x+-x,0,"inverse add");
  143. arith_check(test,x,y,z,x-x,0,"inverse sub");
  144. arith_check(test,x,y,z,x-(x+1),-1,"inverse add2");
  145. arith_check(test,x,y,z,x+(y+z),(x+y)+z,"assoc add");
  146. arith_check(test,x,y,z,x*(y+z),x*y + x*z,"distributive mul/add");
  147. arith_check(test,x,y,z,x*(y-z),x*y - x*z,"distributive mul/add");
  148. arith_check(test,x,y,z,x*(y*z),(x*y)*z,"assoc mul");
  149. arith_check(test,x,y,z,x*y,y*x,"commute mul");
  150. arith_check(test,x,y,z,x,x*1,"ident mul");
  151. arith_check(test,x,y,z,0,x*0,"mul by 0");
  152. arith_check(test,x,y,z,-x,x*-1,"mul by -1");
  153. arith_check(test,x,y,z,x+x,x*2,"mul by 2");
  154. arith_check(test,x,y,z,-(x*y),(-x)*y,"neg prop mul");
  155. arith_check(test,x,y,z,x*y,(-x)*(-y),"double neg prop mul");
  156. arith_check(test,x,y,z,-(x+y),(-x)+(-y),"neg prop add");
  157. arith_check(test,x,y,z,x-y,(x)+(-y),"add neg sub");
  158. arith_check(test,x,y,z,(-x)-y,-(x+y),"neg add");
  159. if (sob <= 4) {
  160. uint64_t xi = leint(xx), yi = leint(yy);
  161. arith_check(test,x,y,z,x,xi,"parse consistency");
  162. arith_check(test,x,y,z,x+y,xi+yi,"add consistency");
  163. arith_check(test,x,y,z,x*y,xi*yi,"mul consistency");
  164. }
  165. if (i%20) continue;
  166. if (y!=0) arith_check(test,x,y,z,x*y/y,x,"invert");
  167. try {
  168. y = x/0;
  169. test.fail();
  170. printf(" Inverted zero!");
  171. print("x", x);
  172. print("y", y);
  173. } catch(CryptoException) {}
  174. }
  175. }
  176. static const Block sqrt_minus_one;
  177. static const Block minus_sqrt_minus_one;
  178. static const Block elli_patho; /* sqrt(1/(u(1-d))) */
  179. static void test_elligator() {
  180. SpongeRng rng(Block("test_elligator"),SpongeRng::DETERMINISTIC);
  181. Test test("Elligator");
  182. const int NHINTS = 1<<Point::INVERT_ELLIGATOR_WHICH_BITS;
  183. SecureBuffer *alts[NHINTS];
  184. bool successes[NHINTS];
  185. SecureBuffer *alts2[NHINTS];
  186. bool successes2[NHINTS];
  187. for (unsigned int i=0; i<NTESTS/10 && (i<10 || test.passing_now); i++) {
  188. size_t len = (i % (2*Point::HASH_BYTES + 3));
  189. SecureBuffer b1(len);
  190. if (i!=Point::HASH_BYTES) rng.read(b1); /* special test case */
  191. /* Pathological cases */
  192. if (i==1) b1[0] = 1;
  193. if (i==2 && sqrt_minus_one.size()) b1 = sqrt_minus_one;
  194. if (i==3 && minus_sqrt_minus_one.size()) b1 = minus_sqrt_minus_one;
  195. if (i==4 && elli_patho.size()) b1 = elli_patho;
  196. len = b1.size();
  197. Point s = Point::from_hash(b1), ss=s;
  198. for (unsigned int j=0; j<(i&3); j++) ss = ss.debugging_torque();
  199. ss = ss.debugging_pscale(rng);
  200. bool good = false;
  201. for (int j=0; j<NHINTS; j++) {
  202. alts[j] = new SecureBuffer(len);
  203. alts2[j] = new SecureBuffer(len);
  204. if (len > Point::HASH_BYTES)
  205. memcpy(&(*alts[j])[Point::HASH_BYTES], &b1[Point::HASH_BYTES], len-Point::HASH_BYTES);
  206. if (len > Point::HASH_BYTES)
  207. memcpy(&(*alts2[j])[Point::HASH_BYTES], &b1[Point::HASH_BYTES], len-Point::HASH_BYTES);
  208. successes[j] = decaf_successful( s.invert_elligator(*alts[j], j));
  209. successes2[j] = decaf_successful(ss.invert_elligator(*alts2[j],j));
  210. if (successes[j] != successes2[j]
  211. || (successes[j] && successes2[j] && *alts[j] != *alts2[j])
  212. ) {
  213. test.fail();
  214. printf(" Unscalable Elligator inversion: i=%d, hint=%d, s=%d,%d\n",i,j,
  215. -int(successes[j]),-int(successes2[j]));
  216. hexprint("x",b1);
  217. hexprint("X",*alts[j]);
  218. hexprint("X",*alts2[j]);
  219. }
  220. if (successes[j]) {
  221. good = good || (b1 == *alts[j]);
  222. for (int k=0; k<j; k++) {
  223. if (successes[k] && *alts[j] == *alts[k]) {
  224. test.fail();
  225. printf(" Duplicate Elligator inversion: i=%d, hints=%d, %d\n",i,j,k);
  226. hexprint("x",b1);
  227. hexprint("X",*alts[j]);
  228. }
  229. }
  230. if (s != Point::from_hash(*alts[j])) {
  231. test.fail();
  232. printf(" Fail Elligator inversion round-trip: i=%d, hint=%d %s\n",i,j,
  233. (s==-Point::from_hash(*alts[j])) ? "[output was -input]": "");
  234. hexprint("x",b1);
  235. hexprint("X",*alts[j]);
  236. }
  237. }
  238. }
  239. if (!good) {
  240. test.fail();
  241. printf(" %s Elligator inversion: i=%d\n",good ? "Passed" : "Failed", i);
  242. hexprint("B", b1);
  243. for (int j=0; j<NHINTS; j++) {
  244. printf(" %d: %s%s", j, successes[j] ? "succ" : "fail\n", (successes[j] && *alts[j] == b1) ? " [x]" : "");
  245. if (successes[j]) {
  246. hexprint("b", *alts[j]);
  247. }
  248. }
  249. printf("\n");
  250. }
  251. for (int j=0; j<NHINTS; j++) {
  252. delete alts[j];
  253. alts[j] = NULL;
  254. delete alts2[j];
  255. alts2[j] = NULL;
  256. }
  257. Point t(rng);
  258. point_check(test,t,t,t,0,0,t,Point::from_hash(t.steg_encode(rng)),"steg round-trip");
  259. }
  260. }
  261. static void test_ec() {
  262. SpongeRng rng(Block("test_ec"),SpongeRng::DETERMINISTIC);
  263. Test test("EC");
  264. Point id = Point::identity(), base = Point::base();
  265. point_check(test,id,id,id,0,0,Point::from_hash(""),id,"fh0");
  266. unsigned char enc[Point::SER_BYTES] = {0};
  267. if (Group::FIELD_MODULUS_TYPE == 3) {
  268. /* When p == 3 mod 4, the QNR is -1, so u*1^2 = -1 also produces the
  269. * identity.
  270. */
  271. point_check(test,id,id,id,0,0,Point::from_hash("\x01"),id,"fh1");
  272. }
  273. point_check(test,id,id,id,0,0,Point(FixedBlock<sizeof(enc)>(enc)),id,"decode [0]");
  274. try {
  275. enc[0] = 1;
  276. Point f((FixedBlock<sizeof(enc)>(enc)));
  277. test.fail();
  278. printf(" Allowed deserialize of [1]: %d", f==id);
  279. } catch (CryptoException) {
  280. /* ok */
  281. }
  282. if (sqrt_minus_one.size()) {
  283. try {
  284. Point f(sqrt_minus_one);
  285. test.fail();
  286. printf(" Allowed deserialize of [i]: %d", f==id);
  287. } catch (CryptoException) {
  288. /* ok */
  289. }
  290. }
  291. if (minus_sqrt_minus_one.size()) {
  292. try {
  293. Point f(minus_sqrt_minus_one);
  294. test.fail();
  295. printf(" Allowed deserialize of [-i]: %d", f==id);
  296. } catch (CryptoException) {
  297. /* ok */
  298. }
  299. }
  300. for (int i=0; i<NTESTS && test.passing_now; i++) {
  301. Scalar x(rng);
  302. Scalar y(rng);
  303. Point p(rng);
  304. Point q(rng);
  305. Point d1, d2;
  306. SecureBuffer buffer(2*Point::HASH_BYTES);
  307. rng.read(buffer);
  308. Point r = Point::from_hash(buffer);
  309. point_check(test,p,q,r,0,0,p,Point(p.serialize()),"round-trip");
  310. Point pp = p.debugging_torque().debugging_pscale(rng);
  311. if (!memeq(pp.serialize(),p.serialize())) {
  312. test.fail();
  313. printf(" Fail torque seq test\n");
  314. }
  315. if (!memeq((p-pp).serialize(),id.serialize())) {
  316. test.fail();
  317. printf(" Fail torque id test\n");
  318. }
  319. if (!memeq((p-p).serialize(),id.serialize())) {
  320. test.fail();
  321. printf(" Fail id test\n");
  322. }
  323. point_check(test,p,q,r,0,0,p,pp,"torque eq");
  324. point_check(test,p,q,r,0,0,p+q,q+p,"commute add");
  325. point_check(test,p,q,r,0,0,(p-q)+q,p,"correct sub");
  326. point_check(test,p,q,r,0,0,p+(q+r),(p+q)+r,"assoc add");
  327. point_check(test,p,q,r,0,0,p.times_two(),p+p,"dbl add");
  328. if (i%10) continue;
  329. point_check(test,p,q,r,0,0,p.times_two(),p*Scalar(2),"add times two");
  330. point_check(test,p,q,r,x,0,x*(p+q),x*p+x*q,"distr mul");
  331. point_check(test,p,q,r,x,y,(x*y)*p,x*(y*p),"assoc mul");
  332. point_check(test,p,q,r,x,y,x*p+y*q,Point::double_scalarmul(x,p,y,q),"double mul");
  333. p.dual_scalarmul(d1,d2,x,y);
  334. point_check(test,p,q,r,x,y,x*p,d1,"dual mul 1");
  335. point_check(test,p,q,r,x,y,y*p,d2,"dual mul 2");
  336. point_check(test,base,q,r,x,y,x*base+y*q,q.non_secret_combo_with_base(y,x),"ds vt mul");
  337. point_check(test,p,q,r,x,0,Precomputed(p)*x,p*x,"precomp mul");
  338. point_check(test,p,q,r,0,0,r,
  339. Point::from_hash(Buffer(buffer).slice(0,Point::HASH_BYTES))
  340. + Point::from_hash(Buffer(buffer).slice(Point::HASH_BYTES,Point::HASH_BYTES)),
  341. "unih = hash+add"
  342. );
  343. point_check(test,p,q,r,x,0,Point(x.direct_scalarmul(p.serialize())),x*p,"direct mul");
  344. q=p;
  345. for (int j=1; j<Group::REMOVED_COFACTOR; j<<=1) q = q.times_two();
  346. decaf_error_t error = r.decode_like_eddsa_and_ignore_cofactor_noexcept(
  347. p.mul_by_cofactor_and_encode_like_eddsa()
  348. );
  349. if (error != DECAF_SUCCESS) {
  350. test.fail();
  351. printf(" Decode like EdDSA failed.");
  352. }
  353. point_check(test,-q,q,r,i,0,q,r,"Encode like EdDSA round-trip");
  354. }
  355. }
  356. static void test_toy_crypto() {
  357. Test test("Toy crypto");
  358. SpongeRng rng(Block("test_decaf_crypto"),SpongeRng::DETERMINISTIC);
  359. for (int i=0; i<NTESTS && test.passing_now; i++) {
  360. try {
  361. PrivateKey<Group> priv1(rng), priv2(rng);
  362. PublicKey<Group> pub1(priv1), pub2(priv2);
  363. SecureBuffer message = rng.read(i);
  364. SecureBuffer sig(priv1.sign(message));
  365. pub1.verify(message, sig);
  366. SecureBuffer s1(priv1.shared_secret(pub2,32,true));
  367. SecureBuffer s2(priv2.shared_secret(pub1,32,false));
  368. if (!memeq(s1,s2)) {
  369. test.fail();
  370. printf(" Shared secrets disagree on iteration %d.\n",i);
  371. }
  372. } catch (CryptoException) {
  373. test.fail();
  374. printf(" Threw CryptoException.\n");
  375. }
  376. }
  377. }
  378. static const uint8_t rfc7748_1[DhLadder::PUBLIC_BYTES];
  379. static const uint8_t rfc7748_1000[DhLadder::PUBLIC_BYTES];
  380. static const uint8_t rfc7748_1000000[DhLadder::PUBLIC_BYTES];
  381. static void test_cfrg_crypto() {
  382. Test test("CFRG crypto");
  383. SpongeRng rng(Block("test_cfrg_crypto"),SpongeRng::DETERMINISTIC);
  384. for (int i=0; i<NTESTS && test.passing_now; i++) {
  385. FixedArrayBuffer<DhLadder::PUBLIC_BYTES> base(rng);
  386. FixedArrayBuffer<DhLadder::PRIVATE_BYTES> s1(rng), s2(rng);
  387. SecureBuffer p1 = DhLadder::shared_secret(base,s1);
  388. SecureBuffer p2 = DhLadder::shared_secret(base,s2);
  389. SecureBuffer ss1 = DhLadder::shared_secret(p2,s1);
  390. SecureBuffer ss2 = DhLadder::shared_secret(p1,s2);
  391. if (!memeq(ss1,ss2)) {
  392. test.fail();
  393. printf(" Shared secrets disagree on iteration %d.\n",i);
  394. }
  395. if (!memeq(
  396. DhLadder::shared_secret(DhLadder::base_point(),s1),
  397. DhLadder::generate_key(s1)
  398. )) {
  399. test.fail();
  400. printf(" Generated keys disagree on iteration %d.\n",i);
  401. }
  402. }
  403. }
  404. static const bool eddsa_prehashed[];
  405. static const Block eddsa_sk[], eddsa_pk[], eddsa_message[], eddsa_context[], eddsa_sig[];
  406. static void test_cfrg_vectors() {
  407. Test test("CFRG test vectors");
  408. SecureBuffer k = DhLadder::base_point();
  409. SecureBuffer u = DhLadder::base_point();
  410. int the_ntests = (NTESTS < 1000000) ? 1000 : 1000000;
  411. /* EdDSA */
  412. for (unsigned int t=0; eddsa_sk[t].size(); t++) {
  413. typename EdDSA<Group>::PrivateKey priv(eddsa_sk[t]);
  414. SecureBuffer eddsa_pk2 = priv.pub().serialize();
  415. if (!memeq(SecureBuffer(eddsa_pk[t]), eddsa_pk2)) {
  416. test.fail();
  417. printf(" EdDSA PK vectors disagree.");
  418. printf("\n Correct: ");
  419. for (unsigned i=0; i<eddsa_pk[t].size(); i++) printf("%02x", eddsa_pk[t][i]);
  420. printf("\n Incorrect: ");
  421. for (unsigned i=0; i<eddsa_pk2.size(); i++) printf("%02x", eddsa_pk2[i]);
  422. printf("\n");
  423. }
  424. SecureBuffer sig;
  425. if (eddsa_prehashed[t]) {
  426. typename EdDSA<Group>::PrivateKeyPh priv2(eddsa_sk[t]);
  427. if (priv2.SUPPORTS_CONTEXTS) {
  428. sig = priv2.sign_with_prehash(eddsa_message[t],eddsa_context[t]);
  429. } else {
  430. sig = priv2.sign_with_prehash(eddsa_message[t]);
  431. }
  432. } else {
  433. if (priv.SUPPORTS_CONTEXTS) {
  434. sig = priv.sign(eddsa_message[t],eddsa_context[t]);
  435. } else {
  436. sig = priv.sign(eddsa_message[t]);
  437. }
  438. }
  439. if (!memeq(SecureBuffer(eddsa_sig[t]),sig)) {
  440. test.fail();
  441. printf(" EdDSA sig vectors disagree.");
  442. printf("\n Correct: ");
  443. for (unsigned i=0; i<eddsa_sig[t].size(); i++) printf("%02x", eddsa_sig[t][i]);
  444. printf("\n Incorrect: ");
  445. for (unsigned i=0; i<sig.size(); i++) printf("%02x", sig[i]);
  446. printf("\n");
  447. }
  448. }
  449. /* X25519/X448 */
  450. for (int i=0; i<the_ntests && test.passing_now; i++) {
  451. SecureBuffer n = DhLadder::shared_secret(u,k);
  452. u = k; k = n;
  453. if (i==1-1) {
  454. if (!memeq(k,SecureBuffer(FixedBlock<DhLadder::PUBLIC_BYTES>(rfc7748_1)))) {
  455. test.fail();
  456. printf(" Test vectors disagree at 1.");
  457. }
  458. } else if (i==1000-1) {
  459. if (!memeq(k,SecureBuffer(FixedBlock<DhLadder::PUBLIC_BYTES>(rfc7748_1000)))) {
  460. test.fail();
  461. printf(" Test vectors disagree at 1000.");
  462. }
  463. } else if (i==1000000-1) {
  464. if (!memeq(k,SecureBuffer(FixedBlock<DhLadder::PUBLIC_BYTES>(rfc7748_1000000)))) {
  465. test.fail();
  466. printf(" Test vectors disagree at 1000000.");
  467. }
  468. }
  469. }
  470. }
  471. static void test_eddsa() {
  472. Test test("EdDSA");
  473. SpongeRng rng(Block("test_eddsa"),SpongeRng::DETERMINISTIC);
  474. for (int i=0; i<NTESTS && test.passing_now; i++) {
  475. typename EdDSA<Group>::PrivateKey priv(rng);
  476. typename EdDSA<Group>::PublicKey pub(priv);
  477. SecureBuffer message(i);
  478. rng.read(message);
  479. SecureBuffer context(priv.SUPPORTS_CONTEXTS ? i%256 : 0);
  480. rng.read(message);
  481. SecureBuffer sig = priv.sign(message,context);
  482. try {
  483. pub.verify(sig,message,context);
  484. } catch(CryptoException) {
  485. test.fail();
  486. printf(" Signature validation failed on sig %d\n", i);
  487. }
  488. }
  489. }
  490. static void run() {
  491. printf("Testing %s:\n",Group::name());
  492. test_arithmetic();
  493. test_elligator();
  494. test_ec();
  495. test_eddsa();
  496. test_cfrg_crypto();
  497. test_cfrg_vectors();
  498. test_toy_crypto();
  499. printf("\n");
  500. }
  501. }; /* template<GroupId GROUP> struct Tests */
  502. #include "vectors.inc.cxx"
  503. int main(int argc, char **argv) {
  504. (void) argc; (void) argv;
  505. run_for_all_curves<Tests>();
  506. if (passing) printf("Passed all tests.\n");
  507. return passing ? 0 : 1;
  508. }