55
66import os
77import re
8+ from collections .abc import Iterable
89from dataclasses import MISSING , dataclass , field
9- from typing import Any , Dict , Iterable , List , Optional
10+ from typing import Any
1011
1112from omegaconf import OmegaConf
1213
@@ -103,22 +104,22 @@ class DataConfig:
103104
104105 """
105106
106- target : Optional [ List [ str ]] = field (
107+ target : list [ str ] | None = field (
107108 default = None ,
108109 metadata = {
109110 "help" : "A list of strings with the names of the target column(s)."
110111 " It is mandatory for all except SSL tasks."
111112 },
112113 )
113- continuous_cols : List = field (
114+ continuous_cols : list = field (
114115 default_factory = list ,
115116 metadata = {"help" : "Column names of the numeric fields. Defaults to []" },
116117 )
117- categorical_cols : List = field (
118+ categorical_cols : list = field (
118119 default_factory = list ,
119120 metadata = {"help" : "Column names of the categorical fields to treat differently. Defaults to []" },
120121 )
121- date_columns : List = field (
122+ date_columns : list = field (
122123 default_factory = list ,
123124 metadata = {
124125 "help" : "(Column names, Freq) tuples of the date fields. For eg. a field named"
@@ -131,14 +132,14 @@ class DataConfig:
131132 default = True ,
132133 metadata = {"help" : "Whether or not to encode the derived variables from date" },
133134 )
134- validation_split : Optional [ float ] = field (
135+ validation_split : float | None = field (
135136 default = 0.2 ,
136137 metadata = {
137138 "help" : "Percentage of Training rows to keep aside as validation."
138139 " Used only if Validation Data is not given separately"
139140 },
140141 )
141- continuous_feature_transform : Optional [ str ] = field (
142+ continuous_feature_transform : str | None = field (
142143 default = None ,
143144 metadata = {
144145 "help" : "Whether or not to transform the features before modelling. By default it is turned off." ,
@@ -164,7 +165,7 @@ class DataConfig:
164165 " the noise is only applied for QuantileTransformer"
165166 },
166167 )
167- num_workers : Optional [ int ] = field (
168+ num_workers : int | None = field (
168169 default = 0 ,
169170 metadata = {"help" : "The number of workers used for data loading. For windows always set to 0" },
170171 )
@@ -186,7 +187,7 @@ class DataConfig:
186187 metadata = {"help" : "pickle protocol version passed to `torch.save` for dataset caching to disk" },
187188 )
188189
189- dataloader_kwargs : Dict [str , Any ] = field (
190+ dataloader_kwargs : dict [str , Any ] = field (
190191 default_factory = dict ,
191192 metadata = {"help" : "Additional kwargs to be passed to PyTorch DataLoader." },
192193 )
@@ -229,19 +230,19 @@ class InferredConfig:
229230 continuous_dim : int = field (
230231 metadata = {"help" : "The number of continuous features" },
231232 )
232- output_dim : Optional [ int ] = field (
233+ output_dim : int | None = field (
233234 default = None ,
234235 metadata = {"help" : "The number of output targets" },
235236 )
236- output_cardinality : Optional [ List [ int ]] = field (
237+ output_cardinality : list [ int ] | None = field (
237238 default = None ,
238239 metadata = {"help" : "The number of unique values in classification output" },
239240 )
240- categorical_cardinality : Optional [ List [ int ]] = field (
241+ categorical_cardinality : list [ int ] | None = field (
241242 default = None ,
242243 metadata = {"help" : "The number of unique values in categorical features" },
243244 )
244- embedding_dims : Optional [ List ] = field (
245+ embedding_dims : list | None = field (
245246 default = None ,
246247 metadata = {
247248 "help" : "The dimensions of the embedding for each categorical column as a list of tuples "
@@ -384,30 +385,30 @@ class TrainerConfig:
384385 },
385386 )
386387 max_epochs : int = field (default = 10 , metadata = {"help" : "Maximum number of epochs to be run" })
387- min_epochs : Optional [ int ] = field (
388+ min_epochs : int | None = field (
388389 default = 1 ,
389390 metadata = {"help" : "Force training for at least these many epochs. 1 by default" },
390391 )
391- max_time : Optional [ int ] = field (
392+ max_time : int | None = field (
392393 default = None ,
393394 metadata = {"help" : "Stop training after this amount of time has passed. Disabled by default (None)" },
394395 )
395- accelerator : Optional [ str ] = field (
396+ accelerator : str | None = field (
396397 default = "auto" ,
397398 metadata = {
398399 "help" : "The accelerator to use for training. Can be one of 'cpu','gpu','tpu','ipu','auto'."
399400 " Defaults to 'auto'" ,
400401 "choices" : ["cpu" , "gpu" , "tpu" , "ipu" , "mps" , "auto" ],
401402 },
402403 )
403- devices : Optional [ int ] = field (
404+ devices : int | None = field (
404405 default = - 1 ,
405406 metadata = {
406407 "help" : "Number of devices to train on. -1 uses all available devices."
407408 " By default uses all available devices (-1)" ,
408409 },
409410 )
410- devices_list : Optional [ List [ int ]] = field (
411+ devices_list : list [ int ] | None = field (
411412 default = None ,
412413 metadata = {
413414 "help" : "List of devices to train on (list). If specified, takes precedence over `devices` argument."
@@ -454,15 +455,15 @@ class TrainerConfig:
454455 "help" : "If true enables cudnn.deterministic. Might make your system slower, but ensures reproducibility."
455456 },
456457 )
457- profiler : Optional [ str ] = field (
458+ profiler : str | None = field (
458459 default = None ,
459460 metadata = {
460461 "help" : "To profile individual steps during training and assist in identifying bottlenecks."
461462 " None, simple or advanced, pytorch" ,
462463 "choices" : [None , "simple" , "advanced" , "pytorch" ],
463464 },
464465 )
465- early_stopping : Optional [ str ] = field (
466+ early_stopping : str | None = field (
466467 default = "valid_loss" ,
467468 metadata = {
468469 "help" : "The loss/metric that needed to be monitored for early stopping."
@@ -484,14 +485,14 @@ class TrainerConfig:
484485 default = 3 ,
485486 metadata = {"help" : "The number of epochs to wait until there is no further improvements in loss/metric" },
486487 )
487- early_stopping_kwargs : Optional [ Dict [ str , Any ]] = field (
488+ early_stopping_kwargs : dict [ str , Any ] | None = field (
488489 default_factory = lambda : {},
489490 metadata = {
490491 "help" : "Additional keyword arguments for the early stopping callback."
491492 " See the documentation for the PyTorch Lightning EarlyStopping callback for more details."
492493 },
493494 )
494- checkpoints : Optional [ str ] = field (
495+ checkpoints : str | None = field (
495496 default = "valid_loss" ,
496497 metadata = {
497498 "help" : "The loss/metric that needed to be monitored for checkpoints. If None, there will be no checkpoints"
@@ -505,7 +506,7 @@ class TrainerConfig:
505506 default = 1 ,
506507 metadata = {"help" : "Number of training steps between checkpoints" },
507508 )
508- checkpoints_name : Optional [ str ] = field (
509+ checkpoints_name : str | None = field (
509510 default = None ,
510511 metadata = {
511512 "help" : "The name under which the models will be saved. If left blank,"
@@ -521,7 +522,7 @@ class TrainerConfig:
521522 default = 1 ,
522523 metadata = {"help" : "The number of best models to save" },
523524 )
524- checkpoints_kwargs : Optional [ Dict [ str , Any ]] = field (
525+ checkpoints_kwargs : dict [ str , Any ] | None = field (
525526 default_factory = lambda : {},
526527 metadata = {
527528 "help" : "Additional keyword arguments for the checkpoints callback. See the documentation"
@@ -553,7 +554,7 @@ class TrainerConfig:
553554 default = 42 ,
554555 metadata = {"help" : "Seed for random number generators. Defaults to 42" },
555556 )
556- trainer_kwargs : Dict [str , Any ] = field (
557+ trainer_kwargs : dict [str , Any ] = field (
557558 default_factory = dict ,
558559 metadata = {"help" : "Additional kwargs to be passed to PyTorch Lightning Trainer." },
559560 )
@@ -611,14 +612,14 @@ class ExperimentConfig:
611612 },
612613 )
613614
614- run_name : Optional [ str ] = field (
615+ run_name : str | None = field (
615616 default = None ,
616617 metadata = {
617618 "help" : "The name of the run; a specific identifier to recognize the run."
618619 " If left blank, will be assigned a auto-generated name"
619620 },
620621 )
621- exp_watch : Optional [ str ] = field (
622+ exp_watch : str | None = field (
622623 default = None ,
623624 metadata = {
624625 "help" : "The level of logging required. Can be `gradients`, `parameters`, `all` or `None`."
@@ -690,29 +691,29 @@ class OptimizerConfig:
690691 " for example 'torch_optimizer.RAdam'."
691692 },
692693 )
693- optimizer_params : Dict = field (
694+ optimizer_params : dict = field (
694695 default_factory = lambda : {},
695696 metadata = {"help" : "The parameters for the optimizer. If left blank, will use default parameters." },
696697 )
697- lr_scheduler : Optional [ str ] = field (
698+ lr_scheduler : str | None = field (
698699 default = None ,
699700 metadata = {
700701 "help" : "The name of the LearningRateScheduler to use, if any, from"
701702 " https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate."
702703 " If None, will not use any scheduler. Defaults to `None`" ,
703704 },
704705 )
705- lr_scheduler_params : Optional [ Dict ] = field (
706+ lr_scheduler_params : dict | None = field (
706707 default_factory = lambda : {},
707708 metadata = {"help" : "The parameters for the LearningRateScheduler. If left blank, will use default parameters." },
708709 )
709710
710- lr_scheduler_monitor_metric : Optional [ str ] = field (
711+ lr_scheduler_monitor_metric : str | None = field (
711712 default = "valid_loss" ,
712713 metadata = {"help" : "Used with ReduceLROnPlateau, where the plateau is decided based on this metric" },
713714 )
714715
715- lr_scheduler_interval : Optional [ str ] = field (
716+ lr_scheduler_interval : str | None = field (
716717 default = "epoch" ,
717718 metadata = {"help" : "Interval at which to step the LR Scheduler, one of `epoch` or `step`. Defaults to `epoch`." },
718719 )
@@ -823,7 +824,7 @@ class ModelConfig:
823824 }
824825 )
825826
826- head : Optional [ str ] = field (
827+ head : str | None = field (
827828 default = "LinearHead" ,
828829 metadata = {
829830 "help" : "The head to be used for the model. Should be one of the heads defined"
@@ -832,14 +833,14 @@ class ModelConfig:
832833 },
833834 )
834835
835- head_config : Optional [ Dict ] = field (
836+ head_config : dict | None = field (
836837 default_factory = lambda : {"layers" : "" },
837838 metadata = {
838839 "help" : "The config as a dict which defines the head."
839840 " If left empty, will be initialized as default linear head."
840841 },
841842 )
842- embedding_dims : Optional [ List ] = field (
843+ embedding_dims : list | None = field (
843844 default = None ,
844845 metadata = {
845846 "help" : "The dimensions of the embedding for each categorical column as a list of tuples "
@@ -860,15 +861,15 @@ class ModelConfig:
860861 default = 1e-3 ,
861862 metadata = {"help" : "The learning rate of the model. Defaults to 1e-3." },
862863 )
863- loss : Optional [ str ] = field (
864+ loss : str | None = field (
864865 default = None ,
865866 metadata = {
866867 "help" : "The loss function to be applied. By Default it is MSELoss for regression "
867868 "and CrossEntropyLoss for classification. Unless you are sure what you are doing, "
868869 "leave it at MSELoss or L1Loss for regression and CrossEntropyLoss for classification"
869870 },
870871 )
871- metrics : Optional [ List [ str ]] = field (
872+ metrics : list [ str ] | None = field (
872873 default = None ,
873874 metadata = {
874875 "help" : "the list of metrics you need to track during training. The metrics should be one "
@@ -877,23 +878,23 @@ class ModelConfig:
877878 "and mean_squared_error for regression"
878879 },
879880 )
880- metrics_prob_input : Optional [ List [ bool ]] = field (
881+ metrics_prob_input : list [ bool ] | None = field (
881882 default = None ,
882883 metadata = {
883884 "help" : "Is a mandatory parameter for classification metrics defined in the config. This defines "
884885 "whether the input to the metric function is the probability or the class. Length should be same "
885886 "as the number of metrics. Defaults to None."
886887 },
887888 )
888- metrics_params : Optional [ List ] = field (
889+ metrics_params : list | None = field (
889890 default = None ,
890891 metadata = {
891892 "help" : "The parameters to be passed to the metrics function. `task` is forced to be `multiclass`` "
892893 "because the multiclass version can handle binary as well and for simplicity we are only using "
893894 "`multiclass`."
894895 },
895896 )
896- target_range : Optional [ List ] = field (
897+ target_range : list | None = field (
897898 default = None ,
898899 metadata = {
899900 "help" : "The range in which we should limit the output variable. "
@@ -902,7 +903,7 @@ class ModelConfig:
902903 },
903904 )
904905
905- virtual_batch_size : Optional [ int ] = field (
906+ virtual_batch_size : int | None = field (
906907 default = None ,
907908 metadata = {
908909 "help" : "If not None, all BatchNorms will be converted to GhostBatchNorm's "
@@ -1001,23 +1002,23 @@ class SSLModelConfig:
10011002
10021003 task : str = field (init = False , default = "ssl" )
10031004
1004- encoder_config : Optional [ ModelConfig ] = field (
1005+ encoder_config : ModelConfig | None = field (
10051006 default = None ,
10061007 metadata = {
10071008 "help" : "The config of the encoder to be used for the model."
10081009 " Should be one of the model configs defined in PyTorch Tabular" ,
10091010 },
10101011 )
10111012
1012- decoder_config : Optional [ ModelConfig ] = field (
1013+ decoder_config : ModelConfig | None = field (
10131014 default = None ,
10141015 metadata = {
10151016 "help" : "The config of decoder to be used for the model."
10161017 " Should be one of the model configs defined in PyTorch Tabular. Defaults to nn.Identity" ,
10171018 },
10181019 )
10191020
1020- embedding_dims : Optional [ List ] = field (
1021+ embedding_dims : list | None = field (
10211022 default = None ,
10221023 metadata = {
10231024 "help" : "The dimensions of the embedding for each categorical column as a list of tuples "
@@ -1033,7 +1034,7 @@ class SSLModelConfig:
10331034 default = True ,
10341035 metadata = {"help" : "If True, we will normalize the continuous layer by passing it through a BatchNorm layer." },
10351036 )
1036- virtual_batch_size : Optional [ int ] = field (
1037+ virtual_batch_size : int | None = field (
10371038 default = None ,
10381039 metadata = {
10391040 "help" : "If not None, all BatchNorms will be converted to GhostBatchNorm's "
0 commit comments