Skip to content

Commit 6fc868f

Browse files
kevin-j-millercopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 718144984
1 parent 4150bec commit 6fc868f

File tree

1 file changed

+169
-11
lines changed

1 file changed

+169
-11
lines changed

disentangled_rnns/library/two_armed_bandits.py

+169-11
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""Two armed bandit experiments. Generate synthetic data, plot data."""
1616

17+
import abc
1718
from collections.abc import Callable
1819
from typing import NamedTuple, Optional, Union
1920

@@ -24,7 +25,50 @@
2425
import numpy as np
2526

2627

27-
class EnvironmentBanditsDrift:
28+
abstractmethod = abc.abstractmethod
29+
30+
################
31+
# ENVIRONMENTS #
32+
################
33+
34+
35+
class BaseEnvironment(abc.ABC):
36+
"""Base class for two-armed bandit environments.
37+
38+
Subclasses must implement the following methods:
39+
- new_sess()
40+
- step(choice)
41+
"""
42+
43+
def __init__(self, seed: Optional[int] = None):
44+
self._random_state = np.random.RandomState(seed)
45+
self.n_arms = 2 # For now we only support 2-armed bandits
46+
47+
@abstractmethod
48+
def new_sess(self):
49+
"""Starts a new session (e.g., resets environment parameters).
50+
51+
This method should be implemented by subclasses to initialize or
52+
reset the environment's state at the beginning of a new session or episode.
53+
"""
54+
55+
@abstractmethod
56+
def step(self, attempted_choice: int) -> tuple[int, float, int]:
57+
"""Executes a single step in the environment.
58+
59+
Args:
60+
attempted_choice: The action chosen by the agent.
61+
62+
Returns:
63+
choice: The action actually taken. May be different from the attempted
64+
choice if the environment decides the choice should be instructed on
65+
that trial.
66+
reward: The reward received after taking the action.
67+
instructed: 1 if the choice was instructed, 0 otherwise
68+
"""
69+
70+
71+
class EnvironmentBanditsDrift(BaseEnvironment):
2872
"""Environment for a drifting two-armed bandit task.
2973
3074
Reward probabilities on each arm are sampled randomly between 0 and
@@ -38,15 +82,18 @@ class EnvironmentBanditsDrift:
3882

3983
def __init__(self,
4084
sigma: float,
85+
p_instructed: float = 0.0,
4186
seed: Optional[int] = None,
4287
):
88+
super().__init__()
4389

4490
# Check inputs
4591
if sigma < 0:
4692
msg = ('sigma was {}, but must be greater than 0')
4793
raise ValueError(msg.format(sigma))
4894
# Initialize persistent properties
4995
self._sigma = sigma
96+
self._p_instructed = p_instructed
5097
self._random_state = np.random.RandomState(seed)
5198

5299
# Sample new reward probabilities
@@ -58,20 +105,37 @@ def new_sess(self):
58105
self._reward_probs = self._random_state.rand(2)
59106

60107
def step(self,
61-
choice: int) -> int:
108+
attempted_choice: int) -> tuple[int, float, int]:
62109
"""Run a single trial of the task.
63110
64111
Args:
65-
choice: The choice made by the agent. 0 or 1
112+
attempted_choice: The choice made by the agent. 0 or 1
66113
67114
Returns:
115+
choice: The action actually taken. May be different from the attempted
116+
choice if the environment decides the choice should be instructed on
117+
that trial.
68118
reward: The reward to be given to the agent. 0 or 1.
119+
instructed: 1 if the choice was instructed, 0 otherwise
69120
70121
"""
122+
if attempted_choice == -1:
123+
choice = -1
124+
reward = -1
125+
instructed = -1
126+
return choice, reward, instructed
127+
71128
# Check inputs
72-
if not np.logical_or(choice == 0, choice == 1):
129+
if not np.logical_or(attempted_choice == 0, attempted_choice == 1):
73130
msg = ('choice given was {}, but must be either 0 or 1')
74-
raise ValueError(msg.format(choice))
131+
raise ValueError(msg.format(attempted_choice))
132+
133+
# If choice was instructed, overrule it and decide randomly
134+
instructed = self._random_state.rand() < self._p_instructed
135+
if instructed:
136+
choice = self._random_state.choice(2)
137+
else:
138+
choice = attempted_choice
75139

76140
# Sample reward with the probability of the chosen side
77141
reward = self._random_state.rand() < self._reward_probs[choice]
@@ -82,13 +146,106 @@ def step(self,
82146
# Fix reward probs that've drifted below 0 or above 1
83147
self._reward_probs = np.clip(self._reward_probs, 0, 1)
84148

85-
return reward
149+
return choice, float(reward), int(instructed)
86150

87151
@property
88152
def reward_probs(self) -> np.ndarray:
89153
return self._reward_probs.copy()
90154

91155

156+
class NoMoreTrialsInSessionError(ValueError):
157+
pass
158+
159+
160+
class NoMoreSessionsInDatasetError(ValueError):
161+
pass
162+
163+
164+
class EnvironmentPayoutMatrix(BaseEnvironment):
165+
"""Environment for a two-armed bandit task with a specified payout matrix."""
166+
167+
def __init__(
168+
self,
169+
payout_matrix: np.ndarray,
170+
instructed_matrix: Optional[np.ndarray] = None,
171+
):
172+
"""Initialize the environment.
173+
174+
Args:
175+
payout_matrix: A numpy array of shape (n_sessions, n_actions, n_trials)
176+
giving the reward for each session, action, and trial. These are
177+
deterministic, i.e. for the same trial_num, session_num, and action, the
178+
reward will always be the same. (If you'd like stochastic rewards you
179+
can populate this matrix ahead of time).
180+
instructed_matrix: A numpy array of shape (n_sessions, n_trials) giving
181+
the choice that should be made, if any, for each session and trial.
182+
Elements should be ints or nan. If nan, the choice is not instructed. If
183+
None, no choices are instructed.
184+
"""
185+
super().__init__()
186+
self._payout_matrix = payout_matrix
187+
self._n_sessions = payout_matrix.shape[0]
188+
self._n_trials = payout_matrix.shape[1]
189+
self._n_actions = payout_matrix.shape[2]
190+
191+
if instructed_matrix is not None:
192+
self._instructed_matrix = instructed_matrix
193+
else:
194+
self._instructed_matrix = np.nan * np.zeros_like(payout_matrix)
195+
196+
self._current_session = 0
197+
self._current_trial = 0
198+
199+
def new_sess(self):
200+
self._current_session += 1
201+
if self._current_session >= self._n_sessions:
202+
raise NoMoreSessionsInDatasetError(
203+
'No more sessions in dataset. '
204+
f'Current session {self._current_session} is out of range '
205+
f'[0, {self._n_sessions - 1})'
206+
)
207+
self._current_trial = 0
208+
209+
def step(self, attempted_choice: int) -> tuple[int, float, int]:
210+
if attempted_choice == -1:
211+
choice = -1
212+
reward = -1
213+
instructed = -1
214+
return choice, reward, instructed
215+
if attempted_choice > self._n_actions - 1:
216+
raise ValueError(
217+
'Choice given was {}, but must be less than {}'.format(
218+
attempted_choice, self._n_actions - 1
219+
)
220+
)
221+
if self._current_trial >= self._n_trials:
222+
raise NoMoreTrialsInSessionError(
223+
'No more trials in session. '
224+
f'Current trial {self._current_trial} is out of range '
225+
f'[0, {self._n_trials})'
226+
)
227+
# If choice was instructed, overrule and replace with the instructed choice
228+
instruction = self._instructed_matrix[
229+
self._current_session, self._current_trial
230+
]
231+
instructed = not np.isnan(instruction)
232+
if instructed:
233+
choice = int(instruction)
234+
else:
235+
choice = attempted_choice
236+
237+
reward = self._payout_matrix[
238+
self._current_session, self._current_trial, choice
239+
]
240+
self._current_trial += 1
241+
return choice, float(reward), int(instructed)
242+
243+
244+
##########
245+
# AGENTS #
246+
##########
247+
248+
92249
class AgentQ:
93250
"""An agent that runs "vanilla" Q-learning for the y-maze tasks.
94251
@@ -130,7 +287,7 @@ def get_choice(self) -> int:
130287

131288
def update(self,
132289
choice: int,
133-
reward: int):
290+
reward: float):
134291
"""Update the agent after one step of the task.
135292
136293
Args:
@@ -185,7 +342,7 @@ def get_choice(self) -> int:
185342
choice = np.random.choice(2, p=choice_probs)
186343
return choice
187344

188-
def update(self, choice: int, reward: int):
345+
def update(self, choice: int, reward: float):
189346
"""Update the agent after one step of the task.
190347
191348
Args:
@@ -280,7 +437,7 @@ def run_experiment(agent: Agent,
280437
n_steps: The number of steps in the session you'd like to generate
281438
282439
Returns:
283-
experiment: A YMazeSession holding choices and rewards from the session
440+
experiment: A SessData object holding choices and rewards from the session
284441
"""
285442
choices = np.zeros(n_steps)
286443
rewards = np.zeros(n_steps)
@@ -290,9 +447,9 @@ def run_experiment(agent: Agent,
290447
# First record environment reward probs
291448
reward_probs[step] = environment.reward_probs
292449
# First agent makes a choice
293-
choice = agent.get_choice()
450+
attempted_choice = agent.get_choice()
294451
# Then environment computes a reward
295-
reward = environment.step(choice)
452+
choice, reward, _ = environment.step(attempted_choice)
296453
# Finally agent learns
297454
agent.update(choice, reward)
298455
# Log choice and reward
@@ -335,6 +492,7 @@ def create_dataset(agent: Agent,
335492
np.concatenate(([prev_choices], [prev_rewards]), axis=0), 0, 1
336493
)
337494
ys[:, sess_i] = np.expand_dims(experiment.choices, 1)
495+
environment.new_sess()
338496

339497
dataset = rnn_utils.DatasetRNN(
340498
xs=xs,

0 commit comments

Comments
 (0)