44import numpy as np
55
66import blocks as bl
7- from track_utils import intersects , occupied_track_vectors , rotate_track_tuples , dist
8- from blocks import BID , BROT , BX , BY , BZ
7+ from track_utils import intersects , occupied_track_vectors , rotate_track_tuples , dist , is_on_ground , vectorize_track
8+ from blocks import BLOCKS , BID , BROT , BX , BY , BZ , BFLAGS
99from tech_block_weights import TECH_BLOCK_WEIGHTS
1010
1111POS_LEN = 3
@@ -28,7 +28,7 @@ def __init__(self, block_model, position_model, lookback, seed_data, pattern_dat
2828
2929 @staticmethod
3030 def random_start_block ():
31- return (bl .START_LINE_BLOCK , 0 , 0 , 0 , random .randrange (0 , 4 ))
31+ return (bl .START_LINE_BLOCK , 0 , 0 , 0 , random .randrange (0 , 4 ), 0 )
3232
3333 # Source:
3434 # https://github.com/keras-team/keras/blob/master/examples/lstm_text_generation.py#L66
@@ -58,7 +58,7 @@ def block_to_vec(inp_len, block, scaler, encode_pos):
5858
5959 @staticmethod
6060 def decoded_track (track , start_pos = (0 , 0 , 0 )):
61- d = track
61+ d = track [:]
6262 d [0 ] = (d [0 ][BID ], start_pos [0 ], start_pos [1 ],
6363 start_pos [2 ], d [0 ][BROT ])
6464 for i in range (1 , len (d )):
@@ -76,14 +76,16 @@ def unpack_position_preds_vector(self, preds):
7676 pos_rot = np .argmax (preds [1 ][0 ])
7777 return pos_vec , pos_rot
7878
79- def predict_next_block (self , X_block , X_position , block_override = - 1 , blacklist = []):
79+ def predict_next_block (self , X_block , X_position , block_override = - 1 , blacklist = [], block_preds = None ):
8080 if block_override != - 1 :
8181 next_block = block_override
8282 else :
83- block_preds = self .block_model .predict (X_block )[0 ]
84- block_preds = block_preds * TECH_BLOCK_WEIGHTS
85- block_preds = np .delete (
86- block_preds , [bid - 1 for bid in blacklist ])
83+ if block_preds is None :
84+ block_preds = self .block_model .predict (X_block )[0 ]
85+ block_preds = block_preds * TECH_BLOCK_WEIGHTS
86+
87+ for bid in blacklist :
88+ block_preds [bid - 1 ] = 0
8789
8890 next_block = self .sample (block_preds ) + 1
8991
@@ -94,8 +96,9 @@ def predict_next_block(self, X_block, X_position, block_override=-1, blacklist=[
9496 (next_block , 0 , 0 , 0 , 0 ), self .scaler , False )
9597
9698 pos_preds = self .position_model .predict (X_position )
97- pos_vec , pos_rot = self .unpack_position_preds_vector (pos_preds )
98- return (next_block , pos_vec [0 ], pos_vec [1 ], pos_vec [2 ], pos_rot )
99+ pos_vec , pos_rot = self .unpack_position_preds_vector (
100+ pos_preds )
101+ return (next_block , pos_vec [0 ], pos_vec [1 ], pos_vec [2 ], pos_rot ), block_preds
99102
100103 def sample_seed (self , seed_len ):
101104 seed_idx = random .randrange (0 , len (self .seed_data ))
@@ -112,8 +115,7 @@ def score_prediction(self, prev_block, next_block):
112115 next_block = (next_block [BID ], next_block [BX ] - prev_block [BX ], next_block [BY ] -
113116 prev_block [BY ], next_block [BZ ] - prev_block [BZ ], next_block [BROT ])
114117
115- prev_block = (prev_block [BID ], 0 , 0 , 0 , prev_block [BROT ])
116-
118+ prev_block = prev_block [BID ]
117119 target = (prev_block , next_block )
118120 try :
119121 return self .pattern_data [target ]
@@ -150,10 +152,10 @@ def position_track(self, track):
150152
151153 cx = 32 - (max_x - min_x + 1 )
152154 if cx > 0 :
153- cx = random . randrange ( 0 , cx )
155+ cx = int ( cx / 2 )
154156 cz = 32 - (max_z - min_z + 1 )
155157 if cz > 0 :
156- cz = random . randrange ( 0 , cz )
158+ cz = int ( cz / 2 )
157159
158160 min_x = 0 if min_x >= 0 else min_x
159161 min_y = 0 if min_y >= 0 else min_y
@@ -164,13 +166,12 @@ def position_track(self, track):
164166 max_z = 0 if max_z < 32 else max_z - 31
165167
166168 xoff = min_x - max_x
167- yoff = min_y - max_y
168169 zoff = min_z - max_z
169170
170171 p = []
171172 for block in track :
172- p .append ((block [BID ], block [BX ] - xoff + cx , block [BY ] -
173- yoff , block [BZ ] - zoff + cz , block [BROT ]))
173+ p .append ((block [BID ], block [BX ] - xoff + cx , block [BY ],
174+ block [BZ ] - zoff + cz , block [BROT ]))
174175
175176 return p
176177
@@ -184,31 +185,25 @@ def exceeds_map_size(self, track):
184185 max_y = max (occ , key = lambda pos : pos [1 ])[1 ]
185186 max_z = max (occ , key = lambda pos : pos [2 ])[2 ]
186187
187- return max_x - min_x + 1 > self .max_map_size [0 ] or max_y - min_y + 1 > self .max_map_size [1 ] or max_z - min_z + 1 > self .max_map_size [2 ]
188+ return max_x - min_x + 1 > self .max_map_size [0 ] or max_y - min_y + 1 > self .max_map_size [1 ] or max_z - min_z + 1 > self .max_map_size [2 ] or min_y < 1
188189
189190 def stop (self ):
190191 self .running = False
191192
192- def get_y_locked (self ):
193- for block in self .track :
194- if block [BID ] in bl .GROUND_BLOCKS :
195- return True
196-
197- return False
198-
199193 def build (self , track_len , use_seed = False , failsafe = True , verbose = True , save = True , progress_callback = None ):
200194 self .running = True
201195
202- # self.max_map_size = (random.randrange(
203- # 12, 32+1), random.randrange(5, 10), random.randrange(12, 32+1))
204- if use_seed :
196+ self .max_map_size = (20 , 8 , 20 )
197+ if use_seed and self .seed_data :
205198 self .track = self .sample_seed (3 )
206199 else :
207200 self .track = [self .random_start_block ()]
208201
202+ fixed_y = random .randrange (1 , 5 )
203+
209204 blacklist = []
210205 end = False
211- current_min_y = 0
206+ current_block_preds = None
212207 while len (self .track ) < track_len :
213208 if not self .running :
214209 return None
@@ -217,43 +212,51 @@ def build(self, track_len, use_seed=False, failsafe=True, verbose=True, save=Tru
217212 if verbose :
218213 print ('More than 10 fails, going back.' )
219214
220- if end :
215+ if len ( self . track ) > track_len - 5 :
221216 back = 5
217+ elif end :
218+ back = 10
222219 else :
223- back = random .randrange (1 , 4 )
220+ back = random .randrange (2 , 6 )
224221
225222 end_idx = min (len (self .track ) - 1 , back )
226223 if end_idx > 0 :
227224 del self .track [- end_idx :len (self .track )]
228225
229226 end = False
230227 blacklist = []
228+ current_block_preds = None
231229
232230 X_block , X_position = self .prepare_inputs ()
233231
234232 override_block = - 1
235233 if end :
236234 override_block = bl .FINISH_LINE_BLOCK
237235
238- next_block = self .predict_next_block (
239- X_block [:], X_position [:], override_block , blacklist = blacklist )
236+ next_block , current_block_preds = self .predict_next_block (
237+ X_block [:], X_position [:], override_block , blacklist = blacklist , block_preds = current_block_preds )
240238
241239 decoded = self .decoded_track (
242- self .track + [next_block ], start_pos = (0 , 0 , 0 ))
240+ self .track + [next_block ], start_pos = (0 , fixed_y , 0 ))
243241
244242 if failsafe :
245243 # Do not exceed map size
246244 if self .exceeds_map_size (decoded ):
247245 blacklist .append (next_block [BID ])
248246 continue
249247
250- if decoded [- 1 ][BY ] > current_min_y :
251- # TODO: encode ground bit in the position network
252- if decoded [- 1 ][BID ] == 6 and decoded [- 2 ][BID ] == 6 and dist (decoded [- 1 ][BX :BZ + 1 ], decoded [- 2 ][BX :BZ + 1 ]) > 1 :
248+ occ = occupied_track_vectors ([decoded [- 1 ]])
249+ if len (occ ) > 0 :
250+ min_y_block = min (occ , key = lambda x : x [1 ])[1 ]
251+ else :
252+ min_y_block = decoded [- 1 ][BY ]
253+
254+ # If we are above the ground
255+ if min_y_block > 1 :
256+ if next_block [BID ] == BLOCKS ['StadiumGrass' ]:
253257 blacklist .append (next_block [BID ])
254258 continue
255259
256- # Wants to put a ground block higher than ground
257260 if next_block [BID ] in bl .GROUND_BLOCKS :
258261 blacklist .extend (bl .GROUND_BLOCKS )
259262 continue
@@ -268,16 +271,10 @@ def build(self, track_len, use_seed=False, failsafe=True, verbose=True, save=Tru
268271 continue
269272
270273 blacklist = []
274+ current_block_preds = None
271275
272- occ = occupied_track_vectors ([decoded [- 1 ]])
273- min_y_block = min (occ , key = lambda x : x [BY ])[BY ]
274- if min_y_block < current_min_y :
275- if self .get_y_locked ():
276- blacklist .append (next_block [BID ])
277- continue
278-
279- current_min_y = min_y_block
280-
276+ next_block = (next_block [BID ], next_block [BX ], next_block [BY ],
277+ next_block [BZ ], next_block [BROT ])
281278 self .track .append (next_block )
282279 if len (self .track ) >= track_len - 1 :
283280 end = True
@@ -289,5 +286,8 @@ def build(self, track_len, use_seed=False, failsafe=True, verbose=True, save=Tru
289286 print (len (self .track ))
290287
291288 result_track = self .position_track (
292- self .decoded_track (self .track , (0 , 0 , 0 )))
289+ self .decoded_track (self .track , (0 , fixed_y , 0 )))
290+
291+ result_track = [
292+ block for block in result_track if block [BID ] != BLOCKS ['StadiumGrass' ]]
293293 return result_track
0 commit comments