Skip to content

Commit dc2e583

Browse files
authored
Enable flake8-bugbear (#530)
1 parent b42d95b commit dc2e583

12 files changed

+39
-35
lines changed

benchmark/data_frame_text_benchmark.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def main_torch(
529529
"Currently Trompt with finetuning is too expensive")
530530
model_cls = Trompt
531531
stype_encoder_dicts = []
532-
for i in range(train_cfg["num_layers"]):
532+
for _ in range(train_cfg["num_layers"]):
533533
stype_encoder_dicts.append(
534534
get_stype_encoder_dict(text_stype, text_encoder,
535535
train_tensor_frame))

pyproject.toml

+6-7
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,13 @@ name="torch_frame"
7070

7171
[tool.ruff] # https://docs.astral.sh/ruff/rules
7272
target-version = "py39"
73+
src = ["torch_frame", "test", "examples", "benchmark"]
74+
line-length = 80
75+
indent-width = 4
76+
77+
[tool.ruff.lint]
7378
select = [
79+
"B", # flake8-bugbear
7480
"D", # pydocstyle
7581
"UP", # pyupgrade
7682
]
@@ -83,13 +89,6 @@ ignore = [
8389
"D107", # Ignore "Missing docstring in __init__"
8490
"D205", # Ignore "blank line required between summary line and description"
8591
]
86-
src = ["torch_frame"]
87-
line-length = 80
88-
indent-width = 4
89-
90-
# [tool.ruff.per-files-ignores]
91-
92-
9392

9493
[tool.ruff.lint.pydocstyle]
9594
convention = "google"

torch_frame/config/image_embedder.py

-3
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ class ImageEmbedder(ABC):
1717
override :meth:`forward_retrieve` which takes the paths to images and
1818
return a list of :obj:`PIL.Image.Image`.
1919
"""
20-
def __init__(self, *args, **kwargs):
21-
pass
22-
2320
def forward_retrieve(self, path_to_images: list[str]) -> list[Image.Image]:
2421
r"""Retrieval function that reads a list of images from
2522
a list of file paths with the :obj:`RGB` mode.

torch_frame/data/dataset.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import copy
44
import functools
55
import os.path as osp
6-
from abc import ABC
76
from collections import defaultdict
87
from typing import Any
98

@@ -324,7 +323,7 @@ def __call__(
324323
return self._merge_feat(tf)
325324

326325

327-
class Dataset(ABC):
326+
class Dataset:
328327
r"""A base class for creating tabular datasets.
329328
330329
Args:
@@ -382,7 +381,7 @@ def __init__(
382381
col_to_image_embedder_cfg: dict[str, ImageEmbedderConfig]
383382
| ImageEmbedderConfig | None = None,
384383
col_to_time_format: str | None | dict[str, str | None] = None,
385-
):
384+
) -> None:
386385
self.df = df
387386
self.target_col = target_col
388387

torch_frame/data/multi_embedding_tensor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def _single_index_select(
194194
values=values,
195195
offset=offset,
196196
)
197-
assert False, "Should not reach here."
197+
raise AssertionError("Should not reach here.")
198198

199199
def fillna_col(
200200
self,
@@ -290,4 +290,4 @@ def cat(
290290
offset = torch.tensor(offset_list)
291291
return MultiEmbeddingTensor(num_rows, num_cols, values, offset)
292292

293-
assert False, "Should not reach here."
293+
raise AssertionError("Should not reach here.")

torch_frame/data/multi_tensor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def size(self, dim: int) -> int:
5959
return self.num_rows
6060
elif dim == 1:
6161
return self.num_cols
62-
assert False, "Should not reach here."
62+
raise AssertionError("Should not reach here.")
6363

6464
def dim(self) -> int:
6565
return self.ndim
@@ -243,7 +243,7 @@ def index_select(self, index: Tensor, dim: int) -> _MultiTensor:
243243
return self._row_index_select(idx)
244244
elif dim == 1:
245245
return self._col_index_select(idx)
246-
assert False, "Should not reach here."
246+
raise AssertionError("Should not reach here.")
247247

248248
def _row_index_select(self, index: Tensor) -> _MultiTensor:
249249
raise NotImplementedError
@@ -300,7 +300,7 @@ def narrow(self, dim: int, start: int, length: int) -> _MultiTensor:
300300
return self._row_narrow(start, length)
301301
elif dim == 1:
302302
return self._col_narrow(start, length)
303-
assert False, "Should not reach here."
303+
raise AssertionError("Should not reach here.")
304304

305305
def _row_narrow(self, start: int, length: int) -> _MultiTensor:
306306
raise NotImplementedError
@@ -339,7 +339,7 @@ def select(
339339
torch.tensor(index, dtype=torch.long, device=self.device),
340340
dim=dim,
341341
)
342-
assert False, "Should not reach here."
342+
raise AssertionError("Should not reach here.")
343343

344344
def _single_index_select(self, index: int, dim: int) -> _MultiTensor:
345345
raise NotImplementedError

torch_frame/datasets/data_frame_benchmark.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,7 @@ def __init__(
777777

778778
# Check the scale
779779
if dataset.num_rows < 5000:
780-
assert False
780+
raise AssertionError()
781781
elif dataset.num_rows < 50000:
782782
assert scale == "small"
783783
elif dataset.num_rows < 500000:

torch_frame/datasets/fake.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
self,
5959
num_rows: int,
6060
with_nan: bool = False,
61-
stypes: list[stype] = [stype.categorical, stype.numerical],
61+
stypes: list[stype] | None = None,
6262
create_split: bool = False,
6363
task_type: TaskType = TaskType.REGRESSION,
6464
col_to_text_embedder_cfg: dict[str, TextEmbedderConfig]
@@ -69,6 +69,7 @@ def __init__(
6969
| ImageEmbedderConfig | None = None,
7070
tmp_path: str | None = None,
7171
) -> None:
72+
stypes = stypes or [stype.categorical, stype.numerical]
7273
assert len(stypes) > 0
7374
df_dict: dict[str, list | np.ndarray]
7475
arr: list | np.ndarray
@@ -137,7 +138,7 @@ def __init__(
137138
if stype.sequence_numerical in stypes:
138139
for col_name in ['seq_num_1', 'seq_num_2']:
139140
arr = []
140-
for i in range(num_rows):
141+
for _ in range(num_rows):
141142
sequence_length = random.randint(1, 5)
142143
sequence = [
143144
random.random() for _ in range(sequence_length)

torch_frame/datasets/huggingface_dataset.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,9 @@ def __init__(
8686
) -> None:
8787
try:
8888
from datasets import DatasetDict, load_dataset
89-
except ImportError: # pragma: no cover
90-
raise ImportError("Please run `pip install datasets` at first.")
89+
except ImportError as e: # pragma: no cover
90+
raise ImportError(
91+
"Please run `pip install datasets` first.") from e
9192
dataset = load_dataset(path, name=name)
9293
if not isinstance(dataset, DatasetDict):
9394
raise ValueError(f"{self.__class__} only supports `DatasetDict`")

torch_frame/gbdt/gbdt.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,14 @@ def is_fitted(self) -> bool:
6868
r"""Whether the GBDT is already fitted."""
6969
return self._is_fitted
7070

71-
def tune(self, tf_train: TensorFrame, tf_val: TensorFrame, num_trials: int,
72-
*args, **kwargs):
71+
def tune(
72+
self,
73+
tf_train: TensorFrame,
74+
tf_val: TensorFrame,
75+
num_trials: int,
76+
*args,
77+
**kwargs,
78+
) -> None:
7379
r"""Fit the model by performing hyperparameter tuning using Optuna. The
7480
number of trials is specified by num_trials.
7581
@@ -85,7 +91,7 @@ def tune(self, tf_train: TensorFrame, tf_val: TensorFrame, num_trials: int,
8591
raise RuntimeError("tf_train.y must be a Tensor, but None given.")
8692
if tf_val.y is None:
8793
raise RuntimeError("tf_val.y must be a Tensor, but None given.")
88-
self._tune(tf_train, tf_val, num_trials=num_trials, *args, **kwargs)
94+
self._tune(tf_train, tf_val, *args, num_trials=num_trials, **kwargs)
8995
self._is_fitted = True
9096

9197
def predict(self, tf_test: TensorFrame) -> Tensor:

torch_frame/utils/infer_stype.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,9 @@ def infer_series_stype(ser: Series) -> stype | None:
150150
try:
151151
min_count_list.append(
152152
_min_count(
153-
ser.apply(
154-
lambda row: MultiCategoricalTensorMapper.
155-
split_by_sep(row, sep)).explode()))
153+
ser.apply(lambda row, sep=sep:
154+
MultiCategoricalTensorMapper.
155+
split_by_sep(row, sep)).explode()))
156156
except Exception as e:
157157
logging.warn(
158158
"Mapping series into multicategorical stype "

torch_frame/utils/io.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,12 @@ def load(
110110
"compatible in your case.")
111111
match = re.search(r'add_safe_globals\(.*?\)', error_msg)
112112
if match is not None:
113-
warnings.warn(f"{warn_msg} Please use "
114-
f"`torch.serialization.{match.group()}` to "
115-
f"allowlist this global.")
113+
warnings.warn(
114+
f"{warn_msg} Please use "
115+
f"`torch.serialization.{match.group()}` to "
116+
f"allowlist this global.", stacklevel=2)
116117
else:
117-
warnings.warn(warn_msg)
118+
warnings.warn(warn_msg, stacklevel=2)
118119

119120
tf_dict, col_stats = torch.load(path, weights_only=False)
120121
else:

0 commit comments

Comments
 (0)