aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJordan <me@jordan.im>2022-02-14 23:17:42 -0700
committerJordan <me@jordan.im>2022-02-14 23:17:42 -0700
commita39310f111cef49ff630cc12fdebabc4df37ec28 (patch)
treeb2f3778cb49ed2ed2c3283db079285ca0084b704
parenta6a6fef1c7cc7d6878e8aa36541565fb3e0c9747 (diff)
downloadcrawl-a39310f111cef49ff630cc12fdebabc4df37ec28.tar.gz
crawl-a39310f111cef49ff630cc12fdebabc4df37ec28.zip
client, crawl: fix/simplify net.Dialer overrides
-rw-r--r--client.go47
-rw-r--r--cmd/crawl/crawl.go6
2 files changed, 17 insertions, 36 deletions
diff --git a/client.go b/client.go
index 2284e9e..b369a05 100644
--- a/client.go
+++ b/client.go
@@ -37,11 +37,21 @@ func NewHTTPClient() *http.Client {
}
}
-// NewHTTPClientWithDNSOverride returns an http.Client suitable for
-// crawling, with some additional DNS overrides.
-func NewHTTPClientWithDNSOverride(dnsMap map[string]string) *http.Client {
+// NewHTTPClientWithOverrides returns an http.Client suitable for
+// crawling, with additional (optional) DNS and LocalAddr overrides.
+func NewHTTPClientWithOverrides(dnsMap map[string]string, localAddr *net.IPAddr) *http.Client {
jar, _ := cookiejar.New(nil) // nolint
- dialer := new(net.Dialer)
+ var dialer *net.Dialer
+ if localAddr != nil {
+ localTCPAddr := net.TCPAddr{
+ IP: localAddr.IP,
+ }
+ dialer = &net.Dialer{
+ LocalAddr: &localTCPAddr,
+ }
+ } else {
+ dialer = new(net.Dialer)
+ }
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
@@ -66,32 +76,3 @@ func NewHTTPClientWithDNSOverride(dnsMap map[string]string) *http.Client {
Jar: jar,
}
}
-
-// NewHTTPClientWithLocalAddrOverride returns an http.Client suitable for
-// crawling, with a LocalAddr override for making outbound connections using
-// an explicit interface
-func NewHTTPClientWithLocalAddrOverride(addr *net.IPAddr) *http.Client {
- jar, _ := cookiejar.New(nil) // nolint
- localTCPAddr := net.TCPAddr{
- IP: addr.IP,
- }
- transport := &http.Transport{
- DialContext: (&net.Dialer{
- LocalAddr: &localTCPAddr,
- Timeout: 30 * time.Second,
- KeepAlive: 30 * time.Second,
- DualStack: false,
- }).DialContext,
- TLSClientConfig: &tls.Config{
- InsecureSkipVerify: true, // nolint
- },
- }
- return &http.Client{
- Timeout: defaultClientTimeout,
- Transport: transport,
- CheckRedirect: func(req *http.Request, via []*http.Request) error {
- return http.ErrUseLastResponse
- },
- Jar: jar,
- }
-}
diff --git a/cmd/crawl/crawl.go b/cmd/crawl/crawl.go
index 63fda1a..8f84837 100644
--- a/cmd/crawl/crawl.go
+++ b/cmd/crawl/crawl.go
@@ -348,14 +348,14 @@ func main() {
log.Fatal(err)
}
- httpClient = crawl.NewHTTPClientWithDNSOverride(dnsMap)
-
if *bindIP != "" {
addr, err := net.ResolveIPAddr("ip", *bindIP)
if err != nil {
log.Fatal(err)
}
- httpClient = crawl.NewHTTPClientWithLocalAddrOverride(addr)
+ httpClient = crawl.NewHTTPClientWithOverrides(dnsMap, addr)
+ } else {
+ httpClient = crawl.NewHTTPClientWithOverrides(dnsMap, nil)
}
crawler, err := crawl.NewCrawler(