from contextlib import contextmanager
from pyfakefs import fake_filesystem
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Any
def import_cephadm():
else:
yield ctx
+
+@pytest.fixture()
+def funkypatch(monkeypatch):
+ """Defines the funkypatch fixtures that acts like a mixture between
+ mock.patch and pytest's monkeypatch fixture.
+ """
+ fp = FunkyPatcher(monkeypatch)
+ yield fp
+
+
+class FunkyPatcher:
+ """FunkyPatcher monkeypatches all imported instances of an object.
+
+ Use `patch` to patch the canonical location of an object and FunkyPatcher
+ will automatically replace other imports of that object.
+ """
+
+ def __init__(self, monkeypatcher):
+ self._mp = monkeypatcher
+ # keep track of objects we've already patched. this dictionary
+ # maps a (module-name, object-name) tuple to the original object
+ # before patching. This could be used to determine if a name has
+ # already been patched or compare a patched object to the original.
+ self._originals: Dict[Tuple[str, str], Any] = {}
+
+ def patch(
+ self,
+ mod: str,
+ name: str = '',
+ *,
+ dest: Any = None,
+ force: bool = False,
+ ) -> Any:
+ """Patch an object and all existing imports of that object.
+ Specify mod as `my.mod.name.obj` where obj is name of the object to be
+ patched or as `my.mod.name` and specify `name` as the name of the
+ object to be patched.
+ If the object to be patched is not imported as the same name in `mod`
+ it will *not* be automatically patched. In other words, `from
+ my.mod.name import foo` will work, but `from my.mod.name import foo as
+ _foo` will not.
+ Use the keyword-only argument `dest` to specify the new object to be
+ used. A MagicMock will be created and used if dest is None.
+ Use the keyword-only argument `force` to override checks that a mocked
+ objects are the same across modules. This can be used in the case that
+ some other code already patched an object and you want funkypatch to
+ override that patch (use with caution).
+ Returns the patched object (the MagicMock or supplied dest).
+ """
+ import sys
+ import importlib
+
+ if not name:
+ mod, name = mod.rsplit('.', 1)
+ modname = (mod, name)
+ # We don't strictly need the check but patching already patched objs is
+ # confusing to think about. It's better to block it for now and perhaps
+ # later we can relax these restrictions or be clever in some way.
+ if modname in self._originals:
+ raise KeyError(f'{modname} already patched')
+
+ if dest is None:
+ dest = mock.MagicMock()
+
+ imod = importlib.import_module(mod)
+ self._originals[modname] = getattr(imod, name)
+
+ for mname, imod in sys.modules.items():
+ try:
+ obj = getattr(imod, name)
+ except AttributeError:
+ # no matching name in module
+ continue
+ # make sure that the module imported the same object as the
+ # one we want to patch out, and not just some naming collision.
+ # ensure the original object and the one in the module are the
+ # same object
+ if obj is self._originals[modname] or force:
+ self._mp.setattr(imod, name, dest)
+ return dest