diff --git a/src/flake8_requirements/checker.py b/src/flake8_requirements/checker.py index ba4a66c..58c2070 100644 --- a/src/flake8_requirements/checker.py +++ b/src/flake8_requirements/checker.py @@ -15,7 +15,7 @@ from .modules import STDLIB_PY3 # NOTE: Changing this number will alter package version as well. -__version__ = "1.2.0" +__version__ = "1.3.0" __license__ = "MIT" LOG = getLogger('flake8.plugin.requirements') @@ -169,7 +169,7 @@ class SetupVisitor(ast.NodeVisitor): 'zip_safe': 0.6, } - def __init__(self, tree): + def __init__(self, tree, cwd): """Initialize package setup visitor.""" self.redirected = False self.keywords = {} @@ -188,13 +188,13 @@ def setup(**kw): # have to add its root directory to the import search path. # Note however, that this hack might break further imports # for OUR Python instance (we're changing our own sys.path)! - sys.path.insert(0, os.getcwd()) + sys.path.insert(0, cwd) try: tree = ast.fix_missing_locations(tree) eval(compile(tree, "", mode='exec'), { '__name__': "__main__", - '__file__': "setup.py", + '__file__': os.path.join(cwd, "setup.py"), '__f8r_setup': setup, }) except Exception as e: @@ -204,6 +204,7 @@ def setup(**kw): # heuristic for function arguments). Anyway, we shall not # break flake8 execution due to out eval() usage. LOG.exception("Couldn't evaluate setup.py: %s", e) + self.redirected = False # Restore import search path. sys.path.pop(0) @@ -288,6 +289,9 @@ class Flake8Checker(object): # Max depth to resolve recursive requirements. requirements_max_depth = 1 + # Root directory of the project. + root_dir = "" + def __init__(self, tree, filename, lines=None): """Initialize requirements checker.""" self.tree = tree @@ -333,6 +337,20 @@ def parse_options(cls, options): ] } cls.requirements_max_depth = options.requirements_max_depth + cls.discover_project_root_dir() + + @classmethod + def discover_project_root_dir(cls): + """Discover project's root directory.""" + root_dir = os.getcwd() + root_files = ["pyproject.toml", "requirements.txt", "setup.py"] + while root_dir != os.path.abspath(os.sep): + paths = [os.path.join(root_dir, x) for x in root_files] + if any(map(os.path.exists, paths)): + LOG.info("Discovered root directory: %s", root_dir) + cls.root_dir = root_dir + break + root_dir = os.path.abspath(os.path.join(root_dir, "..")) @classmethod def resolve_requirement(cls, requirement, max_depth=0): @@ -355,7 +373,7 @@ def resolve_requirement(cls, requirement, max_depth=0): raise RuntimeError(msg.format(requirement)) resolved = [] # Error out if requirements file cannot be opened. - with open(requirement) as f: + with open(os.path.join(cls.root_dir, requirement)) as f: for line in joinlines(f.readlines()): resolved.extend(cls.resolve_requirement( line, max_depth - 1)) @@ -370,10 +388,10 @@ def resolve_requirement(cls, requirement, max_depth=0): def get_pyproject_toml(cls): """Try to load PEP 518 configuration file.""" try: - with open("pyproject.toml") as f: + with open(os.path.join(cls.root_dir, "pyproject.toml")) as f: return toml.loads(f.read()) except IOError as e: - LOG.warning("Couldn't load project setup: %s", e) + LOG.debug("Couldn't load project setup: %s", e) return {} @classmethod @@ -387,10 +405,9 @@ def get_pyproject_toml_poetry(cls): def get_requirements_txt(cls): """Try to load requirements from text file.""" try: - if not os.path.exists("requirements.txt"): - return () + path = os.path.join(cls.root_dir, "requirements.txt") return tuple(parse_requirements(cls.resolve_requirement( - "-r requirements.txt", cls.requirements_max_depth + 1))) + "-r {}".format(path), cls.requirements_max_depth + 1))) except IOError as e: LOG.debug("Couldn't load requirements: %s", e) return () @@ -400,11 +417,11 @@ def get_requirements_txt(cls): def get_setup_py(cls): """Try to load standard setup file.""" try: - with open("setup.py") as f: - return SetupVisitor(ast.parse(f.read())) + with open(os.path.join(cls.root_dir, "setup.py")) as f: + return SetupVisitor(ast.parse(f.read()), cls.root_dir) except IOError as e: LOG.debug("Couldn't load project setup: %s", e) - return SetupVisitor(ast.parse("")) + return SetupVisitor(ast.parse(""), cls.root_dir) @classmethod @memoize diff --git a/test/test_requirements.py b/test/test_requirements.py index 25f74cc..325b6dc 100644 --- a/test/test_requirements.py +++ b/test/test_requirements.py @@ -80,71 +80,59 @@ def test_resolve_requirement_with_file_recursion(self): ) def test_init_with_no_requirements(self): - with mock.patch("os.path.exists", return_value=False) as exists: + with mock.patch(builtins_open, mock.mock_open()) as m: + m.side_effect = IOError("No such file or directory"), checker = Flake8Checker(None, None) - requirements = checker.get_requirements_txt() - self.assertEqual(requirements, ()) - exists.assert_called_once_with("requirements.txt") + self.assertEqual(checker.get_requirements_txt(), ()) def test_init_with_simple_requirements(self): content = "foo >= 1.0.0\nbar <= 1.0.0\n" + with mock.patch(builtins_open, mock.mock_open(read_data=content)): - with mock.patch("os.path.exists", return_value=True): - with mock.patch(builtins_open, mock.mock_open()) as m: - m.side_effect = ( - mock.mock_open(read_data=content).return_value, - ) - - checker = Flake8Checker(None, None) - requirements = checker.get_requirements_txt() - - self.assertEqual( - sorted(requirements, key=lambda x: x.project_name), - sorted(parse_requirements([ - "foo >= 1.0.0", - "bar <= 1.0.0", - ]), key=lambda x: x.project_name), - ) + checker = Flake8Checker(None, None) + self.assertEqual( + checker.get_requirements_txt(), + tuple(parse_requirements([ + "foo >= 1.0.0", + "bar <= 1.0.0", + ])), + ) def test_init_with_recursive_requirements_beyond_max_depth(self): content = "foo >= 1.0.0\n-r inner.txt\nbar <= 1.0.0\n" inner_content = "# inner\nbaz\n\nqux\n" - with mock.patch("os.path.exists", return_value=True): - with mock.patch(builtins_open, mock.mock_open()) as m: - m.side_effect = ( - mock.mock_open(read_data=content).return_value, - mock.mock_open(read_data=inner_content).return_value, - ) - - with self.assertRaises(RuntimeError): - try: - Flake8Checker.requirements_max_depth = 0 - checker = Flake8Checker(None, None) - checker.get_requirements_txt() - finally: - Flake8Checker.requirements_max_depth = 1 + with mock.patch(builtins_open, mock.mock_open()) as m: + m.side_effect = ( + mock.mock_open(read_data=content).return_value, + mock.mock_open(read_data=inner_content).return_value, + ) + + with self.assertRaises(RuntimeError): + try: + Flake8Checker.requirements_max_depth = 0 + checker = Flake8Checker(None, None) + checker.get_requirements_txt() + finally: + Flake8Checker.requirements_max_depth = 1 def test_init_with_recursive_requirements(self): content = "foo >= 1.0.0\n-r inner.txt\nbar <= 1.0.0\n" inner_content = "# inner\nbaz\n\nqux\n" - with mock.patch("os.path.exists", return_value=True): - with mock.patch(builtins_open, mock.mock_open()) as m: - m.side_effect = ( - mock.mock_open(read_data=content).return_value, - mock.mock_open(read_data=inner_content).return_value, - ) - - checker = Flake8Checker(None, None) - requirements = checker.get_requirements_txt() - - self.assertEqual( - sorted(requirements, key=lambda x: x.project_name), - sorted(parse_requirements([ - "foo >= 1.0.0", - "baz", - "qux", - "bar <= 1.0.0", - ]), key=lambda x: x.project_name), - ) + with mock.patch(builtins_open, mock.mock_open()) as m: + m.side_effect = ( + mock.mock_open(read_data=content).return_value, + mock.mock_open(read_data=inner_content).return_value, + ) + + checker = Flake8Checker(None, None) + self.assertEqual( + checker.get_requirements_txt(), + tuple(parse_requirements([ + "foo >= 1.0.0", + "baz", + "qux", + "bar <= 1.0.0", + ])), + ) diff --git a/test/test_setup.py b/test/test_setup.py index 4580f0d..8390c56 100644 --- a/test/test_setup.py +++ b/test/test_setup.py @@ -16,7 +16,7 @@ def test_detect_setup(self): "packages=['']", "url='URL'", ))) - setup = SetupVisitor(ast.parse(code)) + setup = SetupVisitor(ast.parse(code), "") self.assertEqual(setup.redirected, True) self.assertDictEqual(setup.keywords, { 'name': 'A', @@ -30,7 +30,7 @@ def test_detect_setup(self): "{}='{}'".format(x, x) for x in SetupVisitor.attributes )) - setup = SetupVisitor(ast.parse(code)) + setup = SetupVisitor(ast.parse(code), "") self.assertEqual(setup.redirected, True) self.assertDictEqual(setup.keywords, { x: x for x in SetupVisitor.attributes @@ -43,7 +43,7 @@ def test_detect_setup(self): "processing=True", "verbose=True", ))) - setup = SetupVisitor(ast.parse(code)) + setup = SetupVisitor(ast.parse(code), "") self.assertEqual(setup.redirected, False) def test_get_requirements(self): @@ -55,7 +55,7 @@ def test_get_requirements(self): 'extras_require': { 'extra': ["extra < 10"], }, - })))) + }))), "") self.assertEqual(setup.redirected, True) self.assertEqual( sorted(setup.get_requirements(), key=lambda x: x.project_name),