diff --git a/disentangled_rnns/library/two_armed_bandits.py b/disentangled_rnns/library/two_armed_bandits.py index 199c02c..12c00e5 100644 --- a/disentangled_rnns/library/two_armed_bandits.py +++ b/disentangled_rnns/library/two_armed_bandits.py @@ -38,11 +38,14 @@ class BaseEnvironment(abc.ABC): Subclasses must implement the following methods: - new_sess() - step(choice) + + Attributes: + n_arms: The number of arms in the environment. """ - def __init__(self, seed: Optional[int] = None): + def __init__(self, seed: Optional[int] = None, n_arms: int = 2): self._random_state = np.random.RandomState(seed) - self.n_arms = 2 # For now we only support 2-armed bandits + self._n_arms = n_arms @abstractmethod def new_sess(self): @@ -67,6 +70,11 @@ def step(self, attempted_choice: int) -> tuple[int, float, int]: instructed: 1 if the choice was instructed, 0 otherwise """ + @property + def n_arms(self) -> int: + """Returns the current reward probabilities for each arm.""" + return self._n_arms + class EnvironmentBanditsDrift(BaseEnvironment): """Environment for a drifting two-armed bandit task. @@ -78,23 +86,25 @@ class EnvironmentBanditsDrift(BaseEnvironment): Attributes: sigma: A float, between 0 and 1, giving the magnitude of the drift reward_probs: Probability of reward associated with each action + n_arms: The number of arms in the environment. """ def __init__(self, sigma: float, p_instructed: float = 0.0, seed: Optional[int] = None, + n_arms: int = 2, ): - super().__init__() + super().__init__(seed=seed, n_arms=n_arms) # Check inputs if sigma < 0: msg = ('sigma was {}, but must be greater than 0') raise ValueError(msg.format(sigma)) + # Initialize persistent properties self._sigma = sigma self._p_instructed = p_instructed - self._random_state = np.random.RandomState(seed) # Sample new reward probabilities self.new_sess() @@ -102,7 +112,7 @@ def __init__(self, def new_sess(self): # Pick new reward probabilities. # Sample randomly between 0 and 1 - self._reward_probs = self._random_state.rand(2) + self._reward_probs = self._random_state.rand(self.n_arms) def step(self, attempted_choice: int) -> tuple[int, float, int]: @@ -126,21 +136,23 @@ def step(self, return choice, reward, instructed # Check inputs - if not np.logical_or(attempted_choice == 0, attempted_choice == 1): - msg = ('choice given was {}, but must be either 0 or 1') - raise ValueError(msg.format(attempted_choice)) + if attempted_choice not in list(range(self.n_arms)): + msg = (f'choice given was {attempted_choice}, but must be one of ' + f'{list(range(self.n_arms))}.') + raise ValueError(msg) # If choice was instructed, overrule it and decide randomly instructed = self._random_state.rand() < self._p_instructed if instructed: - choice = self._random_state.choice(2) + choice = self._random_state.choice(self.n_arms) else: choice = attempted_choice # Sample reward with the probability of the chosen side reward = self._random_state.rand() < self._reward_probs[choice] # Add gaussian noise to reward probabilities - drift = self._random_state.normal(loc=0, scale=self._sigma, size=2) + drift = self._random_state.normal( + loc=0, scale=self._sigma, size=self.n_arms) self._reward_probs += drift # Fix reward probs that've drifted below 0 or above 1 @@ -182,11 +194,12 @@ def __init__( Elements should be ints or nan. If nan, the choice is not instructed. If None, no choices are instructed. """ - super().__init__() + n_arms = payout_matrix.shape[2] + super().__init__(seed=None, n_arms=n_arms) + self._payout_matrix = payout_matrix self._n_sessions = payout_matrix.shape[0] self._n_trials = payout_matrix.shape[1] - self._n_actions = payout_matrix.shape[2] if instructed_matrix is not None: self._instructed_matrix = instructed_matrix @@ -207,23 +220,26 @@ def new_sess(self): self._current_trial = 0 def step(self, attempted_choice: int) -> tuple[int, float, int]: + # If agent choice is default empty value -1, return -1 for all outputs. if attempted_choice == -1: choice = -1 reward = -1 instructed = -1 return choice, reward, instructed - if attempted_choice > self._n_actions - 1: - raise ValueError( - 'Choice given was {}, but must be less than {}'.format( - attempted_choice, self._n_actions - 1 - ) - ) + + # Check inputted choice is valid. + if attempted_choice not in list(range(self.n_arms)): + msg = (f'choice given was {attempted_choice}, but must be one of ' + f'{list(range(self.n_arms))}.') + raise ValueError(msg) + if self._current_trial >= self._n_trials: raise NoMoreTrialsInSessionError( 'No more trials in session. ' f'Current trial {self._current_trial} is out of range ' f'[0, {self._n_trials})' ) + # If choice was instructed, overrule and replace with the instructed choice instruction = self._instructed_matrix[ self._current_session, self._current_trial @@ -240,6 +256,11 @@ def step(self, attempted_choice: int) -> tuple[int, float, int]: self._current_trial += 1 return choice, float(reward), int(instructed) + @property + def payout(self) -> np.ndarray: + """Get possible payouts for current session, trial across actions.""" + return self._payout_matrix[ + self._current_session, self._current_trial, :].copy() ########## # AGENTS #