@@ -40,71 +40,109 @@ def __init__(self, xml_path, num_requests=1):
4040
4141 # 3. 创建异步队列
4242 # 对于游戏辅助,jobs 建议设为 1 或 2,以保证最低延迟
43- self .infer_queue = AsyncInferQueue ( self . compiled_model , jobs = num_requests )
44- self .infer_queue . set_callback ( self ._callback )
43+ self .num_requests = num_requests
44+ self .infer_queue = self ._create_infer_queue ( )
4545
4646 # 内部状态
4747 self .latest_results = None
4848 self .latest_image = None
4949 self .class_names = ["target" ] # 可根据 data.yaml 修改
5050 self .latency = 0.0 # 单次推理总耗时 (秒)
5151 self .job_id = 0
52+ self ._retired_infer_queues = []
53+ self ._active_queue_jobs = {}
54+
55+ def _create_infer_queue (self ):
56+ infer_queue = AsyncInferQueue (self .compiled_model , jobs = self .num_requests )
57+ infer_queue .set_callback (self ._callback )
58+ return infer_queue
59+
60+ def _cleanup_retired_infer_queues (self ):
61+ self ._retired_infer_queues = [
62+ infer_queue
63+ for infer_queue in self ._retired_infer_queues
64+ if self ._active_queue_jobs .get (id (infer_queue ), 0 ) > 0
65+ ]
66+
67+ def _mark_queue_job_started (self , infer_queue ):
68+ queue_id = id (infer_queue )
69+ self ._active_queue_jobs [queue_id ] = (
70+ self ._active_queue_jobs .get (queue_id , 0 ) + 1
71+ )
72+ return queue_id
73+
74+ def _mark_queue_job_finished (self , queue_id ):
75+ pending_jobs = self ._active_queue_jobs .get (queue_id , 0 ) - 1
76+ if pending_jobs > 0 :
77+ self ._active_queue_jobs [queue_id ] = pending_jobs
78+ else :
79+ self ._active_queue_jobs .pop (queue_id , None )
80+
81+ def _queue_has_active_jobs (self , infer_queue ):
82+ return self ._active_queue_jobs .get (id (infer_queue ), 0 ) > 0
5283
5384 def _callback (self , infer_request , user_data ):
5485 """异步推理完成后的回调函数"""
55- job_id = user_data .get ("job_id" , 0 )
56- if job_id < self .job_id :
57- return
58-
59- start_time = user_data ["start_time" ]
60- self .latency = time .time () - start_time
61-
62- detections = infer_request .get_output_tensor ().data [0 ]
63-
64- box = user_data ["box" ]
65- threshold = user_data ["threshold" ]
66- target_label = user_data ["label" ]
67- pad_x = user_data ["pad_x" ]
68- pad_y = user_data ["pad_y" ]
69-
70- # 1. 画布相较于模型的缩放比例
71- scale = user_data ["target_w" ] / self .model_w
72-
73- tmp_results = []
74- for x1 , y1 , x2 , y2 , conf , cls_id in detections :
75- if conf < threshold :
76- continue
77-
78- name = (
79- self .class_names [int (cls_id )] if int (cls_id ) < len (self .class_names ) else "unknown"
80- )
81- if target_label and name != target_label :
82- continue
83-
84- # 2. 从 AI 的坐标还原到带灰边的 Canvas 坐标
85- canvas_x1 = x1 * scale
86- canvas_y1 = y1 * scale
87- canvas_w = (x2 - x1 ) * scale
88- canvas_h = (y2 - y1 ) * scale
89-
90- # 3. 减去灰边的偏移量,得到在输入 input_crop 中的坐标
91- # 再加上外面传进来的 Box 原图坐标,直接映射到全屏
92- abs_x = int (canvas_x1 - pad_x + box .x )
93- abs_y = int (canvas_y1 - pad_y + box .y )
94-
95- tmp_results .append (
96- Box (
97- x = abs_x ,
98- y = abs_y ,
99- width = int (canvas_w ),
100- height = int (canvas_h ),
101- confidence = float (conf ),
102- name = name ,
86+ queue_id = user_data .get ("queue_id" )
87+ try :
88+ job_id = user_data .get ("job_id" , 0 )
89+ if job_id < self .job_id :
90+ return
91+
92+ start_time = user_data ["start_time" ]
93+ self .latency = time .time () - start_time
94+
95+ detections = infer_request .get_output_tensor ().data [0 ]
96+
97+ box = user_data ["box" ]
98+ threshold = user_data ["threshold" ]
99+ target_label = user_data ["label" ]
100+ pad_x = user_data ["pad_x" ]
101+ pad_y = user_data ["pad_y" ]
102+
103+ # 1. 画布相较于模型的缩放比例
104+ scale = user_data ["target_w" ] / self .model_w
105+
106+ tmp_results = []
107+ for x1 , y1 , x2 , y2 , conf , cls_id in detections :
108+ if conf < threshold :
109+ continue
110+
111+ name = (
112+ self .class_names [int (cls_id )]
113+ if int (cls_id ) < len (self .class_names )
114+ else "unknown"
115+ )
116+ if target_label and name != target_label :
117+ continue
118+
119+ # 2. 从 AI 的坐标还原到带灰边的 Canvas 坐标
120+ canvas_x1 = x1 * scale
121+ canvas_y1 = y1 * scale
122+ canvas_w = (x2 - x1 ) * scale
123+ canvas_h = (y2 - y1 ) * scale
124+
125+ # 3. 减去灰边的偏移量,得到在输入 input_crop 中的坐标
126+ # 再加上外面传进来的 Box 原图坐标,直接映射到全屏
127+ abs_x = int (canvas_x1 - pad_x + box .x )
128+ abs_y = int (canvas_y1 - pad_y + box .y )
129+
130+ tmp_results .append (
131+ Box (
132+ x = abs_x ,
133+ y = abs_y ,
134+ width = int (canvas_w ),
135+ height = int (canvas_h ),
136+ confidence = float (conf ),
137+ name = name ,
138+ )
103139 )
104- )
105140
106- self .latest_results = tmp_results
107- self .latest_image = user_data .get ("image" )
141+ self .latest_results = tmp_results
142+ self .latest_image = user_data .get ("image" )
143+ finally :
144+ if queue_id is not None :
145+ self ._mark_queue_job_finished (queue_id )
108146
109147 def detect (
110148 self ,
@@ -121,13 +159,20 @@ def detect(
121159 :param box: 指定检测区域的 Box 实例。如果为 None, 则检测全图。
122160 :param threshold: 置信度阈值
123161 :param label: 指定检测的类别名称
124- :param force: 如果为 True,即使队列满也会阻塞提交新任务
162+ :param force: 如果为 True,即使队列满也会丢弃旧结果并立刻提交新任务
125163 :param mask_regions: 需要屏蔽的全图归一化区域列表,格式为
126164 [(x1, y1, x2, y2), ...]。屏蔽会应用到推理画布,不修改原图。
127165 :return: list[Box] (返回的是上一帧或最近一次完成的结果)
128166 """
129167
168+ self ._cleanup_retired_infer_queues ()
130169 if force or self .infer_queue .is_ready ():
170+ if force and not self .infer_queue .is_ready ():
171+ # Keep the busy queue alive so replacing it does not wait for its running request.
172+ self ._retired_infer_queues .append (self .infer_queue )
173+ self .infer_queue = self ._create_infer_queue ()
174+ self .job_id += 1
175+
131176 h , w = image .shape [:2 ]
132177
133178 if box is None :
@@ -175,21 +220,28 @@ def detect(
175220 self .job_id += 1
176221 current_job_id = self .job_id
177222
178- self .infer_queue .start_async (
179- {0 : input_tensor },
180- {
181- "box" : box ,
182- "threshold" : threshold ,
183- "label" : label ,
184- "start_time" : time .time (),
185- # 传给回调函数,用于减去补边的偏移
186- "pad_x" : pad_x ,
187- "pad_y" : pad_y ,
188- "target_w" : target_w , # 记录画布的总宽用于还原缩放
189- "job_id" : current_job_id ,
190- "image" : image ,
191- },
192- )
223+ infer_queue = self .infer_queue
224+ queue_id = self ._mark_queue_job_started (infer_queue )
225+ try :
226+ infer_queue .start_async (
227+ {0 : input_tensor },
228+ {
229+ "box" : box ,
230+ "threshold" : threshold ,
231+ "label" : label ,
232+ "start_time" : time .time (),
233+ # 传给回调函数,用于减去补边的偏移
234+ "pad_x" : pad_x ,
235+ "pad_y" : pad_y ,
236+ "target_w" : target_w , # 记录画布的总宽用于还原缩放
237+ "job_id" : current_job_id ,
238+ "queue_id" : queue_id ,
239+ "image" : image ,
240+ },
241+ )
242+ except Exception :
243+ self ._mark_queue_job_finished (queue_id )
244+ raise
193245
194246 return self .latest_results
195247
@@ -236,3 +288,7 @@ def clear_cache(self):
236288 self .latest_results = None
237289 self .latest_image = None
238290 self .job_id += 1 # 增加 epoch,所有正在运行的旧任务的回调都会失效
291+ if self ._queue_has_active_jobs (self .infer_queue ):
292+ self ._retired_infer_queues .append (self .infer_queue )
293+ self .infer_queue = self ._create_infer_queue ()
294+ self ._cleanup_retired_infer_queues ()
0 commit comments