[PATCH 1 of 5] import-checker: resolve relative imports

Gregory Szorc gregory.szorc at gmail.com
Sun Jun 28 19:49:39 UTC 2015


# HG changeset patch
# User Gregory Szorc <gregory.szorc at gmail.com>
# Date 1435509418 25200
#      Sun Jun 28 09:36:58 2015 -0700
# Node ID 772bfce29007817eb4422e6aa085afe3836adf5c
# Parent  ff5172c830022b64cc5bd1bae36b2276e9dc6e5d
import-checker: resolve relative imports

"from . import X" will produce an ImportFrom ast node with .module =
None. This resulted in a run-time error from attempting to concatenate
None with a str.

Another problem with relative imports is that the prefix may be dynamic
based on the "level" attribute of the import. e.g. "from ." has level 1
and "from .." has level 2.

We teach the "fromlocal" function how to cope with relative imports.
Where appropriate, the consumer passes in the level so relative module
names may be resolved properly.

diff --git a/contrib/import-checker.py b/contrib/import-checker.py
--- a/contrib/import-checker.py
+++ b/contrib/import-checker.py
@@ -77,15 +77,28 @@ def fromlocalfunc(modulename, localmods)
     ('baz.baz1', 'baz.baz1', False)
     >>> # unknown = maybe standard library
     >>> fromlocal('os')
     False
+    >>> fromlocal(None, 1)
+    ('foo', 'foo.__init__', True)
+    >>> fromlocal2 = fromlocalfunc('foo.xxx.yyy', localmods)
+    >>> fromlocal2(None, 2)
+    ('foo', 'foo.__init__', True)
     """
     prefix = '.'.join(modulename.split('.')[:-1])
     if prefix:
         prefix += '.'
-    def fromlocal(name):
-        # check relative name at first
-        for n in prefix + name, name:
+    def fromlocal(name, level=0):
+        # name is None when relative imports are used.
+        if name is None:
+            # If relative imports are used, level must not be absolute.
+            assert level > 0
+            candidates = ['.'.join(modulename.split('.')[:-level])]
+        else:
+            # Check relative name first.
+            candidates = [prefix + name, name]
+
+        for n in candidates:
             if n in localmods:
                 return (n, n, False)
             dottedpath = n + '.__init__'
             if dottedpath in localmods:
@@ -238,9 +251,9 @@ def imported_modules(source, modulename,
                     # this should import standard library
                     continue
                 yield found[1]
         elif isinstance(node, ast.ImportFrom):
-            found = fromlocal(node.module)
+            found = fromlocal(node.module, node.level)
             if not found:
                 # this should import standard library
                 continue
 


More information about the Mercurial-devel mailing list