Refactor the core of choosing by weights into a function

This eliminates duplicated code, and lets us test a hairy piece of
functionality.
This commit is contained in:
Nick Mathewson 2012-08-09 13:47:42 -04:00
parent 9bfb274abb
commit 07df4dd52d
5 changed files with 160 additions and 96 deletions

View File

@ -10,3 +10,7 @@
than it ran through the part of the loop before it had made its
choice. Fix for bug 6538.
o Code simplifications and refactoring:
- Move the core of our "choose a weighted element at random" logic
into its own function, and give it unit tests. Now the logic is
testable, and a little less fragile too.

View File

@ -11,6 +11,7 @@
* servers.
**/
#define ROUTERLIST_PRIVATE
#include "or.h"
#include "circuitbuild.h"
#include "config.h"
@ -1652,6 +1653,53 @@ router_get_advertised_bandwidth_capped(const routerinfo_t *router)
return result;
}
/** Pick a random element of <b>n_entries</b>-element array <b>entries</b>,
* choosing each element with a probability proportional to its value, and
* return the index of that element. If all elements are 0, choose an index
* at random. If <b>total_out</b> is provided, set it to the sum of all
* elements in the array. Return -1 on error.
*/
/* private */ int
choose_array_element_by_weight(const uint64_t *entries, int n_entries,
uint64_t *total_out)
{
int i, i_chosen=-1, n_chosen=0;
uint64_t total_so_far = 0;
uint64_t rand_val;
uint64_t total = 0;
for (i = 0; i < n_entries; ++i)
total += entries[i];
if (total_out)
*total_out = total;
if (n_entries < 1)
return -1;
if (total == 0)
return crypto_rand_int(n_entries);
rand_val = crypto_rand_uint64(total);
for (i = 0; i < n_entries; ++i) {
total_so_far += entries[i];
if (total_so_far > rand_val) {
i_chosen = i;
n_chosen++;
/* Set rand_val to UINT_MAX rather than stopping the loop. This way,
* the time we spend in the loop does not leak which element we chose. */
rand_val = UINT64_MAX;
}
}
tor_assert(total_so_far == total);
tor_assert(n_chosen == 1);
tor_assert(i_chosen >= 0);
tor_assert(i_chosen < n_entries);
return i_chosen;
}
/** When weighting bridges, enforce these values as lower and upper
* bound for believable bandwidth, because there is no way for us
* to verify a bridge's bandwidth currently. */
@ -1702,15 +1750,10 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl,
bandwidth_weight_rule_t rule)
{
int64_t weight_scale;
uint64_t rand_bw;
double Wg = -1, Wm = -1, We = -1, Wd = -1;
double Wgb = -1, Wmb = -1, Web = -1, Wdb = -1;
uint64_t weighted_bw = 0, unweighted_bw = 0;
uint64_t weighted_bw = 0;
uint64_t *bandwidths;
uint64_t tmp;
unsigned int i;
unsigned int i_chosen;
int have_unknown = 0; /* true iff sl contains element not in consensus. */
/* Can't choose exit and guard at same time */
tor_assert(rule == NO_WEIGHTING ||
@ -1814,7 +1857,6 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl,
} else if (node->ri) {
/* bridge or other descriptor not in our consensus */
this_bw = bridge_get_advertised_bandwidth_bounded(node->ri);
have_unknown = 1;
} else {
/* We can't use this one. */
continue;
@ -1838,69 +1880,22 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl,
weight = 0.0;
bandwidths[node_sl_idx] = tor_llround(weight*this_bw + 0.5);
weighted_bw += bandwidths[node_sl_idx];
unweighted_bw += this_bw;
if (is_me)
sl_last_weighted_bw_of_me = bandwidths[node_sl_idx];
} SMARTLIST_FOREACH_END(node);
/* XXXX this is a kludge to expose these values. */
sl_last_total_weighted_bw = weighted_bw;
log_debug(LD_CIRC, "Choosing node for rule %s based on weights "
"Wg=%f Wm=%f We=%f Wd=%f with total bw "U64_FORMAT,
bandwidth_weight_rule_to_string(rule),
Wg, Wm, We, Wd, U64_PRINTF_ARG(weighted_bw));
/* If there is no bandwidth, choose at random */
if (weighted_bw == 0) {
/* Don't warn when using bridges/relays not in the consensus */
if (!have_unknown) {
#define ZERO_BANDWIDTH_WARNING_INTERVAL (15)
static ratelim_t zero_bandwidth_warning_limit =
RATELIM_INIT(ZERO_BANDWIDTH_WARNING_INTERVAL);
char *msg;
if ((msg = rate_limit_log(&zero_bandwidth_warning_limit,
approx_time()))) {
log_warn(LD_CIRC,
"Weighted bandwidth is "U64_FORMAT" in node selection for "
"rule %s (unweighted was "U64_FORMAT") %s",
U64_PRINTF_ARG(weighted_bw),
bandwidth_weight_rule_to_string(rule),
U64_PRINTF_ARG(unweighted_bw), msg);
}
}
{
int idx = choose_array_element_by_weight(bandwidths,
smartlist_len(sl),
&sl_last_total_weighted_bw);
tor_free(bandwidths);
return smartlist_choose(sl);
return idx < 0 ? NULL : smartlist_get(sl, idx);
}
rand_bw = crypto_rand_uint64(weighted_bw);
/* Last, count through sl until we get to the element we picked */
i_chosen = (unsigned)smartlist_len(sl);
tmp = 0;
for (i=0; i < (unsigned)smartlist_len(sl); i++) {
tmp += bandwidths[i];
if (tmp > rand_bw) {
i_chosen = i;
rand_bw = UINT64_MAX;
}
}
i = i_chosen;
if (i == (unsigned)smartlist_len(sl)) {
/* This was once possible due to round-off error, but shouldn't be able
* to occur any longer. */
tor_fragile_assert();
--i;
log_warn(LD_BUG, "Round-off error in computing bandwidth had an effect on "
" which router we chose. Please tell the developers. "
U64_FORMAT" "U64_FORMAT" "U64_FORMAT,
U64_PRINTF_ARG(tmp), U64_PRINTF_ARG(rand_bw),
U64_PRINTF_ARG(weighted_bw));
}
tor_free(bandwidths);
return smartlist_get(sl, i);
}
/** Helper function:
@ -1921,14 +1916,12 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
bandwidth_weight_rule_t rule)
{
unsigned int i;
unsigned int i_chosen;
uint64_t *bandwidths;
int is_exit;
int is_guard;
int is_fast;
uint64_t total_nonexit_bw = 0, total_exit_bw = 0, total_bw = 0;
uint64_t total_nonexit_bw = 0, total_exit_bw = 0;
uint64_t total_nonguard_bw = 0, total_guard_bw = 0;
uint64_t rand_bw, tmp;
double exit_weight;
double guard_weight;
int n_unknown = 0;
@ -2073,7 +2066,6 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
if (guard_weight <= 0.0)
guard_weight = 0.0;
total_bw = 0;
sl_last_weighted_bw_of_me = 0;
for (i=0; i < (unsigned)smartlist_len(sl); i++) {
tor_assert(bandwidths[i] < UINT64_MAX);
@ -2087,15 +2079,12 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
else if (is_exit)
bandwidths[i] = tor_llround(bandwidths[i] * exit_weight);
total_bw += bandwidths[i];
if (i == (unsigned) me_idx)
sl_last_weighted_bw_of_me = bandwidths[i];
}
}
/* XXXX this is a kludge to expose these values. */
sl_last_total_weighted_bw = total_bw;
#if 0
log_debug(LD_CIRC, "Total weighted bw = "U64_FORMAT
", exit bw = "U64_FORMAT
", nonexit bw = "U64_FORMAT", exit weight = %f "
@ -2108,37 +2097,18 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl,
exit_weight, (int)(rule == WEIGHT_FOR_EXIT),
U64_PRINTF_ARG(total_guard_bw), U64_PRINTF_ARG(total_nonguard_bw),
guard_weight, (int)(rule == WEIGHT_FOR_GUARD));
#endif
/* Almost done: choose a random value from the bandwidth weights. */
rand_bw = crypto_rand_uint64(total_bw);
/* Last, count through sl until we get to the element we picked */
tmp = 0;
i_chosen = (unsigned)smartlist_len(sl);
for (i=0; i < (unsigned)smartlist_len(sl); i++) {
tmp += bandwidths[i];
if (tmp > rand_bw) {
i_chosen = i;
rand_bw = UINT64_MAX;
}
{
int idx = choose_array_element_by_weight(bandwidths,
smartlist_len(sl),
&sl_last_total_weighted_bw);
tor_free(bandwidths);
tor_free(fast_bits);
tor_free(exit_bits);
tor_free(guard_bits);
return idx < 0 ? NULL : smartlist_get(sl, idx);
}
i = i_chosen;
if (i == (unsigned)smartlist_len(sl)) {
/* This was once possible due to round-off error, but shouldn't be able
* to occur any longer. */
tor_fragile_assert();
--i;
log_warn(LD_BUG, "Round-off error in computing bandwidth had an effect on "
" which router we chose. Please tell the developers. "
U64_FORMAT " " U64_FORMAT " " U64_FORMAT, U64_PRINTF_ARG(tmp),
U64_PRINTF_ARG(rand_bw), U64_PRINTF_ARG(total_bw));
}
tor_free(bandwidths);
tor_free(fast_bits);
tor_free(exit_bits);
tor_free(guard_bits);
return smartlist_get(sl, i);
}
/** Choose a random element of status list <b>sl</b>, weighted by

View File

@ -216,5 +216,10 @@ int hex_digest_nickname_decode(const char *hexdigest,
char *nickname_qualifier_out,
char *nickname_out);
#ifdef ROUTERLIST_PRIVATE
int choose_array_element_by_weight(const uint64_t *entries, int n_entries,
uint64_t *total_out);
#endif
#endif

View File

@ -65,6 +65,10 @@
#define test_memeq_hex(expr1, hex) test_mem_op_hex(expr1, ==, hex)
#define tt_double_op(a,op,b) \
tt_assert_test_type(a,b,#a" "#op" "#b,double,(val1_ op val2_),"%f", \
TT_EXIT_TEST_FUNCTION)
const char *get_fname(const char *name);
crypto_pk_t *pk_generate(int idx);

View File

@ -7,6 +7,7 @@
#define DIRSERV_PRIVATE
#define DIRVOTE_PRIVATE
#define ROUTER_PRIVATE
#define ROUTERLIST_PRIVATE
#define HIBERNATE_PRIVATE
#include "or.h"
#include "directory.h"
@ -1381,6 +1382,85 @@ test_dir_v3_networkstatus(void)
ns_detached_signatures_free(dsig2);
}
static void
test_dir_random_weighted(void *testdata)
{
int histogram[10];
uint64_t vals[10] = {3,1,2,4,6,0,7,5,8,9}, total=0;
uint64_t zeros[5] = {0,0,0,0,0};
int i, choice;
const int n = 50000;
double max_sq_error;
(void) testdata;
/* Try a ten-element array with values from 0 through 10. The values are
* in a scrambled order to make sure we don't depend on order. */
memset(histogram,0,sizeof(histogram));
for (i=0; i<10; ++i)
total += vals[i];
tt_int_op(total, ==, 45);
for (i=0; i<n; ++i) {
uint64_t t;
choice = choose_array_element_by_weight(vals, 10, &t);
tt_int_op(t, ==, total);
tt_int_op(choice, >=, 0);
tt_int_op(choice, <, 10);
histogram[choice]++;
}
/* Now see if we chose things about frequently enough. */
max_sq_error = 0;
for (i=0; i<10; ++i) {
int expected = (int)(n*vals[i]/total);
double frac_diff = 0, sq;
TT_BLATHER((" %d : %5d vs %5d\n", (int)vals[i], histogram[i], expected));
if (expected)
frac_diff = (histogram[i] - expected) / ((double)expected);
else
tt_int_op(histogram[i], ==, 0);
sq = frac_diff * frac_diff;
if (sq > max_sq_error)
max_sq_error = sq;
}
/* It should almost always be much much less than this. If you want to
* figure out the odds, please feel free. */
tt_double_op(max_sq_error, <, .05);
/* Now try a singleton; do we choose it? */
for (i = 0; i < 100; ++i) {
choice = choose_array_element_by_weight(vals, 1, NULL);
tt_int_op(choice, ==, 0);
}
/* Now try an array of zeros. We should choose randomly. */
memset(histogram,0,sizeof(histogram));
for (i = 0; i < n; ++i) {
uint64_t t;
choice = choose_array_element_by_weight(zeros, 5, &t);
tt_int_op(t, ==, 0);
tt_int_op(choice, >=, 0);
tt_int_op(choice, <, 5);
histogram[choice]++;
}
/* Now see if we chose things about frequently enough. */
max_sq_error = 0;
for (i=0; i<5; ++i) {
int expected = n/5;
double frac_diff = 0, sq;
TT_BLATHER((" %d : %5d vs %5d\n", (int)vals[i], histogram[i], expected));
frac_diff = (histogram[i] - expected) / ((double)expected);
sq = frac_diff * frac_diff;
if (sq > max_sq_error)
max_sq_error = sq;
}
/* It should almost always be much much less than this. If you want to
* figure out the odds, please feel free. */
tt_double_op(max_sq_error, <, .05);
done:
;
}
#define DIR_LEGACY(name) \
{ #name, legacy_test_helper, TT_FORK, &legacy_setup, test_dir_ ## name }
@ -1396,6 +1476,7 @@ struct testcase_t dir_tests[] = {
DIR_LEGACY(measured_bw),
DIR_LEGACY(param_voting),
DIR_LEGACY(v3_networkstatus),
DIR(random_weighted),
END_OF_TESTCASES
};