]> git.apps.os.sepia.ceph.com Git - ceph.git/commitdiff
rgw/test_multi: add support for elasticsearch testing
authorYehuda Sadeh <yehuda@redhat.com>
Fri, 14 Apr 2017 00:17:46 +0000 (17:17 -0700)
committerYehuda Sadeh <yehuda@redhat.com>
Tue, 30 May 2017 20:26:55 +0000 (13:26 -0700)
Add support for different zone types, and create an elasticsearch
zone type that deals with es testing.

Signed-off-by: Yehuda Sadeh <yehuda@redhat.com>
src/rgw/rgw_sync_module_es.cc
src/rgw/rgw_sync_module_es_rest.cc
src/test/rgw/rgw_multi/conn.py [new file with mode: 0644]
src/test/rgw/rgw_multi/multisite.py
src/test/rgw/rgw_multi/tests.py
src/test/rgw/rgw_multi/zone_es.py [new file with mode: 0644]
src/test/rgw/rgw_multi/zone_rados.py [new file with mode: 0644]
src/test/rgw/test_multi.py

index 4c83cad1f72b789687621cedddca6bce655af5e5..efd1f4b32c8a4dbfcd3b5716c3d28f75ec3b76ef 100644 (file)
@@ -162,14 +162,18 @@ using ElasticConfigRef = std::shared_ptr<ElasticConfig>;
 struct es_dump_type {
   const char *type;
   const char *format;
+  bool analyzed;
 
-  es_dump_type(const char *t, const char *f = nullptr) : type(t), format(f) {}
+  es_dump_type(const char *t, const char *f = nullptr, bool a = false) : type(t), format(f), analyzed(a) {}
 
   void dump(Formatter *f) const {
     encode_json("type", type, f);
     if (format) {
       encode_json("format", format, f);
     }
+    if (!analyzed && strcmp(type, "string") == 0) {
+      encode_json("index", "not_analyzed", f);
+    }
   }
 };
 
@@ -178,10 +182,7 @@ struct es_index_mappings {
     f->open_object_section(section);
     ::encode_json("type", "nested", f);
     f->open_object_section("properties");
-    f->open_object_section("name");
-    ::encode_json("type", "string", f);
-    ::encode_json("index", "not_analyzed", f);
-    f->close_section(); // name
+    encode_json("name", es_dump_type("string"), f);
     encode_json("value", es_dump_type(type, format), f);
     f->close_section(); // entry
     f->close_section(); // custom-string
index 6999f5d59acfa3fb211fffe747ac3a4125a7ec5a..edbbab2d18566f3ef861b1a29c63dea66e40dd3f 100644 (file)
@@ -19,6 +19,7 @@ struct es_index_obj_response {
     uint64_t size{0};
     ceph::real_time mtime;
     string etag;
+    string content_type;
     map<string, string> custom_str;
     map<string, int64_t> custom_int;
     map<string, string> custom_date;
@@ -39,6 +40,7 @@ struct es_index_obj_response {
       JSONDecoder::decode_json("mtime", mtime_str, obj);
       parse_time(mtime_str.c_str(), &mtime);
       JSONDecoder::decode_json("etag", etag, obj);
+      JSONDecoder::decode_json("content_type", content_type, obj);
       list<_custom_entry<string> > str_entries;
       JSONDecoder::decode_json("custom-string", str_entries, obj);
       for (auto& e : str_entries) {
@@ -312,16 +314,21 @@ public:
     if (is_truncated) {
       s->formatter->dump_string("NextMarker", next_marker);
     }
+    if (s->format == RGW_FORMAT_JSON) {
+      s->formatter->open_array_section("Objects");
+    }
     for (auto& i : response.hits.hits) {
-      es_index_obj_response& e = i.source;
       s->formatter->open_object_section("Contents");
+      es_index_obj_response& e = i.source;
       s->formatter->dump_string("Bucket", e.bucket);
       s->formatter->dump_string("Key", e.key.name);
       string instance = (!e.key.instance.empty() ? e.key.instance : "null");
       s->formatter->dump_string("Instance", instance.c_str());
       s->formatter->dump_int("VersionedEpoch", e.versioned_epoch);
       dump_time(s, "LastModified", &e.meta.mtime);
+      s->formatter->dump_int("Size", e.meta.size);
       s->formatter->dump_format("ETag", "\"%s\"", e.meta.etag.c_str());
+      s->formatter->dump_string("ContentType", e.meta.content_type.c_str());
       dump_owner(s, e.owner.get_id(), e.owner.get_display_name());
       s->formatter->open_array_section("CustomMetadata");
       for (auto& m : e.meta.custom_str) {
@@ -343,9 +350,12 @@ public:
         s->formatter->close_section();
       }
       s->formatter->close_section();
-      s->formatter->close_section();
       rgw_flush_formatter(s, s->formatter);
+      s->formatter->close_section();
     };
+    if (s->format == RGW_FORMAT_JSON) {
+      s->formatter->close_section();
+    }
     s->formatter->close_section();
    rgw_flush_formatter_and_reset(s, s->formatter);
   }
diff --git a/src/test/rgw/rgw_multi/conn.py b/src/test/rgw/rgw_multi/conn.py
new file mode 100644 (file)
index 0000000..1099664
--- /dev/null
@@ -0,0 +1,16 @@
+import boto
+import boto.s3.connection
+
+
+def get_gateway_connection(gateway, credentials):
+    """ connect to the given gateway """
+    if gateway.connection is None:
+        gateway.connection = boto.connect_s3(
+                aws_access_key_id = credentials.access_key,
+                aws_secret_access_key = credentials.secret,
+                host = gateway.host,
+                port = gateway.port,
+                is_secure = False,
+                calling_format = boto.s3.connection.OrdinaryCallingFormat())
+    return gateway.connection
+
index 2d392cf8893c092a6ce15ea72a72ccc4ae580346..278b74b6ac8e8b05fcd2b0eaf8ffeba3e09b907a 100644 (file)
@@ -2,6 +2,8 @@ from abc import ABCMeta, abstractmethod
 from cStringIO import StringIO
 import json
 
+from conn import get_gateway_connection
+
 class Cluster:
     """ interface to run commands against a distinct ceph cluster """
     __metaclass__ = ABCMeta
@@ -154,6 +156,27 @@ class Zone(SystemObject, SystemObject.CreateDelete, SystemObject.GetSet, SystemO
     def realm(self):
         return self.zonegroup.realm() if self.zonegroup else None
 
+    def is_read_only(self):
+        return False
+
+    def tier_type(self):
+        raise NotImplementedError
+
+    def has_buckets(self):
+        return True
+
+    def get_connection(self, credentials):
+        """ connect to the zone's first gateway """
+        if isinstance(credentials, list):
+            credentials = credentials[0]
+        return get_gateway_connection(self.gateways[0], credentials)
+
+    def get_bucket(self, bucket_name, credentials):
+        raise NotImplementedError
+
+    def check_bucket_eq(self, zone, bucket_name):
+        raise NotImplementedError
+
 class ZoneGroup(SystemObject, SystemObject.CreateDelete, SystemObject.GetSet, SystemObject.Modify):
     def __init__(self, name, period = None, data = None, zonegroup_id = None, zones = None, master_zone  = None):
         self.name = name
@@ -161,6 +184,13 @@ class ZoneGroup(SystemObject, SystemObject.CreateDelete, SystemObject.GetSet, Sy
         self.zones = zones or []
         self.master_zone = master_zone
         super(ZoneGroup, self).__init__(data, zonegroup_id)
+        self.rw_zones = []
+        self.ro_zones = []
+        for z in self.zones:
+            if z.is_read_only():
+                self.ro_zones.append(z)
+            else:
+                self.rw_zones.append(z)
 
     def zonegroup_arg(self):
         """ command-line argument to specify this zonegroup """
index 3633b22a33622d313049c225c8000fd18845cc5c..c89229fa848ecd35b709487d7efc0c9a5eac1f03 100644 (file)
@@ -4,6 +4,7 @@ import string
 import sys
 import time
 import logging
+
 try:
     from itertools import izip_longest as zip_longest
 except ImportError:
@@ -19,6 +20,8 @@ from nose.plugins.skip import SkipTest
 
 from .multisite import Zone
 
+from rgw_multi_conn import get_gateway_connection
+
 class Config:
     """ test configuration """
     def __init__(self, **kwargs):
@@ -73,6 +76,15 @@ def mdlog_list(zone, period = None):
     mdlog_json = mdlog_json.decode('utf-8')
     return json.loads(mdlog_json)
 
+def meta_sync_status(zone):
+    while True:
+        cmd = ['metadata', 'sync', 'status'] + zone.zone_args()
+        meta_sync_status_json, retcode = zone.cluster.admin(cmd, check_retcode=False, read_only=True)
+        if retcode == 0:
+            break
+        assert(retcode == 2) # ENOENT
+        time.sleep(5)
+
 def mdlog_autotrim(zone):
     zone.cluster.admin(['mdlog', 'autotrim'])
 
@@ -376,7 +388,10 @@ def gen_bucket_name():
     return run_prefix + '-' + str(num_buckets)
 
 def check_all_buckets_exist(zone, buckets):
-    conn = get_zone_connection(zone, user.credentials)
+    if not zone.has_buckets():
+        return True
+
+    conn = zone.get_connection(user.credentials)
     for b in buckets:
         try:
             conn.get_bucket(b)
@@ -387,7 +402,10 @@ def check_all_buckets_exist(zone, buckets):
     return True
 
 def check_all_buckets_dont_exist(zone, buckets):
-    conn = get_zone_connection(zone, user.credentials)
+    if not zone.has_buckets():
+        return True
+
+    conn = zone.get_connection(user.credentials)
     for b in buckets:
         try:
             conn.get_bucket(b)
@@ -402,8 +420,8 @@ def check_all_buckets_dont_exist(zone, buckets):
 def create_bucket_per_zone(zonegroup):
     buckets = []
     zone_bucket = {}
-    for zone in zonegroup.zones:
-        conn = get_zone_connection(zone, user.credentials)
+    for zone in zonegroup.rw_zones:
+        conn = zone.get_connection(user.credentials)
         bucket_name = gen_bucket_name()
         log.info('create bucket zone=%s name=%s', zone.name, bucket_name)
         bucket = conn.create_bucket(bucket_name)
@@ -438,9 +456,9 @@ def test_bucket_recreate():
         assert check_all_buckets_exist(zone, buckets)
 
     # recreate buckets on all zones, make sure they weren't removed
-    for zone in zonegroup.zones:
+    for zone in zonegroup.rw_zones:
         for bucket_name in buckets:
-            conn = get_zone_connection(zone, user.credentials)
+            conn = zone.get_connection(user.credentials)
             bucket = conn.create_bucket(bucket_name)
 
     for zone in zonegroup.zones:
@@ -460,7 +478,7 @@ def test_bucket_remove():
         assert check_all_buckets_exist(zone, buckets)
 
     for zone, bucket_name in zone_bucket.items():
-        conn = get_zone_connection(zone, user.credentials)
+        conn = zone.get_connection(user.credentials)
         conn.delete_bucket(bucket_name)
 
     zonegroup_meta_checkpoint(zonegroup)
@@ -469,7 +487,7 @@ def test_bucket_remove():
         assert check_all_buckets_dont_exist(zone, buckets)
 
 def get_bucket(zone, bucket_name):
-    conn = get_zone_connection(zone, user.credentials)
+    conn = zone.get_connection(user.credentials)
     return conn.get_bucket(bucket_name)
 
 def get_key(zone, bucket_name, obj_name):
@@ -480,58 +498,8 @@ def new_key(zone, bucket_name, obj_name):
     b = get_bucket(zone, bucket_name)
     return b.new_key(obj_name)
 
-def check_object_eq(k1, k2, check_extra = True):
-    assert k1
-    assert k2
-    log.debug('comparing key name=%s', k1.name)
-    eq(k1.name, k2.name)
-    eq(k1.get_contents_as_string(), k2.get_contents_as_string())
-    eq(k1.metadata, k2.metadata)
-    eq(k1.cache_control, k2.cache_control)
-    eq(k1.content_type, k2.content_type)
-    eq(k1.content_encoding, k2.content_encoding)
-    eq(k1.content_disposition, k2.content_disposition)
-    eq(k1.content_language, k2.content_language)
-    eq(k1.etag, k2.etag)
-    eq(k1.last_modified, k2.last_modified)
-    if check_extra:
-        eq(k1.owner.id, k2.owner.id)
-        eq(k1.owner.display_name, k2.owner.display_name)
-    eq(k1.storage_class, k2.storage_class)
-    eq(k1.size, k2.size)
-    eq(k1.version_id, k2.version_id)
-    eq(k1.encrypted, k2.encrypted)
-
-def check_bucket_eq(zone1, zone2, bucket_name):
-    log.info('comparing bucket=%s zones={%s, %s}', bucket_name, zone1.name, zone2.name)
-    b1 = get_bucket(zone1, bucket_name)
-    b2 = get_bucket(zone2, bucket_name)
-
-    log.debug('bucket1 objects:')
-    for o in b1.get_all_versions():
-        log.debug('o=%s', o.name)
-    log.debug('bucket2 objects:')
-    for o in b2.get_all_versions():
-        log.debug('o=%s', o.name)
-
-    for k1, k2 in zip_longest(b1.get_all_versions(), b2.get_all_versions()):
-        if k1 is None:
-            log.critical('key=%s is missing from zone=%s', k2.name, zone1.name)
-            assert False
-        if k2 is None:
-            log.critical('key=%s is missing from zone=%s', k1.name, zone2.name)
-            assert False
-
-        check_object_eq(k1, k2)
-
-        # now get the keys through a HEAD operation, verify that the available data is the same
-        k1_head = b1.get_key(k1.name)
-        k2_head = b2.get_key(k2.name)
-
-        check_object_eq(k1_head, k2_head, False)
-
-    log.info('success, bucket identical: bucket=%s zones={%s, %s}', bucket_name, zone1.name, zone2.name)
-
+def check_bucket_eq(zone1, zone2, bucket):
+    return zone2.check_bucket_eq(zone1, bucket.name, user.credentials)
 
 def test_object_sync():
     zonegroup = realm.master_zonegroup()
@@ -617,7 +585,7 @@ def test_versioned_object_incremental_sync():
 
     for _, bucket in zone_bucket.items():
         # create and delete multiple versions of an object from each zone
-        for zone in zonegroup.zones:
+        for zone in zonegroup.rw_zones:
             obj = 'obj-' + zone.name
             k = new_key(zone, bucket, obj)
 
@@ -668,7 +636,7 @@ def test_bucket_delete_notempty():
 
     for zone, bucket_name in zone_bucket.items():
         # upload an object to each bucket on its own zone
-        conn = get_zone_connection(zone, user.credentials)
+        conn = zone.get_connection(user.credentials)
         bucket = conn.get_bucket(bucket_name)
         k = bucket.new_key('foo')
         k.set_contents_from_string('bar')
@@ -681,7 +649,7 @@ def test_bucket_delete_notempty():
         assert False # expected 409 BucketNotEmpty
 
     # assert that each bucket still exists on the master
-    c1 = get_zone_connection(zonegroup.master_zone, user.credentials)
+    c1 = zonegroup.master_zone.get_connection(user.credentials)
     for _, bucket_name in zone_bucket.items():
         assert c1.get_bucket(bucket_name)
 
diff --git a/src/test/rgw/rgw_multi/zone_es.py b/src/test/rgw/rgw_multi/zone_es.py
new file mode 100644 (file)
index 0000000..c9bd685
--- /dev/null
@@ -0,0 +1,186 @@
+import json
+import urllib
+import logging
+
+import boto
+import boto.s3.connection
+
+from nose.tools import eq_ as eq
+try:
+    from itertools import izip_longest as zip_longest
+except ImportError:
+    from itertools import zip_longest
+
+from rgw_multi.multisite import *
+
+log = logging.getLogger(__name__)
+
+def check_object_eq(k1, k2, check_extra = True):
+    assert k1
+    assert k2
+    log.debug('comparing key name=%s', k1.name)
+    eq(k1.name, k2.name)
+    eq(k1.metadata, k2.metadata)
+    # eq(k1.cache_control, k2.cache_control)
+    eq(k1.content_type, k2.content_type)
+    # eq(k1.content_encoding, k2.content_encoding)
+    # eq(k1.content_disposition, k2.content_disposition)
+    # eq(k1.content_language, k2.content_language)
+    eq(k1.etag, k2.etag)
+    eq(k1.last_modified, k2.last_modified)
+    if check_extra:
+        eq(k1.owner.id, k2.owner.id)
+        eq(k1.owner.display_name, k2.owner.display_name)
+    # eq(k1.storage_class, k2.storage_class)
+    eq(k1.size, k2.size)
+    eq(k1.version_id, k2.version_id)
+    # eq(k1.encrypted, k2.encrypted)
+
+def make_request(conn, method, bucket, key, query_args, headers):
+    result = conn.make_request(method, bucket=bucket, key=key, query_args=query_args, headers=headers)
+    if result.status / 100 != 2:
+        raise boto.exception.S3ResponseError(result.status, result.reason, result.read())
+    return result
+
+def dump_json(o):
+    return json.dumps(o, indent=4)
+
+def append_query_arg(s, n, v):
+    if not v:
+        return s
+    nv = '{n}={v}'.format(n=n, v=v)
+    if not s:
+        return nv
+    return '{s}&{nv}'.format(s=s, nv=nv)
+
+class MDSearch:
+    def __init__(self, conn, bucket_name, query, query_args = None, marker = None):
+        self.conn = conn
+        self.bucket_name = bucket_name or ''
+        self.query = query
+        self.query_args = query_args
+        self.max_keys = None
+        self.marker = marker
+
+    def search(self):
+        q = self.query or ''
+        query_args = append_query_arg(self.query_args, 'query', urllib.quote_plus(q))
+        if self.max_keys is not None:
+            query_args = append_query_arg(query_args, 'max-keys', self.max_keys)
+        if self.marker:
+            query_args = append_query_arg(query_args, 'marker', self.marker)
+
+        query_args = append_query_arg(query_args, 'format', 'json')
+
+        headers = {}
+
+        result = make_request(self.conn, "GET", bucket=self.bucket_name, key='', query_args=query_args, headers=headers)
+        return json.loads(result.read())
+
+
+class ESZoneBucket:
+    def __init__(self, zone, name, credentials):
+        self.zone = zone
+        self.name = name
+        self.conn = zone.get_connection(credentials)
+
+        self.bucket = boto.s3.bucket.Bucket(name=name)
+
+    def get_all_versions(self):
+
+        marker = None
+        is_done = False
+
+        l = []
+
+        while not is_done:
+            req = MDSearch(self.conn, self.name, 'bucket == ' + self.name, marker=marker)
+
+            result = req.search()
+
+            for entry in result['Objects']:
+                k = boto.s3.key.Key(self.bucket, entry['Key'])
+
+                k.version_id = entry['Instance']
+                k.etag = entry['ETag']
+                k.owner = boto.s3.user.User(id=entry['Owner']['ID'], display_name=entry['Owner']['DisplayName'])
+                k.last_modified = entry['LastModified']
+                k.size = entry['Size']
+                k.content_type = entry['ContentType']
+                k.versioned_epoch = entry['VersionedEpoch']
+
+                k.metadata = {}
+                for e in entry['CustomMetadata']:
+                    k.metadata[e['Name']] = e['Value']
+
+                l.append(k)
+
+            is_done = (result['IsTruncated'] == "false")
+            marker = result['Marker']
+
+        l.sort(key = lambda l: (l.name, -l.versioned_epoch))
+
+        for k in l:
+            yield k
+
+
+
+
+class ESZone(Zone):
+    def __init__(self, name, es_endpoint, zonegroup = None, cluster = None, data = None, zone_id = None, gateways = []):
+        self.es_endpoint = es_endpoint
+        super(ESZone, self).__init__(name, zonegroup, cluster, data, zone_id, gateways)
+
+    def is_read_only(self):
+        return True
+
+    def tier_type(self):
+        return "elasticsearch"
+
+    def create(self, cluster, args = None, check_retcode = True):
+        """ create the object with the given arguments """
+
+        if args is None:
+            args = ''
+
+        tier_config = ','.join([ 'endpoint=' + self.es_endpoint, 'explicit_custom_meta=false' ])
+
+        args += [ '--tier-type', self.tier_type(), '--tier-config', tier_config ] 
+
+        return self.json_command(cluster, 'create', args, check_retcode=check_retcode)
+
+    def has_buckets(self):
+        return False
+
+    def get_bucket(self, bucket_name, credentials):
+        return ESZoneBucket(self, bucket_name, credentials)
+
+    def check_bucket_eq(self, zone, bucket_name, credentials):
+        assert(zone.tier_type() == "rados")
+
+        log.info('comparing bucket=%s zones={%s, %s}', bucket_name, self.name, zone.name)
+        b1 = self.get_bucket(bucket_name, credentials)
+        b2 = zone.get_bucket(bucket_name, credentials)
+
+        log.debug('bucket1 objects:')
+        for o in b1.get_all_versions():
+            log.debug('o=%s', o.name)
+        log.debug('bucket2 objects:')
+        for o in b2.get_all_versions():
+            log.debug('o=%s', o.name)
+
+        for k1, k2 in zip_longest(b1.get_all_versions(), b2.get_all_versions()):
+            if k1 is None:
+                log.critical('key=%s is missing from zone=%s', k2.name, self.name)
+                assert False
+            if k2 is None:
+                log.critical('key=%s is missing from zone=%s', k1.name, zone.name)
+                assert False
+
+            check_object_eq(k1, k2)
+
+
+        log.info('success, bucket identical: bucket=%s zones={%s, %s}', bucket_name, self.name, zone.name)
+
+
+        return True
diff --git a/src/test/rgw/rgw_multi/zone_rados.py b/src/test/rgw/rgw_multi/zone_rados.py
new file mode 100644 (file)
index 0000000..675dd5b
--- /dev/null
@@ -0,0 +1,78 @@
+import logging
+
+try:
+    from itertools import izip_longest as zip_longest
+except ImportError:
+    from itertools import zip_longest
+
+from nose.tools import eq_ as eq
+
+from rgw_multi.multisite import *
+
+log = logging.getLogger(__name__)
+
+def check_object_eq(k1, k2, check_extra = True):
+    assert k1
+    assert k2
+    log.debug('comparing key name=%s', k1.name)
+    eq(k1.name, k2.name)
+    eq(k1.get_contents_as_string(), k2.get_contents_as_string())
+    eq(k1.metadata, k2.metadata)
+    eq(k1.cache_control, k2.cache_control)
+    eq(k1.content_type, k2.content_type)
+    eq(k1.content_encoding, k2.content_encoding)
+    eq(k1.content_disposition, k2.content_disposition)
+    eq(k1.content_language, k2.content_language)
+    eq(k1.etag, k2.etag)
+    eq(k1.last_modified, k2.last_modified)
+    if check_extra:
+        eq(k1.owner.id, k2.owner.id)
+        eq(k1.owner.display_name, k2.owner.display_name)
+    eq(k1.storage_class, k2.storage_class)
+    eq(k1.size, k2.size)
+    eq(k1.version_id, k2.version_id)
+    eq(k1.encrypted, k2.encrypted)
+
+
+class RadosZone(Zone):
+    def __init__(self, name, zonegroup = None, cluster = None, data = None, zone_id = None, gateways = []):
+        super(RadosZone, self).__init__(name, zonegroup, cluster, data, zone_id, gateways)
+
+    def  tier_type(self):
+        return "rados"
+
+    def get_bucket(self, name, credentials):
+        conn = self.get_connection(credentials)
+        return conn.get_bucket(name, credentials)
+
+    def check_bucket_eq(self, zone, bucket_name, credentials):
+        log.info('comparing bucket=%s zones={%s, %s}', bucket_name, self.name, zone.name)
+        b1 = self.get_bucket(bucket_name, credentials)
+        b2 = zone.get_bucket(bucket_name, credentials)
+
+        log.debug('bucket1 objects:')
+        for o in b1.get_all_versions():
+            log.debug('o=%s', o.name)
+        log.debug('bucket2 objects:')
+        for o in b2.get_all_versions():
+            log.debug('o=%s', o.name)
+
+        for k1, k2 in zip_longest(b1.get_all_versions(), b2.get_all_versions()):
+            if k1 is None:
+                log.critical('key=%s is missing from zone=%s', k2.name, self.name)
+                assert False
+            if k2 is None:
+                log.critical('key=%s is missing from zone=%s', k1.name, zone.name)
+                assert False
+
+            check_object_eq(k1, k2)
+
+            # now get the keys through a HEAD operation, verify that the available data is the same
+            k1_head = b1.get_key(k1.name)
+            k2_head = b2.get_key(k2.name)
+
+            check_object_eq(k1_head, k2_head, False)
+
+        log.info('success, bucket identical: bucket=%s zones={%s, %s}', bucket_name, self.name, zone.name)
+
+
index e380acc99038e9eed4cfa7018ad9b98d15641d92..2e0870868ff171034b990d6a9b81c6c877692226 100644 (file)
@@ -13,6 +13,9 @@ except ImportError:
 import nose.core
 
 from rgw_multi import multisite
+from rgw_multi.zone_rados import RadosZone as RadosZone
+from rgw_multi.zone_es  import ESZone as ESZone
+
 # make tests from rgw_multi.tests available to nose
 from rgw_multi.tests import *
 
@@ -146,6 +149,7 @@ def init(parse_args):
     cfg = configparser.RawConfigParser({
                                          'num_zonegroups': 1,
                                          'num_zones': 3,
+                                         'num_es_zones': 0,
                                          'gateways_per_zone': 2,
                                          'no_bootstrap': 'false',
                                          'log_level': 20,
@@ -155,6 +159,7 @@ def init(parse_args):
                                          'checkpoint_retries': 60,
                                          'checkpoint_delay': 5,
                                          'reconfigure_delay': 5,
+                                         'es_endpoint': None,
                                          })
     try:
         path = os.environ['RGW_MULTI_TEST_CONF']
@@ -175,6 +180,7 @@ def init(parse_args):
     section = 'DEFAULT'
     parser.add_argument('--num-zonegroups', type=int, default=cfg.getint(section, 'num_zonegroups'))
     parser.add_argument('--num-zones', type=int, default=cfg.getint(section, 'num_zones'))
+    parser.add_argument('--num-es-zones', type=int, default=cfg.getint(section, 'num_es_zones'))
     parser.add_argument('--gateways-per-zone', type=int, default=cfg.getint(section, 'gateways_per_zone'))
     parser.add_argument('--no-bootstrap', action='store_true', default=cfg.getboolean(section, 'no_bootstrap'))
     parser.add_argument('--log-level', type=int, default=cfg.getint(section, 'log_level'))
@@ -184,6 +190,7 @@ def init(parse_args):
     parser.add_argument('--checkpoint-retries', type=int, default=cfg.getint(section, 'checkpoint_retries'))
     parser.add_argument('--checkpoint-delay', type=int, default=cfg.getint(section, 'checkpoint_delay'))
     parser.add_argument('--reconfigure-delay', type=int, default=cfg.getint(section, 'reconfigure_delay'))
+    parser.add_argument('--es-endpoint', type=str, default=cfg.get(section, 'es_endpoint'))
 
     argv = []
 
@@ -193,6 +200,9 @@ def init(parse_args):
     args = parser.parse_args(argv)
     bootstrap = not args.no_bootstrap
 
+    # if num_es_zones is defined, need to have es_endpoint defined too
+    assert(args.num_es_zones == 0 or args.es_endpoint)
+
     setup_logging(args.log_level, args.log_file, args.file_log_level)
 
     # start first cluster
@@ -217,6 +227,8 @@ def init(parse_args):
     period = multisite.Period(realm=realm)
     realm.current_period = period
 
+    num_zones = args.num_zones + args.num_es_zones
+
     for zg in range(0, args.num_zonegroups):
         zonegroup = multisite.ZoneGroup(zonegroup_name(zg), period)
         period.zonegroups.append(zonegroup)
@@ -225,7 +237,7 @@ def init(parse_args):
         if is_master_zg:
             period.master_zonegroup = zonegroup
 
-        for z in range(0, args.num_zones):
+        for z in range(0, num_zones):
             is_master = z == 0
             # start a cluster, or use c1 for first zone
             cluster = None
@@ -253,8 +265,15 @@ def init(parse_args):
                 else:
                     zonegroup.get(cluster)
 
+            es_zone = (z >= args.num_zones)
+
             # create the zone in its zonegroup
             zone = multisite.Zone(zone_name(zg, z), zonegroup, cluster)
+            if es_zone:
+                zone = ESZone(zone_name(zg, z), args.es_endpoint, zonegroup, cluster)
+            else:
+                zone = RadosZone(zone_name(zg, z), zonegroup, cluster)
+
             if bootstrap:
                 arg = admin_creds.credential_args()
                 if is_master:
@@ -268,6 +287,11 @@ def init(parse_args):
             if is_master:
                 zonegroup.master_zone = zone
 
+            if zone.is_read_only():
+                zonegroup.ro_zones.append(zone)
+            else:
+                zonegroup.rw_zones.append(zone)
+
             # update/commit the period
             if bootstrap:
                 period.update(zone, commit=True)