]> git.apps.os.sepia.ceph.com Git - ceph-ci.git/commitdiff
mgr/smb: add staging area class to handler
authorJohn Mulligan <jmulligan@redhat.com>
Mon, 29 Apr 2024 18:55:01 +0000 (14:55 -0400)
committerJohn Mulligan <jmulligan@redhat.com>
Thu, 2 May 2024 21:06:34 +0000 (17:06 -0400)
This in-memory staging area for new resources being applied to the
config helps simplify validation.

Signed-off-by: John Mulligan <jmulligan@redhat.com>
src/pybind/mgr/smb/handler.py

index 40cc4e02794db59f433134bca2b7cc224bffc341..6e819ebd1a7162a2f28487c0053fe62414eb0a4b 100644 (file)
@@ -3,6 +3,7 @@ from typing import (
     Collection,
     Dict,
     Iterable,
+    Iterator,
     List,
     Optional,
     Set,
@@ -28,14 +29,16 @@ from .enums import (
 from .internal import (
     ClusterEntry,
     JoinAuthEntry,
-    ResourceEntry,
     ShareEntry,
     UsersAndGroupsEntry,
+    resource_entry,
+    resource_key,
 )
 from .proto import (
     AccessAuthorizer,
     ConfigEntry,
     ConfigStore,
+    EntryKey,
     OrchSubmitter,
     PathResolver,
     Simplified,
@@ -180,6 +183,72 @@ class _Matcher:
         )
 
 
+class _Staging:
+    def __init__(self, store: ConfigStore) -> None:
+        self.destination_store = store
+        self.incoming: Dict[EntryKey, SMBResource] = {}
+        self.deleted: Dict[EntryKey, SMBResource] = {}
+        self._keycache: Set[EntryKey] = set()
+
+    def stage(self, resource: SMBResource) -> None:
+        self._keycache = set()
+        ekey = resource_key(resource)
+        if resource.intent == Intent.REMOVED:
+            self.deleted[ekey] = resource
+        else:
+            self.deleted.pop(ekey, None)
+            self.incoming[ekey] = resource
+
+    def _virtual_keys(self) -> Iterator[EntryKey]:
+        new = set(self.incoming.keys())
+        for ekey in self.destination_store:
+            if ekey in self.deleted:
+                continue
+            yield ekey
+            new.discard(ekey)
+        for ekey in new:
+            yield ekey
+
+    def __iter__(self) -> Iterator[EntryKey]:
+        self._keycache = set(self._virtual_keys())
+        return iter(self._keycache)
+
+    def namespaces(self) -> Collection[str]:
+        return {k[0] for k in self}
+
+    def contents(self, ns: str) -> Collection[str]:
+        return {kname for kns, kname in self if kns == ns}
+
+    def get_cluster(self, cluster_id: str) -> resources.Cluster:
+        ekey = (str(ClusterEntry.namespace), cluster_id)
+        if ekey in self.incoming:
+            res = self.incoming[ekey]
+            assert isinstance(res, resources.Cluster)
+            return res
+        return ClusterEntry.from_store(
+            self.destination_store, cluster_id
+        ).get_cluster()
+
+    def save(self) -> ResultGroup:
+        results = ResultGroup()
+        for res in self.deleted.values():
+            results.append(self._save(res))
+        for res in self.incoming.values():
+            results.append(self._save(res))
+        return results
+
+    def _save(self, resource: SMBResource) -> Result:
+        entry = resource_entry(self.destination_store, resource)
+        if resource.intent == Intent.REMOVED:
+            removed = entry.remove()
+            state = State.REMOVED if removed else State.NOT_PRESENT
+        else:
+            state = entry.create_or_update(resource)
+        log.debug('saved resource: %r; state: %s', resource, state)
+        result = Result(resource, success=True, status={'state': state})
+        return result
+
+
 class ClusterConfigHandler:
     """The central class for ingesting and handling smb configuration change
     requests.
@@ -247,20 +316,25 @@ class ClusterConfigHandler:
     def apply(self, inputs: Iterable[SMBResource]) -> ResultGroup:
         log.debug('applying changes to internal data store')
         results = ResultGroup()
-        for resource in order_resources(inputs):
-            try:
-                result = self._update_resource(resource)
-            except ErrorResult as err:
-                result = err
-            except Exception as err:
-                log.exception("error updating resource")
-                result = ErrorResult(resource, msg=str(err))
+        staging = _Staging(self.internal_store)
+        try:
+            incoming = order_resources(inputs)
+            for resource in incoming:
+                staging.stage(resource)
+            for resource in incoming:
+                results.append(self._check(resource, staging))
+        except ErrorResult as err:
+            results.append(err)
+        except Exception as err:
+            log.exception("error updating resource")
+            result = ErrorResult(resource, msg=str(err))
             results.append(result)
         if results.success:
             log.debug(
                 'successfully updated %s resources. syncing changes to public stores',
                 len(list(results)),
             )
+            results = staging.save()
             self._sync_modified(results)
         return results
 
@@ -324,38 +398,24 @@ class ClusterConfigHandler:
         log.debug("search found %d resources", len(out))
         return out
 
-    def _update_resource(self, resource: SMBResource) -> Result:
-        """Update the internal store with a new resource object."""
-        entry: ResourceEntry
-        log.debug('updating resource: %r', resource)
+    def _check(self, resource: SMBResource, staging: _Staging) -> Result:
+        """Check/validate a staged resource."""
+        log.debug('staging resource: %r', resource)
         if isinstance(
             resource, (resources.Cluster, resources.RemovedCluster)
         ):
-            check_cluster(resource, self.internal_store)
-            entry = self._cluster_entry(resource.cluster_id)
+            _check_cluster(resource, staging)
         elif isinstance(resource, (resources.Share, resources.RemovedShare)):
-            check_share(resource, self.internal_store, self._path_resolver)
-            entry = self._share_entry(resource.cluster_id, resource.share_id)
+            _check_share(resource, staging, self._path_resolver)
         elif isinstance(resource, resources.JoinAuth):
-            check_join_auths(resource, self.internal_store)
-            entry = self._join_auth_entry(resource.auth_id)
+            _check_join_auths(resource, staging)
         elif isinstance(resource, resources.UsersAndGroups):
-            check_users_and_groups(resource, self.internal_store)
-            entry = self._users_and_groups_entry(resource.users_groups_id)
+            _check_users_and_groups(resource, staging)
         else:
             raise TypeError('not a valid smb resource')
-        state = self._save(entry, resource)
-        result = Result(resource, success=True, status={'state': state})
-        log.debug('saved resource: %r; state: %s', resource, state)
+        result = Result(resource, success=True, status={'checked': True})
         return result
 
-    def _save(self, entry: ResourceEntry, resource: SMBResource) -> State:
-        # Returns the Intent indicating the previous state.
-        if resource.intent == Intent.REMOVED:
-            removed = entry.remove()
-            return State.REMOVED if removed else State.NOT_PRESENT
-        return entry.create_or_update(resource)
-
     def _sync_clusters(
         self, modified_cluster_ids: Optional[Collection[str]] = None
     ) -> None:
@@ -631,10 +691,10 @@ def order_resources(
     return sorted(resource_objs, key=_keyfunc)
 
 
-def check_cluster(cluster: ClusterRef, store: ConfigStore) -> None:
+def _check_cluster(cluster: ClusterRef, staging: _Staging) -> None:
     """Check that the cluster resource can be updated."""
     if cluster.intent == Intent.REMOVED:
-        share_ids = ShareEntry.ids(store)
+        share_ids = ShareEntry.ids(staging)
         clusters_used = {cid for cid, _ in share_ids}
         if cluster.cluster_id in clusters_used:
             raise ErrorResult(
@@ -653,15 +713,15 @@ def check_cluster(cluster: ClusterRef, store: ConfigStore) -> None:
     cluster.validate()
 
 
-def check_share(
-    share: ShareRef, store: ConfigStore, resolver: PathResolver
+def _check_share(
+    share: ShareRef, staging: _Staging, resolver: PathResolver
 ) -> None:
     """Check that the share resource can be updated."""
     if share.intent == Intent.REMOVED:
         return
     assert isinstance(share, resources.Share)
     share.validate()
-    if share.cluster_id not in ClusterEntry.ids(store):
+    if share.cluster_id not in ClusterEntry.ids(staging):
         raise ErrorResult(
             share,
             msg="no matching cluster id",
@@ -681,15 +741,15 @@ def check_share(
         )
 
 
-def check_join_auths(
-    join_auth: resources.JoinAuth, store: ConfigStore
+def _check_join_auths(
+    join_auth: resources.JoinAuth, staging: _Staging
 ) -> None:
     """Check that the JoinAuth resource can be updated."""
     if join_auth.intent == Intent.PRESENT:
         return  # adding is always ok
     refs_in_use: Dict[str, List[str]] = {}
-    for cluster_id in ClusterEntry.ids(store):
-        cluster = ClusterEntry.from_store(store, cluster_id).get_cluster()
+    for cluster_id in ClusterEntry.ids(staging):
+        cluster = staging.get_cluster(cluster_id)
         for ref in _auth_refs(cluster):
             refs_in_use.setdefault(ref, []).append(cluster_id)
     log.debug('refs_in_use: %r', refs_in_use)
@@ -703,15 +763,15 @@ def check_join_auths(
         )
 
 
-def check_users_and_groups(
-    users_and_groups: resources.UsersAndGroups, store: ConfigStore
+def _check_users_and_groups(
+    users_and_groups: resources.UsersAndGroups, staging: _Staging
 ) -> None:
     """Check that the UsersAndGroups resource can be updated."""
     if users_and_groups.intent == Intent.PRESENT:
         return  # adding is always ok
     refs_in_use: Dict[str, List[str]] = {}
-    for cluster_id in ClusterEntry.ids(store):
-        cluster = ClusterEntry.from_store(store, cluster_id).get_cluster()
+    for cluster_id in ClusterEntry.ids(staging):
+        cluster = staging.get_cluster(cluster_id)
         for ref in _ug_refs(cluster):
             refs_in_use.setdefault(ref, []).append(cluster_id)
     log.debug('refs_in_use: %r', refs_in_use)