generic: test MADV_POPULATE_READ with IO errors
[xfstests-dev.git] / src / fscrypt-crypt-util.c
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * fscrypt-crypt-util.c - utility for verifying fscrypt-encrypted data
4  *
5  * Copyright 2019 Google LLC
6  */
7
8 /*
9  * This program implements all crypto algorithms supported by fscrypt (a.k.a.
10  * ext4, f2fs, and ubifs encryption), for the purpose of verifying the
11  * correctness of the ciphertext stored on-disk.  See usage() below.
12  *
13  * All algorithms are implemented in portable C code to avoid depending on
14  * libcrypto (OpenSSL), and because some fscrypt-supported algorithms aren't
15  * available in libcrypto anyway (e.g. Adiantum), or are only supported in
16  * recent versions (e.g. HKDF-SHA512).  For simplicity, all crypto code here
17  * tries to follow the mathematical definitions directly, without optimizing for
18  * performance or worrying about following security best practices such as
19  * mitigating side-channel attacks.  So, only use this program for testing!
20  */
21
22 #include <asm/byteorder.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <limits.h>
26 #include <linux/types.h>
27 #include <stdarg.h>
28 #include <stdbool.h>
29 #include <stdint.h>
30 #include <stdio.h>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <unistd.h>
34
35 #define PROGRAM_NAME "fscrypt-crypt-util"
36
37 /*
38  * Define to enable the tests of the crypto code in this file.  If enabled, you
39  * must link this program with OpenSSL (-lcrypto) v1.1.0 or later, and your
40  * kernel needs CONFIG_CRYPTO_USER_API_SKCIPHER=y, CONFIG_CRYPTO_ADIANTUM=y, and
41  * CONFIG_CRYPTO_HCTR2=y.
42  */
43 #undef ENABLE_ALG_TESTS
44
45 #define NUM_ALG_TEST_ITERATIONS 10000
46
47 static void usage(FILE *fp)
48 {
49         fputs(
50 "Usage: " PROGRAM_NAME " [OPTION]... [CIPHER | --dump-key-identifier] MASTER_KEY\n"
51 "\n"
52 "Utility for verifying fscrypt-encrypted data.  This program encrypts\n"
53 "(or decrypts) the data on stdin using the given CIPHER with the given\n"
54 "MASTER_KEY (or a key derived from it, if a KDF is specified), and writes the\n"
55 "resulting ciphertext (or plaintext) to stdout.\n"
56 "\n"
57 "CIPHER can be AES-256-XTS, AES-256-CTS-CBC, AES-128-CBC-ESSIV, AES-128-CTS-CBC,\n"
58 "Adiantum, or AES-256-HCTR2.  MASTER_KEY must be a hex string long enough for\n"
59 "the cipher.\n"
60 "\n"
61 "WARNING: this program is only meant for testing, not for \"real\" use!\n"
62 "\n"
63 "Options:\n"
64 "  --data-unit-index=DUIDX     Starting data unit index for IV generation.\n"
65 "                                Default: 0\n"
66 "  --data-unit-size=DUSIZE     Encrypt each DUSIZE bytes independently.\n"
67 "                                Default: 4096 bytes\n"
68 "  --decrypt                   Decrypt instead of encrypt\n"
69 "  --direct-key                Use the format where the IVs include the file\n"
70 "                                nonce and the same key is shared across files.\n"
71 "  --dump-key-identifier       Instead of encrypting/decrypting data, just\n"
72 "                                compute and dump the key identifier.\n"
73 "  --file-nonce=NONCE          File's nonce as a 32-character hex string\n"
74 "  --fs-uuid=UUID              The filesystem UUID as a 32-character hex string.\n"
75 "                                Required for --iv-ino-lblk-32 and\n"
76 "                                --iv-ino-lblk-64; otherwise is unused.\n"
77 "  --help                      Show this help\n"
78 "  --inode-number=INUM         The file's inode number.  Required for\n"
79 "                                --iv-ino-lblk-32 and --iv-ino-lblk-64;\n"
80 "                                otherwise is unused.\n"
81 "  --iv-ino-lblk-32            Similar to --iv-ino-lblk-64, but selects the\n"
82 "                                32-bit variant.\n"
83 "  --iv-ino-lblk-64            Use the format where the IVs include the inode\n"
84 "                                number and the same key is shared across files.\n"
85 "  --kdf=KDF                   Key derivation function to use: AES-128-ECB,\n"
86 "                                HKDF-SHA512, or none.  Default: none\n"
87 "  --mode-num=NUM              The encryption mode number.  This may be required\n"
88 "                                for key derivation, depending on other options.\n"
89 "  --padding=PADDING           If last data unit is partial, zero-pad it to next\n"
90 "                                PADDING-byte boundary.  Default: DUSIZE\n"
91         , fp);
92 }
93
94 /*----------------------------------------------------------------------------*
95  *                                 Utilities                                  *
96  *----------------------------------------------------------------------------*/
97
98 #define ARRAY_SIZE(A)           (sizeof(A) / sizeof((A)[0]))
99 #define MIN(x, y)               ((x) < (y) ? (x) : (y))
100 #define MAX(x, y)               ((x) > (y) ? (x) : (y))
101 #define ROUND_DOWN(x, y)        ((x) & ~((y) - 1))
102 #define ROUND_UP(x, y)          (((x) + (y) - 1) & ~((y) - 1))
103 #define DIV_ROUND_UP(n, d)      (((n) + (d) - 1) / (d))
104 #define STATIC_ASSERT(e)        ((void)sizeof(char[1 - 2*!(e)]))
105
106 typedef __u8                    u8;
107 typedef __u16                   u16;
108 typedef __u32                   u32;
109 typedef __u64                   u64;
110
111 #define cpu_to_le32             __cpu_to_le32
112 #define cpu_to_be32             __cpu_to_be32
113 #define cpu_to_le64             __cpu_to_le64
114 #define cpu_to_be64             __cpu_to_be64
115 #define le32_to_cpu             __le32_to_cpu
116 #define be32_to_cpu             __be32_to_cpu
117 #define le64_to_cpu             __le64_to_cpu
118 #define be64_to_cpu             __be64_to_cpu
119
120 #define DEFINE_UNALIGNED_ACCESS_HELPERS(type, native_type)      \
121 static inline native_type __attribute__((unused))               \
122 get_unaligned_##type(const void *p)                             \
123 {                                                               \
124         __##type x;                                             \
125                                                                 \
126         memcpy(&x, p, sizeof(x));                               \
127         return type##_to_cpu(x);                                \
128 }                                                               \
129                                                                 \
130 static inline void __attribute__((unused))                      \
131 put_unaligned_##type(native_type v, void *p)                    \
132 {                                                               \
133         __##type x = cpu_to_##type(v);                          \
134                                                                 \
135         memcpy(p, &x, sizeof(x));                               \
136 }
137
138 DEFINE_UNALIGNED_ACCESS_HELPERS(le32, u32)
139 DEFINE_UNALIGNED_ACCESS_HELPERS(be32, u32)
140 DEFINE_UNALIGNED_ACCESS_HELPERS(le64, u64)
141 DEFINE_UNALIGNED_ACCESS_HELPERS(be64, u64)
142
143 static inline bool is_power_of_2(unsigned long v)
144 {
145         return v != 0 && (v & (v - 1)) == 0;
146 }
147
148 static inline u32 rol32(u32 v, int n)
149 {
150         return (v << n) | (v >> (32 - n));
151 }
152
153 static inline u32 ror32(u32 v, int n)
154 {
155         return (v >> n) | (v << (32 - n));
156 }
157
158 static inline u64 rol64(u64 v, int n)
159 {
160         return (v << n) | (v >> (64 - n));
161 }
162
163 static inline u64 ror64(u64 v, int n)
164 {
165         return (v >> n) | (v << (64 - n));
166 }
167
168 static inline void xor(u8 *res, const u8 *a, const u8 *b, size_t count)
169 {
170         while (count--)
171                 *res++ = *a++ ^ *b++;
172 }
173
174 static void __attribute__((noreturn, format(printf, 2, 3)))
175 do_die(int err, const char *format, ...)
176 {
177         va_list va;
178
179         va_start(va, format);
180         fputs("[" PROGRAM_NAME "] ERROR: ", stderr);
181         vfprintf(stderr, format, va);
182         if (err)
183                 fprintf(stderr, ": %s", strerror(errno));
184         putc('\n', stderr);
185         va_end(va);
186         exit(1);
187 }
188
189 #define die(format, ...)        do_die(0,     (format), ##__VA_ARGS__)
190 #define die_errno(format, ...)  do_die(errno, (format), ##__VA_ARGS__)
191
192 static __attribute__((noreturn)) void
193 assertion_failed(const char *expr, const char *file, int line)
194 {
195         die("Assertion failed: %s at %s:%d", expr, file, line);
196 }
197
198 #define ASSERT(e) ({ if (!(e)) assertion_failed(#e, __FILE__, __LINE__); })
199
200 static void *xmalloc(size_t size)
201 {
202         void *p = malloc(size);
203
204         ASSERT(p != NULL);
205         return p;
206 }
207
208 static int hexchar2bin(char c)
209 {
210         if (c >= 'a' && c <= 'f')
211                 return 10 + c - 'a';
212         if (c >= 'A' && c <= 'F')
213                 return 10 + c - 'A';
214         if (c >= '0' && c <= '9')
215                 return c - '0';
216         return -1;
217 }
218
219 static int hex2bin(const char *hex, u8 *bin, int max_bin_size)
220 {
221         size_t len = strlen(hex);
222         size_t i;
223
224         if (len & 1)
225                 return -1;
226         len /= 2;
227         if (len > max_bin_size)
228                 return -1;
229
230         for (i = 0; i < len; i++) {
231                 int high = hexchar2bin(hex[2 * i]);
232                 int low = hexchar2bin(hex[2 * i + 1]);
233
234                 if (high < 0 || low < 0)
235                         return -1;
236                 bin[i] = (high << 4) | low;
237         }
238         return len;
239 }
240
241 static size_t xread(int fd, void *buf, size_t count)
242 {
243         const size_t orig_count = count;
244
245         while (count) {
246                 ssize_t res = read(fd, buf, count);
247
248                 if (res < 0)
249                         die_errno("read error");
250                 if (res == 0)
251                         break;
252                 buf += res;
253                 count -= res;
254         }
255         return orig_count - count;
256 }
257
258 static void full_write(int fd, const void *buf, size_t count)
259 {
260         while (count) {
261                 ssize_t res = write(fd, buf, count);
262
263                 if (res < 0)
264                         die_errno("write error");
265                 buf += res;
266                 count -= res;
267         }
268 }
269
270 #ifdef ENABLE_ALG_TESTS
271 static void rand_bytes(u8 *buf, size_t count)
272 {
273         while (count--)
274                 *buf++ = rand();
275 }
276
277 #include <linux/if_alg.h>
278 #include <sys/socket.h>
279 #define SOL_ALG 279
280 static void af_alg_crypt(int algfd, int op, const u8 *key, size_t keylen,
281                          const u8 *iv, size_t ivlen,
282                          const u8 *src, u8 *dst, size_t datalen)
283 {
284         size_t controllen = CMSG_SPACE(sizeof(int)) +
285                             CMSG_SPACE(sizeof(struct af_alg_iv) + ivlen);
286         u8 *control = xmalloc(controllen);
287         struct iovec iov = { .iov_base = (u8 *)src, .iov_len = datalen };
288         struct msghdr msg = {
289                 .msg_iov = &iov,
290                 .msg_iovlen = 1,
291                 .msg_control = control,
292                 .msg_controllen = controllen,
293         };
294         struct cmsghdr *cmsg;
295         struct af_alg_iv *algiv;
296         int reqfd;
297
298         memset(control, 0, controllen);
299
300         cmsg = CMSG_FIRSTHDR(&msg);
301         cmsg->cmsg_len = CMSG_LEN(sizeof(int));
302         cmsg->cmsg_level = SOL_ALG;
303         cmsg->cmsg_type = ALG_SET_OP;
304         *(int *)CMSG_DATA(cmsg) = op;
305
306         cmsg = CMSG_NXTHDR(&msg, cmsg);
307         cmsg->cmsg_len = CMSG_LEN(sizeof(struct af_alg_iv) + ivlen);
308         cmsg->cmsg_level = SOL_ALG;
309         cmsg->cmsg_type = ALG_SET_IV;
310         algiv = (struct af_alg_iv *)CMSG_DATA(cmsg);
311         algiv->ivlen = ivlen;
312         memcpy(algiv->iv, iv, ivlen);
313
314         if (setsockopt(algfd, SOL_ALG, ALG_SET_KEY, key, keylen) != 0)
315                 die_errno("can't set key on AF_ALG socket");
316
317         reqfd = accept(algfd, NULL, NULL);
318         if (reqfd < 0)
319                 die_errno("can't accept() AF_ALG socket");
320         if (sendmsg(reqfd, &msg, 0) != datalen)
321                 die_errno("can't sendmsg() AF_ALG request socket");
322         if (xread(reqfd, dst, datalen) != datalen)
323                 die("short read from AF_ALG request socket");
324         close(reqfd);
325
326         free(control);
327 }
328 #endif /* ENABLE_ALG_TESTS */
329
330 /*----------------------------------------------------------------------------*
331  *                          Finite field arithmetic                           *
332  *----------------------------------------------------------------------------*/
333
334 /* Multiply a GF(2^8) element by the polynomial 'x' */
335 static inline u8 gf2_8_mul_x(u8 b)
336 {
337         return (b << 1) ^ ((b & 0x80) ? 0x1B : 0);
338 }
339
340 /* Multiply four packed GF(2^8) elements by the polynomial 'x' */
341 static inline u32 gf2_8_mul_x_4way(u32 w)
342 {
343         return ((w & 0x7F7F7F7F) << 1) ^ (((w & 0x80808080) >> 7) * 0x1B);
344 }
345
346 /* Element of GF(2^128) */
347 typedef struct {
348         __le64 lo;
349         __le64 hi;
350 } ble128;
351
352 /* Multiply a GF(2^128) element by the polynomial 'x' */
353 static inline void gf2_128_mul_x_xts(ble128 *t)
354 {
355         u64 lo = le64_to_cpu(t->lo);
356         u64 hi = le64_to_cpu(t->hi);
357
358         t->hi = cpu_to_le64((hi << 1) | (lo >> 63));
359         t->lo = cpu_to_le64((lo << 1) ^ ((hi & (1ULL << 63)) ? 0x87 : 0));
360 }
361
362 static inline void gf2_128_mul_x_polyval(ble128 *t)
363 {
364         u64 lo = le64_to_cpu(t->lo);
365         u64 hi = le64_to_cpu(t->hi);
366         u64 lo_reducer = (hi & (1ULL << 63)) ? 1 : 0;
367         u64 hi_reducer = (hi & (1ULL << 63)) ? 0xc2ULL << 56 : 0;
368
369         t->hi = cpu_to_le64(((hi << 1) | (lo >> 63)) ^ hi_reducer);
370         t->lo = cpu_to_le64((lo << 1) ^ lo_reducer);
371 }
372
373 static void gf2_128_mul_polyval(ble128 *r, const ble128 *b)
374 {
375         int i;
376         ble128 p;
377         u64 lo = le64_to_cpu(b->lo);
378         u64 hi = le64_to_cpu(b->hi);
379
380         memset(&p, 0, sizeof(p));
381         for (i = 0; i < 64; i++) {
382                 if (lo & (1ULL << i))
383                         xor((u8 *)&p, (u8 *)&p, (u8 *)r, sizeof(p));
384                 gf2_128_mul_x_polyval(r);
385         }
386         for (i = 0; i < 64; i++) {
387                 if (hi & (1ULL << i))
388                         xor((u8 *)&p, (u8 *)&p, (u8 *)r, sizeof(p));
389                 gf2_128_mul_x_polyval(r);
390         }
391         *r = p;
392 }
393
394 /*----------------------------------------------------------------------------*
395  *                             Group arithmetic                               *
396  *----------------------------------------------------------------------------*/
397
398 /* Element of Z/(2^{128}Z)  (a.k.a. the integers modulo 2^128) */
399 typedef struct {
400         __le64 lo;
401         __le64 hi;
402 } le128;
403
404 static inline void le128_add(le128 *res, const le128 *a, const le128 *b)
405 {
406         u64 a_lo = le64_to_cpu(a->lo);
407         u64 b_lo = le64_to_cpu(b->lo);
408
409         res->lo = cpu_to_le64(a_lo + b_lo);
410         res->hi = cpu_to_le64(le64_to_cpu(a->hi) + le64_to_cpu(b->hi) +
411                               (a_lo + b_lo < a_lo));
412 }
413
414 static inline void le128_sub(le128 *res, const le128 *a, const le128 *b)
415 {
416         u64 a_lo = le64_to_cpu(a->lo);
417         u64 b_lo = le64_to_cpu(b->lo);
418
419         res->lo = cpu_to_le64(a_lo - b_lo);
420         res->hi = cpu_to_le64(le64_to_cpu(a->hi) - le64_to_cpu(b->hi) -
421                               (a_lo - b_lo > a_lo));
422 }
423
424 /*----------------------------------------------------------------------------*
425  *                              AES block cipher                              *
426  *----------------------------------------------------------------------------*/
427
428 /*
429  * Reference: "FIPS 197, Advanced Encryption Standard"
430  *      https://nvlpubs.nist.gov/nistpubs/fips/nist.fips.197.pdf
431  */
432
433 #define AES_BLOCK_SIZE          16
434 #define AES_128_KEY_SIZE        16
435 #define AES_192_KEY_SIZE        24
436 #define AES_256_KEY_SIZE        32
437
438 static inline void AddRoundKey(u32 state[4], const u32 *rk)
439 {
440         int i;
441
442         for (i = 0; i < 4; i++)
443                 state[i] ^= rk[i];
444 }
445
446 static const u8 aes_sbox[256] = {
447         0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b,
448         0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0,
449         0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26,
450         0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
451         0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2,
452         0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0,
453         0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed,
454         0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
455         0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f,
456         0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5,
457         0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec,
458         0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
459         0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14,
460         0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c,
461         0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d,
462         0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
463         0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f,
464         0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e,
465         0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11,
466         0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
467         0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f,
468         0xb0, 0x54, 0xbb, 0x16,
469 };
470
471 static u8 aes_inverse_sbox[256];
472
473 static void aes_init(void)
474 {
475         int i;
476
477         for (i = 0; i < 256; i++)
478                 aes_inverse_sbox[aes_sbox[i]] = i;
479 }
480
481 static inline u32 DoSubWord(u32 w, const u8 sbox[256])
482 {
483         return ((u32)sbox[(u8)(w >> 24)] << 24) |
484                ((u32)sbox[(u8)(w >> 16)] << 16) |
485                ((u32)sbox[(u8)(w >>  8)] <<  8) |
486                ((u32)sbox[(u8)(w >>  0)] <<  0);
487 }
488
489 static inline u32 SubWord(u32 w)
490 {
491         return DoSubWord(w, aes_sbox);
492 }
493
494 static inline u32 InvSubWord(u32 w)
495 {
496         return DoSubWord(w, aes_inverse_sbox);
497 }
498
499 static inline void SubBytes(u32 state[4])
500 {
501         int i;
502
503         for (i = 0; i < 4; i++)
504                 state[i] = SubWord(state[i]);
505 }
506
507 static inline void InvSubBytes(u32 state[4])
508 {
509         int i;
510
511         for (i = 0; i < 4; i++)
512                 state[i] = InvSubWord(state[i]);
513 }
514
515 static inline void DoShiftRows(u32 state[4], int direction)
516 {
517         u32 newstate[4];
518         int i;
519
520         for (i = 0; i < 4; i++)
521                 newstate[i] = (state[(i + direction*0) & 3] & 0xff) |
522                               (state[(i + direction*1) & 3] & 0xff00) |
523                               (state[(i + direction*2) & 3] & 0xff0000) |
524                               (state[(i + direction*3) & 3] & 0xff000000);
525         memcpy(state, newstate, 16);
526 }
527
528 static inline void ShiftRows(u32 state[4])
529 {
530         DoShiftRows(state, 1);
531 }
532
533 static inline void InvShiftRows(u32 state[4])
534 {
535         DoShiftRows(state, -1);
536 }
537
538 /*
539  * Mix one column by doing the following matrix multiplication in GF(2^8):
540  *
541  *     | 2 3 1 1 |   | w[0] |
542  *     | 1 2 3 1 |   | w[1] |
543  *     | 1 1 2 3 | x | w[2] |
544  *     | 3 1 1 2 |   | w[3] |
545  *
546  * a.k.a. w[i] = 2*w[i] + 3*w[(i+1)%4] + w[(i+2)%4] + w[(i+3)%4]
547  */
548 static inline u32 MixColumn(u32 w)
549 {
550         u32 _2w0_w2 = gf2_8_mul_x_4way(w) ^ ror32(w, 16);
551         u32 _3w1_w3 = ror32(_2w0_w2 ^ w, 8);
552
553         return _2w0_w2 ^ _3w1_w3;
554 }
555
556 /*
557  *           ( | 5 0 4 0 |   | w[0] | )
558  *          (  | 0 5 0 4 |   | w[1] |  )
559  * MixColumn(  | 4 0 5 0 | x | w[2] |  )
560  *           ( | 0 4 0 5 |   | w[3] | )
561  */
562 static inline u32 InvMixColumn(u32 w)
563 {
564         u32 _4w = gf2_8_mul_x_4way(gf2_8_mul_x_4way(w));
565
566         return MixColumn(_4w ^ w ^ ror32(_4w, 16));
567 }
568
569 static inline void MixColumns(u32 state[4])
570 {
571         int i;
572
573         for (i = 0; i < 4; i++)
574                 state[i] = MixColumn(state[i]);
575 }
576
577 static inline void InvMixColumns(u32 state[4])
578 {
579         int i;
580
581         for (i = 0; i < 4; i++)
582                 state[i] = InvMixColumn(state[i]);
583 }
584
585 struct aes_key {
586         u32 round_keys[15 * 4];
587         int nrounds;
588 };
589
590 /* Expand an AES key */
591 static void aes_setkey(struct aes_key *k, const u8 *key, int keysize)
592 {
593         const int N = keysize / 4;
594         u32 * const rk = k->round_keys;
595         u8 rcon = 1;
596         int i;
597
598         ASSERT(keysize == 16 || keysize == 24 || keysize == 32);
599         k->nrounds = 6 + N;
600         for (i = 0; i < 4 * (k->nrounds + 1); i++) {
601                 if (i < N) {
602                         rk[i] = get_unaligned_le32(&key[i * sizeof(__le32)]);
603                 } else if (i % N == 0) {
604                         rk[i] = rk[i - N] ^ SubWord(ror32(rk[i - 1], 8)) ^ rcon;
605                         rcon = gf2_8_mul_x(rcon);
606                 } else if (N > 6 && i % N == 4) {
607                         rk[i] = rk[i - N] ^ SubWord(rk[i - 1]);
608                 } else {
609                         rk[i] = rk[i - N] ^ rk[i - 1];
610                 }
611         }
612 }
613
614 /* Encrypt one 16-byte block with AES */
615 static void aes_encrypt(const struct aes_key *k, const u8 src[AES_BLOCK_SIZE],
616                         u8 dst[AES_BLOCK_SIZE])
617 {
618         u32 state[4];
619         int i;
620
621         for (i = 0; i < 4; i++)
622                 state[i] = get_unaligned_le32(&src[i * sizeof(__le32)]);
623
624         AddRoundKey(state, k->round_keys);
625         for (i = 1; i < k->nrounds; i++) {
626                 SubBytes(state);
627                 ShiftRows(state);
628                 MixColumns(state);
629                 AddRoundKey(state, &k->round_keys[4 * i]);
630         }
631         SubBytes(state);
632         ShiftRows(state);
633         AddRoundKey(state, &k->round_keys[4 * i]);
634
635         for (i = 0; i < 4; i++)
636                 put_unaligned_le32(state[i], &dst[i * sizeof(__le32)]);
637 }
638
639 /* Decrypt one 16-byte block with AES */
640 static void aes_decrypt(const struct aes_key *k, const u8 src[AES_BLOCK_SIZE],
641                         u8 dst[AES_BLOCK_SIZE])
642 {
643         u32 state[4];
644         int i;
645
646         for (i = 0; i < 4; i++)
647                 state[i] = get_unaligned_le32(&src[i * sizeof(__le32)]);
648
649         AddRoundKey(state, &k->round_keys[4 * k->nrounds]);
650         InvShiftRows(state);
651         InvSubBytes(state);
652         for (i = k->nrounds - 1; i >= 1; i--) {
653                 AddRoundKey(state, &k->round_keys[4 * i]);
654                 InvMixColumns(state);
655                 InvShiftRows(state);
656                 InvSubBytes(state);
657         }
658         AddRoundKey(state, k->round_keys);
659
660         for (i = 0; i < 4; i++)
661                 put_unaligned_le32(state[i], &dst[i * sizeof(__le32)]);
662 }
663
664 #ifdef ENABLE_ALG_TESTS
665 #include <openssl/evp.h>
666 static void test_aes_keysize(int keysize)
667 {
668         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
669         const EVP_CIPHER *evp_cipher;
670         EVP_CIPHER_CTX *ctx;
671
672         switch (keysize) {
673         case 16:
674                 evp_cipher = EVP_aes_128_ecb();
675                 break;
676         case 24:
677                 evp_cipher = EVP_aes_192_ecb();
678                 break;
679         case 32:
680                 evp_cipher = EVP_aes_256_ecb();
681                 break;
682         default:
683                 ASSERT(0);
684         }
685
686         ctx = EVP_CIPHER_CTX_new();
687         ASSERT(ctx != NULL);
688         while (num_tests--) {
689                 struct aes_key k;
690                 u8 key[AES_256_KEY_SIZE];
691                 u8 ptext[AES_BLOCK_SIZE];
692                 u8 ctext[AES_BLOCK_SIZE];
693                 u8 ref_ctext[AES_BLOCK_SIZE];
694                 u8 decrypted[AES_BLOCK_SIZE];
695                 int outl, res;
696
697                 rand_bytes(key, keysize);
698                 rand_bytes(ptext, AES_BLOCK_SIZE);
699
700                 aes_setkey(&k, key, keysize);
701                 aes_encrypt(&k, ptext, ctext);
702
703                 res = EVP_EncryptInit_ex(ctx, evp_cipher, NULL, key, NULL);
704                 ASSERT(res > 0);
705                 res = EVP_EncryptUpdate(ctx, ref_ctext, &outl, ptext,
706                                         AES_BLOCK_SIZE);
707                 ASSERT(res > 0);
708                 ASSERT(outl == AES_BLOCK_SIZE);
709                 ASSERT(memcmp(ctext, ref_ctext, AES_BLOCK_SIZE) == 0);
710
711                 aes_decrypt(&k, ctext, decrypted);
712                 ASSERT(memcmp(ptext, decrypted, AES_BLOCK_SIZE) == 0);
713         }
714         EVP_CIPHER_CTX_free(ctx);
715 }
716
717 static void test_aes(void)
718 {
719         test_aes_keysize(AES_128_KEY_SIZE);
720         test_aes_keysize(AES_192_KEY_SIZE);
721         test_aes_keysize(AES_256_KEY_SIZE);
722 }
723 #endif /* ENABLE_ALG_TESTS */
724
725 /*----------------------------------------------------------------------------*
726  *                            SHA-512 and SHA-256                             *
727  *----------------------------------------------------------------------------*/
728
729 /*
730  * Reference: "FIPS 180-2, Secure Hash Standard"
731  *      https://csrc.nist.gov/csrc/media/publications/fips/180/2/archive/2002-08-01/documents/fips180-2withchangenotice.pdf
732  */
733
734 #define SHA512_DIGEST_SIZE      64
735 #define SHA512_BLOCK_SIZE       128
736
737 #define SHA256_DIGEST_SIZE      32
738 #define SHA256_BLOCK_SIZE       64
739
740 #define Ch(x, y, z)     (((x) & (y)) ^ (~(x) & (z)))
741 #define Maj(x, y, z)    (((x) & (y)) ^ ((x) & (z)) ^ ((y) & (z)))
742
743 #define Sigma512_0(x)   (ror64((x), 28) ^ ror64((x), 34) ^ ror64((x), 39))
744 #define Sigma512_1(x)   (ror64((x), 14) ^ ror64((x), 18) ^ ror64((x), 41))
745 #define sigma512_0(x)   (ror64((x),  1) ^ ror64((x),  8) ^ ((x) >> 7))
746 #define sigma512_1(x)   (ror64((x), 19) ^ ror64((x), 61) ^ ((x) >> 6))
747
748 #define Sigma256_0(x)   (ror32((x),  2) ^ ror32((x), 13) ^ ror32((x), 22))
749 #define Sigma256_1(x)   (ror32((x),  6) ^ ror32((x), 11) ^ ror32((x), 25))
750 #define sigma256_0(x)   (ror32((x),  7) ^ ror32((x), 18) ^ ((x) >>  3))
751 #define sigma256_1(x)   (ror32((x), 17) ^ ror32((x), 19) ^ ((x) >> 10))
752
753 static const u64 sha512_iv[8] = {
754         0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b,
755         0xa54ff53a5f1d36f1, 0x510e527fade682d1, 0x9b05688c2b3e6c1f,
756         0x1f83d9abfb41bd6b, 0x5be0cd19137e2179,
757 };
758
759 static const u64 sha512_round_constants[80] = {
760         0x428a2f98d728ae22, 0x7137449123ef65cd, 0xb5c0fbcfec4d3b2f,
761         0xe9b5dba58189dbbc, 0x3956c25bf348b538, 0x59f111f1b605d019,
762         0x923f82a4af194f9b, 0xab1c5ed5da6d8118, 0xd807aa98a3030242,
763         0x12835b0145706fbe, 0x243185be4ee4b28c, 0x550c7dc3d5ffb4e2,
764         0x72be5d74f27b896f, 0x80deb1fe3b1696b1, 0x9bdc06a725c71235,
765         0xc19bf174cf692694, 0xe49b69c19ef14ad2, 0xefbe4786384f25e3,
766         0x0fc19dc68b8cd5b5, 0x240ca1cc77ac9c65, 0x2de92c6f592b0275,
767         0x4a7484aa6ea6e483, 0x5cb0a9dcbd41fbd4, 0x76f988da831153b5,
768         0x983e5152ee66dfab, 0xa831c66d2db43210, 0xb00327c898fb213f,
769         0xbf597fc7beef0ee4, 0xc6e00bf33da88fc2, 0xd5a79147930aa725,
770         0x06ca6351e003826f, 0x142929670a0e6e70, 0x27b70a8546d22ffc,
771         0x2e1b21385c26c926, 0x4d2c6dfc5ac42aed, 0x53380d139d95b3df,
772         0x650a73548baf63de, 0x766a0abb3c77b2a8, 0x81c2c92e47edaee6,
773         0x92722c851482353b, 0xa2bfe8a14cf10364, 0xa81a664bbc423001,
774         0xc24b8b70d0f89791, 0xc76c51a30654be30, 0xd192e819d6ef5218,
775         0xd69906245565a910, 0xf40e35855771202a, 0x106aa07032bbd1b8,
776         0x19a4c116b8d2d0c8, 0x1e376c085141ab53, 0x2748774cdf8eeb99,
777         0x34b0bcb5e19b48a8, 0x391c0cb3c5c95a63, 0x4ed8aa4ae3418acb,
778         0x5b9cca4f7763e373, 0x682e6ff3d6b2b8a3, 0x748f82ee5defb2fc,
779         0x78a5636f43172f60, 0x84c87814a1f0ab72, 0x8cc702081a6439ec,
780         0x90befffa23631e28, 0xa4506cebde82bde9, 0xbef9a3f7b2c67915,
781         0xc67178f2e372532b, 0xca273eceea26619c, 0xd186b8c721c0c207,
782         0xeada7dd6cde0eb1e, 0xf57d4f7fee6ed178, 0x06f067aa72176fba,
783         0x0a637dc5a2c898a6, 0x113f9804bef90dae, 0x1b710b35131c471b,
784         0x28db77f523047d84, 0x32caab7b40c72493, 0x3c9ebe0a15c9bebc,
785         0x431d67c49c100d4c, 0x4cc5d4becb3e42b6, 0x597f299cfc657e2a,
786         0x5fcb6fab3ad6faec, 0x6c44198c4a475817,
787 };
788
789 /* Compute the SHA-512 digest of the given buffer */
790 static void sha512(const u8 *in, size_t inlen, u8 out[SHA512_DIGEST_SIZE])
791 {
792         const size_t msglen = ROUND_UP(inlen + 17, SHA512_BLOCK_SIZE);
793         u8 * const msg = xmalloc(msglen);
794         u64 H[8];
795         int i;
796
797         /* super naive way of handling the padding */
798         memcpy(msg, in, inlen);
799         memset(&msg[inlen], 0, msglen - inlen);
800         msg[inlen] = 0x80;
801         put_unaligned_be64((u64)inlen * 8, &msg[msglen - sizeof(__be64)]);
802         in = msg;
803
804         memcpy(H, sha512_iv, sizeof(H));
805         do {
806                 u64 a = H[0], b = H[1], c = H[2], d = H[3],
807                     e = H[4], f = H[5], g = H[6], h = H[7];
808                 u64 W[80];
809
810                 for (i = 0; i < 16; i++)
811                         W[i] = get_unaligned_be64(&in[i * sizeof(__be64)]);
812                 for (; i < ARRAY_SIZE(W); i++)
813                         W[i] = sigma512_1(W[i - 2]) + W[i - 7] +
814                                sigma512_0(W[i - 15]) + W[i - 16];
815                 for (i = 0; i < ARRAY_SIZE(W); i++) {
816                         u64 T1 = h + Sigma512_1(e) + Ch(e, f, g) +
817                                  sha512_round_constants[i] + W[i];
818                         u64 T2 = Sigma512_0(a) + Maj(a, b, c);
819
820                         h = g; g = f; f = e; e = d + T1;
821                         d = c; c = b; b = a; a = T1 + T2;
822                 }
823                 H[0] += a; H[1] += b; H[2] += c; H[3] += d;
824                 H[4] += e; H[5] += f; H[6] += g; H[7] += h;
825         } while ((in += SHA512_BLOCK_SIZE) != &msg[msglen]);
826
827         for (i = 0; i < ARRAY_SIZE(H); i++)
828                 put_unaligned_be64(H[i], &out[i * sizeof(__be64)]);
829         free(msg);
830 }
831
832 /* Compute the SHA-256 digest of the given buffer */
833 static void sha256(const u8 *in, size_t inlen, u8 out[SHA256_DIGEST_SIZE])
834 {
835         const size_t msglen = ROUND_UP(inlen + 9, SHA256_BLOCK_SIZE);
836         u8 * const msg = xmalloc(msglen);
837         u32 H[8];
838         int i;
839
840         /* super naive way of handling the padding */
841         memcpy(msg, in, inlen);
842         memset(&msg[inlen], 0, msglen - inlen);
843         msg[inlen] = 0x80;
844         put_unaligned_be64((u64)inlen * 8, &msg[msglen - sizeof(__be64)]);
845         in = msg;
846
847         for (i = 0; i < ARRAY_SIZE(H); i++)
848                 H[i] = (u32)(sha512_iv[i] >> 32);
849         do {
850                 u32 a = H[0], b = H[1], c = H[2], d = H[3],
851                     e = H[4], f = H[5], g = H[6], h = H[7];
852                 u32 W[64];
853
854                 for (i = 0; i < 16; i++)
855                         W[i] = get_unaligned_be32(&in[i * sizeof(__be32)]);
856                 for (; i < ARRAY_SIZE(W); i++)
857                         W[i] = sigma256_1(W[i - 2]) + W[i - 7] +
858                                sigma256_0(W[i - 15]) + W[i - 16];
859                 for (i = 0; i < ARRAY_SIZE(W); i++) {
860                         u32 T1 = h + Sigma256_1(e) + Ch(e, f, g) +
861                                  (u32)(sha512_round_constants[i] >> 32) + W[i];
862                         u32 T2 = Sigma256_0(a) + Maj(a, b, c);
863
864                         h = g; g = f; f = e; e = d + T1;
865                         d = c; c = b; b = a; a = T1 + T2;
866                 }
867                 H[0] += a; H[1] += b; H[2] += c; H[3] += d;
868                 H[4] += e; H[5] += f; H[6] += g; H[7] += h;
869         } while ((in += SHA256_BLOCK_SIZE) != &msg[msglen]);
870
871         for (i = 0; i < ARRAY_SIZE(H); i++)
872                 put_unaligned_be32(H[i], &out[i * sizeof(__be32)]);
873         free(msg);
874 }
875
876 #ifdef ENABLE_ALG_TESTS
877 #include <openssl/sha.h>
878 static void test_sha2(void)
879 {
880         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
881
882         while (num_tests--) {
883                 u8 in[4096];
884                 u8 digest[SHA512_DIGEST_SIZE];
885                 u8 ref_digest[SHA512_DIGEST_SIZE];
886                 const size_t inlen = rand() % (1 + sizeof(in));
887
888                 rand_bytes(in, inlen);
889
890                 sha256(in, inlen, digest);
891                 SHA256(in, inlen, ref_digest);
892                 ASSERT(memcmp(digest, ref_digest, SHA256_DIGEST_SIZE) == 0);
893
894                 sha512(in, inlen, digest);
895                 SHA512(in, inlen, ref_digest);
896                 ASSERT(memcmp(digest, ref_digest, SHA512_DIGEST_SIZE) == 0);
897         }
898 }
899 #endif /* ENABLE_ALG_TESTS */
900
901 /*----------------------------------------------------------------------------*
902  *                            HKDF implementation                             *
903  *----------------------------------------------------------------------------*/
904
905 static void hmac_sha512(const u8 *key, size_t keylen, const u8 *msg,
906                         size_t msglen, u8 mac[SHA512_DIGEST_SIZE])
907 {
908         u8 *ibuf = xmalloc(SHA512_BLOCK_SIZE + msglen);
909         u8 obuf[SHA512_BLOCK_SIZE + SHA512_DIGEST_SIZE];
910
911         ASSERT(keylen <= SHA512_BLOCK_SIZE); /* keylen > bs not implemented */
912
913         memset(ibuf, 0x36, SHA512_BLOCK_SIZE);
914         xor(ibuf, ibuf, key, keylen);
915         memcpy(&ibuf[SHA512_BLOCK_SIZE], msg, msglen);
916
917         memset(obuf, 0x5c, SHA512_BLOCK_SIZE);
918         xor(obuf, obuf, key, keylen);
919         sha512(ibuf, SHA512_BLOCK_SIZE + msglen, &obuf[SHA512_BLOCK_SIZE]);
920         sha512(obuf, sizeof(obuf), mac);
921
922         free(ibuf);
923 }
924
925 static void hkdf_sha512(const u8 *ikm, size_t ikmlen,
926                         const u8 *salt, size_t saltlen,
927                         const u8 *info, size_t infolen,
928                         u8 *output, size_t outlen)
929 {
930         static const u8 default_salt[SHA512_DIGEST_SIZE];
931         u8 prk[SHA512_DIGEST_SIZE]; /* pseudorandom key */
932         u8 *buf = xmalloc(1 + infolen + SHA512_DIGEST_SIZE);
933         u8 counter = 1;
934         size_t i;
935
936         if (saltlen == 0) {
937                 salt = default_salt;
938                 saltlen = sizeof(default_salt);
939         }
940
941         /* HKDF-Extract */
942         ASSERT(ikmlen > 0);
943         hmac_sha512(salt, saltlen, ikm, ikmlen, prk);
944
945         /* HKDF-Expand */
946         for (i = 0; i < outlen; i += SHA512_DIGEST_SIZE) {
947                 u8 *p = buf;
948                 u8 tmp[SHA512_DIGEST_SIZE];
949
950                 ASSERT(counter != 0);
951                 if (i > 0) {
952                         memcpy(p, &output[i - SHA512_DIGEST_SIZE],
953                                SHA512_DIGEST_SIZE);
954                         p += SHA512_DIGEST_SIZE;
955                 }
956                 memcpy(p, info, infolen);
957                 p += infolen;
958                 *p++ = counter++;
959                 hmac_sha512(prk, sizeof(prk), buf, p - buf, tmp);
960                 memcpy(&output[i], tmp, MIN(sizeof(tmp), outlen - i));
961         }
962         free(buf);
963 }
964
965 #ifdef ENABLE_ALG_TESTS
966 #include <openssl/evp.h>
967 #include <openssl/kdf.h>
968 static void openssl_hkdf_sha512(const u8 *ikm, size_t ikmlen,
969                                 const u8 *salt, size_t saltlen,
970                                 const u8 *info, size_t infolen,
971                                 u8 *output, size_t outlen)
972 {
973         EVP_PKEY_CTX *pctx = EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, NULL);
974         size_t actual_outlen = outlen;
975
976         ASSERT(pctx != NULL);
977         ASSERT(EVP_PKEY_derive_init(pctx) > 0);
978         ASSERT(EVP_PKEY_CTX_set_hkdf_md(pctx, EVP_sha512()) > 0);
979         ASSERT(EVP_PKEY_CTX_set1_hkdf_key(pctx, ikm, ikmlen) > 0);
980         ASSERT(EVP_PKEY_CTX_set1_hkdf_salt(pctx, salt, saltlen) > 0);
981         ASSERT(EVP_PKEY_CTX_add1_hkdf_info(pctx, info, infolen) > 0);
982         ASSERT(EVP_PKEY_derive(pctx, output, &actual_outlen) > 0);
983         ASSERT(actual_outlen == outlen);
984         EVP_PKEY_CTX_free(pctx);
985 }
986
987 static void test_hkdf_sha512(void)
988 {
989         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
990
991         while (num_tests--) {
992                 u8 ikm[SHA512_DIGEST_SIZE];
993                 u8 salt[SHA512_DIGEST_SIZE];
994                 u8 info[128];
995                 u8 actual_output[512];
996                 u8 expected_output[sizeof(actual_output)];
997                 size_t ikmlen = 1 + (rand() % sizeof(ikm));
998                 size_t saltlen = rand() % (1 + sizeof(salt));
999                 size_t infolen = rand() % (1 + sizeof(info));
1000                 /*
1001                  * Don't test zero-length outputs, since OpenSSL 3.0 and later
1002                  * returns an error for those.
1003                  */
1004                 size_t outlen = 1 + (rand() % sizeof(actual_output));
1005
1006                 rand_bytes(ikm, ikmlen);
1007                 rand_bytes(salt, saltlen);
1008                 rand_bytes(info, infolen);
1009
1010                 hkdf_sha512(ikm, ikmlen, salt, saltlen, info, infolen,
1011                             actual_output, outlen);
1012                 openssl_hkdf_sha512(ikm, ikmlen, salt, saltlen, info, infolen,
1013                                     expected_output, outlen);
1014                 ASSERT(memcmp(actual_output, expected_output, outlen) == 0);
1015         }
1016 }
1017 #endif /* ENABLE_ALG_TESTS */
1018
1019 /*----------------------------------------------------------------------------*
1020  *                                 POLYVAL                                     *
1021  *----------------------------------------------------------------------------*/
1022
1023 /*
1024  * Reference: "AES-GCM-SIV: Nonce Misuse-Resistant Authenticated Encryption"
1025  *      https://datatracker.ietf.org/doc/html/rfc8452
1026  */
1027
1028 #define POLYVAL_KEY_SIZE        16
1029 #define POLYVAL_BLOCK_SIZE      16
1030
1031 static void polyval_update(const u8 key[POLYVAL_KEY_SIZE],
1032                            const u8 *msg, size_t msglen,
1033                            u8 accumulator[POLYVAL_BLOCK_SIZE])
1034 {
1035         ble128 h;
1036         ble128 aligned_accumulator;
1037         // x^{-128} = x^127 + x^124 + x^121 + x^114 + 1
1038         static const ble128 inv128 = {
1039                 cpu_to_le64(1),
1040                 cpu_to_le64(0x9204ULL << 48)
1041         };
1042
1043         /* Partial block support is not necessary for HCTR2 */
1044         ASSERT(msglen % POLYVAL_BLOCK_SIZE == 0);
1045
1046         memcpy(&h, key, POLYVAL_BLOCK_SIZE);
1047         memcpy(&aligned_accumulator, accumulator, POLYVAL_BLOCK_SIZE);
1048         gf2_128_mul_polyval(&h, &inv128);
1049
1050         while (msglen > 0) {
1051                 xor((u8 *)&aligned_accumulator, (u8 *)&aligned_accumulator, msg,
1052                     POLYVAL_BLOCK_SIZE);
1053                 gf2_128_mul_polyval(&aligned_accumulator, &h);
1054                 msg += POLYVAL_BLOCK_SIZE;
1055                 msglen -= POLYVAL_BLOCK_SIZE;
1056         }
1057         memcpy(accumulator, &aligned_accumulator, POLYVAL_BLOCK_SIZE);
1058 }
1059
1060 /*----------------------------------------------------------------------------*
1061  *                            AES encryption modes                            *
1062  *----------------------------------------------------------------------------*/
1063
1064 static void aes_256_xts_crypt(const u8 key[2 * AES_256_KEY_SIZE],
1065                               const u8 iv[AES_BLOCK_SIZE], const u8 *src,
1066                               u8 *dst, size_t nbytes, bool decrypting)
1067 {
1068         struct aes_key tweak_key, cipher_key;
1069         ble128 t;
1070         size_t i;
1071
1072         ASSERT(nbytes % AES_BLOCK_SIZE == 0);
1073         aes_setkey(&cipher_key, key, AES_256_KEY_SIZE);
1074         aes_setkey(&tweak_key, &key[AES_256_KEY_SIZE], AES_256_KEY_SIZE);
1075         aes_encrypt(&tweak_key, iv, (u8 *)&t);
1076         for (i = 0; i < nbytes; i += AES_BLOCK_SIZE) {
1077                 xor(&dst[i], &src[i], (const u8 *)&t, AES_BLOCK_SIZE);
1078                 if (decrypting)
1079                         aes_decrypt(&cipher_key, &dst[i], &dst[i]);
1080                 else
1081                         aes_encrypt(&cipher_key, &dst[i], &dst[i]);
1082                 xor(&dst[i], &dst[i], (const u8 *)&t, AES_BLOCK_SIZE);
1083                 gf2_128_mul_x_xts(&t);
1084         }
1085 }
1086
1087 static void aes_256_xts_encrypt(const u8 key[2 * AES_256_KEY_SIZE],
1088                                 const u8 iv[AES_BLOCK_SIZE], const u8 *src,
1089                                 u8 *dst, size_t nbytes)
1090 {
1091         aes_256_xts_crypt(key, iv, src, dst, nbytes, false);
1092 }
1093
1094 static void aes_256_xts_decrypt(const u8 key[2 * AES_256_KEY_SIZE],
1095                                 const u8 iv[AES_BLOCK_SIZE], const u8 *src,
1096                                 u8 *dst, size_t nbytes)
1097 {
1098         aes_256_xts_crypt(key, iv, src, dst, nbytes, true);
1099 }
1100
1101 #ifdef ENABLE_ALG_TESTS
1102 #include <openssl/evp.h>
1103 static void test_aes_256_xts(void)
1104 {
1105         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1106         EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
1107
1108         ASSERT(ctx != NULL);
1109         while (num_tests--) {
1110                 u8 key[2 * AES_256_KEY_SIZE];
1111                 u8 iv[AES_BLOCK_SIZE];
1112                 u8 ptext[32 * AES_BLOCK_SIZE];
1113                 u8 ctext[sizeof(ptext)];
1114                 u8 ref_ctext[sizeof(ptext)];
1115                 u8 decrypted[sizeof(ptext)];
1116                 /*
1117                  * Don't test message lengths that aren't a multiple of the AES
1118                  * block size, since support for that is not implemented here.
1119                  * Also don't test zero-length messages, since OpenSSL 3.0 and
1120                  * later returns an error for those.
1121                  */
1122                 const size_t datalen = AES_BLOCK_SIZE *
1123                         (1 + rand() % (sizeof(ptext) / AES_BLOCK_SIZE));
1124                 int outl, res;
1125
1126                 rand_bytes(key, sizeof(key));
1127                 rand_bytes(iv, sizeof(iv));
1128                 rand_bytes(ptext, datalen);
1129
1130                 aes_256_xts_encrypt(key, iv, ptext, ctext, datalen);
1131                 res = EVP_EncryptInit_ex(ctx, EVP_aes_256_xts(), NULL, key, iv);
1132                 ASSERT(res > 0);
1133                 res = EVP_EncryptUpdate(ctx, ref_ctext, &outl, ptext, datalen);
1134                 ASSERT(res > 0);
1135                 ASSERT(outl == datalen);
1136                 ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1137
1138                 aes_256_xts_decrypt(key, iv, ctext, decrypted, datalen);
1139                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1140         }
1141         EVP_CIPHER_CTX_free(ctx);
1142 }
1143 #endif /* ENABLE_ALG_TESTS */
1144
1145 static void aes_cbc_encrypt(const struct aes_key *k,
1146                             const u8 iv[AES_BLOCK_SIZE],
1147                             const u8 *src, u8 *dst, size_t nbytes)
1148 {
1149         size_t i;
1150
1151         ASSERT(nbytes % AES_BLOCK_SIZE == 0);
1152         for (i = 0; i < nbytes; i += AES_BLOCK_SIZE) {
1153                 xor(&dst[i], &src[i], (i == 0 ? iv : &dst[i - AES_BLOCK_SIZE]),
1154                     AES_BLOCK_SIZE);
1155                 aes_encrypt(k, &dst[i], &dst[i]);
1156         }
1157 }
1158
1159 static void aes_cbc_decrypt(const struct aes_key *k,
1160                             const u8 iv[AES_BLOCK_SIZE],
1161                             const u8 *src, u8 *dst, size_t nbytes)
1162 {
1163         size_t i = nbytes;
1164
1165         ASSERT(i % AES_BLOCK_SIZE == 0);
1166         while (i) {
1167                 i -= AES_BLOCK_SIZE;
1168                 aes_decrypt(k, &src[i], &dst[i]);
1169                 xor(&dst[i], &dst[i], (i == 0 ? iv : &src[i - AES_BLOCK_SIZE]),
1170                     AES_BLOCK_SIZE);
1171         }
1172 }
1173
1174 static void aes_cts_cbc_encrypt(const u8 *key, int keysize,
1175                                 const u8 iv[AES_BLOCK_SIZE],
1176                                 const u8 *src, u8 *dst, size_t nbytes)
1177 {
1178         const size_t offset = ROUND_DOWN(nbytes - 1, AES_BLOCK_SIZE);
1179         const size_t final_bsize = nbytes - offset;
1180         struct aes_key k;
1181         u8 *pad;
1182         u8 buf[AES_BLOCK_SIZE];
1183
1184         ASSERT(nbytes >= AES_BLOCK_SIZE);
1185
1186         aes_setkey(&k, key, keysize);
1187
1188         if (nbytes == AES_BLOCK_SIZE)
1189                 return aes_cbc_encrypt(&k, iv, src, dst, nbytes);
1190
1191         aes_cbc_encrypt(&k, iv, src, dst, offset);
1192         pad = &dst[offset - AES_BLOCK_SIZE];
1193
1194         memcpy(buf, pad, AES_BLOCK_SIZE);
1195         xor(buf, buf, &src[offset], final_bsize);
1196         memcpy(&dst[offset], pad, final_bsize);
1197         aes_encrypt(&k, buf, pad);
1198 }
1199
1200 static void aes_cts_cbc_decrypt(const u8 *key, int keysize,
1201                                 const u8 iv[AES_BLOCK_SIZE],
1202                                 const u8 *src, u8 *dst, size_t nbytes)
1203 {
1204         const size_t offset = ROUND_DOWN(nbytes - 1, AES_BLOCK_SIZE);
1205         const size_t final_bsize = nbytes - offset;
1206         struct aes_key k;
1207         u8 *pad;
1208
1209         ASSERT(nbytes >= AES_BLOCK_SIZE);
1210
1211         aes_setkey(&k, key, keysize);
1212
1213         if (nbytes == AES_BLOCK_SIZE)
1214                 return aes_cbc_decrypt(&k, iv, src, dst, nbytes);
1215
1216         pad = &dst[offset - AES_BLOCK_SIZE];
1217         aes_decrypt(&k, &src[offset - AES_BLOCK_SIZE], pad);
1218         xor(&dst[offset], &src[offset], pad, final_bsize);
1219         xor(pad, pad, &dst[offset], final_bsize);
1220
1221         aes_cbc_decrypt(&k, (offset == AES_BLOCK_SIZE ?
1222                              iv : &src[offset - 2 * AES_BLOCK_SIZE]),
1223                         pad, pad, AES_BLOCK_SIZE);
1224         aes_cbc_decrypt(&k, iv, src, dst, offset - AES_BLOCK_SIZE);
1225 }
1226
1227 static void aes_256_cts_cbc_encrypt(const u8 key[AES_256_KEY_SIZE],
1228                                     const u8 iv[AES_BLOCK_SIZE],
1229                                     const u8 *src, u8 *dst, size_t nbytes)
1230 {
1231         aes_cts_cbc_encrypt(key, AES_256_KEY_SIZE, iv, src, dst, nbytes);
1232 }
1233
1234 static void aes_256_cts_cbc_decrypt(const u8 key[AES_256_KEY_SIZE],
1235                                     const u8 iv[AES_BLOCK_SIZE],
1236                                     const u8 *src, u8 *dst, size_t nbytes)
1237 {
1238         aes_cts_cbc_decrypt(key, AES_256_KEY_SIZE, iv, src, dst, nbytes);
1239 }
1240
1241 #ifdef ENABLE_ALG_TESTS
1242 #include <openssl/modes.h>
1243 static void aes_block128_f(const unsigned char in[16],
1244                            unsigned char out[16], const void *key)
1245 {
1246         aes_encrypt(key, in, out);
1247 }
1248
1249 static void test_aes_256_cts_cbc(void)
1250 {
1251         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1252
1253         while (num_tests--) {
1254                 u8 key[AES_256_KEY_SIZE];
1255                 u8 iv[AES_BLOCK_SIZE];
1256                 u8 iv_copy[AES_BLOCK_SIZE];
1257                 u8 ptext[512];
1258                 u8 ctext[sizeof(ptext)];
1259                 u8 ref_ctext[sizeof(ptext)];
1260                 u8 decrypted[sizeof(ptext)];
1261                 const size_t datalen = 16 + (rand() % (sizeof(ptext) - 15));
1262                 struct aes_key k;
1263
1264                 rand_bytes(key, sizeof(key));
1265                 rand_bytes(iv, sizeof(iv));
1266                 rand_bytes(ptext, datalen);
1267
1268                 aes_256_cts_cbc_encrypt(key, iv, ptext, ctext, datalen);
1269
1270                 /* OpenSSL doesn't allow datalen=AES_BLOCK_SIZE; Linux does */
1271                 if (datalen != AES_BLOCK_SIZE) {
1272                         aes_setkey(&k, key, sizeof(key));
1273                         memcpy(iv_copy, iv, sizeof(iv));
1274                         ASSERT(CRYPTO_cts128_encrypt_block(ptext, ref_ctext,
1275                                                            datalen, &k, iv_copy,
1276                                                            aes_block128_f)
1277                                == datalen);
1278                         ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1279                 }
1280                 aes_256_cts_cbc_decrypt(key, iv, ctext, decrypted, datalen);
1281                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1282         }
1283 }
1284 #endif /* ENABLE_ALG_TESTS */
1285
1286 static void essiv_generate_iv(const u8 orig_key[AES_128_KEY_SIZE],
1287                               const u8 orig_iv[AES_BLOCK_SIZE],
1288                               u8 real_iv[AES_BLOCK_SIZE])
1289 {
1290         u8 essiv_key[SHA256_DIGEST_SIZE];
1291         struct aes_key essiv;
1292
1293         /* AES encrypt the original IV using a hash of the original key */
1294         STATIC_ASSERT(SHA256_DIGEST_SIZE == AES_256_KEY_SIZE);
1295         sha256(orig_key, AES_128_KEY_SIZE, essiv_key);
1296         aes_setkey(&essiv, essiv_key, AES_256_KEY_SIZE);
1297         aes_encrypt(&essiv, orig_iv, real_iv);
1298 }
1299
1300 static void aes_128_cbc_essiv_encrypt(const u8 key[AES_128_KEY_SIZE],
1301                                       const u8 iv[AES_BLOCK_SIZE],
1302                                       const u8 *src, u8 *dst, size_t nbytes)
1303 {
1304         struct aes_key k;
1305         u8 real_iv[AES_BLOCK_SIZE];
1306
1307         aes_setkey(&k, key, AES_128_KEY_SIZE);
1308         essiv_generate_iv(key, iv, real_iv);
1309         aes_cbc_encrypt(&k, real_iv, src, dst, nbytes);
1310 }
1311
1312 static void aes_128_cbc_essiv_decrypt(const u8 key[AES_128_KEY_SIZE],
1313                                       const u8 iv[AES_BLOCK_SIZE],
1314                                       const u8 *src, u8 *dst, size_t nbytes)
1315 {
1316         struct aes_key k;
1317         u8 real_iv[AES_BLOCK_SIZE];
1318
1319         aes_setkey(&k, key, AES_128_KEY_SIZE);
1320         essiv_generate_iv(key, iv, real_iv);
1321         aes_cbc_decrypt(&k, real_iv, src, dst, nbytes);
1322 }
1323
1324 static void aes_128_cts_cbc_encrypt(const u8 key[AES_128_KEY_SIZE],
1325                                     const u8 iv[AES_BLOCK_SIZE],
1326                                     const u8 *src, u8 *dst, size_t nbytes)
1327 {
1328         aes_cts_cbc_encrypt(key, AES_128_KEY_SIZE, iv, src, dst, nbytes);
1329 }
1330
1331 static void aes_128_cts_cbc_decrypt(const u8 key[AES_128_KEY_SIZE],
1332                                     const u8 iv[AES_BLOCK_SIZE],
1333                                     const u8 *src, u8 *dst, size_t nbytes)
1334 {
1335         aes_cts_cbc_decrypt(key, AES_128_KEY_SIZE, iv, src, dst, nbytes);
1336 }
1337
1338 /*
1339  * Reference: "Length-preserving encryption with HCTR2"
1340  *      https://ia.cr/2021/1441
1341  */
1342
1343 static void aes_256_xctr_crypt(const u8 key[AES_256_KEY_SIZE],
1344                               const u8 iv[AES_BLOCK_SIZE], const u8 *src,
1345                               u8 *dst, size_t nbytes)
1346 {
1347         struct aes_key k;
1348         union {
1349                 u8 bytes[AES_BLOCK_SIZE];
1350                 __le64 ctr;
1351         } blk;
1352         size_t i;
1353
1354         aes_setkey(&k, key, AES_256_KEY_SIZE);
1355
1356         for (i = 0; i < nbytes; i += AES_BLOCK_SIZE) {
1357                 memcpy(blk.bytes, iv, AES_BLOCK_SIZE);
1358                 blk.ctr ^= cpu_to_le64((i / AES_BLOCK_SIZE) + 1);
1359                 aes_encrypt(&k, blk.bytes, blk.bytes);
1360                 xor(&dst[i], blk.bytes, &src[i], MIN(AES_BLOCK_SIZE, nbytes - i));
1361         }
1362 }
1363
1364 /*
1365  * Reference: "Length-preserving encryption with HCTR2"
1366  *      https://ia.cr/2021/1441
1367  */
1368
1369 #define HCTR2_IV_SIZE 32
1370 static void hctr2_hash_iv(const u8 hbar[POLYVAL_KEY_SIZE],
1371                           const u8 iv[HCTR2_IV_SIZE], size_t msglen,
1372                           u8 digest[POLYVAL_BLOCK_SIZE])
1373 {
1374         le128 tweaklen_blk = {
1375                 .lo = cpu_to_le64(HCTR2_IV_SIZE * 8 * 2 + 2 +
1376                                   (msglen % AES_BLOCK_SIZE != 0))
1377         };
1378
1379         memset(digest, 0, POLYVAL_BLOCK_SIZE);
1380         polyval_update(hbar, (u8 *)&tweaklen_blk, POLYVAL_BLOCK_SIZE, digest);
1381         polyval_update(hbar, iv, HCTR2_IV_SIZE, digest);
1382 }
1383
1384 static void hctr2_hash_message(const u8 hbar[POLYVAL_KEY_SIZE],
1385                                const u8 *msg, size_t msglen,
1386                                u8 digest[POLYVAL_BLOCK_SIZE])
1387 {
1388         size_t remainder = msglen % AES_BLOCK_SIZE;
1389         u8 padded_block[POLYVAL_BLOCK_SIZE] = {0};
1390
1391         polyval_update(hbar, msg, msglen - remainder, digest);
1392         if (remainder) {
1393                 memcpy(padded_block, &msg[msglen - remainder], remainder);
1394                 padded_block[remainder] = 1;
1395                 polyval_update(hbar, padded_block, POLYVAL_BLOCK_SIZE, digest);
1396         }
1397 }
1398
1399 static void aes_256_hctr2_crypt(const u8 key[AES_256_KEY_SIZE],
1400                                 const u8 iv[HCTR2_IV_SIZE], const u8 *src,
1401                                 u8 *dst, size_t nbytes, bool decrypting)
1402 {
1403         struct aes_key k;
1404         u8 hbar[AES_BLOCK_SIZE] = {0};
1405         u8 L[AES_BLOCK_SIZE] = {1};
1406         size_t bulk_bytes = nbytes - AES_BLOCK_SIZE;
1407         u8 digest[POLYVAL_BLOCK_SIZE];
1408         const u8 *M = src;
1409         const u8 *N = src + AES_BLOCK_SIZE;
1410         u8 MM[AES_BLOCK_SIZE];
1411         u8 UU[AES_BLOCK_SIZE];
1412         u8 S[AES_BLOCK_SIZE];
1413         u8 *U = dst;
1414         u8 *V = dst + AES_BLOCK_SIZE;
1415
1416         ASSERT(nbytes >= AES_BLOCK_SIZE);
1417         aes_setkey(&k, key, AES_256_KEY_SIZE);
1418
1419         aes_encrypt(&k, hbar, hbar);
1420         aes_encrypt(&k, L, L);
1421
1422         hctr2_hash_iv(hbar, iv, bulk_bytes, digest);
1423         hctr2_hash_message(hbar, N, bulk_bytes, digest);
1424
1425         xor(MM, M, digest, AES_BLOCK_SIZE);
1426
1427         if (decrypting)
1428                 aes_decrypt(&k, MM, UU);
1429         else
1430                 aes_encrypt(&k, MM, UU);
1431
1432         xor(S, MM, UU, AES_BLOCK_SIZE);
1433         xor(S, L, S, AES_BLOCK_SIZE);
1434
1435         aes_256_xctr_crypt(key, S, N, V, bulk_bytes);
1436
1437         hctr2_hash_iv(hbar, iv, bulk_bytes, digest);
1438         hctr2_hash_message(hbar, V, bulk_bytes, digest);
1439
1440         xor(U, UU, digest, AES_BLOCK_SIZE);
1441 }
1442
1443 static void aes_256_hctr2_encrypt(const u8 key[AES_256_KEY_SIZE],
1444                                   const u8 iv[HCTR2_IV_SIZE], const u8 *src,
1445                                   u8 *dst, size_t nbytes)
1446 {
1447         aes_256_hctr2_crypt(key, iv, src, dst, nbytes, false);
1448 }
1449
1450 static void aes_256_hctr2_decrypt(const u8 key[AES_256_KEY_SIZE],
1451                                   const u8 iv[HCTR2_IV_SIZE], const u8 *src,
1452                                   u8 *dst, size_t nbytes)
1453 {
1454         aes_256_hctr2_crypt(key, iv, src, dst, nbytes, true);
1455 }
1456
1457 #ifdef ENABLE_ALG_TESTS
1458 #include <linux/if_alg.h>
1459 #include <sys/socket.h>
1460 static void test_aes_256_hctr2(void)
1461 {
1462         int algfd = socket(AF_ALG, SOCK_SEQPACKET, 0);
1463         struct sockaddr_alg addr = {
1464                 .salg_type = "skcipher",
1465                 .salg_name = "hctr2(aes)",
1466         };
1467         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1468
1469         if (algfd < 0)
1470                 die_errno("can't create AF_ALG socket");
1471         if (bind(algfd, (struct sockaddr *)&addr, sizeof(addr)) != 0)
1472                 die_errno("can't bind AF_ALG socket to HCTR2 algorithm");
1473
1474         while (num_tests--) {
1475                 u8 key[AES_256_KEY_SIZE];
1476                 u8 iv[HCTR2_IV_SIZE];
1477                 u8 ptext[4096];
1478                 u8 ctext[sizeof(ptext)];
1479                 u8 ref_ctext[sizeof(ptext)];
1480                 u8 decrypted[sizeof(ptext)];
1481                 const size_t datalen = 16 + (rand() % (sizeof(ptext) - 15));
1482
1483                 rand_bytes(key, sizeof(key));
1484                 rand_bytes(iv, sizeof(iv));
1485                 rand_bytes(ptext, datalen);
1486
1487                 aes_256_hctr2_encrypt(key, iv, ptext, ctext, datalen);
1488                 af_alg_crypt(algfd, ALG_OP_ENCRYPT, key, sizeof(key),
1489                              iv, sizeof(iv), ptext, ref_ctext, datalen);
1490                 ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1491
1492                 aes_256_hctr2_decrypt(key, iv, ctext, decrypted, datalen);
1493                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1494         }
1495         close(algfd);
1496 }
1497 #endif /* ENABLE_ALG_TESTS */
1498
1499 /*----------------------------------------------------------------------------*
1500  *                           XChaCha12 stream cipher                          *
1501  *----------------------------------------------------------------------------*/
1502
1503 /*
1504  * References:
1505  *   - "XChaCha: eXtended-nonce ChaCha and AEAD_XChaCha20_Poly1305"
1506  *      https://tools.ietf.org/html/draft-arciszewski-xchacha-03
1507  *
1508  *   - "ChaCha, a variant of Salsa20"
1509  *      https://cr.yp.to/chacha/chacha-20080128.pdf
1510  *
1511  *   - "Extending the Salsa20 nonce"
1512  *      https://cr.yp.to/snuffle/xsalsa-20081128.pdf
1513  */
1514
1515 #define CHACHA_KEY_SIZE         32
1516 #define XCHACHA_KEY_SIZE        CHACHA_KEY_SIZE
1517 #define XCHACHA_NONCE_SIZE      24
1518
1519 static void chacha_init_state(u32 state[16], const u8 key[CHACHA_KEY_SIZE],
1520                               const u8 iv[16])
1521 {
1522         static const u8 consts[16] = "expand 32-byte k";
1523         int i;
1524
1525         for (i = 0; i < 4; i++)
1526                 state[i] = get_unaligned_le32(&consts[i * sizeof(__le32)]);
1527         for (i = 0; i < 8; i++)
1528                 state[4 + i] = get_unaligned_le32(&key[i * sizeof(__le32)]);
1529         for (i = 0; i < 4; i++)
1530                 state[12 + i] = get_unaligned_le32(&iv[i * sizeof(__le32)]);
1531 }
1532
1533 #define CHACHA_QUARTERROUND(a, b, c, d)         \
1534         do {                                    \
1535                 a += b; d = rol32(d ^ a, 16);   \
1536                 c += d; b = rol32(b ^ c, 12);   \
1537                 a += b; d = rol32(d ^ a, 8);    \
1538                 c += d; b = rol32(b ^ c, 7);    \
1539         } while (0)
1540
1541 static void chacha_permute(u32 x[16], int nrounds)
1542 {
1543         do {
1544                 /* column round */
1545                 CHACHA_QUARTERROUND(x[0], x[4], x[8], x[12]);
1546                 CHACHA_QUARTERROUND(x[1], x[5], x[9], x[13]);
1547                 CHACHA_QUARTERROUND(x[2], x[6], x[10], x[14]);
1548                 CHACHA_QUARTERROUND(x[3], x[7], x[11], x[15]);
1549
1550                 /* diagonal round */
1551                 CHACHA_QUARTERROUND(x[0], x[5], x[10], x[15]);
1552                 CHACHA_QUARTERROUND(x[1], x[6], x[11], x[12]);
1553                 CHACHA_QUARTERROUND(x[2], x[7], x[8], x[13]);
1554                 CHACHA_QUARTERROUND(x[3], x[4], x[9], x[14]);
1555         } while ((nrounds -= 2) != 0);
1556 }
1557
1558 static void xchacha(const u8 key[XCHACHA_KEY_SIZE],
1559                     const u8 nonce[XCHACHA_NONCE_SIZE],
1560                     const u8 *src, u8 *dst, size_t nbytes, int nrounds)
1561 {
1562         u32 state[16];
1563         u8 real_key[CHACHA_KEY_SIZE];
1564         u8 real_iv[16] = { 0 };
1565         size_t i, j;
1566
1567         /* Compute real key using original key and first 128 nonce bits */
1568         chacha_init_state(state, key, nonce);
1569         chacha_permute(state, nrounds);
1570         for (i = 0; i < 8; i++) /* state words 0..3, 12..15 */
1571                 put_unaligned_le32(state[(i < 4 ? 0 : 8) + i],
1572                                    &real_key[i * sizeof(__le32)]);
1573
1574         /* Now do regular ChaCha, using real key and remaining nonce bits */
1575         memcpy(&real_iv[8], nonce + 16, 8);
1576         chacha_init_state(state, real_key, real_iv);
1577         for (i = 0; i < nbytes; i += 64) {
1578                 u32 x[16];
1579                 __le32 keystream[16];
1580
1581                 memcpy(x, state, 64);
1582                 chacha_permute(x, nrounds);
1583                 for (j = 0; j < 16; j++)
1584                         keystream[j] = cpu_to_le32(x[j] + state[j]);
1585                 xor(&dst[i], &src[i], (u8 *)keystream, MIN(nbytes - i, 64));
1586                 if (++state[12] == 0)
1587                         state[13]++;
1588         }
1589 }
1590
1591 static void xchacha12(const u8 key[XCHACHA_KEY_SIZE],
1592                       const u8 nonce[XCHACHA_NONCE_SIZE],
1593                       const u8 *src, u8 *dst, size_t nbytes)
1594 {
1595         xchacha(key, nonce, src, dst, nbytes, 12);
1596 }
1597
1598 /*----------------------------------------------------------------------------*
1599  *                                 Poly1305                                   *
1600  *----------------------------------------------------------------------------*/
1601
1602 /*
1603  * Note: this is only the Poly1305 Îµ-almost-∆-universal hash function, not the
1604  * full Poly1305 MAC.  I.e., it doesn't add anything at the end.
1605  */
1606
1607 #define POLY1305_KEY_SIZE       16
1608 #define POLY1305_BLOCK_SIZE     16
1609
1610 static void poly1305(const u8 key[POLY1305_KEY_SIZE],
1611                      const u8 *msg, size_t msglen, le128 *out)
1612 {
1613         const u32 limb_mask = 0x3ffffff;        /* limbs are base 2^26 */
1614         const u64 r0 = (get_unaligned_le32(key +  0) >> 0) & 0x3ffffff;
1615         const u64 r1 = (get_unaligned_le32(key +  3) >> 2) & 0x3ffff03;
1616         const u64 r2 = (get_unaligned_le32(key +  6) >> 4) & 0x3ffc0ff;
1617         const u64 r3 = (get_unaligned_le32(key +  9) >> 6) & 0x3f03fff;
1618         const u64 r4 = (get_unaligned_le32(key + 12) >> 8) & 0x00fffff;
1619         u32 h0 = 0, h1 = 0, h2 = 0, h3 = 0, h4 = 0;
1620         u32 g0, g1, g2, g3, g4, ge_p_mask;
1621
1622         /* Partial block support is not necessary for Adiantum */
1623         ASSERT(msglen % POLY1305_BLOCK_SIZE == 0);
1624
1625         while (msglen) {
1626                 u64 d0, d1, d2, d3, d4;
1627
1628                 /* h += *msg */
1629                 h0 += (get_unaligned_le32(msg +  0) >> 0) & limb_mask;
1630                 h1 += (get_unaligned_le32(msg +  3) >> 2) & limb_mask;
1631                 h2 += (get_unaligned_le32(msg +  6) >> 4) & limb_mask;
1632                 h3 += (get_unaligned_le32(msg +  9) >> 6) & limb_mask;
1633                 h4 += (get_unaligned_le32(msg + 12) >> 8) | (1 << 24);
1634
1635                 /* h *= r */
1636                 d0 = h0*r0 + h1*5*r4 + h2*5*r3 + h3*5*r2 + h4*5*r1;
1637                 d1 = h0*r1 + h1*r0   + h2*5*r4 + h3*5*r3 + h4*5*r2;
1638                 d2 = h0*r2 + h1*r1   + h2*r0   + h3*5*r4 + h4*5*r3;
1639                 d3 = h0*r3 + h1*r2   + h2*r1   + h3*r0   + h4*5*r4;
1640                 d4 = h0*r4 + h1*r3   + h2*r2   + h3*r1   + h4*r0;
1641
1642                 /* (partial) h %= 2^130 - 5 */
1643                 d1 += d0 >> 26;         h0 = d0 & limb_mask;
1644                 d2 += d1 >> 26;         h1 = d1 & limb_mask;
1645                 d3 += d2 >> 26;         h2 = d2 & limb_mask;
1646                 d4 += d3 >> 26;         h3 = d3 & limb_mask;
1647                 h0 += (d4 >> 26) * 5;   h4 = d4 & limb_mask;
1648                 h1 += h0 >> 26;         h0 &= limb_mask;
1649
1650                 msg += POLY1305_BLOCK_SIZE;
1651                 msglen -= POLY1305_BLOCK_SIZE;
1652         }
1653
1654         /* fully carry h */
1655         h2 += (h1 >> 26);       h1 &= limb_mask;
1656         h3 += (h2 >> 26);       h2 &= limb_mask;
1657         h4 += (h3 >> 26);       h3 &= limb_mask;
1658         h0 += (h4 >> 26) * 5;   h4 &= limb_mask;
1659         h1 += (h0 >> 26);       h0 &= limb_mask;
1660
1661         /* if (h >= 2^130 - 5) h -= 2^130 - 5; */
1662         g0 = h0 + 5;
1663         g1 = h1 + (g0 >> 26);   g0 &= limb_mask;
1664         g2 = h2 + (g1 >> 26);   g1 &= limb_mask;
1665         g3 = h3 + (g2 >> 26);   g2 &= limb_mask;
1666         g4 = h4 + (g3 >> 26);   g3 &= limb_mask;
1667         ge_p_mask = ~((g4 >> 26) - 1); /* all 1's if h >= 2^130 - 5, else 0 */
1668         h0 = (h0 & ~ge_p_mask) | (g0 & ge_p_mask);
1669         h1 = (h1 & ~ge_p_mask) | (g1 & ge_p_mask);
1670         h2 = (h2 & ~ge_p_mask) | (g2 & ge_p_mask);
1671         h3 = (h3 & ~ge_p_mask) | (g3 & ge_p_mask);
1672         h4 = (h4 & ~ge_p_mask) | (g4 & ge_p_mask & limb_mask);
1673
1674         /* h %= 2^128 */
1675         out->lo = cpu_to_le64(((u64)h2 << 52) | ((u64)h1 << 26) | h0);
1676         out->hi = cpu_to_le64(((u64)h4 << 40) | ((u64)h3 << 14) | (h2 >> 12));
1677 }
1678
1679 /*----------------------------------------------------------------------------*
1680  *                          Adiantum encryption mode                          *
1681  *----------------------------------------------------------------------------*/
1682
1683 /*
1684  * Reference: "Adiantum: length-preserving encryption for entry-level processors"
1685  *      https://tosc.iacr.org/index.php/ToSC/article/view/7360
1686  */
1687
1688 #define ADIANTUM_KEY_SIZE       32
1689 #define ADIANTUM_IV_SIZE        32
1690 #define ADIANTUM_HASH_KEY_SIZE  ((2 * POLY1305_KEY_SIZE) + NH_KEY_SIZE)
1691
1692 #define NH_KEY_SIZE             1072
1693 #define NH_KEY_WORDS            (NH_KEY_SIZE / sizeof(u32))
1694 #define NH_BLOCK_SIZE           1024
1695 #define NH_HASH_SIZE            32
1696 #define NH_MESSAGE_UNIT         16
1697
1698 static u64 nh_pass(const u32 *key, const u8 *msg, size_t msglen)
1699 {
1700         u64 sum = 0;
1701
1702         ASSERT(msglen % NH_MESSAGE_UNIT == 0);
1703         while (msglen) {
1704                 sum += (u64)(u32)(get_unaligned_le32(msg +  0) + key[0]) *
1705                             (u32)(get_unaligned_le32(msg +  8) + key[2]);
1706                 sum += (u64)(u32)(get_unaligned_le32(msg +  4) + key[1]) *
1707                             (u32)(get_unaligned_le32(msg + 12) + key[3]);
1708                 key += NH_MESSAGE_UNIT / sizeof(key[0]);
1709                 msg += NH_MESSAGE_UNIT;
1710                 msglen -= NH_MESSAGE_UNIT;
1711         }
1712         return sum;
1713 }
1714
1715 /* NH Îµ-almost-universal hash function */
1716 static void nh(const u32 *key, const u8 *msg, size_t msglen,
1717                u8 result[NH_HASH_SIZE])
1718 {
1719         size_t i;
1720
1721         for (i = 0; i < NH_HASH_SIZE; i += sizeof(__le64)) {
1722                 put_unaligned_le64(nh_pass(key, msg, msglen), &result[i]);
1723                 key += NH_MESSAGE_UNIT / sizeof(key[0]);
1724         }
1725 }
1726
1727 /* Adiantum's Îµ-almost-∆-universal hash function */
1728 static void adiantum_hash(const u8 key[ADIANTUM_HASH_KEY_SIZE],
1729                           const u8 iv[ADIANTUM_IV_SIZE],
1730                           const u8 *msg, size_t msglen, le128 *result)
1731 {
1732         const u8 *header_poly_key = key;
1733         const u8 *msg_poly_key = header_poly_key + POLY1305_KEY_SIZE;
1734         const u8 *nh_key = msg_poly_key + POLY1305_KEY_SIZE;
1735         u32 nh_key_words[NH_KEY_WORDS];
1736         u8 header[POLY1305_BLOCK_SIZE + ADIANTUM_IV_SIZE];
1737         const size_t num_nh_blocks = DIV_ROUND_UP(msglen, NH_BLOCK_SIZE);
1738         u8 *nh_hashes = xmalloc(num_nh_blocks * NH_HASH_SIZE);
1739         const size_t padded_msglen = ROUND_UP(msglen, NH_MESSAGE_UNIT);
1740         u8 *padded_msg = xmalloc(padded_msglen);
1741         le128 hash1, hash2;
1742         size_t i;
1743
1744         for (i = 0; i < NH_KEY_WORDS; i++)
1745                 nh_key_words[i] = get_unaligned_le32(&nh_key[i * sizeof(u32)]);
1746
1747         /* Hash tweak and message length with first Poly1305 key */
1748         put_unaligned_le64((u64)msglen * 8, header);
1749         put_unaligned_le64(0, &header[sizeof(__le64)]);
1750         memcpy(&header[POLY1305_BLOCK_SIZE], iv, ADIANTUM_IV_SIZE);
1751         poly1305(header_poly_key, header, sizeof(header), &hash1);
1752
1753         /* Hash NH hashes of message blocks using second Poly1305 key */
1754         /* (using a super naive way of handling the padding) */
1755         memcpy(padded_msg, msg, msglen);
1756         memset(&padded_msg[msglen], 0, padded_msglen - msglen);
1757         for (i = 0; i < num_nh_blocks; i++) {
1758                 nh(nh_key_words, &padded_msg[i * NH_BLOCK_SIZE],
1759                    MIN(NH_BLOCK_SIZE, padded_msglen - (i * NH_BLOCK_SIZE)),
1760                    &nh_hashes[i * NH_HASH_SIZE]);
1761         }
1762         poly1305(msg_poly_key, nh_hashes, num_nh_blocks * NH_HASH_SIZE, &hash2);
1763
1764         /* Add the two hashes together to get the final hash */
1765         le128_add(result, &hash1, &hash2);
1766
1767         free(nh_hashes);
1768         free(padded_msg);
1769 }
1770
1771 static void adiantum_crypt(const u8 key[ADIANTUM_KEY_SIZE],
1772                            const u8 iv[ADIANTUM_IV_SIZE], const u8 *src,
1773                            u8 *dst, size_t nbytes, bool decrypting)
1774 {
1775         u8 subkeys[AES_256_KEY_SIZE + ADIANTUM_HASH_KEY_SIZE] = { 0 };
1776         struct aes_key aes_key;
1777         union {
1778                 u8 nonce[XCHACHA_NONCE_SIZE];
1779                 le128 block;
1780         } u = { .nonce = { 1 } };
1781         const size_t bulk_len = nbytes - sizeof(u.block);
1782         le128 hash;
1783
1784         ASSERT(nbytes >= sizeof(u.block));
1785
1786         /* Derive subkeys */
1787         xchacha12(key, u.nonce, subkeys, subkeys, sizeof(subkeys));
1788         aes_setkey(&aes_key, subkeys, AES_256_KEY_SIZE);
1789
1790         /* Hash left part and add to right part */
1791         adiantum_hash(&subkeys[AES_256_KEY_SIZE], iv, src, bulk_len, &hash);
1792         memcpy(&u.block, &src[bulk_len], sizeof(u.block));
1793         le128_add(&u.block, &u.block, &hash);
1794
1795         if (!decrypting) /* Encrypt right part with block cipher */
1796                 aes_encrypt(&aes_key, u.nonce, u.nonce);
1797
1798         /* Encrypt left part with stream cipher, using the computed nonce */
1799         u.nonce[sizeof(u.block)] = 1;
1800         xchacha12(key, u.nonce, src, dst, bulk_len);
1801
1802         if (decrypting) /* Decrypt right part with block cipher */
1803                 aes_decrypt(&aes_key, u.nonce, u.nonce);
1804
1805         /* Finalize right part by subtracting hash of left part */
1806         adiantum_hash(&subkeys[AES_256_KEY_SIZE], iv, dst, bulk_len, &hash);
1807         le128_sub(&u.block, &u.block, &hash);
1808         memcpy(&dst[bulk_len], &u.block, sizeof(u.block));
1809 }
1810
1811 static void adiantum_encrypt(const u8 key[ADIANTUM_KEY_SIZE],
1812                              const u8 iv[ADIANTUM_IV_SIZE],
1813                              const u8 *src, u8 *dst, size_t nbytes)
1814 {
1815         adiantum_crypt(key, iv, src, dst, nbytes, false);
1816 }
1817
1818 static void adiantum_decrypt(const u8 key[ADIANTUM_KEY_SIZE],
1819                              const u8 iv[ADIANTUM_IV_SIZE],
1820                              const u8 *src, u8 *dst, size_t nbytes)
1821 {
1822         adiantum_crypt(key, iv, src, dst, nbytes, true);
1823 }
1824
1825 #ifdef ENABLE_ALG_TESTS
1826 static void test_adiantum(void)
1827 {
1828         int algfd = socket(AF_ALG, SOCK_SEQPACKET, 0);
1829         struct sockaddr_alg addr = {
1830                 .salg_type = "skcipher",
1831                 .salg_name = "adiantum(xchacha12,aes)",
1832         };
1833         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1834
1835         if (algfd < 0)
1836                 die_errno("can't create AF_ALG socket");
1837         if (bind(algfd, (struct sockaddr *)&addr, sizeof(addr)) != 0)
1838                 die_errno("can't bind AF_ALG socket to Adiantum algorithm");
1839
1840         while (num_tests--) {
1841                 u8 key[ADIANTUM_KEY_SIZE];
1842                 u8 iv[ADIANTUM_IV_SIZE];
1843                 u8 ptext[4096];
1844                 u8 ctext[sizeof(ptext)];
1845                 u8 ref_ctext[sizeof(ptext)];
1846                 u8 decrypted[sizeof(ptext)];
1847                 const size_t datalen = 16 + (rand() % (sizeof(ptext) - 15));
1848
1849                 rand_bytes(key, sizeof(key));
1850                 rand_bytes(iv, sizeof(iv));
1851                 rand_bytes(ptext, datalen);
1852
1853                 adiantum_encrypt(key, iv, ptext, ctext, datalen);
1854                 af_alg_crypt(algfd, ALG_OP_ENCRYPT, key, sizeof(key),
1855                              iv, sizeof(iv), ptext, ref_ctext, datalen);
1856                 ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1857
1858                 adiantum_decrypt(key, iv, ctext, decrypted, datalen);
1859                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1860         }
1861         close(algfd);
1862 }
1863 #endif /* ENABLE_ALG_TESTS */
1864
1865 /*----------------------------------------------------------------------------*
1866  *                               SipHash-2-4                                  *
1867  *----------------------------------------------------------------------------*/
1868
1869 /*
1870  * Reference: "SipHash: a fast short-input PRF"
1871  *      https://cr.yp.to/siphash/siphash-20120918.pdf
1872  */
1873
1874 #define SIPROUND                                                \
1875         do {                                                    \
1876                 v0 += v1;           v2 += v3;                   \
1877                 v1 = rol64(v1, 13); v3 = rol64(v3, 16);         \
1878                 v1 ^= v0;           v3 ^= v2;                   \
1879                 v0 = rol64(v0, 32);                             \
1880                 v2 += v1;           v0 += v3;                   \
1881                 v1 = rol64(v1, 17); v3 = rol64(v3, 21);         \
1882                 v1 ^= v2;           v3 ^= v0;                   \
1883                 v2 = rol64(v2, 32);                             \
1884         } while (0)
1885
1886 /* Compute the SipHash-2-4 of a 64-bit number when formatted as little endian */
1887 static u64 siphash_1u64(const u64 key[2], u64 data)
1888 {
1889         u64 v0 = key[0] ^ 0x736f6d6570736575ULL;
1890         u64 v1 = key[1] ^ 0x646f72616e646f6dULL;
1891         u64 v2 = key[0] ^ 0x6c7967656e657261ULL;
1892         u64 v3 = key[1] ^ 0x7465646279746573ULL;
1893         u64 m[2] = {data, (u64)sizeof(data) << 56};
1894         size_t i;
1895
1896         for (i = 0; i < ARRAY_SIZE(m); i++) {
1897                 v3 ^= m[i];
1898                 SIPROUND;
1899                 SIPROUND;
1900                 v0 ^= m[i];
1901         }
1902
1903         v2 ^= 0xff;
1904         for (i = 0; i < 4; i++)
1905                 SIPROUND;
1906         return v0 ^ v1 ^ v2 ^ v3;
1907 }
1908
1909 /*----------------------------------------------------------------------------*
1910  *                               Main program                                 *
1911  *----------------------------------------------------------------------------*/
1912
1913 #define FILE_NONCE_SIZE         16
1914 #define UUID_SIZE               16
1915 #define MAX_KEY_SIZE            64
1916 #define MAX_IV_SIZE             ADIANTUM_IV_SIZE
1917
1918 static const struct fscrypt_cipher {
1919         const char *name;
1920         void (*encrypt)(const u8 *key, const u8 *iv, const u8 *src,
1921                         u8 *dst, size_t nbytes);
1922         void (*decrypt)(const u8 *key, const u8 *iv, const u8 *src,
1923                         u8 *dst, size_t nbytes);
1924         int keysize;
1925         int min_input_size;
1926 } fscrypt_ciphers[] = {
1927         {
1928                 .name = "AES-256-XTS",
1929                 .encrypt = aes_256_xts_encrypt,
1930                 .decrypt = aes_256_xts_decrypt,
1931                 .keysize = 2 * AES_256_KEY_SIZE,
1932         }, {
1933                 .name = "AES-256-CTS-CBC",
1934                 .encrypt = aes_256_cts_cbc_encrypt,
1935                 .decrypt = aes_256_cts_cbc_decrypt,
1936                 .keysize = AES_256_KEY_SIZE,
1937                 .min_input_size = AES_BLOCK_SIZE,
1938         }, {
1939                 .name = "AES-128-CBC-ESSIV",
1940                 .encrypt = aes_128_cbc_essiv_encrypt,
1941                 .decrypt = aes_128_cbc_essiv_decrypt,
1942                 .keysize = AES_128_KEY_SIZE,
1943         }, {
1944                 .name = "AES-128-CTS-CBC",
1945                 .encrypt = aes_128_cts_cbc_encrypt,
1946                 .decrypt = aes_128_cts_cbc_decrypt,
1947                 .keysize = AES_128_KEY_SIZE,
1948                 .min_input_size = AES_BLOCK_SIZE,
1949         }, {
1950                 .name = "AES-256-HCTR2",
1951                 .encrypt = aes_256_hctr2_encrypt,
1952                 .decrypt = aes_256_hctr2_decrypt,
1953                 .keysize = AES_256_KEY_SIZE,
1954                 .min_input_size = AES_BLOCK_SIZE,
1955         }, {
1956                 .name = "Adiantum",
1957                 .encrypt = adiantum_encrypt,
1958                 .decrypt = adiantum_decrypt,
1959                 .keysize = ADIANTUM_KEY_SIZE,
1960                 .min_input_size = AES_BLOCK_SIZE,
1961         }
1962 };
1963
1964 static const struct fscrypt_cipher *find_fscrypt_cipher(const char *name)
1965 {
1966         size_t i;
1967
1968         for (i = 0; i < ARRAY_SIZE(fscrypt_ciphers); i++) {
1969                 if (strcmp(fscrypt_ciphers[i].name, name) == 0)
1970                         return &fscrypt_ciphers[i];
1971         }
1972         return NULL;
1973 }
1974
1975 union fscrypt_iv {
1976         /* usual IV format */
1977         struct {
1978                 /* data unit index within the file */
1979                 __le64 data_unit_index;
1980
1981                 /* per-file nonce; only set in DIRECT_KEY mode */
1982                 u8 nonce[FILE_NONCE_SIZE];
1983         };
1984         /* IV format for IV_INO_LBLK_* modes */
1985         struct {
1986                 /*
1987                  * IV_INO_LBLK_64: data unit index within the file
1988                  * IV_INO_LBLK_32: hashed inode number + data unit index within
1989                  *                 the file, mod 2^32
1990                  */
1991                 __le32 data_unit_index32;
1992
1993                 /* IV_INO_LBLK_64: inode number */
1994                 __le32 inode_number;
1995         };
1996         /* Any extra bytes up to the algorithm's IV size must be zeroed */
1997         u8 bytes[MAX_IV_SIZE];
1998 };
1999
2000 static void crypt_loop(const struct fscrypt_cipher *cipher, const u8 *key,
2001                        union fscrypt_iv *iv, bool decrypting,
2002                        size_t data_unit_size, size_t padding,
2003                        bool is_data_unit_index_32bit)
2004 {
2005         u8 *buf = xmalloc(data_unit_size);
2006         size_t res;
2007
2008         while ((res = xread(STDIN_FILENO, buf, data_unit_size)) > 0) {
2009                 size_t crypt_len = data_unit_size;
2010
2011                 if (padding > 0) {
2012                         crypt_len = MAX(res, cipher->min_input_size);
2013                         crypt_len = ROUND_UP(crypt_len, padding);
2014                         crypt_len = MIN(crypt_len, data_unit_size);
2015                 }
2016                 ASSERT(crypt_len >= res);
2017                 memset(&buf[res], 0, crypt_len - res);
2018
2019                 if (decrypting)
2020                         cipher->decrypt(key, iv->bytes, buf, buf, crypt_len);
2021                 else
2022                         cipher->encrypt(key, iv->bytes, buf, buf, crypt_len);
2023
2024                 full_write(STDOUT_FILENO, buf, crypt_len);
2025
2026                 if (is_data_unit_index_32bit)
2027                         iv->data_unit_index32 = cpu_to_le32(
2028                                 le32_to_cpu(iv->data_unit_index32) + 1);
2029                 else
2030                         iv->data_unit_index = cpu_to_le64(
2031                                 le64_to_cpu(iv->data_unit_index) + 1);
2032         }
2033         free(buf);
2034 }
2035
2036 /* The supported key derivation functions */
2037 enum kdf_algorithm {
2038         KDF_NONE,
2039         KDF_AES_128_ECB,
2040         KDF_HKDF_SHA512,
2041 };
2042
2043 static enum kdf_algorithm parse_kdf_algorithm(const char *arg)
2044 {
2045         if (strcmp(arg, "none") == 0)
2046                 return KDF_NONE;
2047         if (strcmp(arg, "AES-128-ECB") == 0)
2048                 return KDF_AES_128_ECB;
2049         if (strcmp(arg, "HKDF-SHA512") == 0)
2050                 return KDF_HKDF_SHA512;
2051         die("Unknown KDF: %s", arg);
2052 }
2053
2054 static u8 parse_mode_number(const char *arg)
2055 {
2056         char *tmp;
2057         long num = strtol(arg, &tmp, 10);
2058
2059         if (num <= 0 || *tmp || (u8)num != num)
2060                 die("Invalid mode number: %s", arg);
2061         return num;
2062 }
2063
2064 struct key_and_iv_params {
2065         u8 master_key[MAX_KEY_SIZE];
2066         int master_key_size;
2067         enum kdf_algorithm kdf;
2068         u8 mode_num;
2069         u8 file_nonce[FILE_NONCE_SIZE];
2070         bool file_nonce_specified;
2071         bool direct_key;
2072         bool iv_ino_lblk_64;
2073         bool iv_ino_lblk_32;
2074         u64 data_unit_index;
2075         u64 inode_number;
2076         u8 fs_uuid[UUID_SIZE];
2077         bool fs_uuid_specified;
2078 };
2079
2080 #define HKDF_CONTEXT_KEY_IDENTIFIER     1
2081 #define HKDF_CONTEXT_PER_FILE_ENC_KEY   2
2082 #define HKDF_CONTEXT_DIRECT_KEY         3
2083 #define HKDF_CONTEXT_IV_INO_LBLK_64_KEY 4
2084 #define HKDF_CONTEXT_DIRHASH_KEY        5
2085 #define HKDF_CONTEXT_IV_INO_LBLK_32_KEY 6
2086 #define HKDF_CONTEXT_INODE_HASH_KEY     7
2087
2088 /* Hash the file's inode number using SipHash keyed by a derived key */
2089 static u32 hash_inode_number(const struct key_and_iv_params *params)
2090 {
2091         u8 info[9] = "fscrypt";
2092         union {
2093                 u64 words[2];
2094                 u8 bytes[16];
2095         } hash_key;
2096
2097         info[8] = HKDF_CONTEXT_INODE_HASH_KEY;
2098
2099         if (params->kdf != KDF_HKDF_SHA512)
2100                 die("--iv-ino-lblk-32 requires --kdf=HKDF-SHA512");
2101         hkdf_sha512(params->master_key, params->master_key_size,
2102                     NULL, 0, info, sizeof(info),
2103                     hash_key.bytes, sizeof(hash_key));
2104
2105         hash_key.words[0] = get_unaligned_le64(&hash_key.bytes[0]);
2106         hash_key.words[1] = get_unaligned_le64(&hash_key.bytes[8]);
2107
2108         return (u32)siphash_1u64(hash_key.words, params->inode_number);
2109 }
2110
2111 static void derive_real_key(const struct key_and_iv_params *params,
2112                             u8 *real_key, size_t real_key_size)
2113 {
2114         struct aes_key aes_key;
2115         u8 info[8 + 1 + 1 + UUID_SIZE] = "fscrypt";
2116         size_t infolen = 8;
2117         size_t i;
2118
2119         ASSERT(real_key_size <= params->master_key_size);
2120
2121         switch (params->kdf) {
2122         case KDF_NONE:
2123                 memcpy(real_key, params->master_key, real_key_size);
2124                 break;
2125         case KDF_AES_128_ECB:
2126                 if (!params->file_nonce_specified)
2127                         die("--kdf=AES-128-ECB requires --file-nonce");
2128                 STATIC_ASSERT(FILE_NONCE_SIZE == AES_128_KEY_SIZE);
2129                 ASSERT(real_key_size % AES_BLOCK_SIZE == 0);
2130                 aes_setkey(&aes_key, params->file_nonce, AES_128_KEY_SIZE);
2131                 for (i = 0; i < real_key_size; i += AES_BLOCK_SIZE)
2132                         aes_encrypt(&aes_key, &params->master_key[i],
2133                                     &real_key[i]);
2134                 break;
2135         case KDF_HKDF_SHA512:
2136                 if (params->direct_key) {
2137                         if (params->mode_num == 0)
2138                                 die("--direct-key with KDF requires --mode-num");
2139                         info[infolen++] = HKDF_CONTEXT_DIRECT_KEY;
2140                         info[infolen++] = params->mode_num;
2141                 } else if (params->iv_ino_lblk_64) {
2142                         if (params->mode_num == 0)
2143                                 die("--iv-ino-lblk-64 with KDF requires --mode-num");
2144                         if (!params->fs_uuid_specified)
2145                                 die("--iv-ino-lblk-64 with KDF requires --fs-uuid");
2146                         info[infolen++] = HKDF_CONTEXT_IV_INO_LBLK_64_KEY;
2147                         info[infolen++] = params->mode_num;
2148                         memcpy(&info[infolen], params->fs_uuid, UUID_SIZE);
2149                         infolen += UUID_SIZE;
2150                 } else if (params->iv_ino_lblk_32) {
2151                         if (params->mode_num == 0)
2152                                 die("--iv-ino-lblk-32 with KDF requires --mode-num");
2153                         if (!params->fs_uuid_specified)
2154                                 die("--iv-ino-lblk-32 with KDF requires --fs-uuid");
2155                         info[infolen++] = HKDF_CONTEXT_IV_INO_LBLK_32_KEY;
2156                         info[infolen++] = params->mode_num;
2157                         memcpy(&info[infolen], params->fs_uuid, UUID_SIZE);
2158                         infolen += UUID_SIZE;
2159                 } else {
2160                         if (!params->file_nonce_specified)
2161                                 die("--kdf=HKDF-SHA512 requires --file-nonce or --iv-ino-lblk-{64,32}");
2162                         info[infolen++] = HKDF_CONTEXT_PER_FILE_ENC_KEY;
2163                         memcpy(&info[infolen], params->file_nonce,
2164                                FILE_NONCE_SIZE);
2165                         infolen += FILE_NONCE_SIZE;
2166                 }
2167                 hkdf_sha512(params->master_key, params->master_key_size,
2168                             NULL, 0, info, infolen, real_key, real_key_size);
2169                 break;
2170         default:
2171                 ASSERT(0);
2172         }
2173 }
2174
2175 static void generate_iv(const struct key_and_iv_params *params,
2176                         union fscrypt_iv *iv)
2177 {
2178         memset(iv, 0, sizeof(*iv));
2179         if (params->direct_key) {
2180                 if (!params->file_nonce_specified)
2181                         die("--direct-key requires --file-nonce");
2182                 iv->data_unit_index = cpu_to_le64(params->data_unit_index);
2183                 memcpy(iv->nonce, params->file_nonce, FILE_NONCE_SIZE);
2184         } else if (params->iv_ino_lblk_64) {
2185                 if (params->data_unit_index > UINT32_MAX)
2186                         die("iv-ino-lblk-64 can't use --data-unit-index > UINT32_MAX");
2187                 if (params->inode_number == 0)
2188                         die("iv-ino-lblk-64 requires --inode-number");
2189                 if (params->inode_number > UINT32_MAX)
2190                         die("iv-ino-lblk-64 can't use --inode-number > UINT32_MAX");
2191                 iv->data_unit_index32 = cpu_to_le32(params->data_unit_index);
2192                 iv->inode_number = cpu_to_le32(params->inode_number);
2193         } else if (params->iv_ino_lblk_32) {
2194                 if (params->data_unit_index > UINT32_MAX)
2195                         die("iv-ino-lblk-32 can't use --data-unit-index > UINT32_MAX");
2196                 if (params->inode_number == 0)
2197                         die("iv-ino-lblk-32 requires --inode-number");
2198                 iv->data_unit_index32 = cpu_to_le32(hash_inode_number(params) +
2199                                                     params->data_unit_index);
2200         } else {
2201                 iv->data_unit_index = cpu_to_le64(params->data_unit_index);
2202         }
2203 }
2204
2205 /*
2206  * Get the key and starting IV with which the encryption will actually be done.
2207  * If a KDF was specified, then a subkey is derived from the master key.
2208  * Otherwise, the master key is used directly.
2209  */
2210 static void get_key_and_iv(const struct key_and_iv_params *params,
2211                            u8 *real_key, size_t real_key_size,
2212                            union fscrypt_iv *iv)
2213 {
2214         int iv_methods = 0;
2215
2216         iv_methods += params->direct_key;
2217         iv_methods += params->iv_ino_lblk_64;
2218         iv_methods += params->iv_ino_lblk_32;
2219         if (iv_methods > 1)
2220                 die("Conflicting IV methods specified");
2221         if (iv_methods > 0 && params->kdf == KDF_AES_128_ECB)
2222                 die("--kdf=AES-128-ECB is incompatible with IV method options");
2223
2224         derive_real_key(params, real_key, real_key_size);
2225
2226         generate_iv(params, iv);
2227 }
2228
2229 static void do_dump_key_identifier(const struct key_and_iv_params *params)
2230 {
2231         u8 info[9] = "fscrypt";
2232         u8 key_identifier[16];
2233         int i;
2234
2235         info[8] = HKDF_CONTEXT_KEY_IDENTIFIER;
2236
2237         if (params->kdf != KDF_HKDF_SHA512)
2238                 die("--dump-key-identifier requires --kdf=HKDF-SHA512");
2239         hkdf_sha512(params->master_key, params->master_key_size,
2240                     NULL, 0, info, sizeof(info),
2241                     key_identifier, sizeof(key_identifier));
2242
2243         for (i = 0; i < sizeof(key_identifier); i++)
2244                 printf("%02x", key_identifier[i]);
2245 }
2246
2247 static void parse_master_key(const char *arg, struct key_and_iv_params *params)
2248 {
2249         params->master_key_size = hex2bin(arg, params->master_key,
2250                                           MAX_KEY_SIZE);
2251         if (params->master_key_size < 0)
2252                 die("Invalid master_key: %s", arg);
2253 }
2254
2255 enum {
2256         OPT_DATA_UNIT_INDEX,
2257         OPT_DATA_UNIT_SIZE,
2258         OPT_DECRYPT,
2259         OPT_DIRECT_KEY,
2260         OPT_DUMP_KEY_IDENTIFIER,
2261         OPT_FILE_NONCE,
2262         OPT_FS_UUID,
2263         OPT_HELP,
2264         OPT_INODE_NUMBER,
2265         OPT_IV_INO_LBLK_32,
2266         OPT_IV_INO_LBLK_64,
2267         OPT_KDF,
2268         OPT_MODE_NUM,
2269         OPT_PADDING,
2270 };
2271
2272 static const struct option longopts[] = {
2273         { "data-unit-index", required_argument, NULL, OPT_DATA_UNIT_INDEX },
2274         { "data-unit-size",  required_argument, NULL, OPT_DATA_UNIT_SIZE },
2275         { "decrypt",         no_argument,       NULL, OPT_DECRYPT },
2276         { "direct-key",      no_argument,       NULL, OPT_DIRECT_KEY },
2277         { "dump-key-identifier", no_argument,   NULL, OPT_DUMP_KEY_IDENTIFIER },
2278         { "file-nonce",      required_argument, NULL, OPT_FILE_NONCE },
2279         { "fs-uuid",         required_argument, NULL, OPT_FS_UUID },
2280         { "help",            no_argument,       NULL, OPT_HELP },
2281         { "inode-number",    required_argument, NULL, OPT_INODE_NUMBER },
2282         { "iv-ino-lblk-32",  no_argument,       NULL, OPT_IV_INO_LBLK_32 },
2283         { "iv-ino-lblk-64",  no_argument,       NULL, OPT_IV_INO_LBLK_64 },
2284         { "kdf",             required_argument, NULL, OPT_KDF },
2285         { "mode-num",        required_argument, NULL, OPT_MODE_NUM },
2286         { "padding",         required_argument, NULL, OPT_PADDING },
2287         { NULL, 0, NULL, 0 },
2288 };
2289
2290 int main(int argc, char *argv[])
2291 {
2292         size_t data_unit_size = 4096;
2293         bool decrypting = false;
2294         bool dump_key_identifier = false;
2295         struct key_and_iv_params params;
2296         size_t padding = 0;
2297         const struct fscrypt_cipher *cipher;
2298         u8 real_key[MAX_KEY_SIZE];
2299         union fscrypt_iv iv;
2300         char *tmp;
2301         int c;
2302
2303         memset(&params, 0, sizeof(params));
2304
2305         aes_init();
2306
2307 #ifdef ENABLE_ALG_TESTS
2308         test_aes();
2309         test_sha2();
2310         test_hkdf_sha512();
2311         test_aes_256_xts();
2312         test_aes_256_cts_cbc();
2313         test_adiantum();
2314         test_aes_256_hctr2();
2315 #endif
2316
2317         while ((c = getopt_long(argc, argv, "", longopts, NULL)) != -1) {
2318                 switch (c) {
2319                 case OPT_DATA_UNIT_INDEX:
2320                         errno = 0;
2321                         params.data_unit_index = strtoull(optarg, &tmp, 10);
2322                         if (*tmp || errno)
2323                                 die("Invalid data unit index: %s", optarg);
2324                         break;
2325                 case OPT_DATA_UNIT_SIZE:
2326                         errno = 0;
2327                         data_unit_size = strtoul(optarg, &tmp, 10);
2328                         if (data_unit_size <= 0 || *tmp || errno)
2329                                 die("Invalid data unit size: %s", optarg);
2330                         break;
2331                 case OPT_DECRYPT:
2332                         decrypting = true;
2333                         break;
2334                 case OPT_DIRECT_KEY:
2335                         params.direct_key = true;
2336                         break;
2337                 case OPT_DUMP_KEY_IDENTIFIER:
2338                         dump_key_identifier = true;
2339                         break;
2340                 case OPT_FILE_NONCE:
2341                         if (hex2bin(optarg, params.file_nonce, FILE_NONCE_SIZE)
2342                             != FILE_NONCE_SIZE)
2343                                 die("Invalid file nonce: %s", optarg);
2344                         params.file_nonce_specified = true;
2345                         break;
2346                 case OPT_FS_UUID:
2347                         if (hex2bin(optarg, params.fs_uuid, UUID_SIZE)
2348                             != UUID_SIZE)
2349                                 die("Invalid filesystem UUID: %s", optarg);
2350                         params.fs_uuid_specified = true;
2351                         break;
2352                 case OPT_HELP:
2353                         usage(stdout);
2354                         return 0;
2355                 case OPT_INODE_NUMBER:
2356                         errno = 0;
2357                         params.inode_number = strtoull(optarg, &tmp, 10);
2358                         if (params.inode_number <= 0 || *tmp || errno)
2359                                 die("Invalid inode number: %s", optarg);
2360                         break;
2361                 case OPT_IV_INO_LBLK_32:
2362                         params.iv_ino_lblk_32 = true;
2363                         break;
2364                 case OPT_IV_INO_LBLK_64:
2365                         params.iv_ino_lblk_64 = true;
2366                         break;
2367                 case OPT_KDF:
2368                         params.kdf = parse_kdf_algorithm(optarg);
2369                         break;
2370                 case OPT_MODE_NUM:
2371                         params.mode_num = parse_mode_number(optarg);
2372                         break;
2373                 case OPT_PADDING:
2374                         padding = strtoul(optarg, &tmp, 10);
2375                         if (padding <= 0 || *tmp || !is_power_of_2(padding) ||
2376                             padding > INT_MAX)
2377                                 die("Invalid padding amount: %s", optarg);
2378                         break;
2379                 default:
2380                         usage(stderr);
2381                         return 2;
2382                 }
2383         }
2384         argc -= optind;
2385         argv += optind;
2386
2387         if (dump_key_identifier) {
2388                 if (argc != 1) {
2389                         usage(stderr);
2390                         return 2;
2391                 }
2392                 parse_master_key(argv[0], &params);
2393                 do_dump_key_identifier(&params);
2394                 return 0;
2395         }
2396         if (argc != 2) {
2397                 usage(stderr);
2398                 return 2;
2399         }
2400
2401         cipher = find_fscrypt_cipher(argv[0]);
2402         if (cipher == NULL)
2403                 die("Unknown cipher: %s", argv[0]);
2404
2405         if (data_unit_size < cipher->min_input_size)
2406                 die("Data unit size of %zu bytes is too small for cipher %s",
2407                     data_unit_size, cipher->name);
2408
2409         parse_master_key(argv[1], &params);
2410
2411         if (params.master_key_size < cipher->keysize)
2412                 die("Master key is too short for cipher %s", cipher->name);
2413
2414         get_key_and_iv(&params, real_key, cipher->keysize, &iv);
2415
2416         crypt_loop(cipher, real_key, &iv, decrypting, data_unit_size, padding,
2417                    params.iv_ino_lblk_64 || params.iv_ino_lblk_32);
2418         return 0;
2419 }