From 8b4767221c9b0ff3015229b167e5be1331a16c12 Mon Sep 17 00:00:00 2001 From: moneromooo-monero Date: Sat, 25 Aug 2018 18:37:21 +0000 Subject: [PATCH] bulletproofs: speedup prover --- src/ringct/bulletproofs.cc | 54 ++++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/src/ringct/bulletproofs.cc b/src/ringct/bulletproofs.cc index a3af9264e..f1b821978 100644 --- a/src/ringct/bulletproofs.cc +++ b/src/ringct/bulletproofs.cc @@ -146,13 +146,14 @@ static rct::key vector_exponent(const rct::keyV &a, const rct::keyV &b) } /* Compute a custom vector-scalar commitment */ -static rct::key cross_vector_exponent8(size_t size, const std::vector &A, size_t Ao, const std::vector &B, size_t Bo, const rct::keyV &a, size_t ao, const rct::keyV &b, size_t bo, const ge_p3 *extra_point, const rct::key *extra_scalar) +static rct::key cross_vector_exponent8(size_t size, const std::vector &A, size_t Ao, const std::vector &B, size_t Bo, const rct::keyV &a, size_t ao, const rct::keyV &b, size_t bo, const rct::keyV *scale, const ge_p3 *extra_point, const rct::key *extra_scalar) { CHECK_AND_ASSERT_THROW_MES(size + Ao <= A.size(), "Incompatible size for A"); CHECK_AND_ASSERT_THROW_MES(size + Bo <= B.size(), "Incompatible size for B"); CHECK_AND_ASSERT_THROW_MES(size + ao <= a.size(), "Incompatible size for a"); CHECK_AND_ASSERT_THROW_MES(size + bo <= b.size(), "Incompatible size for b"); CHECK_AND_ASSERT_THROW_MES(size <= maxN*maxM, "size is too large"); + CHECK_AND_ASSERT_THROW_MES(!scale || size == scale->size() / 2, "Incompatible size for scale"); CHECK_AND_ASSERT_THROW_MES(!!extra_point == !!extra_scalar, "only one of extra point/scalar present"); std::vector multiexp_data; @@ -162,6 +163,8 @@ static rct::key cross_vector_exponent8(size_t size, const std::vector &A, sc_mul(multiexp_data[i*2].scalar.bytes, a[ao+i].bytes, INV_EIGHT.bytes);; multiexp_data[i*2].point = A[Ao+i]; sc_mul(multiexp_data[i*2+1].scalar.bytes, b[bo+i].bytes, INV_EIGHT.bytes); + if (scale) + sc_mul(multiexp_data[i*2+1].scalar.bytes, multiexp_data[i*2+1].scalar.bytes, (*scale)[Bo+i].bytes); multiexp_data[i*2+1].point = B[Bo+i]; } if (extra_point) @@ -232,7 +235,7 @@ static rct::keyV hadamard(const rct::keyV &a, const rct::keyV &b) } /* folds a curvepoint array using a two way scaled Hadamard product */ -static void hadamard_fold(std::vector &v, const rct::key &a, const rct::key &b) +static void hadamard_fold(std::vector &v, const rct::keyV *scale, const rct::key &a, const rct::key &b) { CHECK_AND_ASSERT_THROW_MES((v.size() & 1) == 0, "Vector size should be even"); const size_t sz = v.size() / 2; @@ -241,7 +244,10 @@ static void hadamard_fold(std::vector &v, const rct::key &a, const rct::k ge_dsmp c[2]; ge_dsm_precomp(c[0], &v[n]); ge_dsm_precomp(c[1], &v[sz + n]); - ge_double_scalarmult_precomp_vartime2_p3(&v[n], a.bytes, c[0], b.bytes, c[1]); + rct::key sa, sb; + if (scale) sc_mul(sa.bytes, a.bytes, (*scale)[n].bytes); else sa = a; + if (scale) sc_mul(sb.bytes, b.bytes, (*scale)[sz + n].bytes); else sb = b; + ge_double_scalarmult_precomp_vartime2_p3(&v[n], sa.bytes, c[0], sb.bytes, c[1]); } v.resize(sz); } @@ -258,14 +264,24 @@ static rct::keyV vector_add(const rct::keyV &a, const rct::keyV &b) return res; } -/* Subtract two vectors */ -static rct::keyV vector_subtract(const rct::keyV &a, const rct::keyV &b) +/* Add a scalar to all elements of a vector */ +static rct::keyV vector_add(const rct::keyV &a, const rct::key &b) { - CHECK_AND_ASSERT_THROW_MES(a.size() == b.size(), "Incompatible sizes of a and b"); rct::keyV res(a.size()); for (size_t i = 0; i < a.size(); ++i) { - sc_sub(res[i].bytes, a[i].bytes, b[i].bytes); + sc_add(res[i].bytes, a[i].bytes, b.bytes); + } + return res; +} + +/* Subtract a scalar from all elements of a vector */ +static rct::keyV vector_subtract(const rct::keyV &a, const rct::key &b) +{ + rct::keyV res(a.size()); + for (size_t i = 0; i < a.size(); ++i) + { + sc_sub(res[i].bytes, a[i].bytes, b.bytes); } return res; } @@ -549,8 +565,7 @@ try_again: } // Polynomial construction by coefficients - const auto zMN = vector_dup(z, MN); - rct::keyV l0 = vector_subtract(aL, zMN); + rct::keyV l0 = vector_subtract(aL, z); const rct::keyV &l1 = sL; // This computes the ugly sum/concatenation from PAPER LINE 65 @@ -570,7 +585,7 @@ try_again: } } - rct::keyV r0 = vector_add(aR, zMN); + rct::keyV r0 = vector_add(aR, z); const auto yMN = vector_powers(y, MN); r0 = hadamard(r0, yMN); r0 = vector_add(r0, zero_twos); @@ -658,12 +673,15 @@ try_again: rct::keyV aprime(MN); rct::keyV bprime(MN); const rct::key yinv = invert(y); - rct::key yinvpow = rct::identity(); + rct::keyV yinvpow(MN); + yinvpow[0] = rct::identity(); + yinvpow[1] = yinv; for (size_t i = 0; i < MN; ++i) { Gprime[i] = Gi_p3[i]; - ge_scalarmult_p3(&Hprime[i], yinvpow.bytes, &Hi_p3[i]); - sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes); + Hprime[i] = Hi_p3[i]; + if (i > 1) + sc_mul(yinvpow[i].bytes, yinvpow[i-1].bytes, yinv.bytes); aprime[i] = l[i]; bprime[i] = r[i]; } @@ -675,6 +693,7 @@ try_again: PERF_TIMER_START_BP(PROVE_step4); // PAPER LINE 13 + const rct::keyV *scale = &yinvpow; while (nprime > 1) { // PAPER LINE 15 @@ -689,9 +708,9 @@ try_again: // PAPER LINES 18-19 PERF_TIMER_START_BP(PROVE_LR); sc_mul(tmp.bytes, cL.bytes, x_ip.bytes); - L[round] = cross_vector_exponent8(nprime, Gprime, nprime, Hprime, 0, aprime, 0, bprime, nprime, &ge_p3_H, &tmp); + L[round] = cross_vector_exponent8(nprime, Gprime, nprime, Hprime, 0, aprime, 0, bprime, nprime, scale, &ge_p3_H, &tmp); sc_mul(tmp.bytes, cR.bytes, x_ip.bytes); - R[round] = cross_vector_exponent8(nprime, Gprime, 0, Hprime, nprime, aprime, nprime, bprime, 0, &ge_p3_H, &tmp); + R[round] = cross_vector_exponent8(nprime, Gprime, 0, Hprime, nprime, aprime, nprime, bprime, 0, scale, &ge_p3_H, &tmp); PERF_TIMER_STOP(PROVE_LR); // PAPER LINES 21-22 @@ -708,8 +727,8 @@ try_again: if (nprime > 1) { PERF_TIMER_START_BP(PROVE_hadamard2); - hadamard_fold(Gprime, winv, w[round]); - hadamard_fold(Hprime, w[round], winv); + hadamard_fold(Gprime, NULL, winv, w[round]); + hadamard_fold(Hprime, scale, w[round], winv); PERF_TIMER_STOP(PROVE_hadamard2); } @@ -719,6 +738,7 @@ try_again: bprime = vector_add(vector_scalar(slice(bprime, 0, nprime), winv), vector_scalar(slice(bprime, nprime, bprime.size()), w[round])); PERF_TIMER_STOP(PROVE_prime); + scale = NULL; ++round; } PERF_TIMER_STOP(PROVE_step4);