From becac177a2bd824f36c6a78eaba26f9166940b7a Mon Sep 17 00:00:00 2001 From: Zack Cerza Date: Thu, 9 Feb 2017 13:50:20 -0700 Subject: [PATCH] Add cloud.util.AuthToken This provides a mechanism for caching OpenStack authentication tokens Signed-off-by: Zack Cerza --- .../provision/cloud/test/test_cloud_util.py | 108 ++++++++++++++++++ teuthology/provision/cloud/util.py | 62 ++++++++++ 2 files changed, 170 insertions(+) diff --git a/teuthology/provision/cloud/test/test_cloud_util.py b/teuthology/provision/cloud/test/test_cloud_util.py index 022797c667..40e8157459 100644 --- a/teuthology/provision/cloud/test/test_cloud_util.py +++ b/teuthology/provision/cloud/test/test_cloud_util.py @@ -1,3 +1,7 @@ +import datetime +import dateutil +import json + from mock import patch, MagicMock from pytest import mark @@ -64,3 +68,107 @@ def test_get_user_ssh_pubkey(path, exists): ) def test_combine_dicts(input_, func, expected): assert util.combine_dicts(input_, func) == expected + + +def get_datetime(offset_hours=0): + delta = datetime.timedelta(hours=offset_hours) + return datetime.datetime.now(dateutil.tz.tzutc()) + delta + + +def get_datetime_string(offset_hours=0): + obj = get_datetime(offset_hours) + return obj.strftime(util.AuthToken.time_format) + + +class TestAuthToken(object): + klass = util.AuthToken + + def setup(self): + default_expires = get_datetime_string(0) + self.test_data = dict( + value='token_value', + endpoint='endpoint', + expires=default_expires, + ) + self.patchers = dict() + self.patchers['m_open'] = patch( + 'teuthology.provision.cloud.util.open' + ) + self.patchers['m_exists'] = patch( + 'os.path.exists' + ) + self.patchers['m_file_lock'] = patch( + 'teuthology.provision.cloud.util.FileLock' + ) + self.mocks = dict() + for name, patcher in self.patchers.items(): + self.mocks[name] = patcher.start() + self.mocks['m_open'].return_value = MagicMock(spec=file) + + def teardown(self): + for patcher in self.patchers.values(): + patcher.stop() + + def get_obj(self, name='name', directory='/fake/directory'): + return self.klass( + name=name, + directory=directory, + ) + + def test_no_token(self): + obj = self.get_obj() + self.mocks['m_exists'].return_value = False + with obj: + assert obj.value is None + assert obj.expired is True + + @mark.parametrize( + 'test_data, expired', + [ + [ + dict( + value='token_value', + endpoint='endpoint', + expires=get_datetime_string(-1), + ), + True + ], + [ + dict( + value='token_value', + endpoint='endpoint', + expires=get_datetime_string(1), + ), + False + ], + ] + ) + def test_token_read(self, test_data, expired): + obj = self.get_obj() + self.mocks['m_exists'].return_value = True + self.mocks['m_open'].return_value.__enter__.return_value.read.return_value = \ + json.dumps(test_data) + with obj: + if expired: + assert obj.value is None + assert obj.expired is True + else: + assert obj.value == test_data['value'] + + def test_token_write(self): + obj = self.get_obj() + datetime_obj = get_datetime(0) + datetime_string = get_datetime_string(0) + self.mocks['m_exists'].return_value = False + with obj: + obj.write('value', datetime_obj, 'endpoint') + m_open = self.mocks['m_open'] + write_calls = m_open.return_value.__enter__.return_value.write\ + .call_args_list + assert len(write_calls) == 1 + expected = json.dumps(dict( + value='value', + expires=datetime_string, + endpoint='endpoint', + )) + assert write_calls[0][0][0] == expected diff --git a/teuthology/provision/cloud/util.py b/teuthology/provision/cloud/util.py index fa232fa0a5..e9afa22b4a 100644 --- a/teuthology/provision/cloud/util.py +++ b/teuthology/provision/cloud/util.py @@ -1,5 +1,11 @@ +import datetime +import dateutil.tz +import dateutil.parser +import json import os +from teuthology.util.flock import FileLock + def get_user_ssh_pubkey(path='~/.ssh/id_rsa.pub'): full_path = os.path.expanduser(path) @@ -53,3 +59,59 @@ def selective_update(a, b, func): selective_update(a[key], value, func) if func(value, a[key]): a[key] = value + + +class AuthToken(object): + time_format = '%Y-%m-%d %H:%M:%S%z' + + def __init__(self, name, directory=os.path.expanduser('~/.cache/')): + self.name = name + self.directory = directory + self.path = os.path.join(directory, name) + self.lock_path = "%s.lock" % self.path + self.expires = None + self.value = None + self.endpoint = None + + def read(self): + if not os.path.exists(self.path): + self.value = None + self.expires = None + self.endpoint = None + return + with open(self.path, 'r') as obj: + string = obj.read() + obj = json.loads(string) + self.expires = dateutil.parser.parse(obj['expires']) + if self.expired: + self.value = None + self.endpoint = None + else: + self.value = obj['value'] + self.endpoint = obj['endpoint'] + + def write(self, value, expires, endpoint): + obj = dict( + value=value, + expires=datetime.datetime.strftime(expires, self.time_format), + endpoint=endpoint, + ) + string = json.dumps(obj) + with open(self.path, 'w') as obj: + obj.write(string) + + @property + def expired(self): + if self.expires is None: + return True + utcnow = datetime.datetime.now(dateutil.tz.tzutc()) + offset = datetime.timedelta(minutes=30) + return self.expires < (utcnow + offset) + + def __enter__(self): + with FileLock(self.lock_path): + self.read() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass -- 2.39.5