Skip to content

Commit 6db5213

Browse files
Jayant-kernelclaude
andcommitted
[ENH] Update ruff linting standards to py310
Closes #649 Changes: - pyproject.toml: move deprecated top-level ruff keys to lint section - [tool.ruff.per-file-ignores] -> [tool.ruff.lint.per-file-ignores] - [tool.ruff.pydocstyle] -> [tool.ruff.lint.pydocstyle] - Remove deprecated ignore-init-module-imports option - Apply pyupgrade (UP) rules across the codebase for py310 target: - UP006: Replace typing.List/Dict/Tuple with list/dict/tuple - UP007: Replace Optional[X]/Union[X,Y] with X|None / X|Y - UP035: Remove deprecated typing imports, use built-ins directly - UP045: Replace Optional[X] with X|None in annotations - Fix import ordering (isort) after typing imports cleanup Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 8495f64 commit 6db5213

32 files changed

Lines changed: 353 additions & 331 deletions

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,14 @@ exclude = [
149149
"dist",
150150
"docs"
151151
]
152-
ignore-init-module-imports = true
153152

154-
[tool.ruff.per-file-ignores]
153+
154+
[tool.ruff.lint.per-file-ignores]
155155
"setup.py" = ["D100", "SIM115"]
156156
"__about__.py" = ["D100"]
157157
"__init__.py" = ["D100"]
158158

159-
[tool.ruff.pydocstyle]
159+
[tool.ruff.lint.pydocstyle]
160160
# Use numpy-style docstrings.
161161
convention = "numpy"
162162

src/pytorch_tabular/config/config.py

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
import os
77
import re
8+
from collections.abc import Iterable
89
from dataclasses import MISSING, dataclass, field
9-
from typing import Any, Dict, Iterable, List, Optional
10+
from typing import Any
1011

1112
from omegaconf import OmegaConf
1213

@@ -103,22 +104,22 @@ class DataConfig:
103104
104105
"""
105106

106-
target: Optional[List[str]] = field(
107+
target: list[str] | None = field(
107108
default=None,
108109
metadata={
109110
"help": "A list of strings with the names of the target column(s)."
110111
" It is mandatory for all except SSL tasks."
111112
},
112113
)
113-
continuous_cols: List = field(
114+
continuous_cols: list = field(
114115
default_factory=list,
115116
metadata={"help": "Column names of the numeric fields. Defaults to []"},
116117
)
117-
categorical_cols: List = field(
118+
categorical_cols: list = field(
118119
default_factory=list,
119120
metadata={"help": "Column names of the categorical fields to treat differently. Defaults to []"},
120121
)
121-
date_columns: List = field(
122+
date_columns: list = field(
122123
default_factory=list,
123124
metadata={
124125
"help": "(Column names, Freq) tuples of the date fields. For eg. a field named"
@@ -131,14 +132,14 @@ class DataConfig:
131132
default=True,
132133
metadata={"help": "Whether or not to encode the derived variables from date"},
133134
)
134-
validation_split: Optional[float] = field(
135+
validation_split: float | None = field(
135136
default=0.2,
136137
metadata={
137138
"help": "Percentage of Training rows to keep aside as validation."
138139
" Used only if Validation Data is not given separately"
139140
},
140141
)
141-
continuous_feature_transform: Optional[str] = field(
142+
continuous_feature_transform: str | None = field(
142143
default=None,
143144
metadata={
144145
"help": "Whether or not to transform the features before modelling. By default it is turned off.",
@@ -164,7 +165,7 @@ class DataConfig:
164165
" the noise is only applied for QuantileTransformer"
165166
},
166167
)
167-
num_workers: Optional[int] = field(
168+
num_workers: int | None = field(
168169
default=0,
169170
metadata={"help": "The number of workers used for data loading. For windows always set to 0"},
170171
)
@@ -186,7 +187,7 @@ class DataConfig:
186187
metadata={"help": "pickle protocol version passed to `torch.save` for dataset caching to disk"},
187188
)
188189

189-
dataloader_kwargs: Dict[str, Any] = field(
190+
dataloader_kwargs: dict[str, Any] = field(
190191
default_factory=dict,
191192
metadata={"help": "Additional kwargs to be passed to PyTorch DataLoader."},
192193
)
@@ -229,19 +230,19 @@ class InferredConfig:
229230
continuous_dim: int = field(
230231
metadata={"help": "The number of continuous features"},
231232
)
232-
output_dim: Optional[int] = field(
233+
output_dim: int | None = field(
233234
default=None,
234235
metadata={"help": "The number of output targets"},
235236
)
236-
output_cardinality: Optional[List[int]] = field(
237+
output_cardinality: list[int] | None = field(
237238
default=None,
238239
metadata={"help": "The number of unique values in classification output"},
239240
)
240-
categorical_cardinality: Optional[List[int]] = field(
241+
categorical_cardinality: list[int] | None = field(
241242
default=None,
242243
metadata={"help": "The number of unique values in categorical features"},
243244
)
244-
embedding_dims: Optional[List] = field(
245+
embedding_dims: list | None = field(
245246
default=None,
246247
metadata={
247248
"help": "The dimensions of the embedding for each categorical column as a list of tuples "
@@ -384,30 +385,30 @@ class TrainerConfig:
384385
},
385386
)
386387
max_epochs: int = field(default=10, metadata={"help": "Maximum number of epochs to be run"})
387-
min_epochs: Optional[int] = field(
388+
min_epochs: int | None = field(
388389
default=1,
389390
metadata={"help": "Force training for at least these many epochs. 1 by default"},
390391
)
391-
max_time: Optional[int] = field(
392+
max_time: int | None = field(
392393
default=None,
393394
metadata={"help": "Stop training after this amount of time has passed. Disabled by default (None)"},
394395
)
395-
accelerator: Optional[str] = field(
396+
accelerator: str | None = field(
396397
default="auto",
397398
metadata={
398399
"help": "The accelerator to use for training. Can be one of 'cpu','gpu','tpu','ipu','auto'."
399400
" Defaults to 'auto'",
400401
"choices": ["cpu", "gpu", "tpu", "ipu", "mps", "auto"],
401402
},
402403
)
403-
devices: Optional[int] = field(
404+
devices: int | None = field(
404405
default=-1,
405406
metadata={
406407
"help": "Number of devices to train on. -1 uses all available devices."
407408
" By default uses all available devices (-1)",
408409
},
409410
)
410-
devices_list: Optional[List[int]] = field(
411+
devices_list: list[int] | None = field(
411412
default=None,
412413
metadata={
413414
"help": "List of devices to train on (list). If specified, takes precedence over `devices` argument."
@@ -454,15 +455,15 @@ class TrainerConfig:
454455
"help": "If true enables cudnn.deterministic. Might make your system slower, but ensures reproducibility."
455456
},
456457
)
457-
profiler: Optional[str] = field(
458+
profiler: str | None = field(
458459
default=None,
459460
metadata={
460461
"help": "To profile individual steps during training and assist in identifying bottlenecks."
461462
" None, simple or advanced, pytorch",
462463
"choices": [None, "simple", "advanced", "pytorch"],
463464
},
464465
)
465-
early_stopping: Optional[str] = field(
466+
early_stopping: str | None = field(
466467
default="valid_loss",
467468
metadata={
468469
"help": "The loss/metric that needed to be monitored for early stopping."
@@ -484,14 +485,14 @@ class TrainerConfig:
484485
default=3,
485486
metadata={"help": "The number of epochs to wait until there is no further improvements in loss/metric"},
486487
)
487-
early_stopping_kwargs: Optional[Dict[str, Any]] = field(
488+
early_stopping_kwargs: dict[str, Any] | None = field(
488489
default_factory=lambda: {},
489490
metadata={
490491
"help": "Additional keyword arguments for the early stopping callback."
491492
" See the documentation for the PyTorch Lightning EarlyStopping callback for more details."
492493
},
493494
)
494-
checkpoints: Optional[str] = field(
495+
checkpoints: str | None = field(
495496
default="valid_loss",
496497
metadata={
497498
"help": "The loss/metric that needed to be monitored for checkpoints. If None, there will be no checkpoints"
@@ -505,7 +506,7 @@ class TrainerConfig:
505506
default=1,
506507
metadata={"help": "Number of training steps between checkpoints"},
507508
)
508-
checkpoints_name: Optional[str] = field(
509+
checkpoints_name: str | None = field(
509510
default=None,
510511
metadata={
511512
"help": "The name under which the models will be saved. If left blank,"
@@ -521,7 +522,7 @@ class TrainerConfig:
521522
default=1,
522523
metadata={"help": "The number of best models to save"},
523524
)
524-
checkpoints_kwargs: Optional[Dict[str, Any]] = field(
525+
checkpoints_kwargs: dict[str, Any] | None = field(
525526
default_factory=lambda: {},
526527
metadata={
527528
"help": "Additional keyword arguments for the checkpoints callback. See the documentation"
@@ -553,7 +554,7 @@ class TrainerConfig:
553554
default=42,
554555
metadata={"help": "Seed for random number generators. Defaults to 42"},
555556
)
556-
trainer_kwargs: Dict[str, Any] = field(
557+
trainer_kwargs: dict[str, Any] = field(
557558
default_factory=dict,
558559
metadata={"help": "Additional kwargs to be passed to PyTorch Lightning Trainer."},
559560
)
@@ -611,14 +612,14 @@ class ExperimentConfig:
611612
},
612613
)
613614

614-
run_name: Optional[str] = field(
615+
run_name: str | None = field(
615616
default=None,
616617
metadata={
617618
"help": "The name of the run; a specific identifier to recognize the run."
618619
" If left blank, will be assigned a auto-generated name"
619620
},
620621
)
621-
exp_watch: Optional[str] = field(
622+
exp_watch: str | None = field(
622623
default=None,
623624
metadata={
624625
"help": "The level of logging required. Can be `gradients`, `parameters`, `all` or `None`."
@@ -690,29 +691,29 @@ class OptimizerConfig:
690691
" for example 'torch_optimizer.RAdam'."
691692
},
692693
)
693-
optimizer_params: Dict = field(
694+
optimizer_params: dict = field(
694695
default_factory=lambda: {},
695696
metadata={"help": "The parameters for the optimizer. If left blank, will use default parameters."},
696697
)
697-
lr_scheduler: Optional[str] = field(
698+
lr_scheduler: str | None = field(
698699
default=None,
699700
metadata={
700701
"help": "The name of the LearningRateScheduler to use, if any, from"
701702
" https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate."
702703
" If None, will not use any scheduler. Defaults to `None`",
703704
},
704705
)
705-
lr_scheduler_params: Optional[Dict] = field(
706+
lr_scheduler_params: dict | None = field(
706707
default_factory=lambda: {},
707708
metadata={"help": "The parameters for the LearningRateScheduler. If left blank, will use default parameters."},
708709
)
709710

710-
lr_scheduler_monitor_metric: Optional[str] = field(
711+
lr_scheduler_monitor_metric: str | None = field(
711712
default="valid_loss",
712713
metadata={"help": "Used with ReduceLROnPlateau, where the plateau is decided based on this metric"},
713714
)
714715

715-
lr_scheduler_interval: Optional[str] = field(
716+
lr_scheduler_interval: str | None = field(
716717
default="epoch",
717718
metadata={"help": "Interval at which to step the LR Scheduler, one of `epoch` or `step`. Defaults to `epoch`."},
718719
)
@@ -823,7 +824,7 @@ class ModelConfig:
823824
}
824825
)
825826

826-
head: Optional[str] = field(
827+
head: str | None = field(
827828
default="LinearHead",
828829
metadata={
829830
"help": "The head to be used for the model. Should be one of the heads defined"
@@ -832,14 +833,14 @@ class ModelConfig:
832833
},
833834
)
834835

835-
head_config: Optional[Dict] = field(
836+
head_config: dict | None = field(
836837
default_factory=lambda: {"layers": ""},
837838
metadata={
838839
"help": "The config as a dict which defines the head."
839840
" If left empty, will be initialized as default linear head."
840841
},
841842
)
842-
embedding_dims: Optional[List] = field(
843+
embedding_dims: list | None = field(
843844
default=None,
844845
metadata={
845846
"help": "The dimensions of the embedding for each categorical column as a list of tuples "
@@ -860,15 +861,15 @@ class ModelConfig:
860861
default=1e-3,
861862
metadata={"help": "The learning rate of the model. Defaults to 1e-3."},
862863
)
863-
loss: Optional[str] = field(
864+
loss: str | None = field(
864865
default=None,
865866
metadata={
866867
"help": "The loss function to be applied. By Default it is MSELoss for regression "
867868
"and CrossEntropyLoss for classification. Unless you are sure what you are doing, "
868869
"leave it at MSELoss or L1Loss for regression and CrossEntropyLoss for classification"
869870
},
870871
)
871-
metrics: Optional[List[str]] = field(
872+
metrics: list[str] | None = field(
872873
default=None,
873874
metadata={
874875
"help": "the list of metrics you need to track during training. The metrics should be one "
@@ -877,23 +878,23 @@ class ModelConfig:
877878
"and mean_squared_error for regression"
878879
},
879880
)
880-
metrics_prob_input: Optional[List[bool]] = field(
881+
metrics_prob_input: list[bool] | None = field(
881882
default=None,
882883
metadata={
883884
"help": "Is a mandatory parameter for classification metrics defined in the config. This defines "
884885
"whether the input to the metric function is the probability or the class. Length should be same "
885886
"as the number of metrics. Defaults to None."
886887
},
887888
)
888-
metrics_params: Optional[List] = field(
889+
metrics_params: list | None = field(
889890
default=None,
890891
metadata={
891892
"help": "The parameters to be passed to the metrics function. `task` is forced to be `multiclass`` "
892893
"because the multiclass version can handle binary as well and for simplicity we are only using "
893894
"`multiclass`."
894895
},
895896
)
896-
target_range: Optional[List] = field(
897+
target_range: list | None = field(
897898
default=None,
898899
metadata={
899900
"help": "The range in which we should limit the output variable. "
@@ -902,7 +903,7 @@ class ModelConfig:
902903
},
903904
)
904905

905-
virtual_batch_size: Optional[int] = field(
906+
virtual_batch_size: int | None = field(
906907
default=None,
907908
metadata={
908909
"help": "If not None, all BatchNorms will be converted to GhostBatchNorm's "
@@ -1001,23 +1002,23 @@ class SSLModelConfig:
10011002

10021003
task: str = field(init=False, default="ssl")
10031004

1004-
encoder_config: Optional[ModelConfig] = field(
1005+
encoder_config: ModelConfig | None = field(
10051006
default=None,
10061007
metadata={
10071008
"help": "The config of the encoder to be used for the model."
10081009
" Should be one of the model configs defined in PyTorch Tabular",
10091010
},
10101011
)
10111012

1012-
decoder_config: Optional[ModelConfig] = field(
1013+
decoder_config: ModelConfig | None = field(
10131014
default=None,
10141015
metadata={
10151016
"help": "The config of decoder to be used for the model."
10161017
" Should be one of the model configs defined in PyTorch Tabular. Defaults to nn.Identity",
10171018
},
10181019
)
10191020

1020-
embedding_dims: Optional[List] = field(
1021+
embedding_dims: list | None = field(
10211022
default=None,
10221023
metadata={
10231024
"help": "The dimensions of the embedding for each categorical column as a list of tuples "
@@ -1033,7 +1034,7 @@ class SSLModelConfig:
10331034
default=True,
10341035
metadata={"help": "If True, we will normalize the continuous layer by passing it through a BatchNorm layer."},
10351036
)
1036-
virtual_batch_size: Optional[int] = field(
1037+
virtual_batch_size: int | None = field(
10371038
default=None,
10381039
metadata={
10391040
"help": "If not None, all BatchNorms will be converted to GhostBatchNorm's "

0 commit comments

Comments
 (0)