Skip to content

None #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

None #26

Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions disentangled_rnns/library/two_armed_bandits.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@ class BaseEnvironment(abc.ABC):
Subclasses must implement the following methods:
- new_sess()
- step(choice)

Attributes:
n_arms: The number of arms in the environment.
"""

def __init__(self, seed: Optional[int] = None):
def __init__(self, seed: Optional[int] = None, n_arms: int = 2):
self._random_state = np.random.RandomState(seed)
self.n_arms = 2 # For now we only support 2-armed bandits
self._n_arms = n_arms

@abstractmethod
def new_sess(self):
Expand All @@ -67,6 +70,11 @@ def step(self, attempted_choice: int) -> tuple[int, float, int]:
instructed: 1 if the choice was instructed, 0 otherwise
"""

@property
def n_arms(self) -> int:
"""Returns the current reward probabilities for each arm."""
return self._n_arms


class EnvironmentBanditsDrift(BaseEnvironment):
"""Environment for a drifting two-armed bandit task.
Expand All @@ -78,31 +86,33 @@ class EnvironmentBanditsDrift(BaseEnvironment):
Attributes:
sigma: A float, between 0 and 1, giving the magnitude of the drift
reward_probs: Probability of reward associated with each action
n_arms: The number of arms in the environment.
"""

def __init__(self,
sigma: float,
p_instructed: float = 0.0,
seed: Optional[int] = None,
n_arms: int = 2,
):
super().__init__()
super().__init__(seed=seed, n_arms=n_arms)

# Check inputs
if sigma < 0:
msg = ('sigma was {}, but must be greater than 0')
raise ValueError(msg.format(sigma))

# Initialize persistent properties
self._sigma = sigma
self._p_instructed = p_instructed
self._random_state = np.random.RandomState(seed)

# Sample new reward probabilities
self.new_sess()

def new_sess(self):
# Pick new reward probabilities.
# Sample randomly between 0 and 1
self._reward_probs = self._random_state.rand(2)
self._reward_probs = self._random_state.rand(self.n_arms)

def step(self,
attempted_choice: int) -> tuple[int, float, int]:
Expand All @@ -126,21 +136,23 @@ def step(self,
return choice, reward, instructed

# Check inputs
if not np.logical_or(attempted_choice == 0, attempted_choice == 1):
msg = ('choice given was {}, but must be either 0 or 1')
raise ValueError(msg.format(attempted_choice))
if attempted_choice not in list(range(self.n_arms)):
msg = (f'choice given was {attempted_choice}, but must be one of '
f'{list(range(self.n_arms))}.')
raise ValueError(msg)

# If choice was instructed, overrule it and decide randomly
instructed = self._random_state.rand() < self._p_instructed
if instructed:
choice = self._random_state.choice(2)
choice = self._random_state.choice(self.n_arms)
else:
choice = attempted_choice

# Sample reward with the probability of the chosen side
reward = self._random_state.rand() < self._reward_probs[choice]
# Add gaussian noise to reward probabilities
drift = self._random_state.normal(loc=0, scale=self._sigma, size=2)
drift = self._random_state.normal(
loc=0, scale=self._sigma, size=self.n_arms)
self._reward_probs += drift

# Fix reward probs that've drifted below 0 or above 1
Expand Down Expand Up @@ -182,11 +194,12 @@ def __init__(
Elements should be ints or nan. If nan, the choice is not instructed. If
None, no choices are instructed.
"""
super().__init__()
n_arms = payout_matrix.shape[2]
super().__init__(seed=None, n_arms=n_arms)

self._payout_matrix = payout_matrix
self._n_sessions = payout_matrix.shape[0]
self._n_trials = payout_matrix.shape[1]
self._n_actions = payout_matrix.shape[2]

if instructed_matrix is not None:
self._instructed_matrix = instructed_matrix
Expand All @@ -207,23 +220,26 @@ def new_sess(self):
self._current_trial = 0

def step(self, attempted_choice: int) -> tuple[int, float, int]:
# If agent choice is default empty value -1, return -1 for all outputs.
if attempted_choice == -1:
choice = -1
reward = -1
instructed = -1
return choice, reward, instructed
if attempted_choice > self._n_actions - 1:
raise ValueError(
'Choice given was {}, but must be less than {}'.format(
attempted_choice, self._n_actions - 1
)
)

# Check inputted choice is valid.
if attempted_choice not in list(range(self.n_arms)):
msg = (f'choice given was {attempted_choice}, but must be one of '
f'{list(range(self.n_arms))}.')
raise ValueError(msg)

if self._current_trial >= self._n_trials:
raise NoMoreTrialsInSessionError(
'No more trials in session. '
f'Current trial {self._current_trial} is out of range '
f'[0, {self._n_trials})'
)

# If choice was instructed, overrule and replace with the instructed choice
instruction = self._instructed_matrix[
self._current_session, self._current_trial
Expand All @@ -240,6 +256,11 @@ def step(self, attempted_choice: int) -> tuple[int, float, int]:
self._current_trial += 1
return choice, float(reward), int(instructed)

@property
def payout(self) -> np.ndarray:
"""Get possible payouts for current session, trial across actions."""
return self._payout_matrix[
self._current_session, self._current_trial, :].copy()

##########
# AGENTS #
Expand Down
Loading