diff --git a/disentangled_rnns/library/two_armed_bandits.py b/disentangled_rnns/library/two_armed_bandits.py index 232c96a..199c02c 100644 --- a/disentangled_rnns/library/two_armed_bandits.py +++ b/disentangled_rnns/library/two_armed_bandits.py @@ -14,6 +14,7 @@ """Two armed bandit experiments. Generate synthetic data, plot data.""" +import abc from collections.abc import Callable from typing import NamedTuple, Optional, Union @@ -24,7 +25,50 @@ import numpy as np -class EnvironmentBanditsDrift: +abstractmethod = abc.abstractmethod + +################ +# ENVIRONMENTS # +################ + + +class BaseEnvironment(abc.ABC): + """Base class for two-armed bandit environments. + + Subclasses must implement the following methods: + - new_sess() + - step(choice) + """ + + def __init__(self, seed: Optional[int] = None): + self._random_state = np.random.RandomState(seed) + self.n_arms = 2 # For now we only support 2-armed bandits + + @abstractmethod + def new_sess(self): + """Starts a new session (e.g., resets environment parameters). + + This method should be implemented by subclasses to initialize or + reset the environment's state at the beginning of a new session or episode. + """ + + @abstractmethod + def step(self, attempted_choice: int) -> tuple[int, float, int]: + """Executes a single step in the environment. + + Args: + attempted_choice: The action chosen by the agent. + + Returns: + choice: The action actually taken. May be different from the attempted + choice if the environment decides the choice should be instructed on + that trial. + reward: The reward received after taking the action. + instructed: 1 if the choice was instructed, 0 otherwise + """ + + +class EnvironmentBanditsDrift(BaseEnvironment): """Environment for a drifting two-armed bandit task. Reward probabilities on each arm are sampled randomly between 0 and @@ -38,8 +82,10 @@ class EnvironmentBanditsDrift: def __init__(self, sigma: float, + p_instructed: float = 0.0, seed: Optional[int] = None, ): + super().__init__() # Check inputs if sigma < 0: @@ -47,6 +93,7 @@ def __init__(self, raise ValueError(msg.format(sigma)) # Initialize persistent properties self._sigma = sigma + self._p_instructed = p_instructed self._random_state = np.random.RandomState(seed) # Sample new reward probabilities @@ -58,20 +105,37 @@ def new_sess(self): self._reward_probs = self._random_state.rand(2) def step(self, - choice: int) -> int: + attempted_choice: int) -> tuple[int, float, int]: """Run a single trial of the task. Args: - choice: The choice made by the agent. 0 or 1 + attempted_choice: The choice made by the agent. 0 or 1 Returns: + choice: The action actually taken. May be different from the attempted + choice if the environment decides the choice should be instructed on + that trial. reward: The reward to be given to the agent. 0 or 1. + instructed: 1 if the choice was instructed, 0 otherwise """ + if attempted_choice == -1: + choice = -1 + reward = -1 + instructed = -1 + return choice, reward, instructed + # Check inputs - if not np.logical_or(choice == 0, choice == 1): + if not np.logical_or(attempted_choice == 0, attempted_choice == 1): msg = ('choice given was {}, but must be either 0 or 1') - raise ValueError(msg.format(choice)) + raise ValueError(msg.format(attempted_choice)) + + # If choice was instructed, overrule it and decide randomly + instructed = self._random_state.rand() < self._p_instructed + if instructed: + choice = self._random_state.choice(2) + else: + choice = attempted_choice # Sample reward with the probability of the chosen side reward = self._random_state.rand() < self._reward_probs[choice] @@ -82,13 +146,106 @@ def step(self, # Fix reward probs that've drifted below 0 or above 1 self._reward_probs = np.clip(self._reward_probs, 0, 1) - return reward + return choice, float(reward), int(instructed) @property def reward_probs(self) -> np.ndarray: return self._reward_probs.copy() +class NoMoreTrialsInSessionError(ValueError): + pass + + +class NoMoreSessionsInDatasetError(ValueError): + pass + + +class EnvironmentPayoutMatrix(BaseEnvironment): + """Environment for a two-armed bandit task with a specified payout matrix.""" + + def __init__( + self, + payout_matrix: np.ndarray, + instructed_matrix: Optional[np.ndarray] = None, + ): + """Initialize the environment. + + Args: + payout_matrix: A numpy array of shape (n_sessions, n_actions, n_trials) + giving the reward for each session, action, and trial. These are + deterministic, i.e. for the same trial_num, session_num, and action, the + reward will always be the same. (If you'd like stochastic rewards you + can populate this matrix ahead of time). + instructed_matrix: A numpy array of shape (n_sessions, n_trials) giving + the choice that should be made, if any, for each session and trial. + Elements should be ints or nan. If nan, the choice is not instructed. If + None, no choices are instructed. + """ + super().__init__() + self._payout_matrix = payout_matrix + self._n_sessions = payout_matrix.shape[0] + self._n_trials = payout_matrix.shape[1] + self._n_actions = payout_matrix.shape[2] + + if instructed_matrix is not None: + self._instructed_matrix = instructed_matrix + else: + self._instructed_matrix = np.nan * np.zeros_like(payout_matrix) + + self._current_session = 0 + self._current_trial = 0 + + def new_sess(self): + self._current_session += 1 + if self._current_session >= self._n_sessions: + raise NoMoreSessionsInDatasetError( + 'No more sessions in dataset. ' + f'Current session {self._current_session} is out of range ' + f'[0, {self._n_sessions - 1})' + ) + self._current_trial = 0 + + def step(self, attempted_choice: int) -> tuple[int, float, int]: + if attempted_choice == -1: + choice = -1 + reward = -1 + instructed = -1 + return choice, reward, instructed + if attempted_choice > self._n_actions - 1: + raise ValueError( + 'Choice given was {}, but must be less than {}'.format( + attempted_choice, self._n_actions - 1 + ) + ) + if self._current_trial >= self._n_trials: + raise NoMoreTrialsInSessionError( + 'No more trials in session. ' + f'Current trial {self._current_trial} is out of range ' + f'[0, {self._n_trials})' + ) + # If choice was instructed, overrule and replace with the instructed choice + instruction = self._instructed_matrix[ + self._current_session, self._current_trial + ] + instructed = not np.isnan(instruction) + if instructed: + choice = int(instruction) + else: + choice = attempted_choice + + reward = self._payout_matrix[ + self._current_session, self._current_trial, choice + ] + self._current_trial += 1 + return choice, float(reward), int(instructed) + + +########## +# AGENTS # +########## + + class AgentQ: """An agent that runs "vanilla" Q-learning for the y-maze tasks. @@ -130,7 +287,7 @@ def get_choice(self) -> int: def update(self, choice: int, - reward: int): + reward: float): """Update the agent after one step of the task. Args: @@ -185,7 +342,7 @@ def get_choice(self) -> int: choice = np.random.choice(2, p=choice_probs) return choice - def update(self, choice: int, reward: int): + def update(self, choice: int, reward: float): """Update the agent after one step of the task. Args: @@ -280,7 +437,7 @@ def run_experiment(agent: Agent, n_steps: The number of steps in the session you'd like to generate Returns: - experiment: A YMazeSession holding choices and rewards from the session + experiment: A SessData object holding choices and rewards from the session """ choices = np.zeros(n_steps) rewards = np.zeros(n_steps) @@ -290,9 +447,9 @@ def run_experiment(agent: Agent, # First record environment reward probs reward_probs[step] = environment.reward_probs # First agent makes a choice - choice = agent.get_choice() + attempted_choice = agent.get_choice() # Then environment computes a reward - reward = environment.step(choice) + choice, reward, _ = environment.step(attempted_choice) # Finally agent learns agent.update(choice, reward) # Log choice and reward @@ -335,6 +492,7 @@ def create_dataset(agent: Agent, np.concatenate(([prev_choices], [prev_rewards]), axis=0), 0, 1 ) ys[:, sess_i] = np.expand_dims(experiment.choices, 1) + environment.new_sess() dataset = rnn_utils.DatasetRNN( xs=xs,