summaryrefslogtreecommitdiff
path: root/src/test/slownacl_curve25519.py
blob: 25244fb122d4cd0ccbfe23045338b1178d3e4ae8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# This is the curve25519 implementation from Matthew Dempsky's "Slownacl"
# library.  It is in the public domain.
#
# It isn't constant-time.  Don't use it except for testing.
#
# Nick got the slownacl source from:
#        https://github.com/mdempsky/dnscurve/tree/master/slownacl

__all__ = ['smult_curve25519_base', 'smult_curve25519']

P = 2 ** 255 - 19
A = 486662

def expmod(b, e, m):
  if e == 0: return 1
  t = expmod(b, e / 2, m) ** 2 % m
  if e & 1: t = (t * b) % m
  return t

def inv(x):
  return expmod(x, P - 2, P)

# Addition and doubling formulas taken from Appendix D of "Curve25519:
# new Diffie-Hellman speed records".

def add((xn,zn), (xm,zm), (xd,zd)):
  x = 4 * (xm * xn - zm * zn) ** 2 * zd
  z = 4 * (xm * zn - zm * xn) ** 2 * xd
  return (x % P, z % P)

def double((xn,zn)):
  x = (xn ** 2 - zn ** 2) ** 2
  z = 4 * xn * zn * (xn ** 2 + A * xn * zn + zn ** 2)
  return (x % P, z % P)

def curve25519(n, base):
  one = (base,1)
  two = double(one)
  # f(m) evaluates to a tuple containing the mth multiple and the
  # (m+1)th multiple of base.
  def f(m):
    if m == 1: return (one, two)
    (pm, pm1) = f(m / 2)
    if (m & 1):
      return (add(pm, pm1, one), double(pm1))
    return (double(pm), add(pm, pm1, one))
  ((x,z), _) = f(n)
  return (x * inv(z)) % P

def unpack(s):
  if len(s) != 32: raise ValueError('Invalid Curve25519 argument')
  return sum(ord(s[i]) << (8 * i) for i in range(32))

def pack(n):
  return ''.join([chr((n >> (8 * i)) & 255) for i in range(32)])

def clamp(n):
  n &= ~7
  n &= ~(128 << 8 * 31)
  n |= 64 << 8 * 31
  return n

def smult_curve25519(n, p):
  n = clamp(unpack(n))
  p = unpack(p)
  return pack(curve25519(n, p))

def smult_curve25519_base(n):
  n = clamp(unpack(n))
  return pack(curve25519(n, 9))


#
# This part I'm adding in for compatibility with the curve25519 python
# module. -Nick
#
import os

class Private:
    def __init__(self, secret=None, seed=None):
        self.private = pack(clamp(unpack(os.urandom(32))))

    def get_public(self):
        return Public(smult_curve25519_base(self.private))

    def get_shared_key(self, public, hashfn):
        return hashfn(smult_curve25519(self.private, public.public))

    def serialize(self):
        return self.private

class Public:
    def __init__(self, public):
        self.public = public

    def serialize(self):
        return self.public