diff options
Diffstat (limited to 'src/cmd/vet/httpresponse.go')
-rw-r--r-- | src/cmd/vet/httpresponse.go | 153 |
1 files changed, 153 insertions, 0 deletions
diff --git a/src/cmd/vet/httpresponse.go b/src/cmd/vet/httpresponse.go new file mode 100644 index 0000000000..f667edb515 --- /dev/null +++ b/src/cmd/vet/httpresponse.go @@ -0,0 +1,153 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file contains the check for http.Response values being used before +// checking for errors. + +package main + +import ( + "go/ast" + "go/types" +) + +var ( + httpResponseType types.Type + httpClientType types.Type +) + +func init() { + if typ := importType("net/http", "Response"); typ != nil { + httpResponseType = typ + } + if typ := importType("net/http", "Client"); typ != nil { + httpClientType = typ + } + // if http.Response or http.Client are not defined don't register this check. + if httpResponseType == nil || httpClientType == nil { + return + } + + register("httpresponse", + "check errors are checked before using an http Response", + checkHTTPResponse, callExpr) +} + +func checkHTTPResponse(f *File, node ast.Node) { + call := node.(*ast.CallExpr) + if !isHTTPFuncOrMethodOnClient(f, call) { + return // the function call is not related to this check. + } + + finder := &blockStmtFinder{node: call} + ast.Walk(finder, f.file) + stmts := finder.stmts() + if len(stmts) < 2 { + return // the call to the http function is the last statement of the block. + } + + asg, ok := stmts[0].(*ast.AssignStmt) + if !ok { + return // the first statement is not assignment. + } + resp := rootIdent(asg.Lhs[0]) + if resp == nil { + return // could not find the http.Response in the assignment. + } + + def, ok := stmts[1].(*ast.DeferStmt) + if !ok { + return // the following statement is not a defer. + } + root := rootIdent(def.Call.Fun) + if root == nil { + return // could not find the receiver of the defer call. + } + + if resp.Obj == root.Obj { + f.Badf(root.Pos(), "using %s before checking for errors", resp.Name) + } +} + +// isHTTPFuncOrMethodOnClient checks whether the given call expression is on +// either a function of the net/http package or a method of http.Client that +// returns (*http.Response, error). +func isHTTPFuncOrMethodOnClient(f *File, expr *ast.CallExpr) bool { + fun, _ := expr.Fun.(*ast.SelectorExpr) + sig, _ := f.pkg.types[fun].Type.(*types.Signature) + if sig == nil { + return false // the call is not on of the form x.f() + } + + res := sig.Results() + if res.Len() != 2 { + return false // the function called does not return two values. + } + if ptr, ok := res.At(0).Type().(*types.Pointer); !ok || !types.Identical(ptr.Elem(), httpResponseType) { + return false // the first return type is not *http.Response. + } + if !types.Identical(res.At(1).Type().Underlying(), errorType) { + return false // the second return type is not error + } + + typ := f.pkg.types[fun.X].Type + if typ == nil { + id, ok := fun.X.(*ast.Ident) + return ok && id.Name == "http" // function in net/http package. + } + + if types.Identical(typ, httpClientType) { + return true // method on http.Client. + } + ptr, ok := typ.(*types.Pointer) + return ok && types.Identical(ptr.Elem(), httpClientType) // method on *http.Client. +} + +// blockStmtFinder is an ast.Visitor that given any ast node can find the +// statement containing it and its succeeding statements in the same block. +type blockStmtFinder struct { + node ast.Node // target of search + stmt ast.Stmt // innermost statement enclosing argument to Visit + block *ast.BlockStmt // innermost block enclosing argument to Visit. +} + +// Visit finds f.node performing a search down the ast tree. +// It keeps the last block statement and statement seen for later use. +func (f *blockStmtFinder) Visit(node ast.Node) ast.Visitor { + if node == nil || f.node.Pos() < node.Pos() || f.node.End() > node.End() { + return nil // not here + } + switch n := node.(type) { + case *ast.BlockStmt: + f.block = n + case ast.Stmt: + f.stmt = n + } + if f.node.Pos() == node.Pos() && f.node.End() == node.End() { + return nil // found + } + return f // keep looking +} + +// stmts returns the statements of f.block starting from the one including f.node. +func (f *blockStmtFinder) stmts() []ast.Stmt { + for i, v := range f.block.List { + if f.stmt == v { + return f.block.List[i:] + } + } + return nil +} + +// rootIdent finds the root identifier x in a chain of selections x.y.z, or nil if not found. +func rootIdent(n ast.Node) *ast.Ident { + switch n := n.(type) { + case *ast.SelectorExpr: + return rootIdent(n.X) + case *ast.Ident: + return n + default: + return nil + } +} |