-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpolicy.py
172 lines (128 loc) · 5.25 KB
/
policy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""Implementation of different types of policies"""
import numpy as np
import random
from abc import ABC, abstractmethod
from collections import defaultdict
from env import Action, Easy21, State
from typing import Callable, List
class Policy(ABC):
"""A base policy. This class should not be instantiated."""
def __init__(self):
self._pi = np.full(Easy21.state_space, fill_value=Action.hit)
def __setitem__(self, s: State, a: Action) -> None:
"""
Sets the given action for the given state.
:param State s: The state to update
:param Action a: The action to assign to the state
"""
self._pi[s.dealer_first_card, s.player_sum] = a
@abstractmethod
def __getitem__(self, s: State) -> Action:
"""
Retrieves the action for the given state.
:param State s: The state to retrieve an action for
:return: The action
:rtype: Action
"""
raise NotImplementedError
@abstractmethod
def greedy_prob(self, s: State) -> float:
"""
Returns the probability of selecting a greedy action under this policy in state `s`.
:param State s: The state to compute the greedy probability for
:return: The greedy probability
:rtype: float
"""
raise NotImplementedError
@abstractmethod
def prob(self, a: Action, s: State) -> float:
"""
Returns the probability of selecting action `a` in state `s` under this policy.
:param Action a: The action
:param State s: The state
:return: The probability
:rtype: float
"""
raise NotImplementedError
class ApproximationPolicy(ABC):
"""A base policy that uses function approximation. This class should not be instantiated."""
def __init__(self, epsilon: float, approximator: Callable[[State], List[float]]):
self._epsilon = epsilon
self._approximator = approximator
@abstractmethod
def __getitem__(self, s: State) -> Action:
"""
Retrieves the action for the given state.
:param State s: The state to retrieve an action for
:return: The action
:rtype: Action
"""
raise NotImplementedError
class RandomPolicy(Policy):
"""A policy that selects actions randomly with equal probability"""
def __init__(self, seed: int = None):
"""
:param int seed: The seed to use for the random number generator
"""
super().__init__()
random.seed(seed)
self._actions = [Action.hit, Action.stick]
def __getitem__(self, s: State) -> Action:
return random.choice(self._actions)
def greedy_prob(self, s: State) -> float:
return 0.0
def prob(self, a: Action, s: State) -> float:
return 1.0 / len(self._actions)
class GreedyPolicy(Policy):
"""A greedy policy that selects actions based on its current mapping"""
def __getitem__(self, s: State) -> Action:
# Picks the action based on the current policy
return self._pi[s.dealer_first_card, s.player_sum]
def greedy_prob(self, s: State) -> float:
return 1.0
def prob(self, a: Action, s: State) -> float:
return 1.0 if a == self._pi[s.dealer_first_card, s.player_sum] else 0.0
class EpsilonGreedyPolicy(Policy):
"""
An epsilon greedy policy that selects random actions with probability epsilon.
It follows the exploration strategy described in the Easy21 assignment instructions.
"""
def __init__(self, seed: int = None):
"""
:param int seed: The seed to use for the random number generator
"""
super().__init__()
random.seed(seed)
self._n0 = 100.0
# Number of times a state has been visited
self._n = defaultdict(int)
def __getitem__(self, s: State) -> Action:
# Compute epsilon following the strategy outlined in the Easy21 assignment instructions
self._n[s] += 1
epsilon = self._epsilon(s)
if random.random() < epsilon:
return random.choice([Action.hit, Action.stick])
return self._pi[s.dealer_first_card, s.player_sum]
def greedy_prob(self, s: State) -> float:
return 1.0 - self._epsilon(s)
def prob(self, a: Action, s: State) -> float:
eps = self._epsilon(s)
return 1.0 - eps if a == self._pi[s.dealer_first_card, s.player_sum] else eps
def _epsilon(self, s: State) -> float:
return self._n0 / (self._n0 + self._n[s])
class EpsilonGreedyApproximationPolicy(ApproximationPolicy):
"""
An epsilon greedy policy that selects random actions with probability epsilon.
"""
def __init__(self, epsilon: float, approximator: Callable[[State], List[float]], seed: int = None):
"""
:param float epsilon: The exploration factor
:param Callable approximator: The approximator that generates values for all actions in the given state
:param int seed: The seed to use for the random number generator
"""
super().__init__(epsilon, approximator)
random.seed(seed)
def __getitem__(self, s: State) -> Action:
if random.random() < self._epsilon:
return random.choice([Action.hit, Action.stick])
return np.argmax(self._approximator(s))