Skip to content

Commit 8990fe6

Browse files
committed
fix checkpointing in fig 3 and 4
1 parent 480cab5 commit 8990fe6

File tree

5 files changed

+154
-75
lines changed

5 files changed

+154
-75
lines changed

RED/agents/continuous_agents/rt3d.py

+56-35
Original file line numberDiff line numberDiff line change
@@ -734,42 +734,63 @@ def Q_update(self, recurrent=True, monte_carlo=False, policy=True, verbose=False
734734
self.update_target_network(source=self.Q2_network, target=self.Q2_target, tau=self.polyak)
735735
self.update_target_network(source=self.policy_network, target=self.policy_target, tau=self.polyak)
736736

737-
def save_network(self, save_path):
738-
'''
739-
Saves networks to directory specified by save_path
740-
:param save_path: directory to save networks to
741-
'''
742-
743-
torch.save(self.policy_network, os.path.join(save_path, "policy_network.pth"))
744-
torch.save(self.Q1_network, os.path.join(save_path, "Q1_network.pth"))
745-
torch.save(self.Q2_network, os.path.join(save_path, "Q2_network.pth"))
746-
747-
torch.save(self.policy_target, os.path.join(save_path, "policy_target.pth"))
748-
torch.save(self.Q1_target, os.path.join(save_path, "Q1_target.pth"))
749-
torch.save(self.Q2_target, os.path.join(save_path, "Q2_target.pth"))
750-
751-
def load_network(self, load_path, load_target_networks=False):
752-
'''
753-
Loads netoworks from directory specified by load_path.
754-
:param load_path: directory to load networks from
755-
:param load_target_networks: whether to load target networks
756-
'''
757-
758-
self.policy_network = torch.load(os.path.join(load_path, "policy_network.pth"))
759-
self.policy_network_opt = Adam(self.policy_network.parameters(), lr=self.pol_learning_rate)
760-
761-
self.Q1_network = torch.load(os.path.join(load_path, "Q1_network.pth"))
762-
self.Q1_network_opt = Adam(self.Q1_network.parameters(), lr=self.val_learning_rate)
763-
764-
self.Q2_network = torch.load(os.path.join(load_path, "Q2_network.pth"))
765-
self.Q2_etwork_opt = Adam(self.Q2_network.parameters(), lr=self.val_learning_rate)
766-
737+
def save_ckpt(self, save_path, additional_info=None):
738+
'''
739+
Creates a full checkpoint (networks, optimizers, memory buffers) and saves it to the specified path.
740+
:param save_path: path to save the checkpoint to
741+
:param additional_info: additional information to save (Python dictionary)
742+
'''
743+
ckpt = {
744+
"policy_network": self.policy_network.state_dict(),
745+
"Q1_network": self.Q1_network.state_dict(),
746+
"Q2_network": self.Q2_network.state_dict(),
747+
"policy_target": self.policy_target.state_dict(),
748+
"Q1_target": self.Q1_target.state_dict(),
749+
"Q2_target": self.Q2_target.state_dict(),
750+
"policy_network_opt": self.policy_network_opt.state_dict(),
751+
"Q1_network_opt": self.Q1_network_opt.state_dict(),
752+
"Q2_network_opt": self.Q2_network_opt.state_dict(),
753+
"additional_info": additional_info if additional_info is not None else {},
754+
}
755+
756+
### save buffers
757+
for buffer in ("memory", "values", "states", "next_states", "actions", "rewards", "dones",
758+
"sequences", "next_sequences", "all_returns"):
759+
ckpt[buffer] = getattr(self, buffer)
760+
761+
### save the checkpoint
762+
torch.save(ckpt, save_path)
763+
764+
def load_ckpt(self, load_path, load_target_networks=True):
765+
'''
766+
Loads a full checkpoint (networks, optimizers, memory buffers) from the specified path.
767+
:param load_path: path to load the checkpoint from
768+
:param load_target_networks: whether to load the target networks as well
769+
'''
770+
ckpt = torch.load(load_path)
771+
772+
### load networks
773+
self.policy_network.load_state_dict(ckpt["policy_network"])
774+
self.Q1_network.load_state_dict(ckpt["Q1_network"])
775+
self.Q2_network.load_state_dict(ckpt["Q2_network"])
776+
777+
### load target networks
767778
if load_target_networks:
768-
self.policy_target = torch.load(os.path.join(load_path, "policy_target.pth"))
769-
self.Q1_target = torch.load(os.path.join(load_path, "Q1_target.pth"))
770-
self.Q2_target = torch.load(os.path.join(load_path, "Q2_target.pth"))
771-
else:
772-
print("[WARNING] Not loading target networks")
779+
self.policy_target.load_state_dict(ckpt["policy_target"])
780+
self.Q1_target.load_state_dict(ckpt["Q1_target"])
781+
self.Q2_target.load_state_dict(ckpt["Q2_target"])
782+
783+
### load optimizers
784+
self.policy_network_opt.load_state_dict(ckpt["policy_network_opt"])
785+
self.Q1_network_opt.load_state_dict(ckpt["Q1_network_opt"])
786+
self.Q2_network_opt.load_state_dict(ckpt["Q2_network_opt"])
787+
788+
### load buffers
789+
for buffer in ("memory", "values", "states", "next_states", "actions", "rewards", "dones",
790+
"sequences", "next_sequences", "all_returns"):
791+
setattr(self, buffer, ckpt[buffer])
792+
793+
return ckpt
773794

774795
def reset_weights(self, policy=True):
775796
'''

RED/configs/example/Figure_3_RT3D_chemostat.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ explore_rate_mul: 1
99
test_episode: False
1010
save_path: ${hydra:run.dir}
1111
ckpt_freq: 50
12+
load_ckpt_dir_path: null # directory containing agent's checkpoint to load ("agent.pt") + optionally "history.json" from which to resume training
1213

1314
model:
1415
batch_size: ${eval:'${example.environment.N_control_intervals} * ${example.environment.n_parallel_experiments}'}

RED/configs/example/Figure_4_RT3D_chemostat.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ explore_rate_mul: 1
99
test_episode: False
1010
save_path: ${hydra:run.dir}
1111
ckpt_freq: 50
12+
load_ckpt_dir_path: null # directory containing agent's checkpoint to load ("agent.pt") + optionally "history.json" from which to resume training
1213

1314
model:
1415
batch_size: ${eval:'${example.environment.N_control_intervals} * ${example.environment.n_parallel_experiments}'}

examples/Figure_3_RT3D_chemostat/train_RT3D.py

+48-20
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
import json
23
import math
34
import os
45
import sys
@@ -48,12 +49,35 @@ def train_RT3D(cfg : DictConfig):
4849
env, n_params = setup_env(cfg)
4950
total_episodes = cfg.environment.n_episodes // cfg.environment.n_parallel_experiments
5051
skip_first_n_episodes = cfg.environment.skip_first_n_experiments // cfg.environment.n_parallel_experiments
51-
52-
history = {k: [] for k in ["returns", "actions", "rewards", "us", "explore_rate"]}
53-
update_count = 0
52+
starting_episode = 0
53+
54+
history = {k: [] for k in ["returns", "actions", "rewards", "us", "explore_rate", "update_count"]}
55+
56+
### load ckpt
57+
if cfg.load_ckpt_dir_path is not None:
58+
print(f"Loading checkpoint from: {cfg.load_ckpt_dir_path}")
59+
# load the agent
60+
agent_path = os.path.join(cfg.load_ckpt_dir_path, "agent.pt")
61+
print(f"Loading agent from: {agent_path}")
62+
additional_info = agent.load_ckpt(
63+
load_path=agent_path,
64+
load_target_networks=True,
65+
)["additional_info"]
66+
# load history
67+
history_path = os.path.join(cfg.load_ckpt_dir_path, "history.json")
68+
if os.path.exists(history_path):
69+
print(f"Loading history from: {history_path}")
70+
with open(history_path, "r") as f:
71+
history = json.load(f)
72+
# load explore rate
73+
if "explore_rate" in history and len(history["explore_rate"]) > 0:
74+
explore_rate = history["explore_rate"][-1]
75+
# load starting episode
76+
if "episode" in additional_info:
77+
starting_episode = additional_info["episode"] + 1
5478

5579
### training loop
56-
for episode in range(total_episodes):
80+
for episode in range(starting_episode, total_episodes):
5781
actual_params = np.random.uniform(
5882
low=cfg.environment.actual_params,
5983
high=cfg.environment.actual_params,
@@ -108,11 +132,10 @@ def train_RT3D(cfg : DictConfig):
108132
sequences[i].append(np.concatenate((state, action)))
109133

110134
### log episode data
111-
e_us[i].append(u)
135+
e_us[i].append(u.tolist())
112136
next_states.append(next_state)
113-
if reward != -1: # dont include the unstable trajectories as they override the true return
114-
e_rewards[i].append(reward)
115-
e_returns[i] += reward
137+
e_rewards[i].append(reward)
138+
e_returns[i] += reward
116139
states = next_states
117140

118141
### do not memorize the test trajectory (the last one)
@@ -129,9 +152,11 @@ def train_RT3D(cfg : DictConfig):
129152
### train agent
130153
if episode > skip_first_n_episodes:
131154
for _ in range(cfg.environment.n_parallel_experiments):
132-
update_count += 1
133-
update_policy = update_count % cfg.policy_delay == 0
155+
history["update_count"].append(history["update_count"][-1] + 1 if len(history["update_count"]) > 0 else 1)
156+
update_policy = history["update_count"][-1] % cfg.policy_delay == 0
134157
agent.Q_update(policy=update_policy, recurrent=True)
158+
else:
159+
history["update_count"].append(history["update_count"][-1] if len(history["update_count"]) > 0 else 0)
135160

136161
### update explore rate
137162
explore_rate = cfg.explore_rate_mul * agent.get_rate(
@@ -143,7 +168,7 @@ def train_RT3D(cfg : DictConfig):
143168

144169
### log results
145170
history["returns"].extend(e_returns)
146-
history["actions"].extend(np.array(e_actions).transpose(1, 0, 2))
171+
history["actions"].extend(np.array(e_actions).transpose(1, 0, 2).tolist())
147172
history["rewards"].extend(e_rewards)
148173
history["us"].extend(e_us)
149174
history["explore_rate"].append(explore_rate)
@@ -164,17 +189,20 @@ def train_RT3D(cfg : DictConfig):
164189
)
165190

166191
### checkpoint
167-
if cfg.ckpt_freq is not None and episode % cfg.ckpt_freq == 0:
192+
if (cfg.ckpt_freq is not None and episode % cfg.ckpt_freq == 0) \
193+
or episode == total_episodes - 1:
168194
ckpt_dir = os.path.join(cfg.save_path, f"ckpt_{episode}")
169195
os.makedirs(ckpt_dir, exist_ok=True)
170-
agent.save_network(ckpt_dir)
171-
for k in history.keys():
172-
np.save(os.path.join(ckpt_dir, f"{k}.npy"), np.array(history[k]))
173-
174-
### save results and plot
175-
agent.save_network(cfg.save_path)
176-
for k in history.keys():
177-
np.save(os.path.join(cfg.save_path, f"{k}.npy"), np.array(history[k]))
196+
agent.save_ckpt(
197+
save_path=os.path.join(ckpt_dir, "agent.pt"),
198+
additional_info={
199+
"episode": episode,
200+
}
201+
)
202+
with open(os.path.join(ckpt_dir, "history.json"), "w") as f:
203+
json.dump(history, f)
204+
205+
### plot
178206
plot_returns(
179207
returns=history["returns"],
180208
explore_rates=history["explore_rate"],

examples/Figure_4_RT3D_chemostat/train_RT3D.py

+48-20
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
import json
23
import math
34
import os
45
import sys
@@ -47,12 +48,35 @@ def train_RT3D(cfg : DictConfig):
4748
env, n_params = setup_env(cfg)
4849
total_episodes = cfg.environment.n_episodes // cfg.environment.n_parallel_experiments
4950
skip_first_n_episodes = cfg.environment.skip_first_n_experiments // cfg.environment.n_parallel_experiments
50-
51-
history = {k: [] for k in ["returns", "actions", "rewards", "us", "explore_rate"]}
52-
update_count = 0
51+
starting_episode = 0
52+
53+
history = {k: [] for k in ["returns", "actions", "rewards", "us", "explore_rate", "update_count"]}
54+
55+
### load ckpt
56+
if cfg.load_ckpt_dir_path is not None:
57+
print(f"Loading checkpoint from: {cfg.load_ckpt_dir_path}")
58+
# load the agent
59+
agent_path = os.path.join(cfg.load_ckpt_dir_path, "agent.pt")
60+
print(f"Loading agent from: {agent_path}")
61+
additional_info = agent.load_ckpt(
62+
load_path=agent_path,
63+
load_target_networks=True,
64+
)["additional_info"]
65+
# load history
66+
history_path = os.path.join(cfg.load_ckpt_dir_path, "history.json")
67+
if os.path.exists(history_path):
68+
print(f"Loading history from: {history_path}")
69+
with open(history_path, "r") as f:
70+
history = json.load(f)
71+
# load explore rate
72+
if "explore_rate" in history and len(history["explore_rate"]) > 0:
73+
explore_rate = history["explore_rate"][-1]
74+
# load starting episode
75+
if "episode" in additional_info:
76+
starting_episode = additional_info["episode"] + 1
5377

5478
### training loop
55-
for episode in range(total_episodes):
79+
for episode in range(starting_episode, total_episodes):
5680
# sample params from uniform distribution
5781
actual_params = np.random.uniform(
5882
low=cfg.environment.lb,
@@ -108,11 +132,10 @@ def train_RT3D(cfg : DictConfig):
108132
sequences[i].append(np.concatenate((state, action)))
109133

110134
### log episode data
111-
e_us[i].append(u)
135+
e_us[i].append(u.tolist())
112136
next_states.append(next_state)
113-
if reward != -1: # dont include the unstable trajectories as they override the true return
114-
e_rewards[i].append(reward)
115-
e_returns[i] += reward
137+
e_rewards[i].append(reward)
138+
e_returns[i] += reward
116139
states = next_states
117140

118141
### do not memorize the test trajectory (the last one)
@@ -129,9 +152,11 @@ def train_RT3D(cfg : DictConfig):
129152
### train agent
130153
if episode > skip_first_n_episodes:
131154
for _ in range(cfg.environment.n_parallel_experiments):
132-
update_count += 1
133-
update_policy = update_count % cfg.policy_delay == 0
155+
history["update_count"].append(history["update_count"][-1] + 1 if len(history["update_count"]) > 0 else 1)
156+
update_policy = history["update_count"][-1] % cfg.policy_delay == 0
134157
agent.Q_update(policy=update_policy, recurrent=True)
158+
else:
159+
history["update_count"].append(history["update_count"][-1] if len(history["update_count"]) > 0 else 0)
135160

136161
### update explore rate
137162
explore_rate = cfg.explore_rate_mul * agent.get_rate(
@@ -143,7 +168,7 @@ def train_RT3D(cfg : DictConfig):
143168

144169
### log results
145170
history["returns"].extend(e_returns)
146-
history["actions"].extend(np.array(e_actions).transpose(1, 0, 2))
171+
history["actions"].extend(np.array(e_actions).transpose(1, 0, 2).tolist())
147172
history["rewards"].extend(e_rewards)
148173
history["us"].extend(e_us)
149174
history["explore_rate"].append(explore_rate)
@@ -164,17 +189,20 @@ def train_RT3D(cfg : DictConfig):
164189
)
165190

166191
### checkpoint
167-
if cfg.ckpt_freq is not None and episode % cfg.ckpt_freq == 0:
192+
if (cfg.ckpt_freq is not None and episode % cfg.ckpt_freq == 0) \
193+
or episode == total_episodes - 1:
168194
ckpt_dir = os.path.join(cfg.save_path, f"ckpt_{episode}")
169195
os.makedirs(ckpt_dir, exist_ok=True)
170-
agent.save_network(ckpt_dir)
171-
for k in history.keys():
172-
np.save(os.path.join(ckpt_dir, f"{k}.npy"), np.array(history[k]))
173-
174-
### save results and plot
175-
agent.save_network(cfg.save_path)
176-
for k in history.keys():
177-
np.save(os.path.join(cfg.save_path, f"{k}.npy"), np.array(history[k]))
196+
agent.save_ckpt(
197+
save_path=os.path.join(ckpt_dir, "agent.pt"),
198+
additional_info={
199+
"episode": episode,
200+
}
201+
)
202+
with open(os.path.join(ckpt_dir, "history.json"), "w") as f:
203+
json.dump(history, f)
204+
205+
### plot
178206
plot_returns(
179207
returns=history["returns"],
180208
explore_rates=history["explore_rate"],

0 commit comments

Comments
 (0)