diff options
-rw-r--r-- | changes/bug6538 | 4 | ||||
-rw-r--r-- | src/or/routerlist.c | 162 | ||||
-rw-r--r-- | src/or/routerlist.h | 5 | ||||
-rw-r--r-- | src/test/test.h | 4 | ||||
-rw-r--r-- | src/test/test_dir.c | 81 |
5 files changed, 160 insertions, 96 deletions
diff --git a/changes/bug6538 b/changes/bug6538 index fc9e583d52..03c168b60a 100644 --- a/changes/bug6538 +++ b/changes/bug6538 @@ -10,3 +10,7 @@ than it ran through the part of the loop before it had made its choice. Fix for bug 6538. + o Code simplifications and refactoring: + - Move the core of our "choose a weighted element at random" logic + into its own function, and give it unit tests. Now the logic is + testable, and a little less fragile too. diff --git a/src/or/routerlist.c b/src/or/routerlist.c index 801c4965ea..1c0aca8ad1 100644 --- a/src/or/routerlist.c +++ b/src/or/routerlist.c @@ -11,6 +11,7 @@ * servers. **/ +#define ROUTERLIST_PRIVATE #include "or.h" #include "circuitbuild.h" #include "config.h" @@ -1652,6 +1653,53 @@ router_get_advertised_bandwidth_capped(const routerinfo_t *router) return result; } +/** Pick a random element of <b>n_entries</b>-element array <b>entries</b>, + * choosing each element with a probability proportional to its value, and + * return the index of that element. If all elements are 0, choose an index + * at random. If <b>total_out</b> is provided, set it to the sum of all + * elements in the array. Return -1 on error. + */ +/* private */ int +choose_array_element_by_weight(const uint64_t *entries, int n_entries, + uint64_t *total_out) +{ + int i, i_chosen=-1, n_chosen=0; + uint64_t total_so_far = 0; + uint64_t rand_val; + uint64_t total = 0; + + for (i = 0; i < n_entries; ++i) + total += entries[i]; + + if (total_out) + *total_out = total; + + if (n_entries < 1) + return -1; + + if (total == 0) + return crypto_rand_int(n_entries); + + rand_val = crypto_rand_uint64(total); + + for (i = 0; i < n_entries; ++i) { + total_so_far += entries[i]; + if (total_so_far > rand_val) { + i_chosen = i; + n_chosen++; + /* Set rand_val to UINT_MAX rather than stopping the loop. This way, + * the time we spend in the loop does not leak which element we chose. */ + rand_val = UINT64_MAX; + } + } + tor_assert(total_so_far == total); + tor_assert(n_chosen == 1); + tor_assert(i_chosen >= 0); + tor_assert(i_chosen < n_entries); + + return i_chosen; +} + /** When weighting bridges, enforce these values as lower and upper * bound for believable bandwidth, because there is no way for us * to verify a bridge's bandwidth currently. */ @@ -1702,15 +1750,10 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl, bandwidth_weight_rule_t rule) { int64_t weight_scale; - uint64_t rand_bw; double Wg = -1, Wm = -1, We = -1, Wd = -1; double Wgb = -1, Wmb = -1, Web = -1, Wdb = -1; - uint64_t weighted_bw = 0, unweighted_bw = 0; + uint64_t weighted_bw = 0; uint64_t *bandwidths; - uint64_t tmp; - unsigned int i; - unsigned int i_chosen; - int have_unknown = 0; /* true iff sl contains element not in consensus. */ /* Can't choose exit and guard at same time */ tor_assert(rule == NO_WEIGHTING || @@ -1814,7 +1857,6 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl, } else if (node->ri) { /* bridge or other descriptor not in our consensus */ this_bw = bridge_get_advertised_bandwidth_bounded(node->ri); - have_unknown = 1; } else { /* We can't use this one. */ continue; @@ -1838,69 +1880,22 @@ smartlist_choose_node_by_bandwidth_weights(smartlist_t *sl, weight = 0.0; bandwidths[node_sl_idx] = tor_llround(weight*this_bw + 0.5); - weighted_bw += bandwidths[node_sl_idx]; - unweighted_bw += this_bw; if (is_me) sl_last_weighted_bw_of_me = bandwidths[node_sl_idx]; } SMARTLIST_FOREACH_END(node); - /* XXXX this is a kludge to expose these values. */ - sl_last_total_weighted_bw = weighted_bw; - log_debug(LD_CIRC, "Choosing node for rule %s based on weights " "Wg=%f Wm=%f We=%f Wd=%f with total bw "U64_FORMAT, bandwidth_weight_rule_to_string(rule), Wg, Wm, We, Wd, U64_PRINTF_ARG(weighted_bw)); - /* If there is no bandwidth, choose at random */ - if (weighted_bw == 0) { - /* Don't warn when using bridges/relays not in the consensus */ - if (!have_unknown) { -#define ZERO_BANDWIDTH_WARNING_INTERVAL (15) - static ratelim_t zero_bandwidth_warning_limit = - RATELIM_INIT(ZERO_BANDWIDTH_WARNING_INTERVAL); - char *msg; - if ((msg = rate_limit_log(&zero_bandwidth_warning_limit, - approx_time()))) { - log_warn(LD_CIRC, - "Weighted bandwidth is "U64_FORMAT" in node selection for " - "rule %s (unweighted was "U64_FORMAT") %s", - U64_PRINTF_ARG(weighted_bw), - bandwidth_weight_rule_to_string(rule), - U64_PRINTF_ARG(unweighted_bw), msg); - } - } + { + int idx = choose_array_element_by_weight(bandwidths, + smartlist_len(sl), + &sl_last_total_weighted_bw); tor_free(bandwidths); - return smartlist_choose(sl); - } - - rand_bw = crypto_rand_uint64(weighted_bw); - - /* Last, count through sl until we get to the element we picked */ - i_chosen = (unsigned)smartlist_len(sl); - tmp = 0; - for (i=0; i < (unsigned)smartlist_len(sl); i++) { - tmp += bandwidths[i]; - if (tmp > rand_bw) { - i_chosen = i; - rand_bw = UINT64_MAX; - } - } - i = i_chosen; - - if (i == (unsigned)smartlist_len(sl)) { - /* This was once possible due to round-off error, but shouldn't be able - * to occur any longer. */ - tor_fragile_assert(); - --i; - log_warn(LD_BUG, "Round-off error in computing bandwidth had an effect on " - " which router we chose. Please tell the developers. " - U64_FORMAT" "U64_FORMAT" "U64_FORMAT, - U64_PRINTF_ARG(tmp), U64_PRINTF_ARG(rand_bw), - U64_PRINTF_ARG(weighted_bw)); + return idx < 0 ? NULL : smartlist_get(sl, idx); } - tor_free(bandwidths); - return smartlist_get(sl, i); } /** Helper function: @@ -1921,14 +1916,12 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl, bandwidth_weight_rule_t rule) { unsigned int i; - unsigned int i_chosen; uint64_t *bandwidths; int is_exit; int is_guard; int is_fast; - uint64_t total_nonexit_bw = 0, total_exit_bw = 0, total_bw = 0; + uint64_t total_nonexit_bw = 0, total_exit_bw = 0; uint64_t total_nonguard_bw = 0, total_guard_bw = 0; - uint64_t rand_bw, tmp; double exit_weight; double guard_weight; int n_unknown = 0; @@ -2073,7 +2066,6 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl, if (guard_weight <= 0.0) guard_weight = 0.0; - total_bw = 0; sl_last_weighted_bw_of_me = 0; for (i=0; i < (unsigned)smartlist_len(sl); i++) { tor_assert(bandwidths[i] < UINT64_MAX); @@ -2087,15 +2079,12 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl, else if (is_exit) bandwidths[i] = tor_llround(bandwidths[i] * exit_weight); - total_bw += bandwidths[i]; if (i == (unsigned) me_idx) sl_last_weighted_bw_of_me = bandwidths[i]; } } - /* XXXX this is a kludge to expose these values. */ - sl_last_total_weighted_bw = total_bw; - +#if 0 log_debug(LD_CIRC, "Total weighted bw = "U64_FORMAT ", exit bw = "U64_FORMAT ", nonexit bw = "U64_FORMAT", exit weight = %f " @@ -2108,37 +2097,18 @@ smartlist_choose_node_by_bandwidth(smartlist_t *sl, exit_weight, (int)(rule == WEIGHT_FOR_EXIT), U64_PRINTF_ARG(total_guard_bw), U64_PRINTF_ARG(total_nonguard_bw), guard_weight, (int)(rule == WEIGHT_FOR_GUARD)); +#endif - /* Almost done: choose a random value from the bandwidth weights. */ - rand_bw = crypto_rand_uint64(total_bw); - - /* Last, count through sl until we get to the element we picked */ - tmp = 0; - i_chosen = (unsigned)smartlist_len(sl); - for (i=0; i < (unsigned)smartlist_len(sl); i++) { - tmp += bandwidths[i]; - - if (tmp > rand_bw) { - i_chosen = i; - rand_bw = UINT64_MAX; - } + { + int idx = choose_array_element_by_weight(bandwidths, + smartlist_len(sl), + &sl_last_total_weighted_bw); + tor_free(bandwidths); + tor_free(fast_bits); + tor_free(exit_bits); + tor_free(guard_bits); + return idx < 0 ? NULL : smartlist_get(sl, idx); } - i = i_chosen; - if (i == (unsigned)smartlist_len(sl)) { - /* This was once possible due to round-off error, but shouldn't be able - * to occur any longer. */ - tor_fragile_assert(); - --i; - log_warn(LD_BUG, "Round-off error in computing bandwidth had an effect on " - " which router we chose. Please tell the developers. " - U64_FORMAT " " U64_FORMAT " " U64_FORMAT, U64_PRINTF_ARG(tmp), - U64_PRINTF_ARG(rand_bw), U64_PRINTF_ARG(total_bw)); - } - tor_free(bandwidths); - tor_free(fast_bits); - tor_free(exit_bits); - tor_free(guard_bits); - return smartlist_get(sl, i); } /** Choose a random element of status list <b>sl</b>, weighted by diff --git a/src/or/routerlist.h b/src/or/routerlist.h index 8dcc6eb026..0b9b297514 100644 --- a/src/or/routerlist.h +++ b/src/or/routerlist.h @@ -216,5 +216,10 @@ int hex_digest_nickname_decode(const char *hexdigest, char *nickname_qualifier_out, char *nickname_out); +#ifdef ROUTERLIST_PRIVATE +int choose_array_element_by_weight(const uint64_t *entries, int n_entries, + uint64_t *total_out); +#endif + #endif diff --git a/src/test/test.h b/src/test/test.h index 0b6e6c60cb..6dcb9490bd 100644 --- a/src/test/test.h +++ b/src/test/test.h @@ -65,6 +65,10 @@ #define test_memeq_hex(expr1, hex) test_mem_op_hex(expr1, ==, hex) +#define tt_double_op(a,op,b) \ + tt_assert_test_type(a,b,#a" "#op" "#b,double,(val1_ op val2_),"%f", \ + TT_EXIT_TEST_FUNCTION) + const char *get_fname(const char *name); crypto_pk_t *pk_generate(int idx); diff --git a/src/test/test_dir.c b/src/test/test_dir.c index 83c612045b..ed0c5a1afb 100644 --- a/src/test/test_dir.c +++ b/src/test/test_dir.c @@ -7,6 +7,7 @@ #define DIRSERV_PRIVATE #define DIRVOTE_PRIVATE #define ROUTER_PRIVATE +#define ROUTERLIST_PRIVATE #define HIBERNATE_PRIVATE #include "or.h" #include "directory.h" @@ -1381,6 +1382,85 @@ test_dir_v3_networkstatus(void) ns_detached_signatures_free(dsig2); } +static void +test_dir_random_weighted(void *testdata) +{ + int histogram[10]; + uint64_t vals[10] = {3,1,2,4,6,0,7,5,8,9}, total=0; + uint64_t zeros[5] = {0,0,0,0,0}; + int i, choice; + const int n = 50000; + double max_sq_error; + (void) testdata; + + /* Try a ten-element array with values from 0 through 10. The values are + * in a scrambled order to make sure we don't depend on order. */ + memset(histogram,0,sizeof(histogram)); + for (i=0; i<10; ++i) + total += vals[i]; + tt_int_op(total, ==, 45); + for (i=0; i<n; ++i) { + uint64_t t; + choice = choose_array_element_by_weight(vals, 10, &t); + tt_int_op(t, ==, total); + tt_int_op(choice, >=, 0); + tt_int_op(choice, <, 10); + histogram[choice]++; + } + + /* Now see if we chose things about frequently enough. */ + max_sq_error = 0; + for (i=0; i<10; ++i) { + int expected = (int)(n*vals[i]/total); + double frac_diff = 0, sq; + TT_BLATHER((" %d : %5d vs %5d\n", (int)vals[i], histogram[i], expected)); + if (expected) + frac_diff = (histogram[i] - expected) / ((double)expected); + else + tt_int_op(histogram[i], ==, 0); + + sq = frac_diff * frac_diff; + if (sq > max_sq_error) + max_sq_error = sq; + } + /* It should almost always be much much less than this. If you want to + * figure out the odds, please feel free. */ + tt_double_op(max_sq_error, <, .05); + + /* Now try a singleton; do we choose it? */ + for (i = 0; i < 100; ++i) { + choice = choose_array_element_by_weight(vals, 1, NULL); + tt_int_op(choice, ==, 0); + } + + /* Now try an array of zeros. We should choose randomly. */ + memset(histogram,0,sizeof(histogram)); + for (i = 0; i < n; ++i) { + uint64_t t; + choice = choose_array_element_by_weight(zeros, 5, &t); + tt_int_op(t, ==, 0); + tt_int_op(choice, >=, 0); + tt_int_op(choice, <, 5); + histogram[choice]++; + } + /* Now see if we chose things about frequently enough. */ + max_sq_error = 0; + for (i=0; i<5; ++i) { + int expected = n/5; + double frac_diff = 0, sq; + TT_BLATHER((" %d : %5d vs %5d\n", (int)vals[i], histogram[i], expected)); + frac_diff = (histogram[i] - expected) / ((double)expected); + sq = frac_diff * frac_diff; + if (sq > max_sq_error) + max_sq_error = sq; + } + /* It should almost always be much much less than this. If you want to + * figure out the odds, please feel free. */ + tt_double_op(max_sq_error, <, .05); + done: + ; +} + #define DIR_LEGACY(name) \ { #name, legacy_test_helper, TT_FORK, &legacy_setup, test_dir_ ## name } @@ -1396,6 +1476,7 @@ struct testcase_t dir_tests[] = { DIR_LEGACY(measured_bw), DIR_LEGACY(param_voting), DIR_LEGACY(v3_networkstatus), + DIR(random_weighted), END_OF_TESTCASES }; |