Skip to content

ENHANCEMENT: fit and test #302

@robmarkcole

Description

@robmarkcole

A workflow I implemented in another cli is to support a fit_and_test command, so test is performed in same run as train. Let me know if you want this implemented and I can refine the approach below (currently assumes a best checkpoint)

Doc

To run evaluation immediately after training in a single invocation, use the combined
command:

rslearn model fit_and_test --config land_cover_model.yaml

This executes the fit loop, reloads the best checkpoint if available, runs the test
split, and then exits without requiring a second CLI call

Implementation in main.py (to be refined)

@register_handler("model", "fit_and_test")
def model_fit_and_test() -> None:
    """Handler that runs training immediately followed by testing."""
    cli = RslearnLightningCLI(
        model_class=RslearnLightningModule,
        datamodule_class=RslearnDataModule,
        args=sys.argv[2:],
        subclass_mode_model=True,
        subclass_mode_data=True,
        save_config_kwargs={"overwrite": True},
        parser_class=RslearnArgumentParser,
        run=False,
    )

    trainer = cli.trainer
    model = cli.model
    datamodule = cli.datamodule

    datamodule.setup("fit")
    datamodule.setup("test")

    logger.info("Starting training phase")
    trainer.fit(model=model, datamodule=datamodule)

    ckpt_available = False
    for callback in getattr(trainer, "callbacks", []):
        if ModelCheckpoint is not None and isinstance(callback, ModelCheckpoint):
            if getattr(callback, "best_model_path", ""):
                ckpt_available = True
                break
    if not ckpt_available:
        default_checkpoint = getattr(trainer, "checkpoint_callback", None)
        if ModelCheckpoint is not None and isinstance(default_checkpoint, ModelCheckpoint):
            ckpt_available = bool(getattr(default_checkpoint, "best_model_path", ""))

    logger.info("Starting test phase")
    if ckpt_available:
        trainer.test(model=model, datamodule=datamodule, ckpt_path="best")
    else:
        trainer.test(model=model, datamodule=datamodule)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions