common/encrypt: support verifying ciphertext of IV_INO_LBLK_64 policies
[xfstests-dev.git] / src / fscrypt-crypt-util.c
1 // SPDX-License-Identifier: GPL-2.0+
2 /*
3  * fscrypt-crypt-util.c - utility for verifying fscrypt-encrypted data
4  *
5  * Copyright 2019 Google LLC
6  */
7
8 /*
9  * This program implements all crypto algorithms supported by fscrypt (a.k.a.
10  * ext4, f2fs, and ubifs encryption), for the purpose of verifying the
11  * correctness of the ciphertext stored on-disk.  See usage() below.
12  *
13  * All algorithms are implemented in portable C code to avoid depending on
14  * libcrypto (OpenSSL), and because some fscrypt-supported algorithms aren't
15  * available in libcrypto anyway (e.g. Adiantum), or are only supported in
16  * recent versions (e.g. HKDF-SHA512).  For simplicity, all crypto code here
17  * tries to follow the mathematical definitions directly, without optimizing for
18  * performance or worrying about following security best practices such as
19  * mitigating side-channel attacks.  So, only use this program for testing!
20  */
21
22 #include <asm/byteorder.h>
23 #include <errno.h>
24 #include <getopt.h>
25 #include <limits.h>
26 #include <linux/types.h>
27 #include <stdarg.h>
28 #include <stdbool.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32 #include <unistd.h>
33
34 #define PROGRAM_NAME "fscrypt-crypt-util"
35
36 /*
37  * Define to enable the tests of the crypto code in this file.  If enabled, you
38  * must link this program with OpenSSL (-lcrypto) v1.1.0 or later, and your
39  * kernel needs CONFIG_CRYPTO_USER_API_SKCIPHER=y and CONFIG_CRYPTO_ADIANTUM=y.
40  */
41 #undef ENABLE_ALG_TESTS
42
43 #define NUM_ALG_TEST_ITERATIONS 10000
44
45 static void usage(FILE *fp)
46 {
47         fputs(
48 "Usage: " PROGRAM_NAME " [OPTION]... CIPHER MASTER_KEY\n"
49 "\n"
50 "Utility for verifying fscrypt-encrypted data.  This program encrypts\n"
51 "(or decrypts) the data on stdin using the given CIPHER with the given\n"
52 "MASTER_KEY (or a key derived from it, if a KDF is specified), and writes the\n"
53 "resulting ciphertext (or plaintext) to stdout.\n"
54 "\n"
55 "CIPHER can be AES-256-XTS, AES-256-CTS-CBC, AES-128-CBC-ESSIV, AES-128-CTS-CBC,\n"
56 "or Adiantum.  MASTER_KEY must be a hex string long enough for the cipher.\n"
57 "\n"
58 "WARNING: this program is only meant for testing, not for \"real\" use!\n"
59 "\n"
60 "Options:\n"
61 "  --block-size=BLOCK_SIZE     Encrypt each BLOCK_SIZE bytes independently.\n"
62 "                                Default: 4096 bytes\n"
63 "  --decrypt                   Decrypt instead of encrypt\n"
64 "  --file-nonce=NONCE          File's nonce as a 32-character hex string\n"
65 "  --fs-uuid=UUID              The filesystem UUID as a 32-character hex string.\n"
66 "                                Only required for --iv-ino-lblk-64.\n"
67 "  --help                      Show this help\n"
68 "  --inode-number=INUM         The file's inode number.  Only required for\n"
69 "                                --iv-ino-lblk-64.\n"
70 "  --iv-ino-lblk-64            Use the format where the IVs include the inode\n"
71 "                                number and the same key is shared across files.\n"
72 "                                Requires --kdf=HKDF-SHA512, --fs-uuid,\n"
73 "                                --inode-number, and --mode-num.\n"
74 "  --kdf=KDF                   Key derivation function to use: AES-128-ECB,\n"
75 "                                HKDF-SHA512, or none.  Default: none\n"
76 "  --mode-num=NUM              Derive per-mode key using mode number NUM\n"
77 "  --padding=PADDING           If last block is partial, zero-pad it to next\n"
78 "                                PADDING-byte boundary.  Default: BLOCK_SIZE\n"
79         , fp);
80 }
81
82 /*----------------------------------------------------------------------------*
83  *                                 Utilities                                  *
84  *----------------------------------------------------------------------------*/
85
86 #define ARRAY_SIZE(A)           (sizeof(A) / sizeof((A)[0]))
87 #define MIN(x, y)               ((x) < (y) ? (x) : (y))
88 #define MAX(x, y)               ((x) > (y) ? (x) : (y))
89 #define ROUND_DOWN(x, y)        ((x) & ~((y) - 1))
90 #define ROUND_UP(x, y)          (((x) + (y) - 1) & ~((y) - 1))
91 #define DIV_ROUND_UP(n, d)      (((n) + (d) - 1) / (d))
92 #define STATIC_ASSERT(e)        ((void)sizeof(char[1 - 2*!(e)]))
93
94 typedef __u8                    u8;
95 typedef __u16                   u16;
96 typedef __u32                   u32;
97 typedef __u64                   u64;
98
99 #define cpu_to_le32             __cpu_to_le32
100 #define cpu_to_be32             __cpu_to_be32
101 #define cpu_to_le64             __cpu_to_le64
102 #define cpu_to_be64             __cpu_to_be64
103 #define le32_to_cpu             __le32_to_cpu
104 #define be32_to_cpu             __be32_to_cpu
105 #define le64_to_cpu             __le64_to_cpu
106 #define be64_to_cpu             __be64_to_cpu
107
108 #define DEFINE_UNALIGNED_ACCESS_HELPERS(type, native_type)      \
109 static inline native_type __attribute__((unused))               \
110 get_unaligned_##type(const void *p)                             \
111 {                                                               \
112         __##type x;                                             \
113                                                                 \
114         memcpy(&x, p, sizeof(x));                               \
115         return type##_to_cpu(x);                                \
116 }                                                               \
117                                                                 \
118 static inline void __attribute__((unused))                      \
119 put_unaligned_##type(native_type v, void *p)                    \
120 {                                                               \
121         __##type x = cpu_to_##type(v);                          \
122                                                                 \
123         memcpy(p, &x, sizeof(x));                               \
124 }
125
126 DEFINE_UNALIGNED_ACCESS_HELPERS(le32, u32)
127 DEFINE_UNALIGNED_ACCESS_HELPERS(be32, u32)
128 DEFINE_UNALIGNED_ACCESS_HELPERS(le64, u64)
129 DEFINE_UNALIGNED_ACCESS_HELPERS(be64, u64)
130
131 static inline bool is_power_of_2(unsigned long v)
132 {
133         return v != 0 && (v & (v - 1)) == 0;
134 }
135
136 static inline u32 rol32(u32 v, int n)
137 {
138         return (v << n) | (v >> (32 - n));
139 }
140
141 static inline u32 ror32(u32 v, int n)
142 {
143         return (v >> n) | (v << (32 - n));
144 }
145
146 static inline u64 ror64(u64 v, int n)
147 {
148         return (v >> n) | (v << (64 - n));
149 }
150
151 static inline void xor(u8 *res, const u8 *a, const u8 *b, size_t count)
152 {
153         while (count--)
154                 *res++ = *a++ ^ *b++;
155 }
156
157 static void __attribute__((noreturn, format(printf, 2, 3)))
158 do_die(int err, const char *format, ...)
159 {
160         va_list va;
161
162         va_start(va, format);
163         fputs("[" PROGRAM_NAME "] ERROR: ", stderr);
164         vfprintf(stderr, format, va);
165         if (err)
166                 fprintf(stderr, ": %s", strerror(errno));
167         putc('\n', stderr);
168         va_end(va);
169         exit(1);
170 }
171
172 #define die(format, ...)        do_die(0,     (format), ##__VA_ARGS__)
173 #define die_errno(format, ...)  do_die(errno, (format), ##__VA_ARGS__)
174
175 static __attribute__((noreturn)) void
176 assertion_failed(const char *expr, const char *file, int line)
177 {
178         die("Assertion failed: %s at %s:%d", expr, file, line);
179 }
180
181 #define ASSERT(e) ({ if (!(e)) assertion_failed(#e, __FILE__, __LINE__); })
182
183 static void *xmalloc(size_t size)
184 {
185         void *p = malloc(size);
186
187         ASSERT(p != NULL);
188         return p;
189 }
190
191 static int hexchar2bin(char c)
192 {
193         if (c >= 'a' && c <= 'f')
194                 return 10 + c - 'a';
195         if (c >= 'A' && c <= 'F')
196                 return 10 + c - 'A';
197         if (c >= '0' && c <= '9')
198                 return c - '0';
199         return -1;
200 }
201
202 static int hex2bin(const char *hex, u8 *bin, int max_bin_size)
203 {
204         size_t len = strlen(hex);
205         size_t i;
206
207         if (len & 1)
208                 return -1;
209         len /= 2;
210         if (len > max_bin_size)
211                 return -1;
212
213         for (i = 0; i < len; i++) {
214                 int high = hexchar2bin(hex[2 * i]);
215                 int low = hexchar2bin(hex[2 * i + 1]);
216
217                 if (high < 0 || low < 0)
218                         return -1;
219                 bin[i] = (high << 4) | low;
220         }
221         return len;
222 }
223
224 static size_t xread(int fd, void *buf, size_t count)
225 {
226         const size_t orig_count = count;
227
228         while (count) {
229                 ssize_t res = read(fd, buf, count);
230
231                 if (res < 0)
232                         die_errno("read error");
233                 if (res == 0)
234                         break;
235                 buf += res;
236                 count -= res;
237         }
238         return orig_count - count;
239 }
240
241 static void full_write(int fd, const void *buf, size_t count)
242 {
243         while (count) {
244                 ssize_t res = write(fd, buf, count);
245
246                 if (res < 0)
247                         die_errno("write error");
248                 buf += res;
249                 count -= res;
250         }
251 }
252
253 #ifdef ENABLE_ALG_TESTS
254 static void rand_bytes(u8 *buf, size_t count)
255 {
256         while (count--)
257                 *buf++ = rand();
258 }
259 #endif
260
261 /*----------------------------------------------------------------------------*
262  *                          Finite field arithmetic                           *
263  *----------------------------------------------------------------------------*/
264
265 /* Multiply a GF(2^8) element by the polynomial 'x' */
266 static inline u8 gf2_8_mul_x(u8 b)
267 {
268         return (b << 1) ^ ((b & 0x80) ? 0x1B : 0);
269 }
270
271 /* Multiply four packed GF(2^8) elements by the polynomial 'x' */
272 static inline u32 gf2_8_mul_x_4way(u32 w)
273 {
274         return ((w & 0x7F7F7F7F) << 1) ^ (((w & 0x80808080) >> 7) * 0x1B);
275 }
276
277 /* Element of GF(2^128) */
278 typedef struct {
279         __le64 lo;
280         __le64 hi;
281 } ble128;
282
283 /* Multiply a GF(2^128) element by the polynomial 'x' */
284 static inline void gf2_128_mul_x(ble128 *t)
285 {
286         u64 lo = le64_to_cpu(t->lo);
287         u64 hi = le64_to_cpu(t->hi);
288
289         t->hi = cpu_to_le64((hi << 1) | (lo >> 63));
290         t->lo = cpu_to_le64((lo << 1) ^ ((hi & (1ULL << 63)) ? 0x87 : 0));
291 }
292
293 /*----------------------------------------------------------------------------*
294  *                             Group arithmetic                               *
295  *----------------------------------------------------------------------------*/
296
297 /* Element of Z/(2^{128}Z)  (a.k.a. the integers modulo 2^128) */
298 typedef struct {
299         __le64 lo;
300         __le64 hi;
301 } le128;
302
303 static inline void le128_add(le128 *res, const le128 *a, const le128 *b)
304 {
305         u64 a_lo = le64_to_cpu(a->lo);
306         u64 b_lo = le64_to_cpu(b->lo);
307
308         res->lo = cpu_to_le64(a_lo + b_lo);
309         res->hi = cpu_to_le64(le64_to_cpu(a->hi) + le64_to_cpu(b->hi) +
310                               (a_lo + b_lo < a_lo));
311 }
312
313 static inline void le128_sub(le128 *res, const le128 *a, const le128 *b)
314 {
315         u64 a_lo = le64_to_cpu(a->lo);
316         u64 b_lo = le64_to_cpu(b->lo);
317
318         res->lo = cpu_to_le64(a_lo - b_lo);
319         res->hi = cpu_to_le64(le64_to_cpu(a->hi) - le64_to_cpu(b->hi) -
320                               (a_lo - b_lo > a_lo));
321 }
322
323 /*----------------------------------------------------------------------------*
324  *                              AES block cipher                              *
325  *----------------------------------------------------------------------------*/
326
327 /*
328  * Reference: "FIPS 197, Advanced Encryption Standard"
329  *      https://nvlpubs.nist.gov/nistpubs/fips/nist.fips.197.pdf
330  */
331
332 #define AES_BLOCK_SIZE          16
333 #define AES_128_KEY_SIZE        16
334 #define AES_192_KEY_SIZE        24
335 #define AES_256_KEY_SIZE        32
336
337 static inline void AddRoundKey(u32 state[4], const u32 *rk)
338 {
339         int i;
340
341         for (i = 0; i < 4; i++)
342                 state[i] ^= rk[i];
343 }
344
345 static const u8 aes_sbox[256] = {
346         0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b,
347         0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0,
348         0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26,
349         0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
350         0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2,
351         0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0,
352         0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed,
353         0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
354         0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f,
355         0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5,
356         0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec,
357         0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
358         0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14,
359         0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c,
360         0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d,
361         0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
362         0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f,
363         0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e,
364         0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11,
365         0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
366         0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f,
367         0xb0, 0x54, 0xbb, 0x16,
368 };
369
370 static u8 aes_inverse_sbox[256];
371
372 static void aes_init(void)
373 {
374         int i;
375
376         for (i = 0; i < 256; i++)
377                 aes_inverse_sbox[aes_sbox[i]] = i;
378 }
379
380 static inline u32 DoSubWord(u32 w, const u8 sbox[256])
381 {
382         return ((u32)sbox[(u8)(w >> 24)] << 24) |
383                ((u32)sbox[(u8)(w >> 16)] << 16) |
384                ((u32)sbox[(u8)(w >>  8)] <<  8) |
385                ((u32)sbox[(u8)(w >>  0)] <<  0);
386 }
387
388 static inline u32 SubWord(u32 w)
389 {
390         return DoSubWord(w, aes_sbox);
391 }
392
393 static inline u32 InvSubWord(u32 w)
394 {
395         return DoSubWord(w, aes_inverse_sbox);
396 }
397
398 static inline void SubBytes(u32 state[4])
399 {
400         int i;
401
402         for (i = 0; i < 4; i++)
403                 state[i] = SubWord(state[i]);
404 }
405
406 static inline void InvSubBytes(u32 state[4])
407 {
408         int i;
409
410         for (i = 0; i < 4; i++)
411                 state[i] = InvSubWord(state[i]);
412 }
413
414 static inline void DoShiftRows(u32 state[4], int direction)
415 {
416         u32 newstate[4];
417         int i;
418
419         for (i = 0; i < 4; i++)
420                 newstate[i] = (state[(i + direction*0) & 3] & 0xff) |
421                               (state[(i + direction*1) & 3] & 0xff00) |
422                               (state[(i + direction*2) & 3] & 0xff0000) |
423                               (state[(i + direction*3) & 3] & 0xff000000);
424         memcpy(state, newstate, 16);
425 }
426
427 static inline void ShiftRows(u32 state[4])
428 {
429         DoShiftRows(state, 1);
430 }
431
432 static inline void InvShiftRows(u32 state[4])
433 {
434         DoShiftRows(state, -1);
435 }
436
437 /*
438  * Mix one column by doing the following matrix multiplication in GF(2^8):
439  *
440  *     | 2 3 1 1 |   | w[0] |
441  *     | 1 2 3 1 |   | w[1] |
442  *     | 1 1 2 3 | x | w[2] |
443  *     | 3 1 1 2 |   | w[3] |
444  *
445  * a.k.a. w[i] = 2*w[i] + 3*w[(i+1)%4] + w[(i+2)%4] + w[(i+3)%4]
446  */
447 static inline u32 MixColumn(u32 w)
448 {
449         u32 _2w0_w2 = gf2_8_mul_x_4way(w) ^ ror32(w, 16);
450         u32 _3w1_w3 = ror32(_2w0_w2 ^ w, 8);
451
452         return _2w0_w2 ^ _3w1_w3;
453 }
454
455 /*
456  *           ( | 5 0 4 0 |   | w[0] | )
457  *          (  | 0 5 0 4 |   | w[1] |  )
458  * MixColumn(  | 4 0 5 0 | x | w[2] |  )
459  *           ( | 0 4 0 5 |   | w[3] | )
460  */
461 static inline u32 InvMixColumn(u32 w)
462 {
463         u32 _4w = gf2_8_mul_x_4way(gf2_8_mul_x_4way(w));
464
465         return MixColumn(_4w ^ w ^ ror32(_4w, 16));
466 }
467
468 static inline void MixColumns(u32 state[4])
469 {
470         int i;
471
472         for (i = 0; i < 4; i++)
473                 state[i] = MixColumn(state[i]);
474 }
475
476 static inline void InvMixColumns(u32 state[4])
477 {
478         int i;
479
480         for (i = 0; i < 4; i++)
481                 state[i] = InvMixColumn(state[i]);
482 }
483
484 struct aes_key {
485         u32 round_keys[15 * 4];
486         int nrounds;
487 };
488
489 /* Expand an AES key */
490 static void aes_setkey(struct aes_key *k, const u8 *key, int keysize)
491 {
492         const int N = keysize / 4;
493         u32 * const rk = k->round_keys;
494         u8 rcon = 1;
495         int i;
496
497         ASSERT(keysize == 16 || keysize == 24 || keysize == 32);
498         k->nrounds = 6 + N;
499         for (i = 0; i < 4 * (k->nrounds + 1); i++) {
500                 if (i < N) {
501                         rk[i] = get_unaligned_le32(&key[i * sizeof(__le32)]);
502                 } else if (i % N == 0) {
503                         rk[i] = rk[i - N] ^ SubWord(ror32(rk[i - 1], 8)) ^ rcon;
504                         rcon = gf2_8_mul_x(rcon);
505                 } else if (N > 6 && i % N == 4) {
506                         rk[i] = rk[i - N] ^ SubWord(rk[i - 1]);
507                 } else {
508                         rk[i] = rk[i - N] ^ rk[i - 1];
509                 }
510         }
511 }
512
513 /* Encrypt one 16-byte block with AES */
514 static void aes_encrypt(const struct aes_key *k, const u8 src[AES_BLOCK_SIZE],
515                         u8 dst[AES_BLOCK_SIZE])
516 {
517         u32 state[4];
518         int i;
519
520         for (i = 0; i < 4; i++)
521                 state[i] = get_unaligned_le32(&src[i * sizeof(__le32)]);
522
523         AddRoundKey(state, k->round_keys);
524         for (i = 1; i < k->nrounds; i++) {
525                 SubBytes(state);
526                 ShiftRows(state);
527                 MixColumns(state);
528                 AddRoundKey(state, &k->round_keys[4 * i]);
529         }
530         SubBytes(state);
531         ShiftRows(state);
532         AddRoundKey(state, &k->round_keys[4 * i]);
533
534         for (i = 0; i < 4; i++)
535                 put_unaligned_le32(state[i], &dst[i * sizeof(__le32)]);
536 }
537
538 /* Decrypt one 16-byte block with AES */
539 static void aes_decrypt(const struct aes_key *k, const u8 src[AES_BLOCK_SIZE],
540                         u8 dst[AES_BLOCK_SIZE])
541 {
542         u32 state[4];
543         int i;
544
545         for (i = 0; i < 4; i++)
546                 state[i] = get_unaligned_le32(&src[i * sizeof(__le32)]);
547
548         AddRoundKey(state, &k->round_keys[4 * k->nrounds]);
549         InvShiftRows(state);
550         InvSubBytes(state);
551         for (i = k->nrounds - 1; i >= 1; i--) {
552                 AddRoundKey(state, &k->round_keys[4 * i]);
553                 InvMixColumns(state);
554                 InvShiftRows(state);
555                 InvSubBytes(state);
556         }
557         AddRoundKey(state, k->round_keys);
558
559         for (i = 0; i < 4; i++)
560                 put_unaligned_le32(state[i], &dst[i * sizeof(__le32)]);
561 }
562
563 #ifdef ENABLE_ALG_TESTS
564 #include <openssl/aes.h>
565 static void test_aes_keysize(int keysize)
566 {
567         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
568
569         while (num_tests--) {
570                 struct aes_key k;
571                 AES_KEY ref_k;
572                 u8 key[AES_256_KEY_SIZE];
573                 u8 ptext[AES_BLOCK_SIZE];
574                 u8 ctext[AES_BLOCK_SIZE];
575                 u8 ref_ctext[AES_BLOCK_SIZE];
576                 u8 decrypted[AES_BLOCK_SIZE];
577
578                 rand_bytes(key, keysize);
579                 rand_bytes(ptext, AES_BLOCK_SIZE);
580
581                 aes_setkey(&k, key, keysize);
582                 aes_encrypt(&k, ptext, ctext);
583
584                 ASSERT(AES_set_encrypt_key(key, keysize*8, &ref_k) == 0);
585                 AES_encrypt(ptext, ref_ctext, &ref_k);
586
587                 ASSERT(memcmp(ctext, ref_ctext, AES_BLOCK_SIZE) == 0);
588
589                 aes_decrypt(&k, ctext, decrypted);
590                 ASSERT(memcmp(ptext, decrypted, AES_BLOCK_SIZE) == 0);
591         }
592 }
593
594 static void test_aes(void)
595 {
596         test_aes_keysize(AES_128_KEY_SIZE);
597         test_aes_keysize(AES_192_KEY_SIZE);
598         test_aes_keysize(AES_256_KEY_SIZE);
599 }
600 #endif /* ENABLE_ALG_TESTS */
601
602 /*----------------------------------------------------------------------------*
603  *                            SHA-512 and SHA-256                             *
604  *----------------------------------------------------------------------------*/
605
606 /*
607  * Reference: "FIPS 180-2, Secure Hash Standard"
608  *      https://csrc.nist.gov/csrc/media/publications/fips/180/2/archive/2002-08-01/documents/fips180-2withchangenotice.pdf
609  */
610
611 #define SHA512_DIGEST_SIZE      64
612 #define SHA512_BLOCK_SIZE       128
613
614 #define SHA256_DIGEST_SIZE      32
615 #define SHA256_BLOCK_SIZE       64
616
617 #define Ch(x, y, z)     (((x) & (y)) ^ (~(x) & (z)))
618 #define Maj(x, y, z)    (((x) & (y)) ^ ((x) & (z)) ^ ((y) & (z)))
619
620 #define Sigma512_0(x)   (ror64((x), 28) ^ ror64((x), 34) ^ ror64((x), 39))
621 #define Sigma512_1(x)   (ror64((x), 14) ^ ror64((x), 18) ^ ror64((x), 41))
622 #define sigma512_0(x)   (ror64((x),  1) ^ ror64((x),  8) ^ ((x) >> 7))
623 #define sigma512_1(x)   (ror64((x), 19) ^ ror64((x), 61) ^ ((x) >> 6))
624
625 #define Sigma256_0(x)   (ror32((x),  2) ^ ror32((x), 13) ^ ror32((x), 22))
626 #define Sigma256_1(x)   (ror32((x),  6) ^ ror32((x), 11) ^ ror32((x), 25))
627 #define sigma256_0(x)   (ror32((x),  7) ^ ror32((x), 18) ^ ((x) >>  3))
628 #define sigma256_1(x)   (ror32((x), 17) ^ ror32((x), 19) ^ ((x) >> 10))
629
630 static const u64 sha512_iv[8] = {
631         0x6a09e667f3bcc908, 0xbb67ae8584caa73b, 0x3c6ef372fe94f82b,
632         0xa54ff53a5f1d36f1, 0x510e527fade682d1, 0x9b05688c2b3e6c1f,
633         0x1f83d9abfb41bd6b, 0x5be0cd19137e2179,
634 };
635
636 static const u64 sha512_round_constants[80] = {
637         0x428a2f98d728ae22, 0x7137449123ef65cd, 0xb5c0fbcfec4d3b2f,
638         0xe9b5dba58189dbbc, 0x3956c25bf348b538, 0x59f111f1b605d019,
639         0x923f82a4af194f9b, 0xab1c5ed5da6d8118, 0xd807aa98a3030242,
640         0x12835b0145706fbe, 0x243185be4ee4b28c, 0x550c7dc3d5ffb4e2,
641         0x72be5d74f27b896f, 0x80deb1fe3b1696b1, 0x9bdc06a725c71235,
642         0xc19bf174cf692694, 0xe49b69c19ef14ad2, 0xefbe4786384f25e3,
643         0x0fc19dc68b8cd5b5, 0x240ca1cc77ac9c65, 0x2de92c6f592b0275,
644         0x4a7484aa6ea6e483, 0x5cb0a9dcbd41fbd4, 0x76f988da831153b5,
645         0x983e5152ee66dfab, 0xa831c66d2db43210, 0xb00327c898fb213f,
646         0xbf597fc7beef0ee4, 0xc6e00bf33da88fc2, 0xd5a79147930aa725,
647         0x06ca6351e003826f, 0x142929670a0e6e70, 0x27b70a8546d22ffc,
648         0x2e1b21385c26c926, 0x4d2c6dfc5ac42aed, 0x53380d139d95b3df,
649         0x650a73548baf63de, 0x766a0abb3c77b2a8, 0x81c2c92e47edaee6,
650         0x92722c851482353b, 0xa2bfe8a14cf10364, 0xa81a664bbc423001,
651         0xc24b8b70d0f89791, 0xc76c51a30654be30, 0xd192e819d6ef5218,
652         0xd69906245565a910, 0xf40e35855771202a, 0x106aa07032bbd1b8,
653         0x19a4c116b8d2d0c8, 0x1e376c085141ab53, 0x2748774cdf8eeb99,
654         0x34b0bcb5e19b48a8, 0x391c0cb3c5c95a63, 0x4ed8aa4ae3418acb,
655         0x5b9cca4f7763e373, 0x682e6ff3d6b2b8a3, 0x748f82ee5defb2fc,
656         0x78a5636f43172f60, 0x84c87814a1f0ab72, 0x8cc702081a6439ec,
657         0x90befffa23631e28, 0xa4506cebde82bde9, 0xbef9a3f7b2c67915,
658         0xc67178f2e372532b, 0xca273eceea26619c, 0xd186b8c721c0c207,
659         0xeada7dd6cde0eb1e, 0xf57d4f7fee6ed178, 0x06f067aa72176fba,
660         0x0a637dc5a2c898a6, 0x113f9804bef90dae, 0x1b710b35131c471b,
661         0x28db77f523047d84, 0x32caab7b40c72493, 0x3c9ebe0a15c9bebc,
662         0x431d67c49c100d4c, 0x4cc5d4becb3e42b6, 0x597f299cfc657e2a,
663         0x5fcb6fab3ad6faec, 0x6c44198c4a475817,
664 };
665
666 /* Compute the SHA-512 digest of the given buffer */
667 static void sha512(const u8 *in, size_t inlen, u8 out[SHA512_DIGEST_SIZE])
668 {
669         const size_t msglen = ROUND_UP(inlen + 17, SHA512_BLOCK_SIZE);
670         u8 * const msg = xmalloc(msglen);
671         u64 H[8];
672         int i;
673
674         /* super naive way of handling the padding */
675         memcpy(msg, in, inlen);
676         memset(&msg[inlen], 0, msglen - inlen);
677         msg[inlen] = 0x80;
678         put_unaligned_be64((u64)inlen * 8, &msg[msglen - sizeof(__be64)]);
679         in = msg;
680
681         memcpy(H, sha512_iv, sizeof(H));
682         do {
683                 u64 a = H[0], b = H[1], c = H[2], d = H[3],
684                     e = H[4], f = H[5], g = H[6], h = H[7];
685                 u64 W[80];
686
687                 for (i = 0; i < 16; i++)
688                         W[i] = get_unaligned_be64(&in[i * sizeof(__be64)]);
689                 for (; i < ARRAY_SIZE(W); i++)
690                         W[i] = sigma512_1(W[i - 2]) + W[i - 7] +
691                                sigma512_0(W[i - 15]) + W[i - 16];
692                 for (i = 0; i < ARRAY_SIZE(W); i++) {
693                         u64 T1 = h + Sigma512_1(e) + Ch(e, f, g) +
694                                  sha512_round_constants[i] + W[i];
695                         u64 T2 = Sigma512_0(a) + Maj(a, b, c);
696
697                         h = g; g = f; f = e; e = d + T1;
698                         d = c; c = b; b = a; a = T1 + T2;
699                 }
700                 H[0] += a; H[1] += b; H[2] += c; H[3] += d;
701                 H[4] += e; H[5] += f; H[6] += g; H[7] += h;
702         } while ((in += SHA512_BLOCK_SIZE) != &msg[msglen]);
703
704         for (i = 0; i < ARRAY_SIZE(H); i++)
705                 put_unaligned_be64(H[i], &out[i * sizeof(__be64)]);
706         free(msg);
707 }
708
709 /* Compute the SHA-256 digest of the given buffer */
710 static void sha256(const u8 *in, size_t inlen, u8 out[SHA256_DIGEST_SIZE])
711 {
712         const size_t msglen = ROUND_UP(inlen + 9, SHA256_BLOCK_SIZE);
713         u8 * const msg = xmalloc(msglen);
714         u32 H[8];
715         int i;
716
717         /* super naive way of handling the padding */
718         memcpy(msg, in, inlen);
719         memset(&msg[inlen], 0, msglen - inlen);
720         msg[inlen] = 0x80;
721         put_unaligned_be64((u64)inlen * 8, &msg[msglen - sizeof(__be64)]);
722         in = msg;
723
724         for (i = 0; i < ARRAY_SIZE(H); i++)
725                 H[i] = (u32)(sha512_iv[i] >> 32);
726         do {
727                 u32 a = H[0], b = H[1], c = H[2], d = H[3],
728                     e = H[4], f = H[5], g = H[6], h = H[7];
729                 u32 W[64];
730
731                 for (i = 0; i < 16; i++)
732                         W[i] = get_unaligned_be32(&in[i * sizeof(__be32)]);
733                 for (; i < ARRAY_SIZE(W); i++)
734                         W[i] = sigma256_1(W[i - 2]) + W[i - 7] +
735                                sigma256_0(W[i - 15]) + W[i - 16];
736                 for (i = 0; i < ARRAY_SIZE(W); i++) {
737                         u32 T1 = h + Sigma256_1(e) + Ch(e, f, g) +
738                                  (u32)(sha512_round_constants[i] >> 32) + W[i];
739                         u32 T2 = Sigma256_0(a) + Maj(a, b, c);
740
741                         h = g; g = f; f = e; e = d + T1;
742                         d = c; c = b; b = a; a = T1 + T2;
743                 }
744                 H[0] += a; H[1] += b; H[2] += c; H[3] += d;
745                 H[4] += e; H[5] += f; H[6] += g; H[7] += h;
746         } while ((in += SHA256_BLOCK_SIZE) != &msg[msglen]);
747
748         for (i = 0; i < ARRAY_SIZE(H); i++)
749                 put_unaligned_be32(H[i], &out[i * sizeof(__be32)]);
750         free(msg);
751 }
752
753 #ifdef ENABLE_ALG_TESTS
754 #include <openssl/sha.h>
755 static void test_sha2(void)
756 {
757         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
758
759         while (num_tests--) {
760                 u8 in[4096];
761                 u8 digest[SHA512_DIGEST_SIZE];
762                 u8 ref_digest[SHA512_DIGEST_SIZE];
763                 const size_t inlen = rand() % (1 + sizeof(in));
764
765                 rand_bytes(in, inlen);
766
767                 sha256(in, inlen, digest);
768                 SHA256(in, inlen, ref_digest);
769                 ASSERT(memcmp(digest, ref_digest, SHA256_DIGEST_SIZE) == 0);
770
771                 sha512(in, inlen, digest);
772                 SHA512(in, inlen, ref_digest);
773                 ASSERT(memcmp(digest, ref_digest, SHA512_DIGEST_SIZE) == 0);
774         }
775 }
776 #endif /* ENABLE_ALG_TESTS */
777
778 /*----------------------------------------------------------------------------*
779  *                            HKDF implementation                             *
780  *----------------------------------------------------------------------------*/
781
782 static void hmac_sha512(const u8 *key, size_t keylen, const u8 *msg,
783                         size_t msglen, u8 mac[SHA512_DIGEST_SIZE])
784 {
785         u8 *ibuf = xmalloc(SHA512_BLOCK_SIZE + msglen);
786         u8 obuf[SHA512_BLOCK_SIZE + SHA512_DIGEST_SIZE];
787
788         ASSERT(keylen <= SHA512_BLOCK_SIZE); /* keylen > bs not implemented */
789
790         memset(ibuf, 0x36, SHA512_BLOCK_SIZE);
791         xor(ibuf, ibuf, key, keylen);
792         memcpy(&ibuf[SHA512_BLOCK_SIZE], msg, msglen);
793
794         memset(obuf, 0x5c, SHA512_BLOCK_SIZE);
795         xor(obuf, obuf, key, keylen);
796         sha512(ibuf, SHA512_BLOCK_SIZE + msglen, &obuf[SHA512_BLOCK_SIZE]);
797         sha512(obuf, sizeof(obuf), mac);
798
799         free(ibuf);
800 }
801
802 static void hkdf_sha512(const u8 *ikm, size_t ikmlen,
803                         const u8 *salt, size_t saltlen,
804                         const u8 *info, size_t infolen,
805                         u8 *output, size_t outlen)
806 {
807         static const u8 default_salt[SHA512_DIGEST_SIZE];
808         u8 prk[SHA512_DIGEST_SIZE]; /* pseudorandom key */
809         u8 *buf = xmalloc(1 + infolen + SHA512_DIGEST_SIZE);
810         u8 counter = 1;
811         size_t i;
812
813         if (saltlen == 0) {
814                 salt = default_salt;
815                 saltlen = sizeof(default_salt);
816         }
817
818         /* HKDF-Extract */
819         ASSERT(ikmlen > 0);
820         hmac_sha512(salt, saltlen, ikm, ikmlen, prk);
821
822         /* HKDF-Expand */
823         for (i = 0; i < outlen; i += SHA512_DIGEST_SIZE) {
824                 u8 *p = buf;
825                 u8 tmp[SHA512_DIGEST_SIZE];
826
827                 ASSERT(counter != 0);
828                 if (i > 0) {
829                         memcpy(p, &output[i - SHA512_DIGEST_SIZE],
830                                SHA512_DIGEST_SIZE);
831                         p += SHA512_DIGEST_SIZE;
832                 }
833                 memcpy(p, info, infolen);
834                 p += infolen;
835                 *p++ = counter++;
836                 hmac_sha512(prk, sizeof(prk), buf, p - buf, tmp);
837                 memcpy(&output[i], tmp, MIN(sizeof(tmp), outlen - i));
838         }
839         free(buf);
840 }
841
842 #ifdef ENABLE_ALG_TESTS
843 #include <openssl/evp.h>
844 #include <openssl/kdf.h>
845 static void openssl_hkdf_sha512(const u8 *ikm, size_t ikmlen,
846                                 const u8 *salt, size_t saltlen,
847                                 const u8 *info, size_t infolen,
848                                 u8 *output, size_t outlen)
849 {
850         EVP_PKEY_CTX *pctx = EVP_PKEY_CTX_new_id(EVP_PKEY_HKDF, NULL);
851         size_t actual_outlen = outlen;
852
853         ASSERT(pctx != NULL);
854         ASSERT(EVP_PKEY_derive_init(pctx) > 0);
855         ASSERT(EVP_PKEY_CTX_set_hkdf_md(pctx, EVP_sha512()) > 0);
856         ASSERT(EVP_PKEY_CTX_set1_hkdf_key(pctx, ikm, ikmlen) > 0);
857         ASSERT(EVP_PKEY_CTX_set1_hkdf_salt(pctx, salt, saltlen) > 0);
858         ASSERT(EVP_PKEY_CTX_add1_hkdf_info(pctx, info, infolen) > 0);
859         ASSERT(EVP_PKEY_derive(pctx, output, &actual_outlen) > 0);
860         ASSERT(actual_outlen == outlen);
861         EVP_PKEY_CTX_free(pctx);
862 }
863
864 static void test_hkdf_sha512(void)
865 {
866         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
867
868         while (num_tests--) {
869                 u8 ikm[SHA512_DIGEST_SIZE];
870                 u8 salt[SHA512_DIGEST_SIZE];
871                 u8 info[128];
872                 u8 actual_output[512];
873                 u8 expected_output[sizeof(actual_output)];
874                 size_t ikmlen = 1 + (rand() % sizeof(ikm));
875                 size_t saltlen = rand() % (1 + sizeof(salt));
876                 size_t infolen = rand() % (1 + sizeof(info));
877                 size_t outlen = rand() % (1 + sizeof(actual_output));
878
879                 rand_bytes(ikm, ikmlen);
880                 rand_bytes(salt, saltlen);
881                 rand_bytes(info, infolen);
882
883                 hkdf_sha512(ikm, ikmlen, salt, saltlen, info, infolen,
884                             actual_output, outlen);
885                 openssl_hkdf_sha512(ikm, ikmlen, salt, saltlen, info, infolen,
886                                     expected_output, outlen);
887                 ASSERT(memcmp(actual_output, expected_output, outlen) == 0);
888         }
889 }
890 #endif /* ENABLE_ALG_TESTS */
891
892 /*----------------------------------------------------------------------------*
893  *                            AES encryption modes                            *
894  *----------------------------------------------------------------------------*/
895
896 static void aes_256_xts_crypt(const u8 key[2 * AES_256_KEY_SIZE],
897                               const u8 iv[AES_BLOCK_SIZE], const u8 *src,
898                               u8 *dst, size_t nbytes, bool decrypting)
899 {
900         struct aes_key tweak_key, cipher_key;
901         ble128 t;
902         size_t i;
903
904         ASSERT(nbytes % AES_BLOCK_SIZE == 0);
905         aes_setkey(&cipher_key, key, AES_256_KEY_SIZE);
906         aes_setkey(&tweak_key, &key[AES_256_KEY_SIZE], AES_256_KEY_SIZE);
907         aes_encrypt(&tweak_key, iv, (u8 *)&t);
908         for (i = 0; i < nbytes; i += AES_BLOCK_SIZE) {
909                 xor(&dst[i], &src[i], (const u8 *)&t, AES_BLOCK_SIZE);
910                 if (decrypting)
911                         aes_decrypt(&cipher_key, &dst[i], &dst[i]);
912                 else
913                         aes_encrypt(&cipher_key, &dst[i], &dst[i]);
914                 xor(&dst[i], &dst[i], (const u8 *)&t, AES_BLOCK_SIZE);
915                 gf2_128_mul_x(&t);
916         }
917 }
918
919 static void aes_256_xts_encrypt(const u8 key[2 * AES_256_KEY_SIZE],
920                                 const u8 iv[AES_BLOCK_SIZE], const u8 *src,
921                                 u8 *dst, size_t nbytes)
922 {
923         aes_256_xts_crypt(key, iv, src, dst, nbytes, false);
924 }
925
926 static void aes_256_xts_decrypt(const u8 key[2 * AES_256_KEY_SIZE],
927                                 const u8 iv[AES_BLOCK_SIZE], const u8 *src,
928                                 u8 *dst, size_t nbytes)
929 {
930         aes_256_xts_crypt(key, iv, src, dst, nbytes, true);
931 }
932
933 #ifdef ENABLE_ALG_TESTS
934 #include <openssl/evp.h>
935 static void test_aes_256_xts(void)
936 {
937         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
938         EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
939
940         ASSERT(ctx != NULL);
941         while (num_tests--) {
942                 u8 key[2 * AES_256_KEY_SIZE];
943                 u8 iv[AES_BLOCK_SIZE];
944                 u8 ptext[512];
945                 u8 ctext[sizeof(ptext)];
946                 u8 ref_ctext[sizeof(ptext)];
947                 u8 decrypted[sizeof(ptext)];
948                 const size_t datalen = ROUND_DOWN(rand() % (1 + sizeof(ptext)),
949                                                   AES_BLOCK_SIZE);
950                 int outl, res;
951
952                 rand_bytes(key, sizeof(key));
953                 rand_bytes(iv, sizeof(iv));
954                 rand_bytes(ptext, datalen);
955
956                 aes_256_xts_encrypt(key, iv, ptext, ctext, datalen);
957                 res = EVP_EncryptInit_ex(ctx, EVP_aes_256_xts(), NULL, key, iv);
958                 ASSERT(res > 0);
959                 res = EVP_EncryptUpdate(ctx, ref_ctext, &outl, ptext, datalen);
960                 ASSERT(res > 0);
961                 ASSERT(outl == datalen);
962                 ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
963
964                 aes_256_xts_decrypt(key, iv, ctext, decrypted, datalen);
965                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
966         }
967         EVP_CIPHER_CTX_free(ctx);
968 }
969 #endif /* ENABLE_ALG_TESTS */
970
971 static void aes_cbc_encrypt(const struct aes_key *k,
972                             const u8 iv[AES_BLOCK_SIZE],
973                             const u8 *src, u8 *dst, size_t nbytes)
974 {
975         size_t i;
976
977         ASSERT(nbytes % AES_BLOCK_SIZE == 0);
978         for (i = 0; i < nbytes; i += AES_BLOCK_SIZE) {
979                 xor(&dst[i], &src[i], (i == 0 ? iv : &dst[i - AES_BLOCK_SIZE]),
980                     AES_BLOCK_SIZE);
981                 aes_encrypt(k, &dst[i], &dst[i]);
982         }
983 }
984
985 static void aes_cbc_decrypt(const struct aes_key *k,
986                             const u8 iv[AES_BLOCK_SIZE],
987                             const u8 *src, u8 *dst, size_t nbytes)
988 {
989         size_t i = nbytes;
990
991         ASSERT(i % AES_BLOCK_SIZE == 0);
992         while (i) {
993                 i -= AES_BLOCK_SIZE;
994                 aes_decrypt(k, &src[i], &dst[i]);
995                 xor(&dst[i], &dst[i], (i == 0 ? iv : &src[i - AES_BLOCK_SIZE]),
996                     AES_BLOCK_SIZE);
997         }
998 }
999
1000 static void aes_cts_cbc_encrypt(const u8 *key, int keysize,
1001                                 const u8 iv[AES_BLOCK_SIZE],
1002                                 const u8 *src, u8 *dst, size_t nbytes)
1003 {
1004         const size_t offset = ROUND_DOWN(nbytes - 1, AES_BLOCK_SIZE);
1005         const size_t final_bsize = nbytes - offset;
1006         struct aes_key k;
1007         u8 *pad;
1008         u8 buf[AES_BLOCK_SIZE];
1009
1010         ASSERT(nbytes >= AES_BLOCK_SIZE);
1011
1012         aes_setkey(&k, key, keysize);
1013
1014         if (nbytes == AES_BLOCK_SIZE)
1015                 return aes_cbc_encrypt(&k, iv, src, dst, nbytes);
1016
1017         aes_cbc_encrypt(&k, iv, src, dst, offset);
1018         pad = &dst[offset - AES_BLOCK_SIZE];
1019
1020         memcpy(buf, pad, AES_BLOCK_SIZE);
1021         xor(buf, buf, &src[offset], final_bsize);
1022         memcpy(&dst[offset], pad, final_bsize);
1023         aes_encrypt(&k, buf, pad);
1024 }
1025
1026 static void aes_cts_cbc_decrypt(const u8 *key, int keysize,
1027                                 const u8 iv[AES_BLOCK_SIZE],
1028                                 const u8 *src, u8 *dst, size_t nbytes)
1029 {
1030         const size_t offset = ROUND_DOWN(nbytes - 1, AES_BLOCK_SIZE);
1031         const size_t final_bsize = nbytes - offset;
1032         struct aes_key k;
1033         u8 *pad;
1034
1035         ASSERT(nbytes >= AES_BLOCK_SIZE);
1036
1037         aes_setkey(&k, key, keysize);
1038
1039         if (nbytes == AES_BLOCK_SIZE)
1040                 return aes_cbc_decrypt(&k, iv, src, dst, nbytes);
1041
1042         pad = &dst[offset - AES_BLOCK_SIZE];
1043         aes_decrypt(&k, &src[offset - AES_BLOCK_SIZE], pad);
1044         xor(&dst[offset], &src[offset], pad, final_bsize);
1045         xor(pad, pad, &dst[offset], final_bsize);
1046
1047         aes_cbc_decrypt(&k, (offset == AES_BLOCK_SIZE ?
1048                              iv : &src[offset - 2 * AES_BLOCK_SIZE]),
1049                         pad, pad, AES_BLOCK_SIZE);
1050         aes_cbc_decrypt(&k, iv, src, dst, offset - AES_BLOCK_SIZE);
1051 }
1052
1053 static void aes_256_cts_cbc_encrypt(const u8 key[AES_256_KEY_SIZE],
1054                                     const u8 iv[AES_BLOCK_SIZE],
1055                                     const u8 *src, u8 *dst, size_t nbytes)
1056 {
1057         aes_cts_cbc_encrypt(key, AES_256_KEY_SIZE, iv, src, dst, nbytes);
1058 }
1059
1060 static void aes_256_cts_cbc_decrypt(const u8 key[AES_256_KEY_SIZE],
1061                                     const u8 iv[AES_BLOCK_SIZE],
1062                                     const u8 *src, u8 *dst, size_t nbytes)
1063 {
1064         aes_cts_cbc_decrypt(key, AES_256_KEY_SIZE, iv, src, dst, nbytes);
1065 }
1066
1067 #ifdef ENABLE_ALG_TESTS
1068 #include <openssl/modes.h>
1069 static void aes_block128_f(const unsigned char in[16],
1070                            unsigned char out[16], const void *key)
1071 {
1072         aes_encrypt(key, in, out);
1073 }
1074
1075 static void test_aes_256_cts_cbc(void)
1076 {
1077         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1078
1079         while (num_tests--) {
1080                 u8 key[AES_256_KEY_SIZE];
1081                 u8 iv[AES_BLOCK_SIZE];
1082                 u8 iv_copy[AES_BLOCK_SIZE];
1083                 u8 ptext[512];
1084                 u8 ctext[sizeof(ptext)];
1085                 u8 ref_ctext[sizeof(ptext)];
1086                 u8 decrypted[sizeof(ptext)];
1087                 const size_t datalen = 16 + (rand() % (sizeof(ptext) - 15));
1088                 struct aes_key k;
1089
1090                 rand_bytes(key, sizeof(key));
1091                 rand_bytes(iv, sizeof(iv));
1092                 rand_bytes(ptext, datalen);
1093
1094                 aes_256_cts_cbc_encrypt(key, iv, ptext, ctext, datalen);
1095
1096                 /* OpenSSL doesn't allow datalen=AES_BLOCK_SIZE; Linux does */
1097                 if (datalen != AES_BLOCK_SIZE) {
1098                         aes_setkey(&k, key, sizeof(key));
1099                         memcpy(iv_copy, iv, sizeof(iv));
1100                         ASSERT(CRYPTO_cts128_encrypt_block(ptext, ref_ctext,
1101                                                            datalen, &k, iv_copy,
1102                                                            aes_block128_f)
1103                                == datalen);
1104                         ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1105                 }
1106                 aes_256_cts_cbc_decrypt(key, iv, ctext, decrypted, datalen);
1107                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1108         }
1109 }
1110 #endif /* ENABLE_ALG_TESTS */
1111
1112 static void essiv_generate_iv(const u8 orig_key[AES_128_KEY_SIZE],
1113                               const u8 orig_iv[AES_BLOCK_SIZE],
1114                               u8 real_iv[AES_BLOCK_SIZE])
1115 {
1116         u8 essiv_key[SHA256_DIGEST_SIZE];
1117         struct aes_key essiv;
1118
1119         /* AES encrypt the original IV using a hash of the original key */
1120         STATIC_ASSERT(SHA256_DIGEST_SIZE == AES_256_KEY_SIZE);
1121         sha256(orig_key, AES_128_KEY_SIZE, essiv_key);
1122         aes_setkey(&essiv, essiv_key, AES_256_KEY_SIZE);
1123         aes_encrypt(&essiv, orig_iv, real_iv);
1124 }
1125
1126 static void aes_128_cbc_essiv_encrypt(const u8 key[AES_128_KEY_SIZE],
1127                                       const u8 iv[AES_BLOCK_SIZE],
1128                                       const u8 *src, u8 *dst, size_t nbytes)
1129 {
1130         struct aes_key k;
1131         u8 real_iv[AES_BLOCK_SIZE];
1132
1133         aes_setkey(&k, key, AES_128_KEY_SIZE);
1134         essiv_generate_iv(key, iv, real_iv);
1135         aes_cbc_encrypt(&k, real_iv, src, dst, nbytes);
1136 }
1137
1138 static void aes_128_cbc_essiv_decrypt(const u8 key[AES_128_KEY_SIZE],
1139                                       const u8 iv[AES_BLOCK_SIZE],
1140                                       const u8 *src, u8 *dst, size_t nbytes)
1141 {
1142         struct aes_key k;
1143         u8 real_iv[AES_BLOCK_SIZE];
1144
1145         aes_setkey(&k, key, AES_128_KEY_SIZE);
1146         essiv_generate_iv(key, iv, real_iv);
1147         aes_cbc_decrypt(&k, real_iv, src, dst, nbytes);
1148 }
1149
1150 static void aes_128_cts_cbc_encrypt(const u8 key[AES_128_KEY_SIZE],
1151                                     const u8 iv[AES_BLOCK_SIZE],
1152                                     const u8 *src, u8 *dst, size_t nbytes)
1153 {
1154         aes_cts_cbc_encrypt(key, AES_128_KEY_SIZE, iv, src, dst, nbytes);
1155 }
1156
1157 static void aes_128_cts_cbc_decrypt(const u8 key[AES_128_KEY_SIZE],
1158                                     const u8 iv[AES_BLOCK_SIZE],
1159                                     const u8 *src, u8 *dst, size_t nbytes)
1160 {
1161         aes_cts_cbc_decrypt(key, AES_128_KEY_SIZE, iv, src, dst, nbytes);
1162 }
1163
1164 /*----------------------------------------------------------------------------*
1165  *                           XChaCha12 stream cipher                          *
1166  *----------------------------------------------------------------------------*/
1167
1168 /*
1169  * References:
1170  *   - "XChaCha: eXtended-nonce ChaCha and AEAD_XChaCha20_Poly1305"
1171  *      https://tools.ietf.org/html/draft-arciszewski-xchacha-03
1172  *
1173  *   - "ChaCha, a variant of Salsa20"
1174  *      https://cr.yp.to/chacha/chacha-20080128.pdf
1175  *
1176  *   - "Extending the Salsa20 nonce"
1177  *      https://cr.yp.to/snuffle/xsalsa-20081128.pdf
1178  */
1179
1180 #define CHACHA_KEY_SIZE         32
1181 #define XCHACHA_KEY_SIZE        CHACHA_KEY_SIZE
1182 #define XCHACHA_NONCE_SIZE      24
1183
1184 static void chacha_init_state(u32 state[16], const u8 key[CHACHA_KEY_SIZE],
1185                               const u8 iv[16])
1186 {
1187         static const u8 consts[16] = "expand 32-byte k";
1188         int i;
1189
1190         for (i = 0; i < 4; i++)
1191                 state[i] = get_unaligned_le32(&consts[i * sizeof(__le32)]);
1192         for (i = 0; i < 8; i++)
1193                 state[4 + i] = get_unaligned_le32(&key[i * sizeof(__le32)]);
1194         for (i = 0; i < 4; i++)
1195                 state[12 + i] = get_unaligned_le32(&iv[i * sizeof(__le32)]);
1196 }
1197
1198 #define CHACHA_QUARTERROUND(a, b, c, d)         \
1199         do {                                    \
1200                 a += b; d = rol32(d ^ a, 16);   \
1201                 c += d; b = rol32(b ^ c, 12);   \
1202                 a += b; d = rol32(d ^ a, 8);    \
1203                 c += d; b = rol32(b ^ c, 7);    \
1204         } while (0)
1205
1206 static void chacha_permute(u32 x[16], int nrounds)
1207 {
1208         do {
1209                 /* column round */
1210                 CHACHA_QUARTERROUND(x[0], x[4], x[8], x[12]);
1211                 CHACHA_QUARTERROUND(x[1], x[5], x[9], x[13]);
1212                 CHACHA_QUARTERROUND(x[2], x[6], x[10], x[14]);
1213                 CHACHA_QUARTERROUND(x[3], x[7], x[11], x[15]);
1214
1215                 /* diagonal round */
1216                 CHACHA_QUARTERROUND(x[0], x[5], x[10], x[15]);
1217                 CHACHA_QUARTERROUND(x[1], x[6], x[11], x[12]);
1218                 CHACHA_QUARTERROUND(x[2], x[7], x[8], x[13]);
1219                 CHACHA_QUARTERROUND(x[3], x[4], x[9], x[14]);
1220         } while ((nrounds -= 2) != 0);
1221 }
1222
1223 static void xchacha(const u8 key[XCHACHA_KEY_SIZE],
1224                     const u8 nonce[XCHACHA_NONCE_SIZE],
1225                     const u8 *src, u8 *dst, size_t nbytes, int nrounds)
1226 {
1227         u32 state[16];
1228         u8 real_key[CHACHA_KEY_SIZE];
1229         u8 real_iv[16] = { 0 };
1230         size_t i, j;
1231
1232         /* Compute real key using original key and first 128 nonce bits */
1233         chacha_init_state(state, key, nonce);
1234         chacha_permute(state, nrounds);
1235         for (i = 0; i < 8; i++) /* state words 0..3, 12..15 */
1236                 put_unaligned_le32(state[(i < 4 ? 0 : 8) + i],
1237                                    &real_key[i * sizeof(__le32)]);
1238
1239         /* Now do regular ChaCha, using real key and remaining nonce bits */
1240         memcpy(&real_iv[8], nonce + 16, 8);
1241         chacha_init_state(state, real_key, real_iv);
1242         for (i = 0; i < nbytes; i += 64) {
1243                 u32 x[16];
1244                 __le32 keystream[16];
1245
1246                 memcpy(x, state, 64);
1247                 chacha_permute(x, nrounds);
1248                 for (j = 0; j < 16; j++)
1249                         keystream[j] = cpu_to_le32(x[j] + state[j]);
1250                 xor(&dst[i], &src[i], (u8 *)keystream, MIN(nbytes - i, 64));
1251                 if (++state[12] == 0)
1252                         state[13]++;
1253         }
1254 }
1255
1256 static void xchacha12(const u8 key[XCHACHA_KEY_SIZE],
1257                       const u8 nonce[XCHACHA_NONCE_SIZE],
1258                       const u8 *src, u8 *dst, size_t nbytes)
1259 {
1260         xchacha(key, nonce, src, dst, nbytes, 12);
1261 }
1262
1263 /*----------------------------------------------------------------------------*
1264  *                                 Poly1305                                   *
1265  *----------------------------------------------------------------------------*/
1266
1267 /*
1268  * Note: this is only the Poly1305 Îµ-almost-∆-universal hash function, not the
1269  * full Poly1305 MAC.  I.e., it doesn't add anything at the end.
1270  */
1271
1272 #define POLY1305_KEY_SIZE       16
1273 #define POLY1305_BLOCK_SIZE     16
1274
1275 static void poly1305(const u8 key[POLY1305_KEY_SIZE],
1276                      const u8 *msg, size_t msglen, le128 *out)
1277 {
1278         const u32 limb_mask = 0x3ffffff;        /* limbs are base 2^26 */
1279         const u64 r0 = (get_unaligned_le32(key +  0) >> 0) & 0x3ffffff;
1280         const u64 r1 = (get_unaligned_le32(key +  3) >> 2) & 0x3ffff03;
1281         const u64 r2 = (get_unaligned_le32(key +  6) >> 4) & 0x3ffc0ff;
1282         const u64 r3 = (get_unaligned_le32(key +  9) >> 6) & 0x3f03fff;
1283         const u64 r4 = (get_unaligned_le32(key + 12) >> 8) & 0x00fffff;
1284         u32 h0 = 0, h1 = 0, h2 = 0, h3 = 0, h4 = 0;
1285         u32 g0, g1, g2, g3, g4, ge_p_mask;
1286
1287         /* Partial block support is not necessary for Adiantum */
1288         ASSERT(msglen % POLY1305_BLOCK_SIZE == 0);
1289
1290         while (msglen) {
1291                 u64 d0, d1, d2, d3, d4;
1292
1293                 /* h += *msg */
1294                 h0 += (get_unaligned_le32(msg +  0) >> 0) & limb_mask;
1295                 h1 += (get_unaligned_le32(msg +  3) >> 2) & limb_mask;
1296                 h2 += (get_unaligned_le32(msg +  6) >> 4) & limb_mask;
1297                 h3 += (get_unaligned_le32(msg +  9) >> 6) & limb_mask;
1298                 h4 += (get_unaligned_le32(msg + 12) >> 8) | (1 << 24);
1299
1300                 /* h *= r */
1301                 d0 = h0*r0 + h1*5*r4 + h2*5*r3 + h3*5*r2 + h4*5*r1;
1302                 d1 = h0*r1 + h1*r0   + h2*5*r4 + h3*5*r3 + h4*5*r2;
1303                 d2 = h0*r2 + h1*r1   + h2*r0   + h3*5*r4 + h4*5*r3;
1304                 d3 = h0*r3 + h1*r2   + h2*r1   + h3*r0   + h4*5*r4;
1305                 d4 = h0*r4 + h1*r3   + h2*r2   + h3*r1   + h4*r0;
1306
1307                 /* (partial) h %= 2^130 - 5 */
1308                 d1 += d0 >> 26;         h0 = d0 & limb_mask;
1309                 d2 += d1 >> 26;         h1 = d1 & limb_mask;
1310                 d3 += d2 >> 26;         h2 = d2 & limb_mask;
1311                 d4 += d3 >> 26;         h3 = d3 & limb_mask;
1312                 h0 += (d4 >> 26) * 5;   h4 = d4 & limb_mask;
1313                 h1 += h0 >> 26;         h0 &= limb_mask;
1314
1315                 msg += POLY1305_BLOCK_SIZE;
1316                 msglen -= POLY1305_BLOCK_SIZE;
1317         }
1318
1319         /* fully carry h */
1320         h2 += (h1 >> 26);       h1 &= limb_mask;
1321         h3 += (h2 >> 26);       h2 &= limb_mask;
1322         h4 += (h3 >> 26);       h3 &= limb_mask;
1323         h0 += (h4 >> 26) * 5;   h4 &= limb_mask;
1324         h1 += (h0 >> 26);       h0 &= limb_mask;
1325
1326         /* if (h >= 2^130 - 5) h -= 2^130 - 5; */
1327         g0 = h0 + 5;
1328         g1 = h1 + (g0 >> 26);   g0 &= limb_mask;
1329         g2 = h2 + (g1 >> 26);   g1 &= limb_mask;
1330         g3 = h3 + (g2 >> 26);   g2 &= limb_mask;
1331         g4 = h4 + (g3 >> 26);   g3 &= limb_mask;
1332         ge_p_mask = ~((g4 >> 26) - 1); /* all 1's if h >= 2^130 - 5, else 0 */
1333         h0 = (h0 & ~ge_p_mask) | (g0 & ge_p_mask);
1334         h1 = (h1 & ~ge_p_mask) | (g1 & ge_p_mask);
1335         h2 = (h2 & ~ge_p_mask) | (g2 & ge_p_mask);
1336         h3 = (h3 & ~ge_p_mask) | (g3 & ge_p_mask);
1337         h4 = (h4 & ~ge_p_mask) | (g4 & ge_p_mask & limb_mask);
1338
1339         /* h %= 2^128 */
1340         out->lo = cpu_to_le64(((u64)h2 << 52) | ((u64)h1 << 26) | h0);
1341         out->hi = cpu_to_le64(((u64)h4 << 40) | ((u64)h3 << 14) | (h2 >> 12));
1342 }
1343
1344 /*----------------------------------------------------------------------------*
1345  *                          Adiantum encryption mode                          *
1346  *----------------------------------------------------------------------------*/
1347
1348 /*
1349  * Reference: "Adiantum: length-preserving encryption for entry-level processors"
1350  *      https://tosc.iacr.org/index.php/ToSC/article/view/7360
1351  */
1352
1353 #define ADIANTUM_KEY_SIZE       32
1354 #define ADIANTUM_IV_SIZE        32
1355 #define ADIANTUM_HASH_KEY_SIZE  ((2 * POLY1305_KEY_SIZE) + NH_KEY_SIZE)
1356
1357 #define NH_KEY_SIZE             1072
1358 #define NH_KEY_WORDS            (NH_KEY_SIZE / sizeof(u32))
1359 #define NH_BLOCK_SIZE           1024
1360 #define NH_HASH_SIZE            32
1361 #define NH_MESSAGE_UNIT         16
1362
1363 static u64 nh_pass(const u32 *key, const u8 *msg, size_t msglen)
1364 {
1365         u64 sum = 0;
1366
1367         ASSERT(msglen % NH_MESSAGE_UNIT == 0);
1368         while (msglen) {
1369                 sum += (u64)(u32)(get_unaligned_le32(msg +  0) + key[0]) *
1370                             (u32)(get_unaligned_le32(msg +  8) + key[2]);
1371                 sum += (u64)(u32)(get_unaligned_le32(msg +  4) + key[1]) *
1372                             (u32)(get_unaligned_le32(msg + 12) + key[3]);
1373                 key += NH_MESSAGE_UNIT / sizeof(key[0]);
1374                 msg += NH_MESSAGE_UNIT;
1375                 msglen -= NH_MESSAGE_UNIT;
1376         }
1377         return sum;
1378 }
1379
1380 /* NH Îµ-almost-universal hash function */
1381 static void nh(const u32 *key, const u8 *msg, size_t msglen,
1382                u8 result[NH_HASH_SIZE])
1383 {
1384         size_t i;
1385
1386         for (i = 0; i < NH_HASH_SIZE; i += sizeof(__le64)) {
1387                 put_unaligned_le64(nh_pass(key, msg, msglen), &result[i]);
1388                 key += NH_MESSAGE_UNIT / sizeof(key[0]);
1389         }
1390 }
1391
1392 /* Adiantum's Îµ-almost-∆-universal hash function */
1393 static void adiantum_hash(const u8 key[ADIANTUM_HASH_KEY_SIZE],
1394                           const u8 iv[ADIANTUM_IV_SIZE],
1395                           const u8 *msg, size_t msglen, le128 *result)
1396 {
1397         const u8 *header_poly_key = key;
1398         const u8 *msg_poly_key = header_poly_key + POLY1305_KEY_SIZE;
1399         const u8 *nh_key = msg_poly_key + POLY1305_KEY_SIZE;
1400         u32 nh_key_words[NH_KEY_WORDS];
1401         u8 header[POLY1305_BLOCK_SIZE + ADIANTUM_IV_SIZE];
1402         const size_t num_nh_blocks = DIV_ROUND_UP(msglen, NH_BLOCK_SIZE);
1403         u8 *nh_hashes = xmalloc(num_nh_blocks * NH_HASH_SIZE);
1404         const size_t padded_msglen = ROUND_UP(msglen, NH_MESSAGE_UNIT);
1405         u8 *padded_msg = xmalloc(padded_msglen);
1406         le128 hash1, hash2;
1407         size_t i;
1408
1409         for (i = 0; i < NH_KEY_WORDS; i++)
1410                 nh_key_words[i] = get_unaligned_le32(&nh_key[i * sizeof(u32)]);
1411
1412         /* Hash tweak and message length with first Poly1305 key */
1413         put_unaligned_le64((u64)msglen * 8, header);
1414         put_unaligned_le64(0, &header[sizeof(__le64)]);
1415         memcpy(&header[POLY1305_BLOCK_SIZE], iv, ADIANTUM_IV_SIZE);
1416         poly1305(header_poly_key, header, sizeof(header), &hash1);
1417
1418         /* Hash NH hashes of message blocks using second Poly1305 key */
1419         /* (using a super naive way of handling the padding) */
1420         memcpy(padded_msg, msg, msglen);
1421         memset(&padded_msg[msglen], 0, padded_msglen - msglen);
1422         for (i = 0; i < num_nh_blocks; i++) {
1423                 nh(nh_key_words, &padded_msg[i * NH_BLOCK_SIZE],
1424                    MIN(NH_BLOCK_SIZE, padded_msglen - (i * NH_BLOCK_SIZE)),
1425                    &nh_hashes[i * NH_HASH_SIZE]);
1426         }
1427         poly1305(msg_poly_key, nh_hashes, num_nh_blocks * NH_HASH_SIZE, &hash2);
1428
1429         /* Add the two hashes together to get the final hash */
1430         le128_add(result, &hash1, &hash2);
1431
1432         free(nh_hashes);
1433         free(padded_msg);
1434 }
1435
1436 static void adiantum_crypt(const u8 key[ADIANTUM_KEY_SIZE],
1437                            const u8 iv[ADIANTUM_IV_SIZE], const u8 *src,
1438                            u8 *dst, size_t nbytes, bool decrypting)
1439 {
1440         u8 subkeys[AES_256_KEY_SIZE + ADIANTUM_HASH_KEY_SIZE] = { 0 };
1441         struct aes_key aes_key;
1442         union {
1443                 u8 nonce[XCHACHA_NONCE_SIZE];
1444                 le128 block;
1445         } u = { .nonce = { 1 } };
1446         const size_t bulk_len = nbytes - sizeof(u.block);
1447         le128 hash;
1448
1449         ASSERT(nbytes >= sizeof(u.block));
1450
1451         /* Derive subkeys */
1452         xchacha12(key, u.nonce, subkeys, subkeys, sizeof(subkeys));
1453         aes_setkey(&aes_key, subkeys, AES_256_KEY_SIZE);
1454
1455         /* Hash left part and add to right part */
1456         adiantum_hash(&subkeys[AES_256_KEY_SIZE], iv, src, bulk_len, &hash);
1457         memcpy(&u.block, &src[bulk_len], sizeof(u.block));
1458         le128_add(&u.block, &u.block, &hash);
1459
1460         if (!decrypting) /* Encrypt right part with block cipher */
1461                 aes_encrypt(&aes_key, u.nonce, u.nonce);
1462
1463         /* Encrypt left part with stream cipher, using the computed nonce */
1464         u.nonce[sizeof(u.block)] = 1;
1465         xchacha12(key, u.nonce, src, dst, bulk_len);
1466
1467         if (decrypting) /* Decrypt right part with block cipher */
1468                 aes_decrypt(&aes_key, u.nonce, u.nonce);
1469
1470         /* Finalize right part by subtracting hash of left part */
1471         adiantum_hash(&subkeys[AES_256_KEY_SIZE], iv, dst, bulk_len, &hash);
1472         le128_sub(&u.block, &u.block, &hash);
1473         memcpy(&dst[bulk_len], &u.block, sizeof(u.block));
1474 }
1475
1476 static void adiantum_encrypt(const u8 key[ADIANTUM_KEY_SIZE],
1477                              const u8 iv[ADIANTUM_IV_SIZE],
1478                              const u8 *src, u8 *dst, size_t nbytes)
1479 {
1480         adiantum_crypt(key, iv, src, dst, nbytes, false);
1481 }
1482
1483 static void adiantum_decrypt(const u8 key[ADIANTUM_KEY_SIZE],
1484                              const u8 iv[ADIANTUM_IV_SIZE],
1485                              const u8 *src, u8 *dst, size_t nbytes)
1486 {
1487         adiantum_crypt(key, iv, src, dst, nbytes, true);
1488 }
1489
1490 #ifdef ENABLE_ALG_TESTS
1491 #include <linux/if_alg.h>
1492 #include <sys/socket.h>
1493 #define SOL_ALG 279
1494 static void af_alg_crypt(int algfd, int op, const u8 *key, size_t keylen,
1495                          const u8 *iv, size_t ivlen,
1496                          const u8 *src, u8 *dst, size_t datalen)
1497 {
1498         size_t controllen = CMSG_SPACE(sizeof(int)) +
1499                             CMSG_SPACE(sizeof(struct af_alg_iv) + ivlen);
1500         u8 *control = xmalloc(controllen);
1501         struct iovec iov = { .iov_base = (u8 *)src, .iov_len = datalen };
1502         struct msghdr msg = {
1503                 .msg_iov = &iov,
1504                 .msg_iovlen = 1,
1505                 .msg_control = control,
1506                 .msg_controllen = controllen,
1507         };
1508         struct cmsghdr *cmsg;
1509         struct af_alg_iv *algiv;
1510         int reqfd;
1511
1512         memset(control, 0, controllen);
1513
1514         cmsg = CMSG_FIRSTHDR(&msg);
1515         cmsg->cmsg_len = CMSG_LEN(sizeof(int));
1516         cmsg->cmsg_level = SOL_ALG;
1517         cmsg->cmsg_type = ALG_SET_OP;
1518         *(int *)CMSG_DATA(cmsg) = op;
1519
1520         cmsg = CMSG_NXTHDR(&msg, cmsg);
1521         cmsg->cmsg_len = CMSG_LEN(sizeof(struct af_alg_iv) + ivlen);
1522         cmsg->cmsg_level = SOL_ALG;
1523         cmsg->cmsg_type = ALG_SET_IV;
1524         algiv = (struct af_alg_iv *)CMSG_DATA(cmsg);
1525         algiv->ivlen = ivlen;
1526         memcpy(algiv->iv, iv, ivlen);
1527
1528         if (setsockopt(algfd, SOL_ALG, ALG_SET_KEY, key, keylen) != 0)
1529                 die_errno("can't set key on AF_ALG socket");
1530
1531         reqfd = accept(algfd, NULL, NULL);
1532         if (reqfd < 0)
1533                 die_errno("can't accept() AF_ALG socket");
1534         if (sendmsg(reqfd, &msg, 0) != datalen)
1535                 die_errno("can't sendmsg() AF_ALG request socket");
1536         if (xread(reqfd, dst, datalen) != datalen)
1537                 die("short read from AF_ALG request socket");
1538         close(reqfd);
1539
1540         free(control);
1541 }
1542
1543 static void test_adiantum(void)
1544 {
1545         int algfd = socket(AF_ALG, SOCK_SEQPACKET, 0);
1546         struct sockaddr_alg addr = {
1547                 .salg_type = "skcipher",
1548                 .salg_name = "adiantum(xchacha12,aes)",
1549         };
1550         unsigned long num_tests = NUM_ALG_TEST_ITERATIONS;
1551
1552         if (algfd < 0)
1553                 die_errno("can't create AF_ALG socket");
1554         if (bind(algfd, (struct sockaddr *)&addr, sizeof(addr)) != 0)
1555                 die_errno("can't bind AF_ALG socket to Adiantum algorithm");
1556
1557         while (num_tests--) {
1558                 u8 key[ADIANTUM_KEY_SIZE];
1559                 u8 iv[ADIANTUM_IV_SIZE];
1560                 u8 ptext[4096];
1561                 u8 ctext[sizeof(ptext)];
1562                 u8 ref_ctext[sizeof(ptext)];
1563                 u8 decrypted[sizeof(ptext)];
1564                 const size_t datalen = 16 + (rand() % (sizeof(ptext) - 15));
1565
1566                 rand_bytes(key, sizeof(key));
1567                 rand_bytes(iv, sizeof(iv));
1568                 rand_bytes(ptext, datalen);
1569
1570                 adiantum_encrypt(key, iv, ptext, ctext, datalen);
1571                 af_alg_crypt(algfd, ALG_OP_ENCRYPT, key, sizeof(key),
1572                              iv, sizeof(iv), ptext, ref_ctext, datalen);
1573                 ASSERT(memcmp(ctext, ref_ctext, datalen) == 0);
1574
1575                 adiantum_decrypt(key, iv, ctext, decrypted, datalen);
1576                 ASSERT(memcmp(ptext, decrypted, datalen) == 0);
1577         }
1578         close(algfd);
1579 }
1580 #endif /* ENABLE_ALG_TESTS */
1581
1582 /*----------------------------------------------------------------------------*
1583  *                               Main program                                 *
1584  *----------------------------------------------------------------------------*/
1585
1586 #define FILE_NONCE_SIZE         16
1587 #define UUID_SIZE               16
1588 #define MAX_KEY_SIZE            64
1589
1590 static const struct fscrypt_cipher {
1591         const char *name;
1592         void (*encrypt)(const u8 *key, const u8 *iv, const u8 *src,
1593                         u8 *dst, size_t nbytes);
1594         void (*decrypt)(const u8 *key, const u8 *iv, const u8 *src,
1595                         u8 *dst, size_t nbytes);
1596         int keysize;
1597         int min_input_size;
1598 } fscrypt_ciphers[] = {
1599         {
1600                 .name = "AES-256-XTS",
1601                 .encrypt = aes_256_xts_encrypt,
1602                 .decrypt = aes_256_xts_decrypt,
1603                 .keysize = 2 * AES_256_KEY_SIZE,
1604         }, {
1605                 .name = "AES-256-CTS-CBC",
1606                 .encrypt = aes_256_cts_cbc_encrypt,
1607                 .decrypt = aes_256_cts_cbc_decrypt,
1608                 .keysize = AES_256_KEY_SIZE,
1609                 .min_input_size = AES_BLOCK_SIZE,
1610         }, {
1611                 .name = "AES-128-CBC-ESSIV",
1612                 .encrypt = aes_128_cbc_essiv_encrypt,
1613                 .decrypt = aes_128_cbc_essiv_decrypt,
1614                 .keysize = AES_128_KEY_SIZE,
1615         }, {
1616                 .name = "AES-128-CTS-CBC",
1617                 .encrypt = aes_128_cts_cbc_encrypt,
1618                 .decrypt = aes_128_cts_cbc_decrypt,
1619                 .keysize = AES_128_KEY_SIZE,
1620                 .min_input_size = AES_BLOCK_SIZE,
1621         }, {
1622                 .name = "Adiantum",
1623                 .encrypt = adiantum_encrypt,
1624                 .decrypt = adiantum_decrypt,
1625                 .keysize = ADIANTUM_KEY_SIZE,
1626                 .min_input_size = AES_BLOCK_SIZE,
1627         }
1628 };
1629
1630 static const struct fscrypt_cipher *find_fscrypt_cipher(const char *name)
1631 {
1632         size_t i;
1633
1634         for (i = 0; i < ARRAY_SIZE(fscrypt_ciphers); i++) {
1635                 if (strcmp(fscrypt_ciphers[i].name, name) == 0)
1636                         return &fscrypt_ciphers[i];
1637         }
1638         return NULL;
1639 }
1640
1641 struct fscrypt_iv {
1642         union {
1643                 __le64 block_num;
1644                 u8 bytes[32];
1645         };
1646 };
1647
1648 static void crypt_loop(const struct fscrypt_cipher *cipher, const u8 *key,
1649                        struct fscrypt_iv *iv, bool decrypting,
1650                        size_t block_size, size_t padding)
1651 {
1652         u8 *buf = xmalloc(block_size);
1653         size_t res;
1654
1655         while ((res = xread(STDIN_FILENO, buf, block_size)) > 0) {
1656                 size_t crypt_len = block_size;
1657
1658                 if (padding > 0) {
1659                         crypt_len = MAX(res, cipher->min_input_size);
1660                         crypt_len = ROUND_UP(crypt_len, padding);
1661                         crypt_len = MIN(crypt_len, block_size);
1662                 }
1663                 ASSERT(crypt_len >= res);
1664                 memset(&buf[res], 0, crypt_len - res);
1665
1666                 if (decrypting)
1667                         cipher->decrypt(key, iv->bytes, buf, buf, crypt_len);
1668                 else
1669                         cipher->encrypt(key, iv->bytes, buf, buf, crypt_len);
1670
1671                 full_write(STDOUT_FILENO, buf, crypt_len);
1672
1673                 iv->block_num = cpu_to_le64(le64_to_cpu(iv->block_num) + 1);
1674         }
1675         free(buf);
1676 }
1677
1678 /* The supported key derivation functions */
1679 enum kdf_algorithm {
1680         KDF_NONE,
1681         KDF_AES_128_ECB,
1682         KDF_HKDF_SHA512,
1683 };
1684
1685 static enum kdf_algorithm parse_kdf_algorithm(const char *arg)
1686 {
1687         if (strcmp(arg, "none") == 0)
1688                 return KDF_NONE;
1689         if (strcmp(arg, "AES-128-ECB") == 0)
1690                 return KDF_AES_128_ECB;
1691         if (strcmp(arg, "HKDF-SHA512") == 0)
1692                 return KDF_HKDF_SHA512;
1693         die("Unknown KDF: %s", arg);
1694 }
1695
1696 static u8 parse_mode_number(const char *arg)
1697 {
1698         char *tmp;
1699         long num = strtol(arg, &tmp, 10);
1700
1701         if (num <= 0 || *tmp || (u8)num != num)
1702                 die("Invalid mode number: %s", arg);
1703         return num;
1704 }
1705
1706 static u32 parse_inode_number(const char *arg)
1707 {
1708         char *tmp;
1709         unsigned long long num = strtoull(arg, &tmp, 10);
1710
1711         if (num <= 0 || *tmp)
1712                 die("Invalid inode number: %s", arg);
1713         if ((u32)num != num)
1714                 die("Inode number %s is too large; must be 32-bit", arg);
1715         return num;
1716 }
1717
1718 struct key_and_iv_params {
1719         u8 master_key[MAX_KEY_SIZE];
1720         int master_key_size;
1721         enum kdf_algorithm kdf;
1722         u8 mode_num;
1723         u8 file_nonce[FILE_NONCE_SIZE];
1724         bool file_nonce_specified;
1725         bool iv_ino_lblk_64;
1726         u32 inode_number;
1727         u8 fs_uuid[UUID_SIZE];
1728         bool fs_uuid_specified;
1729 };
1730
1731 #define HKDF_CONTEXT_KEY_IDENTIFIER     1
1732 #define HKDF_CONTEXT_PER_FILE_KEY       2
1733 #define HKDF_CONTEXT_DIRECT_KEY         3
1734 #define HKDF_CONTEXT_IV_INO_LBLK_64_KEY 4
1735
1736 /*
1737  * Get the key and starting IV with which the encryption will actually be done.
1738  * If a KDF was specified, a subkey is derived from the master key and the mode
1739  * number or file nonce.  Otherwise, the master key is used directly.
1740  */
1741 static void get_key_and_iv(const struct key_and_iv_params *params,
1742                            u8 *real_key, size_t real_key_size,
1743                            struct fscrypt_iv *iv)
1744 {
1745         bool file_nonce_in_iv = false;
1746         struct aes_key aes_key;
1747         u8 info[8 + 1 + 1 + UUID_SIZE] = "fscrypt";
1748         size_t infolen = 8;
1749         size_t i;
1750
1751         ASSERT(real_key_size <= params->master_key_size);
1752
1753         memset(iv, 0, sizeof(*iv));
1754
1755         if (params->iv_ino_lblk_64 && params->kdf != KDF_HKDF_SHA512)
1756                 die("--iv-ino-lblk-64 requires --kdf=HKDF-SHA512");
1757
1758         switch (params->kdf) {
1759         case KDF_NONE:
1760                 if (params->mode_num != 0)
1761                         die("--mode-num isn't supported with --kdf=none");
1762                 memcpy(real_key, params->master_key, real_key_size);
1763                 file_nonce_in_iv = true;
1764                 break;
1765         case KDF_AES_128_ECB:
1766                 if (!params->file_nonce_specified)
1767                         die("--file-nonce is required with --kdf=AES-128-ECB");
1768                 if (params->mode_num != 0)
1769                         die("--mode-num isn't supported with --kdf=AES-128-ECB");
1770                 STATIC_ASSERT(FILE_NONCE_SIZE == AES_128_KEY_SIZE);
1771                 ASSERT(real_key_size % AES_BLOCK_SIZE == 0);
1772                 aes_setkey(&aes_key, params->file_nonce, AES_128_KEY_SIZE);
1773                 for (i = 0; i < real_key_size; i += AES_BLOCK_SIZE)
1774                         aes_encrypt(&aes_key, &params->master_key[i],
1775                                     &real_key[i]);
1776                 break;
1777         case KDF_HKDF_SHA512:
1778                 if (params->iv_ino_lblk_64) {
1779                         if (!params->fs_uuid_specified)
1780                                 die("--iv-ino-lblk-64 requires --fs-uuid");
1781                         if (params->inode_number == 0)
1782                                 die("--iv-ino-lblk-64 requires --inode-number");
1783                         if (params->mode_num == 0)
1784                                 die("--iv-ino-lblk-64 requires --mode-num");
1785                         info[infolen++] = HKDF_CONTEXT_IV_INO_LBLK_64_KEY;
1786                         info[infolen++] = params->mode_num;
1787                         memcpy(&info[infolen], params->fs_uuid, UUID_SIZE);
1788                         infolen += UUID_SIZE;
1789                         put_unaligned_le32(params->inode_number, &iv->bytes[4]);
1790                 } else if (params->mode_num != 0) {
1791                         info[infolen++] = HKDF_CONTEXT_DIRECT_KEY;
1792                         info[infolen++] = params->mode_num;
1793                         file_nonce_in_iv = true;
1794                 } else if (params->file_nonce_specified) {
1795                         info[infolen++] = HKDF_CONTEXT_PER_FILE_KEY;
1796                         memcpy(&info[infolen], params->file_nonce,
1797                                FILE_NONCE_SIZE);
1798                         infolen += FILE_NONCE_SIZE;
1799                 } else {
1800                         die("With --kdf=HKDF-SHA512, at least one of --file-nonce and --mode-num must be specified");
1801                 }
1802                 hkdf_sha512(params->master_key, params->master_key_size,
1803                             NULL, 0, info, infolen, real_key, real_key_size);
1804                 break;
1805         default:
1806                 ASSERT(0);
1807         }
1808
1809         if (file_nonce_in_iv && params->file_nonce_specified)
1810                 memcpy(&iv->bytes[8], params->file_nonce, FILE_NONCE_SIZE);
1811 }
1812
1813 enum {
1814         OPT_BLOCK_SIZE,
1815         OPT_DECRYPT,
1816         OPT_FILE_NONCE,
1817         OPT_FS_UUID,
1818         OPT_HELP,
1819         OPT_INODE_NUMBER,
1820         OPT_IV_INO_LBLK_64,
1821         OPT_KDF,
1822         OPT_MODE_NUM,
1823         OPT_PADDING,
1824 };
1825
1826 static const struct option longopts[] = {
1827         { "block-size",      required_argument, NULL, OPT_BLOCK_SIZE },
1828         { "decrypt",         no_argument,       NULL, OPT_DECRYPT },
1829         { "file-nonce",      required_argument, NULL, OPT_FILE_NONCE },
1830         { "fs-uuid",         required_argument, NULL, OPT_FS_UUID },
1831         { "help",            no_argument,       NULL, OPT_HELP },
1832         { "inode-number",    required_argument, NULL, OPT_INODE_NUMBER },
1833         { "iv-ino-lblk-64",  no_argument,       NULL, OPT_IV_INO_LBLK_64 },
1834         { "kdf",             required_argument, NULL, OPT_KDF },
1835         { "mode-num",        required_argument, NULL, OPT_MODE_NUM },
1836         { "padding",         required_argument, NULL, OPT_PADDING },
1837         { NULL, 0, NULL, 0 },
1838 };
1839
1840 int main(int argc, char *argv[])
1841 {
1842         size_t block_size = 4096;
1843         bool decrypting = false;
1844         struct key_and_iv_params params;
1845         size_t padding = 0;
1846         const struct fscrypt_cipher *cipher;
1847         u8 real_key[MAX_KEY_SIZE];
1848         struct fscrypt_iv iv;
1849         char *tmp;
1850         int c;
1851
1852         memset(&params, 0, sizeof(params));
1853
1854         aes_init();
1855
1856 #ifdef ENABLE_ALG_TESTS
1857         test_aes();
1858         test_sha2();
1859         test_hkdf_sha512();
1860         test_aes_256_xts();
1861         test_aes_256_cts_cbc();
1862         test_adiantum();
1863 #endif
1864
1865         while ((c = getopt_long(argc, argv, "", longopts, NULL)) != -1) {
1866                 switch (c) {
1867                 case OPT_BLOCK_SIZE:
1868                         block_size = strtoul(optarg, &tmp, 10);
1869                         if (block_size <= 0 || *tmp)
1870                                 die("Invalid block size: %s", optarg);
1871                         break;
1872                 case OPT_DECRYPT:
1873                         decrypting = true;
1874                         break;
1875                 case OPT_FILE_NONCE:
1876                         if (hex2bin(optarg, params.file_nonce, FILE_NONCE_SIZE)
1877                             != FILE_NONCE_SIZE)
1878                                 die("Invalid file nonce: %s", optarg);
1879                         params.file_nonce_specified = true;
1880                         break;
1881                 case OPT_FS_UUID:
1882                         if (hex2bin(optarg, params.fs_uuid, UUID_SIZE)
1883                             != UUID_SIZE)
1884                                 die("Invalid filesystem UUID: %s", optarg);
1885                         params.fs_uuid_specified = true;
1886                         break;
1887                 case OPT_HELP:
1888                         usage(stdout);
1889                         return 0;
1890                 case OPT_INODE_NUMBER:
1891                         params.inode_number = parse_inode_number(optarg);
1892                         break;
1893                 case OPT_IV_INO_LBLK_64:
1894                         params.iv_ino_lblk_64 = true;
1895                         break;
1896                 case OPT_KDF:
1897                         params.kdf = parse_kdf_algorithm(optarg);
1898                         break;
1899                 case OPT_MODE_NUM:
1900                         params.mode_num = parse_mode_number(optarg);
1901                         break;
1902                 case OPT_PADDING:
1903                         padding = strtoul(optarg, &tmp, 10);
1904                         if (padding <= 0 || *tmp || !is_power_of_2(padding) ||
1905                             padding > INT_MAX)
1906                                 die("Invalid padding amount: %s", optarg);
1907                         break;
1908                 default:
1909                         usage(stderr);
1910                         return 2;
1911                 }
1912         }
1913         argc -= optind;
1914         argv += optind;
1915
1916         if (argc != 2) {
1917                 usage(stderr);
1918                 return 2;
1919         }
1920
1921         cipher = find_fscrypt_cipher(argv[0]);
1922         if (cipher == NULL)
1923                 die("Unknown cipher: %s", argv[0]);
1924
1925         if (block_size < cipher->min_input_size)
1926                 die("Block size of %zu bytes is too small for cipher %s",
1927                     block_size, cipher->name);
1928
1929         params.master_key_size = hex2bin(argv[1], params.master_key,
1930                                          MAX_KEY_SIZE);
1931         if (params.master_key_size < 0)
1932                 die("Invalid master_key: %s", argv[1]);
1933         if (params.master_key_size < cipher->keysize)
1934                 die("Master key is too short for cipher %s", cipher->name);
1935
1936         get_key_and_iv(&params, real_key, cipher->keysize, &iv);
1937
1938         crypt_loop(cipher, real_key, &iv, decrypting, block_size, padding);
1939         return 0;
1940 }