]> git.apps.os.sepia.ceph.com Git - ceph.git/commitdiff
test/rgw/notification: use simpler multithreaded http server
authorYuval Lifshitz <ylifshit@ibm.com>
Mon, 25 Mar 2024 11:11:31 +0000 (11:11 +0000)
committerCasey Bodley <cbodley@redhat.com>
Wed, 10 Apr 2024 13:18:07 +0000 (09:18 -0400)
Fixes: https://tracker.ceph.com/issues/63909
Signed-off-by: Yuval Lifshitz <ylifshit@ibm.com>
(cherry picked from commit 673adcbdbd01e64c8b76c3176e062571fb8710ac)

src/test/rgw/bucket_notification/test_bn.py

index b73da03cd6cce39263da85a694f8e450e4ae0e7a..9e23dcfa94e99c54314e9fe0d0b4bc2515195a92 100644 (file)
@@ -12,7 +12,7 @@ import string
 # XXX this should be converted to use boto3
 import boto
 from botocore.exceptions import ClientError
-from http import server as http_server
+from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
 from random import randint
 import hashlib
 # XXX this should be converted to use pytest
@@ -22,6 +22,7 @@ from boto.s3.connection import S3Connection
 import datetime
 from cloudevents.http import from_http
 from dateutil import parser
+import requests
 
 from . import(
     get_config_host,
@@ -66,7 +67,72 @@ def set_contents_from_string(key, content):
         print('Error: ' + str(e))
 
 
-class HTTPPostHandler(http_server.BaseHTTPRequestHandler):
+def verify_s3_records_by_elements(records, keys, exact_match=False, deletions=False, expected_sizes={}, etags=[]):
+    """ verify there is at least one record per element """
+    err = ''
+    for key in keys:
+        key_found = False
+        object_size = 0
+        if type(records) is list:
+            for record_list in records:
+                if key_found:
+                    break
+                for record in record_list['Records']:
+                    assert_in('eTag', record['s3']['object'])
+                    if record['s3']['bucket']['name'] == key.bucket.name and \
+                        record['s3']['object']['key'] == key.name:
+                        # Assertion Error needs to be fixed
+                        #assert_equal(key.etag[1:-1], record['s3']['object']['eTag'])
+                        if etags:
+                            assert_in(key.etag[1:-1], etags)
+                        if len(record['s3']['object']['metadata']) > 0:
+                            for meta in record['s3']['object']['metadata']:
+                                assert(meta['key'].startswith(META_PREFIX))
+                        if deletions and record['eventName'].startswith('ObjectRemoved'):
+                            key_found = True
+                            object_size = record['s3']['object']['size']
+                            break
+                        elif not deletions and record['eventName'].startswith('ObjectCreated'):
+                            key_found = True
+                            object_size = record['s3']['object']['size']
+                            break
+        else:
+            for record in records['Records']:
+                assert_in('eTag', record['s3']['object'])
+                if record['s3']['bucket']['name'] == key.bucket.name and \
+                    record['s3']['object']['key'] == key.name:
+                    assert_equal(key.etag, record['s3']['object']['eTag'])
+                    if etags:
+                        assert_in(key.etag[1:-1], etags)
+                    if len(record['s3']['object']['metadata']) > 0:
+                        for meta in record['s3']['object']['metadata']:
+                            assert(meta['key'].startswith(META_PREFIX))
+                    if deletions and record['eventName'].startswith('ObjectRemoved'):
+                        key_found = True
+                        object_size = record['s3']['object']['size']
+                        break
+                    elif not deletions and record['eventName'].startswith('ObjectCreated'):
+                        key_found = True
+                        object_size = record['s3']['object']['size']
+                        break
+
+        if not key_found:
+            err = 'no ' + ('deletion' if deletions else 'creation') + ' event found for key: ' + str(key)
+            assert False, err
+        elif expected_sizes:
+            assert_equal(object_size, expected_sizes.get(key.name))
+
+    if not len(records) == len(keys):
+        err = 'superfluous records are found'
+        log.warning(err)
+        if exact_match:
+            for record_list in records:
+                for record in record_list['Records']:
+                    log.error(str(record['s3']['bucket']['name']) + ',' + str(record['s3']['object']['key']))
+            assert False, err
+
+
+class HTTPPostHandler(BaseHTTPRequestHandler):
     """HTTP POST hanler class storing the received events in its http server"""
     def do_POST(self):
         """implementation of POST handler"""
@@ -82,7 +148,7 @@ class HTTPPostHandler(http_server.BaseHTTPRequestHandler):
             assert_equal(event['datacontenttype'], 'application/json') 
             assert_equal(event['subject'], record['s3']['object']['key'])
             assert_equal(parser.parse(event['time']), parser.parse(record['eventTime']))
-        log.info('HTTP Server (%d) received event: %s', self.server.worker_id, str(body))
+        log.info('HTTP Server received event: %s', str(body))
         self.server.append(json.loads(body))
         if self.headers.get('Expect') == '100-continue':
             self.send_response(100)
@@ -93,92 +159,81 @@ class HTTPPostHandler(http_server.BaseHTTPRequestHandler):
         self.end_headers()
 
 
-class HTTPServerWithEvents(http_server.HTTPServer):
-    """HTTP server used by the handler to store events"""
-    def __init__(self, addr, handler, worker_id, delay=0, cloudevents=False):
-        http_server.HTTPServer.__init__(self, addr, handler, False)
-        self.worker_id = worker_id
+class HTTPServerWithEvents(ThreadingHTTPServer):
+    """multithreaded HTTP server used by the handler to store events"""
+    def __init__(self, addr, delay=0, cloudevents=False):
         self.events = []
         self.delay = delay
         self.cloudevents = cloudevents
+        self.addr = addr
+        self.lock = threading.Lock()
+        ThreadingHTTPServer.__init__(self, addr, HTTPPostHandler)
+        log.info('http server created on %s', str(self.addr))
+        self.proc = threading.Thread(target=self.run)
+        self.proc.start()
+        retries = 0
+        while self.proc.is_alive() == False and retries < 5:
+            retries += 1
+            time.sleep(5)
+            log.warning('http server on %s did not start yet', str(self.addr))
+        if not self.proc.is_alive():
+            log.error('http server on %s failed to start. closing...', str(self.addr))
+            self.close()
+            assert False
 
-    def append(self, event):
-        self.events.append(event)
-
-class HTTPServerThread(threading.Thread):
-    """thread for running the HTTP server. reusing the same socket for all threads"""
-    def __init__(self, i, sock, addr, delay=0, cloudevents=False):
-        threading.Thread.__init__(self)
-        self.i = i
-        self.daemon = True
-        self.httpd = HTTPServerWithEvents(addr, HTTPPostHandler, i, delay, cloudevents)
-        self.httpd.socket = sock
-        # prevent the HTTP server from re-binding every handler
-        self.httpd.server_bind = self.server_close = lambda self: None
-        self.start()
 
     def run(self):
-        try:
-            log.info('HTTP Server (%d) started on: %s', self.i, self.httpd.server_address)
-            self.httpd.serve_forever()
-            log.info('HTTP Server (%d) ended', self.i)
-        except Exception as error:
-            # could happen if the server r/w to a closing socket during shutdown
-            log.info('HTTP Server (%d) ended unexpectedly: %s', self.i, str(error))
-
-    def close(self):
-        self.httpd.shutdown()
+        log.info('http server started on %s', str(self.addr))
+        self.serve_forever()
+        self.server_close()
+        log.info('http server ended on %s', str(self.addr))
 
-    def get_events(self):
-        return self.httpd.events
-
-    def reset_events(self):
-        self.httpd.events = []
-
-class StreamingHTTPServer:
-    """multi-threaded http server class also holding list of events received into the handler
-    each thread has its own server, and all servers share the same socket"""
-    def __init__(self, host, port, num_workers=100, delay=0, cloudevents=False):
-        addr = (host, port)
-        self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
-        self.sock.bind(addr)
-        self.sock.listen(num_workers)
-        self.workers = [HTTPServerThread(i, self.sock, addr, delay, cloudevents) for i in range(num_workers)]
+    def acquire_lock(self):
+        if self.lock.acquire(timeout=5) == False:
+            self.close()
+            raise AssertionError('failed to acquire lock in HTTPServerWithEvents')
 
+    def append(self, event):
+        self.acquire_lock()
+        self.events.append(event)
+        self.lock.release()
+    
     def verify_s3_events(self, keys, exact_match=False, deletions=False, expected_sizes={}):
         """verify stored s3 records agains a list of keys"""
-        events = []
-        for worker in self.workers:
-            events += worker.get_events()
-            worker.reset_events()
-        verify_s3_records_by_elements(events, keys, exact_match=exact_match, deletions=deletions, expected_sizes=expected_sizes)
-
-    def verify_events(self, keys, exact_match=False, deletions=False):
-        """verify stored events agains a list of keys"""
-        events = []
-        for worker in self.workers:
-            events += worker.get_events()
-            worker.reset_events()
-        verify_events_by_elements(events, keys, exact_match=exact_match, deletions=deletions)
+        self.acquire_lock()
+        log.info('verify_s3_events: http server has %d events', len(self.events))
+        try:
+            verify_s3_records_by_elements(self.events, keys, exact_match=exact_match, deletions=deletions, expected_sizes=expected_sizes)
+        except AssertionError as err:
+            self.close()
+            raise err
+        finally:
+            self.lock.release()
+            self.events = []
 
     def get_and_reset_events(self):
-        events = []
-        for worker in self.workers:
-            events += worker.get_events()
-            worker.reset_events()
+        self.acquire_lock()
+        log.info('get_and_reset_events: http server has %d events', len(self.events))
+        events = self.events
+        self.events = []
+        self.lock.release()
         return events
 
     def close(self):
-        """close all workers in the http server and wait for it to finish"""
-        # make sure that the shared socket is closed
-        # this is needed in case that one of the threads is blocked on the socket
-        self.sock.shutdown(socket.SHUT_RDWR)
-        self.sock.close()
-        # wait for server threads to finish
-        for worker in self.workers:
-            worker.close()
-            worker.join()
+        log.info('http server on %s starting shutdown', str(self.addr))
+        t = threading.Thread(target=self.shutdown)
+        t.start()
+        t.join(5)
+        retries = 0
+        while self.proc.is_alive() and retries < 5:
+            retries += 1
+            t.join(5)
+            log.warning('http server on %s still alive', str(self.addr))
+        if self.proc.is_alive():
+            log.error('http server on %s failed to shutdown', str(self.addr))
+            self.server_close()
+        else:
+            log.info('http server on %s shutdown ended', str(self.addr))
 
 # AMQP endpoint functions
 
@@ -243,11 +298,6 @@ class AMQPReceiver(object):
         verify_s3_records_by_elements(self.events, keys, exact_match=exact_match, deletions=deletions, expected_sizes=expected_sizes)
         self.events = []
 
-    def verify_events(self, keys, exact_match=False, deletions=False):
-        """verify stored events agains a list of keys"""
-        verify_events_by_elements(self.events, keys, exact_match=exact_match, deletions=deletions)
-        self.events = []
-
     def get_and_reset_events(self):
         tmp = self.events
         self.events = []
@@ -308,114 +358,8 @@ def clean_rabbitmq(proc):
         log.info('rabbitmq server already terminated')
 
 
-def verify_events_by_elements(events, keys, exact_match=False, deletions=False):
-    """ verify there is at least one event per element """
-    err = ''
-    for key in keys:
-        key_found = False
-        if type(events) is list:
-            for event_list in events:
-                if key_found:
-                    break
-                for event in event_list['events']:
-                    if event['info']['bucket']['name'] == key.bucket.name and \
-                        event['info']['key']['name'] == key.name:
-                        if deletions and event['event'] == 'OBJECT_DELETE':
-                            key_found = True
-                            break
-                        elif not deletions and event['event'] == 'OBJECT_CREATE':
-                            key_found = True
-                            break
-        else:
-            for event in events['events']:
-                if event['info']['bucket']['name'] == key.bucket.name and \
-                    event['info']['key']['name'] == key.name:
-                    if deletions and event['event'] == 'OBJECT_DELETE':
-                        key_found = True
-                        break
-                    elif not deletions and event['event'] == 'OBJECT_CREATE':
-                        key_found = True
-                        break
-
-        if not key_found:
-            err = 'no ' + ('deletion' if deletions else 'creation') + ' event found for key: ' + str(key)
-            log.error(events)
-            assert False, err
-
-    if not len(events) == len(keys):
-        err = 'superfluous events are found'
-        log.debug(err)
-        if exact_match:
-            log.error(events)
-            assert False, err
-
 META_PREFIX = 'x-amz-meta-'
 
-def verify_s3_records_by_elements(records, keys, exact_match=False, deletions=False, expected_sizes={}, etags=[]):
-    """ verify there is at least one record per element """
-    err = ''
-    for key in keys:
-        key_found = False
-        object_size = 0
-        if type(records) is list:
-            for record_list in records:
-                if key_found:
-                    break
-                for record in record_list['Records']:
-                    assert_in('eTag', record['s3']['object'])
-                    if record['s3']['bucket']['name'] == key.bucket.name and \
-                        record['s3']['object']['key'] == key.name:
-                        # Assertion Error needs to be fixed
-                        #assert_equal(key.etag[1:-1], record['s3']['object']['eTag'])
-                        if etags:
-                            assert_in(key.etag[1:-1], etags)
-                        if len(record['s3']['object']['metadata']) > 0:
-                            for meta in record['s3']['object']['metadata']:
-                                assert(meta['key'].startswith(META_PREFIX))
-                        if deletions and record['eventName'].startswith('ObjectRemoved'):
-                            key_found = True
-                            object_size = record['s3']['object']['size']
-                            break
-                        elif not deletions and record['eventName'].startswith('ObjectCreated'):
-                            key_found = True
-                            object_size = record['s3']['object']['size']
-                            break
-        else:
-            for record in records['Records']:
-                assert_in('eTag', record['s3']['object'])
-                if record['s3']['bucket']['name'] == key.bucket.name and \
-                    record['s3']['object']['key'] == key.name:
-                    assert_equal(key.etag, record['s3']['object']['eTag'])
-                    if etags:
-                        assert_in(key.etag[1:-1], etags)
-                    if len(record['s3']['object']['metadata']) > 0:
-                        for meta in record['s3']['object']['metadata']:
-                            assert(meta['key'].startswith(META_PREFIX))
-                    if deletions and record['eventName'].startswith('ObjectRemoved'):
-                        key_found = True
-                        object_size = record['s3']['object']['size']
-                        break
-                    elif not deletions and record['eventName'].startswith('ObjectCreated'):
-                        key_found = True
-                        object_size = record['s3']['object']['size']
-                        break
-
-        if not key_found:
-            err = 'no ' + ('deletion' if deletions else 'creation') + ' event found for key: ' + str(key)
-            assert False, err
-        elif expected_sizes:
-            assert_equal(object_size, expected_sizes.get(key.name))
-
-    if not len(records) == len(keys):
-        err = 'superfluous records are found'
-        log.warning(err)
-        if exact_match:
-            for record_list in records:
-                for record in record_list['Records']:
-                    log.error(str(record['s3']['bucket']['name']) + ',' + str(record['s3']['object']['key']))
-            assert False, err
-
-
 # Kafka endpoint functions
 
 kafka_server = 'localhost'
@@ -1585,7 +1529,7 @@ def test_ps_s3_notification_multi_delete_on_master():
     port = random.randint(10000, 20000)
     # start an http server in a separate thread
     number_of_objects = 10
-    http_server = StreamingHTTPServer(host, port, num_workers=number_of_objects)
+    http_server = HTTPServerWithEvents((host, port))
 
     # create bucket
     bucket_name = gen_bucket_name()
@@ -1653,7 +1597,7 @@ def test_ps_s3_notification_push_http_on_master():
     port = random.randint(10000, 20000)
     # start an http server in a separate thread
     number_of_objects = 10
-    http_server = StreamingHTTPServer(host, port, num_workers=number_of_objects)
+    http_server = HTTPServerWithEvents((host, port))
 
     # create bucket
     bucket_name = gen_bucket_name()
@@ -1737,7 +1681,7 @@ def test_ps_s3_notification_push_cloudevents_on_master():
     port = random.randint(10000, 20000)
     # start an http server in a separate thread
     number_of_objects = 10
-    http_server = StreamingHTTPServer(host, port, num_workers=number_of_objects, cloudevents=True)
+    http_server = HTTPServerWithEvents((host, port), cloudevents=True)
 
     # create bucket
     bucket_name = gen_bucket_name()
@@ -1821,7 +1765,7 @@ def test_ps_s3_opaque_data_on_master():
     port = random.randint(10000, 20000)
     # start an http server in a separate thread
     number_of_objects = 10
-    http_server = StreamingHTTPServer(host, port, num_workers=number_of_objects)
+    http_server = HTTPServerWithEvents((host, port))
 
     # create bucket
     bucket_name = gen_bucket_name()
@@ -1890,7 +1834,7 @@ def test_ps_s3_lifecycle_on_master():
     port = random.randint(10000, 20000)
     # start an http server in a separate thread
     number_of_objects = 10
-    http_server = StreamingHTTPServer(host, port, num_workers=number_of_objects)
+    http_server = HTTPServerWithEvents((host, port))
 
     # create bucket
     bucket_name = gen_bucket_name()
@@ -2005,7 +1949,7 @@ def test_ps_s3_lifecycle_abort_mpu_on_master():
     port = random.randint(10000, 20000)
     # start an http server in a separate thread
     number_of_objects = 1
-    http_server = StreamingHTTPServer(host, port, num_workers=number_of_objects)
+    http_server = HTTPServerWithEvents((host, port))
 
     # create bucket
     bucket_name = gen_bucket_name()
@@ -2410,7 +2354,7 @@ def test_ps_s3_multipart_on_master_http():
     host = get_ip()
     port = random.randint(10000, 20000)
     # start an http server in a separate thread
-    http_server = StreamingHTTPServer(host, port, num_workers=10)
+    http_server = HTTPServerWithEvents((host, port))
 
     # create bucket
     bucket_name = gen_bucket_name()
@@ -3049,7 +2993,7 @@ def test_ps_s3_persistent_cleanup():
     port = random.randint(10000, 20000)
     # start an http server in a separate thread
     number_of_objects = 200
-    http_server = StreamingHTTPServer(host, port, num_workers=number_of_objects)
+    http_server = HTTPServerWithEvents((host, port))
 
     gw = conn
 
@@ -3149,10 +3093,10 @@ def test_ps_s3_persistent_topic_stats():
 
     # create random port for the http server
     host = get_ip()
-    http_port = random.randint(10000, 20000)
+    port = random.randint(10000, 20000)
 
     # start an http server in a separate thread
-    http_server = StreamingHTTPServer(host, http_port, num_workers=10)
+    http_server = HTTPServerWithEvents((host, port))
 
     # create bucket
     bucket_name = gen_bucket_name()
@@ -3160,7 +3104,7 @@ def test_ps_s3_persistent_topic_stats():
     topic_name = bucket_name + TOPIC_SUFFIX
 
     # create s3 topic
-    endpoint_address = 'http://'+host+':'+str(http_port)
+    endpoint_address = 'http://'+host+':'+str(port)
     endpoint_args = 'push-endpoint='+endpoint_address+'&persistent=true'
     topic_conf = PSTopicS3(conn, topic_name, zonegroup, endpoint_args=endpoint_args)
     topic_arn = topic_conf.set_config()
@@ -3227,7 +3171,7 @@ def test_ps_s3_persistent_topic_stats():
     assert_equal(result[1], 0)
 
     # start an http server in a separate thread
-    http_server = StreamingHTTPServer(host, http_port, num_workers=10)
+    http_server = HTTPServerWithEvents((host, port))
 
     print('wait for '+str(delay*2)+'sec for the messages...')
     time.sleep(delay*2)
@@ -3252,10 +3196,10 @@ def ps_s3_persistent_topic_configs(persistency_time, config_dict):
 
     # create random port for the http server
     host = get_ip()
-    http_port = random.randint(10000, 20000)
+    port = random.randint(10000, 20000)
 
     # start an http server in a separate thread
-    http_server = StreamingHTTPServer(host, http_port, num_workers=10)
+    http_server = HTTPServerWithEvents((host, port))
 
     # create bucket
     bucket_name = gen_bucket_name()
@@ -3263,7 +3207,7 @@ def ps_s3_persistent_topic_configs(persistency_time, config_dict):
     topic_name = bucket_name + TOPIC_SUFFIX
 
     # create s3 topic
-    endpoint_address = 'http://'+host+':'+str(http_port)
+    endpoint_address = 'http://'+host+':'+str(port)
     endpoint_args = 'push-endpoint='+endpoint_address+'&persistent=true&'+create_persistency_config_string(config_dict)
     topic_conf = PSTopicS3(conn, topic_name, zonegroup, endpoint_args=endpoint_args)
     topic_arn = topic_conf.set_config()
@@ -3394,7 +3338,7 @@ def test_ps_s3_persistent_notification_pushback():
     host = get_ip()
     port = random.randint(10000, 20000)
     # start an http server in a separate thread
-    http_server = StreamingHTTPServer(host, port, num_workers=10, delay=0.5)
+    http_server = HTTPServerWithEvents((host, port), delay=0.5)
 
     # create bucket
     bucket_name = gen_bucket_name()
@@ -3806,7 +3750,7 @@ def test_ps_s3_persistent_multiple_endpoints():
     port = random.randint(10000, 20000)
     # start an http server in a separate thread
     number_of_objects = 10
-    http_server = StreamingHTTPServer(host, port, num_workers=number_of_objects)
+    http_server = HTTPServerWithEvents((host, port))
 
     # create bucket
     bucket_name = gen_bucket_name()
@@ -3896,7 +3840,7 @@ def persistent_notification(endpoint_type):
         host = get_ip_http()
         port = random.randint(10000, 20000)
         # start an http server in a separate thread
-        receiver = StreamingHTTPServer(host, port, num_workers=10)
+        receiver = HTTPServerWithEvents((host, port))
         endpoint_address = 'http://'+host+':'+str(port)
         endpoint_args = 'push-endpoint='+endpoint_address+'&persistent=true'
         # the http server does not guarantee order, so duplicates are expected
@@ -4801,7 +4745,7 @@ def test_persistent_ps_s3_data_path_v2_migration():
         assert_equal(result[1], 0)
 
         # start an http server in a separate thread
-        http_server = StreamingHTTPServer(host, http_port, num_workers=number_of_objects)
+        http_server = HTTPServerWithEvents((host, http_port))
 
         delay = 30
         print('wait for '+str(delay)+'sec for the messages...')
@@ -4846,8 +4790,7 @@ def test_ps_s3_data_path_v2_migration():
     http_port = random.randint(10000, 20000)
 
     # start an http server in a separate thread
-    number_of_objects = 10
-    http_server = StreamingHTTPServer(host, http_port, num_workers=number_of_objects)
+    http_server = HTTPServerWithEvents((host, http_port))
 
     # disable v2 notification
     result = admin(['zonegroup', 'modify', '--disable-feature=notification_v2'], get_config_cluster())
@@ -4878,6 +4821,7 @@ def test_ps_s3_data_path_v2_migration():
     assert_equal(status/100, 2)
 
     # create objects in the bucket (async)
+    number_of_objects = 10
     client_threads = []
     start_time = time.time()
     for i in range(number_of_objects):