From cdefc929622820749ccd84939a09dddd46771cdf Mon Sep 17 00:00:00 2001 From: Arkadiusz Bokowy Date: Sat, 25 Nov 2017 19:29:37 +0100 Subject: [PATCH] Heuristic for project to module conversion --- src/flake8_requirements/checker.py | 33 +++++++++++++++++++++++------- test/test_checker.py | 13 ++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/src/flake8_requirements/checker.py b/src/flake8_requirements/checker.py index 4fbf7fb..9b60547 100644 --- a/src/flake8_requirements/checker.py +++ b/src/flake8_requirements/checker.py @@ -24,6 +24,16 @@ STDLIB.update(STDLIB_PY3) +def project2module(project): + """Convert project name into a module name.""" + # Name unification in accordance with PEP 426. + project = project.lower().replace("-", "_") + if project.startswith("python_"): + # Remove conventional "python-" prefix. + project = project[7:] + return project + + class ImportVisitor(ast.NodeVisitor): """Import statement visitor.""" @@ -183,19 +193,28 @@ def get_setup(self): def run(self): """Run checker.""" + def split(module): + """Split module into submodules.""" + return tuple(module.split(".")) + def modcmp(mod1=(), mod2=()): """Compare import modules.""" return all(a == b for a, b in zip(mod1, mod2)) - requirements = set() + mods_1st_party = set() + mods_3rd_party = set() + + # Get 1st party modules (used for absolute imports). + mods_1st_party.add( + split(project2module(self.setup.keywords['name'])), + ) - # Get module names based on requirements. + # Get 3rd party module names based on requirements. for requirement in self.setup.get_requirements(): - project = requirement.project_name.lower() - modules = [project.replace("-", "_")] - if project in KNOWN_3RD_PARTIES: - modules = KNOWN_3RD_PARTIES[project] - requirements.update(tuple(x.split(".")) for x in modules) + modules = [project2module(requirement.project_name)] + if modules[0] in KNOWN_3RD_PARTIES: + modules = KNOWN_3RD_PARTIES[modules[0]] + mods_3rd_party.update(split(x) for x in modules) for node, module in ImportVisitor(self.tree).imports: _module = module.split(".") diff --git a/test/test_checker.py b/test/test_checker.py index 6632a28..525083b 100644 --- a/test/test_checker.py +++ b/test/test_checker.py @@ -7,11 +7,16 @@ class SetupVisitorMock: + keywords = { + 'name': "flake8-requires", + } + def get_requirements(self): return parse_requirements(( "foo", "bar", "hyp-hen", + "python-boom", "setuptools", )) @@ -41,10 +46,18 @@ def test_stdlib_case(self): "I900 'cprofile' not listed as a requirement", ) + def test_1st_party(self): + errors = check("import flake8_requires") + self.assertEqual(len(errors), 0) + def test_3rd_party(self): errors = check("import foo\nfrom bar import Bar") self.assertEqual(len(errors), 0) + def test_3rd_party_python_prefix(self): + errors = check("from boom import blast") + self.assertEqual(len(errors), 0) + def test_3rd_party_missing(self): errors = check("import os\nfrom cat import Cat") self.assertEqual(len(errors), 1)