Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions examples/rl/actor_critic_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
Title: Actor Critic Method
Author: [Apoorv Nandan](https://twitter.com/NandanApoorv)
Date created: 2020/05/13
Last modified: 2024/02/22
Last modified: 2025/01/07
Description: Implement Actor Critic Method in CartPole environment.
Accelerator: NONE
Accelerator: None
Converted to Keras 3 by: [Sitam Meur](https://github.com/sitamgithub-MSIT)
"""

"""
## Introduction

This script shows an implementation of Actor Critic method on CartPole-V0 environment.
This script shows an implementation of Actor Critic method on CartPole-V1 environment.

### Actor Critic Method

Expand All @@ -26,7 +26,7 @@
Agent and Critic learn to perform their tasks, such that the recommended actions
from the actor maximize the rewards.

### CartPole-V0
### CartPole-V1

A pole is attached to a cart placed on a frictionless track. The agent has to apply
force to move the cart. It is rewarded for every time step the pole
Expand All @@ -45,7 +45,7 @@
import os

os.environ["KERAS_BACKEND"] = "tensorflow"
import gym
import gymnasium as gym
import numpy as np
import keras
from keras import ops
Expand All @@ -57,7 +57,7 @@
gamma = 0.99 # Discount factor for past rewards
max_steps_per_episode = 10000
# Adding `render_mode='human'` will show the attempts of the agent
env = gym.make("CartPole-v0") # Create the environment
env = gym.make("CartPole-v1") # Create the environment
env.reset(seed=seed)
eps = np.finfo(np.float32).eps.item() # Smallest number such that 1.0 + eps != 1.0

Expand Down Expand Up @@ -98,12 +98,12 @@
episode_count = 0

while True: # Run until solved
state = env.reset()[0]
obs, _ = env.reset()
episode_reward = 0
with tf.GradientTape() as tape:
for timestep in range(1, max_steps_per_episode):

state = ops.convert_to_tensor(state)
state = ops.convert_to_tensor(obs)
state = ops.expand_dims(state, 0)

# Predict action probabilities and estimated future rewards
Expand All @@ -116,10 +116,11 @@
action_probs_history.append(ops.log(action_probs[0, action]))

# Apply the sampled action in our environment
state, reward, done, *_ = env.step(action)
obs, reward, terminated, truncated, _ = env.step(action)
rewards_history.append(reward)
episode_reward += reward

done = terminated or truncated
if done:
break

Expand Down
39 changes: 20 additions & 19 deletions examples/rl/ipynb/actor_critic_cartpole.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Apoorv Nandan](https://twitter.com/NandanApoorv)<br>\n",
"**Date created:** 2020/05/13<br>\n",
"**Last modified:** 2024/02/22<br>\n",
"**Last modified:** 2025/01/07<br>\n",
"**Description:** Implement Actor Critic Method in CartPole environment."
]
},
Expand All @@ -22,7 +22,7 @@
"source": [
"## Introduction\n",
"\n",
"This script shows an implementation of Actor Critic method on CartPole-V0 environment.\n",
"This script shows an implementation of Actor Critic method on CartPole-V1 environment.\n",
"\n",
"### Actor Critic Method\n",
"\n",
Expand All @@ -37,7 +37,7 @@
"Agent and Critic learn to perform their tasks, such that the recommended actions\n",
"from the actor maximize the rewards.\n",
"\n",
"### CartPole-V0\n",
"### CartPole-V1\n",
"\n",
"A pole is attached to a cart placed on a frictionless track. The agent has to apply\n",
"force to move the cart. It is rewarded for every time step the pole\n",
Expand All @@ -47,7 +47,7 @@
"\n",
"- [Environment documentation](https://gymnasium.farama.org/environments/classic_control/cart_pole/)\n",
"- [CartPole paper](http://www.derongliu.org/adp/adp-cdrom/Barto1983.pdf)\n",
"- [Actor Critic Method](https://hal.inria.fr/hal-00840470/document)\n"
"- [Actor Critic Method](https://hal.inria.fr/hal-00840470/document)"
]
},
{
Expand All @@ -56,12 +56,12 @@
"colab_type": "text"
},
"source": [
"## Setup\n"
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -70,7 +70,7 @@
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
"import gym\n",
"import gymnasium as gym\n",
"import numpy as np\n",
"import keras\n",
"from keras import ops\n",
Expand All @@ -82,7 +82,7 @@
"gamma = 0.99 # Discount factor for past rewards\n",
"max_steps_per_episode = 10000\n",
"# Adding `render_mode='human'` will show the attempts of the agent\n",
"env = gym.make(\"CartPole-v0\") # Create the environment\n",
"env = gym.make(\"CartPole-v1\") # Create the environment\n",
"env.reset(seed=seed)\n",
"eps = np.finfo(np.float32).eps.item() # Smallest number such that 1.0 + eps != 1.0"
]
Expand All @@ -102,12 +102,12 @@
"2. Critic: This takes as input the state of our environment and returns\n",
"an estimate of total rewards in the future.\n",
"\n",
"In our implementation, they share the initial layer.\n"
"In our implementation, they share the initial layer."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -131,12 +131,12 @@
"colab_type": "text"
},
"source": [
"## Train\n"
"## Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -151,12 +151,12 @@
"episode_count = 0\n",
"\n",
"while True: # Run until solved\n",
" state = env.reset()[0]\n",
" obs, _ = env.reset()\n",
" episode_reward = 0\n",
" with tf.GradientTape() as tape:\n",
" for timestep in range(1, max_steps_per_episode):\n",
"\n",
" state = ops.convert_to_tensor(state)\n",
" state = ops.convert_to_tensor(obs)\n",
" state = ops.expand_dims(state, 0)\n",
"\n",
" # Predict action probabilities and estimated future rewards\n",
Expand All @@ -169,10 +169,11 @@
" action_probs_history.append(ops.log(action_probs[0, action]))\n",
"\n",
" # Apply the sampled action in our environment\n",
" state, reward, done, *_ = env.step(action)\n",
" obs, reward, terminated, truncated, _ = env.step(action)\n",
" rewards_history.append(reward)\n",
" episode_reward += reward\n",
"\n",
" done = terminated or truncated\n",
" if done:\n",
" break\n",
"\n",
Expand Down Expand Up @@ -245,12 +246,12 @@
"![Imgur](https://i.imgur.com/5gCs5kH.gif)\n",
"\n",
"In later stages of training:\n",
"![Imgur](https://i.imgur.com/5ziiZUD.gif)\n"
"![Imgur](https://i.imgur.com/5ziiZUD.gif)"
]
}
],
"metadata": {
"accelerator": "GPU",
"accelerator": "None",
"colab": {
"collapsed_sections": [],
"name": "actor_critic_cartpole",
Expand All @@ -273,9 +274,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading