diff options
author | Francesc Campoy <campoy@golang.org> | 2016-11-07 23:37:21 -0800 |
---|---|---|
committer | Francesc Campoy Flores <campoy@golang.org> | 2016-11-10 20:38:11 +0000 |
commit | 47bdae942242eca4be94989bab485bb1335f354d (patch) | |
tree | aa30ca983d460ac40a31f8b7b8b3490e648a5205 | |
parent | 91135f27415ae582081e59d1b70f77d4d4d58112 (diff) | |
download | go-47bdae942242eca4be94989bab485bb1335f354d.tar.gz go-47bdae942242eca4be94989bab485bb1335f354d.zip |
cmd/vet: detect defer resp.Body.Close() before error check
This check detects the code
resp, err := http.Get("http://foo.com")
defer resp.Body.Close()
if err != nil {
...
}
For every call to a function on the net/http package or any method
on http.Client that returns (*http.Response, error), it checks
whether the next line is a defer statement that calls on the response.
Fixes #17780.
Change-Id: I9d70edcbfa2bad205bf7f45281597d074c795977
Reviewed-on: https://go-review.googlesource.com/32911
Reviewed-by: Rob Pike <r@golang.org>
-rw-r--r-- | src/cmd/vet/doc.go | 15 | ||||
-rw-r--r-- | src/cmd/vet/httpresponse.go | 153 | ||||
-rw-r--r-- | src/cmd/vet/testdata/httpresponse.go | 85 |
3 files changed, 249 insertions, 4 deletions
diff --git a/src/cmd/vet/doc.go b/src/cmd/vet/doc.go index 2baa53099d..5cbe116abe 100644 --- a/src/cmd/vet/doc.go +++ b/src/cmd/vet/doc.go @@ -84,12 +84,12 @@ Flag: -copylocks Locks that are erroneously passed by value. -Tests and documentation examples +HTTP responses used incorrectly -Flag: -tests +Flag: -httpresponse -Mistakes involving tests including functions with incorrect names or signatures -and example tests that document identifiers not in the package. +Mistakes deferring a function call on an HTTP response before +checking whether the error returned with the response was nil. Failure to call the cancelation function returned by WithCancel @@ -162,6 +162,13 @@ Flag: -structtags Struct tags that do not follow the format understood by reflect.StructTag.Get. Well-known encoding struct tags (json, xml) used with unexported fields. +Tests and documentation examples + +Flag: -tests + +Mistakes involving tests including functions with incorrect names or signatures +and example tests that document identifiers not in the package. + Unreachable code Flag: -unreachable diff --git a/src/cmd/vet/httpresponse.go b/src/cmd/vet/httpresponse.go new file mode 100644 index 0000000000..f667edb515 --- /dev/null +++ b/src/cmd/vet/httpresponse.go @@ -0,0 +1,153 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file contains the check for http.Response values being used before +// checking for errors. + +package main + +import ( + "go/ast" + "go/types" +) + +var ( + httpResponseType types.Type + httpClientType types.Type +) + +func init() { + if typ := importType("net/http", "Response"); typ != nil { + httpResponseType = typ + } + if typ := importType("net/http", "Client"); typ != nil { + httpClientType = typ + } + // if http.Response or http.Client are not defined don't register this check. + if httpResponseType == nil || httpClientType == nil { + return + } + + register("httpresponse", + "check errors are checked before using an http Response", + checkHTTPResponse, callExpr) +} + +func checkHTTPResponse(f *File, node ast.Node) { + call := node.(*ast.CallExpr) + if !isHTTPFuncOrMethodOnClient(f, call) { + return // the function call is not related to this check. + } + + finder := &blockStmtFinder{node: call} + ast.Walk(finder, f.file) + stmts := finder.stmts() + if len(stmts) < 2 { + return // the call to the http function is the last statement of the block. + } + + asg, ok := stmts[0].(*ast.AssignStmt) + if !ok { + return // the first statement is not assignment. + } + resp := rootIdent(asg.Lhs[0]) + if resp == nil { + return // could not find the http.Response in the assignment. + } + + def, ok := stmts[1].(*ast.DeferStmt) + if !ok { + return // the following statement is not a defer. + } + root := rootIdent(def.Call.Fun) + if root == nil { + return // could not find the receiver of the defer call. + } + + if resp.Obj == root.Obj { + f.Badf(root.Pos(), "using %s before checking for errors", resp.Name) + } +} + +// isHTTPFuncOrMethodOnClient checks whether the given call expression is on +// either a function of the net/http package or a method of http.Client that +// returns (*http.Response, error). +func isHTTPFuncOrMethodOnClient(f *File, expr *ast.CallExpr) bool { + fun, _ := expr.Fun.(*ast.SelectorExpr) + sig, _ := f.pkg.types[fun].Type.(*types.Signature) + if sig == nil { + return false // the call is not on of the form x.f() + } + + res := sig.Results() + if res.Len() != 2 { + return false // the function called does not return two values. + } + if ptr, ok := res.At(0).Type().(*types.Pointer); !ok || !types.Identical(ptr.Elem(), httpResponseType) { + return false // the first return type is not *http.Response. + } + if !types.Identical(res.At(1).Type().Underlying(), errorType) { + return false // the second return type is not error + } + + typ := f.pkg.types[fun.X].Type + if typ == nil { + id, ok := fun.X.(*ast.Ident) + return ok && id.Name == "http" // function in net/http package. + } + + if types.Identical(typ, httpClientType) { + return true // method on http.Client. + } + ptr, ok := typ.(*types.Pointer) + return ok && types.Identical(ptr.Elem(), httpClientType) // method on *http.Client. +} + +// blockStmtFinder is an ast.Visitor that given any ast node can find the +// statement containing it and its succeeding statements in the same block. +type blockStmtFinder struct { + node ast.Node // target of search + stmt ast.Stmt // innermost statement enclosing argument to Visit + block *ast.BlockStmt // innermost block enclosing argument to Visit. +} + +// Visit finds f.node performing a search down the ast tree. +// It keeps the last block statement and statement seen for later use. +func (f *blockStmtFinder) Visit(node ast.Node) ast.Visitor { + if node == nil || f.node.Pos() < node.Pos() || f.node.End() > node.End() { + return nil // not here + } + switch n := node.(type) { + case *ast.BlockStmt: + f.block = n + case ast.Stmt: + f.stmt = n + } + if f.node.Pos() == node.Pos() && f.node.End() == node.End() { + return nil // found + } + return f // keep looking +} + +// stmts returns the statements of f.block starting from the one including f.node. +func (f *blockStmtFinder) stmts() []ast.Stmt { + for i, v := range f.block.List { + if f.stmt == v { + return f.block.List[i:] + } + } + return nil +} + +// rootIdent finds the root identifier x in a chain of selections x.y.z, or nil if not found. +func rootIdent(n ast.Node) *ast.Ident { + switch n := n.(type) { + case *ast.SelectorExpr: + return rootIdent(n.X) + case *ast.Ident: + return n + default: + return nil + } +} diff --git a/src/cmd/vet/testdata/httpresponse.go b/src/cmd/vet/testdata/httpresponse.go new file mode 100644 index 0000000000..7302a64a3b --- /dev/null +++ b/src/cmd/vet/testdata/httpresponse.go @@ -0,0 +1,85 @@ +package testdata + +import ( + "log" + "net/http" +) + +func goodHTTPGet() { + res, err := http.Get("http://foo.com") + if err != nil { + log.Fatal(err) + } + defer res.Body.Close() +} + +func badHTTPGet() { + res, err := http.Get("http://foo.com") + defer res.Body.Close() // ERROR "using res before checking for errors" + if err != nil { + log.Fatal(err) + } +} + +func badHTTPHead() { + res, err := http.Head("http://foo.com") + defer res.Body.Close() // ERROR "using res before checking for errors" + if err != nil { + log.Fatal(err) + } +} + +func goodClientGet() { + client := http.DefaultClient + res, err := client.Get("http://foo.com") + if err != nil { + log.Fatal(err) + } + defer res.Body.Close() +} + +func badClientPtrGet() { + client := http.DefaultClient + resp, err := client.Get("http://foo.com") + defer resp.Body.Close() // ERROR "using resp before checking for errors" + if err != nil { + log.Fatal(err) + } +} + +func badClientGet() { + client := http.Client{} + resp, err := client.Get("http://foo.com") + defer resp.Body.Close() // ERROR "using resp before checking for errors" + if err != nil { + log.Fatal(err) + } +} + +func badClientPtrDo() { + client := http.DefaultClient + req, err := http.NewRequest("GET", "http://foo.com", nil) + if err != nil { + log.Fatal(err) + } + + resp, err := client.Do(req) + defer resp.Body.Close() // ERROR "using resp before checking for errors" + if err != nil { + log.Fatal(err) + } +} + +func badClientDo() { + var client http.Client + req, err := http.NewRequest("GET", "http://foo.com", nil) + if err != nil { + log.Fatal(err) + } + + resp, err := client.Do(req) + defer resp.Body.Close() // ERROR "using resp before checking for errors" + if err != nil { + log.Fatal(err) + } +} |