]> git.apps.os.sepia.ceph.com Git - fscrypt.git/commitdiff
Increase checks for invalid HashingCosts
authorJoe Richey <joerichey@google.com>
Sat, 27 Aug 2022 07:32:56 +0000 (00:32 -0700)
committerEric Biggers <ebiggers3@gmail.com>
Sun, 4 Dec 2022 22:36:56 +0000 (14:36 -0800)
Signed-off-by: Joe Richey <joerichey@google.com>
[ebiggers: moved the new checks from PassphraseHash to CheckValidity]
Signed-off-by: Eric Biggers <ebiggers@google.com>
crypto/crypto_test.go
metadata/checks.go

index 10b3d177812528944aac2952e19a42ab719b41a0..f98c643e5ec8e539e4985f8c8dc54739be804237 100644 (file)
@@ -76,6 +76,12 @@ var hashTestCases = []hashTestCase{
                costs:   &metadata.HashingCosts{Time: 1, Memory: 1 << 10, Parallelism: 1},
                hexHash: "a66f5398e33761bf161fdf1273e99b148f07d88d12d85b7673fddd723f95ec34",
        },
+       // Make sure we maintain our backwards compatible behavior, where
+       // Parallelism is truncated to 8-bits unless TruncationFixed is true.
+       {
+               costs:   &metadata.HashingCosts{Time: 1, Memory: 1 << 10, Parallelism: 257},
+               hexHash: "a66f5398e33761bf161fdf1273e99b148f07d88d12d85b7673fddd723f95ec34",
+       },
        {
                costs:   &metadata.HashingCosts{Time: 10, Memory: 1 << 10, Parallelism: 1},
                hexHash: "5fa2cb89db1f7413ba1776258b7c8ee8c377d122078d28fe1fd645c353787f50",
@@ -88,6 +94,15 @@ var hashTestCases = []hashTestCase{
                costs:   &metadata.HashingCosts{Time: 1, Memory: 1 << 10, Parallelism: 10},
                hexHash: "b7c3d7a0be222680b5ea3af3fb1a0b7b02b92cbd7007821dc8b84800c86c7783",
        },
+       {
+               costs:   &metadata.HashingCosts{Time: 1, Memory: 1 << 11, Parallelism: 255},
+               hexHash: "d51af3775bbdd0cba31d96fd6d921d9de27f521ceffe667618cd7624f6643071",
+       },
+       // Adding TruncationFixed shouldn't matter if Parallelism < 256.
+       {
+               costs:   &metadata.HashingCosts{Time: 1, Memory: 1 << 11, Parallelism: 255, TruncationFixed: true},
+               hexHash: "d51af3775bbdd0cba31d96fd6d921d9de27f521ceffe667618cd7624f6643071",
+       },
 }
 
 // Checks that len(array) == expected
@@ -493,16 +508,21 @@ func TestComputeKeyDescriptorBadVersion(t *testing.T) {
 
 // Run our test cases for passphrase hashing
 func TestPassphraseHashing(t *testing.T) {
+       pk, err := fakePassphraseKey()
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer pk.Wipe()
+
        for i, testCase := range hashTestCases {
-               pk, err := fakePassphraseKey()
-               if err != nil {
-                       t.Fatal(err)
+               if err := testCase.costs.CheckValidity(); err != nil {
+                       t.Errorf("Hash test %d: for costs=%+v hashing failed: %v", i, testCase.costs, err)
+                       continue
                }
-               defer pk.Wipe()
-
                hash, err := PassphraseHash(pk, fakeSalt, testCase.costs)
                if err != nil {
-                       t.Fatal(err)
+                       t.Errorf("Hash test %d: for costs=%+v hashing failed: %v", i, testCase.costs, err)
+                       continue
                }
                defer hash.Wipe()
 
@@ -514,6 +534,29 @@ func TestPassphraseHashing(t *testing.T) {
        }
 }
 
+var badCosts = []*metadata.HashingCosts{
+       // Bad Time costs
+       {Time: 0, Memory: 1 << 11, Parallelism: 1},
+       {Time: 1 << 33, Memory: 1 << 11, Parallelism: 1},
+       // Bad Memory costs
+       {Time: 1, Memory: 5, Parallelism: 1},
+       {Time: 1, Memory: 1 << 33, Parallelism: 1},
+       // Bad Parallelism costs
+       {Time: 1, Memory: 1 << 11, Parallelism: 0, TruncationFixed: false},
+       {Time: 1, Memory: 1 << 11, Parallelism: 0, TruncationFixed: true},
+       {Time: 1, Memory: 1 << 11, Parallelism: 256, TruncationFixed: false},
+       {Time: 1, Memory: 1 << 11, Parallelism: 256, TruncationFixed: true},
+       {Time: 1, Memory: 1 << 11, Parallelism: 257, TruncationFixed: true},
+}
+
+func TestBadParameters(t *testing.T) {
+       for i, costs := range badCosts {
+               if costs.CheckValidity() == nil {
+                       t.Errorf("Hash test %d: expected error for costs=%+v", i, costs)
+               }
+       }
+}
+
 func BenchmarkWrap(b *testing.B) {
        for n := 0; n < b.N; n++ {
                Wrap(fakeWrappingKey, fakeValidPolicyKey)
index bddc8a78c87091a27c01a8866f2823530ef46920..d7dea416594bf555d56ad32b660fa42ea436d2ce 100644 (file)
@@ -20,6 +20,9 @@
 package metadata
 
 import (
+       "log"
+       "math"
+
        "github.com/pkg/errors"
        "google.golang.org/protobuf/proto"
 
@@ -57,20 +60,37 @@ func (s SourceType) CheckValidity() error {
        return nil
 }
 
+// MaxParallelism is the maximum allowed value for HashingCosts.Parallelism.
+const MaxParallelism = math.MaxUint8
+
 // CheckValidity ensures the hash costs will be accepted by Argon2.
 func (h *HashingCosts) CheckValidity() error {
        if h == nil {
                return errNotInitialized
        }
-       if h.Time <= 0 {
-               return errors.Errorf("time=%d is not positive", h.Time)
+
+       minP := int64(1)
+       p := uint8(h.Parallelism)
+       if h.Parallelism < minP || h.Parallelism > MaxParallelism {
+               if h.TruncationFixed || p == 0 {
+                       return errors.Errorf("parallelism cost %d is not in range [%d, %d]",
+                               h.Parallelism, minP, MaxParallelism)
+               }
+               // Previously we unconditionally casted costs.Parallelism to a uint8,
+               // so we replicate this behavior for backwards compatibility.
+               log.Printf("WARNING: Truncating parallelism cost of %d to %d", h.Parallelism, p)
        }
-       if h.Parallelism <= 0 {
-               return errors.Errorf("parallelism=%d is not positive", h.Parallelism)
+
+       minT := int64(1)
+       maxT := int64(math.MaxUint32)
+       if h.Time < minT || h.Time > maxT {
+               return errors.Errorf("time cost %d is not in range [%d, %d]", h.Time, minT, maxT)
        }
-       minMemory := 8 * h.Parallelism
-       if h.Memory < minMemory {
-               return errors.Errorf("memory=%d is less than minimum (%d)", h.Memory, minMemory)
+
+       minM := 8 * int64(p)
+       maxM := int64(math.MaxUint32)
+       if h.Memory < minM || h.Memory > maxM {
+               return errors.Errorf("memory cost %d KiB is not in range [%d, %d]", h.Memory, minM, maxM)
        }
        return nil
 }