Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions project/algorithms/image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __init__(

def configure_model(self):
# Save this for PyTorch-Lightning to infer the input/output shapes of the network.
if self.network is not None:
logger.info("Network is already instantiated.")
return
self.example_input_array = torch.zeros((self.datamodule.batch_size, *self.datamodule.dims))
with torch.random.fork_rng():
# deterministic weight initialization
Expand Down
5 changes: 5 additions & 0 deletions project/algorithms/jax_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from project.datamodules.image_classification.mnist import MNISTDataModule
from project.utils.typing_utils import HydraConfigFor

logger = logging.getLogger(__name__)


def flatten(x: jax.Array) -> jax.Array:
return x.reshape((x.shape[0], -1))
Expand Down Expand Up @@ -95,6 +97,9 @@ def __init__(
self.save_hyperparameters(ignore=["datamodule"])

def configure_model(self):
if self.network is not None:
logger.info("Network is already instantiated.")
return
example_input = torch.zeros(
(self.datamodule.batch_size, *self.datamodule.dims),
)
Expand Down
5 changes: 5 additions & 0 deletions project/algorithms/jax_ppo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,13 @@ def __init__(
self.num_train_iterations = np.ceil(self.learner.hp.eval_freq / iteration_steps).astype(
int
)
self.actor_params: torch.nn.ParameterList | None = None
self.critic_params: torch.nn.ParameterList | None = None

def configure_model(self):
if self.actor_params is not None:
logger.info("Networks are already instantiated.")
return
self.actor_params = torch.nn.ParameterList(
jax.tree.leaves(
jax.tree.map(
Expand Down
6 changes: 6 additions & 0 deletions project/algorithms/text_classifier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from datetime import datetime

import evaluate
Expand All @@ -14,6 +15,8 @@
from project.datamodules.text.text_classification import TextClassificationDataModule
from project.utils.typing_utils import HydraConfigFor

logger = logging.getLogger(__name__)


class TextClassifier(LightningModule):
"""Example of a lightning module used to train a huggingface model for text classification."""
Expand Down Expand Up @@ -51,6 +54,9 @@ def __init__(
self.save_hyperparameters(ignore=["datamodule"])

def configure_model(self) -> None:
if self.network is not None:
logger.info("Network is already instantiated.")
return
with torch.random.fork_rng(devices=[self.device]):
# deterministic weight initialization
torch.manual_seed(self.init_seed)
Expand Down
8 changes: 1 addition & 7 deletions project/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,16 +112,10 @@ def train_lightning(
algorithm: lightning.LightningModule,
/,
*,
trainer: lightning.Trainer | None,
trainer: lightning.Trainer,
datamodule: lightning.LightningDataModule | None = None,
config: Config,
):
# Create the Trainer from the config.
if trainer is None:
_trainer = instantiate_trainer(config.trainer)
assert isinstance(_trainer, lightning.Trainer)
trainer = _trainer

# Train the model using the dataloaders of the datamodule:
# The Algorithm gets to "wrap" the datamodule if it wants to. This could be useful for
# example in RL, where we need to set the actor to use in the environment, as well as
Expand Down