14
14
15
15
"""Two armed bandit experiments. Generate synthetic data, plot data."""
16
16
17
+ import abc
17
18
from collections .abc import Callable
18
19
from typing import NamedTuple , Optional , Union
19
20
24
25
import numpy as np
25
26
26
27
27
- class EnvironmentBanditsDrift :
28
+ abstractmethod = abc .abstractmethod
29
+
30
+ ################
31
+ # ENVIRONMENTS #
32
+ ################
33
+
34
+
35
+ class BaseEnvironment (abc .ABC ):
36
+ """Base class for two-armed bandit environments.
37
+
38
+ Subclasses must implement the following methods:
39
+ - new_sess()
40
+ - step(choice)
41
+ """
42
+
43
+ def __init__ (self , seed : Optional [int ] = None ):
44
+ self ._random_state = np .random .RandomState (seed )
45
+ self .n_arms = 2 # For now we only support 2-armed bandits
46
+
47
+ @abstractmethod
48
+ def new_sess (self ):
49
+ """Starts a new session (e.g., resets environment parameters).
50
+
51
+ This method should be implemented by subclasses to initialize or
52
+ reset the environment's state at the beginning of a new session or episode.
53
+ """
54
+
55
+ @abstractmethod
56
+ def step (self , attempted_choice : int ) -> tuple [int , float , int ]:
57
+ """Executes a single step in the environment.
58
+
59
+ Args:
60
+ attempted_choice: The action chosen by the agent.
61
+
62
+ Returns:
63
+ choice: The action actually taken. May be different from the attempted
64
+ choice if the environment decides the choice should be instructed on
65
+ that trial.
66
+ reward: The reward received after taking the action.
67
+ instructed: 1 if the choice was instructed, 0 otherwise
68
+ """
69
+
70
+
71
+ class EnvironmentBanditsDrift (BaseEnvironment ):
28
72
"""Environment for a drifting two-armed bandit task.
29
73
30
74
Reward probabilities on each arm are sampled randomly between 0 and
@@ -38,15 +82,18 @@ class EnvironmentBanditsDrift:
38
82
39
83
def __init__ (self ,
40
84
sigma : float ,
85
+ p_instructed : float = 0.0 ,
41
86
seed : Optional [int ] = None ,
42
87
):
88
+ super ().__init__ ()
43
89
44
90
# Check inputs
45
91
if sigma < 0 :
46
92
msg = ('sigma was {}, but must be greater than 0' )
47
93
raise ValueError (msg .format (sigma ))
48
94
# Initialize persistent properties
49
95
self ._sigma = sigma
96
+ self ._p_instructed = p_instructed
50
97
self ._random_state = np .random .RandomState (seed )
51
98
52
99
# Sample new reward probabilities
@@ -58,20 +105,37 @@ def new_sess(self):
58
105
self ._reward_probs = self ._random_state .rand (2 )
59
106
60
107
def step (self ,
61
- choice : int ) -> int :
108
+ attempted_choice : int ) -> tuple [ int , float , int ] :
62
109
"""Run a single trial of the task.
63
110
64
111
Args:
65
- choice : The choice made by the agent. 0 or 1
112
+ attempted_choice : The choice made by the agent. 0 or 1
66
113
67
114
Returns:
115
+ choice: The action actually taken. May be different from the attempted
116
+ choice if the environment decides the choice should be instructed on
117
+ that trial.
68
118
reward: The reward to be given to the agent. 0 or 1.
119
+ instructed: 1 if the choice was instructed, 0 otherwise
69
120
70
121
"""
122
+ if attempted_choice == - 1 :
123
+ choice = - 1
124
+ reward = - 1
125
+ instructed = - 1
126
+ return choice , reward , instructed
127
+
71
128
# Check inputs
72
- if not np .logical_or (choice == 0 , choice == 1 ):
129
+ if not np .logical_or (attempted_choice == 0 , attempted_choice == 1 ):
73
130
msg = ('choice given was {}, but must be either 0 or 1' )
74
- raise ValueError (msg .format (choice ))
131
+ raise ValueError (msg .format (attempted_choice ))
132
+
133
+ # If choice was instructed, overrule it and decide randomly
134
+ instructed = self ._random_state .rand () < self ._p_instructed
135
+ if instructed :
136
+ choice = self ._random_state .choice (2 )
137
+ else :
138
+ choice = attempted_choice
75
139
76
140
# Sample reward with the probability of the chosen side
77
141
reward = self ._random_state .rand () < self ._reward_probs [choice ]
@@ -82,13 +146,106 @@ def step(self,
82
146
# Fix reward probs that've drifted below 0 or above 1
83
147
self ._reward_probs = np .clip (self ._reward_probs , 0 , 1 )
84
148
85
- return reward
149
+ return choice , float ( reward ), int ( instructed )
86
150
87
151
@property
88
152
def reward_probs (self ) -> np .ndarray :
89
153
return self ._reward_probs .copy ()
90
154
91
155
156
+ class NoMoreTrialsInSessionError (ValueError ):
157
+ pass
158
+
159
+
160
+ class NoMoreSessionsInDatasetError (ValueError ):
161
+ pass
162
+
163
+
164
+ class EnvironmentPayoutMatrix (BaseEnvironment ):
165
+ """Environment for a two-armed bandit task with a specified payout matrix."""
166
+
167
+ def __init__ (
168
+ self ,
169
+ payout_matrix : np .ndarray ,
170
+ instructed_matrix : Optional [np .ndarray ] = None ,
171
+ ):
172
+ """Initialize the environment.
173
+
174
+ Args:
175
+ payout_matrix: A numpy array of shape (n_sessions, n_actions, n_trials)
176
+ giving the reward for each session, action, and trial. These are
177
+ deterministic, i.e. for the same trial_num, session_num, and action, the
178
+ reward will always be the same. (If you'd like stochastic rewards you
179
+ can populate this matrix ahead of time).
180
+ instructed_matrix: A numpy array of shape (n_sessions, n_trials) giving
181
+ the choice that should be made, if any, for each session and trial.
182
+ Elements should be ints or nan. If nan, the choice is not instructed. If
183
+ None, no choices are instructed.
184
+ """
185
+ super ().__init__ ()
186
+ self ._payout_matrix = payout_matrix
187
+ self ._n_sessions = payout_matrix .shape [0 ]
188
+ self ._n_trials = payout_matrix .shape [1 ]
189
+ self ._n_actions = payout_matrix .shape [2 ]
190
+
191
+ if instructed_matrix is not None :
192
+ self ._instructed_matrix = instructed_matrix
193
+ else :
194
+ self ._instructed_matrix = np .nan * np .zeros_like (payout_matrix )
195
+
196
+ self ._current_session = 0
197
+ self ._current_trial = 0
198
+
199
+ def new_sess (self ):
200
+ self ._current_session += 1
201
+ if self ._current_session >= self ._n_sessions :
202
+ raise NoMoreSessionsInDatasetError (
203
+ 'No more sessions in dataset. '
204
+ f'Current session { self ._current_session } is out of range '
205
+ f'[0, { self ._n_sessions - 1 } )'
206
+ )
207
+ self ._current_trial = 0
208
+
209
+ def step (self , attempted_choice : int ) -> tuple [int , float , int ]:
210
+ if attempted_choice == - 1 :
211
+ choice = - 1
212
+ reward = - 1
213
+ instructed = - 1
214
+ return choice , reward , instructed
215
+ if attempted_choice > self ._n_actions - 1 :
216
+ raise ValueError (
217
+ 'Choice given was {}, but must be less than {}' .format (
218
+ attempted_choice , self ._n_actions - 1
219
+ )
220
+ )
221
+ if self ._current_trial >= self ._n_trials :
222
+ raise NoMoreTrialsInSessionError (
223
+ 'No more trials in session. '
224
+ f'Current trial { self ._current_trial } is out of range '
225
+ f'[0, { self ._n_trials } )'
226
+ )
227
+ # If choice was instructed, overrule and replace with the instructed choice
228
+ instruction = self ._instructed_matrix [
229
+ self ._current_session , self ._current_trial
230
+ ]
231
+ instructed = not np .isnan (instruction )
232
+ if instructed :
233
+ choice = int (instruction )
234
+ else :
235
+ choice = attempted_choice
236
+
237
+ reward = self ._payout_matrix [
238
+ self ._current_session , self ._current_trial , choice
239
+ ]
240
+ self ._current_trial += 1
241
+ return choice , float (reward ), int (instructed )
242
+
243
+
244
+ ##########
245
+ # AGENTS #
246
+ ##########
247
+
248
+
92
249
class AgentQ :
93
250
"""An agent that runs "vanilla" Q-learning for the y-maze tasks.
94
251
@@ -130,7 +287,7 @@ def get_choice(self) -> int:
130
287
131
288
def update (self ,
132
289
choice : int ,
133
- reward : int ):
290
+ reward : float ):
134
291
"""Update the agent after one step of the task.
135
292
136
293
Args:
@@ -185,7 +342,7 @@ def get_choice(self) -> int:
185
342
choice = np .random .choice (2 , p = choice_probs )
186
343
return choice
187
344
188
- def update (self , choice : int , reward : int ):
345
+ def update (self , choice : int , reward : float ):
189
346
"""Update the agent after one step of the task.
190
347
191
348
Args:
@@ -280,7 +437,7 @@ def run_experiment(agent: Agent,
280
437
n_steps: The number of steps in the session you'd like to generate
281
438
282
439
Returns:
283
- experiment: A YMazeSession holding choices and rewards from the session
440
+ experiment: A SessData object holding choices and rewards from the session
284
441
"""
285
442
choices = np .zeros (n_steps )
286
443
rewards = np .zeros (n_steps )
@@ -290,9 +447,9 @@ def run_experiment(agent: Agent,
290
447
# First record environment reward probs
291
448
reward_probs [step ] = environment .reward_probs
292
449
# First agent makes a choice
293
- choice = agent .get_choice ()
450
+ attempted_choice = agent .get_choice ()
294
451
# Then environment computes a reward
295
- reward = environment .step (choice )
452
+ choice , reward , _ = environment .step (attempted_choice )
296
453
# Finally agent learns
297
454
agent .update (choice , reward )
298
455
# Log choice and reward
@@ -335,6 +492,7 @@ def create_dataset(agent: Agent,
335
492
np .concatenate (([prev_choices ], [prev_rewards ]), axis = 0 ), 0 , 1
336
493
)
337
494
ys [:, sess_i ] = np .expand_dims (experiment .choices , 1 )
495
+ environment .new_sess ()
338
496
339
497
dataset = rnn_utils .DatasetRNN (
340
498
xs = xs ,
0 commit comments