Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions bugbear.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
self.check_for_b903(node)
self.check_for_b021(node)
self.check_for_b024_and_b027(node)
self.check_for_b042(node)
self.generic_visit(node)

def visit_Try(self, node) -> None:
Expand Down Expand Up @@ -1721,6 +1722,70 @@ def check(num_args: int, param_name: str) -> None:
elif func.attr == "split":
check(2, "maxsplit")

def check_for_b042(self, node: ast.ClassDef) -> None: # noqa: C901 # too-complex
def is_exception(s: str):
for ending in "Exception", "Error", "Warning", "ExceptionGroup":
if s.endswith(ending):
return True
return False

# A class must inherit from a super class to be an exception, and we also
# require the class name or any of the base names to look like an exception name.
if not (is_exception(node.name) and node.bases):
for base in node.bases:
if isinstance(base, ast.Name) and is_exception(base.id):
break
else:
return

# iterate body nodes looking for __init__
for fun in node.body:
if not (isinstance(fun, ast.FunctionDef) and fun.name == "__init__"):
continue
if fun.args.kwonlyargs or fun.args.kwarg:
# kwargs cannot be passed to super().__init__()
self.add_error("B042", fun)
return
# -1 to exclude the `self` argument
expected_arg_count = (
len(fun.args.posonlyargs)
+ len(fun.args.args)
- 1
+ (1 if fun.args.vararg else 0)
)
if expected_arg_count == 0:
# no arguments, don't need to call super().__init__()
return

# Look for super().__init__()
# We only check top-level nodes instead of doing an `ast.walk`.
# Small risk of false alarm if the user does something weird.
for b in fun.body:
if (
isinstance(b, ast.Expr)
and isinstance(b.value, ast.Call)
and isinstance(b.value.func, ast.Attribute)
and isinstance(b.value.func.value, ast.Call)
and isinstance(b.value.func.value.func, ast.Name)
and b.value.func.value.func.id == "super"
and b.value.func.attr == "__init__"
):
if len(b.value.args) != expected_arg_count:
self.add_error("B042", fun)
elif fun.args.vararg:
for arg in b.value.args:
if isinstance(arg, ast.Starred):
return
else:
# no Starred argument despite vararg
self.add_error("B042", fun)
return
else:
# no super().__init__() found
self.add_error("B042", fun)
return
# no `def __init__` found, which is fine

def check_for_b909(self, node: ast.For) -> None:
if isinstance(node.iter, ast.Name):
name = _to_name_str(node.iter)
Expand Down Expand Up @@ -2332,6 +2397,13 @@ def __call__(self, lineno: int, col: int, vars: tuple[object, ...] = ()) -> erro
message="B040 Exception with added note not used. Did you forget to raise it?"
),
"B041": Error(message=("B041 Repeated key-value pair in dictionary literal.")),
"B042": Error(
message=(
"B042 Exception class with `__init__` should pass all args to "
"`super().__init__()` in order to work with `copy.copy()`. "
"It should also not take any kwargs."
)
),
# Warnings disabled by default.
"B901": Error(
message=(
Expand Down
74 changes: 74 additions & 0 deletions tests/eval_files/b042.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
class MyError_no_args(Exception):
def __init__(self): # safe
...


class MyError_args_good(Exception):
def __init__(self, foo, bar=3):
super().__init__(foo, bar)


class MyError_args_bad(Exception):
def __init__(self, foo, bar=3): # B042: 4
super().__init__(foo)


class MyError_kwonlyargs(Exception):
def __init__(self, *, foo): # B042: 4
super().__init__(foo=foo)


class MyError_kwargs(Exception):
def __init__(self, **kwargs): # B042: 4
super().__init__(**kwargs)


class MyError_vararg_good(Exception):
def __init__(self, *args): # safe
super().__init__(*args)


class MyError_vararg_bad(Exception):
def __init__(self, *args): # B042: 4
super().__init__()


class MyError_args_nothing(Exception):
def __init__(self, *args): ... # B042: 4


class MyError_nested_init(Exception):
def __init__(self, x): # B042: 4
if True:
super().__init__(x)

class MyError_posonlyargs(Exception):
def __init__(self, x, /, y):
super().__init__(x, y)

# triggers if class name ends with, or
# if it inherits from a class whose name ends with, any of
# 'Error', 'Exception', 'ExceptionGroup', 'Warning', 'ExceptionGroup'
class Anything(ValueError):
def __init__(self, x): ... # B042: 4
class Anything2(BaseException):
def __init__(self, x): ... # B042: 4
class Anything3(ExceptionGroup):
def __init__(self, x): ... # B042: 4
class Anything4(UserWarning):
def __init__(self, x): ... # B042: 4

class MyError(Anything):
def __init__(self, x): ... # B042: 4
class MyException(Anything):
def __init__(self, x): ... # B042: 4
class MyExceptionGroup(Anything):
def __init__(self, x): ... # B042: 4
class MyWarning(Anything):
def __init__(self, x): ... # B042: 4

class ExceptionHandler(Anything):
def __init__(self, x): ... # safe

class FooException:
def __init__(self, x): ...