common/encrypt: support verifying ciphertext of v2 encryption policies
[xfstests-dev.git] / src / fscrypt-crypt-util.c
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * fscrypt-crypt-util.c - utility for verifying fscrypt-encrypted data
4  *
5  * Copyright 2019 Google LLC
6  */
7
8 /*
9  * This program implements all crypto algorithms supported by fscrypt (a.k.a.
10  * ext4, f2fs, and ubifs encryption), for the purpose of verifying the
11  * correctness of the ciphertext stored on-disk.  See usage() below.
12  *
13  * All algorithms are implemented in portable C code to avoid depending on
14  * libcrypto (OpenSSL), and because some fscrypt-supported algorithms aren't
15  * available in libcrypto anyway (e.g. Adiantum), or are only supported in
16  * recent versions (e.g. HKDF-SHA512).  For simplicity, all crypto code here
17  * tries to follow the mathematical definitions directly, without optimizing for
18  * performance or worrying about following security best practices such as
19  * mitigating side-channel attacks.  So, only use this program for testing!
20  */
21
22 #include <asm/byteorder.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <limits.h>
26 #include <linux/types.h>
27 #include <stdarg.h>
28 #include <stdbool.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32 #include <unistd.h>
33
34 #define PROGRAM_NAME "fscrypt-crypt-util"
35
36 /*
37  * Define to enable the tests of the crypto code in this file.  If enabled, you
38  * must link this program with OpenSSL (-lcrypto) v1.1.0 or later, and your
39  * kernel needs CONFIG_CRYPTO_USER_API_SKCIPHER=y and CONFIG_CRYPTO_ADIANTUM=y.
40  */
41 #undef ENABLE_ALG_TESTS
42
43 #define NUM_ALG_TEST_ITERATIONS 10000
44
45 static void usage(FILE *fp)
46 {
47         fputs(
48 "Usage: " PROGRAM_NAME " [OPTION]... CIPHER MASTER_KEY\n"
49 "\n"
50 "Utility for verifying fscrypt-encrypted data.  This program encrypts\n"
51 "(or decrypts) the data on stdin using the given CIPHER with the given\n"
52 "MASTER_KEY (or a key derived from it, if a KDF is specified), and writes the\n"
53 "resulting ciphertext (or plaintext) to stdout.\n"
54 "\n"
55 "CIPHER can be AES-256-XTS, AES-256-CTS-CBC, AES-128-CBC-ESSIV, AES-128-CTS-CBC,\n"
56 "or Adiantum.  MASTER_KEY must be a hex string long enough for the cipher.\n"
57 "\n"
58 "WARNING: this program is only meant for testing, not for \"real\" use!\n"
59 "\n"
60 "Options:\n"
61 "  --block-size=BLOCK_SIZE     Encrypt each BLOCK_SIZE bytes independently.\n"
62 "                                Default: 4096 bytes\n"
63 "  --decrypt                   Decrypt instead of encrypt\n"
64 "  --file-nonce=NONCE          File's nonce as a 32-character hex string\n"
65 "  --help                      Show this help\n"
66 "  --kdf=KDF                   Key derivation function to use: AES-128-ECB,\n"
67 "                                HKDF-SHA512, or none.  Default: none\n"
68 "  --mode-num=NUM              Derive per-mode key using mode number NUM\n"
69 "  --padding=PADDING           If last block is partial, zero-pad it to next\n"
70 "                                PADDING-byte boundary.  Default: BLOCK_SIZE\n"
71         , fp);
72 }
73
74 /*----------------------------------------------------------------------------*
75  *                                 Utilities                                  *
76  *----------------------------------------------------------------------------*/
77
78 #define ARRAY_SIZE(A)           (sizeof(A) / sizeof((A)[0]))
79 #define MIN(x, y)               ((x) < (y) ? (x) : (y))
80 #define MAX(x, y)               ((x) > (y) ? (x) : (y))
81 #define ROUND_DOWN(x, y)        ((x) & ~((y) - 1))
82 #define ROUND_UP(x, y)          (((x) + (y) - 1) & ~((y) - 1))
83 #define DIV_ROUND_UP(n, d)      (((n) + (d) - 1) / (d))
84 #define STATIC_ASSERT(e)        ((void)sizeof(char[1 - 2*!(e)]))
85
86 typedef __u8                    u8;
87 typedef __u16                   u16;
88 typedef __u32                   u32;
89 typedef __u64                   u64;
90
91 #define cpu_to_le32             __cpu_to_le32
92 #define cpu_to_be32             __cpu_to_be32
93 #define cpu_to_le64             __cpu_to_le64
94 #define cpu_to_be64             __cpu_to_be64
95 #define le32_to_cpu             __le32_to_cpu
96 #define be32_to_cpu             __be32_to_cpu
97 #define le64_to_cpu             __le64_to_cpu
98 #define be64_to_cpu             __be64_to_cpu
99
100 #define DEFINE_UNALIGNED_ACCESS_HELPERS(type, native_type)      \
101 static inline native_type __attribute__((unused))               \
102 get_unaligned_##type(const void *p)                             \
103 {                                                               \
104         __##type x;                                             \
105                                                                 \
106         memcpy(&x, p, sizeof(x));                               \
107         return type##_to_cpu(x);                                \
108 }                                                               \
109                                                                 \
110 static inline void __attribute__((unused))                      \
111 put_unaligned_##type(native_type v, void *p)                    \
112 {                                                               \
113         __##type x = cpu_to_##type(v);                          \
114                                                                 \
115         memcpy(p, &x, sizeof(x));                               \
116 }
117
118 DEFINE_UNALIGNED_ACCESS_HELPERS(le32, u32)
119 DEFINE_UNALIGNED_ACCESS_HELPERS(be32, u32)
120 DEFINE_UNALIGNED_ACCESS_HELPERS(le64, u64)
121 DEFINE_UNALIGNED_ACCESS_HELPERS(be64, u64)
122
123 static inline bool is_power_of_2(unsigned long v)
124 {
125         return v != 0 && (v & (v - 1)) == 0;
126 }
127
128 static inline u32 rol32(u32 v, int n)
129 {
130         return (v << n) | (v >> (32 - n));
131 }
132
133 static inline u32 ror32(u32 v, int n)
134 {
135         return (v >> n) | (v << (32 - n));
136 }
137
138 static inline u64 ror64(u64 v, int n)
139 {
140         return (v >> n) | (v << (64 - n));
141 }
142
143 static inline void xor(u8 *res, const u8 *a, const u8 *b, size_t count)
144 {
145         while (count--)
146                 *res++ = *a++ ^ *b++;
147 }
148
149 static void __attribute__((noreturn, format(printf, 2, 3)))
150 do_die(int err, const char *format, ...)
151 {
152         va_list va;
153
154         va_start(va, format);
155         fputs("[" PROGRAM_NAME "] ERROR: ", stderr);
156         vfprintf(stderr, format, va);
157         if (err)
158                 fprintf(stderr, ": %s", strerror(errno));
159         putc('\n', stderr);
160         va_end(va);
161         exit(1);
162 }
163
164 #define die(format, ...)        do_die(0,     (format), ##__VA_ARGS__)
165 #define die_errno(format, ...)  do_die(errno, (format), ##__VA_ARGS__)
166
167 static __attribute__((noreturn)) void
168 assertion_failed(const char *expr, const char *file, int line)
169 {
170         die("Assertion failed: %s at %s:%d", expr, file, line);
171 }
172
173 #define ASSERT(e) ({ if (!(e)) assertion_failed(#e, __FILE__, __LINE__); })
174
175 static void *xmalloc(size_t size)
176 {
177         void *p = malloc(size);
178
179         ASSERT(p != NULL);
180         return p;
181 }
182
183 static int hexchar2bin(char c)
184 {
185         if (c >= 'a' && c <= 'f')
186                 return 10 + c - 'a';
187         if (c >= 'A' && c <= 'F')
188                 return 10 + c - 'A';
189         if (c >= '0' && c <= '9')
190                 return c - '0';
191         return -1;
192 }
193
194 static int hex2bin(const char *hex, u8 *bin, int max_bin_size)
195 {
196         size_t len = strlen(hex);
197         size_t i;
198
199         if (len & 1)
200                 return -1;
201         len /= 2;
202         if (len > max_bin_size)
203                 return -1;
204
205         for (i = 0; i < len; i++) {
206                 int high = hexchar2bin(hex[2 * i]);
207                 int low = hexchar2bin(hex[2 * i + 1]);
208
209                 if (high < 0 || low < 0)
210                         return -1;
211                 bin[i] = (high << 4) | low;
212         }
213         return len;
214 }
215
216 static size_t xread(int fd, void *buf, size_t count)
217 {
218         const size_t orig_count = count;
219
220         while (count) {
221                 ssize_t res = read(fd, buf, count);
222
223                 if (res < 0)
224                         die_errno("read error");
225                 if (res == 0)
226                         break;
227                 buf += res;
228                 count -= res;
229         }
230         return orig_count - count;
231 }
232
233 static void full_write(int fd, const void *buf, size_t count)
234 {
235         while (count) {
236                 ssize_t res = write(fd, buf, count);
237
238                 if (res < 0)
239                         die_errno("write error");
240                 buf += res;
241                 count -= res;
242         }
243 }
244
245 #ifdef ENABLE_ALG_TESTS
246 static void rand_bytes(u8 *buf, size_t count)
247 {
248         while (count--)
249                 *buf++ = rand();
250 }
251 #endif
252
253 /*----------------------------------------------------------------------------*
254  *                          Finite field arithmetic                           *
255  *----------------------------------------------------------------------------*/
256
257 /* Multiply a GF(2^8) element by the polynomial 'x' */
258 static inline u8 gf2_8_mul_x(u8 b)
259 {
260         return (b << 1) ^ ((b & 0x80) ? 0x1B : 0);
261 }
262
263 /* Multiply four packed GF(2^8) elements by the polynomial 'x' */
264 static inline u32 gf2_8_mul_x_4way(u32 w)
265 {
266         return ((w & 0x7F7F7F7F) << 1) ^ (((w & 0x80808080) >> 7) * 0x1B);
267 }
268
269 /* Element of GF(2^128) */
270 typedef struct {
271         __le64 lo;
272         __le64 hi;
273 } ble128;
274
275 /* Multiply a GF(2^128) element by the polynomial 'x' */
276 static inline void gf2_128_mul_x(ble128 *t)
277 {
278         u64 lo = le64_to_cpu(t->lo);
279         u64 hi = le64_to_cpu(t->hi);
280
281         t->hi = cpu_to_le64((hi << 1) | (lo >> 63));
282         t->lo = cpu_to_le64((lo << 1) ^ ((hi & (1ULL << 63)) ? 0x87 : 0));
283 }
284
285 /*----------------------------------------------------------------------------*
286  *                             Group arithmetic                               *
287  *----------------------------------------------------------------------------*/
288
289 /* Element of Z/(2^{128}Z)  (a.k.a. the integers modulo 2^128) */
290 typedef struct {
291         __le64 lo;
292         __le64 hi;
293 } le128;
294
295 static inline void le128_add(le128 *res, const le128 *a, const le128 *b)
296 {
297         u64 a_lo = le64_to_cpu(a->lo);
298         u64 b_lo = le64_to_cpu(b->lo);
299
300         res->lo = cpu_to_le64(a_lo + b_lo);
301         res->hi = cpu_to_le64(le64_to_cpu(a->hi) + le64_to_cpu(b->hi) +
302                               (a_lo + b_lo < a_lo));
303 }
304
305 static inline void le128_sub(le128 *res, const le128 *a, const le128 *b)
306 {
307         u64 a_lo = le64_to_cpu(a->lo);
308         u64 b_lo = le64_to_cpu(b->lo);
309
310         res->lo = cpu_to_le64(a_lo - b_lo);
311         res->hi = cpu_to_le64(le64_to_cpu(a->hi) - le64_to_cpu(b->hi) -
312                               (a_lo - b_lo > a_lo));
313 }
314
315 /*----------------------------------------------------------------------------*
316  *                              AES block cipher                              *
317  *----------------------------------------------------------------------------*/
318
319 /*
320  * Reference: "FIPS 197, Advanced Encryption Standard"
321  *      https://nvlpubs.nist.gov/nistpubs/fips/nist.fips.197.pdf
322  */
323
324 #define AES_BLOCK_SIZE          16
325 #define AES_128_KEY_SIZE        16
326 #define AES_192_KEY_SIZE        24
327 #define AES_256_KEY_SIZE        32
328
329 static inline void AddRoundKey(u32 state[4], const u32 *rk)
330 {
331         int i;
332
333         for (i = 0; i < 4; i++)
334                 state[i] ^= rk[i];
335 }
336
337 static const u8 aes_sbox[256] = {
338         0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b,
339         0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0,
340         0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26,
341         0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
342         0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2,
343         0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0,
344         0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed,
345         0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
346         0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f,
347         0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5,
348         0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec,
349         0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
350         0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14,
351         0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c,
352         0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d,
353         0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
354         0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f,
355         0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e,
356         0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11,
357         0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
358         0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f,
359         0xb0, 0x54, 0xbb, 0x16,
360 };
361
362 static u8 aes_inverse_sbox[256];
363
364 static void aes_init(void)
365 {
366         int i;
367
368         for (i = 0; i < 256; i++)
369                 aes_inverse_sbox[aes_sbox[i]] = i;
370 }
371
372 static inline u32 DoSubWord(u32 w, const u8 sbox[256])
373 {
374         return ((u32)sbox[(u8)(w >> 24)] << 24) |
375                ((u32)sbox[(u8)(w >> 16)] << 16) |
376                ((u32)sbox[(u8)(w >>  8)] <<  8) |
377                ((u32)sbox[(u8)(w >>  0)] <<  0);
378 }
379
380 static inline u32 SubWord(u32 w)
381 {
382         return DoSubWord(w, aes_sbox);
383 }
384
385 static inline u32 InvSubWord(u32 w)
386 {
387         return DoSubWord(w, aes_inverse_sbox);
388 }
389
390 static inline void SubBytes(u32 state[4])
391 {
392         int i;
393
394         for (i = 0; i < 4; i++)
395                 state[i] = SubWord(state[i]);
396 }
397
398 static inline void InvSubBytes(u32 state[4])
399 {
400         int i;
401
402         for (i = 0; i < 4; i++)
403                 state[i] = InvSubWord(state[i]);
404 }
405
406 static inline void DoShiftRows(u32 state[4], int direction)
407 {
408         u32 newstate[4];
409         int i;
410
411         for (i = 0; i < 4; i++)
412                 newstate[i] = (state[(i + direction*0) & 3] & 0xff) |
413                               (state[(i + direction*1) & 3] & 0xff00) |
414                               (state[(i + direction*2) & 3] & 0xff0000) |
415                               (state[(i + direction*3) & 3] & 0xff000000);
416         memcpy(state, newstate, 16);
417 }
418
419 static inline void ShiftRows(u32 state[4])
420 {
421         DoShiftRows(state, 1);
422 }
423
424 static inline void InvShiftRows(u32 state[4])
425 {
426         DoShiftRows(state, -1);
427 }
428
429 /*
430  * Mix one column by doing the following matrix multiplication in GF(2^8):
431  *
432  *     | 2 3 1 1 |   | w[0] |
433  *     | 1 2 3 1 |   | w[1] |
434  *     | 1 1 2 3 | x | w[2] |
435  *     | 3 1 1 2 |   | w[3] |
436  *
437  * a.k.a. w[i] = 2*w[i] + 3*w[(i+1)%4] + w[(i+2)%4] + w[(i+3)%4]
438  */
439 static inline u32 MixColumn(u32 w)
440 {
441         u32 _2w0_w2 = gf2_8_mul_x_4way(w) ^ ror32(w, 16);
442         u32 _3w1_w3 = ror32(_2w0_w2 ^ w, 8);
443
444         return _2w0_w2 ^ _3w1_w3;
445 }
446
447 /*
448  *           ( | 5 0 4 0 |   | w[0] | )
449  *          (  | 0 5 0 4 |   | w[1] |  )
450  * MixColumn(  | 4 0 5 0 | x | w[2] |  )
451  *           ( | 0 4 0 5 |   | w[3] | )
452  */
453 static inline u32 InvMixColumn(u32 w)
454 {
455         u32 _4w = gf2_8_mul_x_4way(gf2_8_mul_x_4way(w));
456
457         return MixColumn(_4w ^ w ^ ror32(_4w, 16));
458 }
459
460 static inline void MixColumns(u32 state[4])
461 {
462         int i;
463
464         for (i = 0; i < 4; i++)
465                 state[i] = MixColumn(state[i]);
466 }
467
468 static inline void InvMixColumns(u32 state[4])
469 {
470         int i;
471
472         for (i = 0; i < 4; i++)
473                 state[i] = InvMixColumn(state[i]);
474 }
475
476 struct aes_key {
477         u32 round_keys[15 * 4];
478         int nrounds;
479 };
480
481 /* Expand an AES key */
482 static void aes_setkey(struct aes_key *k, const u8 *key, int keysize)
483 {
484         const int N = keysize / 4;
485         u32 * const rk = k->round_keys;
486         u8 rcon = 1;
487         int i;
488
489         ASSERT(keysize == 16 || keysize == 24 || keysize == 32);
490         k->nrounds = 6 + N;
491         for (i = 0; i < 4 * (k->nrounds + 1); i++) {
492                 if (i < N) {
493                         rk[i] = get_unaligned_le32(&key[i * sizeof(__le32)]);
494                 } else if (i % N == 0) {
495                         rk[i] = rk[i - N] ^ SubWord(ror32(rk[i - 1], 8)) ^ rcon;
496                         rcon = gf2_8_mul_x(rcon);
497                 } else if (N > 6 && i % N == 4) {
498                         rk[i] = rk[i - N] ^ SubWord(rk[i - 1]);
499                 } else {
500                         rk[i] = rk[i - N] ^ rk[i - 1];
501                 }
502         }
503 }
504
505 /* Encrypt one 16-byte block with AES */
506 static void aes_encrypt(const struct aes_key *k, const u8 src[AES_BLOCK_SIZE],
507                         u8 dst[AES_BLOCK_SIZE])
508 {
509         u32 state[4];
510         int i;
511
512         for (i = 0; i < 4; i++)
513                 state[i] = get_unaligned_le32(&src[i * sizeof(__le32)]);
514
515         AddRoundKey(state, k->round_keys);
516         for (i = 1; i < k->nrounds; i++) {
517                 SubBytes(state);
518                 ShiftRows(state);
519                 MixColumns(state);
520                 AddRoundKey(state, &k->round_keys[4 * i]);
521         }
522         SubBytes(state);
523         ShiftRows(state);
524         AddRoundKey(state, &k->round_keys[4 * i]);
525
526         for (i = 0; i < 4; i++)
527                 put_unaligned_le32(state[i], &dst[i * sizeof(__le32)]);
528 }
529
530 /* Decrypt one 16-byte block with AES */
531 static void aes_decrypt(const struct aes_key *k, const u8 src[AES_BLOCK_SIZE],
532                         u8 dst[AES_BLOCK_SIZE])
533 {
534         u32 state[4];
535         int i;
536
537         for (i = 0; i < 4; i++)
538                 state[i] = get_unaligned_le32(&src[i * sizeof(__le32)]);
539
540         AddRoundKey(state, &k->round_keys[4 * k->nrounds]);
541         InvShiftRows(state);
542         InvSubBytes(state);
543         for (i = k->nrounds - 1; i >= 1; i--) {
544                 AddRoundKey(state, &k->round_keys[4 * i]);
545                 InvMixColumns(state);
546                 InvShiftRows(state);
547                 InvSubBytes(state);
548         }
549         AddRoundKey(state, k->round_keys);
550
551         for (i = 0; i < 4; i++)
552                 put_unaligned_le32(state[i], &dst[i * sizeof(__le32)]);
553 }
554
555 #ifdef ENABLE_ALG_TESTS
556 #include <openssl/aes.h>
557 static void test_aes_keysize(int keysize)
558 {
559         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
560
561         while (num_tests--) {
562                 struct aes_key k;
563                 AES_KEY ref_k;
564                 u8 key[AES_256_KEY_SIZE];
565                 u8 ptext[AES_BLOCK_SIZE];
566                 u8 ctext[AES_BLOCK_SIZE];
567                 u8 ref_ctext[AES_BLOCK_SIZE];
568                 u8 decrypted[AES_BLOCK_SIZE];
569
570                 rand_bytes(key, keysize);
571                 rand_bytes(ptext, AES_BLOCK_SIZE);
572
573                 aes_setkey(&k, key, keysize);
574                 aes_encrypt(&k, ptext, ctext);
575
576                 ASSERT(AES_set_encrypt_key(key, keysize*8, &ref_k) == 0);
577                 AES_encrypt(ptext, ref_ctext, &ref_k);
578
579                 ASSERT(memcmp(ctext, ref_ctext, AES_BLOCK_SIZE) == 0);
580
581                 aes_decrypt(&k, ctext, decrypted);
582                 ASSERT(memcmp(ptext, decrypted, AES_BLOCK_SIZE) == 0);
583         }
584 }
585
586 static void test_aes(void)
587 {
588         test_aes_keysize(AES_128_KEY_SIZE);
589         test_aes_keysize(AES_192_KEY_SIZE);
590         test_aes_keysize(AES_256_KEY_SIZE);
591 }
592 #endif /* ENABLE_ALG_TESTS */
593
594 /*----------------------------------------------------------------------------*
595  *                            SHA-512 and SHA-256                             *
596  *----------------------------------------------------------------------------*/
597
598 /*
599  * Reference: "FIPS 180-2, Secure Hash Standard"
600  *      https://csrc.nist.gov/csrc/media/publications/fips/180/2/archive/2002-08-01/documents/fips180-2withchangenotice.pdf
601  */
602
603 #define SHA512_DIGEST_SIZE      64
604 #define SHA512_BLOCK_SIZE       128
605
606 #define SHA256_DIGEST_SIZE      32
607 #define SHA256_BLOCK_SIZE       64
608
609 #define Ch(x, y, z)     (((x) & (y)) ^ (~(x) & (z)))
610 #define Maj(x, y, z)    (((x) & (y)) ^ ((x) & (z)) ^ ((y) & (z)))
611
612 #define Sigma512_0(x)   (ror64((x), 28) ^ ror64((x), 34) ^ ror64((x), 39))
613 #define Sigma512_1(x)   (ror64((x), 14) ^ ror64((x), 18) ^ ror64((x), 41))
614 #define sigma512_0(x)   (ror64((x),  1) ^ ror64((x),  8) ^ ((x) >> 7))
615 #define sigma512_1(x)   (ror64((x), 19) ^ ror64((x), 61) ^ ((x) >> 6))
616
617 #define Sigma256_0(x)   (ror32((x),  2) ^ ror32((x), 13) ^ ror32((x), 22))
618 #define Sigma256_1(x)   (ror32((x),  6) ^ ror32((x), 11) ^ ror32((x), 25))
619 #define sigma256_0(x)   (ror32((x),  7) ^ ror32((x), 18) ^ ((x) >>  3))
620 #define sigma256_1(x)   (ror32((x), 17) ^ ror32((x), 19) ^ ((x) >> 10))
621
622 static const u64 sha512_iv[8] = {
623         0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b,
624         0xa54ff53a5f1d36f1, 0x510e527fade682d1, 0x9b05688c2b3e6c1f,
625         0x1f83d9abfb41bd6b, 0x5be0cd19137e2179,
626 };
627
628 static const u64 sha512_round_constants[80] = {
629         0x428a2f98d728ae22, 0x7137449123ef65cd, 0xb5c0fbcfec4d3b2f,
630         0xe9b5dba58189dbbc, 0x3956c25bf348b538, 0x59f111f1b605d019,
631         0x923f82a4af194f9b, 0xab1c5ed5da6d8118, 0xd807aa98a3030242,
632         0x12835b0145706fbe, 0x243185be4ee4b28c, 0x550c7dc3d5ffb4e2,
633         0x72be5d74f27b896f, 0x80deb1fe3b1696b1, 0x9bdc06a725c71235,
634         0xc19bf174cf692694, 0xe49b69c19ef14ad2, 0xefbe4786384f25e3,
635         0x0fc19dc68b8cd5b5, 0x240ca1cc77ac9c65, 0x2de92c6f592b0275,
636         0x4a7484aa6ea6e483, 0x5cb0a9dcbd41fbd4, 0x76f988da831153b5,
637         0x983e5152ee66dfab, 0xa831c66d2db43210, 0xb00327c898fb213f,
638         0xbf597fc7beef0ee4, 0xc6e00bf33da88fc2, 0xd5a79147930aa725,
639         0x06ca6351e003826f, 0x142929670a0e6e70, 0x27b70a8546d22ffc,
640         0x2e1b21385c26c926, 0x4d2c6dfc5ac42aed, 0x53380d139d95b3df,
641         0x650a73548baf63de, 0x766a0abb3c77b2a8, 0x81c2c92e47edaee6,
642         0x92722c851482353b, 0xa2bfe8a14cf10364, 0xa81a664bbc423001,
643         0xc24b8b70d0f89791, 0xc76c51a30654be30, 0xd192e819d6ef5218,
644         0xd69906245565a910, 0xf40e35855771202a, 0x106aa07032bbd1b8,
645         0x19a4c116b8d2d0c8, 0x1e376c085141ab53, 0x2748774cdf8eeb99,
646         0x34b0bcb5e19b48a8, 0x391c0cb3c5c95a63, 0x4ed8aa4ae3418acb,
647         0x5b9cca4f7763e373, 0x682e6ff3d6b2b8a3, 0x748f82ee5defb2fc,
648         0x78a5636f43172f60, 0x84c87814a1f0ab72, 0x8cc702081a6439ec,
649         0x90befffa23631e28, 0xa4506cebde82bde9, 0xbef9a3f7b2c67915,
650         0xc67178f2e372532b, 0xca273eceea26619c, 0xd186b8c721c0c207,
651         0xeada7dd6cde0eb1e, 0xf57d4f7fee6ed178, 0x06f067aa72176fba,
652         0x0a637dc5a2c898a6, 0x113f9804bef90dae, 0x1b710b35131c471b,
653         0x28db77f523047d84, 0x32caab7b40c72493, 0x3c9ebe0a15c9bebc,
654         0x431d67c49c100d4c, 0x4cc5d4becb3e42b6, 0x597f299cfc657e2a,
655         0x5fcb6fab3ad6faec, 0x6c44198c4a475817,
656 };
657
658 /* Compute the SHA-512 digest of the given buffer */
659 static void sha512(const u8 *in, size_t inlen, u8 out[SHA512_DIGEST_SIZE])
660 {
661         const size_t msglen = ROUND_UP(inlen + 17, SHA512_BLOCK_SIZE);
662         u8 * const msg = xmalloc(msglen);
663         u64 H[8];
664         int i;
665
666         /* super naive way of handling the padding */
667         memcpy(msg, in, inlen);
668         memset(&msg[inlen], 0, msglen - inlen);
669         msg[inlen] = 0x80;
670         put_unaligned_be64((u64)inlen * 8, &msg[msglen - sizeof(__be64)]);
671         in = msg;
672
673         memcpy(H, sha512_iv, sizeof(H));
674         do {
675                 u64 a = H[0], b = H[1], c = H[2], d = H[3],
676                     e = H[4], f = H[5], g = H[6], h = H[7];
677                 u64 W[80];
678
679                 for (i = 0; i < 16; i++)
680                         W[i] = get_unaligned_be64(&in[i * sizeof(__be64)]);
681                 for (; i < ARRAY_SIZE(W); i++)
682                         W[i] = sigma512_1(W[i - 2]) + W[i - 7] +
683                                sigma512_0(W[i - 15]) + W[i - 16];
684                 for (i = 0; i < ARRAY_SIZE(W); i++) {
685                         u64 T1 = h + Sigma512_1(e) + Ch(e, f, g) +
686                                  sha512_round_constants[i] + W[i];
687                         u64 T2 = Sigma512_0(a) + Maj(a, b, c);
688
689                         h = g; g = f; f = e; e = d + T1;
690                         d = c; c = b; b = a; a = T1 + T2;
691                 }
692                 H[0] += a; H[1] += b; H[2] += c; H[3] += d;
693                 H[4] += e; H[5] += f; H[6] += g; H[7] += h;
694         } while ((in += SHA512_BLOCK_SIZE) != &msg[msglen]);
695
696         for (i = 0; i < ARRAY_SIZE(H); i++)
697                 put_unaligned_be64(H[i], &out[i * sizeof(__be64)]);
698         free(msg);
699 }
700
701 /* Compute the SHA-256 digest of the given buffer */
702 static void sha256(const u8 *in, size_t inlen, u8 out[SHA256_DIGEST_SIZE])
703 {
704         const size_t msglen = ROUND_UP(inlen + 9, SHA256_BLOCK_SIZE);
705         u8 * const msg = xmalloc(msglen);
706         u32 H[8];
707         int i;
708
709         /* super naive way of handling the padding */
710         memcpy(msg, in, inlen);
711         memset(&msg[inlen], 0, msglen - inlen);
712         msg[inlen] = 0x80;
713         put_unaligned_be64((u64)inlen * 8, &msg[msglen - sizeof(__be64)]);
714         in = msg;
715
716         for (i = 0; i < ARRAY_SIZE(H); i++)
717                 H[i] = (u32)(sha512_iv[i] >> 32);
718         do {
719                 u32 a = H[0], b = H[1], c = H[2], d = H[3],
720                     e = H[4], f = H[5], g = H[6], h = H[7];
721                 u32 W[64];
722
723                 for (i = 0; i < 16; i++)
724                         W[i] = get_unaligned_be32(&in[i * sizeof(__be32)]);
725                 for (; i < ARRAY_SIZE(W); i++)
726                         W[i] = sigma256_1(W[i - 2]) + W[i - 7] +
727                                sigma256_0(W[i - 15]) + W[i - 16];
728                 for (i = 0; i < ARRAY_SIZE(W); i++) {
729                         u32 T1 = h + Sigma256_1(e) + Ch(e, f, g) +
730                                  (u32)(sha512_round_constants[i] >> 32) + W[i];
731                         u32 T2 = Sigma256_0(a) + Maj(a, b, c);
732
733                         h = g; g = f; f = e; e = d + T1;
734                         d = c; c = b; b = a; a = T1 + T2;
735                 }
736                 H[0] += a; H[1] += b; H[2] += c; H[3] += d;
737                 H[4] += e; H[5] += f; H[6] += g; H[7] += h;
738         } while ((in += SHA256_BLOCK_SIZE) != &msg[msglen]);
739
740         for (i = 0; i < ARRAY_SIZE(H); i++)
741                 put_unaligned_be32(H[i], &out[i * sizeof(__be32)]);
742         free(msg);
743 }
744
745 #ifdef ENABLE_ALG_TESTS
746 #include <openssl/sha.h>
747 static void test_sha2(void)
748 {
749         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
750
751         while (num_tests--) {
752                 u8 in[4096];
753                 u8 digest[SHA512_DIGEST_SIZE];
754                 u8 ref_digest[SHA512_DIGEST_SIZE];
755                 const size_t inlen = rand() % (1 + sizeof(in));
756
757                 rand_bytes(in, inlen);
758
759                 sha256(in, inlen, digest);
760                 SHA256(in, inlen, ref_digest);
761                 ASSERT(memcmp(digest, ref_digest, SHA256_DIGEST_SIZE) == 0);
762
763                 sha512(in, inlen, digest);
764                 SHA512(in, inlen, ref_digest);
765                 ASSERT(memcmp(digest, ref_digest, SHA512_DIGEST_SIZE) == 0);
766         }
767 }
768 #endif /* ENABLE_ALG_TESTS */
769
770 /*----------------------------------------------------------------------------*
771  *                            HKDF implementation                             *
772  *----------------------------------------------------------------------------*/
773
774 static void hmac_sha512(const u8 *key, size_t keylen, const u8 *msg,
775                         size_t msglen, u8 mac[SHA512_DIGEST_SIZE])
776 {
777         u8 *ibuf = xmalloc(SHA512_BLOCK_SIZE + msglen);
778         u8 obuf[SHA512_BLOCK_SIZE + SHA512_DIGEST_SIZE];
779
780         ASSERT(keylen <= SHA512_BLOCK_SIZE); /* keylen > bs not implemented */
781
782         memset(ibuf, 0x36, SHA512_BLOCK_SIZE);
783         xor(ibuf, ibuf, key, keylen);
784         memcpy(&ibuf[SHA512_BLOCK_SIZE], msg, msglen);
785
786         memset(obuf, 0x5c, SHA512_BLOCK_SIZE);
787         xor(obuf, obuf, key, keylen);
788         sha512(ibuf, SHA512_BLOCK_SIZE + msglen, &obuf[SHA512_BLOCK_SIZE]);
789         sha512(obuf, sizeof(obuf), mac);
790
791         free(ibuf);
792 }
793
794 static void hkdf_sha512(const u8 *ikm, size_t ikmlen,
795                         const u8 *salt, size_t saltlen,
796                         const u8 *info, size_t infolen,
797                         u8 *output, size_t outlen)
798 {
799         static const u8 default_salt[SHA512_DIGEST_SIZE];
800         u8 prk[SHA512_DIGEST_SIZE]; /* pseudorandom key */
801         u8 *buf = xmalloc(1 + infolen + SHA512_DIGEST_SIZE);
802         u8 counter = 1;
803         size_t i;
804
805         if (saltlen == 0) {
806                 salt = default_salt;
807                 saltlen = sizeof(default_salt);
808         }
809
810         /* HKDF-Extract */
811         ASSERT(ikmlen > 0);
812         hmac_sha512(salt, saltlen, ikm, ikmlen, prk);
813
814         /* HKDF-Expand */
815         for (i = 0; i < outlen; i += SHA512_DIGEST_SIZE) {
816                 u8 *p = buf;
817                 u8 tmp[SHA512_DIGEST_SIZE];
818
819                 ASSERT(counter != 0);
820                 if (i > 0) {
821                         memcpy(p, &output[i - SHA512_DIGEST_SIZE],
822                                SHA512_DIGEST_SIZE);
823                         p += SHA512_DIGEST_SIZE;
824                 }
825                 memcpy(p, info, infolen);
826                 p += infolen;
827                 *p++ = counter++;
828                 hmac_sha512(prk, sizeof(prk), buf, p - buf, tmp);
829                 memcpy(&output[i], tmp, MIN(sizeof(tmp), outlen - i));
830         }
831         free(buf);
832 }
833
834 #ifdef ENABLE_ALG_TESTS
835 #include <openssl/evp.h>
836 #include <openssl/kdf.h>
837 static void openssl_hkdf_sha512(const u8 *ikm, size_t ikmlen,
838                                 const u8 *salt, size_t saltlen,
839                                 const u8 *info, size_t infolen,
840                                 u8 *output, size_t outlen)
841 {
842         EVP_PKEY_CTX *pctx = EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, NULL);
843         size_t actual_outlen = outlen;
844
845         ASSERT(pctx != NULL);
846         ASSERT(EVP_PKEY_derive_init(pctx) > 0);
847         ASSERT(EVP_PKEY_CTX_set_hkdf_md(pctx, EVP_sha512()) > 0);
848         ASSERT(EVP_PKEY_CTX_set1_hkdf_key(pctx, ikm, ikmlen) > 0);
849         ASSERT(EVP_PKEY_CTX_set1_hkdf_salt(pctx, salt, saltlen) > 0);
850         ASSERT(EVP_PKEY_CTX_add1_hkdf_info(pctx, info, infolen) > 0);
851         ASSERT(EVP_PKEY_derive(pctx, output, &actual_outlen) > 0);
852         ASSERT(actual_outlen == outlen);
853         EVP_PKEY_CTX_free(pctx);
854 }
855
856 static void test_hkdf_sha512(void)
857 {
858         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
859
860         while (num_tests--) {
861                 u8 ikm[SHA512_DIGEST_SIZE];
862                 u8 salt[SHA512_DIGEST_SIZE];
863                 u8 info[128];
864                 u8 actual_output[512];
865                 u8 expected_output[sizeof(actual_output)];
866                 size_t ikmlen = 1 + (rand() % sizeof(ikm));
867                 size_t saltlen = rand() % (1 + sizeof(salt));
868                 size_t infolen = rand() % (1 + sizeof(info));
869                 size_t outlen = rand() % (1 + sizeof(actual_output));
870
871                 rand_bytes(ikm, ikmlen);
872                 rand_bytes(salt, saltlen);
873                 rand_bytes(info, infolen);
874
875                 hkdf_sha512(ikm, ikmlen, salt, saltlen, info, infolen,
876                             actual_output, outlen);
877                 openssl_hkdf_sha512(ikm, ikmlen, salt, saltlen, info, infolen,
878                                     expected_output, outlen);
879                 ASSERT(memcmp(actual_output, expected_output, outlen) == 0);
880         }
881 }
882 #endif /* ENABLE_ALG_TESTS */
883
884 /*----------------------------------------------------------------------------*
885  *                            AES encryption modes                            *
886  *----------------------------------------------------------------------------*/
887
888 static void aes_256_xts_crypt(const u8 key[2 * AES_256_KEY_SIZE],
889                               const u8 iv[AES_BLOCK_SIZE], const u8 *src,
890                               u8 *dst, size_t nbytes, bool decrypting)
891 {
892         struct aes_key tweak_key, cipher_key;
893         ble128 t;
894         size_t i;
895
896         ASSERT(nbytes % AES_BLOCK_SIZE == 0);
897         aes_setkey(&cipher_key, key, AES_256_KEY_SIZE);
898         aes_setkey(&tweak_key, &key[AES_256_KEY_SIZE], AES_256_KEY_SIZE);
899         aes_encrypt(&tweak_key, iv, (u8 *)&t);
900         for (i = 0; i < nbytes; i += AES_BLOCK_SIZE) {
901                 xor(&dst[i], &src[i], (const u8 *)&t, AES_BLOCK_SIZE);
902                 if (decrypting)
903                         aes_decrypt(&cipher_key, &dst[i], &dst[i]);
904                 else
905                         aes_encrypt(&cipher_key, &dst[i], &dst[i]);
906                 xor(&dst[i], &dst[i], (const u8 *)&t, AES_BLOCK_SIZE);
907                 gf2_128_mul_x(&t);
908         }
909 }
910
911 static void aes_256_xts_encrypt(const u8 key[2 * AES_256_KEY_SIZE],
912                                 const u8 iv[AES_BLOCK_SIZE], const u8 *src,
913                                 u8 *dst, size_t nbytes)
914 {
915         aes_256_xts_crypt(key, iv, src, dst, nbytes, false);
916 }
917
918 static void aes_256_xts_decrypt(const u8 key[2 * AES_256_KEY_SIZE],
919                                 const u8 iv[AES_BLOCK_SIZE], const u8 *src,
920                                 u8 *dst, size_t nbytes)
921 {
922         aes_256_xts_crypt(key, iv, src, dst, nbytes, true);
923 }
924
925 #ifdef ENABLE_ALG_TESTS
926 #include <openssl/evp.h>
927 static void test_aes_256_xts(void)
928 {
929         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
930         EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
931
932         ASSERT(ctx != NULL);
933         while (num_tests--) {
934                 u8 key[2 * AES_256_KEY_SIZE];
935                 u8 iv[AES_BLOCK_SIZE];
936                 u8 ptext[512];
937                 u8 ctext[sizeof(ptext)];
938                 u8 ref_ctext[sizeof(ptext)];
939                 u8 decrypted[sizeof(ptext)];
940                 const size_t datalen = ROUND_DOWN(rand() % (1 + sizeof(ptext)),
941                                                   AES_BLOCK_SIZE);
942                 int outl, res;
943
944                 rand_bytes(key, sizeof(key));
945                 rand_bytes(iv, sizeof(iv));
946                 rand_bytes(ptext, datalen);
947
948                 aes_256_xts_encrypt(key, iv, ptext, ctext, datalen);
949                 res = EVP_EncryptInit_ex(ctx, EVP_aes_256_xts(), NULL, key, iv);
950                 ASSERT(res > 0);
951                 res = EVP_EncryptUpdate(ctx, ref_ctext, &outl, ptext, datalen);
952                 ASSERT(res > 0);
953                 ASSERT(outl == datalen);
954                 ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
955
956                 aes_256_xts_decrypt(key, iv, ctext, decrypted, datalen);
957                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
958         }
959         EVP_CIPHER_CTX_free(ctx);
960 }
961 #endif /* ENABLE_ALG_TESTS */
962
963 static void aes_cbc_encrypt(const struct aes_key *k,
964                             const u8 iv[AES_BLOCK_SIZE],
965                             const u8 *src, u8 *dst, size_t nbytes)
966 {
967         size_t i;
968
969         ASSERT(nbytes % AES_BLOCK_SIZE == 0);
970         for (i = 0; i < nbytes; i += AES_BLOCK_SIZE) {
971                 xor(&dst[i], &src[i], (i == 0 ? iv : &dst[i - AES_BLOCK_SIZE]),
972                     AES_BLOCK_SIZE);
973                 aes_encrypt(k, &dst[i], &dst[i]);
974         }
975 }
976
977 static void aes_cbc_decrypt(const struct aes_key *k,
978                             const u8 iv[AES_BLOCK_SIZE],
979                             const u8 *src, u8 *dst, size_t nbytes)
980 {
981         size_t i = nbytes;
982
983         ASSERT(i % AES_BLOCK_SIZE == 0);
984         while (i) {
985                 i -= AES_BLOCK_SIZE;
986                 aes_decrypt(k, &src[i], &dst[i]);
987                 xor(&dst[i], &dst[i], (i == 0 ? iv : &src[i - AES_BLOCK_SIZE]),
988                     AES_BLOCK_SIZE);
989         }
990 }
991
992 static void aes_cts_cbc_encrypt(const u8 *key, int keysize,
993                                 const u8 iv[AES_BLOCK_SIZE],
994                                 const u8 *src, u8 *dst, size_t nbytes)
995 {
996         const size_t offset = ROUND_DOWN(nbytes - 1, AES_BLOCK_SIZE);
997         const size_t final_bsize = nbytes - offset;
998         struct aes_key k;
999         u8 *pad;
1000         u8 buf[AES_BLOCK_SIZE];
1001
1002         ASSERT(nbytes >= AES_BLOCK_SIZE);
1003
1004         aes_setkey(&k, key, keysize);
1005
1006         if (nbytes == AES_BLOCK_SIZE)
1007                 return aes_cbc_encrypt(&k, iv, src, dst, nbytes);
1008
1009         aes_cbc_encrypt(&k, iv, src, dst, offset);
1010         pad = &dst[offset - AES_BLOCK_SIZE];
1011
1012         memcpy(buf, pad, AES_BLOCK_SIZE);
1013         xor(buf, buf, &src[offset], final_bsize);
1014         memcpy(&dst[offset], pad, final_bsize);
1015         aes_encrypt(&k, buf, pad);
1016 }
1017
1018 static void aes_cts_cbc_decrypt(const u8 *key, int keysize,
1019                                 const u8 iv[AES_BLOCK_SIZE],
1020                                 const u8 *src, u8 *dst, size_t nbytes)
1021 {
1022         const size_t offset = ROUND_DOWN(nbytes - 1, AES_BLOCK_SIZE);
1023         const size_t final_bsize = nbytes - offset;
1024         struct aes_key k;
1025         u8 *pad;
1026
1027         ASSERT(nbytes >= AES_BLOCK_SIZE);
1028
1029         aes_setkey(&k, key, keysize);
1030
1031         if (nbytes == AES_BLOCK_SIZE)
1032                 return aes_cbc_decrypt(&k, iv, src, dst, nbytes);
1033
1034         pad = &dst[offset - AES_BLOCK_SIZE];
1035         aes_decrypt(&k, &src[offset - AES_BLOCK_SIZE], pad);
1036         xor(&dst[offset], &src[offset], pad, final_bsize);
1037         xor(pad, pad, &dst[offset], final_bsize);
1038
1039         aes_cbc_decrypt(&k, (offset == AES_BLOCK_SIZE ?
1040                              iv : &src[offset - 2 * AES_BLOCK_SIZE]),
1041                         pad, pad, AES_BLOCK_SIZE);
1042         aes_cbc_decrypt(&k, iv, src, dst, offset - AES_BLOCK_SIZE);
1043 }
1044
1045 static void aes_256_cts_cbc_encrypt(const u8 key[AES_256_KEY_SIZE],
1046                                     const u8 iv[AES_BLOCK_SIZE],
1047                                     const u8 *src, u8 *dst, size_t nbytes)
1048 {
1049         aes_cts_cbc_encrypt(key, AES_256_KEY_SIZE, iv, src, dst, nbytes);
1050 }
1051
1052 static void aes_256_cts_cbc_decrypt(const u8 key[AES_256_KEY_SIZE],
1053                                     const u8 iv[AES_BLOCK_SIZE],
1054                                     const u8 *src, u8 *dst, size_t nbytes)
1055 {
1056         aes_cts_cbc_decrypt(key, AES_256_KEY_SIZE, iv, src, dst, nbytes);
1057 }
1058
1059 #ifdef ENABLE_ALG_TESTS
1060 #include <openssl/modes.h>
1061 static void aes_block128_f(const unsigned char in[16],
1062                            unsigned char out[16], const void *key)
1063 {
1064         aes_encrypt(key, in, out);
1065 }
1066
1067 static void test_aes_256_cts_cbc(void)
1068 {
1069         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1070
1071         while (num_tests--) {
1072                 u8 key[AES_256_KEY_SIZE];
1073                 u8 iv[AES_BLOCK_SIZE];
1074                 u8 iv_copy[AES_BLOCK_SIZE];
1075                 u8 ptext[512];
1076                 u8 ctext[sizeof(ptext)];
1077                 u8 ref_ctext[sizeof(ptext)];
1078                 u8 decrypted[sizeof(ptext)];
1079                 const size_t datalen = 16 + (rand() % (sizeof(ptext) - 15));
1080                 struct aes_key k;
1081
1082                 rand_bytes(key, sizeof(key));
1083                 rand_bytes(iv, sizeof(iv));
1084                 rand_bytes(ptext, datalen);
1085
1086                 aes_256_cts_cbc_encrypt(key, iv, ptext, ctext, datalen);
1087
1088                 /* OpenSSL doesn't allow datalen=AES_BLOCK_SIZE; Linux does */
1089                 if (datalen != AES_BLOCK_SIZE) {
1090                         aes_setkey(&k, key, sizeof(key));
1091                         memcpy(iv_copy, iv, sizeof(iv));
1092                         ASSERT(CRYPTO_cts128_encrypt_block(ptext, ref_ctext,
1093                                                            datalen, &k, iv_copy,
1094                                                            aes_block128_f)
1095                                == datalen);
1096                         ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1097                 }
1098                 aes_256_cts_cbc_decrypt(key, iv, ctext, decrypted, datalen);
1099                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1100         }
1101 }
1102 #endif /* ENABLE_ALG_TESTS */
1103
1104 static void essiv_generate_iv(const u8 orig_key[AES_128_KEY_SIZE],
1105                               const u8 orig_iv[AES_BLOCK_SIZE],
1106                               u8 real_iv[AES_BLOCK_SIZE])
1107 {
1108         u8 essiv_key[SHA256_DIGEST_SIZE];
1109         struct aes_key essiv;
1110
1111         /* AES encrypt the original IV using a hash of the original key */
1112         STATIC_ASSERT(SHA256_DIGEST_SIZE == AES_256_KEY_SIZE);
1113         sha256(orig_key, AES_128_KEY_SIZE, essiv_key);
1114         aes_setkey(&essiv, essiv_key, AES_256_KEY_SIZE);
1115         aes_encrypt(&essiv, orig_iv, real_iv);
1116 }
1117
1118 static void aes_128_cbc_essiv_encrypt(const u8 key[AES_128_KEY_SIZE],
1119                                       const u8 iv[AES_BLOCK_SIZE],
1120                                       const u8 *src, u8 *dst, size_t nbytes)
1121 {
1122         struct aes_key k;
1123         u8 real_iv[AES_BLOCK_SIZE];
1124
1125         aes_setkey(&k, key, AES_128_KEY_SIZE);
1126         essiv_generate_iv(key, iv, real_iv);
1127         aes_cbc_encrypt(&k, real_iv, src, dst, nbytes);
1128 }
1129
1130 static void aes_128_cbc_essiv_decrypt(const u8 key[AES_128_KEY_SIZE],
1131                                       const u8 iv[AES_BLOCK_SIZE],
1132                                       const u8 *src, u8 *dst, size_t nbytes)
1133 {
1134         struct aes_key k;
1135         u8 real_iv[AES_BLOCK_SIZE];
1136
1137         aes_setkey(&k, key, AES_128_KEY_SIZE);
1138         essiv_generate_iv(key, iv, real_iv);
1139         aes_cbc_decrypt(&k, real_iv, src, dst, nbytes);
1140 }
1141
1142 static void aes_128_cts_cbc_encrypt(const u8 key[AES_128_KEY_SIZE],
1143                                     const u8 iv[AES_BLOCK_SIZE],
1144                                     const u8 *src, u8 *dst, size_t nbytes)
1145 {
1146         aes_cts_cbc_encrypt(key, AES_128_KEY_SIZE, iv, src, dst, nbytes);
1147 }
1148
1149 static void aes_128_cts_cbc_decrypt(const u8 key[AES_128_KEY_SIZE],
1150                                     const u8 iv[AES_BLOCK_SIZE],
1151                                     const u8 *src, u8 *dst, size_t nbytes)
1152 {
1153         aes_cts_cbc_decrypt(key, AES_128_KEY_SIZE, iv, src, dst, nbytes);
1154 }
1155
1156 /*----------------------------------------------------------------------------*
1157  *                           XChaCha12 stream cipher                          *
1158  *----------------------------------------------------------------------------*/
1159
1160 /*
1161  * References:
1162  *   - "XChaCha: eXtended-nonce ChaCha and AEAD_XChaCha20_Poly1305"
1163  *      https://tools.ietf.org/html/draft-arciszewski-xchacha-03
1164  *
1165  *   - "ChaCha, a variant of Salsa20"
1166  *      https://cr.yp.to/chacha/chacha-20080128.pdf
1167  *
1168  *   - "Extending the Salsa20 nonce"
1169  *      https://cr.yp.to/snuffle/xsalsa-20081128.pdf
1170  */
1171
1172 #define CHACHA_KEY_SIZE         32
1173 #define XCHACHA_KEY_SIZE        CHACHA_KEY_SIZE
1174 #define XCHACHA_NONCE_SIZE      24
1175
1176 static void chacha_init_state(u32 state[16], const u8 key[CHACHA_KEY_SIZE],
1177                               const u8 iv[16])
1178 {
1179         static const u8 consts[16] = "expand 32-byte k";
1180         int i;
1181
1182         for (i = 0; i < 4; i++)
1183                 state[i] = get_unaligned_le32(&consts[i * sizeof(__le32)]);
1184         for (i = 0; i < 8; i++)
1185                 state[4 + i] = get_unaligned_le32(&key[i * sizeof(__le32)]);
1186         for (i = 0; i < 4; i++)
1187                 state[12 + i] = get_unaligned_le32(&iv[i * sizeof(__le32)]);
1188 }
1189
1190 #define CHACHA_QUARTERROUND(a, b, c, d)         \
1191         do {                                    \
1192                 a += b; d = rol32(d ^ a, 16);   \
1193                 c += d; b = rol32(b ^ c, 12);   \
1194                 a += b; d = rol32(d ^ a, 8);    \
1195                 c += d; b = rol32(b ^ c, 7);    \
1196         } while (0)
1197
1198 static void chacha_permute(u32 x[16], int nrounds)
1199 {
1200         do {
1201                 /* column round */
1202                 CHACHA_QUARTERROUND(x[0], x[4], x[8], x[12]);
1203                 CHACHA_QUARTERROUND(x[1], x[5], x[9], x[13]);
1204                 CHACHA_QUARTERROUND(x[2], x[6], x[10], x[14]);
1205                 CHACHA_QUARTERROUND(x[3], x[7], x[11], x[15]);
1206
1207                 /* diagonal round */
1208                 CHACHA_QUARTERROUND(x[0], x[5], x[10], x[15]);
1209                 CHACHA_QUARTERROUND(x[1], x[6], x[11], x[12]);
1210                 CHACHA_QUARTERROUND(x[2], x[7], x[8], x[13]);
1211                 CHACHA_QUARTERROUND(x[3], x[4], x[9], x[14]);
1212         } while ((nrounds -= 2) != 0);
1213 }
1214
1215 static void xchacha(const u8 key[XCHACHA_KEY_SIZE],
1216                     const u8 nonce[XCHACHA_NONCE_SIZE],
1217                     const u8 *src, u8 *dst, size_t nbytes, int nrounds)
1218 {
1219         u32 state[16];
1220         u8 real_key[CHACHA_KEY_SIZE];
1221         u8 real_iv[16] = { 0 };
1222         size_t i, j;
1223
1224         /* Compute real key using original key and first 128 nonce bits */
1225         chacha_init_state(state, key, nonce);
1226         chacha_permute(state, nrounds);
1227         for (i = 0; i < 8; i++) /* state words 0..3, 12..15 */
1228                 put_unaligned_le32(state[(i < 4 ? 0 : 8) + i],
1229                                    &real_key[i * sizeof(__le32)]);
1230
1231         /* Now do regular ChaCha, using real key and remaining nonce bits */
1232         memcpy(&real_iv[8], nonce + 16, 8);
1233         chacha_init_state(state, real_key, real_iv);
1234         for (i = 0; i < nbytes; i += 64) {
1235                 u32 x[16];
1236                 __le32 keystream[16];
1237
1238                 memcpy(x, state, 64);
1239                 chacha_permute(x, nrounds);
1240                 for (j = 0; j < 16; j++)
1241                         keystream[j] = cpu_to_le32(x[j] + state[j]);
1242                 xor(&dst[i], &src[i], (u8 *)keystream, MIN(nbytes - i, 64));
1243                 if (++state[12] == 0)
1244                         state[13]++;
1245         }
1246 }
1247
1248 static void xchacha12(const u8 key[XCHACHA_KEY_SIZE],
1249                       const u8 nonce[XCHACHA_NONCE_SIZE],
1250                       const u8 *src, u8 *dst, size_t nbytes)
1251 {
1252         xchacha(key, nonce, src, dst, nbytes, 12);
1253 }
1254
1255 /*----------------------------------------------------------------------------*
1256  *                                 Poly1305                                   *
1257  *----------------------------------------------------------------------------*/
1258
1259 /*
1260  * Note: this is only the Poly1305 Îµ-almost-∆-universal hash function, not the
1261  * full Poly1305 MAC.  I.e., it doesn't add anything at the end.
1262  */
1263
1264 #define POLY1305_KEY_SIZE       16
1265 #define POLY1305_BLOCK_SIZE     16
1266
1267 static void poly1305(const u8 key[POLY1305_KEY_SIZE],
1268                      const u8 *msg, size_t msglen, le128 *out)
1269 {
1270         const u32 limb_mask = 0x3ffffff;        /* limbs are base 2^26 */
1271         const u64 r0 = (get_unaligned_le32(key +  0) >> 0) & 0x3ffffff;
1272         const u64 r1 = (get_unaligned_le32(key +  3) >> 2) & 0x3ffff03;
1273         const u64 r2 = (get_unaligned_le32(key +  6) >> 4) & 0x3ffc0ff;
1274         const u64 r3 = (get_unaligned_le32(key +  9) >> 6) & 0x3f03fff;
1275         const u64 r4 = (get_unaligned_le32(key + 12) >> 8) & 0x00fffff;
1276         u32 h0 = 0, h1 = 0, h2 = 0, h3 = 0, h4 = 0;
1277         u32 g0, g1, g2, g3, g4, ge_p_mask;
1278
1279         /* Partial block support is not necessary for Adiantum */
1280         ASSERT(msglen % POLY1305_BLOCK_SIZE == 0);
1281
1282         while (msglen) {
1283                 u64 d0, d1, d2, d3, d4;
1284
1285                 /* h += *msg */
1286                 h0 += (get_unaligned_le32(msg +  0) >> 0) & limb_mask;
1287                 h1 += (get_unaligned_le32(msg +  3) >> 2) & limb_mask;
1288                 h2 += (get_unaligned_le32(msg +  6) >> 4) & limb_mask;
1289                 h3 += (get_unaligned_le32(msg +  9) >> 6) & limb_mask;
1290                 h4 += (get_unaligned_le32(msg + 12) >> 8) | (1 << 24);
1291
1292                 /* h *= r */
1293                 d0 = h0*r0 + h1*5*r4 + h2*5*r3 + h3*5*r2 + h4*5*r1;
1294                 d1 = h0*r1 + h1*r0   + h2*5*r4 + h3*5*r3 + h4*5*r2;
1295                 d2 = h0*r2 + h1*r1   + h2*r0   + h3*5*r4 + h4*5*r3;
1296                 d3 = h0*r3 + h1*r2   + h2*r1   + h3*r0   + h4*5*r4;
1297                 d4 = h0*r4 + h1*r3   + h2*r2   + h3*r1   + h4*r0;
1298
1299                 /* (partial) h %= 2^130 - 5 */
1300                 d1 += d0 >> 26;         h0 = d0 & limb_mask;
1301                 d2 += d1 >> 26;         h1 = d1 & limb_mask;
1302                 d3 += d2 >> 26;         h2 = d2 & limb_mask;
1303                 d4 += d3 >> 26;         h3 = d3 & limb_mask;
1304                 h0 += (d4 >> 26) * 5;   h4 = d4 & limb_mask;
1305                 h1 += h0 >> 26;         h0 &= limb_mask;
1306
1307                 msg += POLY1305_BLOCK_SIZE;
1308                 msglen -= POLY1305_BLOCK_SIZE;
1309         }
1310
1311         /* fully carry h */
1312         h2 += (h1 >> 26);       h1 &= limb_mask;
1313         h3 += (h2 >> 26);       h2 &= limb_mask;
1314         h4 += (h3 >> 26);       h3 &= limb_mask;
1315         h0 += (h4 >> 26) * 5;   h4 &= limb_mask;
1316         h1 += (h0 >> 26);       h0 &= limb_mask;
1317
1318         /* if (h >= 2^130 - 5) h -= 2^130 - 5; */
1319         g0 = h0 + 5;
1320         g1 = h1 + (g0 >> 26);   g0 &= limb_mask;
1321         g2 = h2 + (g1 >> 26);   g1 &= limb_mask;
1322         g3 = h3 + (g2 >> 26);   g2 &= limb_mask;
1323         g4 = h4 + (g3 >> 26);   g3 &= limb_mask;
1324         ge_p_mask = ~((g4 >> 26) - 1); /* all 1's if h >= 2^130 - 5, else 0 */
1325         h0 = (h0 & ~ge_p_mask) | (g0 & ge_p_mask);
1326         h1 = (h1 & ~ge_p_mask) | (g1 & ge_p_mask);
1327         h2 = (h2 & ~ge_p_mask) | (g2 & ge_p_mask);
1328         h3 = (h3 & ~ge_p_mask) | (g3 & ge_p_mask);
1329         h4 = (h4 & ~ge_p_mask) | (g4 & ge_p_mask & limb_mask);
1330
1331         /* h %= 2^128 */
1332         out->lo = cpu_to_le64(((u64)h2 << 52) | ((u64)h1 << 26) | h0);
1333         out->hi = cpu_to_le64(((u64)h4 << 40) | ((u64)h3 << 14) | (h2 >> 12));
1334 }
1335
1336 /*----------------------------------------------------------------------------*
1337  *                          Adiantum encryption mode                          *
1338  *----------------------------------------------------------------------------*/
1339
1340 /*
1341  * Reference: "Adiantum: length-preserving encryption for entry-level processors"
1342  *      https://tosc.iacr.org/index.php/ToSC/article/view/7360
1343  */
1344
1345 #define ADIANTUM_KEY_SIZE       32
1346 #define ADIANTUM_IV_SIZE        32
1347 #define ADIANTUM_HASH_KEY_SIZE  ((2 * POLY1305_KEY_SIZE) + NH_KEY_SIZE)
1348
1349 #define NH_KEY_SIZE             1072
1350 #define NH_KEY_WORDS            (NH_KEY_SIZE / sizeof(u32))
1351 #define NH_BLOCK_SIZE           1024
1352 #define NH_HASH_SIZE            32
1353 #define NH_MESSAGE_UNIT         16
1354
1355 static u64 nh_pass(const u32 *key, const u8 *msg, size_t msglen)
1356 {
1357         u64 sum = 0;
1358
1359         ASSERT(msglen % NH_MESSAGE_UNIT == 0);
1360         while (msglen) {
1361                 sum += (u64)(u32)(get_unaligned_le32(msg +  0) + key[0]) *
1362                             (u32)(get_unaligned_le32(msg +  8) + key[2]);
1363                 sum += (u64)(u32)(get_unaligned_le32(msg +  4) + key[1]) *
1364                             (u32)(get_unaligned_le32(msg + 12) + key[3]);
1365                 key += NH_MESSAGE_UNIT / sizeof(key[0]);
1366                 msg += NH_MESSAGE_UNIT;
1367                 msglen -= NH_MESSAGE_UNIT;
1368         }
1369         return sum;
1370 }
1371
1372 /* NH Îµ-almost-universal hash function */
1373 static void nh(const u32 *key, const u8 *msg, size_t msglen,
1374                u8 result[NH_HASH_SIZE])
1375 {
1376         size_t i;
1377
1378         for (i = 0; i < NH_HASH_SIZE; i += sizeof(__le64)) {
1379                 put_unaligned_le64(nh_pass(key, msg, msglen), &result[i]);
1380                 key += NH_MESSAGE_UNIT / sizeof(key[0]);
1381         }
1382 }
1383
1384 /* Adiantum's Îµ-almost-∆-universal hash function */
1385 static void adiantum_hash(const u8 key[ADIANTUM_HASH_KEY_SIZE],
1386                           const u8 iv[ADIANTUM_IV_SIZE],
1387                           const u8 *msg, size_t msglen, le128 *result)
1388 {
1389         const u8 *header_poly_key = key;
1390         const u8 *msg_poly_key = header_poly_key + POLY1305_KEY_SIZE;
1391         const u8 *nh_key = msg_poly_key + POLY1305_KEY_SIZE;
1392         u32 nh_key_words[NH_KEY_WORDS];
1393         u8 header[POLY1305_BLOCK_SIZE + ADIANTUM_IV_SIZE];
1394         const size_t num_nh_blocks = DIV_ROUND_UP(msglen, NH_BLOCK_SIZE);
1395         u8 *nh_hashes = xmalloc(num_nh_blocks * NH_HASH_SIZE);
1396         const size_t padded_msglen = ROUND_UP(msglen, NH_MESSAGE_UNIT);
1397         u8 *padded_msg = xmalloc(padded_msglen);
1398         le128 hash1, hash2;
1399         size_t i;
1400
1401         for (i = 0; i < NH_KEY_WORDS; i++)
1402                 nh_key_words[i] = get_unaligned_le32(&nh_key[i * sizeof(u32)]);
1403
1404         /* Hash tweak and message length with first Poly1305 key */
1405         put_unaligned_le64((u64)msglen * 8, header);
1406         put_unaligned_le64(0, &header[sizeof(__le64)]);
1407         memcpy(&header[POLY1305_BLOCK_SIZE], iv, ADIANTUM_IV_SIZE);
1408         poly1305(header_poly_key, header, sizeof(header), &hash1);
1409
1410         /* Hash NH hashes of message blocks using second Poly1305 key */
1411         /* (using a super naive way of handling the padding) */
1412         memcpy(padded_msg, msg, msglen);
1413         memset(&padded_msg[msglen], 0, padded_msglen - msglen);
1414         for (i = 0; i < num_nh_blocks; i++) {
1415                 nh(nh_key_words, &padded_msg[i * NH_BLOCK_SIZE],
1416                    MIN(NH_BLOCK_SIZE, padded_msglen - (i * NH_BLOCK_SIZE)),
1417                    &nh_hashes[i * NH_HASH_SIZE]);
1418         }
1419         poly1305(msg_poly_key, nh_hashes, num_nh_blocks * NH_HASH_SIZE, &hash2);
1420
1421         /* Add the two hashes together to get the final hash */
1422         le128_add(result, &hash1, &hash2);
1423
1424         free(nh_hashes);
1425         free(padded_msg);
1426 }
1427
1428 static void adiantum_crypt(const u8 key[ADIANTUM_KEY_SIZE],
1429                            const u8 iv[ADIANTUM_IV_SIZE], const u8 *src,
1430                            u8 *dst, size_t nbytes, bool decrypting)
1431 {
1432         u8 subkeys[AES_256_KEY_SIZE + ADIANTUM_HASH_KEY_SIZE] = { 0 };
1433         struct aes_key aes_key;
1434         union {
1435                 u8 nonce[XCHACHA_NONCE_SIZE];
1436                 le128 block;
1437         } u = { .nonce = { 1 } };
1438         const size_t bulk_len = nbytes - sizeof(u.block);
1439         le128 hash;
1440
1441         ASSERT(nbytes >= sizeof(u.block));
1442
1443         /* Derive subkeys */
1444         xchacha12(key, u.nonce, subkeys, subkeys, sizeof(subkeys));
1445         aes_setkey(&aes_key, subkeys, AES_256_KEY_SIZE);
1446
1447         /* Hash left part and add to right part */
1448         adiantum_hash(&subkeys[AES_256_KEY_SIZE], iv, src, bulk_len, &hash);
1449         memcpy(&u.block, &src[bulk_len], sizeof(u.block));
1450         le128_add(&u.block, &u.block, &hash);
1451
1452         if (!decrypting) /* Encrypt right part with block cipher */
1453                 aes_encrypt(&aes_key, u.nonce, u.nonce);
1454
1455         /* Encrypt left part with stream cipher, using the computed nonce */
1456         u.nonce[sizeof(u.block)] = 1;
1457         xchacha12(key, u.nonce, src, dst, bulk_len);
1458
1459         if (decrypting) /* Decrypt right part with block cipher */
1460                 aes_decrypt(&aes_key, u.nonce, u.nonce);
1461
1462         /* Finalize right part by subtracting hash of left part */
1463         adiantum_hash(&subkeys[AES_256_KEY_SIZE], iv, dst, bulk_len, &hash);
1464         le128_sub(&u.block, &u.block, &hash);
1465         memcpy(&dst[bulk_len], &u.block, sizeof(u.block));
1466 }
1467
1468 static void adiantum_encrypt(const u8 key[ADIANTUM_KEY_SIZE],
1469                              const u8 iv[ADIANTUM_IV_SIZE],
1470                              const u8 *src, u8 *dst, size_t nbytes)
1471 {
1472         adiantum_crypt(key, iv, src, dst, nbytes, false);
1473 }
1474
1475 static void adiantum_decrypt(const u8 key[ADIANTUM_KEY_SIZE],
1476                              const u8 iv[ADIANTUM_IV_SIZE],
1477                              const u8 *src, u8 *dst, size_t nbytes)
1478 {
1479         adiantum_crypt(key, iv, src, dst, nbytes, true);
1480 }
1481
1482 #ifdef ENABLE_ALG_TESTS
1483 #include <linux/if_alg.h>
1484 #include <sys/socket.h>
1485 #define SOL_ALG 279
1486 static void af_alg_crypt(int algfd, int op, const u8 *key, size_t keylen,
1487                          const u8 *iv, size_t ivlen,
1488                          const u8 *src, u8 *dst, size_t datalen)
1489 {
1490         size_t controllen = CMSG_SPACE(sizeof(int)) +
1491                             CMSG_SPACE(sizeof(struct af_alg_iv) + ivlen);
1492         u8 *control = xmalloc(controllen);
1493         struct iovec iov = { .iov_base = (u8 *)src, .iov_len = datalen };
1494         struct msghdr msg = {
1495                 .msg_iov = &iov,
1496                 .msg_iovlen = 1,
1497                 .msg_control = control,
1498                 .msg_controllen = controllen,
1499         };
1500         struct cmsghdr *cmsg;
1501         struct af_alg_iv *algiv;
1502         int reqfd;
1503
1504         memset(control, 0, controllen);
1505
1506         cmsg = CMSG_FIRSTHDR(&msg);
1507         cmsg->cmsg_len = CMSG_LEN(sizeof(int));
1508         cmsg->cmsg_level = SOL_ALG;
1509         cmsg->cmsg_type = ALG_SET_OP;
1510         *(int *)CMSG_DATA(cmsg) = op;
1511
1512         cmsg = CMSG_NXTHDR(&msg, cmsg);
1513         cmsg->cmsg_len = CMSG_LEN(sizeof(struct af_alg_iv) + ivlen);
1514         cmsg->cmsg_level = SOL_ALG;
1515         cmsg->cmsg_type = ALG_SET_IV;
1516         algiv = (struct af_alg_iv *)CMSG_DATA(cmsg);
1517         algiv->ivlen = ivlen;
1518         memcpy(algiv->iv, iv, ivlen);
1519
1520         if (setsockopt(algfd, SOL_ALG, ALG_SET_KEY, key, keylen) != 0)
1521                 die_errno("can't set key on AF_ALG socket");
1522
1523         reqfd = accept(algfd, NULL, NULL);
1524         if (reqfd < 0)
1525                 die_errno("can't accept() AF_ALG socket");
1526         if (sendmsg(reqfd, &msg, 0) != datalen)
1527                 die_errno("can't sendmsg() AF_ALG request socket");
1528         if (xread(reqfd, dst, datalen) != datalen)
1529                 die("short read from AF_ALG request socket");
1530         close(reqfd);
1531
1532         free(control);
1533 }
1534
1535 static void test_adiantum(void)
1536 {
1537         int algfd = socket(AF_ALG, SOCK_SEQPACKET, 0);
1538         struct sockaddr_alg addr = {
1539                 .salg_type = "skcipher",
1540                 .salg_name = "adiantum(xchacha12,aes)",
1541         };
1542         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1543
1544         if (algfd < 0)
1545                 die_errno("can't create AF_ALG socket");
1546         if (bind(algfd, (struct sockaddr *)&addr, sizeof(addr)) != 0)
1547                 die_errno("can't bind AF_ALG socket to Adiantum algorithm");
1548
1549         while (num_tests--) {
1550                 u8 key[ADIANTUM_KEY_SIZE];
1551                 u8 iv[ADIANTUM_IV_SIZE];
1552                 u8 ptext[4096];
1553                 u8 ctext[sizeof(ptext)];
1554                 u8 ref_ctext[sizeof(ptext)];
1555                 u8 decrypted[sizeof(ptext)];
1556                 const size_t datalen = 16 + (rand() % (sizeof(ptext) - 15));
1557
1558                 rand_bytes(key, sizeof(key));
1559                 rand_bytes(iv, sizeof(iv));
1560                 rand_bytes(ptext, datalen);
1561
1562                 adiantum_encrypt(key, iv, ptext, ctext, datalen);
1563                 af_alg_crypt(algfd, ALG_OP_ENCRYPT, key, sizeof(key),
1564                              iv, sizeof(iv), ptext, ref_ctext, datalen);
1565                 ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1566
1567                 adiantum_decrypt(key, iv, ctext, decrypted, datalen);
1568                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1569         }
1570         close(algfd);
1571 }
1572 #endif /* ENABLE_ALG_TESTS */
1573
1574 /*----------------------------------------------------------------------------*
1575  *                               Main program                                 *
1576  *----------------------------------------------------------------------------*/
1577
1578 #define FILE_NONCE_SIZE         16
1579 #define MAX_KEY_SIZE            64
1580
1581 static const struct fscrypt_cipher {
1582         const char *name;
1583         void (*encrypt)(const u8 *key, const u8 *iv, const u8 *src,
1584                         u8 *dst, size_t nbytes);
1585         void (*decrypt)(const u8 *key, const u8 *iv, const u8 *src,
1586                         u8 *dst, size_t nbytes);
1587         int keysize;
1588         int min_input_size;
1589 } fscrypt_ciphers[] = {
1590         {
1591                 .name = "AES-256-XTS",
1592                 .encrypt = aes_256_xts_encrypt,
1593                 .decrypt = aes_256_xts_decrypt,
1594                 .keysize = 2 * AES_256_KEY_SIZE,
1595         }, {
1596                 .name = "AES-256-CTS-CBC",
1597                 .encrypt = aes_256_cts_cbc_encrypt,
1598                 .decrypt = aes_256_cts_cbc_decrypt,
1599                 .keysize = AES_256_KEY_SIZE,
1600                 .min_input_size = AES_BLOCK_SIZE,
1601         }, {
1602                 .name = "AES-128-CBC-ESSIV",
1603                 .encrypt = aes_128_cbc_essiv_encrypt,
1604                 .decrypt = aes_128_cbc_essiv_decrypt,
1605                 .keysize = AES_128_KEY_SIZE,
1606         }, {
1607                 .name = "AES-128-CTS-CBC",
1608                 .encrypt = aes_128_cts_cbc_encrypt,
1609                 .decrypt = aes_128_cts_cbc_decrypt,
1610                 .keysize = AES_128_KEY_SIZE,
1611                 .min_input_size = AES_BLOCK_SIZE,
1612         }, {
1613                 .name = "Adiantum",
1614                 .encrypt = adiantum_encrypt,
1615                 .decrypt = adiantum_decrypt,
1616                 .keysize = ADIANTUM_KEY_SIZE,
1617                 .min_input_size = AES_BLOCK_SIZE,
1618         }
1619 };
1620
1621 static const struct fscrypt_cipher *find_fscrypt_cipher(const char *name)
1622 {
1623         size_t i;
1624
1625         for (i = 0; i < ARRAY_SIZE(fscrypt_ciphers); i++) {
1626                 if (strcmp(fscrypt_ciphers[i].name, name) == 0)
1627                         return &fscrypt_ciphers[i];
1628         }
1629         return NULL;
1630 }
1631
1632 struct fscrypt_iv {
1633         union {
1634                 __le64 block_num;
1635                 u8 bytes[32];
1636         };
1637 };
1638
1639 static void crypt_loop(const struct fscrypt_cipher *cipher, const u8 *key,
1640                        struct fscrypt_iv *iv, bool decrypting,
1641                        size_t block_size, size_t padding)
1642 {
1643         u8 *buf = xmalloc(block_size);
1644         size_t res;
1645
1646         while ((res = xread(STDIN_FILENO, buf, block_size)) > 0) {
1647                 size_t crypt_len = block_size;
1648
1649                 if (padding > 0) {
1650                         crypt_len = MAX(res, cipher->min_input_size);
1651                         crypt_len = ROUND_UP(crypt_len, padding);
1652                         crypt_len = MIN(crypt_len, block_size);
1653                 }
1654                 ASSERT(crypt_len >= res);
1655                 memset(&buf[res], 0, crypt_len - res);
1656
1657                 if (decrypting)
1658                         cipher->decrypt(key, iv->bytes, buf, buf, crypt_len);
1659                 else
1660                         cipher->encrypt(key, iv->bytes, buf, buf, crypt_len);
1661
1662                 full_write(STDOUT_FILENO, buf, crypt_len);
1663
1664                 iv->block_num = cpu_to_le64(le64_to_cpu(iv->block_num) + 1);
1665         }
1666         free(buf);
1667 }
1668
1669 /* The supported key derivation functions */
1670 enum kdf_algorithm {
1671         KDF_NONE,
1672         KDF_AES_128_ECB,
1673         KDF_HKDF_SHA512,
1674 };
1675
1676 static enum kdf_algorithm parse_kdf_algorithm(const char *arg)
1677 {
1678         if (strcmp(arg, "none") == 0)
1679                 return KDF_NONE;
1680         if (strcmp(arg, "AES-128-ECB") == 0)
1681                 return KDF_AES_128_ECB;
1682         if (strcmp(arg, "HKDF-SHA512") == 0)
1683                 return KDF_HKDF_SHA512;
1684         die("Unknown KDF: %s", arg);
1685 }
1686
1687 static u8 parse_mode_number(const char *arg)
1688 {
1689         char *tmp;
1690         long num = strtol(arg, &tmp, 10);
1691
1692         if (num <= 0 || *tmp || (u8)num != num)
1693                 die("Invalid mode number: %s", arg);
1694         return num;
1695 }
1696
1697 /*
1698  * Get the key and starting IV with which the encryption will actually be done.
1699  * If a KDF was specified, a subkey is derived from the master key and the mode
1700  * number or file nonce.  Otherwise, the master key is used directly.
1701  */
1702 static void get_key_and_iv(const u8 *master_key, size_t master_key_size,
1703                            enum kdf_algorithm kdf,
1704                            u8 mode_num, const u8 nonce[FILE_NONCE_SIZE],
1705                            u8 *real_key, size_t real_key_size,
1706                            struct fscrypt_iv *iv)
1707 {
1708         bool nonce_in_iv = false;
1709         struct aes_key aes_key;
1710         u8 info[8 + 1 + FILE_NONCE_SIZE] = "fscrypt";
1711         size_t infolen = 8;
1712         size_t i;
1713
1714         ASSERT(real_key_size <= master_key_size);
1715
1716         memset(iv, 0, sizeof(*iv));
1717
1718         switch (kdf) {
1719         case KDF_NONE:
1720                 if (mode_num != 0)
1721                         die("--mode-num isn't supported with --kdf=none");
1722                 memcpy(real_key, master_key, real_key_size);
1723                 nonce_in_iv = true;
1724                 break;
1725         case KDF_AES_128_ECB:
1726                 if (nonce == NULL)
1727                         die("--file-nonce is required with --kdf=AES-128-ECB");
1728                 if (mode_num != 0)
1729                         die("--mode-num isn't supported with --kdf=AES-128-ECB");
1730                 STATIC_ASSERT(FILE_NONCE_SIZE == AES_128_KEY_SIZE);
1731                 ASSERT(real_key_size % AES_BLOCK_SIZE == 0);
1732                 aes_setkey(&aes_key, nonce, AES_128_KEY_SIZE);
1733                 for (i = 0; i < real_key_size; i += AES_BLOCK_SIZE)
1734                         aes_encrypt(&aes_key, &master_key[i], &real_key[i]);
1735                 break;
1736         case KDF_HKDF_SHA512:
1737                 if (mode_num != 0) {
1738                         info[infolen++] = 3; /* HKDF_CONTEXT_PER_MODE_KEY */
1739                         info[infolen++] = mode_num;
1740                         nonce_in_iv = true;
1741                 } else if (nonce != NULL) {
1742                         info[infolen++] = 2; /* HKDF_CONTEXT_PER_FILE_KEY */
1743                         memcpy(&info[infolen], nonce, FILE_NONCE_SIZE);
1744                         infolen += FILE_NONCE_SIZE;
1745                 } else {
1746                         die("With --kdf=HKDF-SHA512, at least one of --file-nonce and --mode-num must be specified");
1747                 }
1748                 hkdf_sha512(master_key, master_key_size, NULL, 0,
1749                             info, infolen, real_key, real_key_size);
1750                 break;
1751         default:
1752                 ASSERT(0);
1753         }
1754
1755         if (nonce_in_iv && nonce != NULL)
1756                 memcpy(&iv->bytes[8], nonce, FILE_NONCE_SIZE);
1757 }
1758
1759 enum {
1760         OPT_BLOCK_SIZE,
1761         OPT_DECRYPT,
1762         OPT_FILE_NONCE,
1763         OPT_HELP,
1764         OPT_KDF,
1765         OPT_MODE_NUM,
1766         OPT_PADDING,
1767 };
1768
1769 static const struct option longopts[] = {
1770         { "block-size",      required_argument, NULL, OPT_BLOCK_SIZE },
1771         { "decrypt",         no_argument,       NULL, OPT_DECRYPT },
1772         { "file-nonce",      required_argument, NULL, OPT_FILE_NONCE },
1773         { "help",            no_argument,       NULL, OPT_HELP },
1774         { "kdf",             required_argument, NULL, OPT_KDF },
1775         { "mode-num",        required_argument, NULL, OPT_MODE_NUM },
1776         { "padding",         required_argument, NULL, OPT_PADDING },
1777         { NULL, 0, NULL, 0 },
1778 };
1779
1780 int main(int argc, char *argv[])
1781 {
1782         size_t block_size = 4096;
1783         bool decrypting = false;
1784         u8 _file_nonce[FILE_NONCE_SIZE];
1785         u8 *file_nonce = NULL;
1786         enum kdf_algorithm kdf = KDF_NONE;
1787         u8 mode_num = 0;
1788         size_t padding = 0;
1789         const struct fscrypt_cipher *cipher;
1790         u8 master_key[MAX_KEY_SIZE];
1791         int master_key_size;
1792         u8 real_key[MAX_KEY_SIZE];
1793         struct fscrypt_iv iv;
1794         char *tmp;
1795         int c;
1796
1797         aes_init();
1798
1799 #ifdef ENABLE_ALG_TESTS
1800         test_aes();
1801         test_sha2();
1802         test_hkdf_sha512();
1803         test_aes_256_xts();
1804         test_aes_256_cts_cbc();
1805         test_adiantum();
1806 #endif
1807
1808         while ((c = getopt_long(argc, argv, "", longopts, NULL)) != -1) {
1809                 switch (c) {
1810                 case OPT_BLOCK_SIZE:
1811                         block_size = strtoul(optarg, &tmp, 10);
1812                         if (block_size <= 0 || *tmp)
1813                                 die("Invalid block size: %s", optarg);
1814                         break;
1815                 case OPT_DECRYPT:
1816                         decrypting = true;
1817                         break;
1818                 case OPT_FILE_NONCE:
1819                         if (hex2bin(optarg, _file_nonce, FILE_NONCE_SIZE) !=
1820                             FILE_NONCE_SIZE)
1821                                 die("Invalid file nonce: %s", optarg);
1822                         file_nonce = _file_nonce;
1823                         break;
1824                 case OPT_HELP:
1825                         usage(stdout);
1826                         return 0;
1827                 case OPT_KDF:
1828                         kdf = parse_kdf_algorithm(optarg);
1829                         break;
1830                 case OPT_MODE_NUM:
1831                         mode_num = parse_mode_number(optarg);
1832                         break;
1833                 case OPT_PADDING:
1834                         padding = strtoul(optarg, &tmp, 10);
1835                         if (padding <= 0 || *tmp || !is_power_of_2(padding) ||
1836                             padding > INT_MAX)
1837                                 die("Invalid padding amount: %s", optarg);
1838                         break;
1839                 default:
1840                         usage(stderr);
1841                         return 2;
1842                 }
1843         }
1844         argc -= optind;
1845         argv += optind;
1846
1847         if (argc != 2) {
1848                 usage(stderr);
1849                 return 2;
1850         }
1851
1852         cipher = find_fscrypt_cipher(argv[0]);
1853         if (cipher == NULL)
1854                 die("Unknown cipher: %s", argv[0]);
1855
1856         if (block_size < cipher->min_input_size)
1857                 die("Block size of %zu bytes is too small for cipher %s",
1858                     block_size, cipher->name);
1859
1860         master_key_size = hex2bin(argv[1], master_key, MAX_KEY_SIZE);
1861         if (master_key_size < 0)
1862                 die("Invalid master_key: %s", argv[1]);
1863         if (master_key_size < cipher->keysize)
1864                 die("Master key is too short for cipher %s", cipher->name);
1865
1866         get_key_and_iv(master_key, master_key_size, kdf, mode_num, file_nonce,
1867                        real_key, cipher->keysize, &iv);
1868
1869         crypt_loop(cipher, real_key, &iv, decrypting, block_size, padding);
1870         return 0;
1871 }