]> git-server-git.apps.pok.os.sepia.ceph.com Git - ceph.git/commitdiff
mgr/dashboard: reduce complexity of awsauth.py and rest_client.py 45007/head
authorVallari Agrawal <val.agl002@gmail.com>
Thu, 17 Mar 2022 07:44:53 +0000 (13:14 +0530)
committerVallari Agrawal <val.agl002@gmail.com>
Thu, 17 Mar 2022 07:56:32 +0000 (13:26 +0530)
Signed-off-by: Vallari Agrawal <val.agl002@gmail.com>
src/pybind/mgr/dashboard/awsauth.py
src/pybind/mgr/dashboard/rest_client.py

index ad8dc20faaa6037ae215973b97536455c9047fdd..285a2c08847e938b7c7564f8ad804cba8d3aada1 100644 (file)
@@ -81,16 +81,7 @@ class S3Auth(AuthBase):
         h = hmac.new(key, msg, digestmod=sha)
         return encodestring(h.digest()).strip()
 
-    def get_canonical_string(self, url, headers, method):
-        parsedurl = urlparse(url)
-        objectkey = parsedurl.path[1:]
-        query_args = sorted(parsedurl.query.split('&'))
-
-        bucket = parsedurl.netloc[:-len(self.service_base_url)]
-        if len(bucket) > 1:
-            # remove last dot
-            bucket = bucket[:-1]
-
+    def get_interesting_headers(self, headers):
         interesting_headers = {
             'content-md5': '',
             'content-type': '',
@@ -109,6 +100,19 @@ class S3Auth(AuthBase):
         # If x-amz-date is used it supersedes the date header.
         if 'x-amz-date' in interesting_headers:
             interesting_headers['date'] = ''
+        return interesting_headers
+
+    def get_canonical_string(self, url, headers, method):
+        parsedurl = urlparse(url)
+        objectkey = parsedurl.path[1:]
+        query_args = sorted(parsedurl.query.split('&'))
+
+        bucket = parsedurl.netloc[:-len(self.service_base_url)]
+        if len(bucket) > 1:
+            # remove last dot
+            bucket = bucket[:-1]
+
+        interesting_headers = self.get_interesting_headers(headers)
 
         buf = '%s\n' % method
         for key in sorted(interesting_headers.keys()):
index bd03dc7067f9efb719ba0b639cab909383ebbc69..69240bace86678ebd58607b0566960cad1ba22ea 100644 (file)
@@ -211,16 +211,12 @@ class _ResponseValidator(object):
                 _ResponseValidator._validate_array(array_seq[1:], level_next,
                                                    resp[idx])
             elif array_seq[0] == '*':
-                for r in resp:
-                    _ResponseValidator._validate_array(array_seq[1:],
-                                                       level_next, r)
+                _ResponseValidator.validate_all_resp(resp, array_seq, level_next)
             elif array_seq[0] == '+':
                 if len(resp) < 1:
                     raise BadResponseFormatException(
                         "array should not be empty")
-                for r in resp:
-                    _ResponseValidator._validate_array(array_seq[1:],
-                                                       level_next, r)
+                _ResponseValidator.validate_all_resp(resp, array_seq, level_next)
             else:
                 raise Exception(
                     "Response structure is invalid: only <int> | '*' are "
@@ -229,6 +225,12 @@ class _ResponseValidator(object):
             if level_next:
                 _ResponseValidator._validate_level(level_next, resp)
 
+    @staticmethod
+    def validate_all_resp(resp, array_seq, level_next):
+        for r in resp:
+            _ResponseValidator._validate_array(array_seq[1:],
+                                               level_next, r)
+
     @staticmethod
     def _validate_key(key, level_next, resp):
         array_access = [a.strip() for a in key.split("[")]
@@ -393,33 +395,7 @@ class RestClient(object):
         if headers:
             request_headers.update(headers)
         try:
-            if method.lower() == 'get':
-                resp = self.session.get(
-                    url, headers=request_headers, params=params, auth=self.auth)
-            elif method.lower() == 'post':
-                resp = self.session.post(
-                    url,
-                    headers=request_headers,
-                    params=params,
-                    data=data,
-                    auth=self.auth)
-            elif method.lower() == 'put':
-                resp = self.session.put(
-                    url,
-                    headers=request_headers,
-                    params=params,
-                    data=data,
-                    auth=self.auth)
-            elif method.lower() == 'delete':
-                resp = self.session.delete(
-                    url,
-                    headers=request_headers,
-                    params=params,
-                    data=data,
-                    auth=self.auth)
-            else:
-                raise RequestException('Method "{}" not supported'.format(
-                    method.upper()), None)
+            resp = self.send_request(method, url, request_headers, params, data)
             if resp.ok:
                 logger.debug("%s REST API %s res status: %s content: %s",
                              self.client_name, method.upper(),
@@ -451,53 +427,7 @@ class RestClient(object):
                     self._handle_response_status_code(resp.status_code),
                     resp.content)
         except ConnectionError as ex:
-            if ex.args:
-                if isinstance(ex.args[0], SSLError):
-                    errno = "n/a"
-                    strerror = "SSL error. Probably trying to access a non " \
-                               "SSL connection."
-                    logger.error("%s REST API failed %s, SSL error (url=%s).",
-                                 self.client_name, method.upper(), ex.request.url)
-                else:
-                    try:
-                        match = re.match(r'.*: \[Errno (-?\d+)\] (.+)',
-                                         ex.args[0].reason.args[0])
-                    except AttributeError:
-                        match = None
-                    if match:
-                        errno = match.group(1)
-                        strerror = match.group(2)
-                        logger.error(
-                            "%s REST API failed %s, connection error (url=%s): "
-                            "[errno: %s] %s",
-                            self.client_name, method.upper(), ex.request.url,
-                            errno, strerror)
-                    else:
-                        errno = "n/a"
-                        strerror = "n/a"
-                        logger.error(
-                            "%s REST API failed %s, connection error (url=%s).",
-                            self.client_name, method.upper(), ex.request.url)
-            else:
-                errno = "n/a"
-                strerror = "n/a"
-                logger.error("%s REST API failed %s, connection error (url=%s).",
-                             self.client_name, method.upper(), ex.request.url)
-
-            if errno != "n/a":
-                ex_msg = (
-                    "{} REST API cannot be reached: {} [errno {}]. "
-                    "Please check your configuration and that the API endpoint"
-                    " is accessible"
-                    .format(self.client_name, strerror, errno))
-            else:
-                ex_msg = (
-                    "{} REST API cannot be reached. Please check "
-                    "your configuration and that the API endpoint is"
-                    " accessible"
-                    .format(self.client_name))
-            raise RequestException(
-                ex_msg, conn_errno=errno, conn_strerror=strerror)
+            self.handle_connection_error(ex, method)
         except InvalidURL as ex:
             logger.exception("%s REST API failed %s: %s", self.client_name,
                              method.upper(), str(ex))
@@ -509,6 +439,84 @@ class RestClient(object):
             logger.exception(msg)
             raise RequestException(msg)
 
+    def send_request(self, method, url, request_headers, params, data):
+        if method.lower() == 'get':
+            resp = self.session.get(
+                url, headers=request_headers, params=params, auth=self.auth)
+        elif method.lower() == 'post':
+            resp = self.session.post(
+                url,
+                headers=request_headers,
+                params=params,
+                data=data,
+                auth=self.auth)
+        elif method.lower() == 'put':
+            resp = self.session.put(
+                url,
+                headers=request_headers,
+                params=params,
+                data=data,
+                auth=self.auth)
+        elif method.lower() == 'delete':
+            resp = self.session.delete(
+                url,
+                headers=request_headers,
+                params=params,
+                data=data,
+                auth=self.auth)
+        else:
+            raise RequestException('Method "{}" not supported'.format(
+                method.upper()), None)
+        return resp
+
+    def handle_connection_error(self, exception, method):
+        if exception.args:
+            if isinstance(exception.args[0], SSLError):
+                errno = "n/a"
+                strerror = "SSL error. Probably trying to access a non " \
+                           "SSL connection."
+                logger.error("%s REST API failed %s, SSL error (url=%s).",
+                             self.client_name, method.upper(), exception.request.url)
+            else:
+                try:
+                    match = re.match(r'.*: \[Errno (-?\d+)\] (.+)',
+                                     exception.args[0].reason.args[0])
+                except AttributeError:
+                    match = None
+                if match:
+                    errno = match.group(1)
+                    strerror = match.group(2)
+                    logger.error(
+                        "%s REST API failed %s, connection error (url=%s): "
+                        "[errno: %s] %s",
+                        self.client_name, method.upper(), exception.request.url,
+                        errno, strerror)
+                else:
+                    errno = "n/a"
+                    strerror = "n/a"
+                    logger.error(
+                        "%s REST API failed %s, connection error (url=%s).",
+                        self.client_name, method.upper(), exception.request.url)
+        else:
+            errno = "n/a"
+            strerror = "n/a"
+            logger.error("%s REST API failed %s, connection error (url=%s).",
+                         self.client_name, method.upper(), exception.request.url)
+        if errno != "n/a":
+            exception_msg = (
+                "{} REST API cannot be reached: {} [errno {}]. "
+                "Please check your configuration and that the API endpoint"
+                " is accessible"
+                .format(self.client_name, strerror, errno))
+        else:
+            exception_msg = (
+                "{} REST API cannot be reached. Please check "
+                "your configuration and that the API endpoint is"
+                " accessible"
+                .format(self.client_name))
+        raise RequestException(
+            exception_msg, conn_errno=errno, conn_strerror=strerror)
+
     @staticmethod
     def _handle_response_status_code(status_code: int) -> int:
         """