|
493 | 493 | "elite_states = [1, 2, 3, 4, 2, 0, 2, 3, 1]\n", |
494 | 494 | "elite_actions = [0, 2, 4, 3, 2, 0, 1, 3, 3]\n", |
495 | 495 | "\n", |
496 | | - "new_policy = update_policy(elite_states, elite_actions)\n", |
| 496 | + "new_policy = update_policy(elite_states, elite_actions, n_states, n_actions)\n", |
497 | 497 | "\n", |
498 | 498 | "assert np.isfinite(new_policy).all(\n", |
499 | 499 | "), \"Your new policy contains NaNs or +-inf. Make sure you don't divide by zero.\"\n", |
|
587 | 587 | "\n", |
588 | 588 | "for i in range(100):\n", |
589 | 589 | "\n", |
590 | | - " %time sessions = [generate_session(policy) for _ in range(n_sessions)]\n", |
| 590 | + " %time sessions = [generate_session(env, policy) for _ in range(n_sessions)]\n", |
591 | 591 | "\n", |
592 | 592 | " states_batch, actions_batch, rewards_batch = zip(*sessions)\n", |
593 | 593 | "\n", |
594 | 594 | " elite_states, elite_actions = select_elites(states_batch, actions_batch, rewards_batch, percentile)\n", |
595 | 595 | "\n", |
596 | | - " new_policy = update_policy(elite_states, elite_actions)\n", |
| 596 | + " new_policy = update_policy(elite_states, elite_actions, n_states, n_actions)\n", |
597 | 597 | "\n", |
598 | 598 | " policy = learning_rate*new_policy + (1-learning_rate)*policy\n", |
599 | 599 | "\n", |
|
0 commit comments