aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorArlo Breault <arlolra@gmail.com>2022-03-09 19:48:16 -0500
committerArlo Breault <arlolra@gmail.com>2022-03-16 15:43:10 -0400
commit829cacac5f7ecb2cc701a24061679814fc1841bc (patch)
treea3bc15a0d1a601bb2538faa4b4afdc3e94841d52
parent6fd0f1ae5dd22bb30100353d80f681b70d879d92 (diff)
downloadsnowflake-829cacac5f7ecb2cc701a24061679814fc1841bc.tar.gz
snowflake-829cacac5f7ecb2cc701a24061679814fc1841bc.zip
Parse ClientPollRequest version in DecodeClientPollRequest
Instead of IPC.ClientOffers. This makes things consistent with EncodeClientPollRequest which adds the version while serializing.
-rw-r--r--broker/http.go5
-rw-r--r--broker/ipc.go45
-rw-r--r--client/lib/rendezvous.go5
-rw-r--r--client/lib/rendezvous_test.go5
-rw-r--r--common/messages/client.go28
-rw-r--r--common/messages/messages_test.go21
6 files changed, 52 insertions, 57 deletions
diff --git a/broker/http.go b/broker/http.go
index 3b0ba1f..7acc465 100644
--- a/broker/http.go
+++ b/broker/http.go
@@ -146,8 +146,9 @@ func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
if len(body) > 0 && body[0] == '{' {
isLegacy = true
req := messages.ClientPollRequest{
- Offer: string(body),
- NAT: r.Header.Get("Snowflake-NAT-Type"),
+ Offer: string(body),
+ NAT: r.Header.Get("Snowflake-NAT-Type"),
+ Version: messages.ClientVersion1_0,
}
body, err = req.EncodeClientPollRequest()
if err != nil {
diff --git a/broker/ipc.go b/broker/ipc.go
index b8359f6..768c0b7 100644
--- a/broker/ipc.go
+++ b/broker/ipc.go
@@ -1,7 +1,6 @@
package main
import (
- "bytes"
"container/heap"
"fmt"
"log"
@@ -21,12 +20,6 @@ const (
NATUnrestricted = "unrestricted"
)
-type clientVersion int
-
-const (
- v1 clientVersion = iota
-)
-
type IPC struct {
ctx *BrokerContext
}
@@ -132,32 +125,16 @@ func sendClientResponse(resp *messages.ClientPollResponse, response *[]byte) err
}
func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
- var version clientVersion
-
startTime := time.Now()
- body := arg.Body
- parts := bytes.SplitN(body, []byte("\n"), 2)
- if len(parts) < 2 {
- // no version number found
- err := fmt.Errorf("unsupported message version")
- return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
- }
- body = parts[1]
- if string(parts[0]) == "1.0" {
- version = v1
- } else {
- err := fmt.Errorf("unsupported message version")
+ req, err := messages.DecodeClientPollRequest(arg.Body)
+ if err != nil {
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
}
var offer *ClientOffer
- switch version {
- case v1:
- req, err := messages.DecodeClientPollRequest(body)
- if err != nil {
- return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
- }
+ switch req.Version {
+ case messages.ClientVersion1_0:
offer = &ClientOffer{
natType: req.NAT,
sdp: []byte(req.Offer),
@@ -188,8 +165,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.metrics.clientRestrictedDeniedCount++
}
i.ctx.metrics.lock.Unlock()
- switch version {
- case v1:
+ switch req.Version {
+ case messages.ClientVersion1_0:
resp := &messages.ClientPollResponse{Error: messages.StrNoProxies}
return sendClientResponse(resp, response)
default:
@@ -204,8 +181,6 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.snowflakeLock.Unlock()
snowflake.offerChannel <- offer
- var err error
-
// Wait for the answer to be returned on the channel or timeout.
select {
case answer := <-snowflake.answerChannel:
@@ -213,8 +188,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.metrics.clientProxyMatchCount++
i.ctx.metrics.promMetrics.ClientPollTotal.With(prometheus.Labels{"nat": offer.natType, "status": "matched"}).Inc()
i.ctx.metrics.lock.Unlock()
- switch version {
- case v1:
+ switch req.Version {
+ case messages.ClientVersion1_0:
resp := &messages.ClientPollResponse{Answer: answer}
err = sendClientResponse(resp, response)
default:
@@ -224,8 +199,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond
case <-time.After(time.Second * ClientTimeout):
log.Println("Client: Timed out.")
- switch version {
- case v1:
+ switch req.Version {
+ case messages.ClientVersion1_0:
resp := &messages.ClientPollResponse{Error: messages.StrTimedOut}
err = sendClientResponse(resp, response)
default:
diff --git a/client/lib/rendezvous.go b/client/lib/rendezvous.go
index 0ce2744..e7543ad 100644
--- a/client/lib/rendezvous.go
+++ b/client/lib/rendezvous.go
@@ -122,8 +122,9 @@ func (bc *BrokerChannel) Negotiate(offer *webrtc.SessionDescription) (
// Encode the client poll request.
bc.lock.Lock()
req := &messages.ClientPollRequest{
- Offer: offerSDP,
- NAT: bc.natType,
+ Offer: offerSDP,
+ NAT: bc.natType,
+ Version: messages.ClientVersion1_0,
}
encReq, err := req.EncodeClientPollRequest()
bc.lock.Unlock()
diff --git a/client/lib/rendezvous_test.go b/client/lib/rendezvous_test.go
index 21b9f57..a233e7d 100644
--- a/client/lib/rendezvous_test.go
+++ b/client/lib/rendezvous_test.go
@@ -43,8 +43,9 @@ func (t errorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// offer.
func makeEncPollReq(offer string) []byte {
encPollReq, err := (&messages.ClientPollRequest{
- Offer: offer,
- NAT: nat.NATUnknown,
+ Offer: offer,
+ NAT: nat.NATUnknown,
+ Version: messages.ClientVersion1_0,
}).EncodeClientPollRequest()
if err != nil {
panic(err)
diff --git a/common/messages/client.go b/common/messages/client.go
index 5a7d73b..2a35594 100644
--- a/common/messages/client.go
+++ b/common/messages/client.go
@@ -4,13 +4,14 @@
package messages
import (
+ "bytes"
"encoding/json"
"fmt"
"git.torproject.org/pluggable-transports/snowflake.git/v2/common/nat"
)
-const ClientVersion = "1.0"
+const ClientVersion1_0 = "1.0"
/* Client--Broker protocol v1.x specification:
@@ -49,24 +50,41 @@ for the error.
*/
type ClientPollRequest struct {
- Offer string `json:"offer"`
- NAT string `json:"nat"`
+ Offer string `json:"offer"`
+ NAT string `json:"nat"`
+ Version string `json:"-"`
}
// Encodes a poll message from a snowflake client
func (req *ClientPollRequest) EncodeClientPollRequest() ([]byte, error) {
+ if req.Version != ClientVersion1_0 {
+ return nil, fmt.Errorf("unsupported message version")
+ }
body, err := json.Marshal(req)
if err != nil {
return nil, err
}
- return append([]byte(ClientVersion+"\n"), body...), nil
+ return append([]byte(req.Version+"\n"), body...), nil
}
// Decodes a poll message from a snowflake client
func DecodeClientPollRequest(data []byte) (*ClientPollRequest, error) {
+ parts := bytes.SplitN(data, []byte("\n"), 2)
+
+ if len(parts) < 2 {
+ // no version number found
+ return nil, fmt.Errorf("unsupported message version")
+ }
+
var message ClientPollRequest
- err := json.Unmarshal(data, &message)
+ if string(parts[0]) == ClientVersion1_0 {
+ message.Version = ClientVersion1_0
+ } else {
+ return nil, fmt.Errorf("unsupported message version")
+ }
+
+ err := json.Unmarshal(parts[1], &message)
if err != nil {
return nil, err
}
diff --git a/common/messages/messages_test.go b/common/messages/messages_test.go
index 0d8b450..e0aa2a8 100644
--- a/common/messages/messages_test.go
+++ b/common/messages/messages_test.go
@@ -1,7 +1,6 @@
package messages
import (
- "bytes"
"encoding/json"
"fmt"
"testing"
@@ -286,14 +285,16 @@ func TestDecodeClientPollRequest(t *testing.T) {
//version 1.0 client message
"unknown",
"fake",
- `{"nat":"unknown","offer":"fake"}`,
+ `1.0
+{"nat":"unknown","offer":"fake"}`,
nil,
},
{
//version 1.0 client message
"unknown",
"fake",
- `{"offer":"fake"}`,
+ `1.0
+{"offer":"fake"}`,
nil,
},
{
@@ -307,16 +308,17 @@ func TestDecodeClientPollRequest(t *testing.T) {
//no offer
"",
"",
- `{"nat":"unknown"}`,
+ `1.0
+{"nat":"unknown"}`,
fmt.Errorf(""),
},
} {
req, err := DecodeClientPollRequest([]byte(test.data))
+ So(err, ShouldHaveSameTypeAs, test.err)
if test.err == nil {
So(req.NAT, ShouldResemble, test.natType)
So(req.Offer, ShouldResemble, test.offer)
}
- So(err, ShouldHaveSameTypeAs, test.err)
}
})
@@ -325,15 +327,12 @@ func TestDecodeClientPollRequest(t *testing.T) {
func TestEncodeClientPollRequests(t *testing.T) {
Convey("Context", t, func() {
req1 := &ClientPollRequest{
- NAT: "unknown",
- Offer: "fake",
+ NAT: "unknown",
+ Offer: "fake",
+ Version: ClientVersion1_0,
}
b, err := req1.EncodeClientPollRequest()
So(err, ShouldEqual, nil)
- fmt.Println(string(b))
- parts := bytes.SplitN(b, []byte("\n"), 2)
- So(string(parts[0]), ShouldEqual, "1.0")
- b = parts[1]
req2, err := DecodeClientPollRequest(b)
So(err, ShouldEqual, nil)
So(req2, ShouldResemble, req1)