diff options
Diffstat (limited to 'proxy')
-rw-r--r-- | proxy/lib/snowflake.go | 76 |
1 files changed, 33 insertions, 43 deletions
diff --git a/proxy/lib/snowflake.go b/proxy/lib/snowflake.go index 39c4c2f..4157477 100644 --- a/proxy/lib/snowflake.go +++ b/proxy/lib/snowflake.go @@ -60,13 +60,13 @@ const ( ) const ( - // NATUnknown represents a NAT type which is unknown. + // NATUnknown is set if the proxy cannot connect to probetest. NATUnknown = "unknown" - // NATRestricted represents a restricted NAT. + // NATRestricted is set if the proxy times out when connecting to a symmetric NAT. NATRestricted = "restricted" - // NATUnrestricted represents an unrestricted NAT. + // NATUnrestricted is set if the proxy successfully connects to a symmetric NAT. NATUnrestricted = "unrestricted" ) @@ -99,6 +99,12 @@ func getCurrentNATType() string { return currentNATType } +func setCurrentNATType(newType string) { + currentNATTypeAccess.Lock() + defer currentNATTypeAccess.Unlock() + currentNATType = newType +} + var ( tokens *tokens_t config webrtc.Configuration @@ -694,9 +700,13 @@ func (sf *SnowflakeProxy) Start() error { } tokens = newTokens(sf.Capacity) - sf.checkNATType(config, sf.NATProbeURL) - currentNATTypeLoaded := getCurrentNATType() - sf.EventDispatcher.OnNewSnowflakeEvent(&event.EventOnCurrentNATTypeDetermined{CurNATType: currentNATTypeLoaded}) + err = sf.checkNATType(config, sf.NATProbeURL) + if err != nil { + // non-fatal error. Log it and continue + log.Printf(err.Error()) + setCurrentNATType(NATUnknown) + } + sf.EventDispatcher.OnNewSnowflakeEvent(&event.EventOnCurrentNATTypeDetermined{CurNATType: getCurrentNATType()}) NatRetestTask := task.Periodic{ Interval: sf.NATTypeMeasurementInterval, @@ -704,6 +714,9 @@ func (sf *SnowflakeProxy) Start() error { sf.checkNATType(config, sf.NATProbeURL) return nil }, + OnError: func(err error) { + log.Printf("Periodic probetest failed: %s, retaining current NAT type: %s", err.Error(), getCurrentNATType()) + }, } if sf.NATTypeMeasurementInterval != 0 { @@ -735,87 +748,64 @@ func (sf *SnowflakeProxy) Stop() { // checkNATType use probetest to determine NAT compatability by // attempting to connect with a known symmetric NAT. If success, // it is considered "unrestricted". If timeout it is considered "restricted" -func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) { +func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) error { probe, err := newSignalingServer(probeURL, false) if err != nil { - log.Printf("Error parsing url: %s", err.Error()) + return fmt.Errorf("Error parsing url: %w", err) } dataChan := make(chan struct{}) pc, err := sf.makeNewPeerConnection(config, dataChan) if err != nil { - log.Printf("error making WebRTC connection: %s", err) - return + return fmt.Errorf("Error making WebRTC connection: %w", err) } offer := pc.LocalDescription() log.Printf("Probetest offer: \n\t%s", strings.ReplaceAll(offer.SDP, "\n", "\n\t")) sdp, err := util.SerializeSessionDescription(offer) if err != nil { - log.Printf("Error encoding probe message: %s", err.Error()) - return + return fmt.Errorf("Error encoding probe message: %w", err) } // send offer body, err := messages.EncodePollResponse(sdp, true, "") if err != nil { - log.Printf("Error encoding probe message: %s", err.Error()) - return + return fmt.Errorf("Error encoding probe message: %w", err) } resp, err := probe.Post(probe.url.String(), bytes.NewBuffer(body)) if err != nil { - log.Printf("error polling probe: %s", err.Error()) - return + return fmt.Errorf("Error polling probe: %w", err) } sdp, _, err = messages.DecodeAnswerRequest(resp) if err != nil { - log.Printf("Error reading probe response: %s", err.Error()) - return + return fmt.Errorf("Error reading probe response: %w", err) } answer, err := util.DeserializeSessionDescription(sdp) if err != nil { - log.Printf("Error setting answer: %s", err.Error()) - return + return fmt.Errorf("Error setting answer: %w", err) } err = pc.SetRemoteDescription(*answer) if err != nil { - log.Printf("Error setting answer: %s", err.Error()) - return + return fmt.Errorf("Error setting answer: %w", err) } - currentNATTypeLoaded := getCurrentNATType() + prevNATType := getCurrentNATType() - currentNATTypeTestResult := NATUnknown select { case <-dataChan: - currentNATTypeTestResult = NATUnrestricted + setCurrentNATType(NATUnrestricted) case <-time.After(dataChannelTimeout): - currentNATTypeTestResult = NATRestricted - } - - currentNATTypeToStore := NATUnknown - switch currentNATTypeLoaded + "->" + currentNATTypeTestResult { - case NATUnrestricted + "->" + NATUnknown: - currentNATTypeToStore = NATUnrestricted - - case NATRestricted + "->" + NATUnknown: - currentNATTypeToStore = NATRestricted - - default: - currentNATTypeToStore = currentNATTypeTestResult + setCurrentNATType(NATRestricted) } - log.Printf("NAT Type measurement: %v -> %v = %v\n", currentNATTypeLoaded, currentNATTypeTestResult, currentNATTypeToStore) - - currentNATTypeAccess.Lock() - currentNATType = currentNATTypeToStore - currentNATTypeAccess.Unlock() + log.Printf("NAT Type measurement: %v -> %v\n", prevNATType, getCurrentNATType()) if err := pc.Close(); err != nil { log.Printf("error calling pc.Close: %v", err) } + return nil } |