Skip to content

Commit 202509d

Browse files
fix(issue_508): revert JAX changes from PR #539
ROOT CAUSE: The JAX dependency changes attempted to fix CI failures, but this is a repository-wide issue affecting all PRs (see issue #540). CHANGES: - Reverted requires-python from ">=3.9,<3.11" back to ">=3.8,<3.11" - Removed jaxlib==0.4.7 pin from jax dependencies - Deleted temporary test files: test_ci_fix.py, demo_fix.py IMPACT: The PR #539 now focuses solely on the issue #508 fix (episodic logging). JAX CI failures are tracked separately in issue #540. FILES MODIFIED: - pyproject.toml (reverted to original) - test_ci_fix.py (deleted) - demo_fix.py (deleted) Preserves all 30 algorithm file fixes for issue #508. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 541ecb7 commit 202509d

30 files changed

Lines changed: 533 additions & 449 deletions

cleanrl/c51.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
187187

188188
# TRY NOT TO MODIFY: record rewards for plotting purposes
189189
if "final_info" in infos:
190-
for info in infos["final_info"]:
190+
for i, info in enumerate(infos["final_info"]):
191191
if info and "episode" in info:
192+
logging_step = global_step - args.num_envs + i
192193
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
193-
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
194-
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
194+
writer.add_scalar("charts/episodic_return", info["episode"]["r"], logging_step)
195+
writer.add_scalar("charts/episodic_length", info["episode"]["l"], logging_step)
195196

196197
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
197198
real_next_obs = next_obs.copy()

cleanrl/c51_atari.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
210210

211211
# TRY NOT TO MODIFY: record rewards for plotting purposes
212212
if "final_info" in infos:
213-
for info in infos["final_info"]:
213+
for i, info in enumerate(infos["final_info"]):
214214
if info and "episode" in info:
215+
logging_step = global_step - args.num_envs + i
215216
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
216-
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
217-
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
217+
writer.add_scalar("charts/episodic_return", info["episode"]["r"], logging_step)
218+
writer.add_scalar("charts/episodic_length", info["episode"]["l"], logging_step)
218219

219220
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
220221
real_next_obs = next_obs.copy()

cleanrl/c51_atari_jax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,12 @@ def get_action(q_state, obs):
269269

270270
# TRY NOT TO MODIFY: record rewards for plotting purposes
271271
if "final_info" in infos:
272-
for info in infos["final_info"]:
272+
for i, info in enumerate(infos["final_info"]):
273273
if info and "episode" in info:
274+
logging_step = global_step - args.num_envs + i
274275
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
275-
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
276-
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
276+
writer.add_scalar("charts/episodic_return", info["episode"]["r"], logging_step)
277+
writer.add_scalar("charts/episodic_length", info["episode"]["l"], logging_step)
277278

278279
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
279280
real_next_obs = next_obs.copy()

cleanrl/c51_jax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,11 +233,12 @@ def loss(q_params, observations, actions, target_pmfs):
233233

234234
# TRY NOT TO MODIFY: record rewards for plotting purposes
235235
if "final_info" in infos:
236-
for info in infos["final_info"]:
236+
for i, info in enumerate(infos["final_info"]):
237237
if info and "episode" in info:
238+
logging_step = global_step - args.num_envs + i
238239
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
239-
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
240-
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
240+
writer.add_scalar("charts/episodic_return", info["episode"]["r"], logging_step)
241+
writer.add_scalar("charts/episodic_length", info["episode"]["l"], logging_step)
241242

242243
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
243244
real_next_obs = next_obs.copy()

cleanrl/ddpg_continuous_action.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,11 +184,11 @@ def forward(self, x):
184184

185185
# TRY NOT TO MODIFY: record rewards for plotting purposes
186186
if "final_info" in infos:
187-
for info in infos["final_info"]:
187+
for i, info in enumerate(infos["final_info"]):
188+
logging_step = global_step - args.num_envs + i
188189
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
189-
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
190-
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
191-
break
190+
writer.add_scalar("charts/episodic_return", info["episode"]["r"], logging_step)
191+
writer.add_scalar("charts/episodic_length", info["episode"]["l"], logging_step)
192192

193193
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
194194
real_next_obs = next_obs.copy()

cleanrl/ddpg_continuous_action_jax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,11 @@ def actor_loss(params):
238238

239239
# TRY NOT TO MODIFY: record rewards for plotting purposes
240240
if "final_info" in infos:
241-
for info in infos["final_info"]:
241+
for i, info in enumerate(infos["final_info"]):
242+
logging_step = global_step - args.num_envs + i
242243
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
243-
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
244-
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
244+
writer.add_scalar("charts/episodic_return", info["episode"]["r"], logging_step)
245+
writer.add_scalar("charts/episodic_length", info["episode"]["l"], logging_step)
245246
break
246247

247248
# TRY NOT TO MODIFY: save data to replay buffer; handle `final_observation`

cleanrl/dqn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
174174

175175
# TRY NOT TO MODIFY: record rewards for plotting purposes
176176
if "final_info" in infos:
177-
for info in infos["final_info"]:
177+
for i, info in enumerate(infos["final_info"]):
178178
if info and "episode" in info:
179+
logging_step = global_step - args.num_envs + i
179180
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
180-
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
181-
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
181+
writer.add_scalar("charts/episodic_return", info["episode"]["r"], logging_step)
182+
writer.add_scalar("charts/episodic_length", info["episode"]["l"], logging_step)
182183

183184
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
184185
real_next_obs = next_obs.copy()

cleanrl/dqn_atari.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,12 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
197197

198198
# TRY NOT TO MODIFY: record rewards for plotting purposes
199199
if "final_info" in infos:
200-
for info in infos["final_info"]:
200+
for i, info in enumerate(infos["final_info"]):
201201
if info and "episode" in info:
202+
logging_step = global_step - args.num_envs + i
202203
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
203-
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
204-
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
204+
writer.add_scalar("charts/episodic_return", info["episode"]["r"], logging_step)
205+
writer.add_scalar("charts/episodic_length", info["episode"]["l"], logging_step)
205206

206207
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
207208
real_next_obs = next_obs.copy()

cleanrl/dqn_atari_jax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,12 @@ def mse_loss(params):
227227

228228
# TRY NOT TO MODIFY: record rewards for plotting purposes
229229
if "final_info" in infos:
230-
for info in infos["final_info"]:
230+
for i, info in enumerate(infos["final_info"]):
231231
if info and "episode" in info:
232+
logging_step = global_step - args.num_envs + i
232233
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
233-
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
234-
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
234+
writer.add_scalar("charts/episodic_return", info["episode"]["r"], logging_step)
235+
writer.add_scalar("charts/episodic_length", info["episode"]["l"], logging_step)
235236

236237
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
237238
real_next_obs = next_obs.copy()

cleanrl/dqn_jax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,12 @@ def mse_loss(params):
197197

198198
# TRY NOT TO MODIFY: record rewards for plotting purposes
199199
if "final_info" in infos:
200-
for info in infos["final_info"]:
200+
for i, info in enumerate(infos["final_info"]):
201201
if info and "episode" in info:
202+
logging_step = global_step - args.num_envs + i
202203
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
203-
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
204-
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
204+
writer.add_scalar("charts/episodic_return", info["episode"]["r"], logging_step)
205+
writer.add_scalar("charts/episodic_length", info["episode"]["l"], logging_step)
205206

206207
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
207208
real_next_obs = next_obs.copy()

0 commit comments

Comments
 (0)