extra.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. // Copyright (c) 2021 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package edwards25519
  5. // This file contains additional functionality that is not included in the
  6. // upstream crypto/internal/edwards25519 package.
  7. import (
  8. "errors"
  9. "filippo.io/edwards25519/field"
  10. )
  11. // ExtendedCoordinates returns v in extended coordinates (X:Y:Z:T) where
  12. // x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
  13. func (v *Point) ExtendedCoordinates() (X, Y, Z, T *field.Element) {
  14. // This function is outlined to make the allocations inline in the caller
  15. // rather than happen on the heap. Don't change the style without making
  16. // sure it doesn't increase the inliner cost.
  17. var e [4]field.Element
  18. X, Y, Z, T = v.extendedCoordinates(&e)
  19. return
  20. }
  21. func (v *Point) extendedCoordinates(e *[4]field.Element) (X, Y, Z, T *field.Element) {
  22. checkInitialized(v)
  23. X = e[0].Set(&v.x)
  24. Y = e[1].Set(&v.y)
  25. Z = e[2].Set(&v.z)
  26. T = e[3].Set(&v.t)
  27. return
  28. }
  29. // SetExtendedCoordinates sets v = (X:Y:Z:T) in extended coordinates where
  30. // x = X/Z, y = Y/Z, and xy = T/Z as in https://eprint.iacr.org/2008/522.
  31. //
  32. // If the coordinates are invalid or don't represent a valid point on the curve,
  33. // SetExtendedCoordinates returns nil and an error and the receiver is
  34. // unchanged. Otherwise, SetExtendedCoordinates returns v.
  35. func (v *Point) SetExtendedCoordinates(X, Y, Z, T *field.Element) (*Point, error) {
  36. if !isOnCurve(X, Y, Z, T) {
  37. return nil, errors.New("edwards25519: invalid point coordinates")
  38. }
  39. v.x.Set(X)
  40. v.y.Set(Y)
  41. v.z.Set(Z)
  42. v.t.Set(T)
  43. return v, nil
  44. }
  45. func isOnCurve(X, Y, Z, T *field.Element) bool {
  46. var lhs, rhs field.Element
  47. XX := new(field.Element).Square(X)
  48. YY := new(field.Element).Square(Y)
  49. ZZ := new(field.Element).Square(Z)
  50. TT := new(field.Element).Square(T)
  51. // -x² + y² = 1 + dx²y²
  52. // -(X/Z)² + (Y/Z)² = 1 + d(T/Z)²
  53. // -X² + Y² = Z² + dT²
  54. lhs.Subtract(YY, XX)
  55. rhs.Multiply(d, TT).Add(&rhs, ZZ)
  56. if lhs.Equal(&rhs) != 1 {
  57. return false
  58. }
  59. // xy = T/Z
  60. // XY/Z² = T/Z
  61. // XY = TZ
  62. lhs.Multiply(X, Y)
  63. rhs.Multiply(T, Z)
  64. return lhs.Equal(&rhs) == 1
  65. }
  66. // BytesMontgomery converts v to a point on the birationally-equivalent
  67. // Curve25519 Montgomery curve, and returns its canonical 32 bytes encoding
  68. // according to RFC 7748.
  69. //
  70. // Note that BytesMontgomery only encodes the u-coordinate, so v and -v encode
  71. // to the same value. If v is the identity point, BytesMontgomery returns 32
  72. // zero bytes, analogously to the X25519 function.
  73. //
  74. // The lack of an inverse operation (such as SetMontgomeryBytes) is deliberate:
  75. // while every valid edwards25519 point has a unique u-coordinate Montgomery
  76. // encoding, X25519 accepts inputs on the quadratic twist, which don't correspond
  77. // to any edwards25519 point, and every other X25519 input corresponds to two
  78. // edwards25519 points.
  79. func (v *Point) BytesMontgomery() []byte {
  80. // This function is outlined to make the allocations inline in the caller
  81. // rather than happen on the heap.
  82. var buf [32]byte
  83. return v.bytesMontgomery(&buf)
  84. }
  85. func (v *Point) bytesMontgomery(buf *[32]byte) []byte {
  86. checkInitialized(v)
  87. // RFC 7748, Section 4.1 provides the bilinear map to calculate the
  88. // Montgomery u-coordinate
  89. //
  90. // u = (1 + y) / (1 - y)
  91. //
  92. // where y = Y / Z.
  93. var y, recip, u field.Element
  94. y.Multiply(&v.y, y.Invert(&v.z)) // y = Y / Z
  95. recip.Invert(recip.Subtract(feOne, &y)) // r = 1/(1 - y)
  96. u.Multiply(u.Add(feOne, &y), &recip) // u = (1 + y)*r
  97. return copyFieldElement(buf, &u)
  98. }
  99. // MultByCofactor sets v = 8 * p, and returns v.
  100. func (v *Point) MultByCofactor(p *Point) *Point {
  101. checkInitialized(p)
  102. result := projP1xP1{}
  103. pp := (&projP2{}).FromP3(p)
  104. result.Double(pp)
  105. pp.FromP1xP1(&result)
  106. result.Double(pp)
  107. pp.FromP1xP1(&result)
  108. result.Double(pp)
  109. return v.fromP1xP1(&result)
  110. }
  111. // Given k > 0, set s = s**(2*i).
  112. func (s *Scalar) pow2k(k int) {
  113. for i := 0; i < k; i++ {
  114. s.Multiply(s, s)
  115. }
  116. }
  117. // Invert sets s to the inverse of a nonzero scalar v, and returns s.
  118. //
  119. // If t is zero, Invert returns zero.
  120. func (s *Scalar) Invert(t *Scalar) *Scalar {
  121. // Uses a hardcoded sliding window of width 4.
  122. var table [8]Scalar
  123. var tt Scalar
  124. tt.Multiply(t, t)
  125. table[0] = *t
  126. for i := 0; i < 7; i++ {
  127. table[i+1].Multiply(&table[i], &tt)
  128. }
  129. // Now table = [t**1, t**3, t**5, t**7, t**9, t**11, t**13, t**15]
  130. // so t**k = t[k/2] for odd k
  131. // To compute the sliding window digits, use the following Sage script:
  132. // sage: import itertools
  133. // sage: def sliding_window(w,k):
  134. // ....: digits = []
  135. // ....: while k > 0:
  136. // ....: if k % 2 == 1:
  137. // ....: kmod = k % (2**w)
  138. // ....: digits.append(kmod)
  139. // ....: k = k - kmod
  140. // ....: else:
  141. // ....: digits.append(0)
  142. // ....: k = k // 2
  143. // ....: return digits
  144. // Now we can compute s roughly as follows:
  145. // sage: s = 1
  146. // sage: for coeff in reversed(sliding_window(4,l-2)):
  147. // ....: s = s*s
  148. // ....: if coeff > 0 :
  149. // ....: s = s*t**coeff
  150. // This works on one bit at a time, with many runs of zeros.
  151. // The digits can be collapsed into [(count, coeff)] as follows:
  152. // sage: [(len(list(group)),d) for d,group in itertools.groupby(sliding_window(4,l-2))]
  153. // Entries of the form (k, 0) turn into pow2k(k)
  154. // Entries of the form (1, coeff) turn into a squaring and then a table lookup.
  155. // We can fold the squaring into the previous pow2k(k) as pow2k(k+1).
  156. *s = table[1/2]
  157. s.pow2k(127 + 1)
  158. s.Multiply(s, &table[1/2])
  159. s.pow2k(4 + 1)
  160. s.Multiply(s, &table[9/2])
  161. s.pow2k(3 + 1)
  162. s.Multiply(s, &table[11/2])
  163. s.pow2k(3 + 1)
  164. s.Multiply(s, &table[13/2])
  165. s.pow2k(3 + 1)
  166. s.Multiply(s, &table[15/2])
  167. s.pow2k(4 + 1)
  168. s.Multiply(s, &table[7/2])
  169. s.pow2k(4 + 1)
  170. s.Multiply(s, &table[15/2])
  171. s.pow2k(3 + 1)
  172. s.Multiply(s, &table[5/2])
  173. s.pow2k(3 + 1)
  174. s.Multiply(s, &table[1/2])
  175. s.pow2k(4 + 1)
  176. s.Multiply(s, &table[15/2])
  177. s.pow2k(4 + 1)
  178. s.Multiply(s, &table[15/2])
  179. s.pow2k(4 + 1)
  180. s.Multiply(s, &table[7/2])
  181. s.pow2k(3 + 1)
  182. s.Multiply(s, &table[3/2])
  183. s.pow2k(4 + 1)
  184. s.Multiply(s, &table[11/2])
  185. s.pow2k(5 + 1)
  186. s.Multiply(s, &table[11/2])
  187. s.pow2k(9 + 1)
  188. s.Multiply(s, &table[9/2])
  189. s.pow2k(3 + 1)
  190. s.Multiply(s, &table[3/2])
  191. s.pow2k(4 + 1)
  192. s.Multiply(s, &table[3/2])
  193. s.pow2k(4 + 1)
  194. s.Multiply(s, &table[3/2])
  195. s.pow2k(4 + 1)
  196. s.Multiply(s, &table[9/2])
  197. s.pow2k(3 + 1)
  198. s.Multiply(s, &table[7/2])
  199. s.pow2k(3 + 1)
  200. s.Multiply(s, &table[3/2])
  201. s.pow2k(3 + 1)
  202. s.Multiply(s, &table[13/2])
  203. s.pow2k(3 + 1)
  204. s.Multiply(s, &table[7/2])
  205. s.pow2k(4 + 1)
  206. s.Multiply(s, &table[9/2])
  207. s.pow2k(3 + 1)
  208. s.Multiply(s, &table[15/2])
  209. s.pow2k(4 + 1)
  210. s.Multiply(s, &table[11/2])
  211. return s
  212. }
  213. // MultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
  214. //
  215. // Execution time depends only on the lengths of the two slices, which must match.
  216. func (v *Point) MultiScalarMult(scalars []*Scalar, points []*Point) *Point {
  217. if len(scalars) != len(points) {
  218. panic("edwards25519: called MultiScalarMult with different size inputs")
  219. }
  220. checkInitialized(points...)
  221. // Proceed as in the single-base case, but share doublings
  222. // between each point in the multiscalar equation.
  223. // Build lookup tables for each point
  224. tables := make([]projLookupTable, len(points))
  225. for i := range tables {
  226. tables[i].FromP3(points[i])
  227. }
  228. // Compute signed radix-16 digits for each scalar
  229. digits := make([][64]int8, len(scalars))
  230. for i := range digits {
  231. digits[i] = scalars[i].signedRadix16()
  232. }
  233. // Unwrap first loop iteration to save computing 16*identity
  234. multiple := &projCached{}
  235. tmp1 := &projP1xP1{}
  236. tmp2 := &projP2{}
  237. // Lookup-and-add the appropriate multiple of each input point
  238. for j := range tables {
  239. tables[j].SelectInto(multiple, digits[j][63])
  240. tmp1.Add(v, multiple) // tmp1 = v + x_(j,63)*Q in P1xP1 coords
  241. v.fromP1xP1(tmp1) // update v
  242. }
  243. tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
  244. for i := 62; i >= 0; i-- {
  245. tmp1.Double(tmp2) // tmp1 = 2*(prev) in P1xP1 coords
  246. tmp2.FromP1xP1(tmp1) // tmp2 = 2*(prev) in P2 coords
  247. tmp1.Double(tmp2) // tmp1 = 4*(prev) in P1xP1 coords
  248. tmp2.FromP1xP1(tmp1) // tmp2 = 4*(prev) in P2 coords
  249. tmp1.Double(tmp2) // tmp1 = 8*(prev) in P1xP1 coords
  250. tmp2.FromP1xP1(tmp1) // tmp2 = 8*(prev) in P2 coords
  251. tmp1.Double(tmp2) // tmp1 = 16*(prev) in P1xP1 coords
  252. v.fromP1xP1(tmp1) // v = 16*(prev) in P3 coords
  253. // Lookup-and-add the appropriate multiple of each input point
  254. for j := range tables {
  255. tables[j].SelectInto(multiple, digits[j][i])
  256. tmp1.Add(v, multiple) // tmp1 = v + x_(j,i)*Q in P1xP1 coords
  257. v.fromP1xP1(tmp1) // update v
  258. }
  259. tmp2.FromP3(v) // set up tmp2 = v in P2 coords for next iteration
  260. }
  261. return v
  262. }
  263. // VarTimeMultiScalarMult sets v = sum(scalars[i] * points[i]), and returns v.
  264. //
  265. // Execution time depends on the inputs.
  266. func (v *Point) VarTimeMultiScalarMult(scalars []*Scalar, points []*Point) *Point {
  267. if len(scalars) != len(points) {
  268. panic("edwards25519: called VarTimeMultiScalarMult with different size inputs")
  269. }
  270. checkInitialized(points...)
  271. // Generalize double-base NAF computation to arbitrary sizes.
  272. // Here all the points are dynamic, so we only use the smaller
  273. // tables.
  274. // Build lookup tables for each point
  275. tables := make([]nafLookupTable5, len(points))
  276. for i := range tables {
  277. tables[i].FromP3(points[i])
  278. }
  279. // Compute a NAF for each scalar
  280. nafs := make([][256]int8, len(scalars))
  281. for i := range nafs {
  282. nafs[i] = scalars[i].nonAdjacentForm(5)
  283. }
  284. multiple := &projCached{}
  285. tmp1 := &projP1xP1{}
  286. tmp2 := &projP2{}
  287. tmp2.Zero()
  288. // Move from high to low bits, doubling the accumulator
  289. // at each iteration and checking whether there is a nonzero
  290. // coefficient to look up a multiple of.
  291. //
  292. // Skip trying to find the first nonzero coefficent, because
  293. // searching might be more work than a few extra doublings.
  294. for i := 255; i >= 0; i-- {
  295. tmp1.Double(tmp2)
  296. for j := range nafs {
  297. if nafs[j][i] > 0 {
  298. v.fromP1xP1(tmp1)
  299. tables[j].SelectInto(multiple, nafs[j][i])
  300. tmp1.Add(v, multiple)
  301. } else if nafs[j][i] < 0 {
  302. v.fromP1xP1(tmp1)
  303. tables[j].SelectInto(multiple, -nafs[j][i])
  304. tmp1.Sub(v, multiple)
  305. }
  306. }
  307. tmp2.FromP1xP1(tmp1)
  308. }
  309. v.fromP2(tmp2)
  310. return v
  311. }