diff --git a/pre_commit_hooks/debug_statement_hook.py b/pre_commit_hooks/debug_statement_hook.py index 7e6be95e..9139720c 100644 --- a/pre_commit_hooks/debug_statement_hook.py +++ b/pre_commit_hooks/debug_statement_hook.py @@ -74,8 +74,24 @@ def check_file(filename: str) -> int: def main(argv: Sequence[str] | None = None) -> int: parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to run') + parser.add_argument( + '--forbid', + type=str, action='append', + help='Extra module name(s) to forbid', + ) + parser.add_argument( + '--allow', + type=str, + action='append', + help='Extra module name(s) to allow', + ) args = parser.parse_args(argv) + for name in args.forbid or (): + DEBUG_STATEMENTS.add(name) + for name in args.allow or (): + DEBUG_STATEMENTS.discard(name) + retv = 0 for filename in args.filenames: retv |= check_file(filename) diff --git a/tests/debug_statement_hook_test.py b/tests/debug_statement_hook_test.py index 5a8e0bb2..8200ea29 100644 --- a/tests/debug_statement_hook_test.py +++ b/tests/debug_statement_hook_test.py @@ -32,6 +32,20 @@ def test_finds_breakpoint(): assert visitor.breakpoints == [Debug(1, 0, 'breakpoint', 'called')] +def test_allow(tmpdir): + f_py = tmpdir.join('f.py') + f_py.write('import q') + ret = main([str(f_py), '--allow', 'q']) + assert ret == 0 + + +def test_forbid(tmpdir): + f_py = tmpdir.join('f.py') + f_py.write('import foo') + ret = main([str(f_py), '--forbid', 'foo']) + assert ret == 1 + + def test_returns_one_for_failing_file(tmpdir): f_py = tmpdir.join('f.py') f_py.write('def f():\n import pdb; pdb.set_trace()')