3636import six .moves .cPickle
3737import tensorflow as tf
3838
39+ REMOVED_FRAME = 'removed'
40+
3941try :
4042 import cv2
4143except ImportError :
@@ -48,13 +50,11 @@ class DumpConfig(object):
4850 def __init__ (self ,
4951 max_length = 200 ,
5052 max_count = 1 ,
51- skip_visuals = False ,
5253 snapshot_delay = 0 ,
5354 min_frequency = 10 ):
5455 self ._max_length = max_length
5556 self ._max_count = max_count
5657 self ._last_dump = 0
57- self ._skip_visuals = skip_visuals
5858 self ._snapshot_delay = snapshot_delay
5959 self ._file_name = None
6060 self ._result = None
@@ -86,53 +86,54 @@ def write(self, text, scale_factor=1):
8686def get_frame (trace ):
8787 if 'frame' in trace ._trace ['observation' ]:
8888 frame = trace ._trace ['observation' ]['frame' ]
89- else :
90- frame = np .uint8 (np .zeros ((600 , 800 , 3 )))
91- corner1 = (0 , 0 )
92- corner2 = (799 , 0 )
93- corner3 = (799 , 599 )
94- corner4 = (0 , 599 )
95- line_color = (0 , 255 , 255 )
96- cv2 .line (frame , corner1 , corner2 , line_color )
97- cv2 .line (frame , corner2 , corner3 , line_color )
98- cv2 .line (frame , corner3 , corner4 , line_color )
99- cv2 .line (frame , corner4 , corner1 , line_color )
100- cv2 .line (frame , (399 , 0 ), (399 , 799 ), line_color )
89+ if frame != REMOVED_FRAME :
90+ return frame
91+ frame = np .uint8 (np .zeros ((600 , 800 , 3 )))
92+ corner1 = (0 , 0 )
93+ corner2 = (799 , 0 )
94+ corner3 = (799 , 599 )
95+ corner4 = (0 , 599 )
96+ line_color = (0 , 255 , 255 )
97+ cv2 .line (frame , corner1 , corner2 , line_color )
98+ cv2 .line (frame , corner2 , corner3 , line_color )
99+ cv2 .line (frame , corner3 , corner4 , line_color )
100+ cv2 .line (frame , corner4 , corner1 , line_color )
101+ cv2 .line (frame , (399 , 0 ), (399 , 799 ), line_color )
102+ writer = TextWriter (
103+ frame ,
104+ trace ['ball' ][0 ],
105+ trace ['ball' ][1 ],
106+ field_coords = True ,
107+ color = (255 , 0 , 0 ))
108+ writer .write ('B' )
109+ for player_idx , player_coord in enumerate (trace ['left_team' ]):
110+ writer = TextWriter (
111+ frame ,
112+ player_coord [0 ],
113+ player_coord [1 ],
114+ field_coords = True ,
115+ color = (0 , 255 , 0 ))
116+ letter = 'H'
117+ if 'active' in trace and player_idx in trace ['active' ]:
118+ letter = 'X'
119+ elif 'left_agent_controlled_player' in trace and player_idx in trace [
120+ 'left_agent_controlled_player' ]:
121+ letter = 'X'
122+ writer .write (letter )
123+ for player_idx , player_coord in enumerate (trace ['right_team' ]):
101124 writer = TextWriter (
102125 frame ,
103- trace [ 'ball' ] [0 ],
104- trace [ 'ball' ] [1 ],
126+ player_coord [0 ],
127+ player_coord [1 ],
105128 field_coords = True ,
106- color = (255 , 0 , 0 ))
107- writer .write ('B' )
108- for player_idx , player_coord in enumerate (trace ['left_team' ]):
109- writer = TextWriter (
110- frame ,
111- player_coord [0 ],
112- player_coord [1 ],
113- field_coords = True ,
114- color = (0 , 255 , 0 ))
115- letter = 'H'
116- if 'active' in trace and player_idx in trace ['active' ]:
117- letter = 'X'
118- elif 'left_agent_controlled_player' in trace and player_idx in trace [
119- 'left_agent_controlled_player' ]:
120- letter = 'X'
121- writer .write (letter )
122- for player_idx , player_coord in enumerate (trace ['right_team' ]):
123- writer = TextWriter (
124- frame ,
125- player_coord [0 ],
126- player_coord [1 ],
127- field_coords = True ,
128- color = (0 , 0 , 255 ))
129- letter = 'A'
130- if 'opponent_active' in trace and player_idx in trace ['opponent_active' ]:
131- letter = 'Y'
132- elif 'right_agent_controlled_player' in trace and player_idx in trace [
133- 'right_agent_controlled_player' ]:
134- letter = 'Y'
135- writer .write (letter )
129+ color = (0 , 0 , 255 ))
130+ letter = 'A'
131+ if 'opponent_active' in trace and player_idx in trace ['opponent_active' ]:
132+ letter = 'Y'
133+ elif 'right_agent_controlled_player' in trace and player_idx in trace [
134+ 'right_agent_controlled_player' ]:
135+ letter = 'Y'
136+ writer .write (letter )
136137 return frame
137138
138139
@@ -216,7 +217,7 @@ def write_dump(name, trace, skip_visuals=False, config={}):
216217 for o in trace :
217218 if 'frame' in o ._trace ['observation' ]:
218219 temp_frames .append (o ._trace ['observation' ]['frame' ])
219- o ._trace ['observation' ]['frame' ] = 'removed'
220+ o ._trace ['observation' ]['frame' ] = REMOVED_FRAME
220221 to_pickle .append (o ._trace )
221222 with tf .io .gfile .GFile (name + '.dump' , 'wb' ) as f :
222223 six .moves .cPickle .dump (to_pickle , f )
@@ -229,14 +230,6 @@ def write_dump(name, trace, skip_visuals=False, config={}):
229230 return True
230231
231232
232- def logging_write_dump (name , trace , skip_visuals = False , config = {}):
233- try :
234- write_dump (name , trace , skip_visuals = skip_visuals , config = config )
235- except Exception as e :
236- logging .info (traceback .format_exc ())
237- raise
238-
239-
240233class ObservationState (object ):
241234
242235 def __init__ (self , trace ):
@@ -245,7 +238,6 @@ def __init__(self, trace):
245238 self ._additional_frames = []
246239 self ._debugs = []
247240 self ._time = timeit .default_timer ()
248- self ._right_defence_max_x = - 10
249241
250242 def __getitem__ (self , key ):
251243 if key in self ._trace :
@@ -290,21 +282,17 @@ def __init__(self, config):
290282 max_length = 200 ,
291283 max_count = (100000 if config ['dump_scores' ] else 0 ),
292284 min_frequency = 600 ,
293- snapshot_delay = 10 ,
294- skip_visuals = not config ['write_video' ])
285+ snapshot_delay = 10 )
295286 self ._dump_config ['lost_score' ] = DumpConfig (
296287 max_length = 200 ,
297288 max_count = (100000 if config ['dump_scores' ] else 0 ),
298289 min_frequency = 600 ,
299- snapshot_delay = 10 ,
300- skip_visuals = not config ['write_video' ])
290+ snapshot_delay = 10 )
301291 self ._dump_config ['episode_done' ] = DumpConfig (
302292 max_length = (200 if HIGH_RES else 10000 ),
303- max_count = (100000 if config ['dump_full_episodes' ] else 0 ),
304- skip_visuals = not config ['write_video' ])
293+ max_count = (100000 if config ['dump_full_episodes' ] else 0 ))
305294 self ._dump_config ['shutdown' ] = DumpConfig (
306- max_length = (200 if HIGH_RES else 10000 ),
307- skip_visuals = not config ['write_video' ])
295+ max_length = (200 if HIGH_RES else 10000 ))
308296 self ._thread_pool = None
309297 self ._dump_directory = None
310298 self ._config = config
@@ -331,13 +319,13 @@ def __getitem__(self, key):
331319 return self ._trace [key ]
332320
333321 def add_frame (self , frame ):
334- if len (self ._trace ) > 0 :
322+ if len (self ._trace ) > 0 and self . _config [ 'write_video' ] :
335323 self ._trace [- 1 ].add_frame (frame )
336324
337325 @cfg .log
338326 def update (self , trace ):
339327 self ._frame += 1
340- if not self ._config ['write_video' ] and 'frame' in trace :
328+ if not self ._config ['write_video' ] and 'frame' in trace [ 'observation' ] :
341329 # Don't record frame in the trace if we don't write video - full episode
342330 # consumes over 8G.
343331 no_video_trace = trace
@@ -385,11 +373,11 @@ def process_pending_dumps(self, finish):
385373 if finish or config ._trigger_step <= self ._frame :
386374 logging .info ('Start dump %s' , name )
387375 trace = list (self ._trace )[- config ._max_length :]
388- write_dump (config ._file_name , trace , config . _skip_visuals ,
376+ write_dump (config ._file_name , trace , self . _config [ 'write_video' ] ,
389377 self ._config )
390378 config ._file_name = None
391379 if config ._result :
392380 assert not config ._file_name
393381 if config ._result .ready () or finish :
394382 config ._result .get ()
395- config ._result = None
383+ config ._result = None
0 commit comments