]> git.apps.os.sepia.ceph.com Git - ceph.git/commitdiff
cephadm: add a new funkypatch fixture based on mock.patch and pytest
authorJohn Mulligan <jmulligan@redhat.com>
Sun, 20 Aug 2023 17:50:00 +0000 (13:50 -0400)
committerJohn Mulligan <jmulligan@redhat.com>
Thu, 30 Nov 2023 21:55:50 +0000 (16:55 -0500)
This fixture acts like a combination of mock.patch and pytest's
monkeypatch fixture. It has the additional feature of automatically
finding and patching the same object imported in other modules.  If you
have 'from x import y', where y is a function or class, in both a.py and
b.py it will patch both instances (so long as both a and b are already
imported).
This behavior is useful for cephadm because of the heavy use of the
`from x import y` idiom and how cephadm is being actively refactored.

Signed-off-by: John Mulligan <jmulligan@redhat.com>
src/cephadm/tests/fixtures.py

index d25dffa9e3b44cbfc66ac9b1164554046174ddd3..572c1f9969d66962673c098135d5a3a0244b9183 100644 (file)
@@ -6,7 +6,7 @@ import time
 from contextlib import contextmanager
 from pyfakefs import fake_filesystem
 
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Any
 
 
 def import_cephadm():
@@ -183,3 +183,83 @@ def with_cephadm_ctx(
         else:
             yield ctx
 
+
+@pytest.fixture()
+def funkypatch(monkeypatch):
+    """Defines the funkypatch fixtures that acts like a mixture between
+    mock.patch and pytest's monkeypatch fixture.
+    """
+    fp = FunkyPatcher(monkeypatch)
+    yield fp
+
+
+class FunkyPatcher:
+    """FunkyPatcher monkeypatches all imported instances of an object.
+
+    Use `patch` to patch the canonical location of an object and FunkyPatcher
+    will automatically replace other imports of that object.
+    """
+
+    def __init__(self, monkeypatcher):
+        self._mp = monkeypatcher
+        # keep track of objects we've already patched. this dictionary
+        # maps a (module-name, object-name) tuple to the original object
+        # before patching. This could be used to determine if a name has
+        # already been patched or compare a patched object to the original.
+        self._originals: Dict[Tuple[str, str], Any] = {}
+
+    def patch(
+        self,
+        mod: str,
+        name: str = '',
+        *,
+        dest: Any = None,
+        force: bool = False,
+    ) -> Any:
+        """Patch an object and all existing imports of that object.
+        Specify mod as `my.mod.name.obj` where obj is name of the object to be
+        patched or as `my.mod.name` and specify `name` as the name of the
+        object to be patched.
+        If the object to be patched is not imported as the same name in `mod`
+        it will *not* be automatically patched. In other words, `from
+        my.mod.name import foo` will work, but `from my.mod.name import foo as
+        _foo` will not.
+        Use the keyword-only argument `dest` to specify the new object to be
+        used. A MagicMock will be created and used if dest is None.
+        Use the keyword-only argument `force` to override checks that a mocked
+        objects are the same across modules. This can be used in the case that
+        some other code already patched an object and you want funkypatch to
+        override that patch (use with caution).
+        Returns the patched object (the MagicMock or supplied dest).
+        """
+        import sys
+        import importlib
+
+        if not name:
+            mod, name = mod.rsplit('.', 1)
+        modname = (mod, name)
+        # We don't strictly need the check but patching already patched objs is
+        # confusing to think about. It's better to block it for now and perhaps
+        # later we can relax these restrictions or be clever in some way.
+        if modname in self._originals:
+            raise KeyError(f'{modname} already patched')
+
+        if dest is None:
+            dest = mock.MagicMock()
+
+        imod = importlib.import_module(mod)
+        self._originals[modname] = getattr(imod, name)
+
+        for mname, imod in sys.modules.items():
+            try:
+                obj = getattr(imod, name)
+            except AttributeError:
+                # no matching name in module
+                continue
+            # make sure that the module imported the same object as the
+            # one we want to patch out, and not just some naming collision.
+            # ensure the original object and the one in the module are the
+            # same object
+            if obj is self._originals[modname] or force:
+                self._mp.setattr(imod, name, dest)
+        return dest