From 8629a42cf6e4650b552925f7637761b8e7ee66e3 Mon Sep 17 00:00:00 2001 From: moneromooo-monero Date: Wed, 22 Aug 2018 22:30:14 +0000 Subject: [PATCH] bulletproofs: rework flow to use sarang's fast batch inversion code --- src/ringct/bulletproofs.cc | 231 ++++++++++++++++++++++++------------- 1 file changed, 154 insertions(+), 77 deletions(-) diff --git a/src/ringct/bulletproofs.cc b/src/ringct/bulletproofs.cc index 549e52296..d9961cb20 100644 --- a/src/ringct/bulletproofs.cc +++ b/src/ringct/bulletproofs.cc @@ -29,8 +29,6 @@ // Adapted from Java code by Sarang Noether #include -#include -#include #include #include "misc_log_ex.h" #include "common/perf_timer.h" @@ -289,37 +287,59 @@ static rct::keyV vector_dup(const rct::key &x, size_t N) return rct::keyV(N, x); } -static rct::key switch_endianness(rct::key k) +static rct::key sm(rct::key y, int n, const rct::key &x) { - std::reverse(k.bytes, k.bytes + sizeof(k)); - return k; + while (n--) + sc_mul(y.bytes, y.bytes, y.bytes); + sc_mul(y.bytes, y.bytes, x.bytes); + return y; } -/* Compute the inverse of a scalar, the stupid way */ +/* Compute the inverse of a scalar, the clever way */ static rct::key invert(const rct::key &x) { + rct::key _1, _10, _100, _11, _101, _111, _1001, _1011, _1111; + + _1 = x; + sc_mul(_10.bytes, _1.bytes, _1.bytes); + sc_mul(_100.bytes, _10.bytes, _10.bytes); + sc_mul(_11.bytes, _10.bytes, _1.bytes); + sc_mul(_101.bytes, _10.bytes, _11.bytes); + sc_mul(_111.bytes, _10.bytes, _101.bytes); + sc_mul(_1001.bytes, _10.bytes, _111.bytes); + sc_mul(_1011.bytes, _10.bytes, _1001.bytes); + sc_mul(_1111.bytes, _100.bytes, _1011.bytes); + rct::key inv; + sc_mul(inv.bytes, _1111.bytes, _1.bytes); - BN_CTX *ctx = BN_CTX_new(); - BIGNUM *X = BN_new(); - BIGNUM *L = BN_new(); - BIGNUM *I = BN_new(); - - BN_bin2bn(switch_endianness(x).bytes, sizeof(rct::key), X); - BN_bin2bn(switch_endianness(rct::curveOrder()).bytes, sizeof(rct::key), L); - - CHECK_AND_ASSERT_THROW_MES(BN_mod_inverse(I, X, L, ctx), "Failed to invert"); - - const int len = BN_num_bytes(I); - CHECK_AND_ASSERT_THROW_MES((size_t)len <= sizeof(rct::key), "Invalid number length"); - inv = rct::zero(); - BN_bn2bin(I, inv.bytes); - std::reverse(inv.bytes, inv.bytes + len); - - BN_free(I); - BN_free(L); - BN_free(X); - BN_CTX_free(ctx); + inv = sm(inv, 123 + 3, _101); + inv = sm(inv, 2 + 2, _11); + inv = sm(inv, 1 + 4, _1111); + inv = sm(inv, 1 + 4, _1111); + inv = sm(inv, 4, _1001); + inv = sm(inv, 2, _11); + inv = sm(inv, 1 + 4, _1111); + inv = sm(inv, 1 + 3, _101); + inv = sm(inv, 3 + 3, _101); + inv = sm(inv, 3, _111); + inv = sm(inv, 1 + 4, _1111); + inv = sm(inv, 2 + 3, _111); + inv = sm(inv, 2 + 2, _11); + inv = sm(inv, 1 + 4, _1011); + inv = sm(inv, 2 + 4, _1011); + inv = sm(inv, 6 + 4, _1001); + inv = sm(inv, 2 + 2, _11); + inv = sm(inv, 3 + 2, _11); + inv = sm(inv, 3 + 2, _11); + inv = sm(inv, 1 + 4, _1001); + inv = sm(inv, 1 + 3, _111); + inv = sm(inv, 2 + 4, _1111); + inv = sm(inv, 1 + 4, _1011); + inv = sm(inv, 3, _101); + inv = sm(inv, 2 + 4, _1111); + inv = sm(inv, 3, _101); + inv = sm(inv, 1 + 2, _11); #ifdef DEBUG_BP rct::key tmp; @@ -329,6 +349,34 @@ static rct::key invert(const rct::key &x) return inv; } +static rct::keyV invert(rct::keyV x) +{ + rct::keyV scratch; + scratch.reserve(x.size()); + + rct::key acc = rct::identity(); + for (size_t n = 0; n < x.size(); ++n) + { + scratch.push_back(acc); + if (n == 0) + acc = x[0]; + else + sc_mul(acc.bytes, acc.bytes, x[n].bytes); + } + + acc = invert(acc); + + rct::key tmp; + for (int i = x.size(); i-- > 0; ) + { + sc_mul(tmp.bytes, acc.bytes, x[i].bytes); + sc_mul(x[i].bytes, acc.bytes, scratch[i].bytes); + acc = tmp; + } + + return x; +} + /* Compute the slice of a vector */ static rct::keyV slice(const rct::keyV &a, size_t start, size_t stop) { @@ -702,6 +750,13 @@ Bulletproof bulletproof_PROVE(const std::vector &v, const rct::keyV &g return bulletproof_PROVE(sv, gamma); } +struct proof_data_t +{ + rct::key x, y, z, x_ip; + std::vector w; + size_t logM, inv_offset; +}; + /* Given a range proof, determine if it is valid */ bool bulletproof_VERIFY(const std::vector &proofs) { @@ -709,9 +764,17 @@ bool bulletproof_VERIFY(const std::vector &proofs) PERF_TIMER_START_BP(VERIFY); + const size_t logN = 6; + const size_t N = 1 << logN; + // sanity and figure out which proof is longest size_t max_length = 0; size_t nV = 0; + std::vector proof_data; + proof_data.reserve(proofs.size()); + size_t inv_offset = 0; + std::vector to_invert; + to_invert.reserve(11 * sizeof(proofs)); for (const Bulletproof *p: proofs) { const Bulletproof &proof = *p; @@ -729,46 +792,75 @@ bool bulletproof_VERIFY(const std::vector &proofs) max_length = std::max(max_length, proof.L.size()); nV += proof.V.size(); + + // Reconstruct the challenges + PERF_TIMER_START_BP(VERIFY_start); + proof_data.resize(proof_data.size() + 1); + proof_data_t &pd = proof_data.back(); + rct::key hash_cache = rct::hash_to_scalar(proof.V); + pd.y = hash_cache_mash(hash_cache, proof.A, proof.S); + CHECK_AND_ASSERT_MES(!(pd.y == rct::zero()), false, "y == 0"); + pd.z = hash_cache = rct::hash_to_scalar(pd.y); + CHECK_AND_ASSERT_MES(!(pd.z == rct::zero()), false, "z == 0"); + pd.x = hash_cache_mash(hash_cache, pd.z, proof.T1, proof.T2); + CHECK_AND_ASSERT_MES(!(pd.x == rct::zero()), false, "x == 0"); + pd.x_ip = hash_cache_mash(hash_cache, pd.x, proof.taux, proof.mu, proof.t); + CHECK_AND_ASSERT_MES(!(pd.x_ip == rct::zero()), false, "x_ip == 0"); + PERF_TIMER_STOP(VERIFY_start); + + size_t M; + for (pd.logM = 0; (M = 1< 0, false, "Zero rounds"); + + PERF_TIMER_START_BP(VERIFY_line_21_22); + // PAPER LINES 21-22 + // The inner product challenges are computed per round + pd.w.resize(rounds); + for (size_t i = 0; i < rounds; ++i) + { + pd.w[i] = hash_cache_mash(hash_cache, proof.L[i], proof.R[i]); + CHECK_AND_ASSERT_MES(!(pd.w[i] == rct::zero()), false, "w[i] == 0"); + } + PERF_TIMER_STOP(VERIFY_line_21_22); + + pd.inv_offset = inv_offset; + for (size_t i = 0; i < rounds; ++i) + to_invert.push_back(pd.w[i]); + to_invert.push_back(pd.y); + inv_offset += rounds + 1; } CHECK_AND_ASSERT_MES(max_length < 32, false, "At least one proof is too large"); size_t maxMN = 1u << max_length; - const size_t logN = 6; - const size_t N = 1 << logN; rct::key tmp; std::vector multiexp_data; multiexp_data.reserve(nV + (2 * (10/*logM*/ + logN) + 4) * proofs.size() + 2 * maxMN); + PERF_TIMER_START_BP(VERIFY_line_24_25_invert); + const std::vector inverses = invert(to_invert); + PERF_TIMER_STOP(VERIFY_line_24_25_invert); + // setup weighted aggregates rct::key z1 = rct::zero(); rct::key z3 = rct::zero(); rct::keyV z4(maxMN, rct::zero()), z5(maxMN, rct::zero()); rct::key y0 = rct::zero(), y1 = rct::zero(); + int proof_data_index = 0; for (const Bulletproof *p: proofs) { const Bulletproof &proof = *p; + const proof_data_t &pd = proof_data[proof_data_index++]; - size_t M, logM; - for (logM = 0; (M = 1< &proofs) // PAPER LINE 61 sc_muladd(y0.bytes, proof.taux.bytes, weight_y.bytes, y0.bytes); - const rct::keyV zpow = vector_powers(z, M+3); + const rct::keyV zpow = vector_powers(pd.z, M+3); rct::key k; - const rct::key ip1y = vector_power_sum(y, MN); + const rct::key ip1y = vector_power_sum(pd.y, MN); sc_mulsub(k.bytes, zpow[2].bytes, ip1y.bytes, rct::zero().bytes); for (size_t j = 1; j <= M; ++j) { @@ -795,7 +887,7 @@ bool bulletproof_VERIFY(const std::vector &proofs) PERF_TIMER_STOP(VERIFY_line_61); PERF_TIMER_START_BP(VERIFY_line_61rl_new); - sc_muladd(tmp.bytes, z.bytes, ip1y.bytes, k.bytes); + sc_muladd(tmp.bytes, pd.z.bytes, ip1y.bytes, k.bytes); sc_sub(tmp.bytes, proof.t.bytes, tmp.bytes); sc_muladd(y1.bytes, tmp.bytes, weight_y.bytes, y1.bytes); for (size_t j = 0; j < proof8_V.size(); j++) @@ -803,10 +895,10 @@ bool bulletproof_VERIFY(const std::vector &proofs) sc_mul(tmp.bytes, zpow[j+2].bytes, weight_y.bytes); multiexp_data.emplace_back(tmp, proof8_V[j]); } - sc_mul(tmp.bytes, x.bytes, weight_y.bytes); + sc_mul(tmp.bytes, pd.x.bytes, weight_y.bytes); multiexp_data.emplace_back(tmp, proof8_T1); rct::key xsq; - sc_mul(xsq.bytes, x.bytes, x.bytes); + sc_mul(xsq.bytes, pd.x.bytes, pd.x.bytes); sc_mul(tmp.bytes, xsq.bytes, weight_y.bytes); multiexp_data.emplace_back(tmp, proof8_T2); PERF_TIMER_STOP(VERIFY_line_61rl_new); @@ -814,49 +906,34 @@ bool bulletproof_VERIFY(const std::vector &proofs) PERF_TIMER_START_BP(VERIFY_line_62); // PAPER LINE 62 multiexp_data.emplace_back(weight_z, proof8_A); - sc_mul(tmp.bytes, x.bytes, weight_z.bytes); + sc_mul(tmp.bytes, pd.x.bytes, weight_z.bytes); multiexp_data.emplace_back(tmp, proof8_S); PERF_TIMER_STOP(VERIFY_line_62); // Compute the number of rounds for the inner product - const size_t rounds = logM+logN; + const size_t rounds = pd.logM+logN; CHECK_AND_ASSERT_MES(rounds > 0, false, "Zero rounds"); - PERF_TIMER_START_BP(VERIFY_line_21_22); - // PAPER LINES 21-22 - // The inner product challenges are computed per round - rct::keyV w(rounds); - for (size_t i = 0; i < rounds; ++i) - { - w[i] = hash_cache_mash(hash_cache, proof.L[i], proof.R[i]); - CHECK_AND_ASSERT_MES(!(w[i] == rct::zero()), false, "w[i] == 0"); - } - PERF_TIMER_STOP(VERIFY_line_21_22); - PERF_TIMER_START_BP(VERIFY_line_24_25); // Basically PAPER LINES 24-25 // Compute the curvepoints from G[i] and H[i] rct::key yinvpow = rct::identity(); rct::key ypow = rct::identity(); - PERF_TIMER_START_BP(VERIFY_line_24_25_invert); - const rct::key yinv = invert(y); - rct::keyV winv(rounds); - for (size_t i = 0; i < rounds; ++i) - winv[i] = invert(w[i]); - PERF_TIMER_STOP(VERIFY_line_24_25_invert); + const rct::key *winv = &inverses[pd.inv_offset]; + const rct::key yinv = inverses[pd.inv_offset + rounds]; // precalc PERF_TIMER_START_BP(VERIFY_line_24_25_precalc); rct::keyV w_cache(1< 0; --s) { - sc_mul(w_cache[s].bytes, w_cache[s/2].bytes, w[j].bytes); + sc_mul(w_cache[s].bytes, w_cache[s/2].bytes, pd.w[j].bytes); sc_mul(w_cache[s-1].bytes, w_cache[s/2].bytes, winv[j].bytes); } } @@ -876,18 +953,18 @@ bool bulletproof_VERIFY(const std::vector &proofs) sc_mul(h_scalar.bytes, h_scalar.bytes, w_cache[(~i) & (MN-1)].bytes); // Adjust the scalars using the exponents from PAPER LINE 62 - sc_add(g_scalar.bytes, g_scalar.bytes, z.bytes); + sc_add(g_scalar.bytes, g_scalar.bytes, pd.z.bytes); CHECK_AND_ASSERT_MES(2+i/N < zpow.size(), false, "invalid zpow index"); CHECK_AND_ASSERT_MES(i%N < twoN.size(), false, "invalid twoN index"); sc_mul(tmp.bytes, zpow[2+i/N].bytes, twoN[i%N].bytes); if (i == 0) { - sc_add(tmp.bytes, tmp.bytes, z.bytes); + sc_add(tmp.bytes, tmp.bytes, pd.z.bytes); sc_sub(h_scalar.bytes, h_scalar.bytes, tmp.bytes); } else { - sc_muladd(tmp.bytes, z.bytes, ypow.bytes, tmp.bytes); + sc_muladd(tmp.bytes, pd.z.bytes, ypow.bytes, tmp.bytes); sc_mulsub(h_scalar.bytes, tmp.bytes, yinvpow.bytes, h_scalar.bytes); } @@ -897,12 +974,12 @@ bool bulletproof_VERIFY(const std::vector &proofs) if (i == 0) { yinvpow = yinv; - ypow = y; + ypow = pd.y; } else if (i != MN-1) { sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes); - sc_mul(ypow.bytes, ypow.bytes, y.bytes); + sc_mul(ypow.bytes, ypow.bytes, pd.y.bytes); } } @@ -913,7 +990,7 @@ bool bulletproof_VERIFY(const std::vector &proofs) sc_muladd(z1.bytes, proof.mu.bytes, weight_z.bytes, z1.bytes); for (size_t i = 0; i < rounds; ++i) { - sc_mul(tmp.bytes, w[i].bytes, w[i].bytes); + sc_mul(tmp.bytes, pd.w[i].bytes, pd.w[i].bytes); sc_mul(tmp.bytes, tmp.bytes, weight_z.bytes); multiexp_data.emplace_back(tmp, proof8_L[i]); sc_mul(tmp.bytes, winv[i].bytes, winv[i].bytes); @@ -921,7 +998,7 @@ bool bulletproof_VERIFY(const std::vector &proofs) multiexp_data.emplace_back(tmp, proof8_R[i]); } sc_mulsub(tmp.bytes, proof.a.bytes, proof.b.bytes, proof.t.bytes); - sc_mul(tmp.bytes, tmp.bytes, x_ip.bytes); + sc_mul(tmp.bytes, tmp.bytes, pd.x_ip.bytes); sc_muladd(z3.bytes, tmp.bytes, weight_z.bytes, z3.bytes); PERF_TIMER_STOP(VERIFY_line_26_new); }