diff options
author | Shelikhoo <xiaokangwang@outlook.com> | 2022-02-16 11:11:37 +0000 |
---|---|---|
committer | Shelikhoo <xiaokangwang@outlook.com> | 2022-03-16 09:13:29 +0000 |
commit | 3132f680122e27bb9cfb957fbb29c3cbe73935cf (patch) | |
tree | c21ac78a73330c4d626e2d88ed6402e74244777c | |
parent | 8d5998b7441eb9e213b8d86052e94a27d7656495 (diff) | |
download | snowflake-3132f680122e27bb9cfb957fbb29c3cbe73935cf.tar.gz snowflake-3132f680122e27bb9cfb957fbb29c3cbe73935cf.zip |
Add connection expire time for uTLS pendingConn
-rw-r--r-- | common/utls/roundtripper.go | 47 |
1 files changed, 43 insertions, 4 deletions
diff --git a/common/utls/roundtripper.go b/common/utls/roundtripper.go index e2fc82b..df31ff4 100644 --- a/common/utls/roundtripper.go +++ b/common/utls/roundtripper.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "sync" + "time" utls "github.com/refraction-networking/utls" "golang.org/x/net/http2" @@ -19,7 +20,7 @@ func NewUTLSHTTPRoundTripper(clientHelloID utls.ClientHelloID, uTlsConfig *utls. config: uTlsConfig, connectWithH1: map[string]bool{}, backdropTransport: backdropTransport, - pendingConn: map[pendingConnKey]net.Conn{}, + pendingConn: map[pendingConnKey]*unclaimedConnection{}, removeSNI: removeSNI, } rtImpl.init() @@ -38,7 +39,7 @@ type uTLSHTTPRoundTripperImpl struct { backdropTransport http.RoundTripper accessDialingConnection sync.Mutex - pendingConn map[pendingConnKey]net.Conn + pendingConn map[pendingConnKey]*unclaimedConnection removeSNI bool } @@ -50,6 +51,7 @@ type pendingConnKey struct { var errEAGAIN = errors.New("incorrect ALPN negotiated, try again with another ALPN") var errEAGAINTooMany = errors.New("incorrect ALPN negotiated") +var errExpired = errors.New("connection have expired") func (r *uTLSHTTPRoundTripperImpl) RoundTrip(req *http.Request) (*http.Response, error) { if req.URL.Scheme != "https" { @@ -99,12 +101,15 @@ func getPendingConnectionID(dest string, alpnIsH2 bool) pendingConnKey { func (r *uTLSHTTPRoundTripperImpl) putConn(addr string, alpnIsH2 bool, conn net.Conn) { connId := getPendingConnectionID(addr, alpnIsH2) - r.pendingConn[connId] = conn + r.pendingConn[connId] = NewUnclaimedConnection(conn, time.Minute) } func (r *uTLSHTTPRoundTripperImpl) getConn(addr string, alpnIsH2 bool) net.Conn { connId := getPendingConnectionID(addr, alpnIsH2) if conn, ok := r.pendingConn[connId]; ok { - return conn + delete(r.pendingConn, connId) + if claimedConnection, err := conn.claimConnection(); err == nil { + return claimedConnection + } } return nil } @@ -189,3 +194,37 @@ func (r *uTLSHTTPRoundTripperImpl) init() { }, } } + +func NewUnclaimedConnection(conn net.Conn, expireTime time.Duration) *unclaimedConnection { + c := &unclaimedConnection{ + Conn: conn, + } + time.AfterFunc(expireTime, c.tick) + return c +} + +type unclaimedConnection struct { + net.Conn + claimed bool + access sync.Mutex +} + +func (c *unclaimedConnection) claimConnection() (net.Conn, error) { + c.access.Lock() + defer c.access.Unlock() + if !c.claimed { + c.claimed = true + return c.Conn, nil + } + return nil, errExpired +} + +func (c *unclaimedConnection) tick() { + c.access.Lock() + defer c.access.Unlock() + if !c.claimed { + c.claimed = true + c.Conn.Close() + c.Conn = nil + } +} |