Skip to content

Commit 865270f

Browse files
committed
Add a processing time limit, and various clean
1 parent 9695560 commit 865270f

5 files changed

Lines changed: 69 additions & 67 deletions

File tree

Coach.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
log = logging.getLogger(__name__)
1515

16-
1716
class Coach():
1817
"""
1918
This class executes the self-play + learning. It uses the functions defined
@@ -77,6 +76,8 @@ def learn(self):
7776
only if it wins >= updateThreshold fraction of games.
7877
"""
7978

79+
start_time = time.time()
80+
8081
for i in range(1, self.args.numIters + 1):
8182
# bookkeeping
8283
log.info(f'Starting Iter #{i} ...')
@@ -92,8 +93,6 @@ def learn(self):
9293
self.trainExamplesHistory.append(iterationTrainExamples)
9394

9495
if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory:
95-
log.warning(
96-
f"Removing the oldest entry in trainExamples. len(trainExamplesHistory) = {len(self.trainExamplesHistory)}")
9796
self.trainExamplesHistory.pop(0)
9897
# backup history to a file
9998
# NB! the examples were collected using the model from the previous iteration, so (i-1)
@@ -127,6 +126,10 @@ def learn(self):
127126
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename=self.getCheckpointFile(i))
128127
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='best.pt')
129128

129+
if self.args.timeIters > 0:
130+
if time.time() - start_time > self.args.timeIters*3600:
131+
log.info(f'Above timelimit, stopping here after {i} iterations')
132+
break
130133

131134
def getCheckpointFile(self, iteration):
132135
return 'checkpoint_' + str(iteration) + '.pt'

main.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def main():
6666
# rollout = joue sur la perf et le temps...
6767
# learn_rate = 0.001 ? ou bien 0.02 puis diviser à chaque raté ?
6868
parser.add_argument('--numIters' , '-N' , action='store', default=1000 , type=int , help='')
69-
# parser.add_argument('--timeIters' , '-T' , action='store', default=0. , type=float, help='')
69+
parser.add_argument('--timeIters' , '-T' , action='store', default=0. , type=float, help='')
7070
parser.add_argument('--numEps' , '-s' , action='store', default=100 , type=int , help='Number of complete self-play games to simulate during a new iteration')
7171
parser.add_argument('--tempThreshold' , '-t' , action='store', default=15 , type=int , help='')
7272
parser.add_argument('--updateThreshold' , '-u' , action='store', default=0.6 , type=float, help='During arena playoff, new neural net will be accepted if threshold or more of games are won')
@@ -88,8 +88,8 @@ def main():
8888
args = parser.parse_args()
8989
args.arenaCompare = 40
9090
# args.maxlenOfQueue = int(2e6/(1.1*args.numItersForTrainExamplesHistory)) # at most 2GB per process, with each example weighing 1.1kB
91-
# if args.timeIters > 0:
92-
# args.numIters = 1000
91+
if args.timeIters > 0:
92+
args.numIters = 1000
9393

9494
args.load_model = (args.load_folder_file is not None)
9595
if args.profile:

othello/OthelloPlayers.py

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,55 @@
11
import numpy as np
2-
2+
import random
33

44
class RandomPlayer():
5-
def __init__(self, game):
6-
self.game = game
5+
def __init__(self, game):
6+
self.game = game
77

8-
def play(self, board):
9-
a = np.random.randint(self.game.getActionSize())
10-
valids = self.game.getValidMoves(board, 1)
11-
while valids[a]!=1:
12-
a = np.random.randint(self.game.getActionSize())
13-
return a
8+
def play(self, board):
9+
valids = self.game.getValidMoves(board, 1)
10+
return random.choices(range(self.game.getActionSize()), weights=valids, k=1)[0]
1411

1512

1613
class HumanOthelloPlayer():
17-
def __init__(self, game):
18-
self.game = game
19-
20-
def play(self, board):
21-
# display(board)
22-
valid = self.game.getValidMoves(board, 1)
23-
for i in range(len(valid)):
24-
if valid[i]:
25-
print("[", int(i/self.game.n), int(i%self.game.n), end="] ")
26-
while True:
27-
input_move = input()
28-
input_a = input_move.split(" ")
29-
if len(input_a) == 2:
30-
try:
31-
x,y = [int(i) for i in input_a]
32-
if ((0 <= x) and (x < self.game.n) and (0 <= y) and (y < self.game.n)) or \
33-
((x == self.game.n) and (y == 0)):
34-
a = self.game.n * x + y if x != -1 else self.game.n ** 2
35-
if valid[a]:
36-
break
37-
except ValueError:
38-
# Input needs to be an integer
39-
'Invalid integer'
40-
print('Invalid move')
41-
return a
14+
def __init__(self, game):
15+
self.game = game
16+
17+
def play(self, board):
18+
# display(board)
19+
valid = self.game.getValidMoves(board, 1)
20+
for i in range(len(valid)):
21+
if valid[i]:
22+
print("[", int(i/self.game.n), int(i%self.game.n), end="] ")
23+
while True:
24+
input_move = input()
25+
input_a = input_move.split(" ")
26+
if len(input_a) == 2:
27+
try:
28+
x,y = [int(i) for i in input_a]
29+
if ((0 <= x) and (x < self.game.n) and (0 <= y) and (y < self.game.n)) or \
30+
((x == self.game.n) and (y == 0)):
31+
a = self.game.n * x + y if x != -1 else self.game.n ** 2
32+
if valid[a]:
33+
break
34+
except ValueError:
35+
# Input needs to be an integer
36+
'Invalid integer'
37+
print('Invalid move')
38+
return a
4239

4340

4441
class GreedyOthelloPlayer():
45-
def __init__(self, game):
46-
self.game = game
47-
48-
def play(self, board):
49-
valids = self.game.getValidMoves(board, 1)
50-
candidates = []
51-
for a in range(self.game.getActionSize()):
52-
if valids[a]==0:
53-
continue
54-
nextBoard, _ = self.game.getNextState(board, 1, a)
55-
score = self.game.getScore(nextBoard, 1)
56-
candidates += [(-score, a)]
57-
candidates.sort()
58-
return candidates[0][1]
42+
def __init__(self, game):
43+
self.game = game
44+
45+
def play(self, board):
46+
valids = self.game.getValidMoves(board, 1)
47+
candidates = []
48+
for a in range(self.game.getActionSize()):
49+
if valids[a]==0:
50+
continue
51+
nextBoard, _ = self.game.getNextState(board, 1, a)
52+
score = self.game.getScore(nextBoard, 1)
53+
candidates += [(-score, a)]
54+
candidates.sort()
55+
return candidates[0][1]

othello/pytorch/NNet.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,30 @@ def loss_v(self, targets, outputs):
108108
def save_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar'):
109109
filepath = os.path.join(folder, filename)
110110
if not os.path.exists(folder):
111-
print("Checkpoint Directory does not exist! Making directory {}".format(folder))
111+
# print("Checkpoint Directory does not exist! Making directory {}".format(folder))
112112
os.mkdir(folder)
113-
else:
114-
print("Checkpoint Directory exists! ")
113+
# else:
114+
# print("Checkpoint Directory exists! ")
115+
current_uptime = get_uptime()
115116
torch.save({
116117
'state_dict': self.nnet.state_dict(),
118+
'full_model': self.nnet,
119+
'cumulated_uptime': self.cumulated_uptime + current_uptime-self.begin_uptime,
120+
'end_uptime': current_uptime,
121+
'begin': self.begin_time,
117122
}, filepath)
118123
# print(f'SAVE: {self.cumulated_uptime=} {self.begin_uptime=} ==> cumulated_uptime={self.cumulated_uptime + current_uptime-self.begin_uptime}')
119124

120125
def load_checkpoint(self, folder='checkpoint', filename='checkpoint.pth.tar', ongoing_experiment=False):
121126
# https://github.com/pytorch/examples/blob/master/imagenet/main.py#L98
122127
filepath = os.path.join(folder, filename)
123128
if not os.path.exists(filepath):
124-
raise ("No model in path {}".format(filepath))
129+
print("No model in path {}".format(filepath))
130+
return
125131
map_location = None if self.args['cuda'] else 'cpu'
126132
checkpoint = torch.load(filepath, map_location=map_location)
127-
self.nnet.load_state_dict(checkpoint['state_dict'])
133+
self.nnet = checkpoint['full_model']
134+
self.cumulated_uptime = checkpoint.get('cumulated_uptime', 0)
135+
self.begin_time = checkpoint.get('begin', int(time.time()))
136+
self.begin_uptime = checkpoint.get('end_uptime', 0) if ongoing_experiment else get_uptime()
137+

othello/pytorch/OthelloNNet.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
1-
import sys
2-
sys.path.append('..')
3-
from utils import *
4-
5-
import argparse
61
import torch
72
import torch.nn as nn
83
import torch.nn.functional as F
9-
import torch.optim as optim
10-
from torchvision import datasets, transforms
11-
from torch.autograd import Variable
124

135
class OthelloNNet(nn.Module):
146
def __init__(self, game, args):

0 commit comments

Comments
 (0)