bulletproofs: rework flow to use sarang's fast batch inversion code

This commit is contained in:
moneromooo-monero 2018-08-22 22:30:14 +00:00
parent fc9f7d9c81
commit 8629a42cf6
No known key found for this signature in database
GPG Key ID: 686F07454D6CEFC3

View File

@ -29,8 +29,6 @@
// Adapted from Java code by Sarang Noether // Adapted from Java code by Sarang Noether
#include <stdlib.h> #include <stdlib.h>
#include <openssl/ssl.h>
#include <openssl/bn.h>
#include <boost/thread/mutex.hpp> #include <boost/thread/mutex.hpp>
#include "misc_log_ex.h" #include "misc_log_ex.h"
#include "common/perf_timer.h" #include "common/perf_timer.h"
@ -289,37 +287,59 @@ static rct::keyV vector_dup(const rct::key &x, size_t N)
return rct::keyV(N, x); return rct::keyV(N, x);
} }
static rct::key switch_endianness(rct::key k) static rct::key sm(rct::key y, int n, const rct::key &x)
{ {
std::reverse(k.bytes, k.bytes + sizeof(k)); while (n--)
return k; sc_mul(y.bytes, y.bytes, y.bytes);
sc_mul(y.bytes, y.bytes, x.bytes);
return y;
} }
/* Compute the inverse of a scalar, the stupid way */ /* Compute the inverse of a scalar, the clever way */
static rct::key invert(const rct::key &x) static rct::key invert(const rct::key &x)
{ {
rct::key _1, _10, _100, _11, _101, _111, _1001, _1011, _1111;
_1 = x;
sc_mul(_10.bytes, _1.bytes, _1.bytes);
sc_mul(_100.bytes, _10.bytes, _10.bytes);
sc_mul(_11.bytes, _10.bytes, _1.bytes);
sc_mul(_101.bytes, _10.bytes, _11.bytes);
sc_mul(_111.bytes, _10.bytes, _101.bytes);
sc_mul(_1001.bytes, _10.bytes, _111.bytes);
sc_mul(_1011.bytes, _10.bytes, _1001.bytes);
sc_mul(_1111.bytes, _100.bytes, _1011.bytes);
rct::key inv; rct::key inv;
sc_mul(inv.bytes, _1111.bytes, _1.bytes);
BN_CTX *ctx = BN_CTX_new(); inv = sm(inv, 123 + 3, _101);
BIGNUM *X = BN_new(); inv = sm(inv, 2 + 2, _11);
BIGNUM *L = BN_new(); inv = sm(inv, 1 + 4, _1111);
BIGNUM *I = BN_new(); inv = sm(inv, 1 + 4, _1111);
inv = sm(inv, 4, _1001);
BN_bin2bn(switch_endianness(x).bytes, sizeof(rct::key), X); inv = sm(inv, 2, _11);
BN_bin2bn(switch_endianness(rct::curveOrder()).bytes, sizeof(rct::key), L); inv = sm(inv, 1 + 4, _1111);
inv = sm(inv, 1 + 3, _101);
CHECK_AND_ASSERT_THROW_MES(BN_mod_inverse(I, X, L, ctx), "Failed to invert"); inv = sm(inv, 3 + 3, _101);
inv = sm(inv, 3, _111);
const int len = BN_num_bytes(I); inv = sm(inv, 1 + 4, _1111);
CHECK_AND_ASSERT_THROW_MES((size_t)len <= sizeof(rct::key), "Invalid number length"); inv = sm(inv, 2 + 3, _111);
inv = rct::zero(); inv = sm(inv, 2 + 2, _11);
BN_bn2bin(I, inv.bytes); inv = sm(inv, 1 + 4, _1011);
std::reverse(inv.bytes, inv.bytes + len); inv = sm(inv, 2 + 4, _1011);
inv = sm(inv, 6 + 4, _1001);
BN_free(I); inv = sm(inv, 2 + 2, _11);
BN_free(L); inv = sm(inv, 3 + 2, _11);
BN_free(X); inv = sm(inv, 3 + 2, _11);
BN_CTX_free(ctx); inv = sm(inv, 1 + 4, _1001);
inv = sm(inv, 1 + 3, _111);
inv = sm(inv, 2 + 4, _1111);
inv = sm(inv, 1 + 4, _1011);
inv = sm(inv, 3, _101);
inv = sm(inv, 2 + 4, _1111);
inv = sm(inv, 3, _101);
inv = sm(inv, 1 + 2, _11);
#ifdef DEBUG_BP #ifdef DEBUG_BP
rct::key tmp; rct::key tmp;
@ -329,6 +349,34 @@ static rct::key invert(const rct::key &x)
return inv; return inv;
} }
static rct::keyV invert(rct::keyV x)
{
rct::keyV scratch;
scratch.reserve(x.size());
rct::key acc = rct::identity();
for (size_t n = 0; n < x.size(); ++n)
{
scratch.push_back(acc);
if (n == 0)
acc = x[0];
else
sc_mul(acc.bytes, acc.bytes, x[n].bytes);
}
acc = invert(acc);
rct::key tmp;
for (int i = x.size(); i-- > 0; )
{
sc_mul(tmp.bytes, acc.bytes, x[i].bytes);
sc_mul(x[i].bytes, acc.bytes, scratch[i].bytes);
acc = tmp;
}
return x;
}
/* Compute the slice of a vector */ /* Compute the slice of a vector */
static rct::keyV slice(const rct::keyV &a, size_t start, size_t stop) static rct::keyV slice(const rct::keyV &a, size_t start, size_t stop)
{ {
@ -702,6 +750,13 @@ Bulletproof bulletproof_PROVE(const std::vector<uint64_t> &v, const rct::keyV &g
return bulletproof_PROVE(sv, gamma); return bulletproof_PROVE(sv, gamma);
} }
struct proof_data_t
{
rct::key x, y, z, x_ip;
std::vector<rct::key> w;
size_t logM, inv_offset;
};
/* Given a range proof, determine if it is valid */ /* Given a range proof, determine if it is valid */
bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs) bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
{ {
@ -709,9 +764,17 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
PERF_TIMER_START_BP(VERIFY); PERF_TIMER_START_BP(VERIFY);
const size_t logN = 6;
const size_t N = 1 << logN;
// sanity and figure out which proof is longest // sanity and figure out which proof is longest
size_t max_length = 0; size_t max_length = 0;
size_t nV = 0; size_t nV = 0;
std::vector<proof_data_t> proof_data;
proof_data.reserve(proofs.size());
size_t inv_offset = 0;
std::vector<rct::key> to_invert;
to_invert.reserve(11 * sizeof(proofs));
for (const Bulletproof *p: proofs) for (const Bulletproof *p: proofs)
{ {
const Bulletproof &proof = *p; const Bulletproof &proof = *p;
@ -729,46 +792,75 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
max_length = std::max(max_length, proof.L.size()); max_length = std::max(max_length, proof.L.size());
nV += proof.V.size(); nV += proof.V.size();
// Reconstruct the challenges
PERF_TIMER_START_BP(VERIFY_start);
proof_data.resize(proof_data.size() + 1);
proof_data_t &pd = proof_data.back();
rct::key hash_cache = rct::hash_to_scalar(proof.V);
pd.y = hash_cache_mash(hash_cache, proof.A, proof.S);
CHECK_AND_ASSERT_MES(!(pd.y == rct::zero()), false, "y == 0");
pd.z = hash_cache = rct::hash_to_scalar(pd.y);
CHECK_AND_ASSERT_MES(!(pd.z == rct::zero()), false, "z == 0");
pd.x = hash_cache_mash(hash_cache, pd.z, proof.T1, proof.T2);
CHECK_AND_ASSERT_MES(!(pd.x == rct::zero()), false, "x == 0");
pd.x_ip = hash_cache_mash(hash_cache, pd.x, proof.taux, proof.mu, proof.t);
CHECK_AND_ASSERT_MES(!(pd.x_ip == rct::zero()), false, "x_ip == 0");
PERF_TIMER_STOP(VERIFY_start);
size_t M;
for (pd.logM = 0; (M = 1<<pd.logM) <= maxM && M < proof.V.size(); ++pd.logM);
CHECK_AND_ASSERT_MES(proof.L.size() == 6+pd.logM, false, "Proof is not the expected size");
const size_t rounds = pd.logM+logN;
CHECK_AND_ASSERT_MES(rounds > 0, false, "Zero rounds");
PERF_TIMER_START_BP(VERIFY_line_21_22);
// PAPER LINES 21-22
// The inner product challenges are computed per round
pd.w.resize(rounds);
for (size_t i = 0; i < rounds; ++i)
{
pd.w[i] = hash_cache_mash(hash_cache, proof.L[i], proof.R[i]);
CHECK_AND_ASSERT_MES(!(pd.w[i] == rct::zero()), false, "w[i] == 0");
}
PERF_TIMER_STOP(VERIFY_line_21_22);
pd.inv_offset = inv_offset;
for (size_t i = 0; i < rounds; ++i)
to_invert.push_back(pd.w[i]);
to_invert.push_back(pd.y);
inv_offset += rounds + 1;
} }
CHECK_AND_ASSERT_MES(max_length < 32, false, "At least one proof is too large"); CHECK_AND_ASSERT_MES(max_length < 32, false, "At least one proof is too large");
size_t maxMN = 1u << max_length; size_t maxMN = 1u << max_length;
const size_t logN = 6;
const size_t N = 1 << logN;
rct::key tmp; rct::key tmp;
std::vector<MultiexpData> multiexp_data; std::vector<MultiexpData> multiexp_data;
multiexp_data.reserve(nV + (2 * (10/*logM*/ + logN) + 4) * proofs.size() + 2 * maxMN); multiexp_data.reserve(nV + (2 * (10/*logM*/ + logN) + 4) * proofs.size() + 2 * maxMN);
PERF_TIMER_START_BP(VERIFY_line_24_25_invert);
const std::vector<rct::key> inverses = invert(to_invert);
PERF_TIMER_STOP(VERIFY_line_24_25_invert);
// setup weighted aggregates // setup weighted aggregates
rct::key z1 = rct::zero(); rct::key z1 = rct::zero();
rct::key z3 = rct::zero(); rct::key z3 = rct::zero();
rct::keyV z4(maxMN, rct::zero()), z5(maxMN, rct::zero()); rct::keyV z4(maxMN, rct::zero()), z5(maxMN, rct::zero());
rct::key y0 = rct::zero(), y1 = rct::zero(); rct::key y0 = rct::zero(), y1 = rct::zero();
int proof_data_index = 0;
for (const Bulletproof *p: proofs) for (const Bulletproof *p: proofs)
{ {
const Bulletproof &proof = *p; const Bulletproof &proof = *p;
const proof_data_t &pd = proof_data[proof_data_index++];
size_t M, logM; CHECK_AND_ASSERT_MES(proof.L.size() == 6+pd.logM, false, "Proof is not the expected size");
for (logM = 0; (M = 1<<logM) <= maxM && M < proof.V.size(); ++logM); const size_t M = 1 << pd.logM;
CHECK_AND_ASSERT_MES(proof.L.size() == 6+logM, false, "Proof is not the expected size");
const size_t MN = M*N; const size_t MN = M*N;
const rct::key weight_y = rct::skGen(); const rct::key weight_y = rct::skGen();
const rct::key weight_z = rct::skGen(); const rct::key weight_z = rct::skGen();
// Reconstruct the challenges
PERF_TIMER_START_BP(VERIFY_start);
rct::key hash_cache = rct::hash_to_scalar(proof.V);
rct::key y = hash_cache_mash(hash_cache, proof.A, proof.S);
CHECK_AND_ASSERT_MES(!(y == rct::zero()), false, "y == 0");
rct::key z = hash_cache = rct::hash_to_scalar(y);
CHECK_AND_ASSERT_MES(!(z == rct::zero()), false, "z == 0");
rct::key x = hash_cache_mash(hash_cache, z, proof.T1, proof.T2);
CHECK_AND_ASSERT_MES(!(x == rct::zero()), false, "x == 0");
rct::key x_ip = hash_cache_mash(hash_cache, x, proof.taux, proof.mu, proof.t);
CHECK_AND_ASSERT_MES(!(x_ip == rct::zero()), false, "x_ip == 0");
PERF_TIMER_STOP(VERIFY_start);
// pre-multiply some points by 8 // pre-multiply some points by 8
rct::keyV proof8_V = proof.V; for (rct::key &k: proof8_V) k = rct::scalarmult8(k); rct::keyV proof8_V = proof.V; for (rct::key &k: proof8_V) k = rct::scalarmult8(k);
rct::keyV proof8_L = proof.L; for (rct::key &k: proof8_L) k = rct::scalarmult8(k); rct::keyV proof8_L = proof.L; for (rct::key &k: proof8_L) k = rct::scalarmult8(k);
@ -782,10 +874,10 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
// PAPER LINE 61 // PAPER LINE 61
sc_muladd(y0.bytes, proof.taux.bytes, weight_y.bytes, y0.bytes); sc_muladd(y0.bytes, proof.taux.bytes, weight_y.bytes, y0.bytes);
const rct::keyV zpow = vector_powers(z, M+3); const rct::keyV zpow = vector_powers(pd.z, M+3);
rct::key k; rct::key k;
const rct::key ip1y = vector_power_sum(y, MN); const rct::key ip1y = vector_power_sum(pd.y, MN);
sc_mulsub(k.bytes, zpow[2].bytes, ip1y.bytes, rct::zero().bytes); sc_mulsub(k.bytes, zpow[2].bytes, ip1y.bytes, rct::zero().bytes);
for (size_t j = 1; j <= M; ++j) for (size_t j = 1; j <= M; ++j)
{ {
@ -795,7 +887,7 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
PERF_TIMER_STOP(VERIFY_line_61); PERF_TIMER_STOP(VERIFY_line_61);
PERF_TIMER_START_BP(VERIFY_line_61rl_new); PERF_TIMER_START_BP(VERIFY_line_61rl_new);
sc_muladd(tmp.bytes, z.bytes, ip1y.bytes, k.bytes); sc_muladd(tmp.bytes, pd.z.bytes, ip1y.bytes, k.bytes);
sc_sub(tmp.bytes, proof.t.bytes, tmp.bytes); sc_sub(tmp.bytes, proof.t.bytes, tmp.bytes);
sc_muladd(y1.bytes, tmp.bytes, weight_y.bytes, y1.bytes); sc_muladd(y1.bytes, tmp.bytes, weight_y.bytes, y1.bytes);
for (size_t j = 0; j < proof8_V.size(); j++) for (size_t j = 0; j < proof8_V.size(); j++)
@ -803,10 +895,10 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
sc_mul(tmp.bytes, zpow[j+2].bytes, weight_y.bytes); sc_mul(tmp.bytes, zpow[j+2].bytes, weight_y.bytes);
multiexp_data.emplace_back(tmp, proof8_V[j]); multiexp_data.emplace_back(tmp, proof8_V[j]);
} }
sc_mul(tmp.bytes, x.bytes, weight_y.bytes); sc_mul(tmp.bytes, pd.x.bytes, weight_y.bytes);
multiexp_data.emplace_back(tmp, proof8_T1); multiexp_data.emplace_back(tmp, proof8_T1);
rct::key xsq; rct::key xsq;
sc_mul(xsq.bytes, x.bytes, x.bytes); sc_mul(xsq.bytes, pd.x.bytes, pd.x.bytes);
sc_mul(tmp.bytes, xsq.bytes, weight_y.bytes); sc_mul(tmp.bytes, xsq.bytes, weight_y.bytes);
multiexp_data.emplace_back(tmp, proof8_T2); multiexp_data.emplace_back(tmp, proof8_T2);
PERF_TIMER_STOP(VERIFY_line_61rl_new); PERF_TIMER_STOP(VERIFY_line_61rl_new);
@ -814,49 +906,34 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
PERF_TIMER_START_BP(VERIFY_line_62); PERF_TIMER_START_BP(VERIFY_line_62);
// PAPER LINE 62 // PAPER LINE 62
multiexp_data.emplace_back(weight_z, proof8_A); multiexp_data.emplace_back(weight_z, proof8_A);
sc_mul(tmp.bytes, x.bytes, weight_z.bytes); sc_mul(tmp.bytes, pd.x.bytes, weight_z.bytes);
multiexp_data.emplace_back(tmp, proof8_S); multiexp_data.emplace_back(tmp, proof8_S);
PERF_TIMER_STOP(VERIFY_line_62); PERF_TIMER_STOP(VERIFY_line_62);
// Compute the number of rounds for the inner product // Compute the number of rounds for the inner product
const size_t rounds = logM+logN; const size_t rounds = pd.logM+logN;
CHECK_AND_ASSERT_MES(rounds > 0, false, "Zero rounds"); CHECK_AND_ASSERT_MES(rounds > 0, false, "Zero rounds");
PERF_TIMER_START_BP(VERIFY_line_21_22);
// PAPER LINES 21-22
// The inner product challenges are computed per round
rct::keyV w(rounds);
for (size_t i = 0; i < rounds; ++i)
{
w[i] = hash_cache_mash(hash_cache, proof.L[i], proof.R[i]);
CHECK_AND_ASSERT_MES(!(w[i] == rct::zero()), false, "w[i] == 0");
}
PERF_TIMER_STOP(VERIFY_line_21_22);
PERF_TIMER_START_BP(VERIFY_line_24_25); PERF_TIMER_START_BP(VERIFY_line_24_25);
// Basically PAPER LINES 24-25 // Basically PAPER LINES 24-25
// Compute the curvepoints from G[i] and H[i] // Compute the curvepoints from G[i] and H[i]
rct::key yinvpow = rct::identity(); rct::key yinvpow = rct::identity();
rct::key ypow = rct::identity(); rct::key ypow = rct::identity();
PERF_TIMER_START_BP(VERIFY_line_24_25_invert); const rct::key *winv = &inverses[pd.inv_offset];
const rct::key yinv = invert(y); const rct::key yinv = inverses[pd.inv_offset + rounds];
rct::keyV winv(rounds);
for (size_t i = 0; i < rounds; ++i)
winv[i] = invert(w[i]);
PERF_TIMER_STOP(VERIFY_line_24_25_invert);
// precalc // precalc
PERF_TIMER_START_BP(VERIFY_line_24_25_precalc); PERF_TIMER_START_BP(VERIFY_line_24_25_precalc);
rct::keyV w_cache(1<<rounds); rct::keyV w_cache(1<<rounds);
w_cache[0] = winv[0]; w_cache[0] = winv[0];
w_cache[1] = w[0]; w_cache[1] = pd.w[0];
for (size_t j = 1; j < rounds; ++j) for (size_t j = 1; j < rounds; ++j)
{ {
const size_t slots = 1<<(j+1); const size_t slots = 1<<(j+1);
for (size_t s = slots; s-- > 0; --s) for (size_t s = slots; s-- > 0; --s)
{ {
sc_mul(w_cache[s].bytes, w_cache[s/2].bytes, w[j].bytes); sc_mul(w_cache[s].bytes, w_cache[s/2].bytes, pd.w[j].bytes);
sc_mul(w_cache[s-1].bytes, w_cache[s/2].bytes, winv[j].bytes); sc_mul(w_cache[s-1].bytes, w_cache[s/2].bytes, winv[j].bytes);
} }
} }
@ -876,18 +953,18 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
sc_mul(h_scalar.bytes, h_scalar.bytes, w_cache[(~i) & (MN-1)].bytes); sc_mul(h_scalar.bytes, h_scalar.bytes, w_cache[(~i) & (MN-1)].bytes);
// Adjust the scalars using the exponents from PAPER LINE 62 // Adjust the scalars using the exponents from PAPER LINE 62
sc_add(g_scalar.bytes, g_scalar.bytes, z.bytes); sc_add(g_scalar.bytes, g_scalar.bytes, pd.z.bytes);
CHECK_AND_ASSERT_MES(2+i/N < zpow.size(), false, "invalid zpow index"); CHECK_AND_ASSERT_MES(2+i/N < zpow.size(), false, "invalid zpow index");
CHECK_AND_ASSERT_MES(i%N < twoN.size(), false, "invalid twoN index"); CHECK_AND_ASSERT_MES(i%N < twoN.size(), false, "invalid twoN index");
sc_mul(tmp.bytes, zpow[2+i/N].bytes, twoN[i%N].bytes); sc_mul(tmp.bytes, zpow[2+i/N].bytes, twoN[i%N].bytes);
if (i == 0) if (i == 0)
{ {
sc_add(tmp.bytes, tmp.bytes, z.bytes); sc_add(tmp.bytes, tmp.bytes, pd.z.bytes);
sc_sub(h_scalar.bytes, h_scalar.bytes, tmp.bytes); sc_sub(h_scalar.bytes, h_scalar.bytes, tmp.bytes);
} }
else else
{ {
sc_muladd(tmp.bytes, z.bytes, ypow.bytes, tmp.bytes); sc_muladd(tmp.bytes, pd.z.bytes, ypow.bytes, tmp.bytes);
sc_mulsub(h_scalar.bytes, tmp.bytes, yinvpow.bytes, h_scalar.bytes); sc_mulsub(h_scalar.bytes, tmp.bytes, yinvpow.bytes, h_scalar.bytes);
} }
@ -897,12 +974,12 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
if (i == 0) if (i == 0)
{ {
yinvpow = yinv; yinvpow = yinv;
ypow = y; ypow = pd.y;
} }
else if (i != MN-1) else if (i != MN-1)
{ {
sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes); sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes);
sc_mul(ypow.bytes, ypow.bytes, y.bytes); sc_mul(ypow.bytes, ypow.bytes, pd.y.bytes);
} }
} }
@ -913,7 +990,7 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
sc_muladd(z1.bytes, proof.mu.bytes, weight_z.bytes, z1.bytes); sc_muladd(z1.bytes, proof.mu.bytes, weight_z.bytes, z1.bytes);
for (size_t i = 0; i < rounds; ++i) for (size_t i = 0; i < rounds; ++i)
{ {
sc_mul(tmp.bytes, w[i].bytes, w[i].bytes); sc_mul(tmp.bytes, pd.w[i].bytes, pd.w[i].bytes);
sc_mul(tmp.bytes, tmp.bytes, weight_z.bytes); sc_mul(tmp.bytes, tmp.bytes, weight_z.bytes);
multiexp_data.emplace_back(tmp, proof8_L[i]); multiexp_data.emplace_back(tmp, proof8_L[i]);
sc_mul(tmp.bytes, winv[i].bytes, winv[i].bytes); sc_mul(tmp.bytes, winv[i].bytes, winv[i].bytes);
@ -921,7 +998,7 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
multiexp_data.emplace_back(tmp, proof8_R[i]); multiexp_data.emplace_back(tmp, proof8_R[i]);
} }
sc_mulsub(tmp.bytes, proof.a.bytes, proof.b.bytes, proof.t.bytes); sc_mulsub(tmp.bytes, proof.a.bytes, proof.b.bytes, proof.t.bytes);
sc_mul(tmp.bytes, tmp.bytes, x_ip.bytes); sc_mul(tmp.bytes, tmp.bytes, pd.x_ip.bytes);
sc_muladd(z3.bytes, tmp.bytes, weight_z.bytes, z3.bytes); sc_muladd(z3.bytes, tmp.bytes, weight_z.bytes, z3.bytes);
PERF_TIMER_STOP(VERIFY_line_26_new); PERF_TIMER_STOP(VERIFY_line_26_new);
} }