Skip to content

Commit

Permalink
User defined project->modules mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
arkq committed Nov 27, 2017
1 parent 9f7648c commit 83b8774
Showing 1 changed file with 36 additions and 3 deletions.
39 changes: 36 additions & 3 deletions src/flake8_requirements/checker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import re
import sys
from collections import namedtuple
from itertools import chain
Expand Down Expand Up @@ -196,11 +197,40 @@ class Flake8Checker(object):
name = "flake8-requires"
version = __version__

# User defined project->modules mapping.
known_modules = {}

def __init__(self, tree, filename, lines=None):
"""Initialize requirements checker."""
self.setup = self.get_setup()
self.tree = tree

@classmethod
def add_options(cls, manager):
"""Register plug-in specific options."""
manager.add_option(
"--known-modules",
action='store',
parse_from_config=True,
default="",
help=(
"User defined mapping between a project name and a list of"
" provided modules. For example: ``--known-modules=project:"
"[Project],extra-project:[extras,utilities]``."
),
)

@classmethod
def parse_options(cls, options):
"""Parse plug-in specific options."""
cls.known_modules = {
project2module(k): v.split(",")
for k, v in [
x.split(":[")
for x in re.split(r"],?", options.known_modules)[:-1]
]
}

def get_setup(self):
"""Get package setup."""
with open("setup.py") as f:
Expand All @@ -223,15 +253,18 @@ def modcmp(lib=(), test=()):
mods_3rd_party = set()

# Get 1st party modules (used for absolute imports).
mods_1st_party.add(
split(project2module(self.setup.keywords['name'])),
)
modules = [project2module(self.setup.keywords.get('name', ""))]
if modules[0] in self.known_modules:
modules = self.known_modules[modules[0]]
mods_1st_party.update(split(x) for x in modules)

# Get 3rd party module names based on requirements.
for requirement in self.setup.get_requirements():
modules = [project2module(requirement.project_name)]
if modules[0] in KNOWN_3RD_PARTIES:
modules = KNOWN_3RD_PARTIES[modules[0]]
if modules[0] in self.known_modules:
modules = self.known_modules[modules[0]]
mods_3rd_party.update(split(x) for x in modules)

for node in ImportVisitor(self.tree).imports:
Expand Down

0 comments on commit 83b8774

Please sign in to comment.