[PATCH 1 of 3] extensions: introduce a class interposition function

Bryan O'Sullivan bos at serpentine.com
Tue Nov 20 16:55:43 CST 2012


# HG changeset patch
# User Bryan O'Sullivan <bryano at fb.com>
# Date 1353452136 28800
# Node ID cbf274dc847b5c28a282ae98a7d8b9bd48b6137d
# Parent  4ae21a7568f353dcb781df28d2acf458b06afcad
extensions: introduce a class interposition function

This allows an existing class to be augmented in a transparent way,
without its subclasses or callers needing to participate.

The manual class override mechanism currently in use introduces
names into an outer scope that can be accidentally (and incorrectly)
used in monkeypatched classes. It also does not make clear which
class is intended to be monkeypatched. Finally, it requires explicit
support from the to-be-monkeypatched code (e.g. reposetup).

diff --git a/mercurial/extensions.py b/mercurial/extensions.py
--- a/mercurial/extensions.py
+++ b/mercurial/extensions.py
@@ -193,6 +193,34 @@ def wrapfunction(container, funcname, wr
     setattr(container, funcname, wrap)
     return origfn
 
+def replaceclass(container, classname):
+    '''Replace a class with another in a module, and interpose it into
+    the hierarchies of all loaded subclasses. This function is
+    intended for use as a decorator.
+
+      import mymodule
+      @replaceclass(mymodule, 'myclass')
+      class mysubclass(mymodule.myclass):
+          def foo(self):
+              f = super(mysubclass, self).foo()
+              return f + ' bar'
+
+    Existing instances of the class being replaced will not have their
+    __class__ modified, so call this function before creating any
+    objects of the target type.
+    '''
+    def wrap(cls):
+        oldcls = getattr(container, classname)
+        oldbases = (oldcls,)
+        newbases = (cls,)
+        for subcls in oldcls.__subclasses__():
+            if subcls is not cls:
+                assert subcls.__bases__ == oldbases
+                subcls.__bases__ = newbases
+        setattr(container, classname, cls)
+        return cls
+    return wrap
+
 def _disabledpaths(strip_init=False):
     '''find paths of disabled extensions. returns a dict of {name: path}
     removes /__init__.py from packages if strip_init is True'''


More information about the Mercurial-devel mailing list