aescts.go 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. // Package aescts provides AES CBC CipherText Stealing encryption and decryption methods
  2. package aescts
  3. import (
  4. "crypto/aes"
  5. "crypto/cipher"
  6. "errors"
  7. "fmt"
  8. )
  9. // Encrypt the message with the key and the initial vector.
  10. // Returns: next iv, ciphertext bytes, error
  11. func Encrypt(key, iv, plaintext []byte) ([]byte, []byte, error) {
  12. l := len(plaintext)
  13. block, err := aes.NewCipher(key)
  14. if err != nil {
  15. return []byte{}, []byte{}, fmt.Errorf("error creating cipher: %v", err)
  16. }
  17. mode := cipher.NewCBCEncrypter(block, iv)
  18. m := make([]byte, len(plaintext))
  19. copy(m, plaintext)
  20. /*For consistency, ciphertext stealing is always used for the last two
  21. blocks of the data to be encrypted, as in [RC5]. If the data length
  22. is a multiple of the block size, this is equivalent to plain CBC mode
  23. with the last two ciphertext blocks swapped.*/
  24. /*The initial vector carried out from one encryption for use in a
  25. subsequent encryption is the next-to-last block of the encryption
  26. output; this is the encrypted form of the last plaintext block.*/
  27. if l <= aes.BlockSize {
  28. m, _ = zeroPad(m, aes.BlockSize)
  29. mode.CryptBlocks(m, m)
  30. return m, m, nil
  31. }
  32. if l%aes.BlockSize == 0 {
  33. mode.CryptBlocks(m, m)
  34. iv = m[len(m)-aes.BlockSize:]
  35. rb, _ := swapLastTwoBlocks(m, aes.BlockSize)
  36. return iv, rb, nil
  37. }
  38. m, _ = zeroPad(m, aes.BlockSize)
  39. rb, pb, lb, err := tailBlocks(m, aes.BlockSize)
  40. if err != nil {
  41. return []byte{}, []byte{}, fmt.Errorf("error tailing blocks: %v", err)
  42. }
  43. var ct []byte
  44. if rb != nil {
  45. // Encrpt all but the lats 2 blocks and update the rolling iv
  46. mode.CryptBlocks(rb, rb)
  47. iv = rb[len(rb)-aes.BlockSize:]
  48. mode = cipher.NewCBCEncrypter(block, iv)
  49. ct = append(ct, rb...)
  50. }
  51. mode.CryptBlocks(pb, pb)
  52. mode = cipher.NewCBCEncrypter(block, pb)
  53. mode.CryptBlocks(lb, lb)
  54. // Cipher Text Stealing (CTS) - Ref: https://en.wikipedia.org/wiki/Ciphertext_stealing#CBC_ciphertext_stealing
  55. // Swap the last two cipher blocks
  56. // Truncate the ciphertext to the length of the original plaintext
  57. ct = append(ct, lb...)
  58. ct = append(ct, pb...)
  59. return lb, ct[:l], nil
  60. }
  61. // Decrypt the ciphertext with the key and the initial vector.
  62. func Decrypt(key, iv, ciphertext []byte) ([]byte, error) {
  63. // Copy the cipher text as golang slices even when passed by value to this method can result in the backing arrays of the calling code value being updated.
  64. ct := make([]byte, len(ciphertext))
  65. copy(ct, ciphertext)
  66. if len(ct) < aes.BlockSize {
  67. return []byte{}, fmt.Errorf("ciphertext is not large enough. It is less that one block size. Blocksize:%v; Ciphertext:%v", aes.BlockSize, len(ct))
  68. }
  69. // Configure the CBC
  70. block, err := aes.NewCipher(key)
  71. if err != nil {
  72. return nil, fmt.Errorf("error creating cipher: %v", err)
  73. }
  74. var mode cipher.BlockMode
  75. //If ciphertext is multiple of blocksize we just need to swap back the last two blocks and then do CBC
  76. //If the ciphertext is just one block we can't swap so we just decrypt
  77. if len(ct)%aes.BlockSize == 0 {
  78. if len(ct) > aes.BlockSize {
  79. ct, _ = swapLastTwoBlocks(ct, aes.BlockSize)
  80. }
  81. mode = cipher.NewCBCDecrypter(block, iv)
  82. message := make([]byte, len(ct))
  83. mode.CryptBlocks(message, ct)
  84. return message[:len(ct)], nil
  85. }
  86. // Cipher Text Stealing (CTS) using CBC interface. Ref: https://en.wikipedia.org/wiki/Ciphertext_stealing#CBC_ciphertext_stealing
  87. // Get ciphertext of the 2nd to last (penultimate) block (cpb), the last block (clb) and the rest (crb)
  88. crb, cpb, clb, _ := tailBlocks(ct, aes.BlockSize)
  89. v := make([]byte, len(iv), len(iv))
  90. copy(v, iv)
  91. var message []byte
  92. if crb != nil {
  93. //If there is more than just the last and the penultimate block we decrypt it and the last bloc of this becomes the iv for later
  94. rb := make([]byte, len(crb))
  95. mode = cipher.NewCBCDecrypter(block, v)
  96. v = crb[len(crb)-aes.BlockSize:]
  97. mode.CryptBlocks(rb, crb)
  98. message = append(message, rb...)
  99. }
  100. // We need to modify the cipher text
  101. // Decryt the 2nd to last (penultimate) block with a the original iv
  102. pb := make([]byte, aes.BlockSize)
  103. mode = cipher.NewCBCDecrypter(block, iv)
  104. mode.CryptBlocks(pb, cpb)
  105. // number of byte needed to pad
  106. npb := aes.BlockSize - len(ct)%aes.BlockSize
  107. //pad last block using the number of bytes needed from the tail of the plaintext 2nd to last (penultimate) block
  108. clb = append(clb, pb[len(pb)-npb:]...)
  109. // Now decrypt the last block in the penultimate position (iv will be from the crb, if the is no crb it's zeros)
  110. // iv for the penultimate block decrypted in the last position becomes the modified last block
  111. lb := make([]byte, aes.BlockSize)
  112. mode = cipher.NewCBCDecrypter(block, v)
  113. v = clb
  114. mode.CryptBlocks(lb, clb)
  115. message = append(message, lb...)
  116. // Now decrypt the penultimate block in the last position (iv will be from the modified last block)
  117. mode = cipher.NewCBCDecrypter(block, v)
  118. mode.CryptBlocks(cpb, cpb)
  119. message = append(message, cpb...)
  120. // Truncate to the size of the original cipher text
  121. return message[:len(ct)], nil
  122. }
  123. func tailBlocks(b []byte, c int) ([]byte, []byte, []byte, error) {
  124. if len(b) <= c {
  125. return []byte{}, []byte{}, []byte{}, errors.New("bytes slice is not larger than one block so cannot tail")
  126. }
  127. // Get size of last block
  128. var lbs int
  129. if l := len(b) % aes.BlockSize; l == 0 {
  130. lbs = aes.BlockSize
  131. } else {
  132. lbs = l
  133. }
  134. // Get last block
  135. lb := b[len(b)-lbs:]
  136. // Get 2nd to last (penultimate) block
  137. pb := b[len(b)-lbs-c : len(b)-lbs]
  138. if len(b) > 2*c {
  139. rb := b[:len(b)-lbs-c]
  140. return rb, pb, lb, nil
  141. }
  142. return nil, pb, lb, nil
  143. }
  144. func swapLastTwoBlocks(b []byte, c int) ([]byte, error) {
  145. rb, pb, lb, err := tailBlocks(b, c)
  146. if err != nil {
  147. return nil, err
  148. }
  149. var out []byte
  150. if rb != nil {
  151. out = append(out, rb...)
  152. }
  153. out = append(out, lb...)
  154. out = append(out, pb...)
  155. return out, nil
  156. }
  157. // zeroPad pads bytes with zeros to nearest multiple of message size m.
  158. func zeroPad(b []byte, m int) ([]byte, error) {
  159. if m <= 0 {
  160. return nil, errors.New("invalid message block size when padding")
  161. }
  162. if b == nil || len(b) == 0 {
  163. return nil, errors.New("data not valid to pad: Zero size")
  164. }
  165. if l := len(b) % m; l != 0 {
  166. n := m - l
  167. z := make([]byte, n)
  168. b = append(b, z...)
  169. }
  170. return b, nil
  171. }