aboutsummaryrefslogtreecommitdiff
path: root/src/net/http/server.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/http/server.go')
-rw-r--r--src/net/http/server.go60
1 files changed, 40 insertions, 20 deletions
diff --git a/src/net/http/server.go b/src/net/http/server.go
index 25fab288f2..4776d960e5 100644
--- a/src/net/http/server.go
+++ b/src/net/http/server.go
@@ -14,8 +14,8 @@ import (
"errors"
"fmt"
"io"
- "io/ioutil"
"log"
+ "math/rand"
"net"
"net/textproto"
"net/url"
@@ -890,12 +890,12 @@ func (srv *Server) initialReadLimitSize() int64 {
type expectContinueReader struct {
resp *response
readCloser io.ReadCloser
- closed bool
+ closed atomicBool
sawEOF atomicBool
}
func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
- if ecr.closed {
+ if ecr.closed.isSet() {
return 0, ErrBodyReadAfterClose
}
w := ecr.resp
@@ -917,7 +917,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
}
func (ecr *expectContinueReader) Close() error {
- ecr.closed = true
+ ecr.closed.setTrue()
return ecr.readCloser.Close()
}
@@ -992,7 +992,7 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) {
}
if !http1ServerSupportsRequest(req) {
- return nil, badRequestError("unsupported protocol version")
+ return nil, statusError{StatusHTTPVersionNotSupported, "unsupported protocol version"}
}
c.lastMethod = req.Method
@@ -1368,7 +1368,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
}
if discard {
- _, err := io.CopyN(ioutil.Discard, w.reqBody, maxPostHandlerReadBytes+1)
+ _, err := io.CopyN(io.Discard, w.reqBody, maxPostHandlerReadBytes+1)
switch err {
case nil:
// There must be even more data left over.
@@ -1773,9 +1773,16 @@ func (c *conn) getState() (state ConnState, unixSec int64) {
// badRequestError is a literal string (used by in the server in HTML,
// unescaped) to tell the user why their request was bad. It should
// be plain text without user info or other embedded errors.
-type badRequestError string
+func badRequestError(e string) error { return statusError{StatusBadRequest, e} }
-func (e badRequestError) Error() string { return "Bad Request: " + string(e) }
+// statusError is an error used to respond to a request with an HTTP status.
+// The text should be plain text without user info or other embedded errors.
+type statusError struct {
+ code int
+ text string
+}
+
+func (e statusError) Error() string { return StatusText(e.code) + ": " + e.text }
// ErrAbortHandler is a sentinel panic value to abort a handler.
// While any panic from ServeHTTP aborts the response to the client,
@@ -1898,11 +1905,11 @@ func (c *conn) serve(ctx context.Context) {
return // don't reply
default:
- publicErr := "400 Bad Request"
- if v, ok := err.(badRequestError); ok {
- publicErr = publicErr + ": " + string(v)
+ if v, ok := err.(statusError); ok {
+ fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s: %s%s%d %s: %s", v.code, StatusText(v.code), v.text, errorHeaders, v.code, StatusText(v.code), v.text)
+ return
}
-
+ publicErr := "400 Bad Request"
fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr)
return
}
@@ -2685,14 +2692,14 @@ func (srv *Server) Close() error {
return err
}
-// shutdownPollInterval is how often we poll for quiescence
-// during Server.Shutdown. This is lower during tests, to
-// speed up tests.
+// shutdownPollIntervalMax is the max polling interval when checking
+// quiescence during Server.Shutdown. Polling starts with a small
+// interval and backs off to the max.
// Ideally we could find a solution that doesn't involve polling,
// but which also doesn't have a high runtime cost (and doesn't
// involve any contentious mutexes), but that is left as an
// exercise for the reader.
-var shutdownPollInterval = 500 * time.Millisecond
+const shutdownPollIntervalMax = 500 * time.Millisecond
// Shutdown gracefully shuts down the server without interrupting any
// active connections. Shutdown works by first closing all open
@@ -2725,8 +2732,20 @@ func (srv *Server) Shutdown(ctx context.Context) error {
}
srv.mu.Unlock()
- ticker := time.NewTicker(shutdownPollInterval)
- defer ticker.Stop()
+ pollIntervalBase := time.Millisecond
+ nextPollInterval := func() time.Duration {
+ // Add 10% jitter.
+ interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10)))
+ // Double and clamp for next time.
+ pollIntervalBase *= 2
+ if pollIntervalBase > shutdownPollIntervalMax {
+ pollIntervalBase = shutdownPollIntervalMax
+ }
+ return interval
+ }
+
+ timer := time.NewTimer(nextPollInterval())
+ defer timer.Stop()
for {
if srv.closeIdleConns() && srv.numListeners() == 0 {
return lnerr
@@ -2734,7 +2753,8 @@ func (srv *Server) Shutdown(ctx context.Context) error {
select {
case <-ctx.Done():
return ctx.Err()
- case <-ticker.C:
+ case <-timer.C:
+ timer.Reset(nextPollInterval())
}
}
}
@@ -3400,7 +3420,7 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) {
// (or an attack) and we abort and close the connection,
// courtesy of MaxBytesReader's EOF behavior.
mb := MaxBytesReader(w, r.Body, 4<<10)
- io.Copy(ioutil.Discard, mb)
+ io.Copy(io.Discard, mb)
}
}