Skip to content

Commit ed47caf

Browse files
committed
Add support for parsing replays to gbx.py; sort tracks based on replays; fix important bug with data augumentation; new position model; apply weights to block model's output; add gui.py
1 parent 4e7bbdf commit ed47caf

22 files changed

+836
-133
lines changed

README.md

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ This will generate a track using the provided block and position models that wil
1818
* Keras
1919
* python-lzo (through pip)
2020
* numpy
21-
* Not required: pygame or Gtk+3 and GLib for track visualization
21+
* Not required: Gtk+3 and GLib for track visualization
2222

2323
## Dataset
2424
This repo doesn't contain the dataset itself used to train the models in the `models/` directory as it is unusual to provide entire datasets with code in one repo. There is however a preprocessed version of the dataset used in the `data/train_data.pkl` file that you can use for futher training.
2525

26-
The file contains roughly 3000 tech tracks downloaded directly from [TMX](https://tmnforever.tm-exchange.com/) and preprocessed such that they contain only the simplified version of tracks of each map. The maps themselves were downloaded using these filters: type: tech, order: awards (most), length: ~= 1m.
26+
The file contains roughly 3000 tech tracks downloaded directly from [TMX](https://tmnforever.tm-exchange.com/) and preprocessed such that they contain only the simplified version of tracks of each map. The maps themselves were downloaded using these filters: style: tech, order: awards (most), length: ~= 1m.
2727

2828
## Neural Network Architecture
2929
We can represent a track as a sequence of block placements. Each block consists of 3 main features:
@@ -44,8 +44,6 @@ Since we want to predict the next block in the sequence we ask the block model t
4444
The position model's output is two features: the vector to add to the position of last block to get a new position of the new block and the rotation of the new block.
4545
Their loss function is mean squared error and softmax respectively.
4646

47-
![Visualization](/docs/TMTrackArch.png)
48-
4947
## Training
5048
It is recommended to have a dedicated GPU for training the nets, otherwise training process will be very slow.
5149

@@ -62,4 +60,4 @@ python3 -i train_pos.py -g -l models/position_model_1024_512.h5
6260
Invoking either `train_blocks.py` or `train_pos.py` with the `-l` option will automatically
6361
use model checkpointing to save new models with the model filename that was loaded.
6462

65-
`livebuild.py` allows to dynamically generate tracks, it has a Gtk+3 UI to visualize how the track currently looks like. To fully evaluate model's performance, it's recommended to use `build.py` and see the tracks generated in the game itself.
63+
`livebuild.py` allows to dynamically generate tracks, it has a Gtk+3 UI to visualize how the track currently looks like. To fully evaluate model's performance, it's recommended to use `build.py` and see the tracks generated in the game itself.

blocks.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,9 @@
361361

362362
# 'StadiumRoadDirtHighToRoad': 358,
363363
# 'StadiumDirtHill': 359,
364-
# 'StadiumPlatformBiSlope2StartSmall': 360
364+
# 'StadiumPlatformBiSlope2StartSmall': 360,
365+
# 'StadiumWaterClip': 361,
366+
# 'StadiumDirtClip': 362
365367
}
366368

367369
# TODO: needs a separate dict / list for TM2 Stadium
@@ -711,10 +713,12 @@ def is_multilap(name):
711713

712714
BID, BX, BY, BZ, BROT, BFLAGS = range(6)
713715
EMPTY_BLOCK = (0, 0, 0, 0, 0)
714-
BASE_BLOCKS = list(range(6, 97+1)) + \
715-
list(range(105, 186+1)) + list(range(196, 233+1))
716-
GROUND_BLOCKS = [20] + list(range(196, 233+1))
716+
BASE_BLOCKS = list(range(6, 98+1)) + \
717+
list(range(105, 120+1)) + list(range(127, 224+1)) + list(range(227, 233+1))
718+
GROUND_BLOCKS = [20] + list(range(196, 233+1)) + \
719+
list(range(253, 263+1)) + list(range(272, 277+1)) + list(range(287, 292+1))
717720
TRANSITION_BLOCKS = [130, 214, 215]
721+
ROAD_BLOCKS = [6, 198, 264, 265, 266, 270, 271]
718722
START_LINE_BLOCK = BLOCKS['StadiumRoadMainStartLine']
719723
FINISH_LINE_BLOCK = BLOCKS['StadiumRoadMainFinishLine']
720724
MULTILAP_LINE_BLOCK = BLOCKS['StadiumRoadMainStartFinishLine']

build.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from sklearn.preprocessing import MinMaxScaler
2424

2525
from blocks import BID, BLOCKS, BROT, BX, BY, BZ
26-
from track_utils import fit_position_scaler
26+
from track_utils import fit_data_scaler
2727
from config import NET_CONFIG
2828
from builder import Builder
2929
from savegbx import save_gbx
@@ -40,7 +40,7 @@
4040
pattern_data_file = open(NET_CONFIG['patterns_fname'], 'rb')
4141
pattern_data = pickle.load(pattern_data_file)
4242

43-
scaler = fit_position_scaler(train_data)
43+
scaler = fit_data_scaler(train_data)
4444

4545

4646
def progress_callback(completed, total):
@@ -56,7 +56,7 @@ def progress_callback(completed, total):
5656
pos_model = load_model(args.pos_model)
5757

5858
builder = Builder(block_model, pos_model,
59-
NET_CONFIG['lookback'], train_data, pattern_data, scaler, temperature=args.temperature)
59+
NET_CONFIG['lookback'], train_data, pattern_data, scaler, temperature=args.temperature)
6060

6161
track = builder.build(args.length, verbose=False, save=False,
6262
progress_callback=progress_callback)

builder.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import numpy as np
55

66
import blocks as bl
7-
from track_utils import intersects, occupied_track_positions, rotate_track_tuples, dist
7+
from track_utils import intersects, occupied_track_vectors, rotate_track_tuples, dist
88
from blocks import BID, BROT, BX, BY, BZ
9+
from tech_block_weights import TECH_BLOCK_WEIGHTS
910

1011
POS_LEN = 3
1112
ROTATE_LEN = 4
@@ -22,15 +23,15 @@ def __init__(self, block_model, position_model, lookback, seed_data, pattern_dat
2223
self.inp_len = len(bl.BLOCKS) + POS_LEN + ROTATE_LEN
2324
self.temperature = temperature
2425
self.pattern_data = pattern_data
26+
self.max_map_size = (32, 32, 32)
27+
self.running = False
2528

2629
@staticmethod
2730
def random_start_block():
2831
return (bl.START_LINE_BLOCK, 0, 0, 0, random.randrange(0, 4))
2932

30-
# Source:
31-
# https://github.com/keras-team/keras/blob/master/examples/lstm_text_generation.py#L66
32-
# Helper function to sample an index from a probability array
3333
def sample(self, preds):
34+
# helper function to sample an index from a probability array
3435
preds = np.asarray(preds).astype('float64')
3536
preds = np.log(preds) / self.temperature
3637
exp_preds = np.exp(preds)
@@ -68,8 +69,7 @@ def decoded_track(track, start_pos=(0, 0, 0)):
6869
block[BROT])
6970
return d
7071

71-
@staticmethod
72-
def unpack_position_preds_vector(preds):
72+
def unpack_position_preds_vector(self, preds):
7373
pos_vec = [int(round(axis)) for axis in preds[0][0]]
7474
pos_rot = np.argmax(preds[1][0])
7575
return pos_vec, pos_rot
@@ -79,6 +79,7 @@ def predict_next_block(self, X_block, X_position, block_override=-1, blacklist=[
7979
next_block = block_override
8080
else:
8181
block_preds = self.block_model.predict(X_block)[0]
82+
block_preds = block_preds * TECH_BLOCK_WEIGHTS
8283
block_preds = np.delete(
8384
block_preds, [bid - 1 for bid in blacklist])
8485

@@ -136,7 +137,7 @@ def prepare_inputs(self):
136137
return (X_block, X_position)
137138

138139
def position_track(self, track):
139-
occ = occupied_track_positions(track)
140+
occ = occupied_track_vectors(track)
140141
min_x = min(occ, key=lambda pos: pos[0])[0]
141142
min_y = min(occ, key=lambda pos: pos[1])[1] - 1
142143
min_z = min(occ, key=lambda pos: pos[2])[2]
@@ -145,6 +146,13 @@ def position_track(self, track):
145146
max_y = max(occ, key=lambda pos: pos[1])[1] - 1
146147
max_z = max(occ, key=lambda pos: pos[2])[2]
147148

149+
cx = 32 - (max_x - min_x + 1)
150+
if cx > 0:
151+
cx = random.randrange(0, cx)
152+
cz = 32 - (max_z - min_z + 1)
153+
if cz > 0:
154+
cz = random.randrange(0, cz)
155+
148156
min_x = 0 if min_x >= 0 else min_x
149157
min_y = 0 if min_y >= 0 else min_y
150158
min_z = 0 if min_z >= 0 else min_z
@@ -153,26 +161,44 @@ def position_track(self, track):
153161
max_y = 0 if max_y < 32 else max_y - 31
154162
max_z = 0 if max_z < 32 else max_z - 31
155163

164+
xoff = min_x - max_x
165+
yoff = min_y - max_y
166+
zoff = min_z - max_z
167+
156168
p = []
157169
for block in track:
158-
p.append((block[BID], block[BX] - min_x - max_x, block[BY] -
159-
min_y - max_y, block[BZ] - min_z - max_z, block[BROT]))
170+
p.append((block[BID], block[BX] - xoff + cx, block[BY] -
171+
yoff, block[BZ] - zoff + cz, block[BROT]))
160172

161173
return p
162174

163175
def exceeds_map_size(self, track):
164-
occ = occupied_track_positions(track)
176+
occ = occupied_track_vectors(track)
165177
min_x = min(occ, key=lambda pos: pos[0])[0]
166-
min_y = min(occ, key=lambda pos: pos[1])[1] - 1
178+
min_y = min(occ, key=lambda pos: pos[1])[1]
167179
min_z = min(occ, key=lambda pos: pos[2])[2]
168180

169181
max_x = max(occ, key=lambda pos: pos[0])[0]
170-
max_y = max(occ, key=lambda pos: pos[1])[1] - 1
182+
max_y = max(occ, key=lambda pos: pos[1])[1]
171183
max_z = max(occ, key=lambda pos: pos[2])[2]
172184

173-
return max_x - min_x > 32 or max_y - min_y > 32 or max_z - min_z > 32
185+
return max_x - min_x + 1 > self.max_map_size[0] or max_y - min_y + 1 > self.max_map_size[1] or max_z - min_z + 1 > self.max_map_size[2]
186+
187+
def stop(self):
188+
self.running = False
189+
190+
def get_y_locked(self):
191+
for block in self.track:
192+
if block[BID] in bl.GROUND_BLOCKS:
193+
return True
194+
195+
return False
174196

175197
def build(self, track_len, use_seed=False, failsafe=True, verbose=True, save=True, progress_callback=None):
198+
self.running = True
199+
200+
# self.max_map_size = (random.randrange(
201+
# 12, 32+1), random.randrange(5, 10), random.randrange(12, 32+1))
176202
if use_seed:
177203
self.track = self.sample_seed(3)
178204
else:
@@ -182,18 +208,21 @@ def build(self, track_len, use_seed=False, failsafe=True, verbose=True, save=Tru
182208
end = False
183209
current_min_y = 0
184210
while len(self.track) < track_len:
185-
if len(blacklist) >= 5 or (len(blacklist) == 1 and end):
211+
if not self.running:
212+
return None
213+
214+
if len(blacklist) >= 10 or (len(blacklist) == 1 and end):
186215
if verbose:
187216
print('More than 10 fails, going back.')
188217

189218
if end:
190-
back = 10
219+
back = 5
191220
else:
192221
back = random.randrange(1, 4)
193-
# Remove some last blocks
194-
for _ in range(back):
195-
if len(self.track) > 1:
196-
del self.track[-1]
222+
223+
end_idx = min(len(self.track) - 1, back)
224+
if end_idx > 0:
225+
del self.track[-end_idx:len(self.track)]
197226

198227
end = False
199228
blacklist = []
@@ -210,35 +239,42 @@ def build(self, track_len, use_seed=False, failsafe=True, verbose=True, save=Tru
210239
decoded = self.decoded_track(
211240
self.track + [next_block], start_pos=(0, 0, 0))
212241

213-
# Do not exceed map size
214242
if failsafe:
243+
# Do not exceed map size
215244
if self.exceeds_map_size(decoded):
216245
blacklist.append(next_block[BID])
217246
continue
218247

219248
if decoded[-1][BY] > current_min_y:
249+
# TODO: encode ground bit in the position network
220250
if decoded[-1][BID] == 6 and decoded[-2][BID] == 6 and dist(decoded[-1][BX:BZ+1], decoded[-2][BX:BZ+1]) > 1:
221251
blacklist.append(next_block[BID])
222252
continue
223253

254+
# Wants to put a ground block higher than ground
224255
if next_block[BID] in bl.GROUND_BLOCKS:
225-
blacklist.append(next_block[BID])
256+
blacklist.extend(bl.GROUND_BLOCKS)
226257
continue
227258

228259
if (intersects(decoded[:-1], decoded[-1]) or # Overlaps the track
229-
(next_block[BID] in range(99, 104+1) or next_block[BID] in range(121, 126+1)) or
230260
(next_block[BID] == bl.FINISH_LINE_BLOCK and not end)): # Tries to put finish before desired track length
231261
blacklist.append(next_block[BID])
232262
continue
233263

234-
if self.score_prediction(self.track[-1], next_block) < 4:
264+
if self.score_prediction(self.track[-1], next_block) < 5:
235265
blacklist.append(next_block[BID])
236266
continue
237267

238268
blacklist = []
239269

240-
if decoded[-1][BY] < current_min_y:
241-
current_min_y = decoded[-1][BY]
270+
occ = occupied_track_vectors([decoded[-1]])
271+
min_y_block = min(occ, key=lambda x: x[BY])[BY]
272+
if min_y_block < current_min_y:
273+
if self.get_y_locked():
274+
blacklist.append(next_block[BID])
275+
continue
276+
277+
current_min_y = min_y_block
242278

243279
self.track.append(next_block)
244280
if len(self.track) >= track_len - 1:
@@ -252,5 +288,4 @@ def build(self, track_len, use_seed=False, failsafe=True, verbose=True, save=Tru
252288

253289
result_track = self.position_track(
254290
self.decoded_track(self.track, (0, 0, 0)))
255-
pickle.dump(result_track, open('generated-track.bin', 'wb+'))
256291
return result_track

bytereader.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,27 @@ def get_bytes_file(self, num_bytes):
4949
def get_bytes_generic(self, num_bytes):
5050
return self.data[self.pos:self.pos + num_bytes]
5151

52+
def read_int32(self):
53+
return self.read(4, 'i')
54+
5255
def read_uint32(self):
5356
return self.read(4, 'I')
5457

55-
def read_int32(self):
56-
return self.read(4, 'i')
58+
def read_int16(self):
59+
return self.read(2, 'h')
5760

5861
def read_uint16(self):
5962
return self.read(2, 'H')
6063

64+
def read_int8(self):
65+
return self.read(1, 'b')
66+
67+
def read_float(self):
68+
return self.read(4, 'f')
69+
70+
def read_vec3(self):
71+
return (self.read_float(), self.read_float(), self.read_float())
72+
6173
def read_string(self):
6274
strlen = self.read_uint32()
6375
return self.read(strlen, str(strlen) + 's').decode('utf-8')
@@ -77,7 +89,7 @@ def read_info(self, info, f):
7789
def skip(self, num_bytes):
7890
self.pos += num_bytes
7991

80-
def read_string_loopback(self):
92+
def read_string_lookback(self):
8193
if not self.seen_loopback:
8294
self.read_uint32()
8395

config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
NET_CONFIG = {
22
'batch_size': 128,
33
'lookback': 20,
4-
'train_fname': 'data/train_data.pkl',
5-
'patterns_fname': 'data/pattern_data.pkl'
4+
'train_fname': 'data/replay_train_data.pkl',
5+
'patterns_fname': 'data/pattern_data_replay.pkl'
66
}

data/Template.Challenge.Gbx

100755100644
File mode changed.

data/pattern_data_replay.pkl

4.26 MB
Binary file not shown.

data/replay_train_data.pkl

5.27 MB
Binary file not shown.

data/train_data.pkl

-8.83 MB
Binary file not shown.

0 commit comments

Comments
 (0)