From 37c649a8b693ba65a59eab5de3c01bf212f791ad Mon Sep 17 00:00:00 2001 From: ale Date: Sun, 23 Aug 2020 16:53:40 +0100 Subject: Allow setting DNS overrides using the --resolve option --- client.go | 43 ++++++++++++++++++++++++++++++++++++++++--- cmd/crawl/crawl.go | 46 ++++++++++++++++++++++++++++++++-------------- 2 files changed, 72 insertions(+), 17 deletions(-) diff --git a/client.go b/client.go index 45736f5..c028e42 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,9 @@ package crawl import ( + "context" "crypto/tls" + "net" "net/http" "net/http/cookiejar" "time" @@ -9,14 +11,19 @@ import ( var defaultClientTimeout = 60 * time.Second -// DefaultClient returns a http.Client suitable for crawling: does not -// follow redirects, accepts invalid TLS certificates, sets a +// DefaultClient points at a shared http.Client suitable for crawling: +// does not follow redirects, accepts invalid TLS certificates, sets a // reasonable timeout for requests. var DefaultClient *http.Client func init() { + DefaultClient = NewHTTPClient() +} + +// NewHTTPClient returns an http.Client suitable for crawling. +func NewHTTPClient() *http.Client { jar, _ := cookiejar.New(nil) // nolint - DefaultClient = &http.Client{ + return &http.Client{ Timeout: defaultClientTimeout, Transport: &http.Transport{ TLSClientConfig: &tls.Config{ @@ -29,3 +36,33 @@ func init() { Jar: jar, } } + +// NewHTTPClientWithDNSOverride returns an http.Client suitable for +// crawling, with some additional DNS overrides. +func NewHTTPClientWithDNSOverride(dnsMap map[string]string) *http.Client { + jar, _ := cookiejar.New(nil) // nolint + dialer := new(net.Dialer) + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + if override, ok := dnsMap[host]; ok { + addr = net.JoinHostPort(override, port) + } + return dialer.DialContext(ctx, network, addr) + }, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, // nolint + }, + } + return &http.Client{ + Timeout: defaultClientTimeout, + Transport: transport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + Jar: jar, + } +} diff --git a/cmd/crawl/crawl.go b/cmd/crawl/crawl.go index 93506ac..b54b999 100644 --- a/cmd/crawl/crawl.go +++ b/cmd/crawl/crawl.go @@ -5,6 +5,7 @@ package main import ( "bufio" "bytes" + "errors" "flag" "fmt" "io" @@ -38,12 +39,27 @@ var ( warcFileSizeMB = flag.Int("output-max-size", 100, "maximum output WARC file size (in MB) when using patterns") cpuprofile = flag.String("cpuprofile", "", "create cpu profile") + dnsMap = dnsMapFlag(make(map[string]string)) excludes []*regexp.Regexp + + httpClient *http.Client ) func init() { flag.Var(&excludesFlag{}, "exclude", "exclude regex URL patterns") flag.Var(&excludesFileFlag{}, "exclude-from-file", "load exclude regex URL patterns from a file") + flag.Var(dnsMap, "resolve", "set DNS overrides (in hostname=addr format)") + + stats = &crawlStats{ + states: make(map[int]int), + start: time.Now(), + } + + go func() { + for range time.Tick(10 * time.Second) { + stats.Dump() + } + }() } type excludesFlag struct{} @@ -82,6 +98,19 @@ func (f *excludesFileFlag) Set(s string) error { return nil } +type dnsMapFlag map[string]string + +func (f dnsMapFlag) String() string { return "" } + +func (f dnsMapFlag) Set(s string) error { + parts := strings.Split(s, "=") + if len(parts) != 2 { + return errors.New("value not in host=addr format") + } + f[parts[0]] = parts[1] + return nil +} + func extractLinks(p crawl.Publisher, u string, depth int, resp *http.Response, _ error) error { links, err := analysis.GetLinks(resp) if err != nil { @@ -217,26 +246,13 @@ func (c *crawlStats) Dump() { var stats *crawlStats func fetch(urlstr string) (*http.Response, error) { - resp, err := crawl.DefaultClient.Get(urlstr) + resp, err := httpClient.Get(urlstr) if err == nil { stats.Update(resp) } return resp, err } -func init() { - stats = &crawlStats{ - states: make(map[int]int), - start: time.Now(), - } - - go func() { - for range time.Tick(10 * time.Second) { - stats.Dump() - } - }() -} - type byteCounter struct { io.ReadCloser } @@ -298,6 +314,8 @@ func main() { log.Fatal(err) } + httpClient = crawl.NewHTTPClientWithDNSOverride(dnsMap) + crawler, err := crawl.NewCrawler( *dbPath, seeds, -- cgit v1.2.3-54-g00ecf