Skip to content

Commit

Permalink
Store imported modules in a radix-tree structure
Browse files Browse the repository at this point in the history
  • Loading branch information
arkq committed Apr 7, 2021
1 parent bedd1a3 commit 0e3b5fc
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 47 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def get_abs_path(pathname):
packages=["flake8_requirements"],
install_requires=[
"flake8 >= 2.0.0",
"setuptools",
"toml",
"setuptools >= 10.0.0",
"toml >= 0.7.0",
],
setup_requires=["pytest-runner"],
tests_require=["mock", "pytest"],
Expand Down
100 changes: 57 additions & 43 deletions src/flake8_requirements/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,6 @@ def modsplit(module):
return tuple(module.split("."))


def modcmp(lib=(), test=()):
"""Compare import modules."""
if len(lib) > len(test):
return False
return all(a == b for a, b in zip(lib, test))


def project2module(project):
"""Convert project name into a module name."""
# Name unification in accordance with PEP 426.
Expand Down Expand Up @@ -91,7 +84,7 @@ class ImportVisitor(ast.NodeVisitor):
"""Import statement visitor."""

# Convenience structure for storing import statement.
Import = namedtuple('Import', ('line', 'offset', 'mod', 'alt'))
Import = namedtuple('Import', ('line', 'offset', 'module'))

def __init__(self, tree):
"""Initialize import statement visitor."""
Expand All @@ -102,8 +95,7 @@ def visit_Import(self, node):
self.imports.append(ImportVisitor.Import(
node.lineno,
node.col_offset,
node.names[0].name,
node.names[0].name,
modsplit(node.names[0].name),
))

def visit_ImportFrom(self, node):
Expand All @@ -113,10 +105,9 @@ def visit_ImportFrom(self, node):
self.imports.append(ImportVisitor.Import(
node.lineno,
node.col_offset,
node.module,
# Alternative module name which covers:
# Module name which covers:
# > from namespace import module
".".join((node.module, node.names[0].name)),
modsplit(node.module) + modsplit(node.names[0].name),
))


Expand Down Expand Up @@ -277,6 +268,26 @@ def visit_Call(self, node):
self.redirected = True


class ModuleSet(dict):
"""Radix-tree-like structure for modules lookup."""

requirement = None

def add(self, module, requirement):
for mod in module:
self = self.setdefault(mod, ModuleSet())
self.requirement = requirement

def __contains__(self, module):
for mod in module:
self = self.get(mod)
if self is None:
return False
if self.requirement is not None:
return True
return False


class Flake8Checker(object):
"""Package requirements checker."""

Expand Down Expand Up @@ -547,15 +558,16 @@ def get_setup_py(cls):
@classmethod
@memoize
def get_mods_1st_party(cls):
mods_1st_party = set()
mods_1st_party = ModuleSet()
# Get 1st party modules (used for absolute imports).
modules = [project2module(
cls.get_setup_py().keywords.get('name') or
cls.get_pyproject_toml_poetry().get('name') or
"")]
if modules[0] in cls.known_modules:
modules = cls.known_modules[modules[0]]
mods_1st_party.update(modsplit(x) for x in modules)
for module in modules:
mods_1st_party.add(modsplit(module), True)
return mods_1st_party

def get_mods_3rd_party_requirements(self):
Expand Down Expand Up @@ -590,18 +602,18 @@ def get_mods_3rd_party_requirements(self):

@memoize
def get_mods_3rd_party(self):
mods_3rd_party = set()
mods_3rd_party = ModuleSet()
# Get 3rd party module names based on requirements.
for requirement in self.get_mods_3rd_party_requirements():
modules = [project2module(requirement.project_name)]
if modules[0] in self.known_3rd_parties:
modules = self.known_3rd_parties[modules[0]]
if modules[0] in self.known_host_3rd_parties:
elif modules[0] in self.known_host_3rd_parties:
modules = self.known_host_3rd_parties[modules[0]]
if modules[0] in self.known_modules:
elif modules[0] in self.known_modules:
modules = self.known_modules[modules[0]]
mods_3rd_party.update(modsplit(x) for x in modules)

for module in modules:
mods_3rd_party.add(modsplit(module), requirement)
return mods_3rd_party

@property
Expand All @@ -612,33 +624,35 @@ def processing_setup_py(self):
except OSError:
return False

def run(self):
"""Run checker."""
mods_1st_party = self.get_mods_1st_party()
mods_3rd_party = self.get_mods_3rd_party()

def check_I900(self, node):
"""Run missing requirement checker."""
if node.module[0] in STDLIB:
return
if node.module in self.get_mods_3rd_party():
return
if node.module in self.get_mods_1st_party():
return
# When processing setup.py file, forcefully add setuptools to the
# project requirements. Setuptools might be required to build the
# project, even though it is not listed as a requirement - this
# package is required to run setup.py, so listing it as a setup
# requirement would be pointless.
if self.processing_setup_py:
mods_3rd_party.add(modsplit("setuptools"))
if (self.processing_setup_py and
node.module[0] in KNOWN_3RD_PARTIES["setuptools"]):
return
return ERRORS['I900'].format(pkg=node.module[0])

def check_I901(self, node):
"""Run not-used requirement checker."""
return

def run(self):
"""Run checker."""

checkers = []
checkers.append(self.check_I900)
checkers.append(self.check_I901)

for node in ImportVisitor(self.tree).imports:
_mod = modsplit(node.mod)
_alt = modsplit(node.alt)
if _mod[0] in STDLIB:
continue
if any([modcmp(x, _mod) or modcmp(x, _alt)
for x in mods_1st_party]):
continue
if any([modcmp(x, _mod) or modcmp(x, _alt)
for x in mods_3rd_party]):
continue
yield (
node.line,
node.offset,
ERRORS['I900'].format(pkg=node.mod),
Flake8Checker,
)
for err in filter(None, map(lambda c: c(node), checkers)):
yield (node.line, node.offset, err, Flake8Checker)
5 changes: 3 additions & 2 deletions test/test_poetry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest

from flake8_requirements.checker import Flake8Checker
from flake8_requirements.checker import ModuleSet
from flake8_requirements.checker import memoize

try:
Expand Down Expand Up @@ -33,7 +34,7 @@ def test_1st_party(self):

checker = Flake8Checker(None, None)
mods = checker.get_mods_1st_party()
self.assertEqual(mods, set([("book",)]))
self.assertEqual(mods, ModuleSet({"book": {}}))

def test_3rd_party(self):
content = "[tool.poetry.dependencies]\ntools='1.0'\n"
Expand All @@ -47,4 +48,4 @@ def test_3rd_party(self):

checker = Flake8Checker(None, None)
mods = checker.get_mods_3rd_party()
self.assertEqual(mods, set([("tools",), ("dev_tools",)]))
self.assertEqual(mods, ModuleSet({"tools": {}, "dev_tools": {}}))

0 comments on commit 0e3b5fc

Please sign in to comment.