}
}
+func BenchmarkUnwrapNolock(b *testing.B) {
+ UseMlock = false
+ defer func() {
+ UseMlock = true
+ }()
+ data, _ := Wrap(fakeWrappingKey, fakeValidPolicyKey)
+
+ for n := 0; n < b.N; n++ {
+ _, _ = Unwrap(fakeWrappingKey, data)
+ }
+}
+
func BenchmarkRandomWrapUnwrap(b *testing.B) {
for n := 0; n < b.N; n++ {
wk, _ := NewRandomKey(InternalKeyLen)
package crypto
import (
+ "bytes"
+ "encoding/base32"
+ "fmt"
"io"
"os"
"runtime"
return nil
}
+
+var (
+ // The recovery code is base32 with a dash between each block of 8 characters.
+ encoding = base32.StdEncoding
+ blockSize = 8
+ separator = []byte("-")
+ encodedLength = encoding.EncodedLen(InternalKeyLen)
+ decodedLength = encoding.DecodedLen(encodedLength)
+ // RecoveryCodeLength is the number of bytes in every recovery code
+ RecoveryCodeLength = (encodedLength/blockSize)*(blockSize+len(separator)) - len(separator)
+)
+
+// WriteRecoveryCode outputs key's recovery code to the provided writer.
+// WARNING: This recovery key is enough to derive the original key, so it must
+// be given the same level of protection as a raw cryptographic key.
+func WriteRecoveryCode(key *Key, writer io.Writer) error {
+ if key.Len() != InternalKeyLen {
+ return util.InvalidLengthError("key", InternalKeyLen, key.Len())
+ }
+
+ // We store the base32 encoded data (without separators) in a temp key
+ encodedKey, err := newBlankKey(encodedLength)
+ if err != nil {
+ return err
+ }
+ defer encodedKey.Wipe()
+ encoding.Encode(encodedKey.data, key.data)
+
+ w := util.NewErrWriter(writer)
+
+ // Write the blocks with separators between them
+ w.Write(encodedKey.data[:blockSize])
+ for blockStart := blockSize; blockStart < encodedLength; blockStart += blockSize {
+ w.Write(separator)
+
+ blockEnd := util.MinInt(blockStart+blockSize, encodedLength)
+ w.Write(encodedKey.data[blockStart:blockEnd])
+ }
+
+ // If any writes have failed, return the error
+ return w.Err()
+}
+
+// ReadRecoveryCode gets the recovery code from the provided writer and returns
+// the corresponding cryptographic key.
+// WARNING: This recovery key is enough to derive the original key, so it must
+// be given the same level of protection as a raw cryptographic key.
+func ReadRecoveryCode(reader io.Reader) (*Key, error) {
+ // We store the base32 encoded data (without separators) in a temp key
+ encodedKey, err := newBlankKey(encodedLength)
+ if err != nil {
+ return nil, err
+ }
+ defer encodedKey.Wipe()
+
+ r := util.NewErrReader(reader)
+
+ // Read the other blocks, checking the separators between them
+ r.Read(encodedKey.data[:blockSize])
+ inputSeparator := make([]byte, len(separator))
+
+ for blockStart := blockSize; blockStart < encodedLength; blockStart += blockSize {
+ r.Read(inputSeparator)
+ if r.Err() == nil && !bytes.Equal(separator, inputSeparator) {
+ return nil, fmt.Errorf("invalid separator: %q", inputSeparator)
+ }
+
+ blockEnd := util.MinInt(blockStart+blockSize, encodedLength)
+ r.Read(encodedKey.data[blockStart:blockEnd])
+ }
+
+ // If any reads have failed, return the error
+ if r.Err() != nil {
+ return nil, r.Err()
+ }
+
+ // Now we decode the key, resizing if necessary
+ decodedKey, err := newBlankKey(decodedLength)
+ if err != nil {
+ return nil, err
+ }
+ if _, err = encoding.Decode(decodedKey.data, encodedKey.data); err != nil {
+ decodedKey.Wipe()
+ return nil, err
+ }
+ return decodedKey.resize(InternalKeyLen)
+}
--- /dev/null
+/*
+ * recovery_test.go - tests for recovery codes in the crypto package
+ * tests key wrapping/unwrapping and key generation
+ *
+ * Copyright 2017 Google Inc.
+ * Author: Joe Richey (joerichey@google.com)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+
+package crypto
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+)
+
+const fakeSecretRecoveryCode = "EYTCMJRG-EYTCMJRG-EYTCMJRG-EYTCMJRG-EYTCMJRG-EYTCMJRG-EYTA===="
+
+var fakeSecretKey, _ = makeKey(38, InternalKeyLen)
+
+// Note that this function is INSECURE. FOR TESTING ONLY
+func getRecoveryCodeFromKey(key *Key) ([]byte, error) {
+ var buf bytes.Buffer
+ if err := WriteRecoveryCode(key, &buf); err != nil {
+ return nil, err
+ }
+ return buf.Bytes(), nil
+}
+
+func getRandomRecoveryCodeBuffer() ([]byte, error) {
+ key, err := NewRandomKey(InternalKeyLen)
+ if err != nil {
+ return nil, err
+ }
+ return getRecoveryCodeFromKey(key)
+}
+
+func getKeyFromRecoveryCode(buf []byte) (*Key, error) {
+ return ReadRecoveryCode(bytes.NewReader(buf))
+}
+
+// Given a key, make a recovery code from that key, use that code to rederive
+// another key and check if they are the same.
+func testKeyEncodeDecode(key *Key) error {
+ buf, err := getRecoveryCodeFromKey(key)
+ if err != nil {
+ return err
+ }
+
+ key2, err := getKeyFromRecoveryCode(buf)
+ if err != nil {
+ return err
+ }
+
+ if !bytes.Equal(key.data, key2.data) {
+ return fmt.Errorf("encoding then decoding %x didn't yield the same key", key.data)
+ }
+ return nil
+}
+
+// Given a recovery code, make a key from that recovery code, use that key to
+// rederive another recovery code and check if they are the same.
+func testRecoveryDecodeEncode(buf []byte) error {
+ key, err := getKeyFromRecoveryCode(buf)
+ if err != nil {
+ return err
+ }
+
+ buf2, err := getRecoveryCodeFromKey(key)
+ if err != nil {
+ return err
+ }
+
+ if !bytes.Equal(buf, buf2) {
+ return fmt.Errorf("decoding then encoding %x didn't yield the same key", buf)
+ }
+ return nil
+}
+
+func TestGetRandomRecoveryString(t *testing.T) {
+ b, err := getRandomRecoveryCodeBuffer()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ t.Log(string(b))
+ // t.Fail() // Uncomment to see an example random recovery code
+}
+
+func TestFakeSecretKey(t *testing.T) {
+ buf, err := getRecoveryCodeFromKey(fakeSecretKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ recoveryCode := string(buf)
+ if recoveryCode != fakeSecretRecoveryCode {
+ t.Errorf("got '%s' instead of '%s'", recoveryCode, fakeSecretRecoveryCode)
+ }
+}
+
+func TestEncodeDecode(t *testing.T) {
+ key, err := NewRandomKey(InternalKeyLen)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if err = testKeyEncodeDecode(key); err != nil {
+ t.Error(err)
+ }
+}
+
+func TestDecodeEncode(t *testing.T) {
+ buf, err := getRandomRecoveryCodeBuffer()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if err = testRecoveryDecodeEncode(buf); err != nil {
+ t.Error(err)
+ }
+}
+
+func TestWrongLengthError(t *testing.T) {
+ key, err := NewRandomKey(InternalKeyLen - 1)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err = getRecoveryCodeFromKey(key); err == nil {
+ t.Error("key with wrong length should have failed to encode")
+ }
+}
+
+func TestBadCharacterError(t *testing.T) {
+ buf, err := getRandomRecoveryCodeBuffer()
+ // Lowercase letters not allowed
+ buf[3] = 'k'
+ if _, err = getKeyFromRecoveryCode(buf); err == nil {
+ t.Error("lowercase letters should make decoding fail")
+ }
+}
+
+func TestBadEndCharacterError(t *testing.T) {
+ buf, err := getRandomRecoveryCodeBuffer()
+ // Separator must be '-'
+ buf[blockSize] = '_'
+ if _, err = getKeyFromRecoveryCode(buf); err == nil {
+ t.Error("any separator that isn't '-' should make decoding fail")
+ }
+}
+
+func BenchmarkEncode(b *testing.B) {
+ key, err := NewRandomKey(InternalKeyLen)
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ for n := 0; n < b.N; n++ {
+ if _, err = getRecoveryCodeFromKey(key); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkDecode(b *testing.B) {
+ buf, err := getRandomRecoveryCodeBuffer()
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ for n := 0; n < b.N; n++ {
+ if _, err = getKeyFromRecoveryCode(buf); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkEncodeDecode(b *testing.B) {
+ key, err := NewRandomKey(InternalKeyLen)
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ for n := 0; n < b.N; n++ {
+ if err = testKeyEncodeDecode(key); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
+
+func BenchmarkDecodeEncode(b *testing.B) {
+ buf, err := getRandomRecoveryCodeBuffer()
+ if err != nil {
+ b.Fatal(err)
+ }
+
+ for n := 0; n < b.N; n++ {
+ if err = testRecoveryDecodeEncode(buf); err != nil {
+ b.Fatal(err)
+ }
+ }
+}
package util
import (
+ "io"
"unsafe"
)
+// ErrReader wraps an io.Reader, passing along calls to Read() until a read
+// fails. Then, the error is stored, and all subsequent calls to Read() do
+// nothing. This allows you to write code which has many subsequent reads and
+// do all of the error checking at the end. For example:
+//
+// r := NewErrReader(reader)
+// r.Read(foo)
+// io.ReadFull(r, bar)
+// if r.Err() != nil {
+// // Handle error
+// }
+//
+// Taken from https://blog.golang.org/errors-are-values by Rob Pike.
+type ErrReader struct {
+ r io.Reader
+ err error
+}
+
+// NewErrReader creates an ErrReader which wraps the provided reader.
+func NewErrReader(reader io.Reader) *ErrReader {
+ return &ErrReader{r: reader, err: nil}
+}
+
+// Read runs ReadFull on the wrapped reader if no errors have occurred.
+// Otherwise, the previous error is just returned and no reads are attempted.
+func (e *ErrReader) Read(p []byte) (n int, err error) {
+ if e.err == nil {
+ n, e.err = io.ReadFull(e.r, p)
+ }
+ return n, e.err
+}
+
+// Err returns the first encountered err (or nil if no errors occurred).
+func (e *ErrReader) Err() error {
+ return e.err
+}
+
+// ErrWriter works exactly like ErrReader, except with io.Writer.
+type ErrWriter struct {
+ w io.Writer
+ err error
+}
+
+// NewErrWriter creates an ErrWriter which wraps the provided reader.
+func NewErrWriter(writer io.Writer) *ErrWriter {
+ return &ErrWriter{w: writer, err: nil}
+}
+
+// Write runs the wrapped writer's Write if no errors have occurred. Otherwise,
+// the previous error is just returned and no writes are attempted.
+func (e *ErrWriter) Write(p []byte) (n int, err error) {
+ if e.err == nil {
+ n, e.err = e.w.Write(p)
+ }
+ return n, e.err
+}
+
+// Err returns the first encountered err (or nil if no errors occurred).
+func (e *ErrWriter) Err() error {
+ return e.err
+}
+
// Ptr converts an Go byte array to a pointer to the start of the array.
func Ptr(slice []byte) unsafe.Pointer {
return unsafe.Pointer(&slice[0])
}
return outArray[index], true
}
+
+// MinInt returns the lesser of a and b.
+func MinInt(a, b int) int {
+ if a < b {
+ return a
+ }
+ return b
+}