]> git-server-git.apps.pok.os.sepia.ceph.com Git - ceph.git/commitdiff
mgr/dashboard: add SSO through oauth2 protocol 58456/head
authorPedro Gonzalez Gomez <pegonzal@redhat.com>
Mon, 8 Jul 2024 09:19:34 +0000 (11:19 +0200)
committerPedro Gonzalez Gomez <pegonzal@redhat.com>
Mon, 16 Sep 2024 12:03:24 +0000 (14:03 +0200)
Fixes: https://tracker.ceph.com/issues/66900
Signed-off-by: Pedro Gonzalez Gomez <pegonzal@redhat.com>
15 files changed:
qa/tasks/mgr/dashboard/test_auth.py
src/pybind/mgr/dashboard/controllers/auth.py
src/pybind/mgr/dashboard/controllers/oauth2.py [new file with mode: 0644]
src/pybind/mgr/dashboard/controllers/saml2.py
src/pybind/mgr/dashboard/frontend/src/app/shared/api/auth.service.ts
src/pybind/mgr/dashboard/module.py
src/pybind/mgr/dashboard/services/access_control.py
src/pybind/mgr/dashboard/services/auth.py [deleted file]
src/pybind/mgr/dashboard/services/auth/__init__.py [new file with mode: 0644]
src/pybind/mgr/dashboard/services/auth/auth.py [new file with mode: 0644]
src/pybind/mgr/dashboard/services/auth/oauth2.py [new file with mode: 0644]
src/pybind/mgr/dashboard/services/auth/saml2.py [new file with mode: 0644]
src/pybind/mgr/dashboard/services/sso.py
src/pybind/mgr/dashboard/tests/test_auth.py
src/pybind/mgr/dashboard/tests/test_sso.py

index a2266229bef7fd0055d89320c1b9aae14057b291..2b9240b635ec6cd6b749b40a966cdef38efc7284 100644 (file)
@@ -152,7 +152,8 @@ class AuthTest(DashboardTestCase):
         self._post("/api/auth/logout")
         self.assertStatus(200)
         self.assertJsonBody({
-            "redirect_url": "#/login"
+            "redirect_url": "#/login",
+            "protocol": 'local'
         })
         self._get("/api/host", version='1.1')
         self.assertStatus(401)
@@ -167,7 +168,8 @@ class AuthTest(DashboardTestCase):
         self._post("/api/auth/logout", set_cookies=True)
         self.assertStatus(200)
         self.assertJsonBody({
-            "redirect_url": "#/login"
+            "redirect_url": "#/login",
+            "protocol": 'local'
         })
         self._get("/api/host", set_cookies=True, version='1.1')
         self.assertStatus(401)
index 2e6cf855c29773d7ef37798c6f0f3f8d3dc03c62..16276af17e4c4eb55e1ca883f276993db3f40975 100644 (file)
@@ -10,7 +10,7 @@ import cherrypy
 
 from .. import mgr
 from ..exceptions import InvalidCredentialsError, UserDoesNotExist
-from ..services.auth import AuthManager, JwtManager
+from ..services.auth import AuthManager, AuthType, BaseAuth, JwtManager, OAuth2
 from ..services.cluster import ClusterModel
 from ..settings import Settings
 from . import APIDoc, APIRouter, ControllerAuthMixin, EndpointDoc, RESTController, allow_empty_body
@@ -132,7 +132,7 @@ class Auth(RESTController, ControllerAuthMixin):
                     'username': username,
                     'permissions': user_perms,
                     'pwdExpirationDate': pwd_expiration_date,
-                    'sso': mgr.SSO_DB.protocol == 'saml2',
+                    'sso': BaseAuth.from_protocol(mgr.SSO_DB.protocol).sso,
                     'pwdUpdateRequired': pwd_update_required
                 }
             mgr.ACCESS_CTRL_DB.increment_attempt(username)
@@ -156,37 +156,33 @@ class Auth(RESTController, ControllerAuthMixin):
     @RESTController.Collection('POST')
     @allow_empty_body
     def logout(self):
-        logger.debug('Logout successful')
-        token = JwtManager.get_token_from_header()
+        logger.debug('Logout started')
+        token = JwtManager.get_token(cherrypy.request)
         JwtManager.blocklist_token(token)
         self._delete_token_cookie(token)
-        redirect_url = '#/login'
-        if mgr.SSO_DB.protocol == 'saml2':
-            redirect_url = 'auth/saml2/slo'
         return {
-            'redirect_url': redirect_url
+            'redirect_url': BaseAuth.from_db(mgr.SSO_DB).LOGOUT_URL,
+            'protocol': BaseAuth.from_db(mgr.SSO_DB).get_auth_name()
         }
 
-    def _get_login_url(self):
-        if mgr.SSO_DB.protocol == 'saml2':
-            return 'auth/saml2/login'
-        return '#/login'
-
     @RESTController.Collection('POST', query_params=['token'])
     @EndpointDoc("Check token Authentication",
                  parameters={'token': (str, 'Authentication Token')},
                  responses={201: AUTH_CHECK_SCHEMA})
     def check(self, token):
         if token:
-            user = JwtManager.get_user(token)
+            if mgr.SSO_DB.protocol == AuthType.OAUTH2:
+                user = OAuth2.get_user(token)
+            else:
+                user = JwtManager.get_user(token)
             if user:
                 return {
                     'username': user.username,
                     'permissions': user.permissions_dict(),
-                    'sso': mgr.SSO_DB.protocol == 'saml2',
+                    'sso': BaseAuth.from_db(mgr.SSO_DB).sso,
                     'pwdUpdateRequired': user.pwd_update_required
                 }
         return {
-            'login_url': self._get_login_url(),
+            'login_url': BaseAuth.from_db(mgr.SSO_DB).LOGIN_URL,
             'cluster_status': ClusterModel.from_db().dict()['status']
         }
diff --git a/src/pybind/mgr/dashboard/controllers/oauth2.py b/src/pybind/mgr/dashboard/controllers/oauth2.py
new file mode 100644 (file)
index 0000000..ae37c4a
--- /dev/null
@@ -0,0 +1,32 @@
+import cherrypy
+
+from dashboard.exceptions import DashboardException
+from dashboard.services.auth.oauth2 import OAuth2
+
+from . import Endpoint, RESTController, Router
+
+
+@Router('/auth/oauth2', secure=False)
+class Oauth2(RESTController):
+
+    @Endpoint(json_response=False, version=None)
+    def login(self):
+        if not OAuth2.enabled():
+            raise DashboardException(500, msg='Failed to login: SSO OAuth2 is not enabled')
+
+        token = OAuth2.get_token(cherrypy.request)
+        if not token:
+            raise cherrypy.HTTPError()
+
+        raise cherrypy.HTTPRedirect(OAuth2.get_login_redirect_url(token))
+
+    @Endpoint(json_response=False, version=None)
+    def logout(self):
+        if not OAuth2.enabled():
+            raise DashboardException(500, msg='Failed to logout: SSO OAuth2 is not enabled')
+
+        token = OAuth2.get_token(cherrypy.request)
+        if not token:
+            raise cherrypy.HTTPError()
+
+        raise cherrypy.HTTPRedirect(OAuth2.get_logout_redirect_url(token))
index c11b18a27bc7e7e41caf8272ebd13fa0889c3bd6..f834be9587ee425fb594fb2c4008a46559239bce 100644 (file)
@@ -37,7 +37,7 @@ class Saml2(BaseController, ControllerAuthMixin):
         if not python_saml_imported:
             raise cherrypy.HTTPError(400, 'Required library not found: `python3-saml`')
         try:
-            OneLogin_Saml2_Settings(mgr.SSO_DB.saml2.onelogin_settings)
+            OneLogin_Saml2_Settings(mgr.SSO_DB.config.onelogin_settings)
         except OneLogin_Saml2_Error:
             raise cherrypy.HTTPError(400, 'Single Sign-On is not configured.')
 
@@ -46,19 +46,19 @@ class Saml2(BaseController, ControllerAuthMixin):
     def auth_response(self, **kwargs):
         Saml2._check_python_saml()
         req = Saml2._build_req(self._request, kwargs)
-        auth = OneLogin_Saml2_Auth(req, mgr.SSO_DB.saml2.onelogin_settings)
+        auth = OneLogin_Saml2_Auth(req, mgr.SSO_DB.config.onelogin_settings)
         auth.process_response()
         errors = auth.get_errors()
 
         if auth.is_authenticated():
             JwtManager.reset_user()
-            username_attribute = auth.get_attribute(mgr.SSO_DB.saml2.get_username_attribute())
+            username_attribute = auth.get_attribute(mgr.SSO_DB.config.get_username_attribute())
             if username_attribute is None:
                 raise cherrypy.HTTPError(400,
                                          'SSO error - `{}` not found in auth attributes. '
                                          'Received attributes: {}'
                                          .format(
-                                             mgr.SSO_DB.saml2.get_username_attribute(),
+                                             mgr.SSO_DB.config.get_username_attribute(),
                                              auth.get_attributes()))
             username = username_attribute[0]
             url_prefix = prepare_url_prefix(mgr.get_module_option('url_prefix', default=''))
@@ -85,21 +85,21 @@ class Saml2(BaseController, ControllerAuthMixin):
     @Endpoint(xml=True, version=None)
     def metadata(self):
         Saml2._check_python_saml()
-        saml_settings = OneLogin_Saml2_Settings(mgr.SSO_DB.saml2.onelogin_settings)
+        saml_settings = OneLogin_Saml2_Settings(mgr.SSO_DB.config.onelogin_settings)
         return saml_settings.get_sp_metadata()
 
     @Endpoint(json_response=False, version=None)
     def login(self):
         Saml2._check_python_saml()
         req = Saml2._build_req(self._request, {})
-        auth = OneLogin_Saml2_Auth(req, mgr.SSO_DB.saml2.onelogin_settings)
+        auth = OneLogin_Saml2_Auth(req, mgr.SSO_DB.config.onelogin_settings)
         raise cherrypy.HTTPRedirect(auth.login())
 
     @Endpoint(json_response=False, version=None)
     def slo(self):
         Saml2._check_python_saml()
         req = Saml2._build_req(self._request, {})
-        auth = OneLogin_Saml2_Auth(req, mgr.SSO_DB.saml2.onelogin_settings)
+        auth = OneLogin_Saml2_Auth(req, mgr.SSO_DB.config.onelogin_settings)
         raise cherrypy.HTTPRedirect(auth.logout())
 
     @Endpoint(json_response=False, version=None)
@@ -107,7 +107,7 @@ class Saml2(BaseController, ControllerAuthMixin):
         # pylint: disable=unused-argument
         Saml2._check_python_saml()
         JwtManager.reset_user()
-        token = JwtManager.get_token_from_header()
+        token = JwtManager.get_token(cherrypy.request)
         self._delete_token_cookie(token)
         url_prefix = prepare_url_prefix(mgr.get_module_option('url_prefix', default=''))
         raise cherrypy.HTTPRedirect("{}/#/login".format(url_prefix))
index 8a291799235b398f24b14c990269627764df4c6b..c209c7ffdb292f75fd39037a561545619cb5f213 100644 (file)
@@ -42,6 +42,9 @@ export class AuthService {
   logout(callback: Function = null) {
     return this.http.post('api/auth/logout', null).subscribe((resp: any) => {
       this.authStorageService.remove();
+      if (resp.protocol == 'oauth2') {
+        return window.location.replace(resp.redirect_url);
+      }
       const url = _.get(this.route.snapshot.queryParams, 'returnUrl', '/login');
       this.router.navigate([url], { skipLocationChange: true });
       if (callback) {
index 341a4f00f1be0c625b9eb85c603df5469a069c08..777f368a83fc08faa783748a94bb864106a34087 100644 (file)
@@ -275,6 +275,7 @@ class Module(MgrModule, CherryPyConfig):
                min=400, max=599),
         Option(name='redirect_resolve_ip_addr', type='bool', default=False),
         Option(name='cross_origin_url', type='str', default=''),
+        Option(name='sso_oauth2', type='bool', default=False),
     ]
     MODULE_OPTIONS.extend(options_schema_list())
     for options in PLUGIN_MANAGER.hook.get_options() or []:
index b45f81fb9b1ddfa16e73072923b29ecc086c6370..21c1a9572bb6a9aaa61583132eeb11863af40d4d 100644 (file)
@@ -193,6 +193,15 @@ class Role(object):
         return Role(r_dict['name'], r_dict['description'],
                     r_dict['scopes_permissions'])
 
+    @classmethod
+    def map_to_system_roles(cls, roles) -> List['Role']:
+        matches = []
+        for rn in SYSTEM_ROLES_NAMES:
+            for role in roles:
+                if role in SYSTEM_ROLES_NAMES[rn]:
+                    matches.append(rn)
+        return matches
+
 
 # static pre-defined system roles
 # this roles cannot be deleted nor updated
@@ -283,6 +292,12 @@ SYSTEM_ROLES = {
     GANESHA_MGR_ROLE.name: GANESHA_MGR_ROLE,
 }
 
+# static name-like roles list for role mapping
+SYSTEM_ROLES_NAMES = {
+    ADMIN_ROLE: [ADMIN_ROLE.name, 'admin'],
+    READ_ONLY_ROLE: [READ_ONLY_ROLE.name, 'read', 'guest', 'monitor']
+}
+
 
 class User(object):
     def __init__(self, username, password, name=None, email=None, roles=None,
diff --git a/src/pybind/mgr/dashboard/services/auth.py b/src/pybind/mgr/dashboard/services/auth.py
deleted file mode 100644 (file)
index 3b8d5ed..0000000
+++ /dev/null
@@ -1,279 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import base64
-import hashlib
-import hmac
-import json
-import logging
-import os
-import threading
-import time
-import uuid
-from typing import Optional
-
-import cherrypy
-
-from .. import mgr
-from ..exceptions import ExpiredSignatureError, InvalidAlgorithmError, InvalidTokenError
-from .access_control import LocalAuthenticator, UserDoesNotExist
-
-cherrypy.config.update({
-    'response.headers.server': 'Ceph-Dashboard',
-    'response.headers.content-security-policy': "frame-ancestors 'self';",
-    'response.headers.x-content-type-options': 'nosniff',
-    'response.headers.strict-transport-security': 'max-age=63072000; includeSubDomains; preload'
-})
-
-
-class JwtManager(object):
-    JWT_TOKEN_BLOCKLIST_KEY = "jwt_token_block_list"
-    JWT_TOKEN_TTL = 28800  # default 8 hours
-    JWT_ALGORITHM = 'HS256'
-    _secret = None
-
-    LOCAL_USER = threading.local()
-
-    @staticmethod
-    def _gen_secret():
-        secret = os.urandom(16)
-        return base64.b64encode(secret).decode('utf-8')
-
-    @classmethod
-    def init(cls):
-        cls.logger = logging.getLogger('jwt')  # type: ignore
-        # generate a new secret if it does not exist
-        secret = mgr.get_store('jwt_secret')
-        if secret is None:
-            secret = cls._gen_secret()
-            mgr.set_store('jwt_secret', secret)
-        cls._secret = secret
-
-    @classmethod
-    def array_to_base64_string(cls, message):
-        jsonstr = json.dumps(message, sort_keys=True).replace(" ", "")
-        string_bytes = base64.urlsafe_b64encode(bytes(jsonstr, 'UTF-8'))
-        return string_bytes.decode('UTF-8').replace("=", "")
-
-    @classmethod
-    def encode(cls, message, secret):
-        header = {"alg": cls.JWT_ALGORITHM, "typ": "JWT"}
-        base64_header = cls.array_to_base64_string(header)
-        base64_message = cls.array_to_base64_string(message)
-        base64_secret = base64.urlsafe_b64encode(hmac.new(
-            bytes(secret, 'UTF-8'),
-            msg=bytes(base64_header + "." + base64_message, 'UTF-8'),
-            digestmod=hashlib.sha256
-        ).digest()).decode('UTF-8').replace("=", "")
-        return base64_header + "." + base64_message + "." + base64_secret
-
-    @classmethod
-    def decode(cls, message, secret):
-        split_message = message.split(".")
-        base64_header = split_message[0]
-        base64_message = split_message[1]
-        base64_secret = split_message[2]
-
-        decoded_header = json.loads(base64.urlsafe_b64decode(base64_header))
-
-        if decoded_header['alg'] != cls.JWT_ALGORITHM:
-            raise InvalidAlgorithmError()
-
-        incoming_secret = base64.urlsafe_b64encode(hmac.new(
-            bytes(secret, 'UTF-8'),
-            msg=bytes(base64_header + "." + base64_message, 'UTF-8'),
-            digestmod=hashlib.sha256
-        ).digest()).decode('UTF-8').replace("=", "")
-
-        if base64_secret != incoming_secret:
-            raise InvalidTokenError()
-
-        # We add ==== as padding to ignore the requirement to have correct padding in
-        # the urlsafe_b64decode method.
-        decoded_message = json.loads(base64.urlsafe_b64decode(base64_message + "===="))
-        now = int(time.time())
-        if decoded_message['exp'] < now:
-            raise ExpiredSignatureError()
-
-        return decoded_message
-
-    @classmethod
-    def gen_token(cls, username, ttl: Optional[int] = None):
-        if not cls._secret:
-            cls.init()
-        if ttl is None:
-            ttl = mgr.get_module_option('jwt_token_ttl', cls.JWT_TOKEN_TTL)
-        else:
-            ttl = int(ttl) * 60 * 60  # convert hours to seconds
-        now = int(time.time())
-        payload = {
-            'iss': 'ceph-dashboard',
-            'jti': str(uuid.uuid4()),
-            'exp': now + ttl,
-            'iat': now,
-            'username': username
-        }
-        return cls.encode(payload, cls._secret)  # type: ignore
-
-    @classmethod
-    def decode_token(cls, token):
-        if not cls._secret:
-            cls.init()
-        return cls.decode(token, cls._secret)  # type: ignore
-
-    @classmethod
-    def get_token_from_header(cls):
-        auth_cookie_name = 'token'
-        try:
-            # use cookie
-            return cherrypy.request.cookie[auth_cookie_name].value
-        except KeyError:
-            try:
-                # fall-back: use Authorization header
-                auth_header = cherrypy.request.headers.get('authorization')
-                if auth_header is not None:
-                    scheme, params = auth_header.split(' ', 1)
-                    if scheme.lower() == 'bearer':
-                        return params
-            except IndexError:
-                return None
-
-    @classmethod
-    def set_user(cls, username):
-        cls.LOCAL_USER.username = username
-
-    @classmethod
-    def reset_user(cls):
-        cls.set_user(None)
-
-    @classmethod
-    def get_username(cls):
-        return getattr(cls.LOCAL_USER, 'username', None)
-
-    @classmethod
-    def get_user(cls, token):
-        try:
-            dtoken = cls.decode_token(token)
-            if not cls.is_blocklisted(dtoken['jti']):
-                user = AuthManager.get_user(dtoken['username'])
-                if user.last_update <= dtoken['iat']:
-                    return user
-                cls.logger.debug(  # type: ignore
-                    "user info changed after token was issued, iat=%s last_update=%s",
-                    dtoken['iat'], user.last_update
-                )
-            else:
-                cls.logger.debug('Token is block-listed')  # type: ignore
-        except ExpiredSignatureError:
-            cls.logger.debug("Token has expired")  # type: ignore
-        except InvalidTokenError:
-            cls.logger.debug("Failed to decode token")  # type: ignore
-        except InvalidAlgorithmError:
-            cls.logger.debug("Only the HS256 algorithm is supported.")  # type: ignore
-        except UserDoesNotExist:
-            cls.logger.debug(  # type: ignore
-                "Invalid token: user %s does not exist", dtoken['username']
-            )
-        return None
-
-    @classmethod
-    def blocklist_token(cls, token):
-        token = cls.decode_token(token)
-        blocklist_json = mgr.get_store(cls.JWT_TOKEN_BLOCKLIST_KEY)
-        if not blocklist_json:
-            blocklist_json = "{}"
-        bl_dict = json.loads(blocklist_json)
-        now = time.time()
-
-        # remove expired tokens
-        to_delete = []
-        for jti, exp in bl_dict.items():
-            if exp < now:
-                to_delete.append(jti)
-        for jti in to_delete:
-            del bl_dict[jti]
-
-        bl_dict[token['jti']] = token['exp']
-        mgr.set_store(cls.JWT_TOKEN_BLOCKLIST_KEY, json.dumps(bl_dict))
-
-    @classmethod
-    def is_blocklisted(cls, jti):
-        blocklist_json = mgr.get_store(cls.JWT_TOKEN_BLOCKLIST_KEY)
-        if not blocklist_json:
-            blocklist_json = "{}"
-        bl_dict = json.loads(blocklist_json)
-        return jti in bl_dict
-
-
-class AuthManager(object):
-    AUTH_PROVIDER = None
-
-    @classmethod
-    def initialize(cls):
-        cls.AUTH_PROVIDER = LocalAuthenticator()
-
-    @classmethod
-    def get_user(cls, username):
-        return cls.AUTH_PROVIDER.get_user(username)  # type: ignore
-
-    @classmethod
-    def authenticate(cls, username, password):
-        return cls.AUTH_PROVIDER.authenticate(username, password)  # type: ignore
-
-    @classmethod
-    def authorize(cls, username, scope, permissions):
-        return cls.AUTH_PROVIDER.authorize(username, scope, permissions)  # type: ignore
-
-
-class AuthManagerTool(cherrypy.Tool):
-    def __init__(self):
-        super(AuthManagerTool, self).__init__(
-            'before_handler', self._check_authentication, priority=20)
-        self.logger = logging.getLogger('auth')
-
-    def _check_authentication(self):
-        JwtManager.reset_user()
-        token = JwtManager.get_token_from_header()
-        if token:
-            user = JwtManager.get_user(token)
-            if user:
-                self._check_authorization(user.username)
-                return
-
-        resp_head = cherrypy.response.headers
-        req_head = cherrypy.request.headers
-        req_header_cross_origin_url = req_head.get('Access-Control-Allow-Origin')
-        cross_origin_urls = mgr.get_module_option('cross_origin_url', '')
-        cross_origin_url_list = [url.strip() for url in cross_origin_urls.split(',')]
-
-        if req_header_cross_origin_url in cross_origin_url_list:
-            resp_head['Access-Control-Allow-Origin'] = req_header_cross_origin_url
-
-        self.logger.debug('Unauthorized access to %s',
-                          cherrypy.url(relative='server'))
-        raise cherrypy.HTTPError(401, 'You are not authorized to access '
-                                      'that resource')
-
-    def _check_authorization(self, username):
-        self.logger.debug("checking authorization...")
-        handler = cherrypy.request.handler.callable
-        controller = handler.__self__
-        sec_scope = getattr(controller, '_security_scope', None)
-        sec_perms = getattr(handler, '_security_permissions', None)
-        JwtManager.set_user(username)
-
-        if not sec_scope:
-            # controller does not define any authorization restrictions
-            return
-
-        self.logger.debug("checking '%s' access to '%s' scope", sec_perms,
-                          sec_scope)
-
-        if not sec_perms:
-            self.logger.debug("Fail to check permission on: %s:%s", controller,
-                              handler)
-            raise cherrypy.HTTPError(403, "You don't have permissions to "
-                                          "access that resource")
-
-        if not AuthManager.authorize(username, sec_scope, sec_perms):
-            raise cherrypy.HTTPError(403, "You don't have permissions to "
-                                          "access that resource")
diff --git a/src/pybind/mgr/dashboard/services/auth/__init__.py b/src/pybind/mgr/dashboard/services/auth/__init__.py
new file mode 100644 (file)
index 0000000..52fd040
--- /dev/null
@@ -0,0 +1,16 @@
+from .auth import AuthManager, AuthManagerTool, AuthType, BaseAuth, \
+    JwtManager, SSOAuth, decode_jwt_segment
+from .oauth2 import OAuth2
+from .saml2 import Saml2
+
+__all__ = [
+    'AuthManager',
+    'AuthManagerTool',
+    'AuthType',
+    'BaseAuth',
+    'SSOAuth',
+    'JwtManager',
+    'decode_jwt_segment',
+    'Saml2',
+    'OAuth2'
+]
diff --git a/src/pybind/mgr/dashboard/services/auth/auth.py b/src/pybind/mgr/dashboard/services/auth/auth.py
new file mode 100644 (file)
index 0000000..7f1cdb5
--- /dev/null
@@ -0,0 +1,366 @@
+# -*- coding: utf-8 -*-
+
+import abc
+import base64
+import hashlib
+import hmac
+import json
+import logging
+import os
+import threading
+import time
+import uuid
+from enum import Enum
+from typing import TYPE_CHECKING, Optional, Type, TypedDict
+
+import cherrypy
+
+from ... import mgr
+from ...exceptions import ExpiredSignatureError, InvalidAlgorithmError, InvalidTokenError
+from ..access_control import LocalAuthenticator, UserDoesNotExist
+
+if TYPE_CHECKING:
+    from dashboard.services.sso import SsoDB
+
+cherrypy.config.update({
+    'response.headers.server': 'Ceph-Dashboard',
+    'response.headers.content-security-policy': "frame-ancestors 'self';",
+    'response.headers.x-content-type-options': 'nosniff',
+    'response.headers.strict-transport-security': 'max-age=63072000; includeSubDomains; preload'
+})
+
+
+class AuthType(str, Enum):
+    LOCAL = 'local'
+    SAML2 = 'saml2'
+    OAUTH2 = 'oauth2'
+
+
+class BaseAuth(abc.ABC):
+    LOGIN_URL: str
+    LOGOUT_URL: str
+    sso: bool
+
+    @staticmethod
+    def from_protocol(protocol: AuthType) -> Type["BaseAuth"]:
+        for subclass in BaseAuth.__subclasses__():
+            if subclass.__name__.lower() == protocol:
+                return subclass
+            for subsubclass in subclass.__subclasses__():
+                if subsubclass.__name__.lower() == protocol:
+                    return subsubclass
+        raise ValueError(f"Unknown auth backend: '{protocol}'")
+
+    @classmethod
+    def from_db(cls, db: Optional['SsoDB'] = None) -> Type["BaseAuth"]:
+        if db is None:
+            protocol = mgr.SSO_DB.protocol
+        else:
+            protocol = db.protocol
+        return cls.from_protocol(protocol)
+
+    class Config(TypedDict):  # pylint: disable=inherit-non-class
+        pass
+
+    @abc.abstractmethod
+    def to_dict(self) -> 'Config':
+        pass
+
+    @classmethod
+    @abc.abstractmethod
+    def from_dict(cls, s_dict) -> 'BaseAuth':
+        pass
+
+    @classmethod
+    def get_auth_name(cls):
+        return cls.__name__.lower()
+
+
+class Local(BaseAuth):
+    LOGIN_URL = '#/login'
+    LOGOUT_URL = '#/login'
+    sso = False
+
+    @classmethod
+    def get_auth_name(cls):
+        return cls.__name__.lower()
+
+    def to_dict(self) -> 'BaseAuth.Config':
+        return BaseAuth.Config()
+
+    @classmethod
+    def from_dict(cls, s_dict: BaseAuth.Config) -> 'Local':
+        # pylint: disable=unused-argument
+        return cls()
+
+
+class SSOAuth(BaseAuth):
+    sso = True
+
+
+class JwtManager(object):
+    JWT_TOKEN_BLOCKLIST_KEY = "jwt_token_block_list"
+    JWT_TOKEN_TTL = 28800  # default 8 hours
+    JWT_ALGORITHM = 'HS256'
+    _secret = None
+
+    LOCAL_USER = threading.local()
+
+    @staticmethod
+    def _gen_secret():
+        secret = os.urandom(16)
+        return base64.b64encode(secret).decode('utf-8')
+
+    @classmethod
+    def init(cls):
+        cls.logger = logging.getLogger('jwt')  # type: ignore
+        # generate a new secret if it does not exist
+        secret = mgr.get_store('jwt_secret')
+        if secret is None:
+            secret = cls._gen_secret()
+            mgr.set_store('jwt_secret', secret)
+        cls._secret = secret
+
+    @classmethod
+    def array_to_base64_string(cls, message):
+        jsonstr = json.dumps(message, sort_keys=True).replace(" ", "")
+        string_bytes = base64.urlsafe_b64encode(bytes(jsonstr, 'UTF-8'))
+        return string_bytes.decode('UTF-8').replace("=", "")
+
+    @classmethod
+    def encode(cls, message, secret):
+        header = {"alg": cls.JWT_ALGORITHM, "typ": "JWT"}
+        base64_header = cls.array_to_base64_string(header)
+        base64_message = cls.array_to_base64_string(message)
+        base64_secret = base64.urlsafe_b64encode(hmac.new(
+            bytes(secret, 'UTF-8'),
+            msg=bytes(base64_header + "." + base64_message, 'UTF-8'),
+            digestmod=hashlib.sha256
+        ).digest()).decode('UTF-8').replace("=", "")
+        return base64_header + "." + base64_message + "." + base64_secret
+
+    @classmethod
+    def decode(cls, message, secret):
+        oauth2_sso_protocol = mgr.SSO_DB.protocol == AuthType.OAUTH2
+        split_message = message.split(".")
+        base64_header = split_message[0]
+        base64_message = split_message[1]
+        base64_secret = split_message[2]
+
+        decoded_header = decode_jwt_segment(base64_header)
+
+        if decoded_header['alg'] != cls.JWT_ALGORITHM and not oauth2_sso_protocol:
+            raise InvalidAlgorithmError()
+
+        incoming_secret = ''
+        if decoded_header['alg'] == cls.JWT_ALGORITHM:
+            incoming_secret = base64.urlsafe_b64encode(hmac.new(
+                bytes(secret, 'UTF-8'),
+                msg=bytes(base64_header + "." + base64_message, 'UTF-8'),
+                digestmod=hashlib.sha256
+            ).digest()).decode('UTF-8').replace("=", "")
+
+        if base64_secret != incoming_secret and not oauth2_sso_protocol:
+            raise InvalidTokenError()
+
+        decoded_message = decode_jwt_segment(base64_message)
+        if oauth2_sso_protocol:
+            decoded_message['username'] = decoded_message['sub']
+        now = int(time.time())
+        if decoded_message['exp'] < now:
+            raise ExpiredSignatureError()
+
+        return decoded_message
+
+    @classmethod
+    def gen_token(cls, username, ttl: Optional[int] = None):
+        if not cls._secret:
+            cls.init()
+        if ttl is None:
+            ttl = mgr.get_module_option('jwt_token_ttl', cls.JWT_TOKEN_TTL)
+        else:
+            ttl = int(ttl) * 60 * 60  # convert hours to seconds
+        now = int(time.time())
+        payload = {
+            'iss': 'ceph-dashboard',
+            'jti': str(uuid.uuid4()),
+            'exp': now + ttl,
+            'iat': now,
+            'username': username
+        }
+        return cls.encode(payload, cls._secret)  # type: ignore
+
+    @classmethod
+    def decode_token(cls, token):
+        if not cls._secret:
+            cls.init()
+        return cls.decode(token, cls._secret)  # type: ignore
+
+    @classmethod
+    # pylint: disable=protected-access
+    def get_token(cls, request: cherrypy._ThreadLocalProxy):
+        if mgr.SSO_DB.protocol == AuthType.OAUTH2:
+            # Avoids circular import
+            from .oauth2 import OAuth2
+            return OAuth2.get_token(request)
+        auth_cookie_name = 'token'
+        try:
+            # use cookie
+            return request.cookie[auth_cookie_name].value
+        except KeyError:
+            try:
+                # fall-back: use Authorization header
+                auth_header = request.headers.get('authorization')
+                if auth_header is not None:
+                    scheme, params = auth_header.split(' ', 1)
+                    if scheme.lower() == 'bearer':
+                        return params
+            except IndexError:
+                return None
+
+    @classmethod
+    def set_user(cls, username):
+        cls.LOCAL_USER.username = username
+
+    @classmethod
+    def reset_user(cls):
+        cls.set_user(None)
+
+    @classmethod
+    def get_username(cls):
+        return getattr(cls.LOCAL_USER, 'username', None)
+
+    @classmethod
+    def get_user(cls, token):
+        try:
+            dtoken = cls.decode_token(token)
+            if 'jti' in dtoken and not cls.is_blocklisted(dtoken['jti']):
+                user = AuthManager.get_user(dtoken['username'])
+                if 'iat' in dtoken and user.last_update <= dtoken['iat']:
+                    return user
+                cls.logger.debug(  # type: ignore
+                    "user info changed after token was issued, iat=%s last_update=%s",
+                    dtoken['iat'], user.last_update
+                )
+            else:
+                cls.logger.debug('Token is block-listed')  # type: ignore
+        except ExpiredSignatureError:
+            cls.logger.debug("Token has expired")  # type: ignore
+        except InvalidTokenError:
+            cls.logger.debug("Failed to decode token")  # type: ignore
+        except InvalidAlgorithmError:
+            cls.logger.debug("Only the HS256 algorithm is supported.")  # type: ignore
+        except UserDoesNotExist:
+            cls.logger.debug(  # type: ignore
+                "Invalid token: user %s does not exist", dtoken['username']
+            )
+        return None
+
+    @classmethod
+    def blocklist_token(cls, token):
+        token = cls.decode_token(token)
+        blocklist_json = mgr.get_store(cls.JWT_TOKEN_BLOCKLIST_KEY)
+        if not blocklist_json:
+            blocklist_json = "{}"
+        bl_dict = json.loads(blocklist_json)
+        now = time.time()
+
+        # remove expired tokens
+        to_delete = []
+        for jti, exp in bl_dict.items():
+            if exp < now:
+                to_delete.append(jti)
+        for jti in to_delete:
+            del bl_dict[jti]
+
+        bl_dict[token['jti']] = token['exp']
+        mgr.set_store(cls.JWT_TOKEN_BLOCKLIST_KEY, json.dumps(bl_dict))
+
+    @classmethod
+    def is_blocklisted(cls, jti):
+        blocklist_json = mgr.get_store(cls.JWT_TOKEN_BLOCKLIST_KEY)
+        if not blocklist_json:
+            blocklist_json = "{}"
+        bl_dict = json.loads(blocklist_json)
+        return jti in bl_dict
+
+
+class AuthManager(object):
+    AUTH_PROVIDER = None
+
+    @classmethod
+    def initialize(cls):
+        cls.AUTH_PROVIDER = LocalAuthenticator()
+
+    @classmethod
+    def get_user(cls, username):
+        return cls.AUTH_PROVIDER.get_user(username)  # type: ignore
+
+    @classmethod
+    def authenticate(cls, username, password):
+        return cls.AUTH_PROVIDER.authenticate(username, password)  # type: ignore
+
+    @classmethod
+    def authorize(cls, username, scope, permissions):
+        return cls.AUTH_PROVIDER.authorize(username, scope, permissions)  # type: ignore
+
+
+class AuthManagerTool(cherrypy.Tool):
+    def __init__(self):
+        super(AuthManagerTool, self).__init__(
+            'before_handler', self._check_authentication, priority=20)
+        self.logger = logging.getLogger('auth')
+
+    def _check_authentication(self):
+        JwtManager.reset_user()
+        token = JwtManager.get_token(cherrypy.request)
+        if token:
+            user = JwtManager.get_user(token)
+            if user:
+                self._check_authorization(user.username)
+                return
+
+        resp_head = cherrypy.response.headers
+        req_head = cherrypy.request.headers
+        req_header_cross_origin_url = req_head.get('Access-Control-Allow-Origin')
+        cross_origin_urls = mgr.get_module_option('cross_origin_url', '')
+        cross_origin_url_list = [url.strip() for url in cross_origin_urls.split(',')]
+
+        if req_header_cross_origin_url in cross_origin_url_list:
+            resp_head['Access-Control-Allow-Origin'] = req_header_cross_origin_url
+
+        self.logger.debug('Unauthorized access to %s',
+                          cherrypy.url(relative='server'))
+        raise cherrypy.HTTPError(401, 'You are not authorized to access '
+                                      'that resource')
+
+    def _check_authorization(self, username):
+        self.logger.debug("checking authorization...")
+        handler = cherrypy.request.handler.callable
+        controller = handler.__self__
+        sec_scope = getattr(controller, '_security_scope', None)
+        sec_perms = getattr(handler, '_security_permissions', None)
+        JwtManager.set_user(username)
+
+        if not sec_scope:
+            # controller does not define any authorization restrictions
+            return
+
+        self.logger.debug("checking '%s' access to '%s' scope", sec_perms,
+                          sec_scope)
+
+        if not sec_perms:
+            self.logger.debug("Fail to check permission on: %s:%s", controller,
+                              handler)
+            raise cherrypy.HTTPError(403, "You don't have permissions to "
+                                          "access that resource")
+
+        if not AuthManager.authorize(username, sec_scope, sec_perms):
+            raise cherrypy.HTTPError(403, "You don't have permissions to "
+                                          "access that resource")
+
+
+def decode_jwt_segment(encoded_segment: str):
+    # We add ==== as padding to ignore the requirement to have correct padding in
+    # the urlsafe_b64decode method.
+    return json.loads(base64.urlsafe_b64decode(encoded_segment + "===="))
diff --git a/src/pybind/mgr/dashboard/services/auth/oauth2.py b/src/pybind/mgr/dashboard/services/auth/oauth2.py
new file mode 100644 (file)
index 0000000..5376107
--- /dev/null
@@ -0,0 +1,151 @@
+import json
+from typing import Dict, List
+from urllib.parse import quote
+
+import cherrypy
+import requests
+
+from ... import mgr
+from ...services.auth import BaseAuth, SSOAuth, decode_jwt_segment
+from ...tools import prepare_url_prefix
+from ..access_control import Role, User, UserAlreadyExists
+
+
+class OAuth2(SSOAuth):
+    LOGIN_URL = 'auth/oauth2/login'
+    LOGOUT_URL = 'auth/oauth2/logout'
+    sso = True
+
+    class OAuth2Config(BaseAuth.Config):
+        pass
+
+    @staticmethod
+    def enabled():
+        return mgr.get_module_option('sso_oauth2')
+
+    def to_dict(self) -> 'BaseAuth.Config':
+        return self.OAuth2Config()
+
+    @classmethod
+    def from_dict(cls, s_dict: OAuth2Config) -> 'OAuth2':
+        # pylint: disable=unused-argument
+        return OAuth2()
+
+    @classmethod
+    def get_auth_name(cls):
+        return cls.__name__.lower()
+
+    @classmethod
+    # pylint: disable=protected-access
+    def get_token(cls, request: cherrypy._ThreadLocalProxy) -> str:
+        try:
+            return request.cookie['token'].value
+        except KeyError:
+            return request.headers.get('X-Access-Token')
+
+    @classmethod
+    def set_token(cls, token: str):
+        cherrypy.request.jwt = token
+        cherrypy.request.jwt_payload = cls.get_token_payload()
+        cherrypy.request.user = cls.get_user(token)
+
+    @classmethod
+    def get_token_payload(cls) -> Dict:
+        try:
+            return cherrypy.request.jwt_payload
+        except AttributeError:
+            pass
+        try:
+            return decode_jwt_segment(cherrypy.request.jwt.split(".")[1])
+        except AttributeError:
+            return {}
+
+    @classmethod
+    def set_token_payload(cls, token):
+        cherrypy.request.jwt_payload = decode_jwt_segment(token.split(".")[1])
+
+    @classmethod
+    def get_user_roles(cls):
+        roles: List[Role] = []
+        user_roles: List[Role] = []
+        try:
+            jwt_payload = cherrypy.request.jwt_payload
+        except AttributeError:
+            raise cherrypy.HTTPError()
+
+        # check for client roes
+        if 'resource_access' in jwt_payload:
+            # Find the first value where the key is not 'account'
+            roles = next((value['roles'] for key, value in jwt_payload['resource_access'].items()
+                          if key != "account"), user_roles)
+        # check for global roles
+        elif 'realm_access' in jwt_payload:
+            roles = next((value['roles'] for _, value in jwt_payload['realm_access'].items()),
+                         user_roles)
+        else:
+            raise cherrypy.HTTPError()
+        user_roles = Role.map_to_system_roles(roles)
+        return user_roles
+
+    @classmethod
+    def get_user(cls, token: str) -> User:
+        try:
+            return cherrypy.request.user
+        except AttributeError:
+            cls.set_token_payload(token)
+            cls._create_user()
+        return cherrypy.request.user
+
+    @classmethod
+    def _create_user(cls):
+        try:
+            jwt_payload = cherrypy.request.jwt_payload
+        except AttributeError:
+            raise cherrypy.HTTPError()
+        try:
+            user = mgr.ACCESS_CTRL_DB.create_user(
+                jwt_payload['sub'], None, jwt_payload['name'], jwt_payload['email'])
+        except UserAlreadyExists:
+            user = mgr.ACCESS_CTRL_DB.get_user(jwt_payload['sub'])
+        user.set_roles(cls.get_user_roles())
+        # set user last update to token time issued
+        user.last_update = jwt_payload['iat']
+        cherrypy.request.user = user
+
+    @classmethod
+    def reset_user(cls):
+        try:
+            mgr.ACCESS_CTRL_DB.delete_user(cherrypy.request.user.username)
+            cherrypy.request.user = None
+        except AttributeError:
+            raise cherrypy.HTTPError()
+
+    @classmethod
+    def get_token_iss(cls, token=''):
+        if token:
+            cls.set_token_payload(token)
+        return cls.get_token_payload()['iss']
+
+    @classmethod
+    def get_openid_config(cls, iss):
+        msg = 'Failed to logout: could not contact IDP'
+        try:
+            response = requests.get(f'{iss}/.well-known/openid-configuration')
+        except requests.exceptions.RequestException:
+            raise cherrypy.HTTPError(500, message=msg)
+        if response.status_code != 200:
+            raise cherrypy.HTTPError(500, message=msg)
+        return json.loads(response.text)
+
+    @classmethod
+    def get_login_redirect_url(cls, token) -> str:
+        url_prefix = prepare_url_prefix(mgr.get_module_option('url_prefix', default=''))
+        return f"{url_prefix}/#/login?access_token={token}"
+
+    @classmethod
+    def get_logout_redirect_url(cls, token) -> str:
+        openid_config = OAuth2.get_openid_config(OAuth2.get_token_iss(token))
+        end_session_url = openid_config.get('end_session_endpoint')
+        encoded_end_session_url = quote(end_session_url, safe="")
+        url_prefix = prepare_url_prefix(mgr.get_module_option('url_prefix', default=''))
+        return f'{url_prefix}/oauth2/sign_out?rd={encoded_end_session_url}'
diff --git a/src/pybind/mgr/dashboard/services/auth/saml2.py b/src/pybind/mgr/dashboard/services/auth/saml2.py
new file mode 100644 (file)
index 0000000..110de3e
--- /dev/null
@@ -0,0 +1,35 @@
+from typing import Any
+
+from .auth import BaseAuth, SSOAuth
+
+
+class Saml2(SSOAuth):
+    LOGIN_URL = 'auth/saml2/login'
+    LOGOUT_URL = 'auth/saml2/slo'
+    sso = True
+
+    class Saml2Config(BaseAuth.Config):
+        onelogin_settings: Any
+
+    def __init__(self, onelogin_settings):
+        self.onelogin_settings = onelogin_settings
+
+    def get_username_attribute(self):
+        return self.onelogin_settings['sp']['attributeConsumingService']['requestedAttributes'][0][
+            'name']
+
+    def to_dict(self) -> 'Saml2Config':
+        return {
+            'onelogin_settings': self.onelogin_settings
+        }
+
+    @classmethod
+    def from_dict(cls, s_dict: Saml2Config) -> 'Saml2':
+        try:
+            return Saml2(s_dict['onelogin_settings'])
+        except KeyError:
+            return Saml2({})
+
+    @classmethod
+    def get_auth_name(cls):
+        return cls.__name__.lower()
index 2290e6ea3e15f71d7d16e74d6a3fbd26438fe997..0b607e217df76c149aad391e3ca77455d7f75a9a 100644 (file)
@@ -7,9 +7,15 @@ import logging
 import os
 import threading
 import warnings
+from typing import Dict
 from urllib import parse
 
+from mgr_module import CLIWriteCommand, HandleCommandResult
+
 from .. import mgr
+# Saml2 and OAuth2 needed to be recognized by .__subclasses__()
+# pylint: disable=unused-import
+from ..services.auth import AuthType, BaseAuth, OAuth2, Saml2  # noqa
 from ..tools import prepare_url_prefix
 
 logger = logging.getLogger('sso')
@@ -24,39 +30,22 @@ except ImportError:
     python_saml_imported = False
 
 
-class Saml2(object):
-    def __init__(self, onelogin_settings):
-        self.onelogin_settings = onelogin_settings
-
-    def get_username_attribute(self):
-        return self.onelogin_settings['sp']['attributeConsumingService']['requestedAttributes'][0][
-            'name']
-
-    def to_dict(self):
-        return {
-            'onelogin_settings': self.onelogin_settings
-        }
-
-    @classmethod
-    def from_dict(cls, s_dict):
-        return Saml2(s_dict['onelogin_settings'])
-
-
 class SsoDB(object):
     VERSION = 1
     SSODB_CONFIG_KEY = "ssodb_v"
 
-    def __init__(self, version, protocol, saml2):
+    def __init__(self, version, protocol: AuthType, config: BaseAuth):
         self.version = version
         self.protocol = protocol
-        self.saml2 = saml2
+        self.config = config
         self.lock = threading.RLock()
 
     def save(self):
         with self.lock:
             db = {
                 'protocol': self.protocol,
-                'saml2': self.saml2.to_dict(),
+                'saml2': self.config.to_dict(),
+                'oauth2': self.config.to_dict(),
                 'version': self.version
             }
             mgr.set_store(self.ssodb_config_key(), json.dumps(db))
@@ -79,20 +68,33 @@ class SsoDB(object):
         json_db = mgr.get_store(cls.ssodb_config_key(), None)
         if json_db is None:
             logger.debug("No DB v%s found, creating new...", cls.VERSION)
-            db = cls(cls.VERSION, '', Saml2({}))
+            db = cls(cls.VERSION, AuthType.LOCAL, Saml2({}))
             # check if we can update from a previous version database
             db.check_and_update_db()
             return db
 
-        dict_db = json.loads(json_db)  # type: dict
-        return cls(dict_db['version'], dict_db.get('protocol'),
-                   Saml2.from_dict(dict_db.get('saml2')))
+        dict_db = json.loads(json_db)  # type: Dict
+        protocol = dict_db.get('protocol')
+        # keep backward-compatibility
+        if protocol == '':
+            protocol = AuthType.LOCAL
+        protocol = AuthType(protocol)
+        config: BaseAuth = BaseAuth.from_protocol(protocol).from_dict(dict_db.get(protocol))
+        return cls(dict_db['version'], protocol, config)
 
 
 def load_sso_db():
     mgr.SSO_DB = SsoDB.load()  # type: ignore
 
 
+@CLIWriteCommand("dashboard sso enable oauth2")
+def enable_sso(_):
+    mgr.SSO_DB.protocol = AuthType.OAUTH2
+    mgr.SSO_DB.save()
+    mgr.set_module_option('sso_oauth2', True)
+    return HandleCommandResult(stdout='SSO is "enabled" with "OAuth2" protocol.')
+
+
 SSO_COMMANDS = [
     {
         'cmd': 'dashboard sso enable saml2',
@@ -148,27 +150,28 @@ def handle_sso_command(cmd):
         return -errno.EPERM, '', 'Required library not found: `python3-saml`'
 
     if cmd['prefix'] == 'dashboard sso disable':
-        mgr.SSO_DB.protocol = ''
+        mgr.SSO_DB.protocol = AuthType.LOCAL
         mgr.SSO_DB.save()
+        mgr.set_module_option('sso_oauth2', False)
         return 0, 'SSO is "disabled".', ''
 
     if cmd['prefix'] == 'dashboard sso enable saml2':
         configured = _is_sso_configured()
         if configured:
-            mgr.SSO_DB.protocol = 'saml2'
+            mgr.SSO_DB.protocol = AuthType.SAML2
             mgr.SSO_DB.save()
-            return 0, 'SSO is "enabled" with "SAML2" protocol.', ''
+            return 0, 'SSO is "enabled" with "saml2" protocol.', ''
         return -errno.EPERM, '', 'Single Sign-On is not configured: ' \
             'use `ceph dashboard sso setup saml2`'
 
     if cmd['prefix'] == 'dashboard sso status':
-        if mgr.SSO_DB.protocol == 'saml2':
-            return 0, 'SSO is "enabled" with "SAML2" protocol.', ''
+        if not mgr.SSO_DB.protocol == AuthType.LOCAL:
+            return 0, f'SSO is "enabled" with "{mgr.SSO_DB.protocol}" protocol.', ''
 
         return 0, 'SSO is "disabled".', ''
 
     if cmd['prefix'] == 'dashboard sso show saml2':
-        return 0, json.dumps(mgr.SSO_DB.saml2.to_dict()), ''
+        return 0, json.dumps(mgr.SSO_DB.config.to_dict()), ''
 
     if cmd['prefix'] == 'dashboard sso setup saml2':
         ret = _handle_saml_setup(cmd)
@@ -180,8 +183,8 @@ def handle_sso_command(cmd):
 def _is_sso_configured():
     configured = True
     try:
-        Saml2Settings(mgr.SSO_DB.saml2.onelogin_settings)
-    except Saml2Error:
+        Saml2Settings(mgr.SSO_DB.config.onelogin_settings)
+    except (AttributeError, Saml2Error):
         configured = False
     return configured
 
@@ -192,7 +195,7 @@ def _handle_saml_setup(cmd):
         ret = -errno.EINVAL, '', err
     else:
         _set_saml_settings(cmd, sp_x_509_cert, sp_private_key, has_sp_cert)
-        ret = 0, json.dumps(mgr.SSO_DB.saml2.onelogin_settings), ''
+        ret = 0, json.dumps(mgr.SSO_DB.config.onelogin_settings), ''
     return ret
 
 
@@ -274,8 +277,8 @@ def _set_saml_settings(cmd, sp_x_509_cert, sp_private_key, has_sp_cert):
         }
     }
     settings = Saml2Parser.merge_settings(settings, idp_settings)
-    mgr.SSO_DB.saml2.onelogin_settings = settings
-    mgr.SSO_DB.protocol = 'saml2'
+    mgr.SSO_DB.config.onelogin_settings = settings
+    mgr.SSO_DB.protocol = AuthType.SAML2
     mgr.SSO_DB.save()
 
 
index 70e841a667bed61fa52c49935b65f22d524e44e2..a47a625136a8471182c5b398c68d5487e0e386ae 100644 (file)
@@ -1,6 +1,8 @@
 import unittest
 from unittest.mock import Mock, patch
 
+from dashboard.services.auth import AuthType
+
 from .. import mgr
 from ..controllers.auth import Auth
 from ..services.auth import JwtManager
@@ -10,6 +12,7 @@ mgr.get_module_option.return_value = JwtManager.JWT_TOKEN_TTL
 mgr.get_store.return_value = 'jwt_secret'
 mgr.ACCESS_CTRL_DB = Mock()
 mgr.ACCESS_CTRL_DB.get_attempt.return_value = 1
+mgr.SSO_DB.protocol = AuthType.LOCAL
 
 
 class JwtManagerTest(unittest.TestCase):
@@ -67,5 +70,6 @@ class AuthTest(ControllerTestCase):
         self._post('/api/auth/logout')
         self.assertStatus(200)
         self.assertJsonBody({
-            'redirect_url': '#/login'
+            'redirect_url': '#/login',
+            'protocol': 'local'
         })
index e077dde19e18a65cb7a541c71aace5f81fefefef..9492f0a20ed6698e0f5edeebb645616f9d283863 100644 (file)
@@ -166,7 +166,7 @@ class AccessControlTest(unittest.TestCase, CLICommandTestMixin):
                       idp_metadata=self.IDP_METADATA)
 
         result = self.exec_cmd('sso enable saml2')
-        self.assertEqual(result, 'SSO is "enabled" with "SAML2" protocol.')
+        self.assertEqual(result, 'SSO is "enabled" with "saml2" protocol.')
 
     def test_sso_disable(self):
         result = self.exec_cmd('sso disable')
@@ -181,7 +181,7 @@ class AccessControlTest(unittest.TestCase, CLICommandTestMixin):
                       idp_metadata=self.IDP_METADATA)
 
         result = self.exec_cmd('sso status')
-        self.assertEqual(result, 'SSO is "enabled" with "SAML2" protocol.')
+        self.assertEqual(result, 'SSO is "enabled" with "saml2" protocol.')
 
     def test_sso_show_saml2(self):
         result = self.exec_cmd('sso show saml2')