Skip to content

Commit 5d674c9

Browse files
author
izhigal
committed
yet more formatting
1 parent f4b7ec1 commit 5d674c9

File tree

9 files changed

+80
-98
lines changed

9 files changed

+80
-98
lines changed

examples/human/leduc_holdem_human.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
"""A toy example of playing against pretrianed AI on Leduc Hold'em
2-
"""
1+
"""A toy example of playing against pretrianed AI on Leduc Hold'em"""
32

43
import rlcard
54
from rlcard import models

examples/human/limit_holdem_human.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
"""A toy example of playing against a random agent on Limit Hold'em
2-
"""
1+
"""A toy example of playing against a random agent on Limit Hold'em"""
32

43
import rlcard
54
from rlcard.agents import LimitholdemHumanAgent as HumanAgent

examples/run_dmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
import rlcard
88
from rlcard.agents.dmc_agent import DMCTrainer
99

10-
def train(args):
1110

11+
def train(args):
1212
# Make the environment
1313
env = rlcard.make(args.env)
1414

@@ -28,6 +28,7 @@ def train(args):
2828
# Train DMC Agents
2929
trainer.start()
3030

31+
3132
if __name__ == '__main__':
3233
parser = argparse.ArgumentParser("DMC example in RLCard")
3334
parser.add_argument(
@@ -94,4 +95,3 @@ def train(args):
9495

9596
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
9697
train(args)
97-

examples/run_rl.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
plot_curve,
1616
)
1717

18-
def train(args):
1918

19+
def train(args):
2020
# Check whether gpu is available
2121
device = get_device()
22-
22+
2323
# Seed numpy, torch, random
2424
set_seed(args.seed)
2525

@@ -40,7 +40,7 @@ def train(args):
4040
agent = DQNAgent(
4141
num_actions=env.num_actions,
4242
state_shape=env.state_shape[0],
43-
mlp_layers=[64,64],
43+
mlp_layers=[64, 64],
4444
device=device,
4545
save_path=args.log_dir,
4646
save_every=args.save_every
@@ -54,8 +54,8 @@ def train(args):
5454
agent = NFSPAgent(
5555
num_actions=env.num_actions,
5656
state_shape=env.state_shape[0],
57-
hidden_layers_sizes=[64,64],
58-
q_mlp_layers=[64,64],
57+
hidden_layers_sizes=[64, 64],
58+
q_mlp_layers=[64, 64],
5959
device=device,
6060
save_path=args.log_dir,
6161
save_every=args.save_every
@@ -105,6 +105,7 @@ def train(args):
105105
torch.save(agent, save_path)
106106
print('Model saved in', save_path)
107107

108+
108109
if __name__ == '__main__':
109110
parser = argparse.ArgumentParser("DQN/NFSP example in RLCard")
110111
parser.add_argument(
@@ -162,13 +163,13 @@ def train(args):
162163
type=str,
163164
default='experiments/leduc_holdem_dqn_result/',
164165
)
165-
166+
166167
parser.add_argument(
167168
"--load_checkpoint_path",
168169
type=str,
169170
default="",
170171
)
171-
172+
172173
parser.add_argument(
173174
"--save_every",
174175
type=int,
@@ -178,4 +179,3 @@ def train(args):
178179

179180
os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda
180181
train(args)
181-

rlcard/agents/dmc_agent/model.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,17 @@
1818
import torch
1919
from torch import nn
2020

21+
2122
class DMCNet(nn.Module):
22-
def __init__(
23-
self,
24-
state_shape,
25-
action_shape,
26-
mlp_layers=None
27-
):
23+
def __init__(self, state_shape, action_shape, mlp_layers=None):
2824
super().__init__()
2925
if mlp_layers is None:
3026
mlp_layers = [512, 512, 512, 512, 512]
3127
input_dim = np.prod(state_shape) + np.prod(action_shape)
3228
layer_dims = [input_dim] + mlp_layers
3329
fc = []
34-
for i in range(len(layer_dims)-1):
35-
fc.append(nn.Linear(layer_dims[i], layer_dims[i+1]))
30+
for i in range(len(layer_dims) - 1):
31+
fc.append(nn.Linear(layer_dims[i], layer_dims[i + 1]))
3632
fc.append(nn.ReLU())
3733
fc.append(nn.Linear(layer_dims[-1], 1))
3834
self.fc_layers = nn.Sequential(*fc)
@@ -44,19 +40,13 @@ def forward(self, obs, actions):
4440
values = self.fc_layers(x).flatten()
4541
return values
4642

43+
4744
class DMCAgent:
48-
def __init__(
49-
self,
50-
state_shape,
51-
action_shape,
52-
mlp_layers=None,
53-
exp_epsilon=0.01,
54-
device="0",
55-
):
45+
def __init__(self, state_shape, action_shape, mlp_layers=None, exp_epsilon=0.01, device="0"):
5646
if mlp_layers is None:
5747
mlp_layers = [512, 512, 512, 512, 512]
5848
self.use_raw = False
59-
self.device = 'cuda:'+device if device != "cpu" else "cpu"
49+
self.device = 'cuda:' + device if device != "cpu" else "cpu"
6050
self.net = DMCNet(state_shape, action_shape, mlp_layers).to(self.device)
6151
self.exp_epsilon = exp_epsilon
6252
self.action_shape = action_shape
@@ -78,8 +68,7 @@ def eval_step(self, state):
7868
action_idx = np.argmax(values)
7969
action = action_keys[action_idx]
8070

81-
info = {}
82-
info['values'] = {state['raw_legal_actions'][i]: float(values[i]) for i in range(len(action_keys))}
71+
info = {'values': {state['raw_legal_actions'][i]: float(values[i]) for i in range(len(action_keys))}}
8372

8473
return action, info
8574

@@ -125,15 +114,9 @@ def state_dict(self):
125114
def set_device(self, device):
126115
self.device = device
127116

117+
128118
class DMCModel:
129-
def __init__(
130-
self,
131-
state_shape,
132-
action_shape,
133-
mlp_layers=None,
134-
exp_epsilon=0.01,
135-
device=0
136-
):
119+
def __init__(self, state_shape, action_shape, mlp_layers=None, exp_epsilon=0.01, device=0):
137120
if mlp_layers is None:
138121
mlp_layers = [512, 512, 512, 512, 512]
139122
self.agents = []

rlcard/agents/dmc_agent/pettingzoo_model.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,7 @@ def feed(self, ts):
2020

2121

2222
class DMCModelPettingZoo:
23-
def __init__(
24-
self,
25-
env,
26-
mlp_layers=None,
27-
exp_epsilon=0.01,
28-
device="0"
29-
):
23+
def __init__(self, env, mlp_layers=None, exp_epsilon=0.01, device="0"):
3024
if mlp_layers is None:
3125
mlp_layers = [512, 512, 512, 512, 512]
3226

0 commit comments

Comments
 (0)