diff --git a/project/algorithms/image_classifier.py b/project/algorithms/image_classifier.py index 789fab7c..2abd7e5e 100644 --- a/project/algorithms/image_classifier.py +++ b/project/algorithms/image_classifier.py @@ -65,6 +65,9 @@ def __init__( def configure_model(self): # Save this for PyTorch-Lightning to infer the input/output shapes of the network. + if self.network is not None: + logger.info("Network is already instantiated.") + return self.example_input_array = torch.zeros((self.datamodule.batch_size, *self.datamodule.dims)) with torch.random.fork_rng(): # deterministic weight initialization diff --git a/project/algorithms/jax_image_classifier.py b/project/algorithms/jax_image_classifier.py index 4d8ad3d2..c3fe012e 100644 --- a/project/algorithms/jax_image_classifier.py +++ b/project/algorithms/jax_image_classifier.py @@ -22,6 +22,8 @@ from project.datamodules.image_classification.mnist import MNISTDataModule from project.utils.typing_utils import HydraConfigFor +logger = logging.getLogger(__name__) + def flatten(x: jax.Array) -> jax.Array: return x.reshape((x.shape[0], -1)) @@ -95,6 +97,9 @@ def __init__( self.save_hyperparameters(ignore=["datamodule"]) def configure_model(self): + if self.network is not None: + logger.info("Network is already instantiated.") + return example_input = torch.zeros( (self.datamodule.batch_size, *self.datamodule.dims), ) diff --git a/project/algorithms/jax_ppo_test.py b/project/algorithms/jax_ppo_test.py index d90a6c4d..05f2c0f1 100644 --- a/project/algorithms/jax_ppo_test.py +++ b/project/algorithms/jax_ppo_test.py @@ -532,8 +532,13 @@ def __init__( self.num_train_iterations = np.ceil(self.learner.hp.eval_freq / iteration_steps).astype( int ) + self.actor_params: torch.nn.ParameterList | None = None + self.critic_params: torch.nn.ParameterList | None = None def configure_model(self): + if self.actor_params is not None: + logger.info("Networks are already instantiated.") + return self.actor_params = torch.nn.ParameterList( jax.tree.leaves( jax.tree.map( diff --git a/project/algorithms/text_classifier.py b/project/algorithms/text_classifier.py index 1d9ea4e1..213721c6 100644 --- a/project/algorithms/text_classifier.py +++ b/project/algorithms/text_classifier.py @@ -1,3 +1,4 @@ +import logging from datetime import datetime import evaluate @@ -14,6 +15,8 @@ from project.datamodules.text.text_classification import TextClassificationDataModule from project.utils.typing_utils import HydraConfigFor +logger = logging.getLogger(__name__) + class TextClassifier(LightningModule): """Example of a lightning module used to train a huggingface model for text classification.""" @@ -51,6 +54,9 @@ def __init__( self.save_hyperparameters(ignore=["datamodule"]) def configure_model(self) -> None: + if self.network is not None: + logger.info("Network is already instantiated.") + return with torch.random.fork_rng(devices=[self.device]): # deterministic weight initialization torch.manual_seed(self.init_seed) diff --git a/project/experiment.py b/project/experiment.py index 4ab69459..6a5945c6 100644 --- a/project/experiment.py +++ b/project/experiment.py @@ -112,16 +112,10 @@ def train_lightning( algorithm: lightning.LightningModule, /, *, - trainer: lightning.Trainer | None, + trainer: lightning.Trainer, datamodule: lightning.LightningDataModule | None = None, config: Config, ): - # Create the Trainer from the config. - if trainer is None: - _trainer = instantiate_trainer(config.trainer) - assert isinstance(_trainer, lightning.Trainer) - trainer = _trainer - # Train the model using the dataloaders of the datamodule: # The Algorithm gets to "wrap" the datamodule if it wants to. This could be useful for # example in RL, where we need to set the actor to use in the environment, as well as