Skip to content

Commit

Permalink
Support for module namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
arkq committed Nov 27, 2017
1 parent cdefc92 commit 9f7648c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
45 changes: 34 additions & 11 deletions src/flake8_requirements/checker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
import sys
from collections import namedtuple
from itertools import chain

from pkg_resources import parse_requirements
Expand Down Expand Up @@ -37,19 +38,34 @@ def project2module(project):
class ImportVisitor(ast.NodeVisitor):
"""Import statement visitor."""

# Convenience structure for storing import statement.
Import = namedtuple('Import', ('line', 'offset', 'mod', 'alt'))

def __init__(self, tree):
"""Initialize import statement visitor."""
self.imports = []
self.visit(tree)

def visit_Import(self, node):
self.imports.append((node, node.names[0].name))
self.imports.append(ImportVisitor.Import(
node.lineno,
node.col_offset,
node.names[0].name,
node.names[0].name,
))

def visit_ImportFrom(self, node):
if node.level != 0:
# Omit relative imports (local modules).
return
self.imports.append((node, node.module))
self.imports.append(ImportVisitor.Import(
node.lineno,
node.col_offset,
node.module,
# Alternative module name which covers:
# > from namespace import module
".".join((node.module, node.names[0].name)),
))


class SetupVisitor(ast.NodeVisitor):
Expand Down Expand Up @@ -197,9 +213,11 @@ def split(module):
"""Split module into submodules."""
return tuple(module.split("."))

def modcmp(mod1=(), mod2=()):
def modcmp(lib=(), test=()):
"""Compare import modules."""
return all(a == b for a, b in zip(mod1, mod2))
if len(lib) > len(test):
return False
return all(a == b for a, b in zip(lib, test))

mods_1st_party = set()
mods_3rd_party = set()
Expand All @@ -216,15 +234,20 @@ def modcmp(mod1=(), mod2=()):
modules = KNOWN_3RD_PARTIES[modules[0]]
mods_3rd_party.update(split(x) for x in modules)

for node, module in ImportVisitor(self.tree).imports:
_module = module.split(".")
if any([_module[0] == x for x in STDLIB]):
for node in ImportVisitor(self.tree).imports:
_mod = split(node.mod)
_alt = split(node.alt)
if any([_mod[0] == x for x in STDLIB]):
continue
if any([modcmp(x, _mod) or modcmp(x, _alt)
for x in mods_1st_party]):
continue
if any([modcmp(_module, x) for x in requirements]):
if any([modcmp(x, _mod) or modcmp(x, _alt)
for x in mods_3rd_party]):
continue
yield (
node.lineno,
node.col_offset,
ERRORS['I900'].format(pkg=module),
node.line,
node.offset,
ERRORS['I900'].format(pkg=node.mod),
Flake8Checker,
)
11 changes: 10 additions & 1 deletion test/test_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def get_requirements(self):
"hyp-hen",
"python-boom",
"setuptools",
"space.module",
))


Expand Down Expand Up @@ -82,7 +83,15 @@ def test_non_top_level_import(self):
"I900 'cat' not listed as a requirement",
)

def test_relative_import(self):
def test_namespace(self):
errors = check("import space.module")
self.assertEqual(len(errors), 0)
errors = check("from space import module")
self.assertEqual(len(errors), 0)
errors = check("import space")
self.assertEqual(len(errors), 1)

def test_relative(self):
errors = check("from . import local")
self.assertEqual(len(errors), 0)
errors = check("from ..local import local")
Expand Down

0 comments on commit 9f7648c

Please sign in to comment.