Skip to content

Commit 516db1a

Browse files
authored
Fix bug causing network to be re-initialized! (#132)
Signed-off-by: Fabrice Normandin <[email protected]>
1 parent e190484 commit 516db1a

File tree

5 files changed

+20
-7
lines changed

5 files changed

+20
-7
lines changed

project/algorithms/image_classifier.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def __init__(
6565

6666
def configure_model(self):
6767
# Save this for PyTorch-Lightning to infer the input/output shapes of the network.
68+
if self.network is not None:
69+
logger.info("Network is already instantiated.")
70+
return
6871
self.example_input_array = torch.zeros((self.datamodule.batch_size, *self.datamodule.dims))
6972
with torch.random.fork_rng():
7073
# deterministic weight initialization

project/algorithms/jax_image_classifier.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from project.datamodules.image_classification.mnist import MNISTDataModule
2323
from project.utils.typing_utils import HydraConfigFor
2424

25+
logger = logging.getLogger(__name__)
26+
2527

2628
def flatten(x: jax.Array) -> jax.Array:
2729
return x.reshape((x.shape[0], -1))
@@ -95,6 +97,9 @@ def __init__(
9597
self.save_hyperparameters(ignore=["datamodule"])
9698

9799
def configure_model(self):
100+
if self.network is not None:
101+
logger.info("Network is already instantiated.")
102+
return
98103
example_input = torch.zeros(
99104
(self.datamodule.batch_size, *self.datamodule.dims),
100105
)

project/algorithms/jax_ppo_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,8 +532,13 @@ def __init__(
532532
self.num_train_iterations = np.ceil(self.learner.hp.eval_freq / iteration_steps).astype(
533533
int
534534
)
535+
self.actor_params: torch.nn.ParameterList | None = None
536+
self.critic_params: torch.nn.ParameterList | None = None
535537

536538
def configure_model(self):
539+
if self.actor_params is not None:
540+
logger.info("Networks are already instantiated.")
541+
return
537542
self.actor_params = torch.nn.ParameterList(
538543
jax.tree.leaves(
539544
jax.tree.map(

project/algorithms/text_classifier.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from datetime import datetime
23

34
import evaluate
@@ -14,6 +15,8 @@
1415
from project.datamodules.text.text_classification import TextClassificationDataModule
1516
from project.utils.typing_utils import HydraConfigFor
1617

18+
logger = logging.getLogger(__name__)
19+
1720

1821
class TextClassifier(LightningModule):
1922
"""Example of a lightning module used to train a huggingface model for text classification."""
@@ -51,6 +54,9 @@ def __init__(
5154
self.save_hyperparameters(ignore=["datamodule"])
5255

5356
def configure_model(self) -> None:
57+
if self.network is not None:
58+
logger.info("Network is already instantiated.")
59+
return
5460
with torch.random.fork_rng(devices=[self.device]):
5561
# deterministic weight initialization
5662
torch.manual_seed(self.init_seed)

project/experiment.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,16 +112,10 @@ def train_lightning(
112112
algorithm: lightning.LightningModule,
113113
/,
114114
*,
115-
trainer: lightning.Trainer | None,
115+
trainer: lightning.Trainer,
116116
datamodule: lightning.LightningDataModule | None = None,
117117
config: Config,
118118
):
119-
# Create the Trainer from the config.
120-
if trainer is None:
121-
_trainer = instantiate_trainer(config.trainer)
122-
assert isinstance(_trainer, lightning.Trainer)
123-
trainer = _trainer
124-
125119
# Train the model using the dataloaders of the datamodule:
126120
# The Algorithm gets to "wrap" the datamodule if it wants to. This could be useful for
127121
# example in RL, where we need to set the actor to use in the environment, as well as

0 commit comments

Comments
 (0)