Skip to content
Open
Changes from 4 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
2b55d23
lightning example
kingjr Dec 11, 2024
113e828
up
kingjr Dec 11, 2024
5077a8b
add checkpoint
kingjr Dec 11, 2024
a23bccf
check that default config match original class
kingjr Dec 11, 2024
10ae19b
address comments
kingjr Dec 13, 2024
e1d381a
config.build
kingjr Dec 16, 2024
9c36d9c
Update example_lightning.py
jrapin Dec 24, 2024
2751334
Merge branch 'main' into lightning
jrapin Dec 24, 2024
7912ba6
Add packages for examples in docs
jrapin Dec 24, 2024
642aeb0
Update .github/workflows/test-type-lint.yaml
jrapin Dec 24, 2024
51ee016
Merge branch 'test/add-packages-for-docs' into lightning
jrapin Dec 24, 2024
2ac8a53
Merge branch 'test/add-packages-for-docs' into lightning
jrapin Dec 24, 2024
588ef19
add
jrapin Dec 24, 2024
2d4d226
Merge remote main via HTTPS, resolve CI conflict
kingjr Feb 25, 2026
bc98b74
Add to_step and to_chain helpers for function-to-Step conversion
kingjr Feb 25, 2026
da885c1
Fix mypy errors: add type: ignore for dynamic model fields
kingjr Feb 25, 2026
747ff86
Fix test failures: remove leading underscores from helpers, fix typos
kingjr Feb 25, 2026
7312d5a
Fix to_chain infra: use field default instead of post_init override
kingjr Feb 25, 2026
cf6db3a
Fix to_chain infra: validate dict into Backend via model_validate
kingjr Feb 25, 2026
2711ca5
Fix black and isort formatting
kingjr Feb 25, 2026
2d94b8f
Remove extra blank line (black 24.3.0 compat)
kingjr Feb 25, 2026
db228c2
Simplify to_step/to_chain tests: merge into two test functions
kingjr Feb 25, 2026
d2bd8e3
Fix mypy: rename variable g -> gen to avoid redefinition
kingjr Feb 25, 2026
cd9dd7e
Fix black formatting (24.3.0)
kingjr Feb 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions docs/infra/example_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import inspect
import sys
import typing as tp

import exca
import pydantic
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
from torchvision import datasets, transforms
from torchvision.models import resnet18


class ResNet(pl.LightningModule):
def __init__(self, pretrained: bool=True, learning_rate: float=0.001):
super(ResNet, self).__init__()
self.pretrained = pretrained
self.learning_rate = learning_rate
self.model = resnet18(pretrained=pretrained)
self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
self.model.fc = torch.nn.Linear(self.model.fc.in_features, 10)
self.loss_fn = torch.nn.CrossEntropyLoss()

def forward(self, x):
return self.model(x)

def _step(self, batch):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
return loss

def training_step(self, batch, batch_idx):
loss = self._step(batch)
self.log("train_loss", loss)
return loss

def validation_step(self, batch, batch_idx):
loss = self._step(batch)
self.log("val_loss", loss)
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)


class Mnist(pl.LightningDataModule):
def __init__(self, batch_size=64):
super().__init__()
self.batch_size = batch_size

def _dataloader(self, train: bool):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
dset = datasets.MNIST('', train=train, download=True, transform=transform)
return torch.utils.data.DataLoader(dset, batch_size=self.batch_size)

def train_dataloader(self):
return self._dataloader(train=True)

def val_dataloader(self):
return self._dataloader(train=False)


class AutoConfig(pydantic.BaseModel): # TODO move to exca.helpers?
model_config = pydantic.ConfigDict(extra="forbid")
_cls: tp.ClassVar[tp.Any]

@classmethod
def __pydantic_init_subclass__(cls, **kwargs: tp.Any) -> None:
"""Checks that the config default values match the _cls defaults"""
super().__pydantic_init_subclass__(**kwargs)
super().__init_subclass__()

if isinstance(cls._cls, type):
func_or_class = cls._cls.__init__
else:
func_or_class = cls._cls

# Get the function signature
signature = inspect.signature(func_or_class)
func_params = signature.parameters

# Iterate through the class fields and verify their defaults
for field_name, field_info in cls.model_fields.items():
# Check if the field has a default value or is required
model_default = field_info.default
model_required = field_info.is_required()

# Check if the parameter exists in the function signature
if field_name not in func_params:
raise ValueError(f"Field '{field_name}' is missing in the function parameters.")

func_param = func_params[field_name]
func_default = func_param.default

# Check if the field is required in both the function and the model
if model_required != (func_default is inspect.Parameter.empty):
raise ValueError(f"Field '{field_name}' is required in the model but not in the function or vice versa.")

# If it has a default in both, compare them
if model_default != func_default and func_default is not inspect.Parameter.empty:
raise ValueError(f"Field '{field_name}' default value mismatch: model has '{model_default}', function has '{func_default}'.")

def model_post_init(self, log__: tp.Any) -> None:
"""Check that the parameters are compatible with _cls"""
super().model_post_init(log__)
exca.helpers.validate_kwargs(self._cls, self.dict())

def build(self, **kwargs): # /!\ **kwargs needed for trainer checkpoint, but bad api for uid?
return self._cls(**self.dict(), **kwargs)


def args_to_nested_dict(args: list[str]) -> tp.Dict[str, tp.Any]: # TODO move to exca.helpers?
"""
Parses a list of Bash-style arguments (e.g., --key=value) into a nested dict.
"""
nested_dict = {}
for arg in args:
# Split argument into key and value
key, value = arg.lstrip("--").split("=", 1)
# Convert flat key into a nested dictionary
keys = key.split(".")
current_level = nested_dict
for k in keys[:-1]:
current_level = current_level.setdefault(k, {})
current_level[keys[-1]] = value
return nested_dict
Comment thread
kingjr marked this conversation as resolved.
Outdated


class ModelConfig(AutoConfig): # question: right design?
pretrained: bool = True
learning_rate: float = 0.001
_cls = ResNet
Comment thread
kingjr marked this conversation as resolved.
Outdated


class MnistConfig(AutoConfig):
batch_size: int = 64 # question: uid change if add new param, but corresponds to default trainer?
_cls = Mnist


class TrainerConfig(AutoConfig):
max_epochs: tp.Optional[int] = None
_cls = Trainer


class Experiment(pydantic.BaseModel):
model: ModelConfig = ModelConfig()
data: MnistConfig = MnistConfig()
trainer: TrainerConfig = TrainerConfig()

infra: exca.TaskInfra = exca.TaskInfra()

def build(self):
mnist = self.data.build()
model = self.model.build()
callbacks = None
if self.infra.folder:
callbacks = [ModelCheckpoint(
dirpath=self.infra.uid_folder() / 'checkpoint',
save_top_k=1,
monitor="val_loss",
mode="min")
]
trainer = self.trainer.build(callbacks=callbacks)
return mnist, model, trainer

@infra.apply
def fit(self):
data_loaders, model, trainer = self.build()
# Define the checkpoint directory
checkpoint_dir = self.infra.uid_folder() / 'checkpoint'

# Find the latest checkpoint if it exists
checkpoints = sorted(checkpoint_dir.glob('*.ckpt'))
ckpt_path = sorted(checkpoints)[-1] if checkpoints else None

# Fit model
trainer.fit(model, data_loaders, ckpt_path=ckpt_path)
return model
Comment thread
kingjr marked this conversation as resolved.
Outdated

def validate(self):
data_loaders, _, trainer = self.build()
model = self.fit()
return trainer.validate(model, dataloaders=data_loaders.val_dataloader())


if __name__ == '__main__':
config = args_to_nested_dict(['--trainer.max_epochs=5'] + sys.argv[1:])
exp = Experiment(**config)
score = exp.validate()
print(score)