fstests: explicitly specify xattr namespace
[xfstests-dev.git] / src / fscrypt-crypt-util.c
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * fscrypt-crypt-util.c - utility for verifying fscrypt-encrypted data
4  *
5  * Copyright 2019 Google LLC
6  */
7
8 /*
9  * This program implements all crypto algorithms supported by fscrypt (a.k.a.
10  * ext4, f2fs, and ubifs encryption), for the purpose of verifying the
11  * correctness of the ciphertext stored on-disk.  See usage() below.
12  *
13  * All algorithms are implemented in portable C code to avoid depending on
14  * libcrypto (OpenSSL), and because some fscrypt-supported algorithms aren't
15  * available in libcrypto anyway (e.g. Adiantum), or are only supported in
16  * recent versions (e.g. HKDF-SHA512).  For simplicity, all crypto code here
17  * tries to follow the mathematical definitions directly, without optimizing for
18  * performance or worrying about following security best practices such as
19  * mitigating side-channel attacks.  So, only use this program for testing!
20  */
21
22 #include <asm/byteorder.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <limits.h>
26 #include <linux/types.h>
27 #include <stdarg.h>
28 #include <stdbool.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32 #include <unistd.h>
33
34 #define PROGRAM_NAME "fscrypt-crypt-util"
35
36 /*
37  * Define to enable the tests of the crypto code in this file.  If enabled, you
38  * must link this program with OpenSSL (-lcrypto) v1.1.0 or later, and your
39  * kernel needs CONFIG_CRYPTO_USER_API_SKCIPHER=y and CONFIG_CRYPTO_ADIANTUM=y.
40  */
41 #undef ENABLE_ALG_TESTS
42
43 #define NUM_ALG_TEST_ITERATIONS 10000
44
45 static void usage(FILE *fp)
46 {
47         fputs(
48 "Usage: " PROGRAM_NAME " [OPTION]... CIPHER MASTER_KEY\n"
49 "\n"
50 "Utility for verifying fscrypt-encrypted data.  This program encrypts\n"
51 "(or decrypts) the data on stdin using the given CIPHER with the given\n"
52 "MASTER_KEY (or a key derived from it, if a KDF is specified), and writes the\n"
53 "resulting ciphertext (or plaintext) to stdout.\n"
54 "\n"
55 "CIPHER can be AES-256-XTS, AES-256-CTS-CBC, AES-128-CBC-ESSIV, AES-128-CTS-CBC,\n"
56 "or Adiantum.  MASTER_KEY must be a hex string long enough for the cipher.\n"
57 "\n"
58 "WARNING: this program is only meant for testing, not for \"real\" use!\n"
59 "\n"
60 "Options:\n"
61 "  --block-size=BLOCK_SIZE     Encrypt each BLOCK_SIZE bytes independently.\n"
62 "                                Default: 4096 bytes\n"
63 "  --decrypt                   Decrypt instead of encrypt\n"
64 "  --file-nonce=NONCE          File's nonce as a 32-character hex string\n"
65 "  --fs-uuid=UUID              The filesystem UUID as a 32-character hex string.\n"
66 "                                Required for --iv-ino-lblk-32 and\n"
67 "                                --iv-ino-lblk-64; otherwise is unused.\n"
68 "  --help                      Show this help\n"
69 "  --inode-number=INUM         The file's inode number.  Required for\n"
70 "                                --iv-ino-lblk-32 and --iv-ino-lblk-64;\n"
71 "                                otherwise is unused.\n"
72 "  --iv-ino-lblk-32            Similar to --iv-ino-lblk-64, but selects the\n"
73 "                                32-bit variant.\n"
74 "  --iv-ino-lblk-64            Use the format where the IVs include the inode\n"
75 "                                number and the same key is shared across files.\n"
76 "                                Requires --kdf=HKDF-SHA512, --fs-uuid,\n"
77 "                                --inode-number, and --mode-num.\n"
78 "  --kdf=KDF                   Key derivation function to use: AES-128-ECB,\n"
79 "                                HKDF-SHA512, or none.  Default: none\n"
80 "  --mode-num=NUM              Derive per-mode key using mode number NUM\n"
81 "  --padding=PADDING           If last block is partial, zero-pad it to next\n"
82 "                                PADDING-byte boundary.  Default: BLOCK_SIZE\n"
83         , fp);
84 }
85
86 /*----------------------------------------------------------------------------*
87  *                                 Utilities                                  *
88  *----------------------------------------------------------------------------*/
89
90 #define ARRAY_SIZE(A)           (sizeof(A) / sizeof((A)[0]))
91 #define MIN(x, y)               ((x) < (y) ? (x) : (y))
92 #define MAX(x, y)               ((x) > (y) ? (x) : (y))
93 #define ROUND_DOWN(x, y)        ((x) & ~((y) - 1))
94 #define ROUND_UP(x, y)          (((x) + (y) - 1) & ~((y) - 1))
95 #define DIV_ROUND_UP(n, d)      (((n) + (d) - 1) / (d))
96 #define STATIC_ASSERT(e)        ((void)sizeof(char[1 - 2*!(e)]))
97
98 typedef __u8                    u8;
99 typedef __u16                   u16;
100 typedef __u32                   u32;
101 typedef __u64                   u64;
102
103 #define cpu_to_le32             __cpu_to_le32
104 #define cpu_to_be32             __cpu_to_be32
105 #define cpu_to_le64             __cpu_to_le64
106 #define cpu_to_be64             __cpu_to_be64
107 #define le32_to_cpu             __le32_to_cpu
108 #define be32_to_cpu             __be32_to_cpu
109 #define le64_to_cpu             __le64_to_cpu
110 #define be64_to_cpu             __be64_to_cpu
111
112 #define DEFINE_UNALIGNED_ACCESS_HELPERS(type, native_type)      \
113 static inline native_type __attribute__((unused))               \
114 get_unaligned_##type(const void *p)                             \
115 {                                                               \
116         __##type x;                                             \
117                                                                 \
118         memcpy(&x, p, sizeof(x));                               \
119         return type##_to_cpu(x);                                \
120 }                                                               \
121                                                                 \
122 static inline void __attribute__((unused))                      \
123 put_unaligned_##type(native_type v, void *p)                    \
124 {                                                               \
125         __##type x = cpu_to_##type(v);                          \
126                                                                 \
127         memcpy(p, &x, sizeof(x));                               \
128 }
129
130 DEFINE_UNALIGNED_ACCESS_HELPERS(le32, u32)
131 DEFINE_UNALIGNED_ACCESS_HELPERS(be32, u32)
132 DEFINE_UNALIGNED_ACCESS_HELPERS(le64, u64)
133 DEFINE_UNALIGNED_ACCESS_HELPERS(be64, u64)
134
135 static inline bool is_power_of_2(unsigned long v)
136 {
137         return v != 0 && (v & (v - 1)) == 0;
138 }
139
140 static inline u32 rol32(u32 v, int n)
141 {
142         return (v << n) | (v >> (32 - n));
143 }
144
145 static inline u32 ror32(u32 v, int n)
146 {
147         return (v >> n) | (v << (32 - n));
148 }
149
150 static inline u64 rol64(u64 v, int n)
151 {
152         return (v << n) | (v >> (64 - n));
153 }
154
155 static inline u64 ror64(u64 v, int n)
156 {
157         return (v >> n) | (v << (64 - n));
158 }
159
160 static inline void xor(u8 *res, const u8 *a, const u8 *b, size_t count)
161 {
162         while (count--)
163                 *res++ = *a++ ^ *b++;
164 }
165
166 static void __attribute__((noreturn, format(printf, 2, 3)))
167 do_die(int err, const char *format, ...)
168 {
169         va_list va;
170
171         va_start(va, format);
172         fputs("[" PROGRAM_NAME "] ERROR: ", stderr);
173         vfprintf(stderr, format, va);
174         if (err)
175                 fprintf(stderr, ": %s", strerror(errno));
176         putc('\n', stderr);
177         va_end(va);
178         exit(1);
179 }
180
181 #define die(format, ...)        do_die(0,     (format), ##__VA_ARGS__)
182 #define die_errno(format, ...)  do_die(errno, (format), ##__VA_ARGS__)
183
184 static __attribute__((noreturn)) void
185 assertion_failed(const char *expr, const char *file, int line)
186 {
187         die("Assertion failed: %s at %s:%d", expr, file, line);
188 }
189
190 #define ASSERT(e) ({ if (!(e)) assertion_failed(#e, __FILE__, __LINE__); })
191
192 static void *xmalloc(size_t size)
193 {
194         void *p = malloc(size);
195
196         ASSERT(p != NULL);
197         return p;
198 }
199
200 static int hexchar2bin(char c)
201 {
202         if (c >= 'a' && c <= 'f')
203                 return 10 + c - 'a';
204         if (c >= 'A' && c <= 'F')
205                 return 10 + c - 'A';
206         if (c >= '0' && c <= '9')
207                 return c - '0';
208         return -1;
209 }
210
211 static int hex2bin(const char *hex, u8 *bin, int max_bin_size)
212 {
213         size_t len = strlen(hex);
214         size_t i;
215
216         if (len & 1)
217                 return -1;
218         len /= 2;
219         if (len > max_bin_size)
220                 return -1;
221
222         for (i = 0; i < len; i++) {
223                 int high = hexchar2bin(hex[2 * i]);
224                 int low = hexchar2bin(hex[2 * i + 1]);
225
226                 if (high < 0 || low < 0)
227                         return -1;
228                 bin[i] = (high << 4) | low;
229         }
230         return len;
231 }
232
233 static size_t xread(int fd, void *buf, size_t count)
234 {
235         const size_t orig_count = count;
236
237         while (count) {
238                 ssize_t res = read(fd, buf, count);
239
240                 if (res < 0)
241                         die_errno("read error");
242                 if (res == 0)
243                         break;
244                 buf += res;
245                 count -= res;
246         }
247         return orig_count - count;
248 }
249
250 static void full_write(int fd, const void *buf, size_t count)
251 {
252         while (count) {
253                 ssize_t res = write(fd, buf, count);
254
255                 if (res < 0)
256                         die_errno("write error");
257                 buf += res;
258                 count -= res;
259         }
260 }
261
262 #ifdef ENABLE_ALG_TESTS
263 static void rand_bytes(u8 *buf, size_t count)
264 {
265         while (count--)
266                 *buf++ = rand();
267 }
268 #endif
269
270 /*----------------------------------------------------------------------------*
271  *                          Finite field arithmetic                           *
272  *----------------------------------------------------------------------------*/
273
274 /* Multiply a GF(2^8) element by the polynomial 'x' */
275 static inline u8 gf2_8_mul_x(u8 b)
276 {
277         return (b << 1) ^ ((b & 0x80) ? 0x1B : 0);
278 }
279
280 /* Multiply four packed GF(2^8) elements by the polynomial 'x' */
281 static inline u32 gf2_8_mul_x_4way(u32 w)
282 {
283         return ((w & 0x7F7F7F7F) << 1) ^ (((w & 0x80808080) >> 7) * 0x1B);
284 }
285
286 /* Element of GF(2^128) */
287 typedef struct {
288         __le64 lo;
289         __le64 hi;
290 } ble128;
291
292 /* Multiply a GF(2^128) element by the polynomial 'x' */
293 static inline void gf2_128_mul_x(ble128 *t)
294 {
295         u64 lo = le64_to_cpu(t->lo);
296         u64 hi = le64_to_cpu(t->hi);
297
298         t->hi = cpu_to_le64((hi << 1) | (lo >> 63));
299         t->lo = cpu_to_le64((lo << 1) ^ ((hi & (1ULL << 63)) ? 0x87 : 0));
300 }
301
302 /*----------------------------------------------------------------------------*
303  *                             Group arithmetic                               *
304  *----------------------------------------------------------------------------*/
305
306 /* Element of Z/(2^{128}Z)  (a.k.a. the integers modulo 2^128) */
307 typedef struct {
308         __le64 lo;
309         __le64 hi;
310 } le128;
311
312 static inline void le128_add(le128 *res, const le128 *a, const le128 *b)
313 {
314         u64 a_lo = le64_to_cpu(a->lo);
315         u64 b_lo = le64_to_cpu(b->lo);
316
317         res->lo = cpu_to_le64(a_lo + b_lo);
318         res->hi = cpu_to_le64(le64_to_cpu(a->hi) + le64_to_cpu(b->hi) +
319                               (a_lo + b_lo < a_lo));
320 }
321
322 static inline void le128_sub(le128 *res, const le128 *a, const le128 *b)
323 {
324         u64 a_lo = le64_to_cpu(a->lo);
325         u64 b_lo = le64_to_cpu(b->lo);
326
327         res->lo = cpu_to_le64(a_lo - b_lo);
328         res->hi = cpu_to_le64(le64_to_cpu(a->hi) - le64_to_cpu(b->hi) -
329                               (a_lo - b_lo > a_lo));
330 }
331
332 /*----------------------------------------------------------------------------*
333  *                              AES block cipher                              *
334  *----------------------------------------------------------------------------*/
335
336 /*
337  * Reference: "FIPS 197, Advanced Encryption Standard"
338  *      https://nvlpubs.nist.gov/nistpubs/fips/nist.fips.197.pdf
339  */
340
341 #define AES_BLOCK_SIZE          16
342 #define AES_128_KEY_SIZE        16
343 #define AES_192_KEY_SIZE        24
344 #define AES_256_KEY_SIZE        32
345
346 static inline void AddRoundKey(u32 state[4], const u32 *rk)
347 {
348         int i;
349
350         for (i = 0; i < 4; i++)
351                 state[i] ^= rk[i];
352 }
353
354 static const u8 aes_sbox[256] = {
355         0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b,
356         0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0,
357         0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26,
358         0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
359         0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2,
360         0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0,
361         0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed,
362         0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
363         0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f,
364         0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5,
365         0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec,
366         0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
367         0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14,
368         0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c,
369         0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d,
370         0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
371         0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f,
372         0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e,
373         0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11,
374         0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
375         0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f,
376         0xb0, 0x54, 0xbb, 0x16,
377 };
378
379 static u8 aes_inverse_sbox[256];
380
381 static void aes_init(void)
382 {
383         int i;
384
385         for (i = 0; i < 256; i++)
386                 aes_inverse_sbox[aes_sbox[i]] = i;
387 }
388
389 static inline u32 DoSubWord(u32 w, const u8 sbox[256])
390 {
391         return ((u32)sbox[(u8)(w >> 24)] << 24) |
392                ((u32)sbox[(u8)(w >> 16)] << 16) |
393                ((u32)sbox[(u8)(w >>  8)] <<  8) |
394                ((u32)sbox[(u8)(w >>  0)] <<  0);
395 }
396
397 static inline u32 SubWord(u32 w)
398 {
399         return DoSubWord(w, aes_sbox);
400 }
401
402 static inline u32 InvSubWord(u32 w)
403 {
404         return DoSubWord(w, aes_inverse_sbox);
405 }
406
407 static inline void SubBytes(u32 state[4])
408 {
409         int i;
410
411         for (i = 0; i < 4; i++)
412                 state[i] = SubWord(state[i]);
413 }
414
415 static inline void InvSubBytes(u32 state[4])
416 {
417         int i;
418
419         for (i = 0; i < 4; i++)
420                 state[i] = InvSubWord(state[i]);
421 }
422
423 static inline void DoShiftRows(u32 state[4], int direction)
424 {
425         u32 newstate[4];
426         int i;
427
428         for (i = 0; i < 4; i++)
429                 newstate[i] = (state[(i + direction*0) & 3] & 0xff) |
430                               (state[(i + direction*1) & 3] & 0xff00) |
431                               (state[(i + direction*2) & 3] & 0xff0000) |
432                               (state[(i + direction*3) & 3] & 0xff000000);
433         memcpy(state, newstate, 16);
434 }
435
436 static inline void ShiftRows(u32 state[4])
437 {
438         DoShiftRows(state, 1);
439 }
440
441 static inline void InvShiftRows(u32 state[4])
442 {
443         DoShiftRows(state, -1);
444 }
445
446 /*
447  * Mix one column by doing the following matrix multiplication in GF(2^8):
448  *
449  *     | 2 3 1 1 |   | w[0] |
450  *     | 1 2 3 1 |   | w[1] |
451  *     | 1 1 2 3 | x | w[2] |
452  *     | 3 1 1 2 |   | w[3] |
453  *
454  * a.k.a. w[i] = 2*w[i] + 3*w[(i+1)%4] + w[(i+2)%4] + w[(i+3)%4]
455  */
456 static inline u32 MixColumn(u32 w)
457 {
458         u32 _2w0_w2 = gf2_8_mul_x_4way(w) ^ ror32(w, 16);
459         u32 _3w1_w3 = ror32(_2w0_w2 ^ w, 8);
460
461         return _2w0_w2 ^ _3w1_w3;
462 }
463
464 /*
465  *           ( | 5 0 4 0 |   | w[0] | )
466  *          (  | 0 5 0 4 |   | w[1] |  )
467  * MixColumn(  | 4 0 5 0 | x | w[2] |  )
468  *           ( | 0 4 0 5 |   | w[3] | )
469  */
470 static inline u32 InvMixColumn(u32 w)
471 {
472         u32 _4w = gf2_8_mul_x_4way(gf2_8_mul_x_4way(w));
473
474         return MixColumn(_4w ^ w ^ ror32(_4w, 16));
475 }
476
477 static inline void MixColumns(u32 state[4])
478 {
479         int i;
480
481         for (i = 0; i < 4; i++)
482                 state[i] = MixColumn(state[i]);
483 }
484
485 static inline void InvMixColumns(u32 state[4])
486 {
487         int i;
488
489         for (i = 0; i < 4; i++)
490                 state[i] = InvMixColumn(state[i]);
491 }
492
493 struct aes_key {
494         u32 round_keys[15 * 4];
495         int nrounds;
496 };
497
498 /* Expand an AES key */
499 static void aes_setkey(struct aes_key *k, const u8 *key, int keysize)
500 {
501         const int N = keysize / 4;
502         u32 * const rk = k->round_keys;
503         u8 rcon = 1;
504         int i;
505
506         ASSERT(keysize == 16 || keysize == 24 || keysize == 32);
507         k->nrounds = 6 + N;
508         for (i = 0; i < 4 * (k->nrounds + 1); i++) {
509                 if (i < N) {
510                         rk[i] = get_unaligned_le32(&key[i * sizeof(__le32)]);
511                 } else if (i % N == 0) {
512                         rk[i] = rk[i - N] ^ SubWord(ror32(rk[i - 1], 8)) ^ rcon;
513                         rcon = gf2_8_mul_x(rcon);
514                 } else if (N > 6 && i % N == 4) {
515                         rk[i] = rk[i - N] ^ SubWord(rk[i - 1]);
516                 } else {
517                         rk[i] = rk[i - N] ^ rk[i - 1];
518                 }
519         }
520 }
521
522 /* Encrypt one 16-byte block with AES */
523 static void aes_encrypt(const struct aes_key *k, const u8 src[AES_BLOCK_SIZE],
524                         u8 dst[AES_BLOCK_SIZE])
525 {
526         u32 state[4];
527         int i;
528
529         for (i = 0; i < 4; i++)
530                 state[i] = get_unaligned_le32(&src[i * sizeof(__le32)]);
531
532         AddRoundKey(state, k->round_keys);
533         for (i = 1; i < k->nrounds; i++) {
534                 SubBytes(state);
535                 ShiftRows(state);
536                 MixColumns(state);
537                 AddRoundKey(state, &k->round_keys[4 * i]);
538         }
539         SubBytes(state);
540         ShiftRows(state);
541         AddRoundKey(state, &k->round_keys[4 * i]);
542
543         for (i = 0; i < 4; i++)
544                 put_unaligned_le32(state[i], &dst[i * sizeof(__le32)]);
545 }
546
547 /* Decrypt one 16-byte block with AES */
548 static void aes_decrypt(const struct aes_key *k, const u8 src[AES_BLOCK_SIZE],
549                         u8 dst[AES_BLOCK_SIZE])
550 {
551         u32 state[4];
552         int i;
553
554         for (i = 0; i < 4; i++)
555                 state[i] = get_unaligned_le32(&src[i * sizeof(__le32)]);
556
557         AddRoundKey(state, &k->round_keys[4 * k->nrounds]);
558         InvShiftRows(state);
559         InvSubBytes(state);
560         for (i = k->nrounds - 1; i >= 1; i--) {
561                 AddRoundKey(state, &k->round_keys[4 * i]);
562                 InvMixColumns(state);
563                 InvShiftRows(state);
564                 InvSubBytes(state);
565         }
566         AddRoundKey(state, k->round_keys);
567
568         for (i = 0; i < 4; i++)
569                 put_unaligned_le32(state[i], &dst[i * sizeof(__le32)]);
570 }
571
572 #ifdef ENABLE_ALG_TESTS
573 #include <openssl/aes.h>
574 static void test_aes_keysize(int keysize)
575 {
576         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
577
578         while (num_tests--) {
579                 struct aes_key k;
580                 AES_KEY ref_k;
581                 u8 key[AES_256_KEY_SIZE];
582                 u8 ptext[AES_BLOCK_SIZE];
583                 u8 ctext[AES_BLOCK_SIZE];
584                 u8 ref_ctext[AES_BLOCK_SIZE];
585                 u8 decrypted[AES_BLOCK_SIZE];
586
587                 rand_bytes(key, keysize);
588                 rand_bytes(ptext, AES_BLOCK_SIZE);
589
590                 aes_setkey(&k, key, keysize);
591                 aes_encrypt(&k, ptext, ctext);
592
593                 ASSERT(AES_set_encrypt_key(key, keysize*8, &ref_k) == 0);
594                 AES_encrypt(ptext, ref_ctext, &ref_k);
595
596                 ASSERT(memcmp(ctext, ref_ctext, AES_BLOCK_SIZE) == 0);
597
598                 aes_decrypt(&k, ctext, decrypted);
599                 ASSERT(memcmp(ptext, decrypted, AES_BLOCK_SIZE) == 0);
600         }
601 }
602
603 static void test_aes(void)
604 {
605         test_aes_keysize(AES_128_KEY_SIZE);
606         test_aes_keysize(AES_192_KEY_SIZE);
607         test_aes_keysize(AES_256_KEY_SIZE);
608 }
609 #endif /* ENABLE_ALG_TESTS */
610
611 /*----------------------------------------------------------------------------*
612  *                            SHA-512 and SHA-256                             *
613  *----------------------------------------------------------------------------*/
614
615 /*
616  * Reference: "FIPS 180-2, Secure Hash Standard"
617  *      https://csrc.nist.gov/csrc/media/publications/fips/180/2/archive/2002-08-01/documents/fips180-2withchangenotice.pdf
618  */
619
620 #define SHA512_DIGEST_SIZE      64
621 #define SHA512_BLOCK_SIZE       128
622
623 #define SHA256_DIGEST_SIZE      32
624 #define SHA256_BLOCK_SIZE       64
625
626 #define Ch(x, y, z)     (((x) & (y)) ^ (~(x) & (z)))
627 #define Maj(x, y, z)    (((x) & (y)) ^ ((x) & (z)) ^ ((y) & (z)))
628
629 #define Sigma512_0(x)   (ror64((x), 28) ^ ror64((x), 34) ^ ror64((x), 39))
630 #define Sigma512_1(x)   (ror64((x), 14) ^ ror64((x), 18) ^ ror64((x), 41))
631 #define sigma512_0(x)   (ror64((x),  1) ^ ror64((x),  8) ^ ((x) >> 7))
632 #define sigma512_1(x)   (ror64((x), 19) ^ ror64((x), 61) ^ ((x) >> 6))
633
634 #define Sigma256_0(x)   (ror32((x),  2) ^ ror32((x), 13) ^ ror32((x), 22))
635 #define Sigma256_1(x)   (ror32((x),  6) ^ ror32((x), 11) ^ ror32((x), 25))
636 #define sigma256_0(x)   (ror32((x),  7) ^ ror32((x), 18) ^ ((x) >>  3))
637 #define sigma256_1(x)   (ror32((x), 17) ^ ror32((x), 19) ^ ((x) >> 10))
638
639 static const u64 sha512_iv[8] = {
640         0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b,
641         0xa54ff53a5f1d36f1, 0x510e527fade682d1, 0x9b05688c2b3e6c1f,
642         0x1f83d9abfb41bd6b, 0x5be0cd19137e2179,
643 };
644
645 static const u64 sha512_round_constants[80] = {
646         0x428a2f98d728ae22, 0x7137449123ef65cd, 0xb5c0fbcfec4d3b2f,
647         0xe9b5dba58189dbbc, 0x3956c25bf348b538, 0x59f111f1b605d019,
648         0x923f82a4af194f9b, 0xab1c5ed5da6d8118, 0xd807aa98a3030242,
649         0x12835b0145706fbe, 0x243185be4ee4b28c, 0x550c7dc3d5ffb4e2,
650         0x72be5d74f27b896f, 0x80deb1fe3b1696b1, 0x9bdc06a725c71235,
651         0xc19bf174cf692694, 0xe49b69c19ef14ad2, 0xefbe4786384f25e3,
652         0x0fc19dc68b8cd5b5, 0x240ca1cc77ac9c65, 0x2de92c6f592b0275,
653         0x4a7484aa6ea6e483, 0x5cb0a9dcbd41fbd4, 0x76f988da831153b5,
654         0x983e5152ee66dfab, 0xa831c66d2db43210, 0xb00327c898fb213f,
655         0xbf597fc7beef0ee4, 0xc6e00bf33da88fc2, 0xd5a79147930aa725,
656         0x06ca6351e003826f, 0x142929670a0e6e70, 0x27b70a8546d22ffc,
657         0x2e1b21385c26c926, 0x4d2c6dfc5ac42aed, 0x53380d139d95b3df,
658         0x650a73548baf63de, 0x766a0abb3c77b2a8, 0x81c2c92e47edaee6,
659         0x92722c851482353b, 0xa2bfe8a14cf10364, 0xa81a664bbc423001,
660         0xc24b8b70d0f89791, 0xc76c51a30654be30, 0xd192e819d6ef5218,
661         0xd69906245565a910, 0xf40e35855771202a, 0x106aa07032bbd1b8,
662         0x19a4c116b8d2d0c8, 0x1e376c085141ab53, 0x2748774cdf8eeb99,
663         0x34b0bcb5e19b48a8, 0x391c0cb3c5c95a63, 0x4ed8aa4ae3418acb,
664         0x5b9cca4f7763e373, 0x682e6ff3d6b2b8a3, 0x748f82ee5defb2fc,
665         0x78a5636f43172f60, 0x84c87814a1f0ab72, 0x8cc702081a6439ec,
666         0x90befffa23631e28, 0xa4506cebde82bde9, 0xbef9a3f7b2c67915,
667         0xc67178f2e372532b, 0xca273eceea26619c, 0xd186b8c721c0c207,
668         0xeada7dd6cde0eb1e, 0xf57d4f7fee6ed178, 0x06f067aa72176fba,
669         0x0a637dc5a2c898a6, 0x113f9804bef90dae, 0x1b710b35131c471b,
670         0x28db77f523047d84, 0x32caab7b40c72493, 0x3c9ebe0a15c9bebc,
671         0x431d67c49c100d4c, 0x4cc5d4becb3e42b6, 0x597f299cfc657e2a,
672         0x5fcb6fab3ad6faec, 0x6c44198c4a475817,
673 };
674
675 /* Compute the SHA-512 digest of the given buffer */
676 static void sha512(const u8 *in, size_t inlen, u8 out[SHA512_DIGEST_SIZE])
677 {
678         const size_t msglen = ROUND_UP(inlen + 17, SHA512_BLOCK_SIZE);
679         u8 * const msg = xmalloc(msglen);
680         u64 H[8];
681         int i;
682
683         /* super naive way of handling the padding */
684         memcpy(msg, in, inlen);
685         memset(&msg[inlen], 0, msglen - inlen);
686         msg[inlen] = 0x80;
687         put_unaligned_be64((u64)inlen * 8, &msg[msglen - sizeof(__be64)]);
688         in = msg;
689
690         memcpy(H, sha512_iv, sizeof(H));
691         do {
692                 u64 a = H[0], b = H[1], c = H[2], d = H[3],
693                     e = H[4], f = H[5], g = H[6], h = H[7];
694                 u64 W[80];
695
696                 for (i = 0; i < 16; i++)
697                         W[i] = get_unaligned_be64(&in[i * sizeof(__be64)]);
698                 for (; i < ARRAY_SIZE(W); i++)
699                         W[i] = sigma512_1(W[i - 2]) + W[i - 7] +
700                                sigma512_0(W[i - 15]) + W[i - 16];
701                 for (i = 0; i < ARRAY_SIZE(W); i++) {
702                         u64 T1 = h + Sigma512_1(e) + Ch(e, f, g) +
703                                  sha512_round_constants[i] + W[i];
704                         u64 T2 = Sigma512_0(a) + Maj(a, b, c);
705
706                         h = g; g = f; f = e; e = d + T1;
707                         d = c; c = b; b = a; a = T1 + T2;
708                 }
709                 H[0] += a; H[1] += b; H[2] += c; H[3] += d;
710                 H[4] += e; H[5] += f; H[6] += g; H[7] += h;
711         } while ((in += SHA512_BLOCK_SIZE) != &msg[msglen]);
712
713         for (i = 0; i < ARRAY_SIZE(H); i++)
714                 put_unaligned_be64(H[i], &out[i * sizeof(__be64)]);
715         free(msg);
716 }
717
718 /* Compute the SHA-256 digest of the given buffer */
719 static void sha256(const u8 *in, size_t inlen, u8 out[SHA256_DIGEST_SIZE])
720 {
721         const size_t msglen = ROUND_UP(inlen + 9, SHA256_BLOCK_SIZE);
722         u8 * const msg = xmalloc(msglen);
723         u32 H[8];
724         int i;
725
726         /* super naive way of handling the padding */
727         memcpy(msg, in, inlen);
728         memset(&msg[inlen], 0, msglen - inlen);
729         msg[inlen] = 0x80;
730         put_unaligned_be64((u64)inlen * 8, &msg[msglen - sizeof(__be64)]);
731         in = msg;
732
733         for (i = 0; i < ARRAY_SIZE(H); i++)
734                 H[i] = (u32)(sha512_iv[i] >> 32);
735         do {
736                 u32 a = H[0], b = H[1], c = H[2], d = H[3],
737                     e = H[4], f = H[5], g = H[6], h = H[7];
738                 u32 W[64];
739
740                 for (i = 0; i < 16; i++)
741                         W[i] = get_unaligned_be32(&in[i * sizeof(__be32)]);
742                 for (; i < ARRAY_SIZE(W); i++)
743                         W[i] = sigma256_1(W[i - 2]) + W[i - 7] +
744                                sigma256_0(W[i - 15]) + W[i - 16];
745                 for (i = 0; i < ARRAY_SIZE(W); i++) {
746                         u32 T1 = h + Sigma256_1(e) + Ch(e, f, g) +
747                                  (u32)(sha512_round_constants[i] >> 32) + W[i];
748                         u32 T2 = Sigma256_0(a) + Maj(a, b, c);
749
750                         h = g; g = f; f = e; e = d + T1;
751                         d = c; c = b; b = a; a = T1 + T2;
752                 }
753                 H[0] += a; H[1] += b; H[2] += c; H[3] += d;
754                 H[4] += e; H[5] += f; H[6] += g; H[7] += h;
755         } while ((in += SHA256_BLOCK_SIZE) != &msg[msglen]);
756
757         for (i = 0; i < ARRAY_SIZE(H); i++)
758                 put_unaligned_be32(H[i], &out[i * sizeof(__be32)]);
759         free(msg);
760 }
761
762 #ifdef ENABLE_ALG_TESTS
763 #include <openssl/sha.h>
764 static void test_sha2(void)
765 {
766         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
767
768         while (num_tests--) {
769                 u8 in[4096];
770                 u8 digest[SHA512_DIGEST_SIZE];
771                 u8 ref_digest[SHA512_DIGEST_SIZE];
772                 const size_t inlen = rand() % (1 + sizeof(in));
773
774                 rand_bytes(in, inlen);
775
776                 sha256(in, inlen, digest);
777                 SHA256(in, inlen, ref_digest);
778                 ASSERT(memcmp(digest, ref_digest, SHA256_DIGEST_SIZE) == 0);
779
780                 sha512(in, inlen, digest);
781                 SHA512(in, inlen, ref_digest);
782                 ASSERT(memcmp(digest, ref_digest, SHA512_DIGEST_SIZE) == 0);
783         }
784 }
785 #endif /* ENABLE_ALG_TESTS */
786
787 /*----------------------------------------------------------------------------*
788  *                            HKDF implementation                             *
789  *----------------------------------------------------------------------------*/
790
791 static void hmac_sha512(const u8 *key, size_t keylen, const u8 *msg,
792                         size_t msglen, u8 mac[SHA512_DIGEST_SIZE])
793 {
794         u8 *ibuf = xmalloc(SHA512_BLOCK_SIZE + msglen);
795         u8 obuf[SHA512_BLOCK_SIZE + SHA512_DIGEST_SIZE];
796
797         ASSERT(keylen <= SHA512_BLOCK_SIZE); /* keylen > bs not implemented */
798
799         memset(ibuf, 0x36, SHA512_BLOCK_SIZE);
800         xor(ibuf, ibuf, key, keylen);
801         memcpy(&ibuf[SHA512_BLOCK_SIZE], msg, msglen);
802
803         memset(obuf, 0x5c, SHA512_BLOCK_SIZE);
804         xor(obuf, obuf, key, keylen);
805         sha512(ibuf, SHA512_BLOCK_SIZE + msglen, &obuf[SHA512_BLOCK_SIZE]);
806         sha512(obuf, sizeof(obuf), mac);
807
808         free(ibuf);
809 }
810
811 static void hkdf_sha512(const u8 *ikm, size_t ikmlen,
812                         const u8 *salt, size_t saltlen,
813                         const u8 *info, size_t infolen,
814                         u8 *output, size_t outlen)
815 {
816         static const u8 default_salt[SHA512_DIGEST_SIZE];
817         u8 prk[SHA512_DIGEST_SIZE]; /* pseudorandom key */
818         u8 *buf = xmalloc(1 + infolen + SHA512_DIGEST_SIZE);
819         u8 counter = 1;
820         size_t i;
821
822         if (saltlen == 0) {
823                 salt = default_salt;
824                 saltlen = sizeof(default_salt);
825         }
826
827         /* HKDF-Extract */
828         ASSERT(ikmlen > 0);
829         hmac_sha512(salt, saltlen, ikm, ikmlen, prk);
830
831         /* HKDF-Expand */
832         for (i = 0; i < outlen; i += SHA512_DIGEST_SIZE) {
833                 u8 *p = buf;
834                 u8 tmp[SHA512_DIGEST_SIZE];
835
836                 ASSERT(counter != 0);
837                 if (i > 0) {
838                         memcpy(p, &output[i - SHA512_DIGEST_SIZE],
839                                SHA512_DIGEST_SIZE);
840                         p += SHA512_DIGEST_SIZE;
841                 }
842                 memcpy(p, info, infolen);
843                 p += infolen;
844                 *p++ = counter++;
845                 hmac_sha512(prk, sizeof(prk), buf, p - buf, tmp);
846                 memcpy(&output[i], tmp, MIN(sizeof(tmp), outlen - i));
847         }
848         free(buf);
849 }
850
851 #ifdef ENABLE_ALG_TESTS
852 #include <openssl/evp.h>
853 #include <openssl/kdf.h>
854 static void openssl_hkdf_sha512(const u8 *ikm, size_t ikmlen,
855                                 const u8 *salt, size_t saltlen,
856                                 const u8 *info, size_t infolen,
857                                 u8 *output, size_t outlen)
858 {
859         EVP_PKEY_CTX *pctx = EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, NULL);
860         size_t actual_outlen = outlen;
861
862         ASSERT(pctx != NULL);
863         ASSERT(EVP_PKEY_derive_init(pctx) > 0);
864         ASSERT(EVP_PKEY_CTX_set_hkdf_md(pctx, EVP_sha512()) > 0);
865         ASSERT(EVP_PKEY_CTX_set1_hkdf_key(pctx, ikm, ikmlen) > 0);
866         ASSERT(EVP_PKEY_CTX_set1_hkdf_salt(pctx, salt, saltlen) > 0);
867         ASSERT(EVP_PKEY_CTX_add1_hkdf_info(pctx, info, infolen) > 0);
868         ASSERT(EVP_PKEY_derive(pctx, output, &actual_outlen) > 0);
869         ASSERT(actual_outlen == outlen);
870         EVP_PKEY_CTX_free(pctx);
871 }
872
873 static void test_hkdf_sha512(void)
874 {
875         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
876
877         while (num_tests--) {
878                 u8 ikm[SHA512_DIGEST_SIZE];
879                 u8 salt[SHA512_DIGEST_SIZE];
880                 u8 info[128];
881                 u8 actual_output[512];
882                 u8 expected_output[sizeof(actual_output)];
883                 size_t ikmlen = 1 + (rand() % sizeof(ikm));
884                 size_t saltlen = rand() % (1 + sizeof(salt));
885                 size_t infolen = rand() % (1 + sizeof(info));
886                 size_t outlen = rand() % (1 + sizeof(actual_output));
887
888                 rand_bytes(ikm, ikmlen);
889                 rand_bytes(salt, saltlen);
890                 rand_bytes(info, infolen);
891
892                 hkdf_sha512(ikm, ikmlen, salt, saltlen, info, infolen,
893                             actual_output, outlen);
894                 openssl_hkdf_sha512(ikm, ikmlen, salt, saltlen, info, infolen,
895                                     expected_output, outlen);
896                 ASSERT(memcmp(actual_output, expected_output, outlen) == 0);
897         }
898 }
899 #endif /* ENABLE_ALG_TESTS */
900
901 /*----------------------------------------------------------------------------*
902  *                            AES encryption modes                            *
903  *----------------------------------------------------------------------------*/
904
905 static void aes_256_xts_crypt(const u8 key[2 * AES_256_KEY_SIZE],
906                               const u8 iv[AES_BLOCK_SIZE], const u8 *src,
907                               u8 *dst, size_t nbytes, bool decrypting)
908 {
909         struct aes_key tweak_key, cipher_key;
910         ble128 t;
911         size_t i;
912
913         ASSERT(nbytes % AES_BLOCK_SIZE == 0);
914         aes_setkey(&cipher_key, key, AES_256_KEY_SIZE);
915         aes_setkey(&tweak_key, &key[AES_256_KEY_SIZE], AES_256_KEY_SIZE);
916         aes_encrypt(&tweak_key, iv, (u8 *)&t);
917         for (i = 0; i < nbytes; i += AES_BLOCK_SIZE) {
918                 xor(&dst[i], &src[i], (const u8 *)&t, AES_BLOCK_SIZE);
919                 if (decrypting)
920                         aes_decrypt(&cipher_key, &dst[i], &dst[i]);
921                 else
922                         aes_encrypt(&cipher_key, &dst[i], &dst[i]);
923                 xor(&dst[i], &dst[i], (const u8 *)&t, AES_BLOCK_SIZE);
924                 gf2_128_mul_x(&t);
925         }
926 }
927
928 static void aes_256_xts_encrypt(const u8 key[2 * AES_256_KEY_SIZE],
929                                 const u8 iv[AES_BLOCK_SIZE], const u8 *src,
930                                 u8 *dst, size_t nbytes)
931 {
932         aes_256_xts_crypt(key, iv, src, dst, nbytes, false);
933 }
934
935 static void aes_256_xts_decrypt(const u8 key[2 * AES_256_KEY_SIZE],
936                                 const u8 iv[AES_BLOCK_SIZE], const u8 *src,
937                                 u8 *dst, size_t nbytes)
938 {
939         aes_256_xts_crypt(key, iv, src, dst, nbytes, true);
940 }
941
942 #ifdef ENABLE_ALG_TESTS
943 #include <openssl/evp.h>
944 static void test_aes_256_xts(void)
945 {
946         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
947         EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
948
949         ASSERT(ctx != NULL);
950         while (num_tests--) {
951                 u8 key[2 * AES_256_KEY_SIZE];
952                 u8 iv[AES_BLOCK_SIZE];
953                 u8 ptext[512];
954                 u8 ctext[sizeof(ptext)];
955                 u8 ref_ctext[sizeof(ptext)];
956                 u8 decrypted[sizeof(ptext)];
957                 const size_t datalen = ROUND_DOWN(rand() % (1 + sizeof(ptext)),
958                                                   AES_BLOCK_SIZE);
959                 int outl, res;
960
961                 rand_bytes(key, sizeof(key));
962                 rand_bytes(iv, sizeof(iv));
963                 rand_bytes(ptext, datalen);
964
965                 aes_256_xts_encrypt(key, iv, ptext, ctext, datalen);
966                 res = EVP_EncryptInit_ex(ctx, EVP_aes_256_xts(), NULL, key, iv);
967                 ASSERT(res > 0);
968                 res = EVP_EncryptUpdate(ctx, ref_ctext, &outl, ptext, datalen);
969                 ASSERT(res > 0);
970                 ASSERT(outl == datalen);
971                 ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
972
973                 aes_256_xts_decrypt(key, iv, ctext, decrypted, datalen);
974                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
975         }
976         EVP_CIPHER_CTX_free(ctx);
977 }
978 #endif /* ENABLE_ALG_TESTS */
979
980 static void aes_cbc_encrypt(const struct aes_key *k,
981                             const u8 iv[AES_BLOCK_SIZE],
982                             const u8 *src, u8 *dst, size_t nbytes)
983 {
984         size_t i;
985
986         ASSERT(nbytes % AES_BLOCK_SIZE == 0);
987         for (i = 0; i < nbytes; i += AES_BLOCK_SIZE) {
988                 xor(&dst[i], &src[i], (i == 0 ? iv : &dst[i - AES_BLOCK_SIZE]),
989                     AES_BLOCK_SIZE);
990                 aes_encrypt(k, &dst[i], &dst[i]);
991         }
992 }
993
994 static void aes_cbc_decrypt(const struct aes_key *k,
995                             const u8 iv[AES_BLOCK_SIZE],
996                             const u8 *src, u8 *dst, size_t nbytes)
997 {
998         size_t i = nbytes;
999
1000         ASSERT(i % AES_BLOCK_SIZE == 0);
1001         while (i) {
1002                 i -= AES_BLOCK_SIZE;
1003                 aes_decrypt(k, &src[i], &dst[i]);
1004                 xor(&dst[i], &dst[i], (i == 0 ? iv : &src[i - AES_BLOCK_SIZE]),
1005                     AES_BLOCK_SIZE);
1006         }
1007 }
1008
1009 static void aes_cts_cbc_encrypt(const u8 *key, int keysize,
1010                                 const u8 iv[AES_BLOCK_SIZE],
1011                                 const u8 *src, u8 *dst, size_t nbytes)
1012 {
1013         const size_t offset = ROUND_DOWN(nbytes - 1, AES_BLOCK_SIZE);
1014         const size_t final_bsize = nbytes - offset;
1015         struct aes_key k;
1016         u8 *pad;
1017         u8 buf[AES_BLOCK_SIZE];
1018
1019         ASSERT(nbytes >= AES_BLOCK_SIZE);
1020
1021         aes_setkey(&k, key, keysize);
1022
1023         if (nbytes == AES_BLOCK_SIZE)
1024                 return aes_cbc_encrypt(&k, iv, src, dst, nbytes);
1025
1026         aes_cbc_encrypt(&k, iv, src, dst, offset);
1027         pad = &dst[offset - AES_BLOCK_SIZE];
1028
1029         memcpy(buf, pad, AES_BLOCK_SIZE);
1030         xor(buf, buf, &src[offset], final_bsize);
1031         memcpy(&dst[offset], pad, final_bsize);
1032         aes_encrypt(&k, buf, pad);
1033 }
1034
1035 static void aes_cts_cbc_decrypt(const u8 *key, int keysize,
1036                                 const u8 iv[AES_BLOCK_SIZE],
1037                                 const u8 *src, u8 *dst, size_t nbytes)
1038 {
1039         const size_t offset = ROUND_DOWN(nbytes - 1, AES_BLOCK_SIZE);
1040         const size_t final_bsize = nbytes - offset;
1041         struct aes_key k;
1042         u8 *pad;
1043
1044         ASSERT(nbytes >= AES_BLOCK_SIZE);
1045
1046         aes_setkey(&k, key, keysize);
1047
1048         if (nbytes == AES_BLOCK_SIZE)
1049                 return aes_cbc_decrypt(&k, iv, src, dst, nbytes);
1050
1051         pad = &dst[offset - AES_BLOCK_SIZE];
1052         aes_decrypt(&k, &src[offset - AES_BLOCK_SIZE], pad);
1053         xor(&dst[offset], &src[offset], pad, final_bsize);
1054         xor(pad, pad, &dst[offset], final_bsize);
1055
1056         aes_cbc_decrypt(&k, (offset == AES_BLOCK_SIZE ?
1057                              iv : &src[offset - 2 * AES_BLOCK_SIZE]),
1058                         pad, pad, AES_BLOCK_SIZE);
1059         aes_cbc_decrypt(&k, iv, src, dst, offset - AES_BLOCK_SIZE);
1060 }
1061
1062 static void aes_256_cts_cbc_encrypt(const u8 key[AES_256_KEY_SIZE],
1063                                     const u8 iv[AES_BLOCK_SIZE],
1064                                     const u8 *src, u8 *dst, size_t nbytes)
1065 {
1066         aes_cts_cbc_encrypt(key, AES_256_KEY_SIZE, iv, src, dst, nbytes);
1067 }
1068
1069 static void aes_256_cts_cbc_decrypt(const u8 key[AES_256_KEY_SIZE],
1070                                     const u8 iv[AES_BLOCK_SIZE],
1071                                     const u8 *src, u8 *dst, size_t nbytes)
1072 {
1073         aes_cts_cbc_decrypt(key, AES_256_KEY_SIZE, iv, src, dst, nbytes);
1074 }
1075
1076 #ifdef ENABLE_ALG_TESTS
1077 #include <openssl/modes.h>
1078 static void aes_block128_f(const unsigned char in[16],
1079                            unsigned char out[16], const void *key)
1080 {
1081         aes_encrypt(key, in, out);
1082 }
1083
1084 static void test_aes_256_cts_cbc(void)
1085 {
1086         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1087
1088         while (num_tests--) {
1089                 u8 key[AES_256_KEY_SIZE];
1090                 u8 iv[AES_BLOCK_SIZE];
1091                 u8 iv_copy[AES_BLOCK_SIZE];
1092                 u8 ptext[512];
1093                 u8 ctext[sizeof(ptext)];
1094                 u8 ref_ctext[sizeof(ptext)];
1095                 u8 decrypted[sizeof(ptext)];
1096                 const size_t datalen = 16 + (rand() % (sizeof(ptext) - 15));
1097                 struct aes_key k;
1098
1099                 rand_bytes(key, sizeof(key));
1100                 rand_bytes(iv, sizeof(iv));
1101                 rand_bytes(ptext, datalen);
1102
1103                 aes_256_cts_cbc_encrypt(key, iv, ptext, ctext, datalen);
1104
1105                 /* OpenSSL doesn't allow datalen=AES_BLOCK_SIZE; Linux does */
1106                 if (datalen != AES_BLOCK_SIZE) {
1107                         aes_setkey(&k, key, sizeof(key));
1108                         memcpy(iv_copy, iv, sizeof(iv));
1109                         ASSERT(CRYPTO_cts128_encrypt_block(ptext, ref_ctext,
1110                                                            datalen, &k, iv_copy,
1111                                                            aes_block128_f)
1112                                == datalen);
1113                         ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1114                 }
1115                 aes_256_cts_cbc_decrypt(key, iv, ctext, decrypted, datalen);
1116                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1117         }
1118 }
1119 #endif /* ENABLE_ALG_TESTS */
1120
1121 static void essiv_generate_iv(const u8 orig_key[AES_128_KEY_SIZE],
1122                               const u8 orig_iv[AES_BLOCK_SIZE],
1123                               u8 real_iv[AES_BLOCK_SIZE])
1124 {
1125         u8 essiv_key[SHA256_DIGEST_SIZE];
1126         struct aes_key essiv;
1127
1128         /* AES encrypt the original IV using a hash of the original key */
1129         STATIC_ASSERT(SHA256_DIGEST_SIZE == AES_256_KEY_SIZE);
1130         sha256(orig_key, AES_128_KEY_SIZE, essiv_key);
1131         aes_setkey(&essiv, essiv_key, AES_256_KEY_SIZE);
1132         aes_encrypt(&essiv, orig_iv, real_iv);
1133 }
1134
1135 static void aes_128_cbc_essiv_encrypt(const u8 key[AES_128_KEY_SIZE],
1136                                       const u8 iv[AES_BLOCK_SIZE],
1137                                       const u8 *src, u8 *dst, size_t nbytes)
1138 {
1139         struct aes_key k;
1140         u8 real_iv[AES_BLOCK_SIZE];
1141
1142         aes_setkey(&k, key, AES_128_KEY_SIZE);
1143         essiv_generate_iv(key, iv, real_iv);
1144         aes_cbc_encrypt(&k, real_iv, src, dst, nbytes);
1145 }
1146
1147 static void aes_128_cbc_essiv_decrypt(const u8 key[AES_128_KEY_SIZE],
1148                                       const u8 iv[AES_BLOCK_SIZE],
1149                                       const u8 *src, u8 *dst, size_t nbytes)
1150 {
1151         struct aes_key k;
1152         u8 real_iv[AES_BLOCK_SIZE];
1153
1154         aes_setkey(&k, key, AES_128_KEY_SIZE);
1155         essiv_generate_iv(key, iv, real_iv);
1156         aes_cbc_decrypt(&k, real_iv, src, dst, nbytes);
1157 }
1158
1159 static void aes_128_cts_cbc_encrypt(const u8 key[AES_128_KEY_SIZE],
1160                                     const u8 iv[AES_BLOCK_SIZE],
1161                                     const u8 *src, u8 *dst, size_t nbytes)
1162 {
1163         aes_cts_cbc_encrypt(key, AES_128_KEY_SIZE, iv, src, dst, nbytes);
1164 }
1165
1166 static void aes_128_cts_cbc_decrypt(const u8 key[AES_128_KEY_SIZE],
1167                                     const u8 iv[AES_BLOCK_SIZE],
1168                                     const u8 *src, u8 *dst, size_t nbytes)
1169 {
1170         aes_cts_cbc_decrypt(key, AES_128_KEY_SIZE, iv, src, dst, nbytes);
1171 }
1172
1173 /*----------------------------------------------------------------------------*
1174  *                           XChaCha12 stream cipher                          *
1175  *----------------------------------------------------------------------------*/
1176
1177 /*
1178  * References:
1179  *   - "XChaCha: eXtended-nonce ChaCha and AEAD_XChaCha20_Poly1305"
1180  *      https://tools.ietf.org/html/draft-arciszewski-xchacha-03
1181  *
1182  *   - "ChaCha, a variant of Salsa20"
1183  *      https://cr.yp.to/chacha/chacha-20080128.pdf
1184  *
1185  *   - "Extending the Salsa20 nonce"
1186  *      https://cr.yp.to/snuffle/xsalsa-20081128.pdf
1187  */
1188
1189 #define CHACHA_KEY_SIZE         32
1190 #define XCHACHA_KEY_SIZE        CHACHA_KEY_SIZE
1191 #define XCHACHA_NONCE_SIZE      24
1192
1193 static void chacha_init_state(u32 state[16], const u8 key[CHACHA_KEY_SIZE],
1194                               const u8 iv[16])
1195 {
1196         static const u8 consts[16] = "expand 32-byte k";
1197         int i;
1198
1199         for (i = 0; i < 4; i++)
1200                 state[i] = get_unaligned_le32(&consts[i * sizeof(__le32)]);
1201         for (i = 0; i < 8; i++)
1202                 state[4 + i] = get_unaligned_le32(&key[i * sizeof(__le32)]);
1203         for (i = 0; i < 4; i++)
1204                 state[12 + i] = get_unaligned_le32(&iv[i * sizeof(__le32)]);
1205 }
1206
1207 #define CHACHA_QUARTERROUND(a, b, c, d)         \
1208         do {                                    \
1209                 a += b; d = rol32(d ^ a, 16);   \
1210                 c += d; b = rol32(b ^ c, 12);   \
1211                 a += b; d = rol32(d ^ a, 8);    \
1212                 c += d; b = rol32(b ^ c, 7);    \
1213         } while (0)
1214
1215 static void chacha_permute(u32 x[16], int nrounds)
1216 {
1217         do {
1218                 /* column round */
1219                 CHACHA_QUARTERROUND(x[0], x[4], x[8], x[12]);
1220                 CHACHA_QUARTERROUND(x[1], x[5], x[9], x[13]);
1221                 CHACHA_QUARTERROUND(x[2], x[6], x[10], x[14]);
1222                 CHACHA_QUARTERROUND(x[3], x[7], x[11], x[15]);
1223
1224                 /* diagonal round */
1225                 CHACHA_QUARTERROUND(x[0], x[5], x[10], x[15]);
1226                 CHACHA_QUARTERROUND(x[1], x[6], x[11], x[12]);
1227                 CHACHA_QUARTERROUND(x[2], x[7], x[8], x[13]);
1228                 CHACHA_QUARTERROUND(x[3], x[4], x[9], x[14]);
1229         } while ((nrounds -= 2) != 0);
1230 }
1231
1232 static void xchacha(const u8 key[XCHACHA_KEY_SIZE],
1233                     const u8 nonce[XCHACHA_NONCE_SIZE],
1234                     const u8 *src, u8 *dst, size_t nbytes, int nrounds)
1235 {
1236         u32 state[16];
1237         u8 real_key[CHACHA_KEY_SIZE];
1238         u8 real_iv[16] = { 0 };
1239         size_t i, j;
1240
1241         /* Compute real key using original key and first 128 nonce bits */
1242         chacha_init_state(state, key, nonce);
1243         chacha_permute(state, nrounds);
1244         for (i = 0; i < 8; i++) /* state words 0..3, 12..15 */
1245                 put_unaligned_le32(state[(i < 4 ? 0 : 8) + i],
1246                                    &real_key[i * sizeof(__le32)]);
1247
1248         /* Now do regular ChaCha, using real key and remaining nonce bits */
1249         memcpy(&real_iv[8], nonce + 16, 8);
1250         chacha_init_state(state, real_key, real_iv);
1251         for (i = 0; i < nbytes; i += 64) {
1252                 u32 x[16];
1253                 __le32 keystream[16];
1254
1255                 memcpy(x, state, 64);
1256                 chacha_permute(x, nrounds);
1257                 for (j = 0; j < 16; j++)
1258                         keystream[j] = cpu_to_le32(x[j] + state[j]);
1259                 xor(&dst[i], &src[i], (u8 *)keystream, MIN(nbytes - i, 64));
1260                 if (++state[12] == 0)
1261                         state[13]++;
1262         }
1263 }
1264
1265 static void xchacha12(const u8 key[XCHACHA_KEY_SIZE],
1266                       const u8 nonce[XCHACHA_NONCE_SIZE],
1267                       const u8 *src, u8 *dst, size_t nbytes)
1268 {
1269         xchacha(key, nonce, src, dst, nbytes, 12);
1270 }
1271
1272 /*----------------------------------------------------------------------------*
1273  *                                 Poly1305                                   *
1274  *----------------------------------------------------------------------------*/
1275
1276 /*
1277  * Note: this is only the Poly1305 Îµ-almost-∆-universal hash function, not the
1278  * full Poly1305 MAC.  I.e., it doesn't add anything at the end.
1279  */
1280
1281 #define POLY1305_KEY_SIZE       16
1282 #define POLY1305_BLOCK_SIZE     16
1283
1284 static void poly1305(const u8 key[POLY1305_KEY_SIZE],
1285                      const u8 *msg, size_t msglen, le128 *out)
1286 {
1287         const u32 limb_mask = 0x3ffffff;        /* limbs are base 2^26 */
1288         const u64 r0 = (get_unaligned_le32(key +  0) >> 0) & 0x3ffffff;
1289         const u64 r1 = (get_unaligned_le32(key +  3) >> 2) & 0x3ffff03;
1290         const u64 r2 = (get_unaligned_le32(key +  6) >> 4) & 0x3ffc0ff;
1291         const u64 r3 = (get_unaligned_le32(key +  9) >> 6) & 0x3f03fff;
1292         const u64 r4 = (get_unaligned_le32(key + 12) >> 8) & 0x00fffff;
1293         u32 h0 = 0, h1 = 0, h2 = 0, h3 = 0, h4 = 0;
1294         u32 g0, g1, g2, g3, g4, ge_p_mask;
1295
1296         /* Partial block support is not necessary for Adiantum */
1297         ASSERT(msglen % POLY1305_BLOCK_SIZE == 0);
1298
1299         while (msglen) {
1300                 u64 d0, d1, d2, d3, d4;
1301
1302                 /* h += *msg */
1303                 h0 += (get_unaligned_le32(msg +  0) >> 0) & limb_mask;
1304                 h1 += (get_unaligned_le32(msg +  3) >> 2) & limb_mask;
1305                 h2 += (get_unaligned_le32(msg +  6) >> 4) & limb_mask;
1306                 h3 += (get_unaligned_le32(msg +  9) >> 6) & limb_mask;
1307                 h4 += (get_unaligned_le32(msg + 12) >> 8) | (1 << 24);
1308
1309                 /* h *= r */
1310                 d0 = h0*r0 + h1*5*r4 + h2*5*r3 + h3*5*r2 + h4*5*r1;
1311                 d1 = h0*r1 + h1*r0   + h2*5*r4 + h3*5*r3 + h4*5*r2;
1312                 d2 = h0*r2 + h1*r1   + h2*r0   + h3*5*r4 + h4*5*r3;
1313                 d3 = h0*r3 + h1*r2   + h2*r1   + h3*r0   + h4*5*r4;
1314                 d4 = h0*r4 + h1*r3   + h2*r2   + h3*r1   + h4*r0;
1315
1316                 /* (partial) h %= 2^130 - 5 */
1317                 d1 += d0 >> 26;         h0 = d0 & limb_mask;
1318                 d2 += d1 >> 26;         h1 = d1 & limb_mask;
1319                 d3 += d2 >> 26;         h2 = d2 & limb_mask;
1320                 d4 += d3 >> 26;         h3 = d3 & limb_mask;
1321                 h0 += (d4 >> 26) * 5;   h4 = d4 & limb_mask;
1322                 h1 += h0 >> 26;         h0 &= limb_mask;
1323
1324                 msg += POLY1305_BLOCK_SIZE;
1325                 msglen -= POLY1305_BLOCK_SIZE;
1326         }
1327
1328         /* fully carry h */
1329         h2 += (h1 >> 26);       h1 &= limb_mask;
1330         h3 += (h2 >> 26);       h2 &= limb_mask;
1331         h4 += (h3 >> 26);       h3 &= limb_mask;
1332         h0 += (h4 >> 26) * 5;   h4 &= limb_mask;
1333         h1 += (h0 >> 26);       h0 &= limb_mask;
1334
1335         /* if (h >= 2^130 - 5) h -= 2^130 - 5; */
1336         g0 = h0 + 5;
1337         g1 = h1 + (g0 >> 26);   g0 &= limb_mask;
1338         g2 = h2 + (g1 >> 26);   g1 &= limb_mask;
1339         g3 = h3 + (g2 >> 26);   g2 &= limb_mask;
1340         g4 = h4 + (g3 >> 26);   g3 &= limb_mask;
1341         ge_p_mask = ~((g4 >> 26) - 1); /* all 1's if h >= 2^130 - 5, else 0 */
1342         h0 = (h0 & ~ge_p_mask) | (g0 & ge_p_mask);
1343         h1 = (h1 & ~ge_p_mask) | (g1 & ge_p_mask);
1344         h2 = (h2 & ~ge_p_mask) | (g2 & ge_p_mask);
1345         h3 = (h3 & ~ge_p_mask) | (g3 & ge_p_mask);
1346         h4 = (h4 & ~ge_p_mask) | (g4 & ge_p_mask & limb_mask);
1347
1348         /* h %= 2^128 */
1349         out->lo = cpu_to_le64(((u64)h2 << 52) | ((u64)h1 << 26) | h0);
1350         out->hi = cpu_to_le64(((u64)h4 << 40) | ((u64)h3 << 14) | (h2 >> 12));
1351 }
1352
1353 /*----------------------------------------------------------------------------*
1354  *                          Adiantum encryption mode                          *
1355  *----------------------------------------------------------------------------*/
1356
1357 /*
1358  * Reference: "Adiantum: length-preserving encryption for entry-level processors"
1359  *      https://tosc.iacr.org/index.php/ToSC/article/view/7360
1360  */
1361
1362 #define ADIANTUM_KEY_SIZE       32
1363 #define ADIANTUM_IV_SIZE        32
1364 #define ADIANTUM_HASH_KEY_SIZE  ((2 * POLY1305_KEY_SIZE) + NH_KEY_SIZE)
1365
1366 #define NH_KEY_SIZE             1072
1367 #define NH_KEY_WORDS            (NH_KEY_SIZE / sizeof(u32))
1368 #define NH_BLOCK_SIZE           1024
1369 #define NH_HASH_SIZE            32
1370 #define NH_MESSAGE_UNIT         16
1371
1372 static u64 nh_pass(const u32 *key, const u8 *msg, size_t msglen)
1373 {
1374         u64 sum = 0;
1375
1376         ASSERT(msglen % NH_MESSAGE_UNIT == 0);
1377         while (msglen) {
1378                 sum += (u64)(u32)(get_unaligned_le32(msg +  0) + key[0]) *
1379                             (u32)(get_unaligned_le32(msg +  8) + key[2]);
1380                 sum += (u64)(u32)(get_unaligned_le32(msg +  4) + key[1]) *
1381                             (u32)(get_unaligned_le32(msg + 12) + key[3]);
1382                 key += NH_MESSAGE_UNIT / sizeof(key[0]);
1383                 msg += NH_MESSAGE_UNIT;
1384                 msglen -= NH_MESSAGE_UNIT;
1385         }
1386         return sum;
1387 }
1388
1389 /* NH Îµ-almost-universal hash function */
1390 static void nh(const u32 *key, const u8 *msg, size_t msglen,
1391                u8 result[NH_HASH_SIZE])
1392 {
1393         size_t i;
1394
1395         for (i = 0; i < NH_HASH_SIZE; i += sizeof(__le64)) {
1396                 put_unaligned_le64(nh_pass(key, msg, msglen), &result[i]);
1397                 key += NH_MESSAGE_UNIT / sizeof(key[0]);
1398         }
1399 }
1400
1401 /* Adiantum's Îµ-almost-∆-universal hash function */
1402 static void adiantum_hash(const u8 key[ADIANTUM_HASH_KEY_SIZE],
1403                           const u8 iv[ADIANTUM_IV_SIZE],
1404                           const u8 *msg, size_t msglen, le128 *result)
1405 {
1406         const u8 *header_poly_key = key;
1407         const u8 *msg_poly_key = header_poly_key + POLY1305_KEY_SIZE;
1408         const u8 *nh_key = msg_poly_key + POLY1305_KEY_SIZE;
1409         u32 nh_key_words[NH_KEY_WORDS];
1410         u8 header[POLY1305_BLOCK_SIZE + ADIANTUM_IV_SIZE];
1411         const size_t num_nh_blocks = DIV_ROUND_UP(msglen, NH_BLOCK_SIZE);
1412         u8 *nh_hashes = xmalloc(num_nh_blocks * NH_HASH_SIZE);
1413         const size_t padded_msglen = ROUND_UP(msglen, NH_MESSAGE_UNIT);
1414         u8 *padded_msg = xmalloc(padded_msglen);
1415         le128 hash1, hash2;
1416         size_t i;
1417
1418         for (i = 0; i < NH_KEY_WORDS; i++)
1419                 nh_key_words[i] = get_unaligned_le32(&nh_key[i * sizeof(u32)]);
1420
1421         /* Hash tweak and message length with first Poly1305 key */
1422         put_unaligned_le64((u64)msglen * 8, header);
1423         put_unaligned_le64(0, &header[sizeof(__le64)]);
1424         memcpy(&header[POLY1305_BLOCK_SIZE], iv, ADIANTUM_IV_SIZE);
1425         poly1305(header_poly_key, header, sizeof(header), &hash1);
1426
1427         /* Hash NH hashes of message blocks using second Poly1305 key */
1428         /* (using a super naive way of handling the padding) */
1429         memcpy(padded_msg, msg, msglen);
1430         memset(&padded_msg[msglen], 0, padded_msglen - msglen);
1431         for (i = 0; i < num_nh_blocks; i++) {
1432                 nh(nh_key_words, &padded_msg[i * NH_BLOCK_SIZE],
1433                    MIN(NH_BLOCK_SIZE, padded_msglen - (i * NH_BLOCK_SIZE)),
1434                    &nh_hashes[i * NH_HASH_SIZE]);
1435         }
1436         poly1305(msg_poly_key, nh_hashes, num_nh_blocks * NH_HASH_SIZE, &hash2);
1437
1438         /* Add the two hashes together to get the final hash */
1439         le128_add(result, &hash1, &hash2);
1440
1441         free(nh_hashes);
1442         free(padded_msg);
1443 }
1444
1445 static void adiantum_crypt(const u8 key[ADIANTUM_KEY_SIZE],
1446                            const u8 iv[ADIANTUM_IV_SIZE], const u8 *src,
1447                            u8 *dst, size_t nbytes, bool decrypting)
1448 {
1449         u8 subkeys[AES_256_KEY_SIZE + ADIANTUM_HASH_KEY_SIZE] = { 0 };
1450         struct aes_key aes_key;
1451         union {
1452                 u8 nonce[XCHACHA_NONCE_SIZE];
1453                 le128 block;
1454         } u = { .nonce = { 1 } };
1455         const size_t bulk_len = nbytes - sizeof(u.block);
1456         le128 hash;
1457
1458         ASSERT(nbytes >= sizeof(u.block));
1459
1460         /* Derive subkeys */
1461         xchacha12(key, u.nonce, subkeys, subkeys, sizeof(subkeys));
1462         aes_setkey(&aes_key, subkeys, AES_256_KEY_SIZE);
1463
1464         /* Hash left part and add to right part */
1465         adiantum_hash(&subkeys[AES_256_KEY_SIZE], iv, src, bulk_len, &hash);
1466         memcpy(&u.block, &src[bulk_len], sizeof(u.block));
1467         le128_add(&u.block, &u.block, &hash);
1468
1469         if (!decrypting) /* Encrypt right part with block cipher */
1470                 aes_encrypt(&aes_key, u.nonce, u.nonce);
1471
1472         /* Encrypt left part with stream cipher, using the computed nonce */
1473         u.nonce[sizeof(u.block)] = 1;
1474         xchacha12(key, u.nonce, src, dst, bulk_len);
1475
1476         if (decrypting) /* Decrypt right part with block cipher */
1477                 aes_decrypt(&aes_key, u.nonce, u.nonce);
1478
1479         /* Finalize right part by subtracting hash of left part */
1480         adiantum_hash(&subkeys[AES_256_KEY_SIZE], iv, dst, bulk_len, &hash);
1481         le128_sub(&u.block, &u.block, &hash);
1482         memcpy(&dst[bulk_len], &u.block, sizeof(u.block));
1483 }
1484
1485 static void adiantum_encrypt(const u8 key[ADIANTUM_KEY_SIZE],
1486                              const u8 iv[ADIANTUM_IV_SIZE],
1487                              const u8 *src, u8 *dst, size_t nbytes)
1488 {
1489         adiantum_crypt(key, iv, src, dst, nbytes, false);
1490 }
1491
1492 static void adiantum_decrypt(const u8 key[ADIANTUM_KEY_SIZE],
1493                              const u8 iv[ADIANTUM_IV_SIZE],
1494                              const u8 *src, u8 *dst, size_t nbytes)
1495 {
1496         adiantum_crypt(key, iv, src, dst, nbytes, true);
1497 }
1498
1499 #ifdef ENABLE_ALG_TESTS
1500 #include <linux/if_alg.h>
1501 #include <sys/socket.h>
1502 #define SOL_ALG 279
1503 static void af_alg_crypt(int algfd, int op, const u8 *key, size_t keylen,
1504                          const u8 *iv, size_t ivlen,
1505                          const u8 *src, u8 *dst, size_t datalen)
1506 {
1507         size_t controllen = CMSG_SPACE(sizeof(int)) +
1508                             CMSG_SPACE(sizeof(struct af_alg_iv) + ivlen);
1509         u8 *control = xmalloc(controllen);
1510         struct iovec iov = { .iov_base = (u8 *)src, .iov_len = datalen };
1511         struct msghdr msg = {
1512                 .msg_iov = &iov,
1513                 .msg_iovlen = 1,
1514                 .msg_control = control,
1515                 .msg_controllen = controllen,
1516         };
1517         struct cmsghdr *cmsg;
1518         struct af_alg_iv *algiv;
1519         int reqfd;
1520
1521         memset(control, 0, controllen);
1522
1523         cmsg = CMSG_FIRSTHDR(&msg);
1524         cmsg->cmsg_len = CMSG_LEN(sizeof(int));
1525         cmsg->cmsg_level = SOL_ALG;
1526         cmsg->cmsg_type = ALG_SET_OP;
1527         *(int *)CMSG_DATA(cmsg) = op;
1528
1529         cmsg = CMSG_NXTHDR(&msg, cmsg);
1530         cmsg->cmsg_len = CMSG_LEN(sizeof(struct af_alg_iv) + ivlen);
1531         cmsg->cmsg_level = SOL_ALG;
1532         cmsg->cmsg_type = ALG_SET_IV;
1533         algiv = (struct af_alg_iv *)CMSG_DATA(cmsg);
1534         algiv->ivlen = ivlen;
1535         memcpy(algiv->iv, iv, ivlen);
1536
1537         if (setsockopt(algfd, SOL_ALG, ALG_SET_KEY, key, keylen) != 0)
1538                 die_errno("can't set key on AF_ALG socket");
1539
1540         reqfd = accept(algfd, NULL, NULL);
1541         if (reqfd < 0)
1542                 die_errno("can't accept() AF_ALG socket");
1543         if (sendmsg(reqfd, &msg, 0) != datalen)
1544                 die_errno("can't sendmsg() AF_ALG request socket");
1545         if (xread(reqfd, dst, datalen) != datalen)
1546                 die("short read from AF_ALG request socket");
1547         close(reqfd);
1548
1549         free(control);
1550 }
1551
1552 static void test_adiantum(void)
1553 {
1554         int algfd = socket(AF_ALG, SOCK_SEQPACKET, 0);
1555         struct sockaddr_alg addr = {
1556                 .salg_type = "skcipher",
1557                 .salg_name = "adiantum(xchacha12,aes)",
1558         };
1559         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1560
1561         if (algfd < 0)
1562                 die_errno("can't create AF_ALG socket");
1563         if (bind(algfd, (struct sockaddr *)&addr, sizeof(addr)) != 0)
1564                 die_errno("can't bind AF_ALG socket to Adiantum algorithm");
1565
1566         while (num_tests--) {
1567                 u8 key[ADIANTUM_KEY_SIZE];
1568                 u8 iv[ADIANTUM_IV_SIZE];
1569                 u8 ptext[4096];
1570                 u8 ctext[sizeof(ptext)];
1571                 u8 ref_ctext[sizeof(ptext)];
1572                 u8 decrypted[sizeof(ptext)];
1573                 const size_t datalen = 16 + (rand() % (sizeof(ptext) - 15));
1574
1575                 rand_bytes(key, sizeof(key));
1576                 rand_bytes(iv, sizeof(iv));
1577                 rand_bytes(ptext, datalen);
1578
1579                 adiantum_encrypt(key, iv, ptext, ctext, datalen);
1580                 af_alg_crypt(algfd, ALG_OP_ENCRYPT, key, sizeof(key),
1581                              iv, sizeof(iv), ptext, ref_ctext, datalen);
1582                 ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1583
1584                 adiantum_decrypt(key, iv, ctext, decrypted, datalen);
1585                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1586         }
1587         close(algfd);
1588 }
1589 #endif /* ENABLE_ALG_TESTS */
1590
1591 /*----------------------------------------------------------------------------*
1592  *                               SipHash-2-4                                  *
1593  *----------------------------------------------------------------------------*/
1594
1595 /*
1596  * Reference: "SipHash: a fast short-input PRF"
1597  *      https://cr.yp.to/siphash/siphash-20120918.pdf
1598  */
1599
1600 #define SIPROUND                                                \
1601         do {                                                    \
1602                 v0 += v1;           v2 += v3;                   \
1603                 v1 = rol64(v1, 13); v3 = rol64(v3, 16);         \
1604                 v1 ^= v0;           v3 ^= v2;                   \
1605                 v0 = rol64(v0, 32);                             \
1606                 v2 += v1;           v0 += v3;                   \
1607                 v1 = rol64(v1, 17); v3 = rol64(v3, 21);         \
1608                 v1 ^= v2;           v3 ^= v0;                   \
1609                 v2 = rol64(v2, 32);                             \
1610         } while (0)
1611
1612 /* Compute the SipHash-2-4 of a 64-bit number when formatted as little endian */
1613 static u64 siphash_1u64(const u64 key[2], u64 data)
1614 {
1615         u64 v0 = key[0] ^ 0x736f6d6570736575ULL;
1616         u64 v1 = key[1] ^ 0x646f72616e646f6dULL;
1617         u64 v2 = key[0] ^ 0x6c7967656e657261ULL;
1618         u64 v3 = key[1] ^ 0x7465646279746573ULL;
1619         u64 m[2] = {data, (u64)sizeof(data) << 56};
1620         size_t i;
1621
1622         for (i = 0; i < ARRAY_SIZE(m); i++) {
1623                 v3 ^= m[i];
1624                 SIPROUND;
1625                 SIPROUND;
1626                 v0 ^= m[i];
1627         }
1628
1629         v2 ^= 0xff;
1630         for (i = 0; i < 4; i++)
1631                 SIPROUND;
1632         return v0 ^ v1 ^ v2 ^ v3;
1633 }
1634
1635 /*----------------------------------------------------------------------------*
1636  *                               Main program                                 *
1637  *----------------------------------------------------------------------------*/
1638
1639 #define FILE_NONCE_SIZE         16
1640 #define UUID_SIZE               16
1641 #define MAX_KEY_SIZE            64
1642
1643 static const struct fscrypt_cipher {
1644         const char *name;
1645         void (*encrypt)(const u8 *key, const u8 *iv, const u8 *src,
1646                         u8 *dst, size_t nbytes);
1647         void (*decrypt)(const u8 *key, const u8 *iv, const u8 *src,
1648                         u8 *dst, size_t nbytes);
1649         int keysize;
1650         int min_input_size;
1651 } fscrypt_ciphers[] = {
1652         {
1653                 .name = "AES-256-XTS",
1654                 .encrypt = aes_256_xts_encrypt,
1655                 .decrypt = aes_256_xts_decrypt,
1656                 .keysize = 2 * AES_256_KEY_SIZE,
1657         }, {
1658                 .name = "AES-256-CTS-CBC",
1659                 .encrypt = aes_256_cts_cbc_encrypt,
1660                 .decrypt = aes_256_cts_cbc_decrypt,
1661                 .keysize = AES_256_KEY_SIZE,
1662                 .min_input_size = AES_BLOCK_SIZE,
1663         }, {
1664                 .name = "AES-128-CBC-ESSIV",
1665                 .encrypt = aes_128_cbc_essiv_encrypt,
1666                 .decrypt = aes_128_cbc_essiv_decrypt,
1667                 .keysize = AES_128_KEY_SIZE,
1668         }, {
1669                 .name = "AES-128-CTS-CBC",
1670                 .encrypt = aes_128_cts_cbc_encrypt,
1671                 .decrypt = aes_128_cts_cbc_decrypt,
1672                 .keysize = AES_128_KEY_SIZE,
1673                 .min_input_size = AES_BLOCK_SIZE,
1674         }, {
1675                 .name = "Adiantum",
1676                 .encrypt = adiantum_encrypt,
1677                 .decrypt = adiantum_decrypt,
1678                 .keysize = ADIANTUM_KEY_SIZE,
1679                 .min_input_size = AES_BLOCK_SIZE,
1680         }
1681 };
1682
1683 static const struct fscrypt_cipher *find_fscrypt_cipher(const char *name)
1684 {
1685         size_t i;
1686
1687         for (i = 0; i < ARRAY_SIZE(fscrypt_ciphers); i++) {
1688                 if (strcmp(fscrypt_ciphers[i].name, name) == 0)
1689                         return &fscrypt_ciphers[i];
1690         }
1691         return NULL;
1692 }
1693
1694 struct fscrypt_iv {
1695         union {
1696                 __le64 block_num;
1697                 u8 bytes[32];
1698         };
1699 };
1700
1701 static void crypt_loop(const struct fscrypt_cipher *cipher, const u8 *key,
1702                        struct fscrypt_iv *iv, bool decrypting,
1703                        size_t block_size, size_t padding)
1704 {
1705         u8 *buf = xmalloc(block_size);
1706         size_t res;
1707
1708         while ((res = xread(STDIN_FILENO, buf, block_size)) > 0) {
1709                 size_t crypt_len = block_size;
1710
1711                 if (padding > 0) {
1712                         crypt_len = MAX(res, cipher->min_input_size);
1713                         crypt_len = ROUND_UP(crypt_len, padding);
1714                         crypt_len = MIN(crypt_len, block_size);
1715                 }
1716                 ASSERT(crypt_len >= res);
1717                 memset(&buf[res], 0, crypt_len - res);
1718
1719                 if (decrypting)
1720                         cipher->decrypt(key, iv->bytes, buf, buf, crypt_len);
1721                 else
1722                         cipher->encrypt(key, iv->bytes, buf, buf, crypt_len);
1723
1724                 full_write(STDOUT_FILENO, buf, crypt_len);
1725
1726                 iv->block_num = cpu_to_le64(le64_to_cpu(iv->block_num) + 1);
1727         }
1728         free(buf);
1729 }
1730
1731 /* The supported key derivation functions */
1732 enum kdf_algorithm {
1733         KDF_NONE,
1734         KDF_AES_128_ECB,
1735         KDF_HKDF_SHA512,
1736 };
1737
1738 static enum kdf_algorithm parse_kdf_algorithm(const char *arg)
1739 {
1740         if (strcmp(arg, "none") == 0)
1741                 return KDF_NONE;
1742         if (strcmp(arg, "AES-128-ECB") == 0)
1743                 return KDF_AES_128_ECB;
1744         if (strcmp(arg, "HKDF-SHA512") == 0)
1745                 return KDF_HKDF_SHA512;
1746         die("Unknown KDF: %s", arg);
1747 }
1748
1749 static u8 parse_mode_number(const char *arg)
1750 {
1751         char *tmp;
1752         long num = strtol(arg, &tmp, 10);
1753
1754         if (num <= 0 || *tmp || (u8)num != num)
1755                 die("Invalid mode number: %s", arg);
1756         return num;
1757 }
1758
1759 static u32 parse_inode_number(const char *arg)
1760 {
1761         char *tmp;
1762         unsigned long long num = strtoull(arg, &tmp, 10);
1763
1764         if (num <= 0 || *tmp)
1765                 die("Invalid inode number: %s", arg);
1766         if ((u32)num != num)
1767                 die("Inode number %s is too large; must be 32-bit", arg);
1768         return num;
1769 }
1770
1771 struct key_and_iv_params {
1772         u8 master_key[MAX_KEY_SIZE];
1773         int master_key_size;
1774         enum kdf_algorithm kdf;
1775         u8 mode_num;
1776         u8 file_nonce[FILE_NONCE_SIZE];
1777         bool file_nonce_specified;
1778         bool iv_ino_lblk_64;
1779         bool iv_ino_lblk_32;
1780         u32 inode_number;
1781         u8 fs_uuid[UUID_SIZE];
1782         bool fs_uuid_specified;
1783 };
1784
1785 #define HKDF_CONTEXT_KEY_IDENTIFIER     1
1786 #define HKDF_CONTEXT_PER_FILE_ENC_KEY   2
1787 #define HKDF_CONTEXT_DIRECT_KEY         3
1788 #define HKDF_CONTEXT_IV_INO_LBLK_64_KEY 4
1789 #define HKDF_CONTEXT_DIRHASH_KEY        5
1790 #define HKDF_CONTEXT_IV_INO_LBLK_32_KEY 6
1791 #define HKDF_CONTEXT_INODE_HASH_KEY     7
1792
1793 /* Hash the file's inode number using SipHash keyed by a derived key */
1794 static u32 hash_inode_number(const struct key_and_iv_params *params)
1795 {
1796         u8 info[9] = "fscrypt";
1797         union {
1798                 u64 words[2];
1799                 u8 bytes[16];
1800         } hash_key;
1801
1802         info[8] = HKDF_CONTEXT_INODE_HASH_KEY;
1803         hkdf_sha512(params->master_key, params->master_key_size,
1804                     NULL, 0, info, sizeof(info),
1805                     hash_key.bytes, sizeof(hash_key));
1806
1807         hash_key.words[0] = get_unaligned_le64(&hash_key.bytes[0]);
1808         hash_key.words[1] = get_unaligned_le64(&hash_key.bytes[8]);
1809
1810         return (u32)siphash_1u64(hash_key.words, params->inode_number);
1811 }
1812
1813 /*
1814  * Get the key and starting IV with which the encryption will actually be done.
1815  * If a KDF was specified, a subkey is derived from the master key and the mode
1816  * number or file nonce.  Otherwise, the master key is used directly.
1817  */
1818 static void get_key_and_iv(const struct key_and_iv_params *params,
1819                            u8 *real_key, size_t real_key_size,
1820                            struct fscrypt_iv *iv)
1821 {
1822         bool file_nonce_in_iv = false;
1823         struct aes_key aes_key;
1824         u8 info[8 + 1 + 1 + UUID_SIZE] = "fscrypt";
1825         size_t infolen = 8;
1826         size_t i;
1827
1828         ASSERT(real_key_size <= params->master_key_size);
1829
1830         memset(iv, 0, sizeof(*iv));
1831
1832         if (params->iv_ino_lblk_64 || params->iv_ino_lblk_32) {
1833                 const char *opt = params->iv_ino_lblk_64 ? "--iv-ino-lblk-64" :
1834                                                            "--iv-ino-lblk-32";
1835                 if (params->iv_ino_lblk_64 && params->iv_ino_lblk_32)
1836                         die("--iv-ino-lblk-64 and --iv-ino-lblk-32 are mutually exclusive");
1837                 if (params->kdf != KDF_HKDF_SHA512)
1838                         die("%s requires --kdf=HKDF-SHA512", opt);
1839                 if (!params->fs_uuid_specified)
1840                         die("%s requires --fs-uuid", opt);
1841                 if (params->inode_number == 0)
1842                         die("%s requires --inode-number", opt);
1843                 if (params->mode_num == 0)
1844                         die("%s requires --mode-num", opt);
1845         }
1846
1847         switch (params->kdf) {
1848         case KDF_NONE:
1849                 if (params->mode_num != 0)
1850                         die("--mode-num isn't supported with --kdf=none");
1851                 memcpy(real_key, params->master_key, real_key_size);
1852                 file_nonce_in_iv = true;
1853                 break;
1854         case KDF_AES_128_ECB:
1855                 if (!params->file_nonce_specified)
1856                         die("--file-nonce is required with --kdf=AES-128-ECB");
1857                 if (params->mode_num != 0)
1858                         die("--mode-num isn't supported with --kdf=AES-128-ECB");
1859                 STATIC_ASSERT(FILE_NONCE_SIZE == AES_128_KEY_SIZE);
1860                 ASSERT(real_key_size % AES_BLOCK_SIZE == 0);
1861                 aes_setkey(&aes_key, params->file_nonce, AES_128_KEY_SIZE);
1862                 for (i = 0; i < real_key_size; i += AES_BLOCK_SIZE)
1863                         aes_encrypt(&aes_key, &params->master_key[i],
1864                                     &real_key[i]);
1865                 break;
1866         case KDF_HKDF_SHA512:
1867                 if (params->iv_ino_lblk_64) {
1868                         info[infolen++] = HKDF_CONTEXT_IV_INO_LBLK_64_KEY;
1869                         info[infolen++] = params->mode_num;
1870                         memcpy(&info[infolen], params->fs_uuid, UUID_SIZE);
1871                         infolen += UUID_SIZE;
1872                         put_unaligned_le32(params->inode_number, &iv->bytes[4]);
1873                 } else if (params->iv_ino_lblk_32) {
1874                         info[infolen++] = HKDF_CONTEXT_IV_INO_LBLK_32_KEY;
1875                         info[infolen++] = params->mode_num;
1876                         memcpy(&info[infolen], params->fs_uuid, UUID_SIZE);
1877                         infolen += UUID_SIZE;
1878                         put_unaligned_le32(hash_inode_number(params),
1879                                            iv->bytes);
1880                 } else if (params->mode_num != 0) {
1881                         info[infolen++] = HKDF_CONTEXT_DIRECT_KEY;
1882                         info[infolen++] = params->mode_num;
1883                         file_nonce_in_iv = true;
1884                 } else if (params->file_nonce_specified) {
1885                         info[infolen++] = HKDF_CONTEXT_PER_FILE_ENC_KEY;
1886                         memcpy(&info[infolen], params->file_nonce,
1887                                FILE_NONCE_SIZE);
1888                         infolen += FILE_NONCE_SIZE;
1889                 } else {
1890                         die("With --kdf=HKDF-SHA512, at least one of --file-nonce and --mode-num must be specified");
1891                 }
1892                 hkdf_sha512(params->master_key, params->master_key_size,
1893                             NULL, 0, info, infolen, real_key, real_key_size);
1894                 break;
1895         default:
1896                 ASSERT(0);
1897         }
1898
1899         if (file_nonce_in_iv && params->file_nonce_specified)
1900                 memcpy(&iv->bytes[8], params->file_nonce, FILE_NONCE_SIZE);
1901 }
1902
1903 enum {
1904         OPT_BLOCK_SIZE,
1905         OPT_DECRYPT,
1906         OPT_FILE_NONCE,
1907         OPT_FS_UUID,
1908         OPT_HELP,
1909         OPT_INODE_NUMBER,
1910         OPT_IV_INO_LBLK_32,
1911         OPT_IV_INO_LBLK_64,
1912         OPT_KDF,
1913         OPT_MODE_NUM,
1914         OPT_PADDING,
1915 };
1916
1917 static const struct option longopts[] = {
1918         { "block-size",      required_argument, NULL, OPT_BLOCK_SIZE },
1919         { "decrypt",         no_argument,       NULL, OPT_DECRYPT },
1920         { "file-nonce",      required_argument, NULL, OPT_FILE_NONCE },
1921         { "fs-uuid",         required_argument, NULL, OPT_FS_UUID },
1922         { "help",            no_argument,       NULL, OPT_HELP },
1923         { "inode-number",    required_argument, NULL, OPT_INODE_NUMBER },
1924         { "iv-ino-lblk-32",  no_argument,       NULL, OPT_IV_INO_LBLK_32 },
1925         { "iv-ino-lblk-64",  no_argument,       NULL, OPT_IV_INO_LBLK_64 },
1926         { "kdf",             required_argument, NULL, OPT_KDF },
1927         { "mode-num",        required_argument, NULL, OPT_MODE_NUM },
1928         { "padding",         required_argument, NULL, OPT_PADDING },
1929         { NULL, 0, NULL, 0 },
1930 };
1931
1932 int main(int argc, char *argv[])
1933 {
1934         size_t block_size = 4096;
1935         bool decrypting = false;
1936         struct key_and_iv_params params;
1937         size_t padding = 0;
1938         const struct fscrypt_cipher *cipher;
1939         u8 real_key[MAX_KEY_SIZE];
1940         struct fscrypt_iv iv;
1941         char *tmp;
1942         int c;
1943
1944         memset(&params, 0, sizeof(params));
1945
1946         aes_init();
1947
1948 #ifdef ENABLE_ALG_TESTS
1949         test_aes();
1950         test_sha2();
1951         test_hkdf_sha512();
1952         test_aes_256_xts();
1953         test_aes_256_cts_cbc();
1954         test_adiantum();
1955 #endif
1956
1957         while ((c = getopt_long(argc, argv, "", longopts, NULL)) != -1) {
1958                 switch (c) {
1959                 case OPT_BLOCK_SIZE:
1960                         block_size = strtoul(optarg, &tmp, 10);
1961                         if (block_size <= 0 || *tmp)
1962                                 die("Invalid block size: %s", optarg);
1963                         break;
1964                 case OPT_DECRYPT:
1965                         decrypting = true;
1966                         break;
1967                 case OPT_FILE_NONCE:
1968                         if (hex2bin(optarg, params.file_nonce, FILE_NONCE_SIZE)
1969                             != FILE_NONCE_SIZE)
1970                                 die("Invalid file nonce: %s", optarg);
1971                         params.file_nonce_specified = true;
1972                         break;
1973                 case OPT_FS_UUID:
1974                         if (hex2bin(optarg, params.fs_uuid, UUID_SIZE)
1975                             != UUID_SIZE)
1976                                 die("Invalid filesystem UUID: %s", optarg);
1977                         params.fs_uuid_specified = true;
1978                         break;
1979                 case OPT_HELP:
1980                         usage(stdout);
1981                         return 0;
1982                 case OPT_INODE_NUMBER:
1983                         params.inode_number = parse_inode_number(optarg);
1984                         break;
1985                 case OPT_IV_INO_LBLK_32:
1986                         params.iv_ino_lblk_32 = true;
1987                         break;
1988                 case OPT_IV_INO_LBLK_64:
1989                         params.iv_ino_lblk_64 = true;
1990                         break;
1991                 case OPT_KDF:
1992                         params.kdf = parse_kdf_algorithm(optarg);
1993                         break;
1994                 case OPT_MODE_NUM:
1995                         params.mode_num = parse_mode_number(optarg);
1996                         break;
1997                 case OPT_PADDING:
1998                         padding = strtoul(optarg, &tmp, 10);
1999                         if (padding <= 0 || *tmp || !is_power_of_2(padding) ||
2000                             padding > INT_MAX)
2001                                 die("Invalid padding amount: %s", optarg);
2002                         break;
2003                 default:
2004                         usage(stderr);
2005                         return 2;
2006                 }
2007         }
2008         argc -= optind;
2009         argv += optind;
2010
2011         if (argc != 2) {
2012                 usage(stderr);
2013                 return 2;
2014         }
2015
2016         cipher = find_fscrypt_cipher(argv[0]);
2017         if (cipher == NULL)
2018                 die("Unknown cipher: %s", argv[0]);
2019
2020         if (block_size < cipher->min_input_size)
2021                 die("Block size of %zu bytes is too small for cipher %s",
2022                     block_size, cipher->name);
2023
2024         params.master_key_size = hex2bin(argv[1], params.master_key,
2025                                          MAX_KEY_SIZE);
2026         if (params.master_key_size < 0)
2027                 die("Invalid master_key: %s", argv[1]);
2028         if (params.master_key_size < cipher->keysize)
2029                 die("Master key is too short for cipher %s", cipher->name);
2030
2031         get_key_and_iv(&params, real_key, cipher->keysize, &iv);
2032
2033         crypt_loop(cipher, real_key, &iv, decrypting, block_size, padding);
2034         return 0;
2035 }