aboutsummaryrefslogtreecommitdiff
path: root/tun/checksum.go
blob: 29a8fc8fc0fe0d7e9a824bc76e79437f666fa473 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package tun

import "encoding/binary"

// TODO: Explore SIMD and/or other assembly optimizations.
// TODO: Test native endian loads. See RFC 1071 section 2 part B.
func checksumNoFold(b []byte, initial uint64) uint64 {
	ac := initial

	for len(b) >= 128 {
		ac += uint64(binary.BigEndian.Uint32(b[:4]))
		ac += uint64(binary.BigEndian.Uint32(b[4:8]))
		ac += uint64(binary.BigEndian.Uint32(b[8:12]))
		ac += uint64(binary.BigEndian.Uint32(b[12:16]))
		ac += uint64(binary.BigEndian.Uint32(b[16:20]))
		ac += uint64(binary.BigEndian.Uint32(b[20:24]))
		ac += uint64(binary.BigEndian.Uint32(b[24:28]))
		ac += uint64(binary.BigEndian.Uint32(b[28:32]))
		ac += uint64(binary.BigEndian.Uint32(b[32:36]))
		ac += uint64(binary.BigEndian.Uint32(b[36:40]))
		ac += uint64(binary.BigEndian.Uint32(b[40:44]))
		ac += uint64(binary.BigEndian.Uint32(b[44:48]))
		ac += uint64(binary.BigEndian.Uint32(b[48:52]))
		ac += uint64(binary.BigEndian.Uint32(b[52:56]))
		ac += uint64(binary.BigEndian.Uint32(b[56:60]))
		ac += uint64(binary.BigEndian.Uint32(b[60:64]))
		ac += uint64(binary.BigEndian.Uint32(b[64:68]))
		ac += uint64(binary.BigEndian.Uint32(b[68:72]))
		ac += uint64(binary.BigEndian.Uint32(b[72:76]))
		ac += uint64(binary.BigEndian.Uint32(b[76:80]))
		ac += uint64(binary.BigEndian.Uint32(b[80:84]))
		ac += uint64(binary.BigEndian.Uint32(b[84:88]))
		ac += uint64(binary.BigEndian.Uint32(b[88:92]))
		ac += uint64(binary.BigEndian.Uint32(b[92:96]))
		ac += uint64(binary.BigEndian.Uint32(b[96:100]))
		ac += uint64(binary.BigEndian.Uint32(b[100:104]))
		ac += uint64(binary.BigEndian.Uint32(b[104:108]))
		ac += uint64(binary.BigEndian.Uint32(b[108:112]))
		ac += uint64(binary.BigEndian.Uint32(b[112:116]))
		ac += uint64(binary.BigEndian.Uint32(b[116:120]))
		ac += uint64(binary.BigEndian.Uint32(b[120:124]))
		ac += uint64(binary.BigEndian.Uint32(b[124:128]))
		b = b[128:]
	}
	if len(b) >= 64 {
		ac += uint64(binary.BigEndian.Uint32(b[:4]))
		ac += uint64(binary.BigEndian.Uint32(b[4:8]))
		ac += uint64(binary.BigEndian.Uint32(b[8:12]))
		ac += uint64(binary.BigEndian.Uint32(b[12:16]))
		ac += uint64(binary.BigEndian.Uint32(b[16:20]))
		ac += uint64(binary.BigEndian.Uint32(b[20:24]))
		ac += uint64(binary.BigEndian.Uint32(b[24:28]))
		ac += uint64(binary.BigEndian.Uint32(b[28:32]))
		ac += uint64(binary.BigEndian.Uint32(b[32:36]))
		ac += uint64(binary.BigEndian.Uint32(b[36:40]))
		ac += uint64(binary.BigEndian.Uint32(b[40:44]))
		ac += uint64(binary.BigEndian.Uint32(b[44:48]))
		ac += uint64(binary.BigEndian.Uint32(b[48:52]))
		ac += uint64(binary.BigEndian.Uint32(b[52:56]))
		ac += uint64(binary.BigEndian.Uint32(b[56:60]))
		ac += uint64(binary.BigEndian.Uint32(b[60:64]))
		b = b[64:]
	}
	if len(b) >= 32 {
		ac += uint64(binary.BigEndian.Uint32(b[:4]))
		ac += uint64(binary.BigEndian.Uint32(b[4:8]))
		ac += uint64(binary.BigEndian.Uint32(b[8:12]))
		ac += uint64(binary.BigEndian.Uint32(b[12:16]))
		ac += uint64(binary.BigEndian.Uint32(b[16:20]))
		ac += uint64(binary.BigEndian.Uint32(b[20:24]))
		ac += uint64(binary.BigEndian.Uint32(b[24:28]))
		ac += uint64(binary.BigEndian.Uint32(b[28:32]))
		b = b[32:]
	}
	if len(b) >= 16 {
		ac += uint64(binary.BigEndian.Uint32(b[:4]))
		ac += uint64(binary.BigEndian.Uint32(b[4:8]))
		ac += uint64(binary.BigEndian.Uint32(b[8:12]))
		ac += uint64(binary.BigEndian.Uint32(b[12:16]))
		b = b[16:]
	}
	if len(b) >= 8 {
		ac += uint64(binary.BigEndian.Uint32(b[:4]))
		ac += uint64(binary.BigEndian.Uint32(b[4:8]))
		b = b[8:]
	}
	if len(b) >= 4 {
		ac += uint64(binary.BigEndian.Uint32(b))
		b = b[4:]
	}
	if len(b) >= 2 {
		ac += uint64(binary.BigEndian.Uint16(b))
		b = b[2:]
	}
	if len(b) == 1 {
		ac += uint64(b[0]) << 8
	}

	return ac
}

func checksum(b []byte, initial uint64) uint16 {
	ac := checksumNoFold(b, initial)
	ac = (ac >> 16) + (ac & 0xffff)
	ac = (ac >> 16) + (ac & 0xffff)
	ac = (ac >> 16) + (ac & 0xffff)
	ac = (ac >> 16) + (ac & 0xffff)
	return uint16(ac)
}

func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
	sum := checksumNoFold(srcAddr, 0)
	sum = checksumNoFold(dstAddr, sum)
	sum = checksumNoFold([]byte{0, protocol}, sum)
	tmp := make([]byte, 2)
	binary.BigEndian.PutUint16(tmp, totalLen)
	return checksumNoFold(tmp, sum)
}