-
Notifications
You must be signed in to change notification settings - Fork 12
Open
Description
A workflow I implemented in another cli is to support a fit_and_test command, so test is performed in same run as train. Let me know if you want this implemented and I can refine the approach below (currently assumes a best checkpoint)
Doc
To run evaluation immediately after training in a single invocation, use the combined
command:
rslearn model fit_and_test --config land_cover_model.yaml
This executes the fit loop, reloads the best checkpoint if available, runs the test
split, and then exits without requiring a second CLI call
Implementation in main.py (to be refined)
@register_handler("model", "fit_and_test")
def model_fit_and_test() -> None:
"""Handler that runs training immediately followed by testing."""
cli = RslearnLightningCLI(
model_class=RslearnLightningModule,
datamodule_class=RslearnDataModule,
args=sys.argv[2:],
subclass_mode_model=True,
subclass_mode_data=True,
save_config_kwargs={"overwrite": True},
parser_class=RslearnArgumentParser,
run=False,
)
trainer = cli.trainer
model = cli.model
datamodule = cli.datamodule
datamodule.setup("fit")
datamodule.setup("test")
logger.info("Starting training phase")
trainer.fit(model=model, datamodule=datamodule)
ckpt_available = False
for callback in getattr(trainer, "callbacks", []):
if ModelCheckpoint is not None and isinstance(callback, ModelCheckpoint):
if getattr(callback, "best_model_path", ""):
ckpt_available = True
break
if not ckpt_available:
default_checkpoint = getattr(trainer, "checkpoint_callback", None)
if ModelCheckpoint is not None and isinstance(default_checkpoint, ModelCheckpoint):
ckpt_available = bool(getattr(default_checkpoint, "best_model_path", ""))
logger.info("Starting test phase")
if ckpt_available:
trainer.test(model=model, datamodule=datamodule, ckpt_path="best")
else:
trainer.test(model=model, datamodule=datamodule)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels