]> git-server-git.apps.pok.os.sepia.ceph.com Git - ceph.git/commitdiff
librbd/crypto: auto detect plaintext parent
authorOr Ozeri <oro@il.ibm.com>
Sun, 21 Aug 2022 08:49:59 +0000 (11:49 +0300)
committerOr Ozeri <oro@il.ibm.com>
Thu, 25 Aug 2022 15:41:42 +0000 (18:41 +0300)
Encryption loading (i.e. rbd_encryption_load) gets a single passphrase
and tries to applies it to all ancestor images. If it fails, the entire load fails.
This commits extends encryption loading to assume ancestor is actually
in plaintext format if no known encryption header magic is detected.

Signed-off-by: Or Ozeri <oro@il.ibm.com>
14 files changed:
src/include/rbd/librbd.h
src/include/rbd/librbd.hpp
src/librbd/crypto/EncryptionFormat.h
src/librbd/crypto/LoadRequest.cc
src/librbd/crypto/LoadRequest.h
src/librbd/crypto/luks/Header.cc
src/librbd/crypto/luks/Header.h
src/librbd/crypto/luks/LUKSEncryptionFormat.cc
src/librbd/crypto/luks/LUKSEncryptionFormat.h
src/librbd/crypto/luks/LoadRequest.cc
src/librbd/crypto/luks/LoadRequest.h
src/test/librbd/crypto/luks/test_mock_LoadRequest.cc
src/test/librbd/crypto/test_mock_LoadRequest.cc
src/test/librbd/mock/crypto/MockEncryptionFormat.h

index a717cc58ab73a3ea66479df17e0f4fcfbb5b6d35..e0da63dd82218efe04e9a312bd36dfcee979d38a 100644 (file)
@@ -827,6 +827,8 @@ CEPH_RBD_API int rbd_encryption_format(rbd_image_t image,
                                        rbd_encryption_format_t format,
                                        rbd_encryption_options_t opts,
                                        size_t opts_size);
+/* encryption will be loaded on all ancestor images,
+ * until reaching an ancestor image which does not match any known format */
 CEPH_RBD_API int rbd_encryption_load(rbd_image_t image,
                                      rbd_encryption_format_t format,
                                      rbd_encryption_options_t opts,
index 1f6af1ebbe42914cc361c21a9c7f0179acc7c065..6e84a0a0f5c9659c5979a036214ba48c5088aaca 100644 (file)
@@ -597,6 +597,8 @@ public:
   /* encryption */
   int encryption_format(encryption_format_t format, encryption_options_t opts,
                         size_t opts_size);
+  /* encryption will be loaded on all ancestor images,
+   * until reaching an ancestor image which does not match any known format */
   int encryption_load(encryption_format_t format, encryption_options_t opts,
                       size_t opts_size);
 
index 515e8f931e4d0a356c4cccb72dcc8385115a06ff..98c62c2169cc7f4ee4220518feaeb8d6942e9766 100644 (file)
@@ -21,7 +21,8 @@ struct EncryptionFormat {
 
   virtual std::unique_ptr<EncryptionFormat<ImageCtxT>> clone() const = 0;
   virtual void format(ImageCtxT* ictx, Context* on_finish) = 0;
-  virtual void load(ImageCtxT* ictx, Context* on_finish) = 0;
+  virtual void load(ImageCtxT* ictx, std::string* detected_format_name,
+                    Context* on_finish) = 0;
   virtual void flatten(ImageCtxT* ictx, Context* on_finish) = 0;
 
   virtual CryptoInterface* get_crypto() = 0;
index 27ffb02a8f8c1c0877be77ed826582aa5890e428..5bc57d693c5d311db647de0cd3500b13ccf06864 100644 (file)
@@ -8,6 +8,7 @@
 #include "librbd/Utils.h"
 #include "librbd/ImageCtx.h"
 #include "librbd/crypto/EncryptionFormat.h"
+#include "librbd/crypto/Types.h"
 #include "librbd/crypto/Utils.h"
 #include "librbd/io/AioCompletion.h"
 #include "librbd/io/ImageDispatcherInterface.h"
@@ -95,9 +96,11 @@ template <typename I>
 void LoadRequest<I>::load() {
   ldout(m_image_ctx->cct, 20) << "format_idx=" << m_format_idx << dendl;
 
+  m_detected_format_name = "";
   auto ctx = create_context_callback<
           LoadRequest<I>, &LoadRequest<I>::handle_load>(this);
-  m_formats[m_format_idx]->load(m_current_image_ctx, ctx);
+  m_formats[m_format_idx]->load(m_current_image_ctx, &m_detected_format_name,
+                                ctx);
 }
 
 template <typename I>
@@ -105,12 +108,27 @@ void LoadRequest<I>::handle_load(int r) {
   ldout(m_image_ctx->cct, 20) << "r=" << r << dendl;
 
   if (r < 0) {
+    if (m_is_current_format_cloned &&
+        m_detected_format_name == UNKNOWN_FORMAT) {
+      // encryption format was not detected, assume plaintext
+      ldout(m_image_ctx->cct, 5) << "assuming plaintext for image "
+                                 << m_current_image_ctx->name << dendl;
+      m_formats.pop_back();
+      invalidate_cache();
+      return;
+    }
+
     lderr(m_image_ctx->cct) << "failed to load encryption. image name: "
                             << m_current_image_ctx->name << dendl;
     finish(r);
     return;
   }
 
+  ldout(m_image_ctx->cct, 5) << "loaded format " << m_detected_format_name
+                             << (m_is_current_format_cloned ? " (cloned)" : "")
+                             << " for image " << m_current_image_ctx->name
+                             << dendl;
+
   m_format_idx++;
   m_current_image_ctx = m_current_image_ctx->parent;
   if (m_current_image_ctx != nullptr) {
index e09510d7f0e947b06c571746aaf61dd9bc8862b4..84f595bb6c61631c5cfa58b14abd2bcfe55ca90c 100644 (file)
@@ -20,6 +20,8 @@ class LoadRequest {
 public:
     using EncryptionFormat = decltype(I::encryption_format);
 
+    static constexpr char UNKNOWN_FORMAT[] = "<unknown>";
+
     static LoadRequest* create(
             I* image_ctx, std::vector<EncryptionFormat>&& formats,
             Context* on_finish) {
@@ -45,6 +47,7 @@ private:
     bool m_is_current_format_cloned;
     std::vector<EncryptionFormat> m_formats;
     I* m_current_image_ctx;
+    std::string m_detected_format_name;
 };
 
 } // namespace crypto
index f6018e65a1b0b4518bb7ce1499a70714a266fcc0..0866f285f100f02ac8732fa89e131543b977a856 100644 (file)
@@ -251,6 +251,11 @@ const char* Header::get_cipher_mode() {
   return crypt_get_cipher_mode(m_cd);
 }
 
+const char* Header::get_format_name() {
+  ceph_assert(m_cd != nullptr);
+  return crypt_get_type(m_cd);
+}
+
 } // namespace luks
 } // namespace crypto
 } // namespace librbd
index cee80a8e4628fdf06b37348d20808088e46aaafb..067d96b4ae8606065b157b526ab9112bb803009f 100644 (file)
@@ -33,6 +33,7 @@ public:
     uint64_t get_data_offset();
     const char* get_cipher();
     const char* get_cipher_mode();
+    const char* get_format_name();
 
 private:
     void libcryptsetup_log(int level, const char* msg);
index 6c81c20ddbf2b59e6f280b4f4451df51f46f41ff..e739714932508424d62508c6627ed53910137c7c 100644 (file)
@@ -45,9 +45,10 @@ void LUKSEncryptionFormat<I>::format(I* image_ctx, Context* on_finish) {
 }
 
 template <typename I>
-void LUKSEncryptionFormat<I>::load(I* image_ctx, Context* on_finish) {
+void LUKSEncryptionFormat<I>::load(
+        I* image_ctx, std::string* detected_format_name, Context* on_finish) {
   auto req = luks::LoadRequest<I>::create(
-          image_ctx, m_passphrase, &m_crypto, on_finish);
+          image_ctx, m_passphrase, &m_crypto, detected_format_name, on_finish);
   req->send();
 }
 
index 7b21a7f14601b12e971fb680f701eae2122e9a49..3aa8950f6272c23c062bfbb20831fb754187363f 100644 (file)
@@ -36,7 +36,8 @@ public:
     }
 
     void format(ImageCtxT* ictx, Context* on_finish) override;
-    void load(ImageCtxT* ictx, Context* on_finish) override;
+    void load(ImageCtxT* ictx, std::string* detected_format_name,
+              Context* on_finish) override;
     void flatten(ImageCtxT* ictx, Context* on_finish) override;
 
     CryptoInterface* get_crypto() override {
index d1f636389ed515b7b32fae88301d301ef04ccd00..61eaebdbefae25ed40cac2a9e2d53dc878fd04a0 100644 (file)
@@ -7,6 +7,7 @@
 #include "common/errno.h"
 #include "librbd/Utils.h"
 #include "librbd/crypto/Utils.h"
+#include "librbd/crypto/LoadRequest.h"
 #include "librbd/crypto/luks/Magic.h"
 #include "librbd/io/AioCompletion.h"
 #include "librbd/io/ImageDispatchSpec.h"
@@ -27,10 +28,12 @@ template <typename I>
 LoadRequest<I>::LoadRequest(
         I* image_ctx, std::string_view passphrase,
         std::unique_ptr<CryptoInterface>* result_crypto,
+        std::string* detected_format_name,
         Context* on_finish) : m_image_ctx(image_ctx),
                               m_passphrase(passphrase),
                               m_on_finish(on_finish),
                               m_result_crypto(result_crypto),
+                              m_detected_format_name(detected_format_name),
                               m_initial_read_size(DEFAULT_INITIAL_READ_SIZE),
                               m_header(image_ctx->cct), m_offset(0) {
 }
@@ -72,29 +75,39 @@ bool LoadRequest<I>::handle_read(int r) {
     return false;
   }
 
-  m_offset += m_bl.length();
-
-  if (m_last_read_bl.length() > 0) {
-    m_last_read_bl.claim_append(m_bl);
-    m_bl = std::move(m_last_read_bl);
-  }
+  // first, check LUKS magic at the beginning of the image
+  // If no magic is detected, caller may assume image is actually plaintext
+  if (m_offset == 0) {
+    if (Magic::is_luks(m_bl) > 0 || Magic::is_rbd_clone(m_bl) > 0) {
+      *m_detected_format_name = "LUKS";
+    } else {
+      *m_detected_format_name = crypto::LoadRequest<I>::UNKNOWN_FORMAT;
+      finish(-EINVAL);
+      return false;
+    }
 
-  if (m_image_ctx->parent != nullptr && m_bl.length() == m_offset &&
-      Magic::is_rbd_clone(m_bl) > 0) {
-    r = Magic::replace_magic(m_image_ctx->cct, m_bl);
-    if (r < 0) {
-      if (r == -EINVAL && m_offset < MAXIMUM_HEADER_SIZE) {
-        m_last_read_bl = std::move(m_bl);
-        auto ctx = create_context_callback<
-              LoadRequest<I>, &LoadRequest<I>::handle_read_header>(this);
-        read(MAXIMUM_HEADER_SIZE, ctx);
+    if (m_image_ctx->parent != nullptr && Magic::is_rbd_clone(m_bl) > 0) {
+      r = Magic::replace_magic(m_image_ctx->cct, m_bl);
+      if (r < 0) {
+        m_image_ctx->image_lock.lock_shared();
+        auto image_size = m_image_ctx->get_image_size(m_image_ctx->snap_id);
+        m_image_ctx->image_lock.unlock_shared();
+
+        auto max_header_size = std::min(MAXIMUM_HEADER_SIZE, image_size);
+
+        if (r == -EINVAL && m_bl.length() < max_header_size) {
+          m_bl.clear();
+          auto ctx = create_context_callback<
+                LoadRequest<I>, &LoadRequest<I>::handle_read_header>(this);
+          read(max_header_size, ctx);
+          return false;
+        }
+
+        lderr(m_image_ctx->cct) << "error replacing rbd clone magic: "
+                                << cpp_strerror(r) << dendl;
+        finish(r);
         return false;
       }
-
-      lderr(m_image_ctx->cct) << "error replacing rbd clone magic: "
-                              << cpp_strerror(r) << dendl;
-      finish(r);
-      return false;
     }
   }
   
@@ -105,6 +118,8 @@ bool LoadRequest<I>::handle_read(int r) {
     return false;
   }
 
+  m_offset += m_bl.length();
+
   // write header to libcryptsetup interface
   r = m_header.write(m_bl);
   if (r < 0) {
@@ -128,11 +143,16 @@ void LoadRequest<I>::handle_read_header(int r) {
   // parse header via libcryptsetup
   r = m_header.load(CRYPT_LUKS);
   if (r != 0) {
-    if (m_offset < MAXIMUM_HEADER_SIZE) {
+    m_image_ctx->image_lock.lock_shared();
+    auto image_size = m_image_ctx->get_image_size(m_image_ctx->snap_id);
+    m_image_ctx->image_lock.unlock_shared();
+
+    auto max_header_size = std::min(MAXIMUM_HEADER_SIZE, image_size);
+    if (m_offset < max_header_size) {
       // perhaps we did not feed the entire header to libcryptsetup, retry
       auto ctx = create_context_callback<
               LoadRequest<I>, &LoadRequest<I>::handle_read_header>(this);
-      read(MAXIMUM_HEADER_SIZE, ctx);
+      read(max_header_size, ctx);
       return;
     }
 
@@ -140,6 +160,10 @@ void LoadRequest<I>::handle_read_header(int r) {
     return;
   }
 
+  // gets actual LUKS version (only used for logging)
+  ceph_assert(*m_detected_format_name == "LUKS");
+  *m_detected_format_name = m_header.get_format_name();
+
   auto cipher = m_header.get_cipher();
   if (strcmp(cipher, "aes") != 0) {
     lderr(m_image_ctx->cct) << "unsupported cipher: " << cipher << dendl;
index e8d7bb2d88d9fb2e849d674ac578ee7079e88446..5e64df6ffaee139322498df8f4260a4361c23382 100644 (file)
@@ -28,13 +28,15 @@ public:
     static LoadRequest* create(
             I* image_ctx, std::string_view passphrase,
             std::unique_ptr<CryptoInterface>* result_crypto,
+            std::string* detected_format_name,
             Context* on_finish) {
-      return new LoadRequest(image_ctx, passphrase, result_crypto, on_finish);
+      return new LoadRequest(image_ctx, passphrase, result_crypto,
+                             detected_format_name, on_finish);
     }
 
     LoadRequest(I* image_ctx, std::string_view passphrase,
                 std::unique_ptr<CryptoInterface>* result_crypto,
-                Context* on_finish);
+                std::string* detected_format_name, Context* on_finish);
     void send();
     void finish(int r);
     void set_initial_read_size(uint64_t read_size);
@@ -44,8 +46,8 @@ private:
     std::string_view m_passphrase;
     Context* m_on_finish;
     ceph::bufferlist m_bl;
-    ceph::bufferlist m_last_read_bl;
     std::unique_ptr<CryptoInterface>* m_result_crypto;
+    std::string* m_detected_format_name;
     uint64_t m_initial_read_size;
     Header m_header;
     uint64_t m_offset;
index 1b79c7189e8c3e813f90f6f70857c07df45baf76..c18b1486c9ebbb518710dbb069ccd326b8dfed2f 100644 (file)
@@ -40,6 +40,7 @@ struct TestMockCryptoLuksLoadRequest : public TestMockFixture {
   Context* image_read_request;
   ceph::bufferlist header_bl;
   uint64_t data_offset;
+  std::string detected_format_name;
 
   void SetUp() override {
     TestMockFixture::SetUp();
@@ -48,7 +49,9 @@ struct TestMockCryptoLuksLoadRequest : public TestMockFixture {
     ASSERT_EQ(0, open_image(m_image_name, &ictx));
     mock_image_ctx = new MockImageCtx(*ictx);
     mock_load_request = MockLoadRequest::create(
-            mock_image_ctx, std::move(passphrase), &crypto, on_finish);
+            mock_image_ctx, std::move(passphrase), &crypto,
+            &detected_format_name, on_finish);
+    detected_format_name = "";
   }
 
   void TearDown() override {
@@ -100,6 +103,11 @@ struct TestMockCryptoLuksLoadRequest : public TestMockFixture {
                 image_read_request = ctx;
             }));
   }
+
+  void expect_get_image_size() {
+    EXPECT_CALL(*mock_image_ctx, get_image_size(_)).WillOnce(
+            Return(100 * 1024 * 1024));
+  }
 };
 
 TEST_F(TestMockCryptoLuksLoadRequest, AES128) {
@@ -109,6 +117,7 @@ TEST_F(TestMockCryptoLuksLoadRequest, AES128) {
   image_read_request->complete(DEFAULT_INITIAL_READ_SIZE);
   ASSERT_EQ(0, finished_cond.wait());
   ASSERT_NE(crypto.get(), nullptr);
+  ASSERT_EQ("LUKS2", detected_format_name);
 }
 
 TEST_F(TestMockCryptoLuksLoadRequest, AES256) {
@@ -118,18 +127,21 @@ TEST_F(TestMockCryptoLuksLoadRequest, AES256) {
   image_read_request->complete(DEFAULT_INITIAL_READ_SIZE);
   ASSERT_EQ(0, finished_cond.wait());
   ASSERT_NE(crypto.get(), nullptr);
+  ASSERT_EQ("LUKS2", detected_format_name);
 }
 
 TEST_F(TestMockCryptoLuksLoadRequest, LUKS1) {
   delete mock_load_request;
   mock_load_request = MockLoadRequest::create(
-          mock_image_ctx, {passphrase_cstr}, &crypto, on_finish);
+          mock_image_ctx, {passphrase_cstr}, &crypto, &detected_format_name,
+          on_finish);
   generate_header(CRYPT_LUKS1, "aes", 32, "xts-plain64", 512, false);
   expect_image_read(0, DEFAULT_INITIAL_READ_SIZE);
   mock_load_request->send();
   image_read_request->complete(DEFAULT_INITIAL_READ_SIZE);
   ASSERT_EQ(0, finished_cond.wait());
   ASSERT_NE(crypto.get(), nullptr);
+  ASSERT_EQ("LUKS1", detected_format_name);
 }
 
 TEST_F(TestMockCryptoLuksLoadRequest, WrongFormat) {
@@ -137,14 +149,11 @@ TEST_F(TestMockCryptoLuksLoadRequest, WrongFormat) {
   expect_image_read(0, DEFAULT_INITIAL_READ_SIZE);
   mock_load_request->send();
 
-  expect_image_read(DEFAULT_INITIAL_READ_SIZE,
-                    MAXIMUM_HEADER_SIZE - DEFAULT_INITIAL_READ_SIZE);
-  image_read_request->complete(DEFAULT_INITIAL_READ_SIZE); // complete 1st read
+  image_read_request->complete(DEFAULT_INITIAL_READ_SIZE);
 
-  image_read_request->complete(
-          MAXIMUM_HEADER_SIZE - DEFAULT_INITIAL_READ_SIZE);
   ASSERT_EQ(-EINVAL, finished_cond.wait());
   ASSERT_EQ(crypto.get(), nullptr);
+  ASSERT_EQ("<unknown>", detected_format_name);
 }
 
 TEST_F(TestMockCryptoLuksLoadRequest, UnsupportedAlgorithm) {
@@ -154,6 +163,7 @@ TEST_F(TestMockCryptoLuksLoadRequest, UnsupportedAlgorithm) {
   image_read_request->complete(DEFAULT_INITIAL_READ_SIZE);
   ASSERT_EQ(-ENOTSUP, finished_cond.wait());
   ASSERT_EQ(crypto.get(), nullptr);
+  ASSERT_EQ("LUKS2", detected_format_name);
 }
 
 TEST_F(TestMockCryptoLuksLoadRequest, UnsupportedCipherMode) {
@@ -163,6 +173,7 @@ TEST_F(TestMockCryptoLuksLoadRequest, UnsupportedCipherMode) {
   image_read_request->complete(DEFAULT_INITIAL_READ_SIZE);
   ASSERT_EQ(-ENOTSUP, finished_cond.wait());
   ASSERT_EQ(crypto.get(), nullptr);
+  ASSERT_EQ("LUKS2", detected_format_name);
 }
 
 TEST_F(TestMockCryptoLuksLoadRequest, HeaderBiggerThanInitialRead) {
@@ -171,25 +182,29 @@ TEST_F(TestMockCryptoLuksLoadRequest, HeaderBiggerThanInitialRead) {
   expect_image_read(0, 4096);
   mock_load_request->send();
 
+  expect_get_image_size();
   expect_image_read(4096, MAXIMUM_HEADER_SIZE - 4096);
   image_read_request->complete(4096); // complete initial read
 
   image_read_request->complete(MAXIMUM_HEADER_SIZE - 4096);
   ASSERT_EQ(0, finished_cond.wait());
   ASSERT_NE(crypto.get(), nullptr);
+  ASSERT_EQ("LUKS2", detected_format_name);
 }
 
 TEST_F(TestMockCryptoLuksLoadRequest, LUKS1FormattedClone) {
   mock_image_ctx->parent = mock_image_ctx;
   delete mock_load_request;
   mock_load_request = MockLoadRequest::create(
-          mock_image_ctx, {passphrase_cstr}, &crypto, on_finish);
+          mock_image_ctx, {passphrase_cstr}, &crypto, &detected_format_name,
+          on_finish);
   generate_header(CRYPT_LUKS1, "aes", 64, "xts-plain64", 512, true);
   expect_image_read(0, DEFAULT_INITIAL_READ_SIZE);
   mock_load_request->send();
   image_read_request->complete(DEFAULT_INITIAL_READ_SIZE);
   ASSERT_EQ(0, finished_cond.wait());
   ASSERT_NE(crypto.get(), nullptr);
+  ASSERT_EQ("LUKS1", detected_format_name);
 }
 
 TEST_F(TestMockCryptoLuksLoadRequest, LUKS2FormattedClone) {
@@ -200,6 +215,7 @@ TEST_F(TestMockCryptoLuksLoadRequest, LUKS2FormattedClone) {
   image_read_request->complete(DEFAULT_INITIAL_READ_SIZE);
   ASSERT_EQ(0, finished_cond.wait());
   ASSERT_NE(crypto.get(), nullptr);
+  ASSERT_EQ("LUKS2", detected_format_name);
 }
 
 TEST_F(TestMockCryptoLuksLoadRequest, KeyslotsBiggerThanInitialRead) {
@@ -214,12 +230,13 @@ TEST_F(TestMockCryptoLuksLoadRequest, KeyslotsBiggerThanInitialRead) {
   image_read_request->complete(data_offset - 16384);
   ASSERT_EQ(0, finished_cond.wait());
   ASSERT_NE(crypto.get(), nullptr);
+  ASSERT_EQ("LUKS2", detected_format_name);
 }
 
 TEST_F(TestMockCryptoLuksLoadRequest, WrongPassphrase) {
   delete mock_load_request;
   mock_load_request = MockLoadRequest::create(
-        mock_image_ctx, "wrong", &crypto, on_finish);
+        mock_image_ctx, "wrong", &crypto, &detected_format_name, on_finish);
 
   generate_header(CRYPT_LUKS2, "aes", 64, "xts-plain64", 4096, false);
   expect_image_read(0, DEFAULT_INITIAL_READ_SIZE);
@@ -233,6 +250,7 @@ TEST_F(TestMockCryptoLuksLoadRequest, WrongPassphrase) {
   image_read_request->complete(data_offset - DEFAULT_INITIAL_READ_SIZE);
   ASSERT_EQ(-EPERM, finished_cond.wait());
   ASSERT_EQ(crypto.get(), nullptr);
+  ASSERT_EQ("LUKS2", detected_format_name);
 }
 
 } // namespace luks
index cb3e546db6c97af6350f8904f10017ded2f40dc4..849710d827e83788358e55f033978506a6a31cdd 100644 (file)
@@ -94,10 +94,15 @@ struct TestMockCryptoLoadRequest : public TestMockFixture {
   }
 
   void expect_encryption_load(MockEncryptionFormat* encryption_format,
-                              MockTestImageCtx* ictx) {
+                              MockTestImageCtx* ictx,
+                              std::string detected_format = "SOMEFORMAT") {
     EXPECT_CALL(*encryption_format, load(
-            ictx, _)).WillOnce(
-                    WithArgs<1>(Invoke([this](Context* ctx) {
+            ictx, _, _)).WillOnce(
+                    WithArgs<1, 2>(Invoke([this, detected_format](
+                            std::string* detected_format_name, Context* ctx) {
+                      if (!detected_format.empty()) {
+                        *detected_format_name = detected_format;
+                      }
                       load_context = ctx;
     })));
   }
@@ -258,6 +263,45 @@ TEST_F(TestMockCryptoLoadRequest, LoadClonedParentFail) {
   ASSERT_EQ(nullptr, mock_parent_image_ctx->encryption_format.get());
 }
 
+
+TEST_F(TestMockCryptoLoadRequest, LoadClonedPlaintextParent) {
+  expect_test_journal_feature(mock_image_ctx);
+  expect_test_journal_feature(mock_parent_image_ctx);
+  expect_image_flush();
+  expect_encryption_load(mock_encryption_format, mock_image_ctx);
+  mock_load_request->send();
+  ASSERT_EQ(ETIMEDOUT, finished_cond.wait_for(0));
+  expect_encryption_format_clone(mock_encryption_format);
+  expect_encryption_load(
+          cloned_encryption_format, mock_parent_image_ctx,
+          LoadRequest<MockImageCtx>::UNKNOWN_FORMAT);
+  load_context->complete(0);
+  ASSERT_EQ(ETIMEDOUT, finished_cond.wait_for(0));
+  expect_invalidate_cache();
+  load_context->complete(-EINVAL);
+  ASSERT_EQ(0, finished_cond.wait());
+  ASSERT_EQ(mock_encryption_format, mock_image_ctx->encryption_format.get());
+  ASSERT_EQ(nullptr, mock_parent_image_ctx->encryption_format.get());
+}
+
+TEST_F(TestMockCryptoLoadRequest, LoadClonedParentDetectionError) {
+  expect_test_journal_feature(mock_image_ctx);
+  expect_test_journal_feature(mock_parent_image_ctx);
+  expect_image_flush();
+  expect_encryption_load(mock_encryption_format, mock_image_ctx);
+  mock_load_request->send();
+  ASSERT_EQ(ETIMEDOUT, finished_cond.wait_for(0));
+  expect_encryption_format_clone(mock_encryption_format);
+  expect_encryption_load(
+          cloned_encryption_format, mock_parent_image_ctx, "");
+  load_context->complete(0);
+  ASSERT_EQ(ETIMEDOUT, finished_cond.wait_for(0));
+  load_context->complete(-EINVAL);
+  ASSERT_EQ(-EINVAL, finished_cond.wait());
+  ASSERT_EQ(nullptr, mock_image_ctx->encryption_format.get());
+  ASSERT_EQ(nullptr, mock_parent_image_ctx->encryption_format.get());
+}
+
 TEST_F(TestMockCryptoLoadRequest, LoadParentFail) {
   delete mock_load_request;
   mock_encryption_format = new MockEncryptionFormat();
index 0275b146761ba778d3ffe352e546033c5bba905c..3ad1a54db10b43c308642309ec35c4e11d2548c8 100644 (file)
@@ -15,7 +15,7 @@ namespace crypto {
 struct MockEncryptionFormat {
   MOCK_CONST_METHOD0(clone, std::unique_ptr<MockEncryptionFormat>());
   MOCK_METHOD2(format, void(MockImageCtx*, Context*));
-  MOCK_METHOD2(load, void(MockImageCtx*, Context*));
+  MOCK_METHOD3(load, void(MockImageCtx*, std::string*, Context*));
   MOCK_METHOD2(flatten, void(MockImageCtx*, Context*));
   MOCK_METHOD0(get_crypto, MockCryptoInterface*());
 };