File tree Expand file tree Collapse file tree 5 files changed +20
-7
lines changed
Expand file tree Collapse file tree 5 files changed +20
-7
lines changed Original file line number Diff line number Diff line change @@ -65,6 +65,9 @@ def __init__(
6565
6666 def configure_model (self ):
6767 # Save this for PyTorch-Lightning to infer the input/output shapes of the network.
68+ if self .network is not None :
69+ logger .info ("Network is already instantiated." )
70+ return
6871 self .example_input_array = torch .zeros ((self .datamodule .batch_size , * self .datamodule .dims ))
6972 with torch .random .fork_rng ():
7073 # deterministic weight initialization
Original file line number Diff line number Diff line change 2222from project .datamodules .image_classification .mnist import MNISTDataModule
2323from project .utils .typing_utils import HydraConfigFor
2424
25+ logger = logging .getLogger (__name__ )
26+
2527
2628def flatten (x : jax .Array ) -> jax .Array :
2729 return x .reshape ((x .shape [0 ], - 1 ))
@@ -95,6 +97,9 @@ def __init__(
9597 self .save_hyperparameters (ignore = ["datamodule" ])
9698
9799 def configure_model (self ):
100+ if self .network is not None :
101+ logger .info ("Network is already instantiated." )
102+ return
98103 example_input = torch .zeros (
99104 (self .datamodule .batch_size , * self .datamodule .dims ),
100105 )
Original file line number Diff line number Diff line change @@ -532,8 +532,13 @@ def __init__(
532532 self .num_train_iterations = np .ceil (self .learner .hp .eval_freq / iteration_steps ).astype (
533533 int
534534 )
535+ self .actor_params : torch .nn .ParameterList | None = None
536+ self .critic_params : torch .nn .ParameterList | None = None
535537
536538 def configure_model (self ):
539+ if self .actor_params is not None :
540+ logger .info ("Networks are already instantiated." )
541+ return
537542 self .actor_params = torch .nn .ParameterList (
538543 jax .tree .leaves (
539544 jax .tree .map (
Original file line number Diff line number Diff line change 1+ import logging
12from datetime import datetime
23
34import evaluate
1415from project .datamodules .text .text_classification import TextClassificationDataModule
1516from project .utils .typing_utils import HydraConfigFor
1617
18+ logger = logging .getLogger (__name__ )
19+
1720
1821class TextClassifier (LightningModule ):
1922 """Example of a lightning module used to train a huggingface model for text classification."""
@@ -51,6 +54,9 @@ def __init__(
5154 self .save_hyperparameters (ignore = ["datamodule" ])
5255
5356 def configure_model (self ) -> None :
57+ if self .network is not None :
58+ logger .info ("Network is already instantiated." )
59+ return
5460 with torch .random .fork_rng (devices = [self .device ]):
5561 # deterministic weight initialization
5662 torch .manual_seed (self .init_seed )
Original file line number Diff line number Diff line change @@ -112,16 +112,10 @@ def train_lightning(
112112 algorithm : lightning .LightningModule ,
113113 / ,
114114 * ,
115- trainer : lightning .Trainer | None ,
115+ trainer : lightning .Trainer ,
116116 datamodule : lightning .LightningDataModule | None = None ,
117117 config : Config ,
118118):
119- # Create the Trainer from the config.
120- if trainer is None :
121- _trainer = instantiate_trainer (config .trainer )
122- assert isinstance (_trainer , lightning .Trainer )
123- trainer = _trainer
124-
125119 # Train the model using the dataloaders of the datamodule:
126120 # The Algorithm gets to "wrap" the datamodule if it wants to. This could be useful for
127121 # example in RL, where we need to set the actor to use in the environment, as well as
You can’t perform that action at this time.
0 commit comments