fscrypt-crypt-util: clean up parsing --block-size and --inode-number
[xfstests-dev.git] / src / fscrypt-crypt-util.c
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * fscrypt-crypt-util.c - utility for verifying fscrypt-encrypted data
4  *
5  * Copyright 2019 Google LLC
6  */
7
8 /*
9  * This program implements all crypto algorithms supported by fscrypt (a.k.a.
10  * ext4, f2fs, and ubifs encryption), for the purpose of verifying the
11  * correctness of the ciphertext stored on-disk.  See usage() below.
12  *
13  * All algorithms are implemented in portable C code to avoid depending on
14  * libcrypto (OpenSSL), and because some fscrypt-supported algorithms aren't
15  * available in libcrypto anyway (e.g. Adiantum), or are only supported in
16  * recent versions (e.g. HKDF-SHA512).  For simplicity, all crypto code here
17  * tries to follow the mathematical definitions directly, without optimizing for
18  * performance or worrying about following security best practices such as
19  * mitigating side-channel attacks.  So, only use this program for testing!
20  */
21
22 #include <asm/byteorder.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <limits.h>
26 #include <linux/types.h>
27 #include <stdarg.h>
28 #include <stdbool.h>
29 #include <stdint.h>
30 #include <stdio.h>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <unistd.h>
34
35 #define PROGRAM_NAME "fscrypt-crypt-util"
36
37 /*
38  * Define to enable the tests of the crypto code in this file.  If enabled, you
39  * must link this program with OpenSSL (-lcrypto) v1.1.0 or later, and your
40  * kernel needs CONFIG_CRYPTO_USER_API_SKCIPHER=y and CONFIG_CRYPTO_ADIANTUM=y.
41  */
42 #undef ENABLE_ALG_TESTS
43
44 #define NUM_ALG_TEST_ITERATIONS 10000
45
46 static void usage(FILE *fp)
47 {
48         fputs(
49 "Usage: " PROGRAM_NAME " [OPTION]... CIPHER MASTER_KEY\n"
50 "\n"
51 "Utility for verifying fscrypt-encrypted data.  This program encrypts\n"
52 "(or decrypts) the data on stdin using the given CIPHER with the given\n"
53 "MASTER_KEY (or a key derived from it, if a KDF is specified), and writes the\n"
54 "resulting ciphertext (or plaintext) to stdout.\n"
55 "\n"
56 "CIPHER can be AES-256-XTS, AES-256-CTS-CBC, AES-128-CBC-ESSIV, AES-128-CTS-CBC,\n"
57 "or Adiantum.  MASTER_KEY must be a hex string long enough for the cipher.\n"
58 "\n"
59 "WARNING: this program is only meant for testing, not for \"real\" use!\n"
60 "\n"
61 "Options:\n"
62 "  --block-size=BLOCK_SIZE     Encrypt each BLOCK_SIZE bytes independently.\n"
63 "                                Default: 4096 bytes\n"
64 "  --decrypt                   Decrypt instead of encrypt\n"
65 "  --file-nonce=NONCE          File's nonce as a 32-character hex string\n"
66 "  --fs-uuid=UUID              The filesystem UUID as a 32-character hex string.\n"
67 "                                Required for --iv-ino-lblk-32 and\n"
68 "                                --iv-ino-lblk-64; otherwise is unused.\n"
69 "  --help                      Show this help\n"
70 "  --inode-number=INUM         The file's inode number.  Required for\n"
71 "                                --iv-ino-lblk-32 and --iv-ino-lblk-64;\n"
72 "                                otherwise is unused.\n"
73 "  --iv-ino-lblk-32            Similar to --iv-ino-lblk-64, but selects the\n"
74 "                                32-bit variant.\n"
75 "  --iv-ino-lblk-64            Use the format where the IVs include the inode\n"
76 "                                number and the same key is shared across files.\n"
77 "                                Requires --kdf=HKDF-SHA512, --fs-uuid,\n"
78 "                                --inode-number, and --mode-num.\n"
79 "  --kdf=KDF                   Key derivation function to use: AES-128-ECB,\n"
80 "                                HKDF-SHA512, or none.  Default: none\n"
81 "  --mode-num=NUM              Derive per-mode key using mode number NUM\n"
82 "  --padding=PADDING           If last block is partial, zero-pad it to next\n"
83 "                                PADDING-byte boundary.  Default: BLOCK_SIZE\n"
84         , fp);
85 }
86
87 /*----------------------------------------------------------------------------*
88  *                                 Utilities                                  *
89  *----------------------------------------------------------------------------*/
90
91 #define ARRAY_SIZE(A)           (sizeof(A) / sizeof((A)[0]))
92 #define MIN(x, y)               ((x) < (y) ? (x) : (y))
93 #define MAX(x, y)               ((x) > (y) ? (x) : (y))
94 #define ROUND_DOWN(x, y)        ((x) & ~((y) - 1))
95 #define ROUND_UP(x, y)          (((x) + (y) - 1) & ~((y) - 1))
96 #define DIV_ROUND_UP(n, d)      (((n) + (d) - 1) / (d))
97 #define STATIC_ASSERT(e)        ((void)sizeof(char[1 - 2*!(e)]))
98
99 typedef __u8                    u8;
100 typedef __u16                   u16;
101 typedef __u32                   u32;
102 typedef __u64                   u64;
103
104 #define cpu_to_le32             __cpu_to_le32
105 #define cpu_to_be32             __cpu_to_be32
106 #define cpu_to_le64             __cpu_to_le64
107 #define cpu_to_be64             __cpu_to_be64
108 #define le32_to_cpu             __le32_to_cpu
109 #define be32_to_cpu             __be32_to_cpu
110 #define le64_to_cpu             __le64_to_cpu
111 #define be64_to_cpu             __be64_to_cpu
112
113 #define DEFINE_UNALIGNED_ACCESS_HELPERS(type, native_type)      \
114 static inline native_type __attribute__((unused))               \
115 get_unaligned_##type(const void *p)                             \
116 {                                                               \
117         __##type x;                                             \
118                                                                 \
119         memcpy(&x, p, sizeof(x));                               \
120         return type##_to_cpu(x);                                \
121 }                                                               \
122                                                                 \
123 static inline void __attribute__((unused))                      \
124 put_unaligned_##type(native_type v, void *p)                    \
125 {                                                               \
126         __##type x = cpu_to_##type(v);                          \
127                                                                 \
128         memcpy(p, &x, sizeof(x));                               \
129 }
130
131 DEFINE_UNALIGNED_ACCESS_HELPERS(le32, u32)
132 DEFINE_UNALIGNED_ACCESS_HELPERS(be32, u32)
133 DEFINE_UNALIGNED_ACCESS_HELPERS(le64, u64)
134 DEFINE_UNALIGNED_ACCESS_HELPERS(be64, u64)
135
136 static inline bool is_power_of_2(unsigned long v)
137 {
138         return v != 0 && (v & (v - 1)) == 0;
139 }
140
141 static inline u32 rol32(u32 v, int n)
142 {
143         return (v << n) | (v >> (32 - n));
144 }
145
146 static inline u32 ror32(u32 v, int n)
147 {
148         return (v >> n) | (v << (32 - n));
149 }
150
151 static inline u64 rol64(u64 v, int n)
152 {
153         return (v << n) | (v >> (64 - n));
154 }
155
156 static inline u64 ror64(u64 v, int n)
157 {
158         return (v >> n) | (v << (64 - n));
159 }
160
161 static inline void xor(u8 *res, const u8 *a, const u8 *b, size_t count)
162 {
163         while (count--)
164                 *res++ = *a++ ^ *b++;
165 }
166
167 static void __attribute__((noreturn, format(printf, 2, 3)))
168 do_die(int err, const char *format, ...)
169 {
170         va_list va;
171
172         va_start(va, format);
173         fputs("[" PROGRAM_NAME "] ERROR: ", stderr);
174         vfprintf(stderr, format, va);
175         if (err)
176                 fprintf(stderr, ": %s", strerror(errno));
177         putc('\n', stderr);
178         va_end(va);
179         exit(1);
180 }
181
182 #define die(format, ...)        do_die(0,     (format), ##__VA_ARGS__)
183 #define die_errno(format, ...)  do_die(errno, (format), ##__VA_ARGS__)
184
185 static __attribute__((noreturn)) void
186 assertion_failed(const char *expr, const char *file, int line)
187 {
188         die("Assertion failed: %s at %s:%d", expr, file, line);
189 }
190
191 #define ASSERT(e) ({ if (!(e)) assertion_failed(#e, __FILE__, __LINE__); })
192
193 static void *xmalloc(size_t size)
194 {
195         void *p = malloc(size);
196
197         ASSERT(p != NULL);
198         return p;
199 }
200
201 static int hexchar2bin(char c)
202 {
203         if (c >= 'a' && c <= 'f')
204                 return 10 + c - 'a';
205         if (c >= 'A' && c <= 'F')
206                 return 10 + c - 'A';
207         if (c >= '0' && c <= '9')
208                 return c - '0';
209         return -1;
210 }
211
212 static int hex2bin(const char *hex, u8 *bin, int max_bin_size)
213 {
214         size_t len = strlen(hex);
215         size_t i;
216
217         if (len & 1)
218                 return -1;
219         len /= 2;
220         if (len > max_bin_size)
221                 return -1;
222
223         for (i = 0; i < len; i++) {
224                 int high = hexchar2bin(hex[2 * i]);
225                 int low = hexchar2bin(hex[2 * i + 1]);
226
227                 if (high < 0 || low < 0)
228                         return -1;
229                 bin[i] = (high << 4) | low;
230         }
231         return len;
232 }
233
234 static size_t xread(int fd, void *buf, size_t count)
235 {
236         const size_t orig_count = count;
237
238         while (count) {
239                 ssize_t res = read(fd, buf, count);
240
241                 if (res < 0)
242                         die_errno("read error");
243                 if (res == 0)
244                         break;
245                 buf += res;
246                 count -= res;
247         }
248         return orig_count - count;
249 }
250
251 static void full_write(int fd, const void *buf, size_t count)
252 {
253         while (count) {
254                 ssize_t res = write(fd, buf, count);
255
256                 if (res < 0)
257                         die_errno("write error");
258                 buf += res;
259                 count -= res;
260         }
261 }
262
263 #ifdef ENABLE_ALG_TESTS
264 static void rand_bytes(u8 *buf, size_t count)
265 {
266         while (count--)
267                 *buf++ = rand();
268 }
269 #endif
270
271 /*----------------------------------------------------------------------------*
272  *                          Finite field arithmetic                           *
273  *----------------------------------------------------------------------------*/
274
275 /* Multiply a GF(2^8) element by the polynomial 'x' */
276 static inline u8 gf2_8_mul_x(u8 b)
277 {
278         return (b << 1) ^ ((b & 0x80) ? 0x1B : 0);
279 }
280
281 /* Multiply four packed GF(2^8) elements by the polynomial 'x' */
282 static inline u32 gf2_8_mul_x_4way(u32 w)
283 {
284         return ((w & 0x7F7F7F7F) << 1) ^ (((w & 0x80808080) >> 7) * 0x1B);
285 }
286
287 /* Element of GF(2^128) */
288 typedef struct {
289         __le64 lo;
290         __le64 hi;
291 } ble128;
292
293 /* Multiply a GF(2^128) element by the polynomial 'x' */
294 static inline void gf2_128_mul_x(ble128 *t)
295 {
296         u64 lo = le64_to_cpu(t->lo);
297         u64 hi = le64_to_cpu(t->hi);
298
299         t->hi = cpu_to_le64((hi << 1) | (lo >> 63));
300         t->lo = cpu_to_le64((lo << 1) ^ ((hi & (1ULL << 63)) ? 0x87 : 0));
301 }
302
303 /*----------------------------------------------------------------------------*
304  *                             Group arithmetic                               *
305  *----------------------------------------------------------------------------*/
306
307 /* Element of Z/(2^{128}Z)  (a.k.a. the integers modulo 2^128) */
308 typedef struct {
309         __le64 lo;
310         __le64 hi;
311 } le128;
312
313 static inline void le128_add(le128 *res, const le128 *a, const le128 *b)
314 {
315         u64 a_lo = le64_to_cpu(a->lo);
316         u64 b_lo = le64_to_cpu(b->lo);
317
318         res->lo = cpu_to_le64(a_lo + b_lo);
319         res->hi = cpu_to_le64(le64_to_cpu(a->hi) + le64_to_cpu(b->hi) +
320                               (a_lo + b_lo < a_lo));
321 }
322
323 static inline void le128_sub(le128 *res, const le128 *a, const le128 *b)
324 {
325         u64 a_lo = le64_to_cpu(a->lo);
326         u64 b_lo = le64_to_cpu(b->lo);
327
328         res->lo = cpu_to_le64(a_lo - b_lo);
329         res->hi = cpu_to_le64(le64_to_cpu(a->hi) - le64_to_cpu(b->hi) -
330                               (a_lo - b_lo > a_lo));
331 }
332
333 /*----------------------------------------------------------------------------*
334  *                              AES block cipher                              *
335  *----------------------------------------------------------------------------*/
336
337 /*
338  * Reference: "FIPS 197, Advanced Encryption Standard"
339  *      https://nvlpubs.nist.gov/nistpubs/fips/nist.fips.197.pdf
340  */
341
342 #define AES_BLOCK_SIZE          16
343 #define AES_128_KEY_SIZE        16
344 #define AES_192_KEY_SIZE        24
345 #define AES_256_KEY_SIZE        32
346
347 static inline void AddRoundKey(u32 state[4], const u32 *rk)
348 {
349         int i;
350
351         for (i = 0; i < 4; i++)
352                 state[i] ^= rk[i];
353 }
354
355 static const u8 aes_sbox[256] = {
356         0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b,
357         0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0,
358         0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26,
359         0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
360         0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2,
361         0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0,
362         0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed,
363         0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
364         0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f,
365         0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5,
366         0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec,
367         0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
368         0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14,
369         0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c,
370         0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d,
371         0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
372         0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f,
373         0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e,
374         0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11,
375         0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
376         0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f,
377         0xb0, 0x54, 0xbb, 0x16,
378 };
379
380 static u8 aes_inverse_sbox[256];
381
382 static void aes_init(void)
383 {
384         int i;
385
386         for (i = 0; i < 256; i++)
387                 aes_inverse_sbox[aes_sbox[i]] = i;
388 }
389
390 static inline u32 DoSubWord(u32 w, const u8 sbox[256])
391 {
392         return ((u32)sbox[(u8)(w >> 24)] << 24) |
393                ((u32)sbox[(u8)(w >> 16)] << 16) |
394                ((u32)sbox[(u8)(w >>  8)] <<  8) |
395                ((u32)sbox[(u8)(w >>  0)] <<  0);
396 }
397
398 static inline u32 SubWord(u32 w)
399 {
400         return DoSubWord(w, aes_sbox);
401 }
402
403 static inline u32 InvSubWord(u32 w)
404 {
405         return DoSubWord(w, aes_inverse_sbox);
406 }
407
408 static inline void SubBytes(u32 state[4])
409 {
410         int i;
411
412         for (i = 0; i < 4; i++)
413                 state[i] = SubWord(state[i]);
414 }
415
416 static inline void InvSubBytes(u32 state[4])
417 {
418         int i;
419
420         for (i = 0; i < 4; i++)
421                 state[i] = InvSubWord(state[i]);
422 }
423
424 static inline void DoShiftRows(u32 state[4], int direction)
425 {
426         u32 newstate[4];
427         int i;
428
429         for (i = 0; i < 4; i++)
430                 newstate[i] = (state[(i + direction*0) & 3] & 0xff) |
431                               (state[(i + direction*1) & 3] & 0xff00) |
432                               (state[(i + direction*2) & 3] & 0xff0000) |
433                               (state[(i + direction*3) & 3] & 0xff000000);
434         memcpy(state, newstate, 16);
435 }
436
437 static inline void ShiftRows(u32 state[4])
438 {
439         DoShiftRows(state, 1);
440 }
441
442 static inline void InvShiftRows(u32 state[4])
443 {
444         DoShiftRows(state, -1);
445 }
446
447 /*
448  * Mix one column by doing the following matrix multiplication in GF(2^8):
449  *
450  *     | 2 3 1 1 |   | w[0] |
451  *     | 1 2 3 1 |   | w[1] |
452  *     | 1 1 2 3 | x | w[2] |
453  *     | 3 1 1 2 |   | w[3] |
454  *
455  * a.k.a. w[i] = 2*w[i] + 3*w[(i+1)%4] + w[(i+2)%4] + w[(i+3)%4]
456  */
457 static inline u32 MixColumn(u32 w)
458 {
459         u32 _2w0_w2 = gf2_8_mul_x_4way(w) ^ ror32(w, 16);
460         u32 _3w1_w3 = ror32(_2w0_w2 ^ w, 8);
461
462         return _2w0_w2 ^ _3w1_w3;
463 }
464
465 /*
466  *           ( | 5 0 4 0 |   | w[0] | )
467  *          (  | 0 5 0 4 |   | w[1] |  )
468  * MixColumn(  | 4 0 5 0 | x | w[2] |  )
469  *           ( | 0 4 0 5 |   | w[3] | )
470  */
471 static inline u32 InvMixColumn(u32 w)
472 {
473         u32 _4w = gf2_8_mul_x_4way(gf2_8_mul_x_4way(w));
474
475         return MixColumn(_4w ^ w ^ ror32(_4w, 16));
476 }
477
478 static inline void MixColumns(u32 state[4])
479 {
480         int i;
481
482         for (i = 0; i < 4; i++)
483                 state[i] = MixColumn(state[i]);
484 }
485
486 static inline void InvMixColumns(u32 state[4])
487 {
488         int i;
489
490         for (i = 0; i < 4; i++)
491                 state[i] = InvMixColumn(state[i]);
492 }
493
494 struct aes_key {
495         u32 round_keys[15 * 4];
496         int nrounds;
497 };
498
499 /* Expand an AES key */
500 static void aes_setkey(struct aes_key *k, const u8 *key, int keysize)
501 {
502         const int N = keysize / 4;
503         u32 * const rk = k->round_keys;
504         u8 rcon = 1;
505         int i;
506
507         ASSERT(keysize == 16 || keysize == 24 || keysize == 32);
508         k->nrounds = 6 + N;
509         for (i = 0; i < 4 * (k->nrounds + 1); i++) {
510                 if (i < N) {
511                         rk[i] = get_unaligned_le32(&key[i * sizeof(__le32)]);
512                 } else if (i % N == 0) {
513                         rk[i] = rk[i - N] ^ SubWord(ror32(rk[i - 1], 8)) ^ rcon;
514                         rcon = gf2_8_mul_x(rcon);
515                 } else if (N > 6 && i % N == 4) {
516                         rk[i] = rk[i - N] ^ SubWord(rk[i - 1]);
517                 } else {
518                         rk[i] = rk[i - N] ^ rk[i - 1];
519                 }
520         }
521 }
522
523 /* Encrypt one 16-byte block with AES */
524 static void aes_encrypt(const struct aes_key *k, const u8 src[AES_BLOCK_SIZE],
525                         u8 dst[AES_BLOCK_SIZE])
526 {
527         u32 state[4];
528         int i;
529
530         for (i = 0; i < 4; i++)
531                 state[i] = get_unaligned_le32(&src[i * sizeof(__le32)]);
532
533         AddRoundKey(state, k->round_keys);
534         for (i = 1; i < k->nrounds; i++) {
535                 SubBytes(state);
536                 ShiftRows(state);
537                 MixColumns(state);
538                 AddRoundKey(state, &k->round_keys[4 * i]);
539         }
540         SubBytes(state);
541         ShiftRows(state);
542         AddRoundKey(state, &k->round_keys[4 * i]);
543
544         for (i = 0; i < 4; i++)
545                 put_unaligned_le32(state[i], &dst[i * sizeof(__le32)]);
546 }
547
548 /* Decrypt one 16-byte block with AES */
549 static void aes_decrypt(const struct aes_key *k, const u8 src[AES_BLOCK_SIZE],
550                         u8 dst[AES_BLOCK_SIZE])
551 {
552         u32 state[4];
553         int i;
554
555         for (i = 0; i < 4; i++)
556                 state[i] = get_unaligned_le32(&src[i * sizeof(__le32)]);
557
558         AddRoundKey(state, &k->round_keys[4 * k->nrounds]);
559         InvShiftRows(state);
560         InvSubBytes(state);
561         for (i = k->nrounds - 1; i >= 1; i--) {
562                 AddRoundKey(state, &k->round_keys[4 * i]);
563                 InvMixColumns(state);
564                 InvShiftRows(state);
565                 InvSubBytes(state);
566         }
567         AddRoundKey(state, k->round_keys);
568
569         for (i = 0; i < 4; i++)
570                 put_unaligned_le32(state[i], &dst[i * sizeof(__le32)]);
571 }
572
573 #ifdef ENABLE_ALG_TESTS
574 #include <openssl/aes.h>
575 static void test_aes_keysize(int keysize)
576 {
577         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
578
579         while (num_tests--) {
580                 struct aes_key k;
581                 AES_KEY ref_k;
582                 u8 key[AES_256_KEY_SIZE];
583                 u8 ptext[AES_BLOCK_SIZE];
584                 u8 ctext[AES_BLOCK_SIZE];
585                 u8 ref_ctext[AES_BLOCK_SIZE];
586                 u8 decrypted[AES_BLOCK_SIZE];
587
588                 rand_bytes(key, keysize);
589                 rand_bytes(ptext, AES_BLOCK_SIZE);
590
591                 aes_setkey(&k, key, keysize);
592                 aes_encrypt(&k, ptext, ctext);
593
594                 ASSERT(AES_set_encrypt_key(key, keysize*8, &ref_k) == 0);
595                 AES_encrypt(ptext, ref_ctext, &ref_k);
596
597                 ASSERT(memcmp(ctext, ref_ctext, AES_BLOCK_SIZE) == 0);
598
599                 aes_decrypt(&k, ctext, decrypted);
600                 ASSERT(memcmp(ptext, decrypted, AES_BLOCK_SIZE) == 0);
601         }
602 }
603
604 static void test_aes(void)
605 {
606         test_aes_keysize(AES_128_KEY_SIZE);
607         test_aes_keysize(AES_192_KEY_SIZE);
608         test_aes_keysize(AES_256_KEY_SIZE);
609 }
610 #endif /* ENABLE_ALG_TESTS */
611
612 /*----------------------------------------------------------------------------*
613  *                            SHA-512 and SHA-256                             *
614  *----------------------------------------------------------------------------*/
615
616 /*
617  * Reference: "FIPS 180-2, Secure Hash Standard"
618  *      https://csrc.nist.gov/csrc/media/publications/fips/180/2/archive/2002-08-01/documents/fips180-2withchangenotice.pdf
619  */
620
621 #define SHA512_DIGEST_SIZE      64
622 #define SHA512_BLOCK_SIZE       128
623
624 #define SHA256_DIGEST_SIZE      32
625 #define SHA256_BLOCK_SIZE       64
626
627 #define Ch(x, y, z)     (((x) & (y)) ^ (~(x) & (z)))
628 #define Maj(x, y, z)    (((x) & (y)) ^ ((x) & (z)) ^ ((y) & (z)))
629
630 #define Sigma512_0(x)   (ror64((x), 28) ^ ror64((x), 34) ^ ror64((x), 39))
631 #define Sigma512_1(x)   (ror64((x), 14) ^ ror64((x), 18) ^ ror64((x), 41))
632 #define sigma512_0(x)   (ror64((x),  1) ^ ror64((x),  8) ^ ((x) >> 7))
633 #define sigma512_1(x)   (ror64((x), 19) ^ ror64((x), 61) ^ ((x) >> 6))
634
635 #define Sigma256_0(x)   (ror32((x),  2) ^ ror32((x), 13) ^ ror32((x), 22))
636 #define Sigma256_1(x)   (ror32((x),  6) ^ ror32((x), 11) ^ ror32((x), 25))
637 #define sigma256_0(x)   (ror32((x),  7) ^ ror32((x), 18) ^ ((x) >>  3))
638 #define sigma256_1(x)   (ror32((x), 17) ^ ror32((x), 19) ^ ((x) >> 10))
639
640 static const u64 sha512_iv[8] = {
641         0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b,
642         0xa54ff53a5f1d36f1, 0x510e527fade682d1, 0x9b05688c2b3e6c1f,
643         0x1f83d9abfb41bd6b, 0x5be0cd19137e2179,
644 };
645
646 static const u64 sha512_round_constants[80] = {
647         0x428a2f98d728ae22, 0x7137449123ef65cd, 0xb5c0fbcfec4d3b2f,
648         0xe9b5dba58189dbbc, 0x3956c25bf348b538, 0x59f111f1b605d019,
649         0x923f82a4af194f9b, 0xab1c5ed5da6d8118, 0xd807aa98a3030242,
650         0x12835b0145706fbe, 0x243185be4ee4b28c, 0x550c7dc3d5ffb4e2,
651         0x72be5d74f27b896f, 0x80deb1fe3b1696b1, 0x9bdc06a725c71235,
652         0xc19bf174cf692694, 0xe49b69c19ef14ad2, 0xefbe4786384f25e3,
653         0x0fc19dc68b8cd5b5, 0x240ca1cc77ac9c65, 0x2de92c6f592b0275,
654         0x4a7484aa6ea6e483, 0x5cb0a9dcbd41fbd4, 0x76f988da831153b5,
655         0x983e5152ee66dfab, 0xa831c66d2db43210, 0xb00327c898fb213f,
656         0xbf597fc7beef0ee4, 0xc6e00bf33da88fc2, 0xd5a79147930aa725,
657         0x06ca6351e003826f, 0x142929670a0e6e70, 0x27b70a8546d22ffc,
658         0x2e1b21385c26c926, 0x4d2c6dfc5ac42aed, 0x53380d139d95b3df,
659         0x650a73548baf63de, 0x766a0abb3c77b2a8, 0x81c2c92e47edaee6,
660         0x92722c851482353b, 0xa2bfe8a14cf10364, 0xa81a664bbc423001,
661         0xc24b8b70d0f89791, 0xc76c51a30654be30, 0xd192e819d6ef5218,
662         0xd69906245565a910, 0xf40e35855771202a, 0x106aa07032bbd1b8,
663         0x19a4c116b8d2d0c8, 0x1e376c085141ab53, 0x2748774cdf8eeb99,
664         0x34b0bcb5e19b48a8, 0x391c0cb3c5c95a63, 0x4ed8aa4ae3418acb,
665         0x5b9cca4f7763e373, 0x682e6ff3d6b2b8a3, 0x748f82ee5defb2fc,
666         0x78a5636f43172f60, 0x84c87814a1f0ab72, 0x8cc702081a6439ec,
667         0x90befffa23631e28, 0xa4506cebde82bde9, 0xbef9a3f7b2c67915,
668         0xc67178f2e372532b, 0xca273eceea26619c, 0xd186b8c721c0c207,
669         0xeada7dd6cde0eb1e, 0xf57d4f7fee6ed178, 0x06f067aa72176fba,
670         0x0a637dc5a2c898a6, 0x113f9804bef90dae, 0x1b710b35131c471b,
671         0x28db77f523047d84, 0x32caab7b40c72493, 0x3c9ebe0a15c9bebc,
672         0x431d67c49c100d4c, 0x4cc5d4becb3e42b6, 0x597f299cfc657e2a,
673         0x5fcb6fab3ad6faec, 0x6c44198c4a475817,
674 };
675
676 /* Compute the SHA-512 digest of the given buffer */
677 static void sha512(const u8 *in, size_t inlen, u8 out[SHA512_DIGEST_SIZE])
678 {
679         const size_t msglen = ROUND_UP(inlen + 17, SHA512_BLOCK_SIZE);
680         u8 * const msg = xmalloc(msglen);
681         u64 H[8];
682         int i;
683
684         /* super naive way of handling the padding */
685         memcpy(msg, in, inlen);
686         memset(&msg[inlen], 0, msglen - inlen);
687         msg[inlen] = 0x80;
688         put_unaligned_be64((u64)inlen * 8, &msg[msglen - sizeof(__be64)]);
689         in = msg;
690
691         memcpy(H, sha512_iv, sizeof(H));
692         do {
693                 u64 a = H[0], b = H[1], c = H[2], d = H[3],
694                     e = H[4], f = H[5], g = H[6], h = H[7];
695                 u64 W[80];
696
697                 for (i = 0; i < 16; i++)
698                         W[i] = get_unaligned_be64(&in[i * sizeof(__be64)]);
699                 for (; i < ARRAY_SIZE(W); i++)
700                         W[i] = sigma512_1(W[i - 2]) + W[i - 7] +
701                                sigma512_0(W[i - 15]) + W[i - 16];
702                 for (i = 0; i < ARRAY_SIZE(W); i++) {
703                         u64 T1 = h + Sigma512_1(e) + Ch(e, f, g) +
704                                  sha512_round_constants[i] + W[i];
705                         u64 T2 = Sigma512_0(a) + Maj(a, b, c);
706
707                         h = g; g = f; f = e; e = d + T1;
708                         d = c; c = b; b = a; a = T1 + T2;
709                 }
710                 H[0] += a; H[1] += b; H[2] += c; H[3] += d;
711                 H[4] += e; H[5] += f; H[6] += g; H[7] += h;
712         } while ((in += SHA512_BLOCK_SIZE) != &msg[msglen]);
713
714         for (i = 0; i < ARRAY_SIZE(H); i++)
715                 put_unaligned_be64(H[i], &out[i * sizeof(__be64)]);
716         free(msg);
717 }
718
719 /* Compute the SHA-256 digest of the given buffer */
720 static void sha256(const u8 *in, size_t inlen, u8 out[SHA256_DIGEST_SIZE])
721 {
722         const size_t msglen = ROUND_UP(inlen + 9, SHA256_BLOCK_SIZE);
723         u8 * const msg = xmalloc(msglen);
724         u32 H[8];
725         int i;
726
727         /* super naive way of handling the padding */
728         memcpy(msg, in, inlen);
729         memset(&msg[inlen], 0, msglen - inlen);
730         msg[inlen] = 0x80;
731         put_unaligned_be64((u64)inlen * 8, &msg[msglen - sizeof(__be64)]);
732         in = msg;
733
734         for (i = 0; i < ARRAY_SIZE(H); i++)
735                 H[i] = (u32)(sha512_iv[i] >> 32);
736         do {
737                 u32 a = H[0], b = H[1], c = H[2], d = H[3],
738                     e = H[4], f = H[5], g = H[6], h = H[7];
739                 u32 W[64];
740
741                 for (i = 0; i < 16; i++)
742                         W[i] = get_unaligned_be32(&in[i * sizeof(__be32)]);
743                 for (; i < ARRAY_SIZE(W); i++)
744                         W[i] = sigma256_1(W[i - 2]) + W[i - 7] +
745                                sigma256_0(W[i - 15]) + W[i - 16];
746                 for (i = 0; i < ARRAY_SIZE(W); i++) {
747                         u32 T1 = h + Sigma256_1(e) + Ch(e, f, g) +
748                                  (u32)(sha512_round_constants[i] >> 32) + W[i];
749                         u32 T2 = Sigma256_0(a) + Maj(a, b, c);
750
751                         h = g; g = f; f = e; e = d + T1;
752                         d = c; c = b; b = a; a = T1 + T2;
753                 }
754                 H[0] += a; H[1] += b; H[2] += c; H[3] += d;
755                 H[4] += e; H[5] += f; H[6] += g; H[7] += h;
756         } while ((in += SHA256_BLOCK_SIZE) != &msg[msglen]);
757
758         for (i = 0; i < ARRAY_SIZE(H); i++)
759                 put_unaligned_be32(H[i], &out[i * sizeof(__be32)]);
760         free(msg);
761 }
762
763 #ifdef ENABLE_ALG_TESTS
764 #include <openssl/sha.h>
765 static void test_sha2(void)
766 {
767         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
768
769         while (num_tests--) {
770                 u8 in[4096];
771                 u8 digest[SHA512_DIGEST_SIZE];
772                 u8 ref_digest[SHA512_DIGEST_SIZE];
773                 const size_t inlen = rand() % (1 + sizeof(in));
774
775                 rand_bytes(in, inlen);
776
777                 sha256(in, inlen, digest);
778                 SHA256(in, inlen, ref_digest);
779                 ASSERT(memcmp(digest, ref_digest, SHA256_DIGEST_SIZE) == 0);
780
781                 sha512(in, inlen, digest);
782                 SHA512(in, inlen, ref_digest);
783                 ASSERT(memcmp(digest, ref_digest, SHA512_DIGEST_SIZE) == 0);
784         }
785 }
786 #endif /* ENABLE_ALG_TESTS */
787
788 /*----------------------------------------------------------------------------*
789  *                            HKDF implementation                             *
790  *----------------------------------------------------------------------------*/
791
792 static void hmac_sha512(const u8 *key, size_t keylen, const u8 *msg,
793                         size_t msglen, u8 mac[SHA512_DIGEST_SIZE])
794 {
795         u8 *ibuf = xmalloc(SHA512_BLOCK_SIZE + msglen);
796         u8 obuf[SHA512_BLOCK_SIZE + SHA512_DIGEST_SIZE];
797
798         ASSERT(keylen <= SHA512_BLOCK_SIZE); /* keylen > bs not implemented */
799
800         memset(ibuf, 0x36, SHA512_BLOCK_SIZE);
801         xor(ibuf, ibuf, key, keylen);
802         memcpy(&ibuf[SHA512_BLOCK_SIZE], msg, msglen);
803
804         memset(obuf, 0x5c, SHA512_BLOCK_SIZE);
805         xor(obuf, obuf, key, keylen);
806         sha512(ibuf, SHA512_BLOCK_SIZE + msglen, &obuf[SHA512_BLOCK_SIZE]);
807         sha512(obuf, sizeof(obuf), mac);
808
809         free(ibuf);
810 }
811
812 static void hkdf_sha512(const u8 *ikm, size_t ikmlen,
813                         const u8 *salt, size_t saltlen,
814                         const u8 *info, size_t infolen,
815                         u8 *output, size_t outlen)
816 {
817         static const u8 default_salt[SHA512_DIGEST_SIZE];
818         u8 prk[SHA512_DIGEST_SIZE]; /* pseudorandom key */
819         u8 *buf = xmalloc(1 + infolen + SHA512_DIGEST_SIZE);
820         u8 counter = 1;
821         size_t i;
822
823         if (saltlen == 0) {
824                 salt = default_salt;
825                 saltlen = sizeof(default_salt);
826         }
827
828         /* HKDF-Extract */
829         ASSERT(ikmlen > 0);
830         hmac_sha512(salt, saltlen, ikm, ikmlen, prk);
831
832         /* HKDF-Expand */
833         for (i = 0; i < outlen; i += SHA512_DIGEST_SIZE) {
834                 u8 *p = buf;
835                 u8 tmp[SHA512_DIGEST_SIZE];
836
837                 ASSERT(counter != 0);
838                 if (i > 0) {
839                         memcpy(p, &output[i - SHA512_DIGEST_SIZE],
840                                SHA512_DIGEST_SIZE);
841                         p += SHA512_DIGEST_SIZE;
842                 }
843                 memcpy(p, info, infolen);
844                 p += infolen;
845                 *p++ = counter++;
846                 hmac_sha512(prk, sizeof(prk), buf, p - buf, tmp);
847                 memcpy(&output[i], tmp, MIN(sizeof(tmp), outlen - i));
848         }
849         free(buf);
850 }
851
852 #ifdef ENABLE_ALG_TESTS
853 #include <openssl/evp.h>
854 #include <openssl/kdf.h>
855 static void openssl_hkdf_sha512(const u8 *ikm, size_t ikmlen,
856                                 const u8 *salt, size_t saltlen,
857                                 const u8 *info, size_t infolen,
858                                 u8 *output, size_t outlen)
859 {
860         EVP_PKEY_CTX *pctx = EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, NULL);
861         size_t actual_outlen = outlen;
862
863         ASSERT(pctx != NULL);
864         ASSERT(EVP_PKEY_derive_init(pctx) > 0);
865         ASSERT(EVP_PKEY_CTX_set_hkdf_md(pctx, EVP_sha512()) > 0);
866         ASSERT(EVP_PKEY_CTX_set1_hkdf_key(pctx, ikm, ikmlen) > 0);
867         ASSERT(EVP_PKEY_CTX_set1_hkdf_salt(pctx, salt, saltlen) > 0);
868         ASSERT(EVP_PKEY_CTX_add1_hkdf_info(pctx, info, infolen) > 0);
869         ASSERT(EVP_PKEY_derive(pctx, output, &actual_outlen) > 0);
870         ASSERT(actual_outlen == outlen);
871         EVP_PKEY_CTX_free(pctx);
872 }
873
874 static void test_hkdf_sha512(void)
875 {
876         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
877
878         while (num_tests--) {
879                 u8 ikm[SHA512_DIGEST_SIZE];
880                 u8 salt[SHA512_DIGEST_SIZE];
881                 u8 info[128];
882                 u8 actual_output[512];
883                 u8 expected_output[sizeof(actual_output)];
884                 size_t ikmlen = 1 + (rand() % sizeof(ikm));
885                 size_t saltlen = rand() % (1 + sizeof(salt));
886                 size_t infolen = rand() % (1 + sizeof(info));
887                 size_t outlen = rand() % (1 + sizeof(actual_output));
888
889                 rand_bytes(ikm, ikmlen);
890                 rand_bytes(salt, saltlen);
891                 rand_bytes(info, infolen);
892
893                 hkdf_sha512(ikm, ikmlen, salt, saltlen, info, infolen,
894                             actual_output, outlen);
895                 openssl_hkdf_sha512(ikm, ikmlen, salt, saltlen, info, infolen,
896                                     expected_output, outlen);
897                 ASSERT(memcmp(actual_output, expected_output, outlen) == 0);
898         }
899 }
900 #endif /* ENABLE_ALG_TESTS */
901
902 /*----------------------------------------------------------------------------*
903  *                            AES encryption modes                            *
904  *----------------------------------------------------------------------------*/
905
906 static void aes_256_xts_crypt(const u8 key[2 * AES_256_KEY_SIZE],
907                               const u8 iv[AES_BLOCK_SIZE], const u8 *src,
908                               u8 *dst, size_t nbytes, bool decrypting)
909 {
910         struct aes_key tweak_key, cipher_key;
911         ble128 t;
912         size_t i;
913
914         ASSERT(nbytes % AES_BLOCK_SIZE == 0);
915         aes_setkey(&cipher_key, key, AES_256_KEY_SIZE);
916         aes_setkey(&tweak_key, &key[AES_256_KEY_SIZE], AES_256_KEY_SIZE);
917         aes_encrypt(&tweak_key, iv, (u8 *)&t);
918         for (i = 0; i < nbytes; i += AES_BLOCK_SIZE) {
919                 xor(&dst[i], &src[i], (const u8 *)&t, AES_BLOCK_SIZE);
920                 if (decrypting)
921                         aes_decrypt(&cipher_key, &dst[i], &dst[i]);
922                 else
923                         aes_encrypt(&cipher_key, &dst[i], &dst[i]);
924                 xor(&dst[i], &dst[i], (const u8 *)&t, AES_BLOCK_SIZE);
925                 gf2_128_mul_x(&t);
926         }
927 }
928
929 static void aes_256_xts_encrypt(const u8 key[2 * AES_256_KEY_SIZE],
930                                 const u8 iv[AES_BLOCK_SIZE], const u8 *src,
931                                 u8 *dst, size_t nbytes)
932 {
933         aes_256_xts_crypt(key, iv, src, dst, nbytes, false);
934 }
935
936 static void aes_256_xts_decrypt(const u8 key[2 * AES_256_KEY_SIZE],
937                                 const u8 iv[AES_BLOCK_SIZE], const u8 *src,
938                                 u8 *dst, size_t nbytes)
939 {
940         aes_256_xts_crypt(key, iv, src, dst, nbytes, true);
941 }
942
943 #ifdef ENABLE_ALG_TESTS
944 #include <openssl/evp.h>
945 static void test_aes_256_xts(void)
946 {
947         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
948         EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
949
950         ASSERT(ctx != NULL);
951         while (num_tests--) {
952                 u8 key[2 * AES_256_KEY_SIZE];
953                 u8 iv[AES_BLOCK_SIZE];
954                 u8 ptext[512];
955                 u8 ctext[sizeof(ptext)];
956                 u8 ref_ctext[sizeof(ptext)];
957                 u8 decrypted[sizeof(ptext)];
958                 const size_t datalen = ROUND_DOWN(rand() % (1 + sizeof(ptext)),
959                                                   AES_BLOCK_SIZE);
960                 int outl, res;
961
962                 rand_bytes(key, sizeof(key));
963                 rand_bytes(iv, sizeof(iv));
964                 rand_bytes(ptext, datalen);
965
966                 aes_256_xts_encrypt(key, iv, ptext, ctext, datalen);
967                 res = EVP_EncryptInit_ex(ctx, EVP_aes_256_xts(), NULL, key, iv);
968                 ASSERT(res > 0);
969                 res = EVP_EncryptUpdate(ctx, ref_ctext, &outl, ptext, datalen);
970                 ASSERT(res > 0);
971                 ASSERT(outl == datalen);
972                 ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
973
974                 aes_256_xts_decrypt(key, iv, ctext, decrypted, datalen);
975                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
976         }
977         EVP_CIPHER_CTX_free(ctx);
978 }
979 #endif /* ENABLE_ALG_TESTS */
980
981 static void aes_cbc_encrypt(const struct aes_key *k,
982                             const u8 iv[AES_BLOCK_SIZE],
983                             const u8 *src, u8 *dst, size_t nbytes)
984 {
985         size_t i;
986
987         ASSERT(nbytes % AES_BLOCK_SIZE == 0);
988         for (i = 0; i < nbytes; i += AES_BLOCK_SIZE) {
989                 xor(&dst[i], &src[i], (i == 0 ? iv : &dst[i - AES_BLOCK_SIZE]),
990                     AES_BLOCK_SIZE);
991                 aes_encrypt(k, &dst[i], &dst[i]);
992         }
993 }
994
995 static void aes_cbc_decrypt(const struct aes_key *k,
996                             const u8 iv[AES_BLOCK_SIZE],
997                             const u8 *src, u8 *dst, size_t nbytes)
998 {
999         size_t i = nbytes;
1000
1001         ASSERT(i % AES_BLOCK_SIZE == 0);
1002         while (i) {
1003                 i -= AES_BLOCK_SIZE;
1004                 aes_decrypt(k, &src[i], &dst[i]);
1005                 xor(&dst[i], &dst[i], (i == 0 ? iv : &src[i - AES_BLOCK_SIZE]),
1006                     AES_BLOCK_SIZE);
1007         }
1008 }
1009
1010 static void aes_cts_cbc_encrypt(const u8 *key, int keysize,
1011                                 const u8 iv[AES_BLOCK_SIZE],
1012                                 const u8 *src, u8 *dst, size_t nbytes)
1013 {
1014         const size_t offset = ROUND_DOWN(nbytes - 1, AES_BLOCK_SIZE);
1015         const size_t final_bsize = nbytes - offset;
1016         struct aes_key k;
1017         u8 *pad;
1018         u8 buf[AES_BLOCK_SIZE];
1019
1020         ASSERT(nbytes >= AES_BLOCK_SIZE);
1021
1022         aes_setkey(&k, key, keysize);
1023
1024         if (nbytes == AES_BLOCK_SIZE)
1025                 return aes_cbc_encrypt(&k, iv, src, dst, nbytes);
1026
1027         aes_cbc_encrypt(&k, iv, src, dst, offset);
1028         pad = &dst[offset - AES_BLOCK_SIZE];
1029
1030         memcpy(buf, pad, AES_BLOCK_SIZE);
1031         xor(buf, buf, &src[offset], final_bsize);
1032         memcpy(&dst[offset], pad, final_bsize);
1033         aes_encrypt(&k, buf, pad);
1034 }
1035
1036 static void aes_cts_cbc_decrypt(const u8 *key, int keysize,
1037                                 const u8 iv[AES_BLOCK_SIZE],
1038                                 const u8 *src, u8 *dst, size_t nbytes)
1039 {
1040         const size_t offset = ROUND_DOWN(nbytes - 1, AES_BLOCK_SIZE);
1041         const size_t final_bsize = nbytes - offset;
1042         struct aes_key k;
1043         u8 *pad;
1044
1045         ASSERT(nbytes >= AES_BLOCK_SIZE);
1046
1047         aes_setkey(&k, key, keysize);
1048
1049         if (nbytes == AES_BLOCK_SIZE)
1050                 return aes_cbc_decrypt(&k, iv, src, dst, nbytes);
1051
1052         pad = &dst[offset - AES_BLOCK_SIZE];
1053         aes_decrypt(&k, &src[offset - AES_BLOCK_SIZE], pad);
1054         xor(&dst[offset], &src[offset], pad, final_bsize);
1055         xor(pad, pad, &dst[offset], final_bsize);
1056
1057         aes_cbc_decrypt(&k, (offset == AES_BLOCK_SIZE ?
1058                              iv : &src[offset - 2 * AES_BLOCK_SIZE]),
1059                         pad, pad, AES_BLOCK_SIZE);
1060         aes_cbc_decrypt(&k, iv, src, dst, offset - AES_BLOCK_SIZE);
1061 }
1062
1063 static void aes_256_cts_cbc_encrypt(const u8 key[AES_256_KEY_SIZE],
1064                                     const u8 iv[AES_BLOCK_SIZE],
1065                                     const u8 *src, u8 *dst, size_t nbytes)
1066 {
1067         aes_cts_cbc_encrypt(key, AES_256_KEY_SIZE, iv, src, dst, nbytes);
1068 }
1069
1070 static void aes_256_cts_cbc_decrypt(const u8 key[AES_256_KEY_SIZE],
1071                                     const u8 iv[AES_BLOCK_SIZE],
1072                                     const u8 *src, u8 *dst, size_t nbytes)
1073 {
1074         aes_cts_cbc_decrypt(key, AES_256_KEY_SIZE, iv, src, dst, nbytes);
1075 }
1076
1077 #ifdef ENABLE_ALG_TESTS
1078 #include <openssl/modes.h>
1079 static void aes_block128_f(const unsigned char in[16],
1080                            unsigned char out[16], const void *key)
1081 {
1082         aes_encrypt(key, in, out);
1083 }
1084
1085 static void test_aes_256_cts_cbc(void)
1086 {
1087         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1088
1089         while (num_tests--) {
1090                 u8 key[AES_256_KEY_SIZE];
1091                 u8 iv[AES_BLOCK_SIZE];
1092                 u8 iv_copy[AES_BLOCK_SIZE];
1093                 u8 ptext[512];
1094                 u8 ctext[sizeof(ptext)];
1095                 u8 ref_ctext[sizeof(ptext)];
1096                 u8 decrypted[sizeof(ptext)];
1097                 const size_t datalen = 16 + (rand() % (sizeof(ptext) - 15));
1098                 struct aes_key k;
1099
1100                 rand_bytes(key, sizeof(key));
1101                 rand_bytes(iv, sizeof(iv));
1102                 rand_bytes(ptext, datalen);
1103
1104                 aes_256_cts_cbc_encrypt(key, iv, ptext, ctext, datalen);
1105
1106                 /* OpenSSL doesn't allow datalen=AES_BLOCK_SIZE; Linux does */
1107                 if (datalen != AES_BLOCK_SIZE) {
1108                         aes_setkey(&k, key, sizeof(key));
1109                         memcpy(iv_copy, iv, sizeof(iv));
1110                         ASSERT(CRYPTO_cts128_encrypt_block(ptext, ref_ctext,
1111                                                            datalen, &k, iv_copy,
1112                                                            aes_block128_f)
1113                                == datalen);
1114                         ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1115                 }
1116                 aes_256_cts_cbc_decrypt(key, iv, ctext, decrypted, datalen);
1117                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1118         }
1119 }
1120 #endif /* ENABLE_ALG_TESTS */
1121
1122 static void essiv_generate_iv(const u8 orig_key[AES_128_KEY_SIZE],
1123                               const u8 orig_iv[AES_BLOCK_SIZE],
1124                               u8 real_iv[AES_BLOCK_SIZE])
1125 {
1126         u8 essiv_key[SHA256_DIGEST_SIZE];
1127         struct aes_key essiv;
1128
1129         /* AES encrypt the original IV using a hash of the original key */
1130         STATIC_ASSERT(SHA256_DIGEST_SIZE == AES_256_KEY_SIZE);
1131         sha256(orig_key, AES_128_KEY_SIZE, essiv_key);
1132         aes_setkey(&essiv, essiv_key, AES_256_KEY_SIZE);
1133         aes_encrypt(&essiv, orig_iv, real_iv);
1134 }
1135
1136 static void aes_128_cbc_essiv_encrypt(const u8 key[AES_128_KEY_SIZE],
1137                                       const u8 iv[AES_BLOCK_SIZE],
1138                                       const u8 *src, u8 *dst, size_t nbytes)
1139 {
1140         struct aes_key k;
1141         u8 real_iv[AES_BLOCK_SIZE];
1142
1143         aes_setkey(&k, key, AES_128_KEY_SIZE);
1144         essiv_generate_iv(key, iv, real_iv);
1145         aes_cbc_encrypt(&k, real_iv, src, dst, nbytes);
1146 }
1147
1148 static void aes_128_cbc_essiv_decrypt(const u8 key[AES_128_KEY_SIZE],
1149                                       const u8 iv[AES_BLOCK_SIZE],
1150                                       const u8 *src, u8 *dst, size_t nbytes)
1151 {
1152         struct aes_key k;
1153         u8 real_iv[AES_BLOCK_SIZE];
1154
1155         aes_setkey(&k, key, AES_128_KEY_SIZE);
1156         essiv_generate_iv(key, iv, real_iv);
1157         aes_cbc_decrypt(&k, real_iv, src, dst, nbytes);
1158 }
1159
1160 static void aes_128_cts_cbc_encrypt(const u8 key[AES_128_KEY_SIZE],
1161                                     const u8 iv[AES_BLOCK_SIZE],
1162                                     const u8 *src, u8 *dst, size_t nbytes)
1163 {
1164         aes_cts_cbc_encrypt(key, AES_128_KEY_SIZE, iv, src, dst, nbytes);
1165 }
1166
1167 static void aes_128_cts_cbc_decrypt(const u8 key[AES_128_KEY_SIZE],
1168                                     const u8 iv[AES_BLOCK_SIZE],
1169                                     const u8 *src, u8 *dst, size_t nbytes)
1170 {
1171         aes_cts_cbc_decrypt(key, AES_128_KEY_SIZE, iv, src, dst, nbytes);
1172 }
1173
1174 /*----------------------------------------------------------------------------*
1175  *                           XChaCha12 stream cipher                          *
1176  *----------------------------------------------------------------------------*/
1177
1178 /*
1179  * References:
1180  *   - "XChaCha: eXtended-nonce ChaCha and AEAD_XChaCha20_Poly1305"
1181  *      https://tools.ietf.org/html/draft-arciszewski-xchacha-03
1182  *
1183  *   - "ChaCha, a variant of Salsa20"
1184  *      https://cr.yp.to/chacha/chacha-20080128.pdf
1185  *
1186  *   - "Extending the Salsa20 nonce"
1187  *      https://cr.yp.to/snuffle/xsalsa-20081128.pdf
1188  */
1189
1190 #define CHACHA_KEY_SIZE         32
1191 #define XCHACHA_KEY_SIZE        CHACHA_KEY_SIZE
1192 #define XCHACHA_NONCE_SIZE      24
1193
1194 static void chacha_init_state(u32 state[16], const u8 key[CHACHA_KEY_SIZE],
1195                               const u8 iv[16])
1196 {
1197         static const u8 consts[16] = "expand 32-byte k";
1198         int i;
1199
1200         for (i = 0; i < 4; i++)
1201                 state[i] = get_unaligned_le32(&consts[i * sizeof(__le32)]);
1202         for (i = 0; i < 8; i++)
1203                 state[4 + i] = get_unaligned_le32(&key[i * sizeof(__le32)]);
1204         for (i = 0; i < 4; i++)
1205                 state[12 + i] = get_unaligned_le32(&iv[i * sizeof(__le32)]);
1206 }
1207
1208 #define CHACHA_QUARTERROUND(a, b, c, d)         \
1209         do {                                    \
1210                 a += b; d = rol32(d ^ a, 16);   \
1211                 c += d; b = rol32(b ^ c, 12);   \
1212                 a += b; d = rol32(d ^ a, 8);    \
1213                 c += d; b = rol32(b ^ c, 7);    \
1214         } while (0)
1215
1216 static void chacha_permute(u32 x[16], int nrounds)
1217 {
1218         do {
1219                 /* column round */
1220                 CHACHA_QUARTERROUND(x[0], x[4], x[8], x[12]);
1221                 CHACHA_QUARTERROUND(x[1], x[5], x[9], x[13]);
1222                 CHACHA_QUARTERROUND(x[2], x[6], x[10], x[14]);
1223                 CHACHA_QUARTERROUND(x[3], x[7], x[11], x[15]);
1224
1225                 /* diagonal round */
1226                 CHACHA_QUARTERROUND(x[0], x[5], x[10], x[15]);
1227                 CHACHA_QUARTERROUND(x[1], x[6], x[11], x[12]);
1228                 CHACHA_QUARTERROUND(x[2], x[7], x[8], x[13]);
1229                 CHACHA_QUARTERROUND(x[3], x[4], x[9], x[14]);
1230         } while ((nrounds -= 2) != 0);
1231 }
1232
1233 static void xchacha(const u8 key[XCHACHA_KEY_SIZE],
1234                     const u8 nonce[XCHACHA_NONCE_SIZE],
1235                     const u8 *src, u8 *dst, size_t nbytes, int nrounds)
1236 {
1237         u32 state[16];
1238         u8 real_key[CHACHA_KEY_SIZE];
1239         u8 real_iv[16] = { 0 };
1240         size_t i, j;
1241
1242         /* Compute real key using original key and first 128 nonce bits */
1243         chacha_init_state(state, key, nonce);
1244         chacha_permute(state, nrounds);
1245         for (i = 0; i < 8; i++) /* state words 0..3, 12..15 */
1246                 put_unaligned_le32(state[(i < 4 ? 0 : 8) + i],
1247                                    &real_key[i * sizeof(__le32)]);
1248
1249         /* Now do regular ChaCha, using real key and remaining nonce bits */
1250         memcpy(&real_iv[8], nonce + 16, 8);
1251         chacha_init_state(state, real_key, real_iv);
1252         for (i = 0; i < nbytes; i += 64) {
1253                 u32 x[16];
1254                 __le32 keystream[16];
1255
1256                 memcpy(x, state, 64);
1257                 chacha_permute(x, nrounds);
1258                 for (j = 0; j < 16; j++)
1259                         keystream[j] = cpu_to_le32(x[j] + state[j]);
1260                 xor(&dst[i], &src[i], (u8 *)keystream, MIN(nbytes - i, 64));
1261                 if (++state[12] == 0)
1262                         state[13]++;
1263         }
1264 }
1265
1266 static void xchacha12(const u8 key[XCHACHA_KEY_SIZE],
1267                       const u8 nonce[XCHACHA_NONCE_SIZE],
1268                       const u8 *src, u8 *dst, size_t nbytes)
1269 {
1270         xchacha(key, nonce, src, dst, nbytes, 12);
1271 }
1272
1273 /*----------------------------------------------------------------------------*
1274  *                                 Poly1305                                   *
1275  *----------------------------------------------------------------------------*/
1276
1277 /*
1278  * Note: this is only the Poly1305 Îµ-almost-∆-universal hash function, not the
1279  * full Poly1305 MAC.  I.e., it doesn't add anything at the end.
1280  */
1281
1282 #define POLY1305_KEY_SIZE       16
1283 #define POLY1305_BLOCK_SIZE     16
1284
1285 static void poly1305(const u8 key[POLY1305_KEY_SIZE],
1286                      const u8 *msg, size_t msglen, le128 *out)
1287 {
1288         const u32 limb_mask = 0x3ffffff;        /* limbs are base 2^26 */
1289         const u64 r0 = (get_unaligned_le32(key +  0) >> 0) & 0x3ffffff;
1290         const u64 r1 = (get_unaligned_le32(key +  3) >> 2) & 0x3ffff03;
1291         const u64 r2 = (get_unaligned_le32(key +  6) >> 4) & 0x3ffc0ff;
1292         const u64 r3 = (get_unaligned_le32(key +  9) >> 6) & 0x3f03fff;
1293         const u64 r4 = (get_unaligned_le32(key + 12) >> 8) & 0x00fffff;
1294         u32 h0 = 0, h1 = 0, h2 = 0, h3 = 0, h4 = 0;
1295         u32 g0, g1, g2, g3, g4, ge_p_mask;
1296
1297         /* Partial block support is not necessary for Adiantum */
1298         ASSERT(msglen % POLY1305_BLOCK_SIZE == 0);
1299
1300         while (msglen) {
1301                 u64 d0, d1, d2, d3, d4;
1302
1303                 /* h += *msg */
1304                 h0 += (get_unaligned_le32(msg +  0) >> 0) & limb_mask;
1305                 h1 += (get_unaligned_le32(msg +  3) >> 2) & limb_mask;
1306                 h2 += (get_unaligned_le32(msg +  6) >> 4) & limb_mask;
1307                 h3 += (get_unaligned_le32(msg +  9) >> 6) & limb_mask;
1308                 h4 += (get_unaligned_le32(msg + 12) >> 8) | (1 << 24);
1309
1310                 /* h *= r */
1311                 d0 = h0*r0 + h1*5*r4 + h2*5*r3 + h3*5*r2 + h4*5*r1;
1312                 d1 = h0*r1 + h1*r0   + h2*5*r4 + h3*5*r3 + h4*5*r2;
1313                 d2 = h0*r2 + h1*r1   + h2*r0   + h3*5*r4 + h4*5*r3;
1314                 d3 = h0*r3 + h1*r2   + h2*r1   + h3*r0   + h4*5*r4;
1315                 d4 = h0*r4 + h1*r3   + h2*r2   + h3*r1   + h4*r0;
1316
1317                 /* (partial) h %= 2^130 - 5 */
1318                 d1 += d0 >> 26;         h0 = d0 & limb_mask;
1319                 d2 += d1 >> 26;         h1 = d1 & limb_mask;
1320                 d3 += d2 >> 26;         h2 = d2 & limb_mask;
1321                 d4 += d3 >> 26;         h3 = d3 & limb_mask;
1322                 h0 += (d4 >> 26) * 5;   h4 = d4 & limb_mask;
1323                 h1 += h0 >> 26;         h0 &= limb_mask;
1324
1325                 msg += POLY1305_BLOCK_SIZE;
1326                 msglen -= POLY1305_BLOCK_SIZE;
1327         }
1328
1329         /* fully carry h */
1330         h2 += (h1 >> 26);       h1 &= limb_mask;
1331         h3 += (h2 >> 26);       h2 &= limb_mask;
1332         h4 += (h3 >> 26);       h3 &= limb_mask;
1333         h0 += (h4 >> 26) * 5;   h4 &= limb_mask;
1334         h1 += (h0 >> 26);       h0 &= limb_mask;
1335
1336         /* if (h >= 2^130 - 5) h -= 2^130 - 5; */
1337         g0 = h0 + 5;
1338         g1 = h1 + (g0 >> 26);   g0 &= limb_mask;
1339         g2 = h2 + (g1 >> 26);   g1 &= limb_mask;
1340         g3 = h3 + (g2 >> 26);   g2 &= limb_mask;
1341         g4 = h4 + (g3 >> 26);   g3 &= limb_mask;
1342         ge_p_mask = ~((g4 >> 26) - 1); /* all 1's if h >= 2^130 - 5, else 0 */
1343         h0 = (h0 & ~ge_p_mask) | (g0 & ge_p_mask);
1344         h1 = (h1 & ~ge_p_mask) | (g1 & ge_p_mask);
1345         h2 = (h2 & ~ge_p_mask) | (g2 & ge_p_mask);
1346         h3 = (h3 & ~ge_p_mask) | (g3 & ge_p_mask);
1347         h4 = (h4 & ~ge_p_mask) | (g4 & ge_p_mask & limb_mask);
1348
1349         /* h %= 2^128 */
1350         out->lo = cpu_to_le64(((u64)h2 << 52) | ((u64)h1 << 26) | h0);
1351         out->hi = cpu_to_le64(((u64)h4 << 40) | ((u64)h3 << 14) | (h2 >> 12));
1352 }
1353
1354 /*----------------------------------------------------------------------------*
1355  *                          Adiantum encryption mode                          *
1356  *----------------------------------------------------------------------------*/
1357
1358 /*
1359  * Reference: "Adiantum: length-preserving encryption for entry-level processors"
1360  *      https://tosc.iacr.org/index.php/ToSC/article/view/7360
1361  */
1362
1363 #define ADIANTUM_KEY_SIZE       32
1364 #define ADIANTUM_IV_SIZE        32
1365 #define ADIANTUM_HASH_KEY_SIZE  ((2 * POLY1305_KEY_SIZE) + NH_KEY_SIZE)
1366
1367 #define NH_KEY_SIZE             1072
1368 #define NH_KEY_WORDS            (NH_KEY_SIZE / sizeof(u32))
1369 #define NH_BLOCK_SIZE           1024
1370 #define NH_HASH_SIZE            32
1371 #define NH_MESSAGE_UNIT         16
1372
1373 static u64 nh_pass(const u32 *key, const u8 *msg, size_t msglen)
1374 {
1375         u64 sum = 0;
1376
1377         ASSERT(msglen % NH_MESSAGE_UNIT == 0);
1378         while (msglen) {
1379                 sum += (u64)(u32)(get_unaligned_le32(msg +  0) + key[0]) *
1380                             (u32)(get_unaligned_le32(msg +  8) + key[2]);
1381                 sum += (u64)(u32)(get_unaligned_le32(msg +  4) + key[1]) *
1382                             (u32)(get_unaligned_le32(msg + 12) + key[3]);
1383                 key += NH_MESSAGE_UNIT / sizeof(key[0]);
1384                 msg += NH_MESSAGE_UNIT;
1385                 msglen -= NH_MESSAGE_UNIT;
1386         }
1387         return sum;
1388 }
1389
1390 /* NH Îµ-almost-universal hash function */
1391 static void nh(const u32 *key, const u8 *msg, size_t msglen,
1392                u8 result[NH_HASH_SIZE])
1393 {
1394         size_t i;
1395
1396         for (i = 0; i < NH_HASH_SIZE; i += sizeof(__le64)) {
1397                 put_unaligned_le64(nh_pass(key, msg, msglen), &result[i]);
1398                 key += NH_MESSAGE_UNIT / sizeof(key[0]);
1399         }
1400 }
1401
1402 /* Adiantum's Îµ-almost-∆-universal hash function */
1403 static void adiantum_hash(const u8 key[ADIANTUM_HASH_KEY_SIZE],
1404                           const u8 iv[ADIANTUM_IV_SIZE],
1405                           const u8 *msg, size_t msglen, le128 *result)
1406 {
1407         const u8 *header_poly_key = key;
1408         const u8 *msg_poly_key = header_poly_key + POLY1305_KEY_SIZE;
1409         const u8 *nh_key = msg_poly_key + POLY1305_KEY_SIZE;
1410         u32 nh_key_words[NH_KEY_WORDS];
1411         u8 header[POLY1305_BLOCK_SIZE + ADIANTUM_IV_SIZE];
1412         const size_t num_nh_blocks = DIV_ROUND_UP(msglen, NH_BLOCK_SIZE);
1413         u8 *nh_hashes = xmalloc(num_nh_blocks * NH_HASH_SIZE);
1414         const size_t padded_msglen = ROUND_UP(msglen, NH_MESSAGE_UNIT);
1415         u8 *padded_msg = xmalloc(padded_msglen);
1416         le128 hash1, hash2;
1417         size_t i;
1418
1419         for (i = 0; i < NH_KEY_WORDS; i++)
1420                 nh_key_words[i] = get_unaligned_le32(&nh_key[i * sizeof(u32)]);
1421
1422         /* Hash tweak and message length with first Poly1305 key */
1423         put_unaligned_le64((u64)msglen * 8, header);
1424         put_unaligned_le64(0, &header[sizeof(__le64)]);
1425         memcpy(&header[POLY1305_BLOCK_SIZE], iv, ADIANTUM_IV_SIZE);
1426         poly1305(header_poly_key, header, sizeof(header), &hash1);
1427
1428         /* Hash NH hashes of message blocks using second Poly1305 key */
1429         /* (using a super naive way of handling the padding) */
1430         memcpy(padded_msg, msg, msglen);
1431         memset(&padded_msg[msglen], 0, padded_msglen - msglen);
1432         for (i = 0; i < num_nh_blocks; i++) {
1433                 nh(nh_key_words, &padded_msg[i * NH_BLOCK_SIZE],
1434                    MIN(NH_BLOCK_SIZE, padded_msglen - (i * NH_BLOCK_SIZE)),
1435                    &nh_hashes[i * NH_HASH_SIZE]);
1436         }
1437         poly1305(msg_poly_key, nh_hashes, num_nh_blocks * NH_HASH_SIZE, &hash2);
1438
1439         /* Add the two hashes together to get the final hash */
1440         le128_add(result, &hash1, &hash2);
1441
1442         free(nh_hashes);
1443         free(padded_msg);
1444 }
1445
1446 static void adiantum_crypt(const u8 key[ADIANTUM_KEY_SIZE],
1447                            const u8 iv[ADIANTUM_IV_SIZE], const u8 *src,
1448                            u8 *dst, size_t nbytes, bool decrypting)
1449 {
1450         u8 subkeys[AES_256_KEY_SIZE + ADIANTUM_HASH_KEY_SIZE] = { 0 };
1451         struct aes_key aes_key;
1452         union {
1453                 u8 nonce[XCHACHA_NONCE_SIZE];
1454                 le128 block;
1455         } u = { .nonce = { 1 } };
1456         const size_t bulk_len = nbytes - sizeof(u.block);
1457         le128 hash;
1458
1459         ASSERT(nbytes >= sizeof(u.block));
1460
1461         /* Derive subkeys */
1462         xchacha12(key, u.nonce, subkeys, subkeys, sizeof(subkeys));
1463         aes_setkey(&aes_key, subkeys, AES_256_KEY_SIZE);
1464
1465         /* Hash left part and add to right part */
1466         adiantum_hash(&subkeys[AES_256_KEY_SIZE], iv, src, bulk_len, &hash);
1467         memcpy(&u.block, &src[bulk_len], sizeof(u.block));
1468         le128_add(&u.block, &u.block, &hash);
1469
1470         if (!decrypting) /* Encrypt right part with block cipher */
1471                 aes_encrypt(&aes_key, u.nonce, u.nonce);
1472
1473         /* Encrypt left part with stream cipher, using the computed nonce */
1474         u.nonce[sizeof(u.block)] = 1;
1475         xchacha12(key, u.nonce, src, dst, bulk_len);
1476
1477         if (decrypting) /* Decrypt right part with block cipher */
1478                 aes_decrypt(&aes_key, u.nonce, u.nonce);
1479
1480         /* Finalize right part by subtracting hash of left part */
1481         adiantum_hash(&subkeys[AES_256_KEY_SIZE], iv, dst, bulk_len, &hash);
1482         le128_sub(&u.block, &u.block, &hash);
1483         memcpy(&dst[bulk_len], &u.block, sizeof(u.block));
1484 }
1485
1486 static void adiantum_encrypt(const u8 key[ADIANTUM_KEY_SIZE],
1487                              const u8 iv[ADIANTUM_IV_SIZE],
1488                              const u8 *src, u8 *dst, size_t nbytes)
1489 {
1490         adiantum_crypt(key, iv, src, dst, nbytes, false);
1491 }
1492
1493 static void adiantum_decrypt(const u8 key[ADIANTUM_KEY_SIZE],
1494                              const u8 iv[ADIANTUM_IV_SIZE],
1495                              const u8 *src, u8 *dst, size_t nbytes)
1496 {
1497         adiantum_crypt(key, iv, src, dst, nbytes, true);
1498 }
1499
1500 #ifdef ENABLE_ALG_TESTS
1501 #include <linux/if_alg.h>
1502 #include <sys/socket.h>
1503 #define SOL_ALG 279
1504 static void af_alg_crypt(int algfd, int op, const u8 *key, size_t keylen,
1505                          const u8 *iv, size_t ivlen,
1506                          const u8 *src, u8 *dst, size_t datalen)
1507 {
1508         size_t controllen = CMSG_SPACE(sizeof(int)) +
1509                             CMSG_SPACE(sizeof(struct af_alg_iv) + ivlen);
1510         u8 *control = xmalloc(controllen);
1511         struct iovec iov = { .iov_base = (u8 *)src, .iov_len = datalen };
1512         struct msghdr msg = {
1513                 .msg_iov = &iov,
1514                 .msg_iovlen = 1,
1515                 .msg_control = control,
1516                 .msg_controllen = controllen,
1517         };
1518         struct cmsghdr *cmsg;
1519         struct af_alg_iv *algiv;
1520         int reqfd;
1521
1522         memset(control, 0, controllen);
1523
1524         cmsg = CMSG_FIRSTHDR(&msg);
1525         cmsg->cmsg_len = CMSG_LEN(sizeof(int));
1526         cmsg->cmsg_level = SOL_ALG;
1527         cmsg->cmsg_type = ALG_SET_OP;
1528         *(int *)CMSG_DATA(cmsg) = op;
1529
1530         cmsg = CMSG_NXTHDR(&msg, cmsg);
1531         cmsg->cmsg_len = CMSG_LEN(sizeof(struct af_alg_iv) + ivlen);
1532         cmsg->cmsg_level = SOL_ALG;
1533         cmsg->cmsg_type = ALG_SET_IV;
1534         algiv = (struct af_alg_iv *)CMSG_DATA(cmsg);
1535         algiv->ivlen = ivlen;
1536         memcpy(algiv->iv, iv, ivlen);
1537
1538         if (setsockopt(algfd, SOL_ALG, ALG_SET_KEY, key, keylen) != 0)
1539                 die_errno("can't set key on AF_ALG socket");
1540
1541         reqfd = accept(algfd, NULL, NULL);
1542         if (reqfd < 0)
1543                 die_errno("can't accept() AF_ALG socket");
1544         if (sendmsg(reqfd, &msg, 0) != datalen)
1545                 die_errno("can't sendmsg() AF_ALG request socket");
1546         if (xread(reqfd, dst, datalen) != datalen)
1547                 die("short read from AF_ALG request socket");
1548         close(reqfd);
1549
1550         free(control);
1551 }
1552
1553 static void test_adiantum(void)
1554 {
1555         int algfd = socket(AF_ALG, SOCK_SEQPACKET, 0);
1556         struct sockaddr_alg addr = {
1557                 .salg_type = "skcipher",
1558                 .salg_name = "adiantum(xchacha12,aes)",
1559         };
1560         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1561
1562         if (algfd < 0)
1563                 die_errno("can't create AF_ALG socket");
1564         if (bind(algfd, (struct sockaddr *)&addr, sizeof(addr)) != 0)
1565                 die_errno("can't bind AF_ALG socket to Adiantum algorithm");
1566
1567         while (num_tests--) {
1568                 u8 key[ADIANTUM_KEY_SIZE];
1569                 u8 iv[ADIANTUM_IV_SIZE];
1570                 u8 ptext[4096];
1571                 u8 ctext[sizeof(ptext)];
1572                 u8 ref_ctext[sizeof(ptext)];
1573                 u8 decrypted[sizeof(ptext)];
1574                 const size_t datalen = 16 + (rand() % (sizeof(ptext) - 15));
1575
1576                 rand_bytes(key, sizeof(key));
1577                 rand_bytes(iv, sizeof(iv));
1578                 rand_bytes(ptext, datalen);
1579
1580                 adiantum_encrypt(key, iv, ptext, ctext, datalen);
1581                 af_alg_crypt(algfd, ALG_OP_ENCRYPT, key, sizeof(key),
1582                              iv, sizeof(iv), ptext, ref_ctext, datalen);
1583                 ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1584
1585                 adiantum_decrypt(key, iv, ctext, decrypted, datalen);
1586                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1587         }
1588         close(algfd);
1589 }
1590 #endif /* ENABLE_ALG_TESTS */
1591
1592 /*----------------------------------------------------------------------------*
1593  *                               SipHash-2-4                                  *
1594  *----------------------------------------------------------------------------*/
1595
1596 /*
1597  * Reference: "SipHash: a fast short-input PRF"
1598  *      https://cr.yp.to/siphash/siphash-20120918.pdf
1599  */
1600
1601 #define SIPROUND                                                \
1602         do {                                                    \
1603                 v0 += v1;           v2 += v3;                   \
1604                 v1 = rol64(v1, 13); v3 = rol64(v3, 16);         \
1605                 v1 ^= v0;           v3 ^= v2;                   \
1606                 v0 = rol64(v0, 32);                             \
1607                 v2 += v1;           v0 += v3;                   \
1608                 v1 = rol64(v1, 17); v3 = rol64(v3, 21);         \
1609                 v1 ^= v2;           v3 ^= v0;                   \
1610                 v2 = rol64(v2, 32);                             \
1611         } while (0)
1612
1613 /* Compute the SipHash-2-4 of a 64-bit number when formatted as little endian */
1614 static u64 siphash_1u64(const u64 key[2], u64 data)
1615 {
1616         u64 v0 = key[0] ^ 0x736f6d6570736575ULL;
1617         u64 v1 = key[1] ^ 0x646f72616e646f6dULL;
1618         u64 v2 = key[0] ^ 0x6c7967656e657261ULL;
1619         u64 v3 = key[1] ^ 0x7465646279746573ULL;
1620         u64 m[2] = {data, (u64)sizeof(data) << 56};
1621         size_t i;
1622
1623         for (i = 0; i < ARRAY_SIZE(m); i++) {
1624                 v3 ^= m[i];
1625                 SIPROUND;
1626                 SIPROUND;
1627                 v0 ^= m[i];
1628         }
1629
1630         v2 ^= 0xff;
1631         for (i = 0; i < 4; i++)
1632                 SIPROUND;
1633         return v0 ^ v1 ^ v2 ^ v3;
1634 }
1635
1636 /*----------------------------------------------------------------------------*
1637  *                               Main program                                 *
1638  *----------------------------------------------------------------------------*/
1639
1640 #define FILE_NONCE_SIZE         16
1641 #define UUID_SIZE               16
1642 #define MAX_KEY_SIZE            64
1643
1644 static const struct fscrypt_cipher {
1645         const char *name;
1646         void (*encrypt)(const u8 *key, const u8 *iv, const u8 *src,
1647                         u8 *dst, size_t nbytes);
1648         void (*decrypt)(const u8 *key, const u8 *iv, const u8 *src,
1649                         u8 *dst, size_t nbytes);
1650         int keysize;
1651         int min_input_size;
1652 } fscrypt_ciphers[] = {
1653         {
1654                 .name = "AES-256-XTS",
1655                 .encrypt = aes_256_xts_encrypt,
1656                 .decrypt = aes_256_xts_decrypt,
1657                 .keysize = 2 * AES_256_KEY_SIZE,
1658         }, {
1659                 .name = "AES-256-CTS-CBC",
1660                 .encrypt = aes_256_cts_cbc_encrypt,
1661                 .decrypt = aes_256_cts_cbc_decrypt,
1662                 .keysize = AES_256_KEY_SIZE,
1663                 .min_input_size = AES_BLOCK_SIZE,
1664         }, {
1665                 .name = "AES-128-CBC-ESSIV",
1666                 .encrypt = aes_128_cbc_essiv_encrypt,
1667                 .decrypt = aes_128_cbc_essiv_decrypt,
1668                 .keysize = AES_128_KEY_SIZE,
1669         }, {
1670                 .name = "AES-128-CTS-CBC",
1671                 .encrypt = aes_128_cts_cbc_encrypt,
1672                 .decrypt = aes_128_cts_cbc_decrypt,
1673                 .keysize = AES_128_KEY_SIZE,
1674                 .min_input_size = AES_BLOCK_SIZE,
1675         }, {
1676                 .name = "Adiantum",
1677                 .encrypt = adiantum_encrypt,
1678                 .decrypt = adiantum_decrypt,
1679                 .keysize = ADIANTUM_KEY_SIZE,
1680                 .min_input_size = AES_BLOCK_SIZE,
1681         }
1682 };
1683
1684 static const struct fscrypt_cipher *find_fscrypt_cipher(const char *name)
1685 {
1686         size_t i;
1687
1688         for (i = 0; i < ARRAY_SIZE(fscrypt_ciphers); i++) {
1689                 if (strcmp(fscrypt_ciphers[i].name, name) == 0)
1690                         return &fscrypt_ciphers[i];
1691         }
1692         return NULL;
1693 }
1694
1695 struct fscrypt_iv {
1696         union {
1697                 __le64 block_num;
1698                 u8 bytes[32];
1699         };
1700 };
1701
1702 static void crypt_loop(const struct fscrypt_cipher *cipher, const u8 *key,
1703                        struct fscrypt_iv *iv, bool decrypting,
1704                        size_t block_size, size_t padding)
1705 {
1706         u8 *buf = xmalloc(block_size);
1707         size_t res;
1708
1709         while ((res = xread(STDIN_FILENO, buf, block_size)) > 0) {
1710                 size_t crypt_len = block_size;
1711
1712                 if (padding > 0) {
1713                         crypt_len = MAX(res, cipher->min_input_size);
1714                         crypt_len = ROUND_UP(crypt_len, padding);
1715                         crypt_len = MIN(crypt_len, block_size);
1716                 }
1717                 ASSERT(crypt_len >= res);
1718                 memset(&buf[res], 0, crypt_len - res);
1719
1720                 if (decrypting)
1721                         cipher->decrypt(key, iv->bytes, buf, buf, crypt_len);
1722                 else
1723                         cipher->encrypt(key, iv->bytes, buf, buf, crypt_len);
1724
1725                 full_write(STDOUT_FILENO, buf, crypt_len);
1726
1727                 iv->block_num = cpu_to_le64(le64_to_cpu(iv->block_num) + 1);
1728         }
1729         free(buf);
1730 }
1731
1732 /* The supported key derivation functions */
1733 enum kdf_algorithm {
1734         KDF_NONE,
1735         KDF_AES_128_ECB,
1736         KDF_HKDF_SHA512,
1737 };
1738
1739 static enum kdf_algorithm parse_kdf_algorithm(const char *arg)
1740 {
1741         if (strcmp(arg, "none") == 0)
1742                 return KDF_NONE;
1743         if (strcmp(arg, "AES-128-ECB") == 0)
1744                 return KDF_AES_128_ECB;
1745         if (strcmp(arg, "HKDF-SHA512") == 0)
1746                 return KDF_HKDF_SHA512;
1747         die("Unknown KDF: %s", arg);
1748 }
1749
1750 static u8 parse_mode_number(const char *arg)
1751 {
1752         char *tmp;
1753         long num = strtol(arg, &tmp, 10);
1754
1755         if (num <= 0 || *tmp || (u8)num != num)
1756                 die("Invalid mode number: %s", arg);
1757         return num;
1758 }
1759
1760 struct key_and_iv_params {
1761         u8 master_key[MAX_KEY_SIZE];
1762         int master_key_size;
1763         enum kdf_algorithm kdf;
1764         u8 mode_num;
1765         u8 file_nonce[FILE_NONCE_SIZE];
1766         bool file_nonce_specified;
1767         bool iv_ino_lblk_64;
1768         bool iv_ino_lblk_32;
1769         u64 inode_number;
1770         u8 fs_uuid[UUID_SIZE];
1771         bool fs_uuid_specified;
1772 };
1773
1774 #define HKDF_CONTEXT_KEY_IDENTIFIER     1
1775 #define HKDF_CONTEXT_PER_FILE_ENC_KEY   2
1776 #define HKDF_CONTEXT_DIRECT_KEY         3
1777 #define HKDF_CONTEXT_IV_INO_LBLK_64_KEY 4
1778 #define HKDF_CONTEXT_DIRHASH_KEY        5
1779 #define HKDF_CONTEXT_IV_INO_LBLK_32_KEY 6
1780 #define HKDF_CONTEXT_INODE_HASH_KEY     7
1781
1782 /* Hash the file's inode number using SipHash keyed by a derived key */
1783 static u32 hash_inode_number(const struct key_and_iv_params *params)
1784 {
1785         u8 info[9] = "fscrypt";
1786         union {
1787                 u64 words[2];
1788                 u8 bytes[16];
1789         } hash_key;
1790
1791         info[8] = HKDF_CONTEXT_INODE_HASH_KEY;
1792         hkdf_sha512(params->master_key, params->master_key_size,
1793                     NULL, 0, info, sizeof(info),
1794                     hash_key.bytes, sizeof(hash_key));
1795
1796         hash_key.words[0] = get_unaligned_le64(&hash_key.bytes[0]);
1797         hash_key.words[1] = get_unaligned_le64(&hash_key.bytes[8]);
1798
1799         return (u32)siphash_1u64(hash_key.words, params->inode_number);
1800 }
1801
1802 /*
1803  * Get the key and starting IV with which the encryption will actually be done.
1804  * If a KDF was specified, a subkey is derived from the master key and the mode
1805  * number or file nonce.  Otherwise, the master key is used directly.
1806  */
1807 static void get_key_and_iv(const struct key_and_iv_params *params,
1808                            u8 *real_key, size_t real_key_size,
1809                            struct fscrypt_iv *iv)
1810 {
1811         bool file_nonce_in_iv = false;
1812         struct aes_key aes_key;
1813         u8 info[8 + 1 + 1 + UUID_SIZE] = "fscrypt";
1814         size_t infolen = 8;
1815         size_t i;
1816
1817         ASSERT(real_key_size <= params->master_key_size);
1818
1819         memset(iv, 0, sizeof(*iv));
1820
1821         if (params->iv_ino_lblk_64 || params->iv_ino_lblk_32) {
1822                 const char *opt = params->iv_ino_lblk_64 ? "--iv-ino-lblk-64" :
1823                                                            "--iv-ino-lblk-32";
1824                 if (params->iv_ino_lblk_64 && params->iv_ino_lblk_32)
1825                         die("--iv-ino-lblk-64 and --iv-ino-lblk-32 are mutually exclusive");
1826                 if (params->kdf != KDF_HKDF_SHA512)
1827                         die("%s requires --kdf=HKDF-SHA512", opt);
1828                 if (!params->fs_uuid_specified)
1829                         die("%s requires --fs-uuid", opt);
1830                 if (params->inode_number == 0)
1831                         die("%s requires --inode-number", opt);
1832                 if (params->mode_num == 0)
1833                         die("%s requires --mode-num", opt);
1834                 if (params->inode_number > UINT32_MAX)
1835                         die("%s can't use --inode-number > UINT32_MAX", opt);
1836         }
1837
1838         switch (params->kdf) {
1839         case KDF_NONE:
1840                 if (params->mode_num != 0)
1841                         die("--mode-num isn't supported with --kdf=none");
1842                 memcpy(real_key, params->master_key, real_key_size);
1843                 file_nonce_in_iv = true;
1844                 break;
1845         case KDF_AES_128_ECB:
1846                 if (!params->file_nonce_specified)
1847                         die("--file-nonce is required with --kdf=AES-128-ECB");
1848                 if (params->mode_num != 0)
1849                         die("--mode-num isn't supported with --kdf=AES-128-ECB");
1850                 STATIC_ASSERT(FILE_NONCE_SIZE == AES_128_KEY_SIZE);
1851                 ASSERT(real_key_size % AES_BLOCK_SIZE == 0);
1852                 aes_setkey(&aes_key, params->file_nonce, AES_128_KEY_SIZE);
1853                 for (i = 0; i < real_key_size; i += AES_BLOCK_SIZE)
1854                         aes_encrypt(&aes_key, &params->master_key[i],
1855                                     &real_key[i]);
1856                 break;
1857         case KDF_HKDF_SHA512:
1858                 if (params->iv_ino_lblk_64) {
1859                         info[infolen++] = HKDF_CONTEXT_IV_INO_LBLK_64_KEY;
1860                         info[infolen++] = params->mode_num;
1861                         memcpy(&info[infolen], params->fs_uuid, UUID_SIZE);
1862                         infolen += UUID_SIZE;
1863                         put_unaligned_le32(params->inode_number, &iv->bytes[4]);
1864                 } else if (params->iv_ino_lblk_32) {
1865                         info[infolen++] = HKDF_CONTEXT_IV_INO_LBLK_32_KEY;
1866                         info[infolen++] = params->mode_num;
1867                         memcpy(&info[infolen], params->fs_uuid, UUID_SIZE);
1868                         infolen += UUID_SIZE;
1869                         put_unaligned_le32(hash_inode_number(params),
1870                                            iv->bytes);
1871                 } else if (params->mode_num != 0) {
1872                         info[infolen++] = HKDF_CONTEXT_DIRECT_KEY;
1873                         info[infolen++] = params->mode_num;
1874                         file_nonce_in_iv = true;
1875                 } else if (params->file_nonce_specified) {
1876                         info[infolen++] = HKDF_CONTEXT_PER_FILE_ENC_KEY;
1877                         memcpy(&info[infolen], params->file_nonce,
1878                                FILE_NONCE_SIZE);
1879                         infolen += FILE_NONCE_SIZE;
1880                 } else {
1881                         die("With --kdf=HKDF-SHA512, at least one of --file-nonce and --mode-num must be specified");
1882                 }
1883                 hkdf_sha512(params->master_key, params->master_key_size,
1884                             NULL, 0, info, infolen, real_key, real_key_size);
1885                 break;
1886         default:
1887                 ASSERT(0);
1888         }
1889
1890         if (file_nonce_in_iv && params->file_nonce_specified)
1891                 memcpy(&iv->bytes[8], params->file_nonce, FILE_NONCE_SIZE);
1892 }
1893
1894 enum {
1895         OPT_BLOCK_SIZE,
1896         OPT_DECRYPT,
1897         OPT_FILE_NONCE,
1898         OPT_FS_UUID,
1899         OPT_HELP,
1900         OPT_INODE_NUMBER,
1901         OPT_IV_INO_LBLK_32,
1902         OPT_IV_INO_LBLK_64,
1903         OPT_KDF,
1904         OPT_MODE_NUM,
1905         OPT_PADDING,
1906 };
1907
1908 static const struct option longopts[] = {
1909         { "block-size",      required_argument, NULL, OPT_BLOCK_SIZE },
1910         { "decrypt",         no_argument,       NULL, OPT_DECRYPT },
1911         { "file-nonce",      required_argument, NULL, OPT_FILE_NONCE },
1912         { "fs-uuid",         required_argument, NULL, OPT_FS_UUID },
1913         { "help",            no_argument,       NULL, OPT_HELP },
1914         { "inode-number",    required_argument, NULL, OPT_INODE_NUMBER },
1915         { "iv-ino-lblk-32",  no_argument,       NULL, OPT_IV_INO_LBLK_32 },
1916         { "iv-ino-lblk-64",  no_argument,       NULL, OPT_IV_INO_LBLK_64 },
1917         { "kdf",             required_argument, NULL, OPT_KDF },
1918         { "mode-num",        required_argument, NULL, OPT_MODE_NUM },
1919         { "padding",         required_argument, NULL, OPT_PADDING },
1920         { NULL, 0, NULL, 0 },
1921 };
1922
1923 int main(int argc, char *argv[])
1924 {
1925         size_t block_size = 4096;
1926         bool decrypting = false;
1927         struct key_and_iv_params params;
1928         size_t padding = 0;
1929         const struct fscrypt_cipher *cipher;
1930         u8 real_key[MAX_KEY_SIZE];
1931         struct fscrypt_iv iv;
1932         char *tmp;
1933         int c;
1934
1935         memset(&params, 0, sizeof(params));
1936
1937         aes_init();
1938
1939 #ifdef ENABLE_ALG_TESTS
1940         test_aes();
1941         test_sha2();
1942         test_hkdf_sha512();
1943         test_aes_256_xts();
1944         test_aes_256_cts_cbc();
1945         test_adiantum();
1946 #endif
1947
1948         while ((c = getopt_long(argc, argv, "", longopts, NULL)) != -1) {
1949                 switch (c) {
1950                 case OPT_BLOCK_SIZE:
1951                         errno = 0;
1952                         block_size = strtoul(optarg, &tmp, 10);
1953                         if (block_size <= 0 || *tmp || errno)
1954                                 die("Invalid block size: %s", optarg);
1955                         break;
1956                 case OPT_DECRYPT:
1957                         decrypting = true;
1958                         break;
1959                 case OPT_FILE_NONCE:
1960                         if (hex2bin(optarg, params.file_nonce, FILE_NONCE_SIZE)
1961                             != FILE_NONCE_SIZE)
1962                                 die("Invalid file nonce: %s", optarg);
1963                         params.file_nonce_specified = true;
1964                         break;
1965                 case OPT_FS_UUID:
1966                         if (hex2bin(optarg, params.fs_uuid, UUID_SIZE)
1967                             != UUID_SIZE)
1968                                 die("Invalid filesystem UUID: %s", optarg);
1969                         params.fs_uuid_specified = true;
1970                         break;
1971                 case OPT_HELP:
1972                         usage(stdout);
1973                         return 0;
1974                 case OPT_INODE_NUMBER:
1975                         errno = 0;
1976                         params.inode_number = strtoull(optarg, &tmp, 10);
1977                         if (params.inode_number <= 0 || *tmp || errno)
1978                                 die("Invalid inode number: %s", optarg);
1979                         break;
1980                 case OPT_IV_INO_LBLK_32:
1981                         params.iv_ino_lblk_32 = true;
1982                         break;
1983                 case OPT_IV_INO_LBLK_64:
1984                         params.iv_ino_lblk_64 = true;
1985                         break;
1986                 case OPT_KDF:
1987                         params.kdf = parse_kdf_algorithm(optarg);
1988                         break;
1989                 case OPT_MODE_NUM:
1990                         params.mode_num = parse_mode_number(optarg);
1991                         break;
1992                 case OPT_PADDING:
1993                         padding = strtoul(optarg, &tmp, 10);
1994                         if (padding <= 0 || *tmp || !is_power_of_2(padding) ||
1995                             padding > INT_MAX)
1996                                 die("Invalid padding amount: %s", optarg);
1997                         break;
1998                 default:
1999                         usage(stderr);
2000                         return 2;
2001                 }
2002         }
2003         argc -= optind;
2004         argv += optind;
2005
2006         if (argc != 2) {
2007                 usage(stderr);
2008                 return 2;
2009         }
2010
2011         cipher = find_fscrypt_cipher(argv[0]);
2012         if (cipher == NULL)
2013                 die("Unknown cipher: %s", argv[0]);
2014
2015         if (block_size < cipher->min_input_size)
2016                 die("Block size of %zu bytes is too small for cipher %s",
2017                     block_size, cipher->name);
2018
2019         params.master_key_size = hex2bin(argv[1], params.master_key,
2020                                          MAX_KEY_SIZE);
2021         if (params.master_key_size < 0)
2022                 die("Invalid master_key: %s", argv[1]);
2023         if (params.master_key_size < cipher->keysize)
2024                 die("Master key is too short for cipher %s", cipher->name);
2025
2026         get_key_and_iv(&params, real_key, cipher->keysize, &iv);
2027
2028         crypt_loop(cipher, real_key, &iv, decrypting, block_size, padding);
2029         return 0;
2030 }