03cc3c4a12997734195a455c1bfd2ba249525c5c
[xfstests-dev.git] / src / fscrypt-crypt-util.c
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * fscrypt-crypt-util.c - utility for verifying fscrypt-encrypted data
4  *
5  * Copyright 2019 Google LLC
6  */
7
8 /*
9  * This program implements all crypto algorithms supported by fscrypt (a.k.a.
10  * ext4, f2fs, and ubifs encryption), for the purpose of verifying the
11  * correctness of the ciphertext stored on-disk.  See usage() below.
12  *
13  * All algorithms are implemented in portable C code to avoid depending on
14  * libcrypto (OpenSSL), and because some fscrypt-supported algorithms aren't
15  * available in libcrypto anyway (e.g. Adiantum), or are only supported in
16  * recent versions (e.g. HKDF-SHA512).  For simplicity, all crypto code here
17  * tries to follow the mathematical definitions directly, without optimizing for
18  * performance or worrying about following security best practices such as
19  * mitigating side-channel attacks.  So, only use this program for testing!
20  */
21
22 #include <asm/byteorder.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <limits.h>
26 #include <linux/types.h>
27 #include <stdarg.h>
28 #include <stdbool.h>
29 #include <stdint.h>
30 #include <stdio.h>
31 #include <stdlib.h>
32 #include <string.h>
33 #include <unistd.h>
34
35 #define PROGRAM_NAME "fscrypt-crypt-util"
36
37 /*
38  * Define to enable the tests of the crypto code in this file.  If enabled, you
39  * must link this program with OpenSSL (-lcrypto) v1.1.0 or later, and your
40  * kernel needs CONFIG_CRYPTO_USER_API_SKCIPHER=y and CONFIG_CRYPTO_ADIANTUM=y.
41  */
42 #undef ENABLE_ALG_TESTS
43
44 #define NUM_ALG_TEST_ITERATIONS 10000
45
46 static void usage(FILE *fp)
47 {
48         fputs(
49 "Usage: " PROGRAM_NAME " [OPTION]... CIPHER MASTER_KEY\n"
50 "\n"
51 "Utility for verifying fscrypt-encrypted data.  This program encrypts\n"
52 "(or decrypts) the data on stdin using the given CIPHER with the given\n"
53 "MASTER_KEY (or a key derived from it, if a KDF is specified), and writes the\n"
54 "resulting ciphertext (or plaintext) to stdout.\n"
55 "\n"
56 "CIPHER can be AES-256-XTS, AES-256-CTS-CBC, AES-128-CBC-ESSIV, AES-128-CTS-CBC,\n"
57 "or Adiantum.  MASTER_KEY must be a hex string long enough for the cipher.\n"
58 "\n"
59 "WARNING: this program is only meant for testing, not for \"real\" use!\n"
60 "\n"
61 "Options:\n"
62 "  --block-number=BNUM         Starting block number for IV generation.\n"
63 "                                Default: 0\n"
64 "  --block-size=BLOCK_SIZE     Encrypt each BLOCK_SIZE bytes independently.\n"
65 "                                Default: 4096 bytes\n"
66 "  --decrypt                   Decrypt instead of encrypt\n"
67 "  --file-nonce=NONCE          File's nonce as a 32-character hex string\n"
68 "  --fs-uuid=UUID              The filesystem UUID as a 32-character hex string.\n"
69 "                                Required for --iv-ino-lblk-32 and\n"
70 "                                --iv-ino-lblk-64; otherwise is unused.\n"
71 "  --help                      Show this help\n"
72 "  --inode-number=INUM         The file's inode number.  Required for\n"
73 "                                --iv-ino-lblk-32 and --iv-ino-lblk-64;\n"
74 "                                otherwise is unused.\n"
75 "  --iv-ino-lblk-32            Similar to --iv-ino-lblk-64, but selects the\n"
76 "                                32-bit variant.\n"
77 "  --iv-ino-lblk-64            Use the format where the IVs include the inode\n"
78 "                                number and the same key is shared across files.\n"
79 "                                Requires --kdf=HKDF-SHA512, --fs-uuid,\n"
80 "                                --inode-number, and --mode-num.\n"
81 "  --kdf=KDF                   Key derivation function to use: AES-128-ECB,\n"
82 "                                HKDF-SHA512, or none.  Default: none\n"
83 "  --mode-num=NUM              Derive per-mode key using mode number NUM\n"
84 "  --padding=PADDING           If last block is partial, zero-pad it to next\n"
85 "                                PADDING-byte boundary.  Default: BLOCK_SIZE\n"
86         , fp);
87 }
88
89 /*----------------------------------------------------------------------------*
90  *                                 Utilities                                  *
91  *----------------------------------------------------------------------------*/
92
93 #define ARRAY_SIZE(A)           (sizeof(A) / sizeof((A)[0]))
94 #define MIN(x, y)               ((x) < (y) ? (x) : (y))
95 #define MAX(x, y)               ((x) > (y) ? (x) : (y))
96 #define ROUND_DOWN(x, y)        ((x) & ~((y) - 1))
97 #define ROUND_UP(x, y)          (((x) + (y) - 1) & ~((y) - 1))
98 #define DIV_ROUND_UP(n, d)      (((n) + (d) - 1) / (d))
99 #define STATIC_ASSERT(e)        ((void)sizeof(char[1 - 2*!(e)]))
100
101 typedef __u8                    u8;
102 typedef __u16                   u16;
103 typedef __u32                   u32;
104 typedef __u64                   u64;
105
106 #define cpu_to_le32             __cpu_to_le32
107 #define cpu_to_be32             __cpu_to_be32
108 #define cpu_to_le64             __cpu_to_le64
109 #define cpu_to_be64             __cpu_to_be64
110 #define le32_to_cpu             __le32_to_cpu
111 #define be32_to_cpu             __be32_to_cpu
112 #define le64_to_cpu             __le64_to_cpu
113 #define be64_to_cpu             __be64_to_cpu
114
115 #define DEFINE_UNALIGNED_ACCESS_HELPERS(type, native_type)      \
116 static inline native_type __attribute__((unused))               \
117 get_unaligned_##type(const void *p)                             \
118 {                                                               \
119         __##type x;                                             \
120                                                                 \
121         memcpy(&x, p, sizeof(x));                               \
122         return type##_to_cpu(x);                                \
123 }                                                               \
124                                                                 \
125 static inline void __attribute__((unused))                      \
126 put_unaligned_##type(native_type v, void *p)                    \
127 {                                                               \
128         __##type x = cpu_to_##type(v);                          \
129                                                                 \
130         memcpy(p, &x, sizeof(x));                               \
131 }
132
133 DEFINE_UNALIGNED_ACCESS_HELPERS(le32, u32)
134 DEFINE_UNALIGNED_ACCESS_HELPERS(be32, u32)
135 DEFINE_UNALIGNED_ACCESS_HELPERS(le64, u64)
136 DEFINE_UNALIGNED_ACCESS_HELPERS(be64, u64)
137
138 static inline bool is_power_of_2(unsigned long v)
139 {
140         return v != 0 && (v & (v - 1)) == 0;
141 }
142
143 static inline u32 rol32(u32 v, int n)
144 {
145         return (v << n) | (v >> (32 - n));
146 }
147
148 static inline u32 ror32(u32 v, int n)
149 {
150         return (v >> n) | (v << (32 - n));
151 }
152
153 static inline u64 rol64(u64 v, int n)
154 {
155         return (v << n) | (v >> (64 - n));
156 }
157
158 static inline u64 ror64(u64 v, int n)
159 {
160         return (v >> n) | (v << (64 - n));
161 }
162
163 static inline void xor(u8 *res, const u8 *a, const u8 *b, size_t count)
164 {
165         while (count--)
166                 *res++ = *a++ ^ *b++;
167 }
168
169 static void __attribute__((noreturn, format(printf, 2, 3)))
170 do_die(int err, const char *format, ...)
171 {
172         va_list va;
173
174         va_start(va, format);
175         fputs("[" PROGRAM_NAME "] ERROR: ", stderr);
176         vfprintf(stderr, format, va);
177         if (err)
178                 fprintf(stderr, ": %s", strerror(errno));
179         putc('\n', stderr);
180         va_end(va);
181         exit(1);
182 }
183
184 #define die(format, ...)        do_die(0,     (format), ##__VA_ARGS__)
185 #define die_errno(format, ...)  do_die(errno, (format), ##__VA_ARGS__)
186
187 static __attribute__((noreturn)) void
188 assertion_failed(const char *expr, const char *file, int line)
189 {
190         die("Assertion failed: %s at %s:%d", expr, file, line);
191 }
192
193 #define ASSERT(e) ({ if (!(e)) assertion_failed(#e, __FILE__, __LINE__); })
194
195 static void *xmalloc(size_t size)
196 {
197         void *p = malloc(size);
198
199         ASSERT(p != NULL);
200         return p;
201 }
202
203 static int hexchar2bin(char c)
204 {
205         if (c >= 'a' && c <= 'f')
206                 return 10 + c - 'a';
207         if (c >= 'A' && c <= 'F')
208                 return 10 + c - 'A';
209         if (c >= '0' && c <= '9')
210                 return c - '0';
211         return -1;
212 }
213
214 static int hex2bin(const char *hex, u8 *bin, int max_bin_size)
215 {
216         size_t len = strlen(hex);
217         size_t i;
218
219         if (len & 1)
220                 return -1;
221         len /= 2;
222         if (len > max_bin_size)
223                 return -1;
224
225         for (i = 0; i < len; i++) {
226                 int high = hexchar2bin(hex[2 * i]);
227                 int low = hexchar2bin(hex[2 * i + 1]);
228
229                 if (high < 0 || low < 0)
230                         return -1;
231                 bin[i] = (high << 4) | low;
232         }
233         return len;
234 }
235
236 static size_t xread(int fd, void *buf, size_t count)
237 {
238         const size_t orig_count = count;
239
240         while (count) {
241                 ssize_t res = read(fd, buf, count);
242
243                 if (res < 0)
244                         die_errno("read error");
245                 if (res == 0)
246                         break;
247                 buf += res;
248                 count -= res;
249         }
250         return orig_count - count;
251 }
252
253 static void full_write(int fd, const void *buf, size_t count)
254 {
255         while (count) {
256                 ssize_t res = write(fd, buf, count);
257
258                 if (res < 0)
259                         die_errno("write error");
260                 buf += res;
261                 count -= res;
262         }
263 }
264
265 #ifdef ENABLE_ALG_TESTS
266 static void rand_bytes(u8 *buf, size_t count)
267 {
268         while (count--)
269                 *buf++ = rand();
270 }
271 #endif
272
273 /*----------------------------------------------------------------------------*
274  *                          Finite field arithmetic                           *
275  *----------------------------------------------------------------------------*/
276
277 /* Multiply a GF(2^8) element by the polynomial 'x' */
278 static inline u8 gf2_8_mul_x(u8 b)
279 {
280         return (b << 1) ^ ((b & 0x80) ? 0x1B : 0);
281 }
282
283 /* Multiply four packed GF(2^8) elements by the polynomial 'x' */
284 static inline u32 gf2_8_mul_x_4way(u32 w)
285 {
286         return ((w & 0x7F7F7F7F) << 1) ^ (((w & 0x80808080) >> 7) * 0x1B);
287 }
288
289 /* Element of GF(2^128) */
290 typedef struct {
291         __le64 lo;
292         __le64 hi;
293 } ble128;
294
295 /* Multiply a GF(2^128) element by the polynomial 'x' */
296 static inline void gf2_128_mul_x(ble128 *t)
297 {
298         u64 lo = le64_to_cpu(t->lo);
299         u64 hi = le64_to_cpu(t->hi);
300
301         t->hi = cpu_to_le64((hi << 1) | (lo >> 63));
302         t->lo = cpu_to_le64((lo << 1) ^ ((hi & (1ULL << 63)) ? 0x87 : 0));
303 }
304
305 /*----------------------------------------------------------------------------*
306  *                             Group arithmetic                               *
307  *----------------------------------------------------------------------------*/
308
309 /* Element of Z/(2^{128}Z)  (a.k.a. the integers modulo 2^128) */
310 typedef struct {
311         __le64 lo;
312         __le64 hi;
313 } le128;
314
315 static inline void le128_add(le128 *res, const le128 *a, const le128 *b)
316 {
317         u64 a_lo = le64_to_cpu(a->lo);
318         u64 b_lo = le64_to_cpu(b->lo);
319
320         res->lo = cpu_to_le64(a_lo + b_lo);
321         res->hi = cpu_to_le64(le64_to_cpu(a->hi) + le64_to_cpu(b->hi) +
322                               (a_lo + b_lo < a_lo));
323 }
324
325 static inline void le128_sub(le128 *res, const le128 *a, const le128 *b)
326 {
327         u64 a_lo = le64_to_cpu(a->lo);
328         u64 b_lo = le64_to_cpu(b->lo);
329
330         res->lo = cpu_to_le64(a_lo - b_lo);
331         res->hi = cpu_to_le64(le64_to_cpu(a->hi) - le64_to_cpu(b->hi) -
332                               (a_lo - b_lo > a_lo));
333 }
334
335 /*----------------------------------------------------------------------------*
336  *                              AES block cipher                              *
337  *----------------------------------------------------------------------------*/
338
339 /*
340  * Reference: "FIPS 197, Advanced Encryption Standard"
341  *      https://nvlpubs.nist.gov/nistpubs/fips/nist.fips.197.pdf
342  */
343
344 #define AES_BLOCK_SIZE          16
345 #define AES_128_KEY_SIZE        16
346 #define AES_192_KEY_SIZE        24
347 #define AES_256_KEY_SIZE        32
348
349 static inline void AddRoundKey(u32 state[4], const u32 *rk)
350 {
351         int i;
352
353         for (i = 0; i < 4; i++)
354                 state[i] ^= rk[i];
355 }
356
357 static const u8 aes_sbox[256] = {
358         0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b,
359         0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0,
360         0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26,
361         0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
362         0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2,
363         0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0,
364         0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed,
365         0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
366         0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f,
367         0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5,
368         0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec,
369         0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
370         0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14,
371         0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c,
372         0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d,
373         0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
374         0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f,
375         0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e,
376         0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11,
377         0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
378         0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f,
379         0xb0, 0x54, 0xbb, 0x16,
380 };
381
382 static u8 aes_inverse_sbox[256];
383
384 static void aes_init(void)
385 {
386         int i;
387
388         for (i = 0; i < 256; i++)
389                 aes_inverse_sbox[aes_sbox[i]] = i;
390 }
391
392 static inline u32 DoSubWord(u32 w, const u8 sbox[256])
393 {
394         return ((u32)sbox[(u8)(w >> 24)] << 24) |
395                ((u32)sbox[(u8)(w >> 16)] << 16) |
396                ((u32)sbox[(u8)(w >>  8)] <<  8) |
397                ((u32)sbox[(u8)(w >>  0)] <<  0);
398 }
399
400 static inline u32 SubWord(u32 w)
401 {
402         return DoSubWord(w, aes_sbox);
403 }
404
405 static inline u32 InvSubWord(u32 w)
406 {
407         return DoSubWord(w, aes_inverse_sbox);
408 }
409
410 static inline void SubBytes(u32 state[4])
411 {
412         int i;
413
414         for (i = 0; i < 4; i++)
415                 state[i] = SubWord(state[i]);
416 }
417
418 static inline void InvSubBytes(u32 state[4])
419 {
420         int i;
421
422         for (i = 0; i < 4; i++)
423                 state[i] = InvSubWord(state[i]);
424 }
425
426 static inline void DoShiftRows(u32 state[4], int direction)
427 {
428         u32 newstate[4];
429         int i;
430
431         for (i = 0; i < 4; i++)
432                 newstate[i] = (state[(i + direction*0) & 3] & 0xff) |
433                               (state[(i + direction*1) & 3] & 0xff00) |
434                               (state[(i + direction*2) & 3] & 0xff0000) |
435                               (state[(i + direction*3) & 3] & 0xff000000);
436         memcpy(state, newstate, 16);
437 }
438
439 static inline void ShiftRows(u32 state[4])
440 {
441         DoShiftRows(state, 1);
442 }
443
444 static inline void InvShiftRows(u32 state[4])
445 {
446         DoShiftRows(state, -1);
447 }
448
449 /*
450  * Mix one column by doing the following matrix multiplication in GF(2^8):
451  *
452  *     | 2 3 1 1 |   | w[0] |
453  *     | 1 2 3 1 |   | w[1] |
454  *     | 1 1 2 3 | x | w[2] |
455  *     | 3 1 1 2 |   | w[3] |
456  *
457  * a.k.a. w[i] = 2*w[i] + 3*w[(i+1)%4] + w[(i+2)%4] + w[(i+3)%4]
458  */
459 static inline u32 MixColumn(u32 w)
460 {
461         u32 _2w0_w2 = gf2_8_mul_x_4way(w) ^ ror32(w, 16);
462         u32 _3w1_w3 = ror32(_2w0_w2 ^ w, 8);
463
464         return _2w0_w2 ^ _3w1_w3;
465 }
466
467 /*
468  *           ( | 5 0 4 0 |   | w[0] | )
469  *          (  | 0 5 0 4 |   | w[1] |  )
470  * MixColumn(  | 4 0 5 0 | x | w[2] |  )
471  *           ( | 0 4 0 5 |   | w[3] | )
472  */
473 static inline u32 InvMixColumn(u32 w)
474 {
475         u32 _4w = gf2_8_mul_x_4way(gf2_8_mul_x_4way(w));
476
477         return MixColumn(_4w ^ w ^ ror32(_4w, 16));
478 }
479
480 static inline void MixColumns(u32 state[4])
481 {
482         int i;
483
484         for (i = 0; i < 4; i++)
485                 state[i] = MixColumn(state[i]);
486 }
487
488 static inline void InvMixColumns(u32 state[4])
489 {
490         int i;
491
492         for (i = 0; i < 4; i++)
493                 state[i] = InvMixColumn(state[i]);
494 }
495
496 struct aes_key {
497         u32 round_keys[15 * 4];
498         int nrounds;
499 };
500
501 /* Expand an AES key */
502 static void aes_setkey(struct aes_key *k, const u8 *key, int keysize)
503 {
504         const int N = keysize / 4;
505         u32 * const rk = k->round_keys;
506         u8 rcon = 1;
507         int i;
508
509         ASSERT(keysize == 16 || keysize == 24 || keysize == 32);
510         k->nrounds = 6 + N;
511         for (i = 0; i < 4 * (k->nrounds + 1); i++) {
512                 if (i < N) {
513                         rk[i] = get_unaligned_le32(&key[i * sizeof(__le32)]);
514                 } else if (i % N == 0) {
515                         rk[i] = rk[i - N] ^ SubWord(ror32(rk[i - 1], 8)) ^ rcon;
516                         rcon = gf2_8_mul_x(rcon);
517                 } else if (N > 6 && i % N == 4) {
518                         rk[i] = rk[i - N] ^ SubWord(rk[i - 1]);
519                 } else {
520                         rk[i] = rk[i - N] ^ rk[i - 1];
521                 }
522         }
523 }
524
525 /* Encrypt one 16-byte block with AES */
526 static void aes_encrypt(const struct aes_key *k, const u8 src[AES_BLOCK_SIZE],
527                         u8 dst[AES_BLOCK_SIZE])
528 {
529         u32 state[4];
530         int i;
531
532         for (i = 0; i < 4; i++)
533                 state[i] = get_unaligned_le32(&src[i * sizeof(__le32)]);
534
535         AddRoundKey(state, k->round_keys);
536         for (i = 1; i < k->nrounds; i++) {
537                 SubBytes(state);
538                 ShiftRows(state);
539                 MixColumns(state);
540                 AddRoundKey(state, &k->round_keys[4 * i]);
541         }
542         SubBytes(state);
543         ShiftRows(state);
544         AddRoundKey(state, &k->round_keys[4 * i]);
545
546         for (i = 0; i < 4; i++)
547                 put_unaligned_le32(state[i], &dst[i * sizeof(__le32)]);
548 }
549
550 /* Decrypt one 16-byte block with AES */
551 static void aes_decrypt(const struct aes_key *k, const u8 src[AES_BLOCK_SIZE],
552                         u8 dst[AES_BLOCK_SIZE])
553 {
554         u32 state[4];
555         int i;
556
557         for (i = 0; i < 4; i++)
558                 state[i] = get_unaligned_le32(&src[i * sizeof(__le32)]);
559
560         AddRoundKey(state, &k->round_keys[4 * k->nrounds]);
561         InvShiftRows(state);
562         InvSubBytes(state);
563         for (i = k->nrounds - 1; i >= 1; i--) {
564                 AddRoundKey(state, &k->round_keys[4 * i]);
565                 InvMixColumns(state);
566                 InvShiftRows(state);
567                 InvSubBytes(state);
568         }
569         AddRoundKey(state, k->round_keys);
570
571         for (i = 0; i < 4; i++)
572                 put_unaligned_le32(state[i], &dst[i * sizeof(__le32)]);
573 }
574
575 #ifdef ENABLE_ALG_TESTS
576 #include <openssl/aes.h>
577 static void test_aes_keysize(int keysize)
578 {
579         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
580
581         while (num_tests--) {
582                 struct aes_key k;
583                 AES_KEY ref_k;
584                 u8 key[AES_256_KEY_SIZE];
585                 u8 ptext[AES_BLOCK_SIZE];
586                 u8 ctext[AES_BLOCK_SIZE];
587                 u8 ref_ctext[AES_BLOCK_SIZE];
588                 u8 decrypted[AES_BLOCK_SIZE];
589
590                 rand_bytes(key, keysize);
591                 rand_bytes(ptext, AES_BLOCK_SIZE);
592
593                 aes_setkey(&k, key, keysize);
594                 aes_encrypt(&k, ptext, ctext);
595
596                 ASSERT(AES_set_encrypt_key(key, keysize*8, &ref_k) == 0);
597                 AES_encrypt(ptext, ref_ctext, &ref_k);
598
599                 ASSERT(memcmp(ctext, ref_ctext, AES_BLOCK_SIZE) == 0);
600
601                 aes_decrypt(&k, ctext, decrypted);
602                 ASSERT(memcmp(ptext, decrypted, AES_BLOCK_SIZE) == 0);
603         }
604 }
605
606 static void test_aes(void)
607 {
608         test_aes_keysize(AES_128_KEY_SIZE);
609         test_aes_keysize(AES_192_KEY_SIZE);
610         test_aes_keysize(AES_256_KEY_SIZE);
611 }
612 #endif /* ENABLE_ALG_TESTS */
613
614 /*----------------------------------------------------------------------------*
615  *                            SHA-512 and SHA-256                             *
616  *----------------------------------------------------------------------------*/
617
618 /*
619  * Reference: "FIPS 180-2, Secure Hash Standard"
620  *      https://csrc.nist.gov/csrc/media/publications/fips/180/2/archive/2002-08-01/documents/fips180-2withchangenotice.pdf
621  */
622
623 #define SHA512_DIGEST_SIZE      64
624 #define SHA512_BLOCK_SIZE       128
625
626 #define SHA256_DIGEST_SIZE      32
627 #define SHA256_BLOCK_SIZE       64
628
629 #define Ch(x, y, z)     (((x) & (y)) ^ (~(x) & (z)))
630 #define Maj(x, y, z)    (((x) & (y)) ^ ((x) & (z)) ^ ((y) & (z)))
631
632 #define Sigma512_0(x)   (ror64((x), 28) ^ ror64((x), 34) ^ ror64((x), 39))
633 #define Sigma512_1(x)   (ror64((x), 14) ^ ror64((x), 18) ^ ror64((x), 41))
634 #define sigma512_0(x)   (ror64((x),  1) ^ ror64((x),  8) ^ ((x) >> 7))
635 #define sigma512_1(x)   (ror64((x), 19) ^ ror64((x), 61) ^ ((x) >> 6))
636
637 #define Sigma256_0(x)   (ror32((x),  2) ^ ror32((x), 13) ^ ror32((x), 22))
638 #define Sigma256_1(x)   (ror32((x),  6) ^ ror32((x), 11) ^ ror32((x), 25))
639 #define sigma256_0(x)   (ror32((x),  7) ^ ror32((x), 18) ^ ((x) >>  3))
640 #define sigma256_1(x)   (ror32((x), 17) ^ ror32((x), 19) ^ ((x) >> 10))
641
642 static const u64 sha512_iv[8] = {
643         0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b,
644         0xa54ff53a5f1d36f1, 0x510e527fade682d1, 0x9b05688c2b3e6c1f,
645         0x1f83d9abfb41bd6b, 0x5be0cd19137e2179,
646 };
647
648 static const u64 sha512_round_constants[80] = {
649         0x428a2f98d728ae22, 0x7137449123ef65cd, 0xb5c0fbcfec4d3b2f,
650         0xe9b5dba58189dbbc, 0x3956c25bf348b538, 0x59f111f1b605d019,
651         0x923f82a4af194f9b, 0xab1c5ed5da6d8118, 0xd807aa98a3030242,
652         0x12835b0145706fbe, 0x243185be4ee4b28c, 0x550c7dc3d5ffb4e2,
653         0x72be5d74f27b896f, 0x80deb1fe3b1696b1, 0x9bdc06a725c71235,
654         0xc19bf174cf692694, 0xe49b69c19ef14ad2, 0xefbe4786384f25e3,
655         0x0fc19dc68b8cd5b5, 0x240ca1cc77ac9c65, 0x2de92c6f592b0275,
656         0x4a7484aa6ea6e483, 0x5cb0a9dcbd41fbd4, 0x76f988da831153b5,
657         0x983e5152ee66dfab, 0xa831c66d2db43210, 0xb00327c898fb213f,
658         0xbf597fc7beef0ee4, 0xc6e00bf33da88fc2, 0xd5a79147930aa725,
659         0x06ca6351e003826f, 0x142929670a0e6e70, 0x27b70a8546d22ffc,
660         0x2e1b21385c26c926, 0x4d2c6dfc5ac42aed, 0x53380d139d95b3df,
661         0x650a73548baf63de, 0x766a0abb3c77b2a8, 0x81c2c92e47edaee6,
662         0x92722c851482353b, 0xa2bfe8a14cf10364, 0xa81a664bbc423001,
663         0xc24b8b70d0f89791, 0xc76c51a30654be30, 0xd192e819d6ef5218,
664         0xd69906245565a910, 0xf40e35855771202a, 0x106aa07032bbd1b8,
665         0x19a4c116b8d2d0c8, 0x1e376c085141ab53, 0x2748774cdf8eeb99,
666         0x34b0bcb5e19b48a8, 0x391c0cb3c5c95a63, 0x4ed8aa4ae3418acb,
667         0x5b9cca4f7763e373, 0x682e6ff3d6b2b8a3, 0x748f82ee5defb2fc,
668         0x78a5636f43172f60, 0x84c87814a1f0ab72, 0x8cc702081a6439ec,
669         0x90befffa23631e28, 0xa4506cebde82bde9, 0xbef9a3f7b2c67915,
670         0xc67178f2e372532b, 0xca273eceea26619c, 0xd186b8c721c0c207,
671         0xeada7dd6cde0eb1e, 0xf57d4f7fee6ed178, 0x06f067aa72176fba,
672         0x0a637dc5a2c898a6, 0x113f9804bef90dae, 0x1b710b35131c471b,
673         0x28db77f523047d84, 0x32caab7b40c72493, 0x3c9ebe0a15c9bebc,
674         0x431d67c49c100d4c, 0x4cc5d4becb3e42b6, 0x597f299cfc657e2a,
675         0x5fcb6fab3ad6faec, 0x6c44198c4a475817,
676 };
677
678 /* Compute the SHA-512 digest of the given buffer */
679 static void sha512(const u8 *in, size_t inlen, u8 out[SHA512_DIGEST_SIZE])
680 {
681         const size_t msglen = ROUND_UP(inlen + 17, SHA512_BLOCK_SIZE);
682         u8 * const msg = xmalloc(msglen);
683         u64 H[8];
684         int i;
685
686         /* super naive way of handling the padding */
687         memcpy(msg, in, inlen);
688         memset(&msg[inlen], 0, msglen - inlen);
689         msg[inlen] = 0x80;
690         put_unaligned_be64((u64)inlen * 8, &msg[msglen - sizeof(__be64)]);
691         in = msg;
692
693         memcpy(H, sha512_iv, sizeof(H));
694         do {
695                 u64 a = H[0], b = H[1], c = H[2], d = H[3],
696                     e = H[4], f = H[5], g = H[6], h = H[7];
697                 u64 W[80];
698
699                 for (i = 0; i < 16; i++)
700                         W[i] = get_unaligned_be64(&in[i * sizeof(__be64)]);
701                 for (; i < ARRAY_SIZE(W); i++)
702                         W[i] = sigma512_1(W[i - 2]) + W[i - 7] +
703                                sigma512_0(W[i - 15]) + W[i - 16];
704                 for (i = 0; i < ARRAY_SIZE(W); i++) {
705                         u64 T1 = h + Sigma512_1(e) + Ch(e, f, g) +
706                                  sha512_round_constants[i] + W[i];
707                         u64 T2 = Sigma512_0(a) + Maj(a, b, c);
708
709                         h = g; g = f; f = e; e = d + T1;
710                         d = c; c = b; b = a; a = T1 + T2;
711                 }
712                 H[0] += a; H[1] += b; H[2] += c; H[3] += d;
713                 H[4] += e; H[5] += f; H[6] += g; H[7] += h;
714         } while ((in += SHA512_BLOCK_SIZE) != &msg[msglen]);
715
716         for (i = 0; i < ARRAY_SIZE(H); i++)
717                 put_unaligned_be64(H[i], &out[i * sizeof(__be64)]);
718         free(msg);
719 }
720
721 /* Compute the SHA-256 digest of the given buffer */
722 static void sha256(const u8 *in, size_t inlen, u8 out[SHA256_DIGEST_SIZE])
723 {
724         const size_t msglen = ROUND_UP(inlen + 9, SHA256_BLOCK_SIZE);
725         u8 * const msg = xmalloc(msglen);
726         u32 H[8];
727         int i;
728
729         /* super naive way of handling the padding */
730         memcpy(msg, in, inlen);
731         memset(&msg[inlen], 0, msglen - inlen);
732         msg[inlen] = 0x80;
733         put_unaligned_be64((u64)inlen * 8, &msg[msglen - sizeof(__be64)]);
734         in = msg;
735
736         for (i = 0; i < ARRAY_SIZE(H); i++)
737                 H[i] = (u32)(sha512_iv[i] >> 32);
738         do {
739                 u32 a = H[0], b = H[1], c = H[2], d = H[3],
740                     e = H[4], f = H[5], g = H[6], h = H[7];
741                 u32 W[64];
742
743                 for (i = 0; i < 16; i++)
744                         W[i] = get_unaligned_be32(&in[i * sizeof(__be32)]);
745                 for (; i < ARRAY_SIZE(W); i++)
746                         W[i] = sigma256_1(W[i - 2]) + W[i - 7] +
747                                sigma256_0(W[i - 15]) + W[i - 16];
748                 for (i = 0; i < ARRAY_SIZE(W); i++) {
749                         u32 T1 = h + Sigma256_1(e) + Ch(e, f, g) +
750                                  (u32)(sha512_round_constants[i] >> 32) + W[i];
751                         u32 T2 = Sigma256_0(a) + Maj(a, b, c);
752
753                         h = g; g = f; f = e; e = d + T1;
754                         d = c; c = b; b = a; a = T1 + T2;
755                 }
756                 H[0] += a; H[1] += b; H[2] += c; H[3] += d;
757                 H[4] += e; H[5] += f; H[6] += g; H[7] += h;
758         } while ((in += SHA256_BLOCK_SIZE) != &msg[msglen]);
759
760         for (i = 0; i < ARRAY_SIZE(H); i++)
761                 put_unaligned_be32(H[i], &out[i * sizeof(__be32)]);
762         free(msg);
763 }
764
765 #ifdef ENABLE_ALG_TESTS
766 #include <openssl/sha.h>
767 static void test_sha2(void)
768 {
769         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
770
771         while (num_tests--) {
772                 u8 in[4096];
773                 u8 digest[SHA512_DIGEST_SIZE];
774                 u8 ref_digest[SHA512_DIGEST_SIZE];
775                 const size_t inlen = rand() % (1 + sizeof(in));
776
777                 rand_bytes(in, inlen);
778
779                 sha256(in, inlen, digest);
780                 SHA256(in, inlen, ref_digest);
781                 ASSERT(memcmp(digest, ref_digest, SHA256_DIGEST_SIZE) == 0);
782
783                 sha512(in, inlen, digest);
784                 SHA512(in, inlen, ref_digest);
785                 ASSERT(memcmp(digest, ref_digest, SHA512_DIGEST_SIZE) == 0);
786         }
787 }
788 #endif /* ENABLE_ALG_TESTS */
789
790 /*----------------------------------------------------------------------------*
791  *                            HKDF implementation                             *
792  *----------------------------------------------------------------------------*/
793
794 static void hmac_sha512(const u8 *key, size_t keylen, const u8 *msg,
795                         size_t msglen, u8 mac[SHA512_DIGEST_SIZE])
796 {
797         u8 *ibuf = xmalloc(SHA512_BLOCK_SIZE + msglen);
798         u8 obuf[SHA512_BLOCK_SIZE + SHA512_DIGEST_SIZE];
799
800         ASSERT(keylen <= SHA512_BLOCK_SIZE); /* keylen > bs not implemented */
801
802         memset(ibuf, 0x36, SHA512_BLOCK_SIZE);
803         xor(ibuf, ibuf, key, keylen);
804         memcpy(&ibuf[SHA512_BLOCK_SIZE], msg, msglen);
805
806         memset(obuf, 0x5c, SHA512_BLOCK_SIZE);
807         xor(obuf, obuf, key, keylen);
808         sha512(ibuf, SHA512_BLOCK_SIZE + msglen, &obuf[SHA512_BLOCK_SIZE]);
809         sha512(obuf, sizeof(obuf), mac);
810
811         free(ibuf);
812 }
813
814 static void hkdf_sha512(const u8 *ikm, size_t ikmlen,
815                         const u8 *salt, size_t saltlen,
816                         const u8 *info, size_t infolen,
817                         u8 *output, size_t outlen)
818 {
819         static const u8 default_salt[SHA512_DIGEST_SIZE];
820         u8 prk[SHA512_DIGEST_SIZE]; /* pseudorandom key */
821         u8 *buf = xmalloc(1 + infolen + SHA512_DIGEST_SIZE);
822         u8 counter = 1;
823         size_t i;
824
825         if (saltlen == 0) {
826                 salt = default_salt;
827                 saltlen = sizeof(default_salt);
828         }
829
830         /* HKDF-Extract */
831         ASSERT(ikmlen > 0);
832         hmac_sha512(salt, saltlen, ikm, ikmlen, prk);
833
834         /* HKDF-Expand */
835         for (i = 0; i < outlen; i += SHA512_DIGEST_SIZE) {
836                 u8 *p = buf;
837                 u8 tmp[SHA512_DIGEST_SIZE];
838
839                 ASSERT(counter != 0);
840                 if (i > 0) {
841                         memcpy(p, &output[i - SHA512_DIGEST_SIZE],
842                                SHA512_DIGEST_SIZE);
843                         p += SHA512_DIGEST_SIZE;
844                 }
845                 memcpy(p, info, infolen);
846                 p += infolen;
847                 *p++ = counter++;
848                 hmac_sha512(prk, sizeof(prk), buf, p - buf, tmp);
849                 memcpy(&output[i], tmp, MIN(sizeof(tmp), outlen - i));
850         }
851         free(buf);
852 }
853
854 #ifdef ENABLE_ALG_TESTS
855 #include <openssl/evp.h>
856 #include <openssl/kdf.h>
857 static void openssl_hkdf_sha512(const u8 *ikm, size_t ikmlen,
858                                 const u8 *salt, size_t saltlen,
859                                 const u8 *info, size_t infolen,
860                                 u8 *output, size_t outlen)
861 {
862         EVP_PKEY_CTX *pctx = EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, NULL);
863         size_t actual_outlen = outlen;
864
865         ASSERT(pctx != NULL);
866         ASSERT(EVP_PKEY_derive_init(pctx) > 0);
867         ASSERT(EVP_PKEY_CTX_set_hkdf_md(pctx, EVP_sha512()) > 0);
868         ASSERT(EVP_PKEY_CTX_set1_hkdf_key(pctx, ikm, ikmlen) > 0);
869         ASSERT(EVP_PKEY_CTX_set1_hkdf_salt(pctx, salt, saltlen) > 0);
870         ASSERT(EVP_PKEY_CTX_add1_hkdf_info(pctx, info, infolen) > 0);
871         ASSERT(EVP_PKEY_derive(pctx, output, &actual_outlen) > 0);
872         ASSERT(actual_outlen == outlen);
873         EVP_PKEY_CTX_free(pctx);
874 }
875
876 static void test_hkdf_sha512(void)
877 {
878         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
879
880         while (num_tests--) {
881                 u8 ikm[SHA512_DIGEST_SIZE];
882                 u8 salt[SHA512_DIGEST_SIZE];
883                 u8 info[128];
884                 u8 actual_output[512];
885                 u8 expected_output[sizeof(actual_output)];
886                 size_t ikmlen = 1 + (rand() % sizeof(ikm));
887                 size_t saltlen = rand() % (1 + sizeof(salt));
888                 size_t infolen = rand() % (1 + sizeof(info));
889                 size_t outlen = rand() % (1 + sizeof(actual_output));
890
891                 rand_bytes(ikm, ikmlen);
892                 rand_bytes(salt, saltlen);
893                 rand_bytes(info, infolen);
894
895                 hkdf_sha512(ikm, ikmlen, salt, saltlen, info, infolen,
896                             actual_output, outlen);
897                 openssl_hkdf_sha512(ikm, ikmlen, salt, saltlen, info, infolen,
898                                     expected_output, outlen);
899                 ASSERT(memcmp(actual_output, expected_output, outlen) == 0);
900         }
901 }
902 #endif /* ENABLE_ALG_TESTS */
903
904 /*----------------------------------------------------------------------------*
905  *                            AES encryption modes                            *
906  *----------------------------------------------------------------------------*/
907
908 static void aes_256_xts_crypt(const u8 key[2 * AES_256_KEY_SIZE],
909                               const u8 iv[AES_BLOCK_SIZE], const u8 *src,
910                               u8 *dst, size_t nbytes, bool decrypting)
911 {
912         struct aes_key tweak_key, cipher_key;
913         ble128 t;
914         size_t i;
915
916         ASSERT(nbytes % AES_BLOCK_SIZE == 0);
917         aes_setkey(&cipher_key, key, AES_256_KEY_SIZE);
918         aes_setkey(&tweak_key, &key[AES_256_KEY_SIZE], AES_256_KEY_SIZE);
919         aes_encrypt(&tweak_key, iv, (u8 *)&t);
920         for (i = 0; i < nbytes; i += AES_BLOCK_SIZE) {
921                 xor(&dst[i], &src[i], (const u8 *)&t, AES_BLOCK_SIZE);
922                 if (decrypting)
923                         aes_decrypt(&cipher_key, &dst[i], &dst[i]);
924                 else
925                         aes_encrypt(&cipher_key, &dst[i], &dst[i]);
926                 xor(&dst[i], &dst[i], (const u8 *)&t, AES_BLOCK_SIZE);
927                 gf2_128_mul_x(&t);
928         }
929 }
930
931 static void aes_256_xts_encrypt(const u8 key[2 * AES_256_KEY_SIZE],
932                                 const u8 iv[AES_BLOCK_SIZE], const u8 *src,
933                                 u8 *dst, size_t nbytes)
934 {
935         aes_256_xts_crypt(key, iv, src, dst, nbytes, false);
936 }
937
938 static void aes_256_xts_decrypt(const u8 key[2 * AES_256_KEY_SIZE],
939                                 const u8 iv[AES_BLOCK_SIZE], const u8 *src,
940                                 u8 *dst, size_t nbytes)
941 {
942         aes_256_xts_crypt(key, iv, src, dst, nbytes, true);
943 }
944
945 #ifdef ENABLE_ALG_TESTS
946 #include <openssl/evp.h>
947 static void test_aes_256_xts(void)
948 {
949         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
950         EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
951
952         ASSERT(ctx != NULL);
953         while (num_tests--) {
954                 u8 key[2 * AES_256_KEY_SIZE];
955                 u8 iv[AES_BLOCK_SIZE];
956                 u8 ptext[512];
957                 u8 ctext[sizeof(ptext)];
958                 u8 ref_ctext[sizeof(ptext)];
959                 u8 decrypted[sizeof(ptext)];
960                 const size_t datalen = ROUND_DOWN(rand() % (1 + sizeof(ptext)),
961                                                   AES_BLOCK_SIZE);
962                 int outl, res;
963
964                 rand_bytes(key, sizeof(key));
965                 rand_bytes(iv, sizeof(iv));
966                 rand_bytes(ptext, datalen);
967
968                 aes_256_xts_encrypt(key, iv, ptext, ctext, datalen);
969                 res = EVP_EncryptInit_ex(ctx, EVP_aes_256_xts(), NULL, key, iv);
970                 ASSERT(res > 0);
971                 res = EVP_EncryptUpdate(ctx, ref_ctext, &outl, ptext, datalen);
972                 ASSERT(res > 0);
973                 ASSERT(outl == datalen);
974                 ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
975
976                 aes_256_xts_decrypt(key, iv, ctext, decrypted, datalen);
977                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
978         }
979         EVP_CIPHER_CTX_free(ctx);
980 }
981 #endif /* ENABLE_ALG_TESTS */
982
983 static void aes_cbc_encrypt(const struct aes_key *k,
984                             const u8 iv[AES_BLOCK_SIZE],
985                             const u8 *src, u8 *dst, size_t nbytes)
986 {
987         size_t i;
988
989         ASSERT(nbytes % AES_BLOCK_SIZE == 0);
990         for (i = 0; i < nbytes; i += AES_BLOCK_SIZE) {
991                 xor(&dst[i], &src[i], (i == 0 ? iv : &dst[i - AES_BLOCK_SIZE]),
992                     AES_BLOCK_SIZE);
993                 aes_encrypt(k, &dst[i], &dst[i]);
994         }
995 }
996
997 static void aes_cbc_decrypt(const struct aes_key *k,
998                             const u8 iv[AES_BLOCK_SIZE],
999                             const u8 *src, u8 *dst, size_t nbytes)
1000 {
1001         size_t i = nbytes;
1002
1003         ASSERT(i % AES_BLOCK_SIZE == 0);
1004         while (i) {
1005                 i -= AES_BLOCK_SIZE;
1006                 aes_decrypt(k, &src[i], &dst[i]);
1007                 xor(&dst[i], &dst[i], (i == 0 ? iv : &src[i - AES_BLOCK_SIZE]),
1008                     AES_BLOCK_SIZE);
1009         }
1010 }
1011
1012 static void aes_cts_cbc_encrypt(const u8 *key, int keysize,
1013                                 const u8 iv[AES_BLOCK_SIZE],
1014                                 const u8 *src, u8 *dst, size_t nbytes)
1015 {
1016         const size_t offset = ROUND_DOWN(nbytes - 1, AES_BLOCK_SIZE);
1017         const size_t final_bsize = nbytes - offset;
1018         struct aes_key k;
1019         u8 *pad;
1020         u8 buf[AES_BLOCK_SIZE];
1021
1022         ASSERT(nbytes >= AES_BLOCK_SIZE);
1023
1024         aes_setkey(&k, key, keysize);
1025
1026         if (nbytes == AES_BLOCK_SIZE)
1027                 return aes_cbc_encrypt(&k, iv, src, dst, nbytes);
1028
1029         aes_cbc_encrypt(&k, iv, src, dst, offset);
1030         pad = &dst[offset - AES_BLOCK_SIZE];
1031
1032         memcpy(buf, pad, AES_BLOCK_SIZE);
1033         xor(buf, buf, &src[offset], final_bsize);
1034         memcpy(&dst[offset], pad, final_bsize);
1035         aes_encrypt(&k, buf, pad);
1036 }
1037
1038 static void aes_cts_cbc_decrypt(const u8 *key, int keysize,
1039                                 const u8 iv[AES_BLOCK_SIZE],
1040                                 const u8 *src, u8 *dst, size_t nbytes)
1041 {
1042         const size_t offset = ROUND_DOWN(nbytes - 1, AES_BLOCK_SIZE);
1043         const size_t final_bsize = nbytes - offset;
1044         struct aes_key k;
1045         u8 *pad;
1046
1047         ASSERT(nbytes >= AES_BLOCK_SIZE);
1048
1049         aes_setkey(&k, key, keysize);
1050
1051         if (nbytes == AES_BLOCK_SIZE)
1052                 return aes_cbc_decrypt(&k, iv, src, dst, nbytes);
1053
1054         pad = &dst[offset - AES_BLOCK_SIZE];
1055         aes_decrypt(&k, &src[offset - AES_BLOCK_SIZE], pad);
1056         xor(&dst[offset], &src[offset], pad, final_bsize);
1057         xor(pad, pad, &dst[offset], final_bsize);
1058
1059         aes_cbc_decrypt(&k, (offset == AES_BLOCK_SIZE ?
1060                              iv : &src[offset - 2 * AES_BLOCK_SIZE]),
1061                         pad, pad, AES_BLOCK_SIZE);
1062         aes_cbc_decrypt(&k, iv, src, dst, offset - AES_BLOCK_SIZE);
1063 }
1064
1065 static void aes_256_cts_cbc_encrypt(const u8 key[AES_256_KEY_SIZE],
1066                                     const u8 iv[AES_BLOCK_SIZE],
1067                                     const u8 *src, u8 *dst, size_t nbytes)
1068 {
1069         aes_cts_cbc_encrypt(key, AES_256_KEY_SIZE, iv, src, dst, nbytes);
1070 }
1071
1072 static void aes_256_cts_cbc_decrypt(const u8 key[AES_256_KEY_SIZE],
1073                                     const u8 iv[AES_BLOCK_SIZE],
1074                                     const u8 *src, u8 *dst, size_t nbytes)
1075 {
1076         aes_cts_cbc_decrypt(key, AES_256_KEY_SIZE, iv, src, dst, nbytes);
1077 }
1078
1079 #ifdef ENABLE_ALG_TESTS
1080 #include <openssl/modes.h>
1081 static void aes_block128_f(const unsigned char in[16],
1082                            unsigned char out[16], const void *key)
1083 {
1084         aes_encrypt(key, in, out);
1085 }
1086
1087 static void test_aes_256_cts_cbc(void)
1088 {
1089         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1090
1091         while (num_tests--) {
1092                 u8 key[AES_256_KEY_SIZE];
1093                 u8 iv[AES_BLOCK_SIZE];
1094                 u8 iv_copy[AES_BLOCK_SIZE];
1095                 u8 ptext[512];
1096                 u8 ctext[sizeof(ptext)];
1097                 u8 ref_ctext[sizeof(ptext)];
1098                 u8 decrypted[sizeof(ptext)];
1099                 const size_t datalen = 16 + (rand() % (sizeof(ptext) - 15));
1100                 struct aes_key k;
1101
1102                 rand_bytes(key, sizeof(key));
1103                 rand_bytes(iv, sizeof(iv));
1104                 rand_bytes(ptext, datalen);
1105
1106                 aes_256_cts_cbc_encrypt(key, iv, ptext, ctext, datalen);
1107
1108                 /* OpenSSL doesn't allow datalen=AES_BLOCK_SIZE; Linux does */
1109                 if (datalen != AES_BLOCK_SIZE) {
1110                         aes_setkey(&k, key, sizeof(key));
1111                         memcpy(iv_copy, iv, sizeof(iv));
1112                         ASSERT(CRYPTO_cts128_encrypt_block(ptext, ref_ctext,
1113                                                            datalen, &k, iv_copy,
1114                                                            aes_block128_f)
1115                                == datalen);
1116                         ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1117                 }
1118                 aes_256_cts_cbc_decrypt(key, iv, ctext, decrypted, datalen);
1119                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1120         }
1121 }
1122 #endif /* ENABLE_ALG_TESTS */
1123
1124 static void essiv_generate_iv(const u8 orig_key[AES_128_KEY_SIZE],
1125                               const u8 orig_iv[AES_BLOCK_SIZE],
1126                               u8 real_iv[AES_BLOCK_SIZE])
1127 {
1128         u8 essiv_key[SHA256_DIGEST_SIZE];
1129         struct aes_key essiv;
1130
1131         /* AES encrypt the original IV using a hash of the original key */
1132         STATIC_ASSERT(SHA256_DIGEST_SIZE == AES_256_KEY_SIZE);
1133         sha256(orig_key, AES_128_KEY_SIZE, essiv_key);
1134         aes_setkey(&essiv, essiv_key, AES_256_KEY_SIZE);
1135         aes_encrypt(&essiv, orig_iv, real_iv);
1136 }
1137
1138 static void aes_128_cbc_essiv_encrypt(const u8 key[AES_128_KEY_SIZE],
1139                                       const u8 iv[AES_BLOCK_SIZE],
1140                                       const u8 *src, u8 *dst, size_t nbytes)
1141 {
1142         struct aes_key k;
1143         u8 real_iv[AES_BLOCK_SIZE];
1144
1145         aes_setkey(&k, key, AES_128_KEY_SIZE);
1146         essiv_generate_iv(key, iv, real_iv);
1147         aes_cbc_encrypt(&k, real_iv, src, dst, nbytes);
1148 }
1149
1150 static void aes_128_cbc_essiv_decrypt(const u8 key[AES_128_KEY_SIZE],
1151                                       const u8 iv[AES_BLOCK_SIZE],
1152                                       const u8 *src, u8 *dst, size_t nbytes)
1153 {
1154         struct aes_key k;
1155         u8 real_iv[AES_BLOCK_SIZE];
1156
1157         aes_setkey(&k, key, AES_128_KEY_SIZE);
1158         essiv_generate_iv(key, iv, real_iv);
1159         aes_cbc_decrypt(&k, real_iv, src, dst, nbytes);
1160 }
1161
1162 static void aes_128_cts_cbc_encrypt(const u8 key[AES_128_KEY_SIZE],
1163                                     const u8 iv[AES_BLOCK_SIZE],
1164                                     const u8 *src, u8 *dst, size_t nbytes)
1165 {
1166         aes_cts_cbc_encrypt(key, AES_128_KEY_SIZE, iv, src, dst, nbytes);
1167 }
1168
1169 static void aes_128_cts_cbc_decrypt(const u8 key[AES_128_KEY_SIZE],
1170                                     const u8 iv[AES_BLOCK_SIZE],
1171                                     const u8 *src, u8 *dst, size_t nbytes)
1172 {
1173         aes_cts_cbc_decrypt(key, AES_128_KEY_SIZE, iv, src, dst, nbytes);
1174 }
1175
1176 /*----------------------------------------------------------------------------*
1177  *                           XChaCha12 stream cipher                          *
1178  *----------------------------------------------------------------------------*/
1179
1180 /*
1181  * References:
1182  *   - "XChaCha: eXtended-nonce ChaCha and AEAD_XChaCha20_Poly1305"
1183  *      https://tools.ietf.org/html/draft-arciszewski-xchacha-03
1184  *
1185  *   - "ChaCha, a variant of Salsa20"
1186  *      https://cr.yp.to/chacha/chacha-20080128.pdf
1187  *
1188  *   - "Extending the Salsa20 nonce"
1189  *      https://cr.yp.to/snuffle/xsalsa-20081128.pdf
1190  */
1191
1192 #define CHACHA_KEY_SIZE         32
1193 #define XCHACHA_KEY_SIZE        CHACHA_KEY_SIZE
1194 #define XCHACHA_NONCE_SIZE      24
1195
1196 static void chacha_init_state(u32 state[16], const u8 key[CHACHA_KEY_SIZE],
1197                               const u8 iv[16])
1198 {
1199         static const u8 consts[16] = "expand 32-byte k";
1200         int i;
1201
1202         for (i = 0; i < 4; i++)
1203                 state[i] = get_unaligned_le32(&consts[i * sizeof(__le32)]);
1204         for (i = 0; i < 8; i++)
1205                 state[4 + i] = get_unaligned_le32(&key[i * sizeof(__le32)]);
1206         for (i = 0; i < 4; i++)
1207                 state[12 + i] = get_unaligned_le32(&iv[i * sizeof(__le32)]);
1208 }
1209
1210 #define CHACHA_QUARTERROUND(a, b, c, d)         \
1211         do {                                    \
1212                 a += b; d = rol32(d ^ a, 16);   \
1213                 c += d; b = rol32(b ^ c, 12);   \
1214                 a += b; d = rol32(d ^ a, 8);    \
1215                 c += d; b = rol32(b ^ c, 7);    \
1216         } while (0)
1217
1218 static void chacha_permute(u32 x[16], int nrounds)
1219 {
1220         do {
1221                 /* column round */
1222                 CHACHA_QUARTERROUND(x[0], x[4], x[8], x[12]);
1223                 CHACHA_QUARTERROUND(x[1], x[5], x[9], x[13]);
1224                 CHACHA_QUARTERROUND(x[2], x[6], x[10], x[14]);
1225                 CHACHA_QUARTERROUND(x[3], x[7], x[11], x[15]);
1226
1227                 /* diagonal round */
1228                 CHACHA_QUARTERROUND(x[0], x[5], x[10], x[15]);
1229                 CHACHA_QUARTERROUND(x[1], x[6], x[11], x[12]);
1230                 CHACHA_QUARTERROUND(x[2], x[7], x[8], x[13]);
1231                 CHACHA_QUARTERROUND(x[3], x[4], x[9], x[14]);
1232         } while ((nrounds -= 2) != 0);
1233 }
1234
1235 static void xchacha(const u8 key[XCHACHA_KEY_SIZE],
1236                     const u8 nonce[XCHACHA_NONCE_SIZE],
1237                     const u8 *src, u8 *dst, size_t nbytes, int nrounds)
1238 {
1239         u32 state[16];
1240         u8 real_key[CHACHA_KEY_SIZE];
1241         u8 real_iv[16] = { 0 };
1242         size_t i, j;
1243
1244         /* Compute real key using original key and first 128 nonce bits */
1245         chacha_init_state(state, key, nonce);
1246         chacha_permute(state, nrounds);
1247         for (i = 0; i < 8; i++) /* state words 0..3, 12..15 */
1248                 put_unaligned_le32(state[(i < 4 ? 0 : 8) + i],
1249                                    &real_key[i * sizeof(__le32)]);
1250
1251         /* Now do regular ChaCha, using real key and remaining nonce bits */
1252         memcpy(&real_iv[8], nonce + 16, 8);
1253         chacha_init_state(state, real_key, real_iv);
1254         for (i = 0; i < nbytes; i += 64) {
1255                 u32 x[16];
1256                 __le32 keystream[16];
1257
1258                 memcpy(x, state, 64);
1259                 chacha_permute(x, nrounds);
1260                 for (j = 0; j < 16; j++)
1261                         keystream[j] = cpu_to_le32(x[j] + state[j]);
1262                 xor(&dst[i], &src[i], (u8 *)keystream, MIN(nbytes - i, 64));
1263                 if (++state[12] == 0)
1264                         state[13]++;
1265         }
1266 }
1267
1268 static void xchacha12(const u8 key[XCHACHA_KEY_SIZE],
1269                       const u8 nonce[XCHACHA_NONCE_SIZE],
1270                       const u8 *src, u8 *dst, size_t nbytes)
1271 {
1272         xchacha(key, nonce, src, dst, nbytes, 12);
1273 }
1274
1275 /*----------------------------------------------------------------------------*
1276  *                                 Poly1305                                   *
1277  *----------------------------------------------------------------------------*/
1278
1279 /*
1280  * Note: this is only the Poly1305 Îµ-almost-∆-universal hash function, not the
1281  * full Poly1305 MAC.  I.e., it doesn't add anything at the end.
1282  */
1283
1284 #define POLY1305_KEY_SIZE       16
1285 #define POLY1305_BLOCK_SIZE     16
1286
1287 static void poly1305(const u8 key[POLY1305_KEY_SIZE],
1288                      const u8 *msg, size_t msglen, le128 *out)
1289 {
1290         const u32 limb_mask = 0x3ffffff;        /* limbs are base 2^26 */
1291         const u64 r0 = (get_unaligned_le32(key +  0) >> 0) & 0x3ffffff;
1292         const u64 r1 = (get_unaligned_le32(key +  3) >> 2) & 0x3ffff03;
1293         const u64 r2 = (get_unaligned_le32(key +  6) >> 4) & 0x3ffc0ff;
1294         const u64 r3 = (get_unaligned_le32(key +  9) >> 6) & 0x3f03fff;
1295         const u64 r4 = (get_unaligned_le32(key + 12) >> 8) & 0x00fffff;
1296         u32 h0 = 0, h1 = 0, h2 = 0, h3 = 0, h4 = 0;
1297         u32 g0, g1, g2, g3, g4, ge_p_mask;
1298
1299         /* Partial block support is not necessary for Adiantum */
1300         ASSERT(msglen % POLY1305_BLOCK_SIZE == 0);
1301
1302         while (msglen) {
1303                 u64 d0, d1, d2, d3, d4;
1304
1305                 /* h += *msg */
1306                 h0 += (get_unaligned_le32(msg +  0) >> 0) & limb_mask;
1307                 h1 += (get_unaligned_le32(msg +  3) >> 2) & limb_mask;
1308                 h2 += (get_unaligned_le32(msg +  6) >> 4) & limb_mask;
1309                 h3 += (get_unaligned_le32(msg +  9) >> 6) & limb_mask;
1310                 h4 += (get_unaligned_le32(msg + 12) >> 8) | (1 << 24);
1311
1312                 /* h *= r */
1313                 d0 = h0*r0 + h1*5*r4 + h2*5*r3 + h3*5*r2 + h4*5*r1;
1314                 d1 = h0*r1 + h1*r0   + h2*5*r4 + h3*5*r3 + h4*5*r2;
1315                 d2 = h0*r2 + h1*r1   + h2*r0   + h3*5*r4 + h4*5*r3;
1316                 d3 = h0*r3 + h1*r2   + h2*r1   + h3*r0   + h4*5*r4;
1317                 d4 = h0*r4 + h1*r3   + h2*r2   + h3*r1   + h4*r0;
1318
1319                 /* (partial) h %= 2^130 - 5 */
1320                 d1 += d0 >> 26;         h0 = d0 & limb_mask;
1321                 d2 += d1 >> 26;         h1 = d1 & limb_mask;
1322                 d3 += d2 >> 26;         h2 = d2 & limb_mask;
1323                 d4 += d3 >> 26;         h3 = d3 & limb_mask;
1324                 h0 += (d4 >> 26) * 5;   h4 = d4 & limb_mask;
1325                 h1 += h0 >> 26;         h0 &= limb_mask;
1326
1327                 msg += POLY1305_BLOCK_SIZE;
1328                 msglen -= POLY1305_BLOCK_SIZE;
1329         }
1330
1331         /* fully carry h */
1332         h2 += (h1 >> 26);       h1 &= limb_mask;
1333         h3 += (h2 >> 26);       h2 &= limb_mask;
1334         h4 += (h3 >> 26);       h3 &= limb_mask;
1335         h0 += (h4 >> 26) * 5;   h4 &= limb_mask;
1336         h1 += (h0 >> 26);       h0 &= limb_mask;
1337
1338         /* if (h >= 2^130 - 5) h -= 2^130 - 5; */
1339         g0 = h0 + 5;
1340         g1 = h1 + (g0 >> 26);   g0 &= limb_mask;
1341         g2 = h2 + (g1 >> 26);   g1 &= limb_mask;
1342         g3 = h3 + (g2 >> 26);   g2 &= limb_mask;
1343         g4 = h4 + (g3 >> 26);   g3 &= limb_mask;
1344         ge_p_mask = ~((g4 >> 26) - 1); /* all 1's if h >= 2^130 - 5, else 0 */
1345         h0 = (h0 & ~ge_p_mask) | (g0 & ge_p_mask);
1346         h1 = (h1 & ~ge_p_mask) | (g1 & ge_p_mask);
1347         h2 = (h2 & ~ge_p_mask) | (g2 & ge_p_mask);
1348         h3 = (h3 & ~ge_p_mask) | (g3 & ge_p_mask);
1349         h4 = (h4 & ~ge_p_mask) | (g4 & ge_p_mask & limb_mask);
1350
1351         /* h %= 2^128 */
1352         out->lo = cpu_to_le64(((u64)h2 << 52) | ((u64)h1 << 26) | h0);
1353         out->hi = cpu_to_le64(((u64)h4 << 40) | ((u64)h3 << 14) | (h2 >> 12));
1354 }
1355
1356 /*----------------------------------------------------------------------------*
1357  *                          Adiantum encryption mode                          *
1358  *----------------------------------------------------------------------------*/
1359
1360 /*
1361  * Reference: "Adiantum: length-preserving encryption for entry-level processors"
1362  *      https://tosc.iacr.org/index.php/ToSC/article/view/7360
1363  */
1364
1365 #define ADIANTUM_KEY_SIZE       32
1366 #define ADIANTUM_IV_SIZE        32
1367 #define ADIANTUM_HASH_KEY_SIZE  ((2 * POLY1305_KEY_SIZE) + NH_KEY_SIZE)
1368
1369 #define NH_KEY_SIZE             1072
1370 #define NH_KEY_WORDS            (NH_KEY_SIZE / sizeof(u32))
1371 #define NH_BLOCK_SIZE           1024
1372 #define NH_HASH_SIZE            32
1373 #define NH_MESSAGE_UNIT         16
1374
1375 static u64 nh_pass(const u32 *key, const u8 *msg, size_t msglen)
1376 {
1377         u64 sum = 0;
1378
1379         ASSERT(msglen % NH_MESSAGE_UNIT == 0);
1380         while (msglen) {
1381                 sum += (u64)(u32)(get_unaligned_le32(msg +  0) + key[0]) *
1382                             (u32)(get_unaligned_le32(msg +  8) + key[2]);
1383                 sum += (u64)(u32)(get_unaligned_le32(msg +  4) + key[1]) *
1384                             (u32)(get_unaligned_le32(msg + 12) + key[3]);
1385                 key += NH_MESSAGE_UNIT / sizeof(key[0]);
1386                 msg += NH_MESSAGE_UNIT;
1387                 msglen -= NH_MESSAGE_UNIT;
1388         }
1389         return sum;
1390 }
1391
1392 /* NH Îµ-almost-universal hash function */
1393 static void nh(const u32 *key, const u8 *msg, size_t msglen,
1394                u8 result[NH_HASH_SIZE])
1395 {
1396         size_t i;
1397
1398         for (i = 0; i < NH_HASH_SIZE; i += sizeof(__le64)) {
1399                 put_unaligned_le64(nh_pass(key, msg, msglen), &result[i]);
1400                 key += NH_MESSAGE_UNIT / sizeof(key[0]);
1401         }
1402 }
1403
1404 /* Adiantum's Îµ-almost-∆-universal hash function */
1405 static void adiantum_hash(const u8 key[ADIANTUM_HASH_KEY_SIZE],
1406                           const u8 iv[ADIANTUM_IV_SIZE],
1407                           const u8 *msg, size_t msglen, le128 *result)
1408 {
1409         const u8 *header_poly_key = key;
1410         const u8 *msg_poly_key = header_poly_key + POLY1305_KEY_SIZE;
1411         const u8 *nh_key = msg_poly_key + POLY1305_KEY_SIZE;
1412         u32 nh_key_words[NH_KEY_WORDS];
1413         u8 header[POLY1305_BLOCK_SIZE + ADIANTUM_IV_SIZE];
1414         const size_t num_nh_blocks = DIV_ROUND_UP(msglen, NH_BLOCK_SIZE);
1415         u8 *nh_hashes = xmalloc(num_nh_blocks * NH_HASH_SIZE);
1416         const size_t padded_msglen = ROUND_UP(msglen, NH_MESSAGE_UNIT);
1417         u8 *padded_msg = xmalloc(padded_msglen);
1418         le128 hash1, hash2;
1419         size_t i;
1420
1421         for (i = 0; i < NH_KEY_WORDS; i++)
1422                 nh_key_words[i] = get_unaligned_le32(&nh_key[i * sizeof(u32)]);
1423
1424         /* Hash tweak and message length with first Poly1305 key */
1425         put_unaligned_le64((u64)msglen * 8, header);
1426         put_unaligned_le64(0, &header[sizeof(__le64)]);
1427         memcpy(&header[POLY1305_BLOCK_SIZE], iv, ADIANTUM_IV_SIZE);
1428         poly1305(header_poly_key, header, sizeof(header), &hash1);
1429
1430         /* Hash NH hashes of message blocks using second Poly1305 key */
1431         /* (using a super naive way of handling the padding) */
1432         memcpy(padded_msg, msg, msglen);
1433         memset(&padded_msg[msglen], 0, padded_msglen - msglen);
1434         for (i = 0; i < num_nh_blocks; i++) {
1435                 nh(nh_key_words, &padded_msg[i * NH_BLOCK_SIZE],
1436                    MIN(NH_BLOCK_SIZE, padded_msglen - (i * NH_BLOCK_SIZE)),
1437                    &nh_hashes[i * NH_HASH_SIZE]);
1438         }
1439         poly1305(msg_poly_key, nh_hashes, num_nh_blocks * NH_HASH_SIZE, &hash2);
1440
1441         /* Add the two hashes together to get the final hash */
1442         le128_add(result, &hash1, &hash2);
1443
1444         free(nh_hashes);
1445         free(padded_msg);
1446 }
1447
1448 static void adiantum_crypt(const u8 key[ADIANTUM_KEY_SIZE],
1449                            const u8 iv[ADIANTUM_IV_SIZE], const u8 *src,
1450                            u8 *dst, size_t nbytes, bool decrypting)
1451 {
1452         u8 subkeys[AES_256_KEY_SIZE + ADIANTUM_HASH_KEY_SIZE] = { 0 };
1453         struct aes_key aes_key;
1454         union {
1455                 u8 nonce[XCHACHA_NONCE_SIZE];
1456                 le128 block;
1457         } u = { .nonce = { 1 } };
1458         const size_t bulk_len = nbytes - sizeof(u.block);
1459         le128 hash;
1460
1461         ASSERT(nbytes >= sizeof(u.block));
1462
1463         /* Derive subkeys */
1464         xchacha12(key, u.nonce, subkeys, subkeys, sizeof(subkeys));
1465         aes_setkey(&aes_key, subkeys, AES_256_KEY_SIZE);
1466
1467         /* Hash left part and add to right part */
1468         adiantum_hash(&subkeys[AES_256_KEY_SIZE], iv, src, bulk_len, &hash);
1469         memcpy(&u.block, &src[bulk_len], sizeof(u.block));
1470         le128_add(&u.block, &u.block, &hash);
1471
1472         if (!decrypting) /* Encrypt right part with block cipher */
1473                 aes_encrypt(&aes_key, u.nonce, u.nonce);
1474
1475         /* Encrypt left part with stream cipher, using the computed nonce */
1476         u.nonce[sizeof(u.block)] = 1;
1477         xchacha12(key, u.nonce, src, dst, bulk_len);
1478
1479         if (decrypting) /* Decrypt right part with block cipher */
1480                 aes_decrypt(&aes_key, u.nonce, u.nonce);
1481
1482         /* Finalize right part by subtracting hash of left part */
1483         adiantum_hash(&subkeys[AES_256_KEY_SIZE], iv, dst, bulk_len, &hash);
1484         le128_sub(&u.block, &u.block, &hash);
1485         memcpy(&dst[bulk_len], &u.block, sizeof(u.block));
1486 }
1487
1488 static void adiantum_encrypt(const u8 key[ADIANTUM_KEY_SIZE],
1489                              const u8 iv[ADIANTUM_IV_SIZE],
1490                              const u8 *src, u8 *dst, size_t nbytes)
1491 {
1492         adiantum_crypt(key, iv, src, dst, nbytes, false);
1493 }
1494
1495 static void adiantum_decrypt(const u8 key[ADIANTUM_KEY_SIZE],
1496                              const u8 iv[ADIANTUM_IV_SIZE],
1497                              const u8 *src, u8 *dst, size_t nbytes)
1498 {
1499         adiantum_crypt(key, iv, src, dst, nbytes, true);
1500 }
1501
1502 #ifdef ENABLE_ALG_TESTS
1503 #include <linux/if_alg.h>
1504 #include <sys/socket.h>
1505 #define SOL_ALG 279
1506 static void af_alg_crypt(int algfd, int op, const u8 *key, size_t keylen,
1507                          const u8 *iv, size_t ivlen,
1508                          const u8 *src, u8 *dst, size_t datalen)
1509 {
1510         size_t controllen = CMSG_SPACE(sizeof(int)) +
1511                             CMSG_SPACE(sizeof(struct af_alg_iv) + ivlen);
1512         u8 *control = xmalloc(controllen);
1513         struct iovec iov = { .iov_base = (u8 *)src, .iov_len = datalen };
1514         struct msghdr msg = {
1515                 .msg_iov = &iov,
1516                 .msg_iovlen = 1,
1517                 .msg_control = control,
1518                 .msg_controllen = controllen,
1519         };
1520         struct cmsghdr *cmsg;
1521         struct af_alg_iv *algiv;
1522         int reqfd;
1523
1524         memset(control, 0, controllen);
1525
1526         cmsg = CMSG_FIRSTHDR(&msg);
1527         cmsg->cmsg_len = CMSG_LEN(sizeof(int));
1528         cmsg->cmsg_level = SOL_ALG;
1529         cmsg->cmsg_type = ALG_SET_OP;
1530         *(int *)CMSG_DATA(cmsg) = op;
1531
1532         cmsg = CMSG_NXTHDR(&msg, cmsg);
1533         cmsg->cmsg_len = CMSG_LEN(sizeof(struct af_alg_iv) + ivlen);
1534         cmsg->cmsg_level = SOL_ALG;
1535         cmsg->cmsg_type = ALG_SET_IV;
1536         algiv = (struct af_alg_iv *)CMSG_DATA(cmsg);
1537         algiv->ivlen = ivlen;
1538         memcpy(algiv->iv, iv, ivlen);
1539
1540         if (setsockopt(algfd, SOL_ALG, ALG_SET_KEY, key, keylen) != 0)
1541                 die_errno("can't set key on AF_ALG socket");
1542
1543         reqfd = accept(algfd, NULL, NULL);
1544         if (reqfd < 0)
1545                 die_errno("can't accept() AF_ALG socket");
1546         if (sendmsg(reqfd, &msg, 0) != datalen)
1547                 die_errno("can't sendmsg() AF_ALG request socket");
1548         if (xread(reqfd, dst, datalen) != datalen)
1549                 die("short read from AF_ALG request socket");
1550         close(reqfd);
1551
1552         free(control);
1553 }
1554
1555 static void test_adiantum(void)
1556 {
1557         int algfd = socket(AF_ALG, SOCK_SEQPACKET, 0);
1558         struct sockaddr_alg addr = {
1559                 .salg_type = "skcipher",
1560                 .salg_name = "adiantum(xchacha12,aes)",
1561         };
1562         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1563
1564         if (algfd < 0)
1565                 die_errno("can't create AF_ALG socket");
1566         if (bind(algfd, (struct sockaddr *)&addr, sizeof(addr)) != 0)
1567                 die_errno("can't bind AF_ALG socket to Adiantum algorithm");
1568
1569         while (num_tests--) {
1570                 u8 key[ADIANTUM_KEY_SIZE];
1571                 u8 iv[ADIANTUM_IV_SIZE];
1572                 u8 ptext[4096];
1573                 u8 ctext[sizeof(ptext)];
1574                 u8 ref_ctext[sizeof(ptext)];
1575                 u8 decrypted[sizeof(ptext)];
1576                 const size_t datalen = 16 + (rand() % (sizeof(ptext) - 15));
1577
1578                 rand_bytes(key, sizeof(key));
1579                 rand_bytes(iv, sizeof(iv));
1580                 rand_bytes(ptext, datalen);
1581
1582                 adiantum_encrypt(key, iv, ptext, ctext, datalen);
1583                 af_alg_crypt(algfd, ALG_OP_ENCRYPT, key, sizeof(key),
1584                              iv, sizeof(iv), ptext, ref_ctext, datalen);
1585                 ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1586
1587                 adiantum_decrypt(key, iv, ctext, decrypted, datalen);
1588                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1589         }
1590         close(algfd);
1591 }
1592 #endif /* ENABLE_ALG_TESTS */
1593
1594 /*----------------------------------------------------------------------------*
1595  *                               SipHash-2-4                                  *
1596  *----------------------------------------------------------------------------*/
1597
1598 /*
1599  * Reference: "SipHash: a fast short-input PRF"
1600  *      https://cr.yp.to/siphash/siphash-20120918.pdf
1601  */
1602
1603 #define SIPROUND                                                \
1604         do {                                                    \
1605                 v0 += v1;           v2 += v3;                   \
1606                 v1 = rol64(v1, 13); v3 = rol64(v3, 16);         \
1607                 v1 ^= v0;           v3 ^= v2;                   \
1608                 v0 = rol64(v0, 32);                             \
1609                 v2 += v1;           v0 += v3;                   \
1610                 v1 = rol64(v1, 17); v3 = rol64(v3, 21);         \
1611                 v1 ^= v2;           v3 ^= v0;                   \
1612                 v2 = rol64(v2, 32);                             \
1613         } while (0)
1614
1615 /* Compute the SipHash-2-4 of a 64-bit number when formatted as little endian */
1616 static u64 siphash_1u64(const u64 key[2], u64 data)
1617 {
1618         u64 v0 = key[0] ^ 0x736f6d6570736575ULL;
1619         u64 v1 = key[1] ^ 0x646f72616e646f6dULL;
1620         u64 v2 = key[0] ^ 0x6c7967656e657261ULL;
1621         u64 v3 = key[1] ^ 0x7465646279746573ULL;
1622         u64 m[2] = {data, (u64)sizeof(data) << 56};
1623         size_t i;
1624
1625         for (i = 0; i < ARRAY_SIZE(m); i++) {
1626                 v3 ^= m[i];
1627                 SIPROUND;
1628                 SIPROUND;
1629                 v0 ^= m[i];
1630         }
1631
1632         v2 ^= 0xff;
1633         for (i = 0; i < 4; i++)
1634                 SIPROUND;
1635         return v0 ^ v1 ^ v2 ^ v3;
1636 }
1637
1638 /*----------------------------------------------------------------------------*
1639  *                               Main program                                 *
1640  *----------------------------------------------------------------------------*/
1641
1642 #define FILE_NONCE_SIZE         16
1643 #define UUID_SIZE               16
1644 #define MAX_KEY_SIZE            64
1645 #define MAX_IV_SIZE             ADIANTUM_IV_SIZE
1646
1647 static const struct fscrypt_cipher {
1648         const char *name;
1649         void (*encrypt)(const u8 *key, const u8 *iv, const u8 *src,
1650                         u8 *dst, size_t nbytes);
1651         void (*decrypt)(const u8 *key, const u8 *iv, const u8 *src,
1652                         u8 *dst, size_t nbytes);
1653         int keysize;
1654         int min_input_size;
1655 } fscrypt_ciphers[] = {
1656         {
1657                 .name = "AES-256-XTS",
1658                 .encrypt = aes_256_xts_encrypt,
1659                 .decrypt = aes_256_xts_decrypt,
1660                 .keysize = 2 * AES_256_KEY_SIZE,
1661         }, {
1662                 .name = "AES-256-CTS-CBC",
1663                 .encrypt = aes_256_cts_cbc_encrypt,
1664                 .decrypt = aes_256_cts_cbc_decrypt,
1665                 .keysize = AES_256_KEY_SIZE,
1666                 .min_input_size = AES_BLOCK_SIZE,
1667         }, {
1668                 .name = "AES-128-CBC-ESSIV",
1669                 .encrypt = aes_128_cbc_essiv_encrypt,
1670                 .decrypt = aes_128_cbc_essiv_decrypt,
1671                 .keysize = AES_128_KEY_SIZE,
1672         }, {
1673                 .name = "AES-128-CTS-CBC",
1674                 .encrypt = aes_128_cts_cbc_encrypt,
1675                 .decrypt = aes_128_cts_cbc_decrypt,
1676                 .keysize = AES_128_KEY_SIZE,
1677                 .min_input_size = AES_BLOCK_SIZE,
1678         }, {
1679                 .name = "Adiantum",
1680                 .encrypt = adiantum_encrypt,
1681                 .decrypt = adiantum_decrypt,
1682                 .keysize = ADIANTUM_KEY_SIZE,
1683                 .min_input_size = AES_BLOCK_SIZE,
1684         }
1685 };
1686
1687 static const struct fscrypt_cipher *find_fscrypt_cipher(const char *name)
1688 {
1689         size_t i;
1690
1691         for (i = 0; i < ARRAY_SIZE(fscrypt_ciphers); i++) {
1692                 if (strcmp(fscrypt_ciphers[i].name, name) == 0)
1693                         return &fscrypt_ciphers[i];
1694         }
1695         return NULL;
1696 }
1697
1698 union fscrypt_iv {
1699         /* usual IV format */
1700         struct {
1701                 /* logical block number within the file */
1702                 __le64 block_number;
1703
1704                 /* per-file nonce; only set in DIRECT_KEY mode */
1705                 u8 nonce[FILE_NONCE_SIZE];
1706         };
1707         /* IV format for IV_INO_LBLK_* modes */
1708         struct {
1709                 /*
1710                  * IV_INO_LBLK_64: logical block number within the file
1711                  * IV_INO_LBLK_32: hashed inode number + logical block number
1712                  *                 within the file, mod 2^32
1713                  */
1714                 __le32 block_number32;
1715
1716                 /* IV_INO_LBLK_64: inode number */
1717                 __le32 inode_number;
1718         };
1719         /* Any extra bytes up to the algorithm's IV size must be zeroed */
1720         u8 bytes[MAX_IV_SIZE];
1721 };
1722
1723 static void crypt_loop(const struct fscrypt_cipher *cipher, const u8 *key,
1724                        union fscrypt_iv *iv, bool decrypting,
1725                        size_t block_size, size_t padding, bool is_bnum_32bit)
1726 {
1727         u8 *buf = xmalloc(block_size);
1728         size_t res;
1729
1730         while ((res = xread(STDIN_FILENO, buf, block_size)) > 0) {
1731                 size_t crypt_len = block_size;
1732
1733                 if (padding > 0) {
1734                         crypt_len = MAX(res, cipher->min_input_size);
1735                         crypt_len = ROUND_UP(crypt_len, padding);
1736                         crypt_len = MIN(crypt_len, block_size);
1737                 }
1738                 ASSERT(crypt_len >= res);
1739                 memset(&buf[res], 0, crypt_len - res);
1740
1741                 if (decrypting)
1742                         cipher->decrypt(key, iv->bytes, buf, buf, crypt_len);
1743                 else
1744                         cipher->encrypt(key, iv->bytes, buf, buf, crypt_len);
1745
1746                 full_write(STDOUT_FILENO, buf, crypt_len);
1747
1748                 if (is_bnum_32bit)
1749                         iv->block_number32 = cpu_to_le32(
1750                                         le32_to_cpu(iv->block_number32) + 1);
1751                 else
1752                         iv->block_number = cpu_to_le64(
1753                                         le64_to_cpu(iv->block_number) + 1);
1754         }
1755         free(buf);
1756 }
1757
1758 /* The supported key derivation functions */
1759 enum kdf_algorithm {
1760         KDF_NONE,
1761         KDF_AES_128_ECB,
1762         KDF_HKDF_SHA512,
1763 };
1764
1765 static enum kdf_algorithm parse_kdf_algorithm(const char *arg)
1766 {
1767         if (strcmp(arg, "none") == 0)
1768                 return KDF_NONE;
1769         if (strcmp(arg, "AES-128-ECB") == 0)
1770                 return KDF_AES_128_ECB;
1771         if (strcmp(arg, "HKDF-SHA512") == 0)
1772                 return KDF_HKDF_SHA512;
1773         die("Unknown KDF: %s", arg);
1774 }
1775
1776 static u8 parse_mode_number(const char *arg)
1777 {
1778         char *tmp;
1779         long num = strtol(arg, &tmp, 10);
1780
1781         if (num <= 0 || *tmp || (u8)num != num)
1782                 die("Invalid mode number: %s", arg);
1783         return num;
1784 }
1785
1786 struct key_and_iv_params {
1787         u8 master_key[MAX_KEY_SIZE];
1788         int master_key_size;
1789         enum kdf_algorithm kdf;
1790         u8 mode_num;
1791         u8 file_nonce[FILE_NONCE_SIZE];
1792         bool file_nonce_specified;
1793         bool iv_ino_lblk_64;
1794         bool iv_ino_lblk_32;
1795         u64 block_number;
1796         u64 inode_number;
1797         u8 fs_uuid[UUID_SIZE];
1798         bool fs_uuid_specified;
1799 };
1800
1801 #define HKDF_CONTEXT_KEY_IDENTIFIER     1
1802 #define HKDF_CONTEXT_PER_FILE_ENC_KEY   2
1803 #define HKDF_CONTEXT_DIRECT_KEY         3
1804 #define HKDF_CONTEXT_IV_INO_LBLK_64_KEY 4
1805 #define HKDF_CONTEXT_DIRHASH_KEY        5
1806 #define HKDF_CONTEXT_IV_INO_LBLK_32_KEY 6
1807 #define HKDF_CONTEXT_INODE_HASH_KEY     7
1808
1809 /* Hash the file's inode number using SipHash keyed by a derived key */
1810 static u32 hash_inode_number(const struct key_and_iv_params *params)
1811 {
1812         u8 info[9] = "fscrypt";
1813         union {
1814                 u64 words[2];
1815                 u8 bytes[16];
1816         } hash_key;
1817
1818         info[8] = HKDF_CONTEXT_INODE_HASH_KEY;
1819         hkdf_sha512(params->master_key, params->master_key_size,
1820                     NULL, 0, info, sizeof(info),
1821                     hash_key.bytes, sizeof(hash_key));
1822
1823         hash_key.words[0] = get_unaligned_le64(&hash_key.bytes[0]);
1824         hash_key.words[1] = get_unaligned_le64(&hash_key.bytes[8]);
1825
1826         return (u32)siphash_1u64(hash_key.words, params->inode_number);
1827 }
1828
1829 /*
1830  * Get the key and starting IV with which the encryption will actually be done.
1831  * If a KDF was specified, a subkey is derived from the master key and the mode
1832  * number or file nonce.  Otherwise, the master key is used directly.
1833  */
1834 static void get_key_and_iv(const struct key_and_iv_params *params,
1835                            u8 *real_key, size_t real_key_size,
1836                            union fscrypt_iv *iv)
1837 {
1838         bool file_nonce_in_iv = false;
1839         struct aes_key aes_key;
1840         u8 info[8 + 1 + 1 + UUID_SIZE] = "fscrypt";
1841         size_t infolen = 8;
1842         size_t i;
1843
1844         ASSERT(real_key_size <= params->master_key_size);
1845
1846         memset(iv, 0, sizeof(*iv));
1847
1848         /* Overridden later for iv_ino_lblk_{64,32} */
1849         iv->block_number = cpu_to_le64(params->block_number);
1850
1851         if (params->iv_ino_lblk_64 || params->iv_ino_lblk_32) {
1852                 const char *opt = params->iv_ino_lblk_64 ? "--iv-ino-lblk-64" :
1853                                                            "--iv-ino-lblk-32";
1854                 if (params->iv_ino_lblk_64 && params->iv_ino_lblk_32)
1855                         die("--iv-ino-lblk-64 and --iv-ino-lblk-32 are mutually exclusive");
1856                 if (params->kdf != KDF_HKDF_SHA512)
1857                         die("%s requires --kdf=HKDF-SHA512", opt);
1858                 if (!params->fs_uuid_specified)
1859                         die("%s requires --fs-uuid", opt);
1860                 if (params->inode_number == 0)
1861                         die("%s requires --inode-number", opt);
1862                 if (params->mode_num == 0)
1863                         die("%s requires --mode-num", opt);
1864                 if (params->block_number > UINT32_MAX)
1865                         die("%s can't use --block-number > UINT32_MAX", opt);
1866                 if (params->inode_number > UINT32_MAX)
1867                         die("%s can't use --inode-number > UINT32_MAX", opt);
1868         }
1869
1870         switch (params->kdf) {
1871         case KDF_NONE:
1872                 if (params->mode_num != 0)
1873                         die("--mode-num isn't supported with --kdf=none");
1874                 memcpy(real_key, params->master_key, real_key_size);
1875                 file_nonce_in_iv = true;
1876                 break;
1877         case KDF_AES_128_ECB:
1878                 if (!params->file_nonce_specified)
1879                         die("--file-nonce is required with --kdf=AES-128-ECB");
1880                 if (params->mode_num != 0)
1881                         die("--mode-num isn't supported with --kdf=AES-128-ECB");
1882                 STATIC_ASSERT(FILE_NONCE_SIZE == AES_128_KEY_SIZE);
1883                 ASSERT(real_key_size % AES_BLOCK_SIZE == 0);
1884                 aes_setkey(&aes_key, params->file_nonce, AES_128_KEY_SIZE);
1885                 for (i = 0; i < real_key_size; i += AES_BLOCK_SIZE)
1886                         aes_encrypt(&aes_key, &params->master_key[i],
1887                                     &real_key[i]);
1888                 break;
1889         case KDF_HKDF_SHA512:
1890                 if (params->iv_ino_lblk_64) {
1891                         info[infolen++] = HKDF_CONTEXT_IV_INO_LBLK_64_KEY;
1892                         info[infolen++] = params->mode_num;
1893                         memcpy(&info[infolen], params->fs_uuid, UUID_SIZE);
1894                         infolen += UUID_SIZE;
1895                         iv->block_number32 = cpu_to_le32(params->block_number);
1896                         iv->inode_number = cpu_to_le32(params->inode_number);
1897                 } else if (params->iv_ino_lblk_32) {
1898                         info[infolen++] = HKDF_CONTEXT_IV_INO_LBLK_32_KEY;
1899                         info[infolen++] = params->mode_num;
1900                         memcpy(&info[infolen], params->fs_uuid, UUID_SIZE);
1901                         infolen += UUID_SIZE;
1902                         iv->block_number32 =
1903                                 cpu_to_le32(hash_inode_number(params) +
1904                                             params->block_number);
1905                         iv->inode_number = 0;
1906                 } else if (params->mode_num != 0) {
1907                         info[infolen++] = HKDF_CONTEXT_DIRECT_KEY;
1908                         info[infolen++] = params->mode_num;
1909                         file_nonce_in_iv = true;
1910                 } else if (params->file_nonce_specified) {
1911                         info[infolen++] = HKDF_CONTEXT_PER_FILE_ENC_KEY;
1912                         memcpy(&info[infolen], params->file_nonce,
1913                                FILE_NONCE_SIZE);
1914                         infolen += FILE_NONCE_SIZE;
1915                 } else {
1916                         die("With --kdf=HKDF-SHA512, at least one of --file-nonce and --mode-num must be specified");
1917                 }
1918                 hkdf_sha512(params->master_key, params->master_key_size,
1919                             NULL, 0, info, infolen, real_key, real_key_size);
1920                 break;
1921         default:
1922                 ASSERT(0);
1923         }
1924
1925         if (file_nonce_in_iv && params->file_nonce_specified)
1926                 memcpy(iv->nonce, params->file_nonce, FILE_NONCE_SIZE);
1927 }
1928
1929 enum {
1930         OPT_BLOCK_NUMBER,
1931         OPT_BLOCK_SIZE,
1932         OPT_DECRYPT,
1933         OPT_FILE_NONCE,
1934         OPT_FS_UUID,
1935         OPT_HELP,
1936         OPT_INODE_NUMBER,
1937         OPT_IV_INO_LBLK_32,
1938         OPT_IV_INO_LBLK_64,
1939         OPT_KDF,
1940         OPT_MODE_NUM,
1941         OPT_PADDING,
1942 };
1943
1944 static const struct option longopts[] = {
1945         { "block-number",    required_argument, NULL, OPT_BLOCK_NUMBER },
1946         { "block-size",      required_argument, NULL, OPT_BLOCK_SIZE },
1947         { "decrypt",         no_argument,       NULL, OPT_DECRYPT },
1948         { "file-nonce",      required_argument, NULL, OPT_FILE_NONCE },
1949         { "fs-uuid",         required_argument, NULL, OPT_FS_UUID },
1950         { "help",            no_argument,       NULL, OPT_HELP },
1951         { "inode-number",    required_argument, NULL, OPT_INODE_NUMBER },
1952         { "iv-ino-lblk-32",  no_argument,       NULL, OPT_IV_INO_LBLK_32 },
1953         { "iv-ino-lblk-64",  no_argument,       NULL, OPT_IV_INO_LBLK_64 },
1954         { "kdf",             required_argument, NULL, OPT_KDF },
1955         { "mode-num",        required_argument, NULL, OPT_MODE_NUM },
1956         { "padding",         required_argument, NULL, OPT_PADDING },
1957         { NULL, 0, NULL, 0 },
1958 };
1959
1960 int main(int argc, char *argv[])
1961 {
1962         size_t block_size = 4096;
1963         bool decrypting = false;
1964         struct key_and_iv_params params;
1965         size_t padding = 0;
1966         const struct fscrypt_cipher *cipher;
1967         u8 real_key[MAX_KEY_SIZE];
1968         union fscrypt_iv iv;
1969         char *tmp;
1970         int c;
1971
1972         memset(&params, 0, sizeof(params));
1973
1974         aes_init();
1975
1976 #ifdef ENABLE_ALG_TESTS
1977         test_aes();
1978         test_sha2();
1979         test_hkdf_sha512();
1980         test_aes_256_xts();
1981         test_aes_256_cts_cbc();
1982         test_adiantum();
1983 #endif
1984
1985         while ((c = getopt_long(argc, argv, "", longopts, NULL)) != -1) {
1986                 switch (c) {
1987                 case OPT_BLOCK_NUMBER:
1988                         errno = 0;
1989                         params.block_number = strtoull(optarg, &tmp, 10);
1990                         if (*tmp || errno)
1991                                 die("Invalid block number: %s", optarg);
1992                         break;
1993                 case OPT_BLOCK_SIZE:
1994                         errno = 0;
1995                         block_size = strtoul(optarg, &tmp, 10);
1996                         if (block_size <= 0 || *tmp || errno)
1997                                 die("Invalid block size: %s", optarg);
1998                         break;
1999                 case OPT_DECRYPT:
2000                         decrypting = true;
2001                         break;
2002                 case OPT_FILE_NONCE:
2003                         if (hex2bin(optarg, params.file_nonce, FILE_NONCE_SIZE)
2004                             != FILE_NONCE_SIZE)
2005                                 die("Invalid file nonce: %s", optarg);
2006                         params.file_nonce_specified = true;
2007                         break;
2008                 case OPT_FS_UUID:
2009                         if (hex2bin(optarg, params.fs_uuid, UUID_SIZE)
2010                             != UUID_SIZE)
2011                                 die("Invalid filesystem UUID: %s", optarg);
2012                         params.fs_uuid_specified = true;
2013                         break;
2014                 case OPT_HELP:
2015                         usage(stdout);
2016                         return 0;
2017                 case OPT_INODE_NUMBER:
2018                         errno = 0;
2019                         params.inode_number = strtoull(optarg, &tmp, 10);
2020                         if (params.inode_number <= 0 || *tmp || errno)
2021                                 die("Invalid inode number: %s", optarg);
2022                         break;
2023                 case OPT_IV_INO_LBLK_32:
2024                         params.iv_ino_lblk_32 = true;
2025                         break;
2026                 case OPT_IV_INO_LBLK_64:
2027                         params.iv_ino_lblk_64 = true;
2028                         break;
2029                 case OPT_KDF:
2030                         params.kdf = parse_kdf_algorithm(optarg);
2031                         break;
2032                 case OPT_MODE_NUM:
2033                         params.mode_num = parse_mode_number(optarg);
2034                         break;
2035                 case OPT_PADDING:
2036                         padding = strtoul(optarg, &tmp, 10);
2037                         if (padding <= 0 || *tmp || !is_power_of_2(padding) ||
2038                             padding > INT_MAX)
2039                                 die("Invalid padding amount: %s", optarg);
2040                         break;
2041                 default:
2042                         usage(stderr);
2043                         return 2;
2044                 }
2045         }
2046         argc -= optind;
2047         argv += optind;
2048
2049         if (argc != 2) {
2050                 usage(stderr);
2051                 return 2;
2052         }
2053
2054         cipher = find_fscrypt_cipher(argv[0]);
2055         if (cipher == NULL)
2056                 die("Unknown cipher: %s", argv[0]);
2057
2058         if (block_size < cipher->min_input_size)
2059                 die("Block size of %zu bytes is too small for cipher %s",
2060                     block_size, cipher->name);
2061
2062         params.master_key_size = hex2bin(argv[1], params.master_key,
2063                                          MAX_KEY_SIZE);
2064         if (params.master_key_size < 0)
2065                 die("Invalid master_key: %s", argv[1]);
2066         if (params.master_key_size < cipher->keysize)
2067                 die("Master key is too short for cipher %s", cipher->name);
2068
2069         get_key_and_iv(&params, real_key, cipher->keysize, &iv);
2070
2071         crypt_loop(cipher, real_key, &iv, decrypting, block_size, padding,
2072                    params.iv_ino_lblk_64 || params.iv_ino_lblk_32);
2073         return 0;
2074 }