aboutsummaryrefslogtreecommitdiff
path: root/src/cmd/vet/cgo.go
blob: 984911c489f9d70c8c71e371ccbc6c1ac452146f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

// Check for invalid cgo pointer passing.
// This looks for code that uses cgo to call C code passing values
// whose types are almost always invalid according to the cgo pointer
// sharing rules.
// Specifically, it warns about attempts to pass a Go chan, map, func,
// or slice to C, either directly, or via a pointer, array, or struct.

package main

import (
	"go/ast"
	"go/token"
	"go/types"
)

func init() {
	register("cgocall",
		"check for types that may not be passed to cgo calls",
		checkCgoCall,
		callExpr)
}

func checkCgoCall(f *File, node ast.Node) {
	x := node.(*ast.CallExpr)

	// We are only looking for calls to functions imported from
	// the "C" package.
	sel, ok := x.Fun.(*ast.SelectorExpr)
	if !ok {
		return
	}
	id, ok := sel.X.(*ast.Ident)
	if !ok || id.Name != "C" {
		return
	}

	// A call to C.CBytes passes a pointer but is always safe.
	if sel.Sel.Name == "CBytes" {
		return
	}

	for _, arg := range x.Args {
		if !typeOKForCgoCall(cgoBaseType(f, arg), make(map[types.Type]bool)) {
			f.Badf(arg.Pos(), "possibly passing Go type with embedded pointer to C")
		}

		// Check for passing the address of a bad type.
		if conv, ok := arg.(*ast.CallExpr); ok && len(conv.Args) == 1 && f.hasBasicType(conv.Fun, types.UnsafePointer) {
			arg = conv.Args[0]
		}
		if u, ok := arg.(*ast.UnaryExpr); ok && u.Op == token.AND {
			if !typeOKForCgoCall(cgoBaseType(f, u.X), make(map[types.Type]bool)) {
				f.Badf(arg.Pos(), "possibly passing Go type with embedded pointer to C")
			}
		}
	}
}

// cgoBaseType tries to look through type conversions involving
// unsafe.Pointer to find the real type. It converts:
//   unsafe.Pointer(x) => x
//   *(*unsafe.Pointer)(unsafe.Pointer(&x)) => x
func cgoBaseType(f *File, arg ast.Expr) types.Type {
	switch arg := arg.(type) {
	case *ast.CallExpr:
		if len(arg.Args) == 1 && f.hasBasicType(arg.Fun, types.UnsafePointer) {
			return cgoBaseType(f, arg.Args[0])
		}
	case *ast.StarExpr:
		call, ok := arg.X.(*ast.CallExpr)
		if !ok || len(call.Args) != 1 {
			break
		}
		// Here arg is *f(v).
		t := f.pkg.types[call.Fun].Type
		if t == nil {
			break
		}
		ptr, ok := t.Underlying().(*types.Pointer)
		if !ok {
			break
		}
		// Here arg is *(*p)(v)
		elem, ok := ptr.Elem().Underlying().(*types.Basic)
		if !ok || elem.Kind() != types.UnsafePointer {
			break
		}
		// Here arg is *(*unsafe.Pointer)(v)
		call, ok = call.Args[0].(*ast.CallExpr)
		if !ok || len(call.Args) != 1 {
			break
		}
		// Here arg is *(*unsafe.Pointer)(f(v))
		if !f.hasBasicType(call.Fun, types.UnsafePointer) {
			break
		}
		// Here arg is *(*unsafe.Pointer)(unsafe.Pointer(v))
		u, ok := call.Args[0].(*ast.UnaryExpr)
		if !ok || u.Op != token.AND {
			break
		}
		// Here arg is *(*unsafe.Pointer)(unsafe.Pointer(&v))
		return cgoBaseType(f, u.X)
	}

	return f.pkg.types[arg].Type
}

// typeOKForCgoCall reports whether the type of arg is OK to pass to a
// C function using cgo. This is not true for Go types with embedded
// pointers. m is used to avoid infinite recursion on recursive types.
func typeOKForCgoCall(t types.Type, m map[types.Type]bool) bool {
	if t == nil || m[t] {
		return true
	}
	m[t] = true
	switch t := t.Underlying().(type) {
	case *types.Chan, *types.Map, *types.Signature, *types.Slice:
		return false
	case *types.Pointer:
		return typeOKForCgoCall(t.Elem(), m)
	case *types.Array:
		return typeOKForCgoCall(t.Elem(), m)
	case *types.Struct:
		for i := 0; i < t.NumFields(); i++ {
			if !typeOKForCgoCall(t.Field(i).Type(), m) {
				return false
			}
		}
	}
	return true
}