From 9f7648c53a637e2cf8ab2ad7b81a384bf4c5c13b Mon Sep 17 00:00:00 2001 From: Arkadiusz Bokowy Date: Sat, 25 Nov 2017 19:30:10 +0100 Subject: [PATCH] Support for module namespace --- src/flake8_requirements/checker.py | 45 ++++++++++++++++++++++-------- test/test_checker.py | 11 +++++++- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/src/flake8_requirements/checker.py b/src/flake8_requirements/checker.py index 9b60547..695dc53 100644 --- a/src/flake8_requirements/checker.py +++ b/src/flake8_requirements/checker.py @@ -1,5 +1,6 @@ import ast import sys +from collections import namedtuple from itertools import chain from pkg_resources import parse_requirements @@ -37,19 +38,34 @@ def project2module(project): class ImportVisitor(ast.NodeVisitor): """Import statement visitor.""" + # Convenience structure for storing import statement. + Import = namedtuple('Import', ('line', 'offset', 'mod', 'alt')) + def __init__(self, tree): """Initialize import statement visitor.""" self.imports = [] self.visit(tree) def visit_Import(self, node): - self.imports.append((node, node.names[0].name)) + self.imports.append(ImportVisitor.Import( + node.lineno, + node.col_offset, + node.names[0].name, + node.names[0].name, + )) def visit_ImportFrom(self, node): if node.level != 0: # Omit relative imports (local modules). return - self.imports.append((node, node.module)) + self.imports.append(ImportVisitor.Import( + node.lineno, + node.col_offset, + node.module, + # Alternative module name which covers: + # > from namespace import module + ".".join((node.module, node.names[0].name)), + )) class SetupVisitor(ast.NodeVisitor): @@ -197,9 +213,11 @@ def split(module): """Split module into submodules.""" return tuple(module.split(".")) - def modcmp(mod1=(), mod2=()): + def modcmp(lib=(), test=()): """Compare import modules.""" - return all(a == b for a, b in zip(mod1, mod2)) + if len(lib) > len(test): + return False + return all(a == b for a, b in zip(lib, test)) mods_1st_party = set() mods_3rd_party = set() @@ -216,15 +234,20 @@ def modcmp(mod1=(), mod2=()): modules = KNOWN_3RD_PARTIES[modules[0]] mods_3rd_party.update(split(x) for x in modules) - for node, module in ImportVisitor(self.tree).imports: - _module = module.split(".") - if any([_module[0] == x for x in STDLIB]): + for node in ImportVisitor(self.tree).imports: + _mod = split(node.mod) + _alt = split(node.alt) + if any([_mod[0] == x for x in STDLIB]): + continue + if any([modcmp(x, _mod) or modcmp(x, _alt) + for x in mods_1st_party]): continue - if any([modcmp(_module, x) for x in requirements]): + if any([modcmp(x, _mod) or modcmp(x, _alt) + for x in mods_3rd_party]): continue yield ( - node.lineno, - node.col_offset, - ERRORS['I900'].format(pkg=module), + node.line, + node.offset, + ERRORS['I900'].format(pkg=node.mod), Flake8Checker, ) diff --git a/test/test_checker.py b/test/test_checker.py index 525083b..042ea39 100644 --- a/test/test_checker.py +++ b/test/test_checker.py @@ -18,6 +18,7 @@ def get_requirements(self): "hyp-hen", "python-boom", "setuptools", + "space.module", )) @@ -82,7 +83,15 @@ def test_non_top_level_import(self): "I900 'cat' not listed as a requirement", ) - def test_relative_import(self): + def test_namespace(self): + errors = check("import space.module") + self.assertEqual(len(errors), 0) + errors = check("from space import module") + self.assertEqual(len(errors), 0) + errors = check("import space") + self.assertEqual(len(errors), 1) + + def test_relative(self): errors = check("from . import local") self.assertEqual(len(errors), 0) errors = check("from ..local import local")