]> git.apps.os.sepia.ceph.com Git - teuthology.git/commitdiff
xxx: improve sqlite machine pool
authorJohn Mulligan <phlogistonjohn@asynchrono.us>
Sat, 10 Aug 2024 18:33:26 +0000 (14:33 -0400)
committerJohn Mulligan <phlogistonjohn@asynchrono.us>
Sat, 10 Aug 2024 18:33:26 +0000 (14:33 -0400)
Signed-off-by: John Mulligan <phlogistonjohn@asynchrono.us>
teuthology/lock/sqlite_pool.py

index f647dfd3f654d2eb488d3ce22a76ea2bf9b50ef2..4e3c8b039616beb2785923cb27f8ecbb2f9c0ba2 100644 (file)
@@ -7,45 +7,41 @@ log = logging.getLogger(__name__)
 
 class SqliteMachinePool:
     def __init__(
-        self, path: str,
+        self,
+        path: str,
     ) -> None:
         self._path = path
         self._connect()
         self._create_tables()
 
-    def _select_machines(self, ):
-        with self._conn:
-            cur = self._conn.cursor()
-            cur.execute(
-                "SELECT rowid,jobdesc FROM jobs ORDER BY rowid LIMIT 1"
-            )
-            rows = [(jid, data) for jid, data in cur.fetchall()]
-        if rows:
-            assert len(rows) == 1
-            return rows[0]
-        return None
-
-    def _delete(self, jid: int) -> None:
-        with self._conn:
-            self._conn.execute("DELETE FROM jobs WHERE rowid=?", (jid,))
-
     def _create_tables(self) -> None:
         try:
             with self._conn:
-                self._conn.execute("""
+                self._conn.execute(
+                    """
                     CREATE TABLE IF NOT EXISTS machines (
-                        name TEXT,
+                        name TEXT UNIQUE,
                         mtype TEXT,
                         up INTEGER,
                         in_use INTEGER,
+                        cookie TEXT,
                         info JSON
                     )
-                """)
+                    """
+                )
         except sqlite3.OperationalError:
             pass
 
-    def _select(self, machine_type, up, locked, count):
-        query = "SELECT name, mtype, up, in_use, info FROM machines"
+    def _select(
+        self,
+        *,
+        machine_type=None,
+        up=None,
+        locked=None,
+        cookie=None,
+        count=None,
+    ):
+        query = "SELECT name, mtype, up, in_use, cookie, info FROM machines"
         where = []
         params = []
         if machine_type is not None:
@@ -57,12 +53,16 @@ class SqliteMachinePool:
         if locked is not None:
             where.append('in_use=?')
             params.append(1 if locked else 0)
+        if cookie is not None:
+            where.append('cookie=?')
+            params.append(cookie)
         if where:
             query += ' WHERE ' + (' AND '.join(where))
         if count is not None:
-            query += ' LIMIT '+ str(int(count))
+            query += ' LIMIT ' + str(int(count))
 
         with self._conn:
+            self._conn.row_factory = sqlite3.Row
             cur = self._conn.cursor()
             cur.execute(query, tuple(params))
             rows = cur.fetchall()
@@ -72,7 +72,10 @@ class SqliteMachinePool:
     def add_machine(self, name, machine_type, info):
         with self._conn:
             cur = self._conn.cursor()
-            cur.execute("INSERT INTO machines VALUES (?,?, 1, 0, ?)", (name, machine_type,info))
+            cur.execute(
+                "INSERT INTO machines VALUES (?,?, 1, 0, '', ?)",
+                (name, machine_type, info),
+            )
             cur.close()
 
     def remove_machine(self, name):
@@ -81,12 +84,18 @@ class SqliteMachinePool:
             cur.execute("DELETE FROM machines WHERE name=?", (name,))
             cur.close()
 
-    def _take(self, machine_type, count):
+    def remove_all_machines(self):
+        with self._conn:
+            cur = self._conn.cursor()
+            cur.execute('DELETE FROM machines')
+            cur.close()
+
+    def _take(self, machine_type, count, cookie):
         count = int(count)
-        query = "UPDATE machines SET in_use=1 WHERE rowid IN (SELECT rowid FROM machines WHERE in_use=0 AND mtype=? LIMIT ?)"
+        query = "UPDATE machines SET in_use=1, cookie=? WHERE rowid IN (SELECT rowid FROM machines WHERE in_use=0 AND mtype=? LIMIT ?)"
         with self._conn:
             cur = self._conn.cursor()
-            cur.execute(query, (machine_type, count))
+            cur.execute(query, (cookie, machine_type, count))
             cur.close()
 
     def _connect(self) -> None:
@@ -97,9 +106,10 @@ class SqliteMachinePool:
             path = path[7:]
         log.warning("P:%s", path)
         self._conn = sqlite3.connect(path)
+        # self._conn.set_trace_callback(print)
 
     def everything(self):
-        return self._select(None, None, None, None)
+        return [dict(v) for v in self._select()]
 
     def list_locks(
         self,
@@ -110,7 +120,12 @@ class SqliteMachinePool:
         count,
         tries=None,
     ):
-        return {v[0]: None for v in self._select(machine_type, up, locked, count)}
+        return {
+            v[0]: None
+            for v in self._select(
+                machine_type=machine_type, up=up, locked=locked, count=count
+            )
+        }
 
     def acquire(
         self,
@@ -124,8 +139,25 @@ class SqliteMachinePool:
         arch=None,
         reimage=True,
     ):
-        self._take(machine_type, num)
-        return {v[0]: None for v in self._select(machine_type, True, True, num)}
+        cookie = getattr(ctx, 'job_cookie', None)
+        if cookie is None:
+            user = user or 'default'
+            description = description or 'missing'
+            cookie = f'{user}/{description}'
+            if ctx:
+                setattr(ctx, 'job_cookie', cookie)
+
+        self._take(machine_type, num, cookie)
+        return {
+            v[0]: None
+            for v in self._select(
+                machine_type=machine_type,
+                up=True,
+                locked=True,
+                count=num,
+                cookie=cookie,
+            )
+        }
 
     def is_vm(self, name):
         return False
@@ -137,24 +169,33 @@ def main():
     import sys
     import yaml
 
+    class Context:
+        pass
+
     parser = argparse.ArgumentParser()
     parser.add_argument('--list', action='store_true')
     parser.add_argument('--add', action='append')
+    parser.add_argument('--rm-all', action='store_true')
     parser.add_argument('--rm', action='append')
     parser.add_argument('--acquire', type=int)
+    parser.add_argument('--cookie', type=str)
     parser.add_argument('--machine-type')
     parser.add_argument('--info')
     cli = parser.parse_args()
 
     mpool = SqliteMachinePool(config.machine_pool)
+    if cli.rm_all:
+        mpool.remove_all_machines()
     for name in cli.rm or []:
         mpool.remove_machine(name)
     for name in cli.add or []:
         mpool.add_machine(name, cli.machine_type, cli.info)
     if cli.acquire:
-        mpool.acquire(None, cli.acquire, cli.machine_type)
+        ctx = Context()
+        setattr(ctx, 'job_cookie', cli.cookie)
+        mpool.acquire(ctx, cli.acquire, cli.machine_type)
     if cli.list:
-        yaml.safe_dump(mpool.everything(), sys.stdout)
+        yaml.safe_dump(mpool.everything(), sys.stdout, sort_keys=False)
 
 
 if __name__ == '__main__':