Source code for flake8_dunder_all.__init__

#!/usr/bin/env python3
#
#  __init__.py
"""
A Flake8 plugin and pre-commit hook which checks to ensure modules have defined ``__all__``.
"""
#
#  Copyright (c) 2020-2022 Dominic Davis-Foster <dominic@davis-foster.co.uk>
#
#  Based on flake8_2020
#  Copyright (c) 2019 Anthony Sottile
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in
#  all copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
#  THE SOFTWARE.
#

# stdlib
import ast
import sys
from typing import Any, Generator, Iterator, List, Set, Tuple, Type, Union

# 3rd party
from consolekit.terminal_colours import Fore
from domdf_python_tools.paths import PathPlus
from domdf_python_tools.typing import PathLike
from domdf_python_tools.utils import stderr_writer

# this package
from flake8_dunder_all.utils import find_noqa, get_docstring_lineno, mark_text_ranges

__author__: str = "Dominic Davis-Foster"
__copyright__: str = "2020 Dominic Davis-Foster"
__license__: str = "MIT"
__version__: str = "0.4.1"
__email__: str = "dominic@davis-foster.co.uk"

__all__ = ("Visitor", "Plugin", "check_and_add_all", "DALL000")

DALL000 = "DALL000 Module lacks __all__."


[docs]class Visitor(ast.NodeVisitor): """ AST :class:`~ast.NodeVisitor` to check a module has defined ``__all__``, and add one if it not. :param use_endlineno: Flag to indicate whether the end_lineno functionality is available. This functionality is available on Python 3.8 and above, or when the tree has been passed through :func:`flake8_dunder_all.utils.mark_text_ranges``. """ found_all: bool #: Flag to indicate a ``__all__`` declaration has been found in the AST. last_import: int #: The lineno of the last top-level or conditional import members: Set[str] #: List of functions and classed defined in the AST use_endlineno: bool def __init__(self, use_endlineno: bool = False) -> None: self.found_all = False self.members = set() self.last_import = 0 self.use_endlineno = use_endlineno
[docs] def visit_Name(self, node: ast.Name) -> None: """ Visit a variable. :param node: The node being visited. """ if node.id == "__all__": self.found_all = True else: self.generic_visit(node)
[docs] def handle_def(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef]) -> None: """ Handles ``def foo(): ...``, ``async def foo(): ...`` and ``class Foo: ...``. :param node: The node being visited. """ decorators = [] NameNode, AttributeNode = ast.Name, ast.Attribute for deco in node.decorator_list: # pylint: disable = if isinstance(deco, NameNode): decorators.append(deco.id) elif isinstance(deco, AttributeNode): parts = [deco.attr] # last_part = deco.value # # while True: # if isinstance(last_part, ast.Attribute): # parts.append(last_part.attr) # last_part = last_part.value # elif isinstance(last_part, ast.Name): # parts.append(last_part.id) # break # else: # break decorators.append('.'.join(reversed(parts))) if not node.name.startswith('_') and "overload" not in decorators: self.members.add(node.name)
[docs] def visit_FunctionDef(self, node: ast.FunctionDef) -> None: """ Visit ``def foo(): ...``. :param node: The node being visited. """ # Don't generic visit self.handle_def(node)
[docs] def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: """ Visit ``async def foo(): ...``. :param node: The node being visited. """ # Don't generic visit self.handle_def(node)
[docs] def visit_ClassDef(self, node: ast.ClassDef) -> None: """ Visit ``class Foo: ...``. :param node: The node being visited. """ # Don't generic visit self.handle_def(node)
[docs] def handle_import(self, node: Union[ast.Import, ast.ImportFrom]) -> None: """ Handles ``import foo`` and ``from foo import bar``. :param node: The node being visited """ if self.use_endlineno and node.end_lineno is not None: self.last_import = max(self.last_import, node.end_lineno) else: self.last_import = max(self.last_import, node.lineno)
[docs] def visit_Import(self, node: ast.Import) -> None: """ Visit ``import foo``. :param node: The node being visited """ # Don't generic visit self.handle_import(node)
[docs] def visit_ImportFrom(self, node: ast.ImportFrom) -> None: """ Visit ``from foo import bar``. :param node: The node being visited """ # Don't generic visit self.handle_import(node)
[docs] def visit_If(self, node: ast.If) -> None: """ Visit an if statement and check if it's for `TYPE_CHECKING`. :param node: The node being visited. """ if _is_type_checking(node.test): if self.use_endlineno and node.end_lineno is not None: self.last_import = max(self.last_import, node.end_lineno) else: self.last_import = max(self.last_import, max(_descend_node(node))) self.generic_visit(node)
[docs] def visit_Try(self, node: ast.Try) -> None: """ Visit a Try statement. :param node: The node being visited. """ if any(isinstance(n, (ast.Import, ast.ImportFrom)) for n in node.body): if self.use_endlineno and node.end_lineno is not None and sys.implementation.name != "pypy": # pragma: no cover (pypy) self.last_import = max(self.last_import, node.end_lineno) else: # pragma: no cover (!pypy) end_lineno = max( *_descend_node(node), *_descend_node(node, "handlers"), *_descend_node(node, "orelse"), *_descend_node(node, "finalbody"), ) self.last_import = max(self.last_import, end_lineno) self.generic_visit(node)
def _descend_node(node: ast.AST, attr: str = "body") -> Iterator[int]: for child in getattr(node, attr, []): yield child.lineno yield from _descend_node(child) _nameconstant = ast.Constant if sys.version_info >= (3, 8) else ast.NameConstant def _is_type_checking(node: ast.AST) -> bool: """ Does the given ``if`` node indicate a `TYPE_CHECKING` block? """ # noqa: D400 if isinstance(node, ast.Name) and node.id == "TYPE_CHECKING": return True elif isinstance(node, _nameconstant) and node.value is False: return True elif isinstance(node, ast.BoolOp): return any(_is_type_checking(value) for value in node.values) return False
[docs]class Plugin: """ A Flake8 plugin which checks to ensure modules have defined ``__all__``. :param tree: The abstract syntax tree (AST) to check. """ name: str = __name__ version: str = __version__ #: The plugin version def __init__(self, tree: ast.AST): self._tree = tree
[docs] def run(self) -> Generator[Tuple[int, int, str, Type[Any]], None, None]: """ Run the plugin. Yields four-element tuples, consisting of: #. The line number of the error. #. The column offset of the error. #. The error message. #. The class of the plugin raising the error. """ visitor = Visitor() visitor.visit(self._tree) if visitor.found_all: return elif not visitor.members: return else: yield 1, 0, DALL000, type(self)
[docs]def check_and_add_all(filename: PathLike, quote_type: str = '"', use_tuple: bool = False) -> int: """ Check the given filename for the presence of a ``__all__`` declaration, and add one if none is found. :param filename: The filename of the Python source file (``.py``) to check. :param quote_type: The type of quote to use for strings. :param use_tuple: Whether to use tuples instead of lists for ``__all__``. :returns: * ``0`` if the file already contains a ``__all__`` declaration, has no function or class definitions, or has a `` # noqa: DALL000 ` comment. * ``1`` If ``__all__`` is absent. * ``4`` if an error was encountered when parsing the file. .. versionchanged:: 0.2.0 Now returns ``0`` and doesn't add ``__all__`` if the file contains a ``noqa: DALL000`` comment. .. versionchanged:: 0.3.0 Added the ``use_tuple`` argument. """ filename = PathPlus(filename) try: source = filename.read_text() for line in source.splitlines(): noqas = find_noqa(line) if noqas is not None: if noqas["codes"]: # pylint: disable=loop-invariant-statement noqa_list: List[str] = noqas["codes"].rstrip().upper().split(',') if "DALL000" in noqa_list: return 0 # pylint: enable=loop-invariant-statement tree = ast.parse(source) if sys.version_info < (3, 8): # pragma: no cover (py38+) mark_text_ranges(tree, source) except SyntaxError: stderr_writer(Fore.RED(f"'{filename}' does not appear to be a valid Python source file.")) return 4 visitor = Visitor(use_endlineno=True) visitor.visit(tree) if visitor.found_all: return 0 else: docstring_start = (get_docstring_lineno(tree) or 0) - 1 docstring = ast.get_docstring(tree, clean=False) or '' docstring_end = len(docstring.split('\n')) + docstring_start insertion_position = max(docstring_end, visitor.last_import) + 1 if not visitor.members: return 0 members = f"{quote_type}, {quote_type}".join(sorted(visitor.members)) lines = filename.read_text().split('\n') # Ensure there don't end up too many lines if lines[insertion_position].strip(): lines.insert(insertion_position, '\n') else: lines.insert(insertion_position, '') if use_tuple: lines.insert(insertion_position, f"__all__ = ({quote_type}{members}{quote_type}, )") else: lines.insert(insertion_position, f"__all__ = [{quote_type}{members}{quote_type}]") filename.write_clean('\n'.join(lines)) return 1