aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorale <ale@incal.net>2020-08-23 16:53:40 +0100
committerale <ale@incal.net>2020-08-23 16:53:40 +0100
commit37c649a8b693ba65a59eab5de3c01bf212f791ad (patch)
tree72da32ab179da107226924f6fdbca417e1438c1b
parentcf35cce601579804ed91027fce92c2d507bc939d (diff)
downloadcrawl-37c649a8b693ba65a59eab5de3c01bf212f791ad.tar.gz
crawl-37c649a8b693ba65a59eab5de3c01bf212f791ad.zip
Allow setting DNS overrides using the --resolve option
-rw-r--r--client.go43
-rw-r--r--cmd/crawl/crawl.go46
2 files changed, 72 insertions, 17 deletions
diff --git a/client.go b/client.go
index 45736f5..c028e42 100644
--- a/client.go
+++ b/client.go
@@ -1,7 +1,9 @@
package crawl
import (
+ "context"
"crypto/tls"
+ "net"
"net/http"
"net/http/cookiejar"
"time"
@@ -9,14 +11,19 @@ import (
var defaultClientTimeout = 60 * time.Second
-// DefaultClient returns a http.Client suitable for crawling: does not
-// follow redirects, accepts invalid TLS certificates, sets a
+// DefaultClient points at a shared http.Client suitable for crawling:
+// does not follow redirects, accepts invalid TLS certificates, sets a
// reasonable timeout for requests.
var DefaultClient *http.Client
func init() {
+ DefaultClient = NewHTTPClient()
+}
+
+// NewHTTPClient returns an http.Client suitable for crawling.
+func NewHTTPClient() *http.Client {
jar, _ := cookiejar.New(nil) // nolint
- DefaultClient = &http.Client{
+ return &http.Client{
Timeout: defaultClientTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
@@ -29,3 +36,33 @@ func init() {
Jar: jar,
}
}
+
+// NewHTTPClientWithDNSOverride returns an http.Client suitable for
+// crawling, with some additional DNS overrides.
+func NewHTTPClientWithDNSOverride(dnsMap map[string]string) *http.Client {
+ jar, _ := cookiejar.New(nil) // nolint
+ dialer := new(net.Dialer)
+ transport := &http.Transport{
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ host, port, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, err
+ }
+ if override, ok := dnsMap[host]; ok {
+ addr = net.JoinHostPort(override, port)
+ }
+ return dialer.DialContext(ctx, network, addr)
+ },
+ TLSClientConfig: &tls.Config{
+ InsecureSkipVerify: true, // nolint
+ },
+ }
+ return &http.Client{
+ Timeout: defaultClientTimeout,
+ Transport: transport,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return http.ErrUseLastResponse
+ },
+ Jar: jar,
+ }
+}
diff --git a/cmd/crawl/crawl.go b/cmd/crawl/crawl.go
index 93506ac..b54b999 100644
--- a/cmd/crawl/crawl.go
+++ b/cmd/crawl/crawl.go
@@ -5,6 +5,7 @@ package main
import (
"bufio"
"bytes"
+ "errors"
"flag"
"fmt"
"io"
@@ -38,12 +39,27 @@ var (
warcFileSizeMB = flag.Int("output-max-size", 100, "maximum output WARC file size (in MB) when using patterns")
cpuprofile = flag.String("cpuprofile", "", "create cpu profile")
+ dnsMap = dnsMapFlag(make(map[string]string))
excludes []*regexp.Regexp
+
+ httpClient *http.Client
)
func init() {
flag.Var(&excludesFlag{}, "exclude", "exclude regex URL patterns")
flag.Var(&excludesFileFlag{}, "exclude-from-file", "load exclude regex URL patterns from a file")
+ flag.Var(dnsMap, "resolve", "set DNS overrides (in hostname=addr format)")
+
+ stats = &crawlStats{
+ states: make(map[int]int),
+ start: time.Now(),
+ }
+
+ go func() {
+ for range time.Tick(10 * time.Second) {
+ stats.Dump()
+ }
+ }()
}
type excludesFlag struct{}
@@ -82,6 +98,19 @@ func (f *excludesFileFlag) Set(s string) error {
return nil
}
+type dnsMapFlag map[string]string
+
+func (f dnsMapFlag) String() string { return "" }
+
+func (f dnsMapFlag) Set(s string) error {
+ parts := strings.Split(s, "=")
+ if len(parts) != 2 {
+ return errors.New("value not in host=addr format")
+ }
+ f[parts[0]] = parts[1]
+ return nil
+}
+
func extractLinks(p crawl.Publisher, u string, depth int, resp *http.Response, _ error) error {
links, err := analysis.GetLinks(resp)
if err != nil {
@@ -217,26 +246,13 @@ func (c *crawlStats) Dump() {
var stats *crawlStats
func fetch(urlstr string) (*http.Response, error) {
- resp, err := crawl.DefaultClient.Get(urlstr)
+ resp, err := httpClient.Get(urlstr)
if err == nil {
stats.Update(resp)
}
return resp, err
}
-func init() {
- stats = &crawlStats{
- states: make(map[int]int),
- start: time.Now(),
- }
-
- go func() {
- for range time.Tick(10 * time.Second) {
- stats.Dump()
- }
- }()
-}
-
type byteCounter struct {
io.ReadCloser
}
@@ -298,6 +314,8 @@ func main() {
log.Fatal(err)
}
+ httpClient = crawl.NewHTTPClientWithDNSOverride(dnsMap)
+
crawler, err := crawl.NewCrawler(
*dbPath,
seeds,