aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrad Fitzpatrick <bradfitz@golang.org>2014-04-17 13:28:40 -0700
committerBrad Fitzpatrick <bradfitz@golang.org>2014-04-17 13:28:40 -0700
commitc45392bae0332e1407c4468ee35a5176651b03fa (patch)
treeebaed86deb3a963e919a51c334781f8d72e53c76
parent576318c7bd9c1deeb5877df65dc26aeb53999ee2 (diff)
downloadgo-c45392bae0332e1407c4468ee35a5176651b03fa.tar.gz
go-c45392bae0332e1407c4468ee35a5176651b03fa.zip
net: fix probabilities in DNS SRV shuffleByWeight
Patch from msolo. Just moving it to a CL. The test fails before and passes with the fix. Fixes #7098 LGTM=msolo, rsc R=rsc, iant, msolo CC=golang-codereviews https://golang.org/cl/88900044
-rw-r--r--src/pkg/net/dnsclient.go4
-rw-r--r--src/pkg/net/dnsclient_test.go69
2 files changed, 71 insertions, 2 deletions
diff --git a/src/pkg/net/dnsclient.go b/src/pkg/net/dnsclient.go
index 01db437294..9bffa11f91 100644
--- a/src/pkg/net/dnsclient.go
+++ b/src/pkg/net/dnsclient.go
@@ -191,10 +191,10 @@ func (addrs byPriorityWeight) shuffleByWeight() {
}
for sum > 0 && len(addrs) > 1 {
s := 0
- n := rand.Intn(sum + 1)
+ n := rand.Intn(sum)
for i := range addrs {
s += int(addrs[i].Weight)
- if s >= n {
+ if s > n {
if i > 0 {
t := addrs[i]
copy(addrs[1:i+1], addrs[0:i])
diff --git a/src/pkg/net/dnsclient_test.go b/src/pkg/net/dnsclient_test.go
new file mode 100644
index 0000000000..435eb35506
--- /dev/null
+++ b/src/pkg/net/dnsclient_test.go
@@ -0,0 +1,69 @@
+// Copyright 2014 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "math/rand"
+ "testing"
+)
+
+func checkDistribution(t *testing.T, data []*SRV, margin float64) {
+ sum := 0
+ for _, srv := range data {
+ sum += int(srv.Weight)
+ }
+
+ results := make(map[string]int)
+
+ count := 1000
+ for j := 0; j < count; j++ {
+ d := make([]*SRV, len(data))
+ copy(d, data)
+ byPriorityWeight(d).shuffleByWeight()
+ key := d[0].Target
+ results[key] = results[key] + 1
+ }
+
+ actual := results[data[0].Target]
+ expected := float64(count) * float64(data[0].Weight) / float64(sum)
+ diff := float64(actual) - expected
+ t.Logf("actual: %v diff: %v e: %v m: %v", actual, diff, expected, margin)
+ if diff < 0 {
+ diff = -diff
+ }
+ if diff > (expected * margin) {
+ t.Errorf("missed target weight: expected %v, %v", expected, actual)
+ }
+}
+
+func testUniformity(t *testing.T, size int, margin float64) {
+ rand.Seed(1)
+ data := make([]*SRV, size)
+ for i := 0; i < size; i++ {
+ data[i] = &SRV{Target: string('a' + i), Weight: 1}
+ }
+ checkDistribution(t, data, margin)
+}
+
+func TestUniformity(t *testing.T) {
+ testUniformity(t, 2, 0.05)
+ testUniformity(t, 3, 0.10)
+ testUniformity(t, 10, 0.20)
+ testWeighting(t, 0.05)
+}
+
+func testWeighting(t *testing.T, margin float64) {
+ rand.Seed(1)
+ data := []*SRV{
+ {Target: "a", Weight: 60},
+ {Target: "b", Weight: 30},
+ {Target: "c", Weight: 10},
+ }
+ checkDistribution(t, data, margin)
+}
+
+func TestWeighting(t *testing.T) {
+ testWeighting(t, 0.05)
+}