]> git-server-git.apps.pok.os.sepia.ceph.com Git - ceph-ci.git/commitdiff
python-common/cryptotools: move actual crypto opts into a class
authorJohn Mulligan <jmulligan@redhat.com>
Mon, 21 Apr 2025 19:07:59 +0000 (15:07 -0400)
committerKefu Chai <k.chai@proxmox.com>
Thu, 5 Feb 2026 02:50:07 +0000 (10:50 +0800)
The functions now handle the i/o but allow the crypto function class
to centralize the functions that actually use the crypto libs.

Signed-off-by: John Mulligan <jmulligan@redhat.com>
(cherry picked from commit 4e4cfa58c4b124c0b0406619cc14ced0b2422550)

src/python-common/ceph/cryptotools/cryptotools.py

index 2610213525084861f81aa88a8ab3704afebdd57c..52c28d3f6ec92a8954a48a67037bc24ac9ffc0b0 100644 (file)
@@ -14,7 +14,128 @@ import warnings
 from argparse import Namespace
 from OpenSSL import crypto, SSL
 from uuid import uuid4
-from typing import Tuple, Optional
+from typing import Tuple, Any, Dict, Union
+
+
+class InternalError(ValueError):
+    pass
+
+
+class InternalCryptoCaller:
+    def fail(self, msg: str) -> None:
+        raise ValueError(msg)
+
+    def password_hash(self, password: str, salt_password: str) -> str:
+        salt = salt_password.encode() if salt_password else bcrypt.gensalt()
+        return bcrypt.hashpw(password.encode(), salt).decode()
+
+    def verify_password(self, password: str, hashed_password: str) -> bool:
+        _password = password.encode()
+        _hashed_password = hashed_password.encode()
+        try:
+            ok = bcrypt.checkpw(_password, _hashed_password)
+        except ValueError as err:
+            self.fail(str(err))
+        return ok
+
+    def create_private_key(self) -> str:
+        pkey = crypto.PKey()
+        pkey.generate_key(crypto.TYPE_RSA, 2048)
+        return crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey).decode()
+
+    def create_self_signed_cert(
+        self, dname: Dict[str, str], pkey: str
+    ) -> str:
+        _pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, pkey)
+
+        # Create a "subject" object
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            req = crypto.X509Req()
+        subj = req.get_subject()
+
+        # populate the subject with the dname settings
+        for k, v in dname.items():
+            setattr(subj, k, v)
+
+        # create a self-signed cert
+        cert = crypto.X509()
+        cert.set_subject(req.get_subject())
+        cert.set_serial_number(int(uuid4()))
+        cert.gmtime_adj_notBefore(0)
+        cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)  # 10 years
+        cert.set_issuer(cert.get_subject())
+        cert.set_pubkey(_pkey)
+        cert.sign(_pkey, 'sha512')
+        return crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode()
+
+    def _load_cert(self, crt: Union[str, bytes]) -> Any:
+        crt_buffer = crt.encode() if isinstance(crt, str) else crt
+        cert = crypto.load_certificate(crypto.FILETYPE_PEM, crt_buffer)
+        return cert
+
+    def _issuer_info(self, cert: Any) -> Tuple[str, str]:
+        components = cert.get_issuer().get_components()
+        org_name = cn = ''
+        for c in components:
+            if c[0].decode() == 'O':  # org comp
+                org_name = c[1].decode()
+            elif c[0].decode() == 'CN':  # common name comp
+                cn = c[1].decode()
+        return (org_name, cn)
+
+    def certificate_days_to_expire(self, crt: str) -> int:
+        x509 = self._load_cert(crt)
+        no_after = x509.get_notAfter()
+        if not no_after:
+            self.fail("Certificate does not have an expiration date.")
+
+        end_date = datetime.datetime.strptime(
+            no_after.decode(), '%Y%m%d%H%M%SZ'
+        )
+
+        if x509.has_expired():
+            org, cn = self._issuer_info(x509)
+            msg = 'Certificate issued by "%s/%s" expired on %s' % (
+                org,
+                cn,
+                end_date,
+            )
+            self.fail(msg)
+
+        # Certificate still valid, calculate and return days until expiration
+        with warnings.catch_warnings():
+            warnings.simplefilter("ignore")
+            days_until_exp = (end_date - datetime.datetime.utcnow()).days
+        return int(days_until_exp)
+
+    def get_cert_issuer_info(self, crt: str) -> Tuple[str, str]:
+        return self._issuer_info(self._load_cert(crt))
+
+    def verify_tls(self, crt: str, key: str) -> None:
+        try:
+            _key = crypto.load_privatekey(crypto.FILETYPE_PEM, key)
+            _key.check()
+        except (ValueError, crypto.Error) as e:
+            self.fail('Invalid private key: %s' % str(e))
+        try:
+            _crt = self._load_cert(crt)
+        except ValueError as e:
+            self.fail('Invalid certificate key: %s' % str(e))
+
+        try:
+            context = SSL.Context(SSL.TLSv1_METHOD)
+            with warnings.catch_warnings():
+                warnings.simplefilter("ignore")
+                context.use_certificate(_crt)
+                context.use_privatekey(_key)
+            context.check_privatekey()
+        except crypto.Error as e:
+            self.fail(
+                'Private key and certificate do not match up: %s' % str(e)
+            )
+        except SSL.Error as e:
+            self.fail(f'Invalid cert/key pair: {e}')
 
 
 # subcommand functions
@@ -24,118 +145,49 @@ def password_hash(args: Namespace) -> None:
     password = data['password']
     salt_password = data['salt_password']
 
-    if not salt_password:
-        salt = bcrypt.gensalt()
-    else:
-        salt = salt_password.encode()
-
-    hash_str = bcrypt.hashpw(password.encode(), salt).decode()
+    hash_str = InternalCryptoCaller().password_hash(password, salt_password)
     json.dump({'hash': hash_str}, sys.stdout)
 
 
 def verify_password(args: Namespace) -> None:
+    icc = InternalCryptoCaller()
     data = json.loads(sys.stdin.read())
-    password = data.encode('utf-8')
-    hashed_password = data.encode('utf-8')
+    password = data.get('password', '')
+    hashed_password = data.get('hashed_password', '')
     try:
-        ok = bcrypt.checkpw(password, hashed_password)
+        icc.verify_password(password, hashed_password)
     except ValueError as err:
         _fail_message(str(err))
     json.dump({'ok': ok}, sys.stdout)
 
 
 def create_self_signed_cert(args: Namespace) -> None:
-
+    icc = InternalCryptoCaller()
     # Generate private key
     if args.private_key:
         # create a key pair
-        pkey = crypto.PKey()
-        pkey.generate_key(crypto.TYPE_RSA, 2048)
-        print(crypto.dump_privatekey(crypto.FILETYPE_PEM, pkey).decode())
+        print(icc.create_private_key())
         return
 
     data = json.loads(sys.stdin.read())
-
     dname = data['dname']
-    pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, data['private_key'])
-
-    # Create a "subject" object
-    with warnings.catch_warnings():
-        warnings.simplefilter("ignore")
-        req = crypto.X509Req()
-    subj = req.get_subject()
-
-    # populate the subject with the dname settings
-    for k, v in dname.items():
-        setattr(subj, k, v)
-
-    # create a self-signed cert
-    cert = crypto.X509()
-    cert.set_subject(req.get_subject())
-    cert.set_serial_number(int(uuid4()))
-    cert.gmtime_adj_notBefore(0)
-    cert.gmtime_adj_notAfter(10 * 365 * 24 * 60 * 60)  # 10 years
-    cert.set_issuer(cert.get_subject())
-    cert.set_pubkey(pkey)
-    cert.sign(pkey, 'sha512')
-
-    print(crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode())
-
-
-def _get_cert_issuer_info(crt: str) -> Tuple[Optional[str], Optional[str]]:
-    """Basic validation of a CA cert
-    """
-
-    crt_buffer = crt.encode() if isinstance(crt, str) else crt
-    (org_name, cn) = (None, None)
-    cert = crypto.load_certificate(crypto.FILETYPE_PEM, crt_buffer)
-    components = cert.get_issuer().get_components()
-    for c in components:
-        if c[0].decode() == 'O':  # org comp
-            org_name = c[1].decode()
-        elif c[0].decode() == 'CN':  # common name comp
-            cn = c[1].decode()
-
-    return (org_name, cn)
+    print(icc.create_self_signed_cert(dname, data['private_key']))
 
 
 def certificate_days_to_expire(args: Namespace) -> None:
+    icc = InternalCryptoCaller()
     crt = sys.stdin.read()
-
-    crt_buffer = crt.encode() if isinstance(crt, str) else crt
-    x509 = crypto.load_certificate(crypto.FILETYPE_PEM, crt_buffer)
-    no_after = x509.get_notAfter()
-    if not no_after:
-        print("Certificate does not have an expiration date.", file=sys.stderr)
-        sys.exit(1)
-
-    end_date = datetime.datetime.strptime(no_after.decode(), '%Y%m%d%H%M%SZ')
-
-    if x509.has_expired():
-        org, cn = _get_cert_issuer_info(crt)
-        msg = 'Certificate issued by "%s/%s" expired on %s' % (org, cn, end_date)
-        print(msg, file=sys.stderr)
+    try:
+        days_until_exp = icc.certificate_days_to_expire(crt)
+    except InternalError as err:
+        print(str(err), file=sys.stderr)
         sys.exit(1)
-
-    # Certificate still valid, calculate and return days until expiration
-    with warnings.catch_warnings():
-        warnings.simplefilter("ignore")
-        days_until_exp = (end_date - datetime.datetime.utcnow()).days
-        json.dump({'days_until_expiration': int(days_until_exp)}, sys.stdout)
+    json.dump({'days_until_expiration': days_until_exp}, sys.stdout)
 
 
 def get_cert_issuer_info(args: Namespace) -> None:
     crt = sys.stdin.read()
-
-    crt_buffer = crt.encode() if isinstance(crt, str) else crt
-    (org_name, cn) = (None, None)
-    cert = crypto.load_certificate(crypto.FILETYPE_PEM, crt_buffer)
-    components = cert.get_issuer().get_components()
-    for c in components:
-        if c[0].decode() == 'O':  # org comp
-            org_name = c[1].decode()
-        elif c[0].decode() == 'CN':  # common name comp
-            cn = c[1].decode()
+    org_name, cn = InternalCryptoCaller().get_cert_issuer_info(crt)
     json.dump({'org_name': org_name, 'cn': cn}, sys.stdout)
 
 
@@ -151,28 +203,9 @@ def verify_tls(args: Namespace) -> None:
     key = data['key']
 
     try:
-        _key = crypto.load_privatekey(crypto.FILETYPE_PEM, key)
-        _key.check()
-    except (ValueError, crypto.Error) as e:
-        _fail_message('Invalid private key: %s' % str(e))
-    try:
-        crt_buffer = crt.encode() if isinstance(crt, str) else crt
-        _crt = crypto.load_certificate(crypto.FILETYPE_PEM, crt_buffer)
-    except ValueError as e:
-        _fail_message('Invalid certificate key: %s' % str(e))
-
-    try:
-        context = SSL.Context(SSL.TLSv1_METHOD)
-        with warnings.catch_warnings():
-            warnings.simplefilter("ignore")
-            context.use_certificate(_crt)
-            context.use_privatekey(_key)
-
-        context.check_privatekey()
-    except crypto.Error as e:
-        _fail_message('Private key and certificate do not match up: %s' % str(e))
-    except SSL.Error as e:
-        _fail_message(f'Invalid cert/key pair: {e}')
+        InternalCryptoCaller().verify_tls(crt, key)
+    except ValueError as err:
+        json.dump({'error': str(err)}, sys.stdout)
     json.dump({'ok': True}, sys.stdout)  # need to emit something on success