@@ -33,19 +33,23 @@ def add_arguments_to_parser(parser):
3333 )
3434
3535 @staticmethod
36- def get_time_mem_list (process ):
37- return [time .time (), process .memory_info ().rss ]
36+ def get_time_mem_list (processes ):
37+ if processes is None :
38+ return [time .time (), float ("nan" )]
39+ if len (processes ) == 0 :
40+ return [time .time (), 0 ]
41+ return [time .time (), sum ([process .memory_info ().rss for process in processes ])]
3842
3943 def __init__ (self , parser_arg_value ):
4044 super ().__init__ ()
4145 self .status_stats += [fs .Keys .MEMORY_USAGE_PLOT ]
4246 self .track_memory_interval = parser_arg_value
43- self .process_being_tracked = None
4447 self .build_dir = None
4548 self .queue = None
4649 self .tracker_process = None
4750 self .tracking_active = False
4851 self .yaml_path = None
52+ self .processes_being_tracked = None
4953
5054 def start (self , build_dir ):
5155 if self .tracking_active :
@@ -54,10 +58,6 @@ def start(self, build_dir):
5458 # Save the folder where data and plot will be stored
5559 self .build_dir = build_dir
5660
57- # Get the process being tracked
58- track_pid = os .getpid ()
59- self .process_being_tracked = psutil .Process (track_pid )
60-
6161 # Create queue for passing messages to the tracker
6262 self .queue = Queue ()
6363
@@ -68,21 +68,31 @@ def start(self, build_dir):
6868 self .tracker_process = Process (
6969 target = self ._memory_tracker_ ,
7070 args = (
71- track_pid ,
7271 self .queue ,
7372 self .yaml_path ,
7473 self .track_memory_interval ,
7574 ),
7675 )
7776 self .tracker_process .start ()
7877 self .tracking_active = True
78+ # Set start of track and log a zero memory usage
7979 self .set_label ("start" )
80- self .sample ()
80+ self .queue .put (MemoryTracker .get_time_mem_list ([]))
81+
82+ def add_pid_to_track (self , pid ):
83+ if self .tracking_active :
84+ self .processes_being_tracked .append (psutil .Process (pid ))
85+ self .queue .put (pid )
8186
8287 def tool_starting (self , tool_name ):
8388 self .set_label (tool_name )
8489
85- def tool_stopping (self ):
90+ def tool_stopping (self , state ):
91+ # Check it the tool as created the inference_processes attribute to state
92+ if self .processes_being_tracked is None and hasattr (state , "inference_pids" ):
93+ self .processes_being_tracked = []
94+ for pid in state .inference_pids :
95+ self .add_pid_to_track (pid )
8696 self .sample ()
8797
8898 def set_label (self , label ):
@@ -91,7 +101,12 @@ def set_label(self, label):
91101
92102 def sample (self ):
93103 if self .tracking_active :
94- self .queue .put (MemoryTracker .get_time_mem_list (self .process_being_tracked ))
104+ if self .processes_being_tracked is None :
105+ self .queue .put (MemoryTracker .get_time_mem_list ([]))
106+ else :
107+ self .queue .put (
108+ MemoryTracker .get_time_mem_list (self .processes_being_tracked )
109+ )
95110
96111 def stop (self ):
97112 if self .tracking_active :
@@ -136,8 +151,8 @@ def generate_results(self, state, timestamp, _):
136151
137152 # last_t and last_y are used to draw a line between the last point of the prior
138153 # track and the first point of the current track
139- last_t = None
140- last_y = None
154+ last_t = 0
155+ last_y = track [ - 1 ][ 1 ]
141156
142157 plt .figure ()
143158 for k , v in memory_tracks [1 :]:
@@ -174,7 +189,6 @@ def generate_results(self, state, timestamp, _):
174189
175190 @staticmethod
176191 def _memory_tracker_ (
177- tracked_pid ,
178192 input_queue : Queue ,
179193 yaml_path : str ,
180194 track_memory_interval : float ,
@@ -191,17 +205,14 @@ def _memory_tracker_(
191205 3) None - This indicates that the tracker should stop tracking, save its data to a file
192206 and end
193207 """
208+ tracked_processes = None
194209 memory_tracks = []
195210 current_track = []
196211 track_name = None
197212 tracker_exit = False
198213
199214 try :
200- tracked_process = psutil .Process (tracked_pid )
201- while (
202- not tracker_exit and tracked_process .status () == psutil .STATUS_RUNNING
203- ):
204-
215+ while not tracker_exit :
205216 time .sleep (track_memory_interval )
206217
207218 # Read any messages from the parent process
@@ -227,6 +238,10 @@ def _memory_tracker_(
227238 "Track name must be passed to memory tracker prior to "
228239 "sending data"
229240 )
241+ elif isinstance (message , int ):
242+ if tracked_processes is None :
243+ tracked_processes = []
244+ tracked_processes .append (psutil .Process (message ))
230245 else :
231246 raise TypeError (
232247 "Unrecognized message type in memory_tracker input queue: "
@@ -240,7 +255,7 @@ def _memory_tracker_(
240255 if not tracker_exit and track_name is not None :
241256 # Save current time and memory usage
242257 current_track .append (
243- MemoryTracker .get_time_mem_list (tracked_process )
258+ MemoryTracker .get_time_mem_list (tracked_processes )
244259 )
245260
246261 # Save the collected memory tracks
0 commit comments