From a39310f111cef49ff630cc12fdebabc4df37ec28 Mon Sep 17 00:00:00 2001 From: Jordan Date: Mon, 14 Feb 2022 23:17:42 -0700 Subject: client, crawl: fix/simplify net.Dialer overrides --- client.go | 47 ++++++++++++++--------------------------------- cmd/crawl/crawl.go | 6 +++--- 2 files changed, 17 insertions(+), 36 deletions(-) diff --git a/client.go b/client.go index 2284e9e..b369a05 100644 --- a/client.go +++ b/client.go @@ -37,11 +37,21 @@ func NewHTTPClient() *http.Client { } } -// NewHTTPClientWithDNSOverride returns an http.Client suitable for -// crawling, with some additional DNS overrides. -func NewHTTPClientWithDNSOverride(dnsMap map[string]string) *http.Client { +// NewHTTPClientWithOverrides returns an http.Client suitable for +// crawling, with additional (optional) DNS and LocalAddr overrides. +func NewHTTPClientWithOverrides(dnsMap map[string]string, localAddr *net.IPAddr) *http.Client { jar, _ := cookiejar.New(nil) // nolint - dialer := new(net.Dialer) + var dialer *net.Dialer + if localAddr != nil { + localTCPAddr := net.TCPAddr{ + IP: localAddr.IP, + } + dialer = &net.Dialer{ + LocalAddr: &localTCPAddr, + } + } else { + dialer = new(net.Dialer) + } transport := &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { host, port, err := net.SplitHostPort(addr) @@ -66,32 +76,3 @@ func NewHTTPClientWithDNSOverride(dnsMap map[string]string) *http.Client { Jar: jar, } } - -// NewHTTPClientWithLocalAddrOverride returns an http.Client suitable for -// crawling, with a LocalAddr override for making outbound connections using -// an explicit interface -func NewHTTPClientWithLocalAddrOverride(addr *net.IPAddr) *http.Client { - jar, _ := cookiejar.New(nil) // nolint - localTCPAddr := net.TCPAddr{ - IP: addr.IP, - } - transport := &http.Transport{ - DialContext: (&net.Dialer{ - LocalAddr: &localTCPAddr, - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: false, - }).DialContext, - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, // nolint - }, - } - return &http.Client{ - Timeout: defaultClientTimeout, - Transport: transport, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - Jar: jar, - } -} diff --git a/cmd/crawl/crawl.go b/cmd/crawl/crawl.go index 63fda1a..8f84837 100644 --- a/cmd/crawl/crawl.go +++ b/cmd/crawl/crawl.go @@ -348,14 +348,14 @@ func main() { log.Fatal(err) } - httpClient = crawl.NewHTTPClientWithDNSOverride(dnsMap) - if *bindIP != "" { addr, err := net.ResolveIPAddr("ip", *bindIP) if err != nil { log.Fatal(err) } - httpClient = crawl.NewHTTPClientWithLocalAddrOverride(addr) + httpClient = crawl.NewHTTPClientWithOverrides(dnsMap, addr) + } else { + httpClient = crawl.NewHTTPClientWithOverrides(dnsMap, nil) } crawler, err := crawl.NewCrawler( -- cgit v1.2.3-54-g00ecf