diff --git a/colly.go b/colly.go index decfd31f..741bdc13 100644 --- a/colly.go +++ b/colly.go @@ -56,6 +56,8 @@ type CollectorOption func(*Collector) type Collector struct { // UserAgent is the User-Agent string used by HTTP requests UserAgent string + // Custom headers for the request + Headers *http.Header // MaxDepth limits the recursion depth of visited URLs. // Set it to 0 for infinite recursion (default). MaxDepth int @@ -281,6 +283,17 @@ func UserAgent(ua string) CollectorOption { } } +// Header sets the custom headers used by the Collector. +func Headers(headers map[string]string) CollectorOption { + return func(c *Collector) { + custom_headers := make(http.Header) + for header, value := range headers { + custom_headers.Add(header, value) + } + c.Headers = &custom_headers + } +} + // MaxDepth limits the recursion depth of visited URLs. func MaxDepth(depth int) CollectorOption { return func(c *Collector) { @@ -415,6 +428,7 @@ func CheckHead() CollectorOption { // configuration for the Collector func (c *Collector) Init() { c.UserAgent = "colly - https://github.com/gocolly/colly/v2" + c.Headers = nil c.MaxDepth = 0 c.store = &storage.InMemoryStorage{} c.store.Init() @@ -568,6 +582,13 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c if hdr == nil { hdr = http.Header{} + if c.Headers != nil { + for k, v := range *c.Headers { + for _, value := range v { + hdr.Add(k, value) + } + } + } } if _, ok := hdr["User-Agent"]; !ok { hdr.Set("User-Agent", c.UserAgent) @@ -645,6 +666,7 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct request := &Request{ URL: req.URL, Headers: &req.Header, + Host: req.Host, Ctx: ctx, Depth: depth, Method: method, @@ -1284,6 +1306,7 @@ func (c *Collector) Clone() *Collector { CheckHead: c.CheckHead, ParseHTTPErrorResponse: c.ParseHTTPErrorResponse, UserAgent: c.UserAgent, + Headers: c.Headers, TraceHTTP: c.TraceHTTP, Context: c.Context, store: c.store, diff --git a/colly_test.go b/colly_test.go index 507027da..ca2aeb2b 100644 --- a/colly_test.go +++ b/colly_test.go @@ -138,6 +138,16 @@ func newTestServer() *httptest.Server { w.Write([]byte(r.Header.Get("User-Agent"))) }) + mux.HandleFunc("/host_header", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(r.Host)) + }) + + mux.HandleFunc("/custom_header", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + w.Write([]byte(r.Header.Get("Test"))) + }) + mux.HandleFunc("/base", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/html") w.Write([]byte(` @@ -1139,6 +1149,41 @@ func TestUserAgent(t *testing.T) { }() } +func TestHeaders(t *testing.T) { + const exampleHostHeader = "example.com" + const exampleTestHeader = "Testing" + + ts := newTestServer() + defer ts.Close() + + var receivedHeader string + + func() { + c := NewCollector( + Headers(map[string]string{"Host": exampleHostHeader}), + ) + c.OnResponse(func(resp *Response) { + receivedHeader = string(resp.Body) + }) + c.Visit(ts.URL + "/host_header") + if got, want := receivedHeader, exampleHostHeader; got != want { + t.Errorf("mismatched Host header: got=%q want=%q", got, want) + } + }() + func() { + c := NewCollector( + Headers(map[string]string{"Test": exampleTestHeader}), + ) + c.OnResponse(func(resp *Response) { + receivedHeader = string(resp.Body) + }) + c.Visit(ts.URL + "/custom_header") + if got, want := receivedHeader, exampleTestHeader; got != want { + t.Errorf("mismatched custom header: got=%q want=%q", got, want) + } + }() +} + func TestParseHTTPErrorResponse(t *testing.T) { contentCount := 0 ts := newTestServer() diff --git a/request.go b/request.go index 0c4b9cd9..e650efea 100644 --- a/request.go +++ b/request.go @@ -33,6 +33,8 @@ type Request struct { URL *url.URL // Headers contains the Request's HTTP headers Headers *http.Header + // the Host header + Host string // Ctx is a context between a Request and a Response Ctx *Context // Depth is the number of the parents of the request @@ -62,6 +64,7 @@ type serializableRequest struct { ID uint32 Ctx map[string]interface{} Headers http.Header + Host string } // New creates a new request with the context of the original request @@ -80,6 +83,7 @@ func (r *Request) New(method, URL string, body io.Reader) (*Request, error) { Body: body, Ctx: r.Ctx, Headers: &http.Header{}, + Host: r.Host, ID: atomic.AddUint32(&r.collector.requestCount, 1), collector: r.collector, }, nil @@ -178,6 +182,7 @@ func (r *Request) Marshal() ([]byte, error) { } sr := &serializableRequest{ URL: r.URL.String(), + Host: r.Host, Method: r.Method, Depth: r.Depth, Body: body,