-
Notifications
You must be signed in to change notification settings - Fork 413
Description
Currently we use kwargs plumbing extensively to route different sets of arguments to model parameters, training, dataloaders, etc. For example, for SCVI
:
model = SCVI(**kwargs) <- kwargs for VAE
model.train(
datasplitter_kwargs=... <- kwargs for DataSplitter
plan_kwargs=... <- kwargs for TrainingPlan
**kwargs <- kwargs for Trainer
)
All these different sets of kwargs can be confusing since users need to refer to different parts of the documentation, and they might not know where a certain parameter should go (e.g. should learning rate go to the Trainer
or TrainingPlan
?)
One way to address this would be to use configs defined by dataclasses. For example:
@dataclass
class VAEConfig:
n_hidden: int = 128
n_latent: int = 10
...
So just initializing SCVI()
would use the default config, but in order to change a parameter, a user would have to define a new dataclass VAEConfig(n_hidden=256)
that can automatically check arguments. The advantage here would be even more clear for the train function, where the docstring would only contain descriptions for DataSplitterConfig
, TrainingPlanConfig
, and TrainerConfig
. This means that all models using, e.g. TrainingPlan
, will only have one config, leading to less redundancy.