Skip to content

Internal #20

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 169 additions & 11 deletions disentangled_rnns/library/two_armed_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Two armed bandit experiments. Generate synthetic data, plot data."""

import abc
from collections.abc import Callable
from typing import NamedTuple, Optional, Union

Expand All @@ -24,7 +25,50 @@
import numpy as np


class EnvironmentBanditsDrift:
abstractmethod = abc.abstractmethod

################
# ENVIRONMENTS #
################


class BaseEnvironment(abc.ABC):
"""Base class for two-armed bandit environments.

Subclasses must implement the following methods:
- new_sess()
- step(choice)
"""

def __init__(self, seed: Optional[int] = None):
self._random_state = np.random.RandomState(seed)
self.n_arms = 2 # For now we only support 2-armed bandits

@abstractmethod
def new_sess(self):
"""Starts a new session (e.g., resets environment parameters).

This method should be implemented by subclasses to initialize or
reset the environment's state at the beginning of a new session or episode.
"""

@abstractmethod
def step(self, attempted_choice: int) -> tuple[int, float, int]:
"""Executes a single step in the environment.

Args:
attempted_choice: The action chosen by the agent.

Returns:
choice: The action actually taken. May be different from the attempted
choice if the environment decides the choice should be instructed on
that trial.
reward: The reward received after taking the action.
instructed: 1 if the choice was instructed, 0 otherwise
"""


class EnvironmentBanditsDrift(BaseEnvironment):
"""Environment for a drifting two-armed bandit task.

Reward probabilities on each arm are sampled randomly between 0 and
Expand All @@ -38,15 +82,18 @@ class EnvironmentBanditsDrift:

def __init__(self,
sigma: float,
p_instructed: float = 0.0,
seed: Optional[int] = None,
):
super().__init__()

# Check inputs
if sigma < 0:
msg = ('sigma was {}, but must be greater than 0')
raise ValueError(msg.format(sigma))
# Initialize persistent properties
self._sigma = sigma
self._p_instructed = p_instructed
self._random_state = np.random.RandomState(seed)

# Sample new reward probabilities
Expand All @@ -58,20 +105,37 @@ def new_sess(self):
self._reward_probs = self._random_state.rand(2)

def step(self,
choice: int) -> int:
attempted_choice: int) -> tuple[int, float, int]:
"""Run a single trial of the task.

Args:
choice: The choice made by the agent. 0 or 1
attempted_choice: The choice made by the agent. 0 or 1

Returns:
choice: The action actually taken. May be different from the attempted
choice if the environment decides the choice should be instructed on
that trial.
reward: The reward to be given to the agent. 0 or 1.
instructed: 1 if the choice was instructed, 0 otherwise

"""
if attempted_choice == -1:
choice = -1
reward = -1
instructed = -1
return choice, reward, instructed

# Check inputs
if not np.logical_or(choice == 0, choice == 1):
if not np.logical_or(attempted_choice == 0, attempted_choice == 1):
msg = ('choice given was {}, but must be either 0 or 1')
raise ValueError(msg.format(choice))
raise ValueError(msg.format(attempted_choice))

# If choice was instructed, overrule it and decide randomly
instructed = self._random_state.rand() < self._p_instructed
if instructed:
choice = self._random_state.choice(2)
else:
choice = attempted_choice

# Sample reward with the probability of the chosen side
reward = self._random_state.rand() < self._reward_probs[choice]
Expand All @@ -82,13 +146,106 @@ def step(self,
# Fix reward probs that've drifted below 0 or above 1
self._reward_probs = np.clip(self._reward_probs, 0, 1)

return reward
return choice, float(reward), int(instructed)

@property
def reward_probs(self) -> np.ndarray:
return self._reward_probs.copy()


class NoMoreTrialsInSessionError(ValueError):
pass


class NoMoreSessionsInDatasetError(ValueError):
pass


class EnvironmentPayoutMatrix(BaseEnvironment):
"""Environment for a two-armed bandit task with a specified payout matrix."""

def __init__(
self,
payout_matrix: np.ndarray,
instructed_matrix: Optional[np.ndarray] = None,
):
"""Initialize the environment.

Args:
payout_matrix: A numpy array of shape (n_sessions, n_actions, n_trials)
giving the reward for each session, action, and trial. These are
deterministic, i.e. for the same trial_num, session_num, and action, the
reward will always be the same. (If you'd like stochastic rewards you
can populate this matrix ahead of time).
instructed_matrix: A numpy array of shape (n_sessions, n_trials) giving
the choice that should be made, if any, for each session and trial.
Elements should be ints or nan. If nan, the choice is not instructed. If
None, no choices are instructed.
"""
super().__init__()
self._payout_matrix = payout_matrix
self._n_sessions = payout_matrix.shape[0]
self._n_trials = payout_matrix.shape[1]
self._n_actions = payout_matrix.shape[2]

if instructed_matrix is not None:
self._instructed_matrix = instructed_matrix
else:
self._instructed_matrix = np.nan * np.zeros_like(payout_matrix)

self._current_session = 0
self._current_trial = 0

def new_sess(self):
self._current_session += 1
if self._current_session >= self._n_sessions:
raise NoMoreSessionsInDatasetError(
'No more sessions in dataset. '
f'Current session {self._current_session} is out of range '
f'[0, {self._n_sessions - 1})'
)
self._current_trial = 0

def step(self, attempted_choice: int) -> tuple[int, float, int]:
if attempted_choice == -1:
choice = -1
reward = -1
instructed = -1
return choice, reward, instructed
if attempted_choice > self._n_actions - 1:
raise ValueError(
'Choice given was {}, but must be less than {}'.format(
attempted_choice, self._n_actions - 1
)
)
if self._current_trial >= self._n_trials:
raise NoMoreTrialsInSessionError(
'No more trials in session. '
f'Current trial {self._current_trial} is out of range '
f'[0, {self._n_trials})'
)
# If choice was instructed, overrule and replace with the instructed choice
instruction = self._instructed_matrix[
self._current_session, self._current_trial
]
instructed = not np.isnan(instruction)
if instructed:
choice = int(instruction)
else:
choice = attempted_choice

reward = self._payout_matrix[
self._current_session, self._current_trial, choice
]
self._current_trial += 1
return choice, float(reward), int(instructed)


##########
# AGENTS #
##########


class AgentQ:
"""An agent that runs "vanilla" Q-learning for the y-maze tasks.

Expand Down Expand Up @@ -130,7 +287,7 @@ def get_choice(self) -> int:

def update(self,
choice: int,
reward: int):
reward: float):
"""Update the agent after one step of the task.

Args:
Expand Down Expand Up @@ -185,7 +342,7 @@ def get_choice(self) -> int:
choice = np.random.choice(2, p=choice_probs)
return choice

def update(self, choice: int, reward: int):
def update(self, choice: int, reward: float):
"""Update the agent after one step of the task.

Args:
Expand Down Expand Up @@ -280,7 +437,7 @@ def run_experiment(agent: Agent,
n_steps: The number of steps in the session you'd like to generate

Returns:
experiment: A YMazeSession holding choices and rewards from the session
experiment: A SessData object holding choices and rewards from the session
"""
choices = np.zeros(n_steps)
rewards = np.zeros(n_steps)
Expand All @@ -290,9 +447,9 @@ def run_experiment(agent: Agent,
# First record environment reward probs
reward_probs[step] = environment.reward_probs
# First agent makes a choice
choice = agent.get_choice()
attempted_choice = agent.get_choice()
# Then environment computes a reward
reward = environment.step(choice)
choice, reward, _ = environment.step(attempted_choice)
# Finally agent learns
agent.update(choice, reward)
# Log choice and reward
Expand Down Expand Up @@ -335,6 +492,7 @@ def create_dataset(agent: Agent,
np.concatenate(([prev_choices], [prev_rewards]), axis=0), 0, 1
)
ys[:, sess_i] = np.expand_dims(experiment.choices, 1)
environment.new_sess()

dataset = rnn_utils.DatasetRNN(
xs=xs,
Expand Down
Loading