diff options
author | Arlo Breault <arlolra@gmail.com> | 2022-03-09 19:48:16 -0500 |
---|---|---|
committer | Arlo Breault <arlolra@gmail.com> | 2022-03-16 15:43:10 -0400 |
commit | 829cacac5f7ecb2cc701a24061679814fc1841bc (patch) | |
tree | a3bc15a0d1a601bb2538faa4b4afdc3e94841d52 | |
parent | 6fd0f1ae5dd22bb30100353d80f681b70d879d92 (diff) | |
download | snowflake-829cacac5f7ecb2cc701a24061679814fc1841bc.tar.gz snowflake-829cacac5f7ecb2cc701a24061679814fc1841bc.zip |
Parse ClientPollRequest version in DecodeClientPollRequest
Instead of IPC.ClientOffers. This makes things consistent with
EncodeClientPollRequest which adds the version while serializing.
-rw-r--r-- | broker/http.go | 5 | ||||
-rw-r--r-- | broker/ipc.go | 45 | ||||
-rw-r--r-- | client/lib/rendezvous.go | 5 | ||||
-rw-r--r-- | client/lib/rendezvous_test.go | 5 | ||||
-rw-r--r-- | common/messages/client.go | 28 | ||||
-rw-r--r-- | common/messages/messages_test.go | 21 |
6 files changed, 52 insertions, 57 deletions
diff --git a/broker/http.go b/broker/http.go index 3b0ba1f..7acc465 100644 --- a/broker/http.go +++ b/broker/http.go @@ -146,8 +146,9 @@ func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) { if len(body) > 0 && body[0] == '{' { isLegacy = true req := messages.ClientPollRequest{ - Offer: string(body), - NAT: r.Header.Get("Snowflake-NAT-Type"), + Offer: string(body), + NAT: r.Header.Get("Snowflake-NAT-Type"), + Version: messages.ClientVersion1_0, } body, err = req.EncodeClientPollRequest() if err != nil { diff --git a/broker/ipc.go b/broker/ipc.go index b8359f6..768c0b7 100644 --- a/broker/ipc.go +++ b/broker/ipc.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "container/heap" "fmt" "log" @@ -21,12 +20,6 @@ const ( NATUnrestricted = "unrestricted" ) -type clientVersion int - -const ( - v1 clientVersion = iota -) - type IPC struct { ctx *BrokerContext } @@ -132,32 +125,16 @@ func sendClientResponse(resp *messages.ClientPollResponse, response *[]byte) err } func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error { - var version clientVersion - startTime := time.Now() - body := arg.Body - parts := bytes.SplitN(body, []byte("\n"), 2) - if len(parts) < 2 { - // no version number found - err := fmt.Errorf("unsupported message version") - return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response) - } - body = parts[1] - if string(parts[0]) == "1.0" { - version = v1 - } else { - err := fmt.Errorf("unsupported message version") + req, err := messages.DecodeClientPollRequest(arg.Body) + if err != nil { return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response) } var offer *ClientOffer - switch version { - case v1: - req, err := messages.DecodeClientPollRequest(body) - if err != nil { - return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response) - } + switch req.Version { + case messages.ClientVersion1_0: offer = &ClientOffer{ natType: req.NAT, sdp: []byte(req.Offer), @@ -188,8 +165,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error { i.ctx.metrics.clientRestrictedDeniedCount++ } i.ctx.metrics.lock.Unlock() - switch version { - case v1: + switch req.Version { + case messages.ClientVersion1_0: resp := &messages.ClientPollResponse{Error: messages.StrNoProxies} return sendClientResponse(resp, response) default: @@ -204,8 +181,6 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error { i.ctx.snowflakeLock.Unlock() snowflake.offerChannel <- offer - var err error - // Wait for the answer to be returned on the channel or timeout. select { case answer := <-snowflake.answerChannel: @@ -213,8 +188,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error { i.ctx.metrics.clientProxyMatchCount++ i.ctx.metrics.promMetrics.ClientPollTotal.With(prometheus.Labels{"nat": offer.natType, "status": "matched"}).Inc() i.ctx.metrics.lock.Unlock() - switch version { - case v1: + switch req.Version { + case messages.ClientVersion1_0: resp := &messages.ClientPollResponse{Answer: answer} err = sendClientResponse(resp, response) default: @@ -224,8 +199,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error { i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond case <-time.After(time.Second * ClientTimeout): log.Println("Client: Timed out.") - switch version { - case v1: + switch req.Version { + case messages.ClientVersion1_0: resp := &messages.ClientPollResponse{Error: messages.StrTimedOut} err = sendClientResponse(resp, response) default: diff --git a/client/lib/rendezvous.go b/client/lib/rendezvous.go index 0ce2744..e7543ad 100644 --- a/client/lib/rendezvous.go +++ b/client/lib/rendezvous.go @@ -122,8 +122,9 @@ func (bc *BrokerChannel) Negotiate(offer *webrtc.SessionDescription) ( // Encode the client poll request. bc.lock.Lock() req := &messages.ClientPollRequest{ - Offer: offerSDP, - NAT: bc.natType, + Offer: offerSDP, + NAT: bc.natType, + Version: messages.ClientVersion1_0, } encReq, err := req.EncodeClientPollRequest() bc.lock.Unlock() diff --git a/client/lib/rendezvous_test.go b/client/lib/rendezvous_test.go index 21b9f57..a233e7d 100644 --- a/client/lib/rendezvous_test.go +++ b/client/lib/rendezvous_test.go @@ -43,8 +43,9 @@ func (t errorTransport) RoundTrip(req *http.Request) (*http.Response, error) { // offer. func makeEncPollReq(offer string) []byte { encPollReq, err := (&messages.ClientPollRequest{ - Offer: offer, - NAT: nat.NATUnknown, + Offer: offer, + NAT: nat.NATUnknown, + Version: messages.ClientVersion1_0, }).EncodeClientPollRequest() if err != nil { panic(err) diff --git a/common/messages/client.go b/common/messages/client.go index 5a7d73b..2a35594 100644 --- a/common/messages/client.go +++ b/common/messages/client.go @@ -4,13 +4,14 @@ package messages import ( + "bytes" "encoding/json" "fmt" "git.torproject.org/pluggable-transports/snowflake.git/v2/common/nat" ) -const ClientVersion = "1.0" +const ClientVersion1_0 = "1.0" /* Client--Broker protocol v1.x specification: @@ -49,24 +50,41 @@ for the error. */ type ClientPollRequest struct { - Offer string `json:"offer"` - NAT string `json:"nat"` + Offer string `json:"offer"` + NAT string `json:"nat"` + Version string `json:"-"` } // Encodes a poll message from a snowflake client func (req *ClientPollRequest) EncodeClientPollRequest() ([]byte, error) { + if req.Version != ClientVersion1_0 { + return nil, fmt.Errorf("unsupported message version") + } body, err := json.Marshal(req) if err != nil { return nil, err } - return append([]byte(ClientVersion+"\n"), body...), nil + return append([]byte(req.Version+"\n"), body...), nil } // Decodes a poll message from a snowflake client func DecodeClientPollRequest(data []byte) (*ClientPollRequest, error) { + parts := bytes.SplitN(data, []byte("\n"), 2) + + if len(parts) < 2 { + // no version number found + return nil, fmt.Errorf("unsupported message version") + } + var message ClientPollRequest - err := json.Unmarshal(data, &message) + if string(parts[0]) == ClientVersion1_0 { + message.Version = ClientVersion1_0 + } else { + return nil, fmt.Errorf("unsupported message version") + } + + err := json.Unmarshal(parts[1], &message) if err != nil { return nil, err } diff --git a/common/messages/messages_test.go b/common/messages/messages_test.go index 0d8b450..e0aa2a8 100644 --- a/common/messages/messages_test.go +++ b/common/messages/messages_test.go @@ -1,7 +1,6 @@ package messages import ( - "bytes" "encoding/json" "fmt" "testing" @@ -286,14 +285,16 @@ func TestDecodeClientPollRequest(t *testing.T) { //version 1.0 client message "unknown", "fake", - `{"nat":"unknown","offer":"fake"}`, + `1.0 +{"nat":"unknown","offer":"fake"}`, nil, }, { //version 1.0 client message "unknown", "fake", - `{"offer":"fake"}`, + `1.0 +{"offer":"fake"}`, nil, }, { @@ -307,16 +308,17 @@ func TestDecodeClientPollRequest(t *testing.T) { //no offer "", "", - `{"nat":"unknown"}`, + `1.0 +{"nat":"unknown"}`, fmt.Errorf(""), }, } { req, err := DecodeClientPollRequest([]byte(test.data)) + So(err, ShouldHaveSameTypeAs, test.err) if test.err == nil { So(req.NAT, ShouldResemble, test.natType) So(req.Offer, ShouldResemble, test.offer) } - So(err, ShouldHaveSameTypeAs, test.err) } }) @@ -325,15 +327,12 @@ func TestDecodeClientPollRequest(t *testing.T) { func TestEncodeClientPollRequests(t *testing.T) { Convey("Context", t, func() { req1 := &ClientPollRequest{ - NAT: "unknown", - Offer: "fake", + NAT: "unknown", + Offer: "fake", + Version: ClientVersion1_0, } b, err := req1.EncodeClientPollRequest() So(err, ShouldEqual, nil) - fmt.Println(string(b)) - parts := bytes.SplitN(b, []byte("\n"), 2) - So(string(parts[0]), ShouldEqual, "1.0") - b = parts[1] req2, err := DecodeClientPollRequest(b) So(err, ShouldEqual, nil) So(req2, ShouldResemble, req1) |