aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKokaKiwi <kokakiwi+tor@kokakiwi.net>2023-09-11 23:34:35 +0200
committerShelikhoo <xiaokangwang@outlook.com>2023-10-12 15:52:43 +0100
commit7142fa3ddb95906784b87ccd9fb2b1dac44d0c58 (patch)
treedf5d4f9b1c20f5c949452296c237708bd82365ca
parent6393af6bab0f7c3c95b11352d5c582d2000062fa (diff)
downloadsnowflake-7142fa3ddb95906784b87ccd9fb2b1dac44d0c58.tar.gz
snowflake-7142fa3ddb95906784b87ccd9fb2b1dac44d0c58.zip
fix(proxy): Correctly close connection pipe when dealing with error
-rw-r--r--proxy/lib/snowflake.go9
-rw-r--r--proxy/lib/webrtcconn.go7
2 files changed, 11 insertions, 5 deletions
diff --git a/proxy/lib/snowflake.go b/proxy/lib/snowflake.go
index 97bea3e..243b49d 100644
--- a/proxy/lib/snowflake.go
+++ b/proxy/lib/snowflake.go
@@ -452,12 +452,17 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(sdp *webrtc.SessionDescrip
var n int
n, err = pw.Write(msg.Data)
if err != nil {
- if inerr := pw.CloseWithError(err); inerr != nil {
- log.Printf("close with error generated an error: %v", inerr)
+ if inErr := pw.CloseWithError(err); inErr != nil {
+ log.Printf("close with error generated an error: %v", inErr)
}
+
+ return
}
+
conn.bytesLogger.AddOutbound(int64(n))
+
if n != len(msg.Data) {
+ // XXX: Maybe don't panic here and log an error instead?
panic("short write")
}
})
diff --git a/proxy/lib/webrtcconn.go b/proxy/lib/webrtcconn.go
index f84734f..024945a 100644
--- a/proxy/lib/webrtcconn.go
+++ b/proxy/lib/webrtcconn.go
@@ -2,6 +2,7 @@ package snowflake_proxy
import (
"context"
+ "errors"
"fmt"
"io"
"log"
@@ -62,7 +63,7 @@ func (c *webRTCConn) timeoutLoop(ctx context.Context) {
for {
select {
case <-timer.C:
- c.Close()
+ _ = c.Close()
log.Println("Closed connection due to inactivity")
return
case <-c.activity:
@@ -90,7 +91,7 @@ func (c *webRTCConn) Write(b []byte) (int, error) {
c.lock.Lock()
defer c.lock.Unlock()
if c.dc != nil {
- c.dc.Send(b)
+ _ = c.dc.Send(b)
if !c.isClosing && c.dc.BufferedAmount() >= maxBufferedAmount {
<-c.sendMoreCh
}
@@ -106,7 +107,7 @@ func (c *webRTCConn) Close() (err error) {
}
c.once.Do(func() {
c.cancelTimeoutLoop()
- err = c.pc.Close()
+ err = errors.Join(c.pr.Close(), c.pc.Close())
})
return
}