Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,369 changes: 134 additions & 1,235 deletions RLBasics.md

Large diffs are not rendered by default.

Binary file modified averaging_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified cliff_walking_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified discount_factor_visualization.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified epsilon_greedy_vs_softmax.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified function_approximation_diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
206 changes: 96 additions & 110 deletions generate_visuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,61 +28,61 @@

def create_gridworld_example():
"""Create a simple gridworld visualization"""
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
# 5x6 grid
grid_height, grid_width = 5, 6
fig, ax = plt.subplots(1, 1, figsize=(8, 3.5))

# 3x6 grid (removed bottom 2 rows)
grid_height, grid_width = 3, 6

# Create grid
for i in range(grid_height + 1):
ax.plot([0, grid_width], [i, i], 'k-', linewidth=1)
for j in range(grid_width + 1):
ax.plot([j, j], [0, grid_height], 'k-', linewidth=1)

# Fill cells with colors
# Start position (green)
start_rect = patches.Rectangle((0, 4), 1, 1, linewidth=2,
# Start position (green) - top row
start_rect = patches.Rectangle((0, 2), 1, 1, linewidth=2,
edgecolor='darkgreen', facecolor='lightgreen', alpha=0.7)
ax.add_patch(start_rect)
ax.text(0.5, 4.5, 'S\nStart', ha='center', va='center', fontsize=13, fontweight='bold')
# Goal position (gold)
goal_rect = patches.Rectangle((5, 4), 1, 1, linewidth=2,
ax.text(0.5, 2.5, 'Start', ha='center', va='center', fontsize=13, fontweight='bold')

# Goal position (gold) - top row
goal_rect = patches.Rectangle((5, 2), 1, 1, linewidth=2,
edgecolor='darkgoldenrod', facecolor='gold', alpha=0.7)
ax.add_patch(goal_rect)
ax.text(5.5, 4.5, 'G\nGoal\n+10', ha='center', va='center', fontsize=13, fontweight='bold')
# Walls (gray)
walls = [(2, 2), (2, 3), (3, 2), (3, 3)]
ax.text(5.5, 2.5, 'Goal\n(+10)', ha='center', va='center', fontsize=13, fontweight='bold')

# Walls (gray) - blocking top middle two squares
walls = [(2, 2), (3, 2)]
for x, y in walls:
wall_rect = patches.Rectangle((x, y), 1, 1, linewidth=1,
edgecolor='black', facecolor='gray', alpha=0.8)
ax.add_patch(wall_rect)
ax.text(x + 0.5, y + 0.5, 'Wall\n-1', ha='center', va='center',
ax.text(x + 0.5, y + 0.5, 'Wall\n(-1)', ha='center', va='center',
fontsize=10, fontweight='bold', color='white')

# Sample path (blue arrows)
path = [(0.5, 4.5), (1.5, 4.5), (2.5, 4.5), (3.5, 4.5), (4.5, 4.5), (5.5, 4.5)]

# Sample path (blue arrows) - go around the walls
path = [(0.8, 2.5), (1.5, 2.5), (1.5, 1.5), (2.5, 1.5), (3.5, 1.5),
(4.5, 1.5), (4.5, 2.5), (5.2, 2.5)]
for i in range(len(path) - 1):
arrow = FancyArrowPatch(path[i], path[i+1],
arrowstyle='->', mutation_scale=20,
linewidth=2.5, color='blue', alpha=0.6)
arrow = FancyArrowPatch(path[i], path[i+1],
arrowstyle='->', mutation_scale=15,
linewidth=2, color='#1f77b4', alpha=0.5)
ax.add_patch(arrow)
# Add action labels
ax.text(3, 0.3, 'Actions: ↑ (up), ↓ (down), ← (left), → (right)',

# Add action labels (positioned below the grid with more spacing)
ax.text(3, -0.4, 'Actions: ↑ (up), ↓ (down), ← (left), → (right)',
ha='center', fontsize=11, style='italic')

ax.set_xlim(-0.2, grid_width + 0.2)
ax.set_ylim(-0.5, grid_height + 0.2)
ax.set_ylim(-0.7, grid_height + 0.2)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('GridWorld Example: Navigate from Start to Goal', fontsize=15, fontweight='bold', pad=10)


plt.tight_layout()
plt.savefig('gridworld_example.png', dpi=150, bbox_inches='tight')
plt.close()
print(" Created gridworld_example.png")
print("[OK] Created gridworld_example.png")


def create_epsilon_greedy_vs_softmax():
Expand Down Expand Up @@ -143,7 +143,7 @@ def create_epsilon_greedy_vs_softmax():
plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.savefig('epsilon_greedy_vs_softmax.png', dpi=150, bbox_inches='tight')
plt.close()
print(" Created epsilon_greedy_vs_softmax.png")
print("[OK] Created epsilon_greedy_vs_softmax.png")


def create_averaging_comparison():
Expand Down Expand Up @@ -186,21 +186,24 @@ def create_averaging_comparison():

ax.axvline(x=300, color='gray', linestyle='--', alpha=0.5, linewidth=1)
ax.axvline(x=600, color='gray', linestyle='--', alpha=0.5, linewidth=1)
ax.text(150, 3.5, 'Stationary', ha='center', fontsize=11, style='italic')
ax.text(450, 3.5, 'Drifting', ha='center', fontsize=11, style='italic')
ax.text(800, 3.5, 'Stationary', ha='center', fontsize=11, style='italic')

ax.text(150, 3.35, 'Stationary', ha='center', fontsize=10, style='italic',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8, edgecolor='gray', linewidth=1))
ax.text(450, 3.35, 'Drifting', ha='center', fontsize=10, style='italic',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8, edgecolor='gray', linewidth=1))
ax.text(800, 3.35, 'Stationary', ha='center', fontsize=10, style='italic',
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8, edgecolor='gray', linewidth=1))

ax.set_xlabel('Step', fontsize=13, fontweight='bold')
ax.set_ylabel('Estimated Value', fontsize=13, fontweight='bold')
ax.set_title('Simple vs. Exponential Averaging (Non-Stationary Environment)',
fontsize=15, fontweight='bold')
ax.legend(loc='upper left', fontsize=12, framealpha=0.9)
ax.set_title('Simple vs. Exponential Averaging (Non-Stationary Environment)',
fontsize=15, fontweight='bold', pad=15)
ax.legend(loc='upper left', fontsize=12, framealpha=0.9, bbox_to_anchor=(0, 0.95))
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('averaging_comparison.png', dpi=150, bbox_inches='tight')
plt.close()
print(" Created averaging_comparison.png")
print("[OK] Created averaging_comparison.png")


def create_discount_factor_visualization():
Expand Down Expand Up @@ -234,7 +237,7 @@ def create_discount_factor_visualization():
plt.tight_layout()
plt.savefig('discount_factor_visualization.png', dpi=150, bbox_inches='tight')
plt.close()
print(" Created discount_factor_visualization.png")
print("[OK] Created discount_factor_visualization.png")


def create_td_update_diagram():
Expand Down Expand Up @@ -296,7 +299,7 @@ def create_td_update_diagram():
plt.tight_layout()
plt.savefig('td_update_diagram.png', dpi=150, bbox_inches='tight')
plt.close()
print(" Created td_update_diagram.png")
print("[OK] Created td_update_diagram.png")


def create_sarsa_vs_qlearning_diagram():
Expand Down Expand Up @@ -394,7 +397,7 @@ def create_sarsa_vs_qlearning_diagram():
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig('sarsa_vs_qlearning_diagram.png', dpi=150, bbox_inches='tight')
plt.close()
print(" Created sarsa_vs_qlearning_diagram.png")
print("[OK] Created sarsa_vs_qlearning_diagram.png")


def create_cliff_walking_comparison():
Expand Down Expand Up @@ -464,7 +467,7 @@ def create_cliff_walking_comparison():
plt.tight_layout(rect=[0, 0.05, 1, 1])
plt.savefig('cliff_walking_comparison.png', dpi=150, bbox_inches='tight')
plt.close()
print(" Created cliff_walking_comparison.png")
print("[OK] Created cliff_walking_comparison.png")


def create_function_approximation_diagram():
Expand Down Expand Up @@ -526,7 +529,7 @@ def create_function_approximation_diagram():
plt.tight_layout()
plt.savefig('function_approximation_diagram.png', dpi=150, bbox_inches='tight')
plt.close()
print(" Created function_approximation_diagram.png")
print("[OK] Created function_approximation_diagram.png")


def create_learning_curves():
Expand Down Expand Up @@ -574,80 +577,63 @@ def create_learning_curves():
plt.tight_layout()
plt.savefig('learning_curves_comparison.png', dpi=150, bbox_inches='tight')
plt.close()
print(" Created learning_curves_comparison.png")
print("[OK] Created learning_curves_comparison.png")


def create_rl_loop_diagram():
"""Create the basic RL loop diagram"""
fig, ax = plt.subplots(1, 1, figsize=(10, 8))
# Agent box
agent_box = FancyBboxPatch((0.35, 0.65), 0.3, 0.15,
"""Create a simple, clean RL interaction diagram - no fluff"""
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

# Agent box (top, with gap below)
agent_box = FancyBboxPatch((0.20, 0.58), 0.6, 0.30,
boxstyle="round,pad=0.02",
edgecolor='blue', facecolor='lightblue',
linewidth=3, alpha=0.8)
edgecolor='#2E86AB', facecolor='#A9D6E5',
linewidth=6, alpha=0.95)
ax.add_patch(agent_box)
ax.text(0.5, 0.725, 'Agent', ha='center', va='center',
fontsize=18, fontweight='bold')
# Environment box
env_box = FancyBboxPatch((0.35, 0.2), 0.3, 0.15,
ax.text(0.5, 0.73, 'AGENT', ha='center', va='center',
fontsize=40, fontweight='bold', color='#01497C')

# Environment box (bottom, with gap above)
env_box = FancyBboxPatch((0.20, 0.12), 0.6, 0.30,
boxstyle="round,pad=0.02",
edgecolor='green', facecolor='lightgreen',
linewidth=3, alpha=0.8)
edgecolor='#2D6A4F', facecolor='#95D5B2',
linewidth=6, alpha=0.95)
ax.add_patch(env_box)
ax.text(0.5, 0.275, 'Environment', ha='center', va='center',
fontsize=18, fontweight='bold')

# Arrows and labels
# Action arrow (agent to env)
action_arrow = FancyArrowPatch((0.5, 0.65), (0.5, 0.35),
arrowstyle='->', mutation_scale=30,
linewidth=3, color='purple')
ax.text(0.5, 0.27, 'ENVIRONMENT', ha='center', va='center',
fontsize=40, fontweight='bold', color='#1B4332')

# Action arrow (Agent → Environment) - centered, arrowhead inside environment
action_arrow = FancyArrowPatch((0.30, 0.61), (0.30, 0.44),
arrowstyle='->', mutation_scale=60,
linewidth=10, color='#6A4C93', zorder=3)
ax.add_patch(action_arrow)
ax.text(0.42, 0.5, 'Action', ha='center', fontsize=13,
fontweight='bold', color='purple',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))

# State arrow (env to agent, left side)
state_arrow = FancyArrowPatch((0.36, 0.35), (0.36, 0.65),
arrowstyle='->', mutation_scale=30,
linewidth=3, color='orange')
ax.add_patch(state_arrow)
ax.text(0.25, 0.5, 'State', ha='center', fontsize=13,
fontweight='bold', color='orange',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))

# Reward arrow (env to agent, right side)
reward_arrow = FancyArrowPatch((0.64, 0.35), (0.64, 0.65),
arrowstyle='->', mutation_scale=30,
linewidth=3, color='red')
ax.add_patch(reward_arrow)
ax.text(0.75, 0.5, 'Reward', ha='center', fontsize=13,
fontweight='bold', color='red',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.9))

# Add numbered steps
steps = [
(0.5, 0.92, '1. Observe State'),
(0.5, 0.57, '2. Choose Action'),
(0.5, 0.13, '3. Get Reward & New State'),
(0.1, 0.725, '4. Update Estimates'),
]

for x, y, text in steps:
ax.text(x, y, text, ha='center', fontsize=12, style='italic',
bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.5))

ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

# Action label - moved towards center
ax.text(0.18, 0.50, 'Action', ha='center', va='center',
fontsize=24, fontweight='bold', color='white',
bbox=dict(boxstyle='round,pad=0.8', facecolor='#6A4C93',
edgecolor='#472D5B', linewidth=3))

# Return arrow (Environment → Agent) - centered, arrowhead inside agent
return_arrow = FancyArrowPatch((0.70, 0.39), (0.70, 0.56),
arrowstyle='->', mutation_scale=60,
linewidth=10, color='#D62828', zorder=3)
ax.add_patch(return_arrow)

# State + Reward label - moved towards center
ax.text(0.82, 0.50, 'State +\nReward', ha='center', va='center',
fontsize=24, fontweight='bold', color='white',
bbox=dict(boxstyle='round,pad=0.8', facecolor='#D62828',
edgecolor='#9B2226', linewidth=3))

ax.set_xlim(0.05, 0.95)
ax.set_ylim(0.08, 0.92)
ax.axis('off')
ax.set_title('The Reinforcement Learning Loop', fontsize=16, fontweight='bold', pad=20)

plt.tight_layout()
plt.savefig('rl_loop_diagram.png', dpi=150, bbox_inches='tight')

plt.tight_layout(pad=0)
plt.savefig('rl_loop_diagram.png', dpi=150, bbox_inches='tight', pad_inches=0.1)
plt.close()
print(" Created rl_loop_diagram.png")
print("[OK] Created rl_loop_diagram.png")


def main():
Expand All @@ -668,7 +654,7 @@ def main():
create_learning_curves()

print("\n" + "="*50)
print(" All visualizations created successfully!")
print("[OK] All visualizations created successfully!")
print("="*50 + "\n")


Expand Down
Binary file modified gridworld_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified learning_curves_comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified rl_loop_diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified sarsa_vs_qlearning_diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified td_update_diagram.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.