Skip to content

Commit

Permalink
Automatically discover project's root directory
Browse files Browse the repository at this point in the history
Fixes arkq#7
  • Loading branch information
arkq committed Feb 4, 2020
1 parent 8e219e2 commit d2b7eab
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 70 deletions.
43 changes: 30 additions & 13 deletions src/flake8_requirements/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .modules import STDLIB_PY3

# NOTE: Changing this number will alter package version as well.
__version__ = "1.2.0"
__version__ = "1.3.0"
__license__ = "MIT"

LOG = getLogger('flake8.plugin.requirements')
Expand Down Expand Up @@ -169,7 +169,7 @@ class SetupVisitor(ast.NodeVisitor):
'zip_safe': 0.6,
}

def __init__(self, tree):
def __init__(self, tree, cwd):
"""Initialize package setup visitor."""
self.redirected = False
self.keywords = {}
Expand All @@ -188,13 +188,13 @@ def setup(**kw):
# have to add its root directory to the import search path.
# Note however, that this hack might break further imports
# for OUR Python instance (we're changing our own sys.path)!
sys.path.insert(0, os.getcwd())
sys.path.insert(0, cwd)

try:
tree = ast.fix_missing_locations(tree)
eval(compile(tree, "<str>", mode='exec'), {
'__name__': "__main__",
'__file__': "setup.py",
'__file__': os.path.join(cwd, "setup.py"),
'__f8r_setup': setup,
})
except Exception as e:
Expand All @@ -204,6 +204,7 @@ def setup(**kw):
# heuristic for function arguments). Anyway, we shall not
# break flake8 execution due to out eval() usage.
LOG.exception("Couldn't evaluate setup.py: %s", e)
self.redirected = False

# Restore import search path.
sys.path.pop(0)
Expand Down Expand Up @@ -288,6 +289,9 @@ class Flake8Checker(object):
# Max depth to resolve recursive requirements.
requirements_max_depth = 1

# Root directory of the project.
root_dir = ""

def __init__(self, tree, filename, lines=None):
"""Initialize requirements checker."""
self.tree = tree
Expand Down Expand Up @@ -333,6 +337,20 @@ def parse_options(cls, options):
]
}
cls.requirements_max_depth = options.requirements_max_depth
cls.discover_project_root_dir()

@classmethod
def discover_project_root_dir(cls):
"""Discover project's root directory."""
root_dir = os.getcwd()
root_files = ["pyproject.toml", "requirements.txt", "setup.py"]
while root_dir != os.path.abspath(os.sep):
paths = [os.path.join(root_dir, x) for x in root_files]
if any(map(os.path.exists, paths)):
LOG.info("Discovered root directory: %s", root_dir)
cls.root_dir = root_dir
break
root_dir = os.path.abspath(os.path.join(root_dir, ".."))

@classmethod
def resolve_requirement(cls, requirement, max_depth=0):
Expand All @@ -355,7 +373,7 @@ def resolve_requirement(cls, requirement, max_depth=0):
raise RuntimeError(msg.format(requirement))
resolved = []
# Error out if requirements file cannot be opened.
with open(requirement) as f:
with open(os.path.join(cls.root_dir, requirement)) as f:
for line in joinlines(f.readlines()):
resolved.extend(cls.resolve_requirement(
line, max_depth - 1))
Expand All @@ -370,10 +388,10 @@ def resolve_requirement(cls, requirement, max_depth=0):
def get_pyproject_toml(cls):
"""Try to load PEP 518 configuration file."""
try:
with open("pyproject.toml") as f:
with open(os.path.join(cls.root_dir, "pyproject.toml")) as f:
return toml.loads(f.read())
except IOError as e:
LOG.warning("Couldn't load project setup: %s", e)
LOG.debug("Couldn't load project setup: %s", e)
return {}

@classmethod
Expand All @@ -387,10 +405,9 @@ def get_pyproject_toml_poetry(cls):
def get_requirements_txt(cls):
"""Try to load requirements from text file."""
try:
if not os.path.exists("requirements.txt"):
return ()
path = os.path.join(cls.root_dir, "requirements.txt")
return tuple(parse_requirements(cls.resolve_requirement(
"-r requirements.txt", cls.requirements_max_depth + 1)))
"-r {}".format(path), cls.requirements_max_depth + 1)))
except IOError as e:
LOG.debug("Couldn't load requirements: %s", e)
return ()
Expand All @@ -400,11 +417,11 @@ def get_requirements_txt(cls):
def get_setup_py(cls):
"""Try to load standard setup file."""
try:
with open("setup.py") as f:
return SetupVisitor(ast.parse(f.read()))
with open(os.path.join(cls.root_dir, "setup.py")) as f:
return SetupVisitor(ast.parse(f.read()), cls.root_dir)
except IOError as e:
LOG.debug("Couldn't load project setup: %s", e)
return SetupVisitor(ast.parse(""))
return SetupVisitor(ast.parse(""), cls.root_dir)

@classmethod
@memoize
Expand Down
94 changes: 41 additions & 53 deletions test/test_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,71 +80,59 @@ def test_resolve_requirement_with_file_recursion(self):
)

def test_init_with_no_requirements(self):
with mock.patch("os.path.exists", return_value=False) as exists:
with mock.patch(builtins_open, mock.mock_open()) as m:
m.side_effect = IOError("No such file or directory"),
checker = Flake8Checker(None, None)
requirements = checker.get_requirements_txt()
self.assertEqual(requirements, ())
exists.assert_called_once_with("requirements.txt")
self.assertEqual(checker.get_requirements_txt(), ())

def test_init_with_simple_requirements(self):
content = "foo >= 1.0.0\nbar <= 1.0.0\n"
with mock.patch(builtins_open, mock.mock_open(read_data=content)):

with mock.patch("os.path.exists", return_value=True):
with mock.patch(builtins_open, mock.mock_open()) as m:
m.side_effect = (
mock.mock_open(read_data=content).return_value,
)

checker = Flake8Checker(None, None)
requirements = checker.get_requirements_txt()

self.assertEqual(
sorted(requirements, key=lambda x: x.project_name),
sorted(parse_requirements([
"foo >= 1.0.0",
"bar <= 1.0.0",
]), key=lambda x: x.project_name),
)
checker = Flake8Checker(None, None)
self.assertEqual(
checker.get_requirements_txt(),
tuple(parse_requirements([
"foo >= 1.0.0",
"bar <= 1.0.0",
])),
)

def test_init_with_recursive_requirements_beyond_max_depth(self):
content = "foo >= 1.0.0\n-r inner.txt\nbar <= 1.0.0\n"
inner_content = "# inner\nbaz\n\nqux\n"

with mock.patch("os.path.exists", return_value=True):
with mock.patch(builtins_open, mock.mock_open()) as m:
m.side_effect = (
mock.mock_open(read_data=content).return_value,
mock.mock_open(read_data=inner_content).return_value,
)

with self.assertRaises(RuntimeError):
try:
Flake8Checker.requirements_max_depth = 0
checker = Flake8Checker(None, None)
checker.get_requirements_txt()
finally:
Flake8Checker.requirements_max_depth = 1
with mock.patch(builtins_open, mock.mock_open()) as m:
m.side_effect = (
mock.mock_open(read_data=content).return_value,
mock.mock_open(read_data=inner_content).return_value,
)

with self.assertRaises(RuntimeError):
try:
Flake8Checker.requirements_max_depth = 0
checker = Flake8Checker(None, None)
checker.get_requirements_txt()
finally:
Flake8Checker.requirements_max_depth = 1

def test_init_with_recursive_requirements(self):
content = "foo >= 1.0.0\n-r inner.txt\nbar <= 1.0.0\n"
inner_content = "# inner\nbaz\n\nqux\n"

with mock.patch("os.path.exists", return_value=True):
with mock.patch(builtins_open, mock.mock_open()) as m:
m.side_effect = (
mock.mock_open(read_data=content).return_value,
mock.mock_open(read_data=inner_content).return_value,
)

checker = Flake8Checker(None, None)
requirements = checker.get_requirements_txt()

self.assertEqual(
sorted(requirements, key=lambda x: x.project_name),
sorted(parse_requirements([
"foo >= 1.0.0",
"baz",
"qux",
"bar <= 1.0.0",
]), key=lambda x: x.project_name),
)
with mock.patch(builtins_open, mock.mock_open()) as m:
m.side_effect = (
mock.mock_open(read_data=content).return_value,
mock.mock_open(read_data=inner_content).return_value,
)

checker = Flake8Checker(None, None)
self.assertEqual(
checker.get_requirements_txt(),
tuple(parse_requirements([
"foo >= 1.0.0",
"baz",
"qux",
"bar <= 1.0.0",
])),
)
8 changes: 4 additions & 4 deletions test/test_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_detect_setup(self):
"packages=['']",
"url='URL'",
)))
setup = SetupVisitor(ast.parse(code))
setup = SetupVisitor(ast.parse(code), "")
self.assertEqual(setup.redirected, True)
self.assertDictEqual(setup.keywords, {
'name': 'A',
Expand All @@ -30,7 +30,7 @@ def test_detect_setup(self):
"{}='{}'".format(x, x)
for x in SetupVisitor.attributes
))
setup = SetupVisitor(ast.parse(code))
setup = SetupVisitor(ast.parse(code), "")
self.assertEqual(setup.redirected, True)
self.assertDictEqual(setup.keywords, {
x: x for x in SetupVisitor.attributes
Expand All @@ -43,7 +43,7 @@ def test_detect_setup(self):
"processing=True",
"verbose=True",
)))
setup = SetupVisitor(ast.parse(code))
setup = SetupVisitor(ast.parse(code), "")
self.assertEqual(setup.redirected, False)

def test_get_requirements(self):
Expand All @@ -55,7 +55,7 @@ def test_get_requirements(self):
'extras_require': {
'extra': ["extra < 10"],
},
}))))
}))), "")
self.assertEqual(setup.redirected, True)
self.assertEqual(
sorted(setup.get_requirements(), key=lambda x: x.project_name),
Expand Down

0 comments on commit d2b7eab

Please sign in to comment.