Open
Description
Maybe something like
def solve_cli_model(Model, args, config_file, default_seed = 123, ):
sys.argv = ["dummy.py"] + [f"--{key}={val}" for key, val in args.items()] # hack overwriting argv
cli = LightningCLI(
Model,
run=False,
seed_everything_default=default_seed,
save_config_overwrite=True,
parser_kwargs={"default_config_files": [config_file]},
)
# Solves the model
trainer = cli.instantiate_trainer(
logger=None,
checkpoint_callback=None,
callbacks=[], # not using the early stopping/etc.
)
trainer.fit(cli.model)
# Calculates the "test" values for it
trainer.test(cli.model)
cli.model.eval() # Turn off training mode, where it calculates gradients for every call.
return cli.model, cli
Except maybe give a few options with defaults as:
use_logger = False
, turns off the logger if they ask.checkpoint_callback = False
callbacks = False
where ifFalse
it zeros them out. Otherwise leaves them be.test = True
, for whe3ther to run thetest
or not.
Etc.
Metadata
Metadata
Assignees
Labels
No labels