Skip to content

Commit c8f2a80

Browse files
authored
Frames skipping (open-edge-platform#819)
1 parent 45004cf commit c8f2a80

7 files changed

Lines changed: 538 additions & 201 deletions

File tree

application/backend/app/runtime/components.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ class ComponentFactory(ABC):
2525
def create_source(self, project_id: UUID) -> Source: ...
2626

2727
@abstractmethod
28-
def create_processor(
29-
self, project_id: UUID, reference_batch: Batch, category_id_to_label_id: dict[int, str]
30-
) -> Processor: ...
28+
def create_processor(self, project_id: UUID, reference_batch: Batch) -> Processor: ...
3129

3230
@abstractmethod
3331
def create_sink(self, project_id: UUID) -> Sink: ...
@@ -46,18 +44,17 @@ def create_source(self, project_id: UUID) -> Source:
4644
cfg = svc.get_pipeline_config(project_id)
4745
return Source(StreamReaderFactory.create(cfg.reader))
4846

49-
def create_processor(
50-
self, project_id: UUID, reference_batch: Batch, category_id_to_label_id: dict[int, str]
51-
) -> Processor:
47+
def create_processor(self, project_id: UUID, reference_batch: Batch) -> Processor:
5248
with self._session_factory() as session:
5349
project_svc = ProjectService(session)
5450
cfg = project_svc.get_pipeline_config(project_id)
55-
logger.info("Creating processor with model config: %s", cfg.processor)
56-
51+
logger.info("Creating processor with model config: %s", cfg.processor)
52+
settings = get_settings()
5753
return Processor(
5854
model_handler=ModelFactory.create(reference_batch, cfg.processor),
59-
batch_size=get_settings().processor_batch_size,
60-
category_id_to_label_id=category_id_to_label_id,
55+
batch_size=settings.processor_batch_size,
56+
frame_skip_interval=settings.processor_frame_skip_interval,
57+
frame_skip_amount=settings.processor_frame_skip_amount,
6158
)
6259

6360
def create_sink(self, project_id: UUID) -> Sink:

application/backend/app/runtime/core/components/pipeline.py

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,35 @@ def stop(self) -> None:
115115
logger.debug(f"Pipeline stopped for project_id={self._project_id}")
116116

117117
def set_source(self, source: Source, start: bool = False) -> Self:
118-
source.setup(self._inbound_broadcaster)
119-
self._register_component(source, start)
118+
with self._lock:
119+
self._stop_component(Source)
120+
source.setup(self._inbound_broadcaster)
121+
self._register_component(source, start)
120122
return self
121123

122124
def set_sink(self, sink: Sink, start: bool = False) -> Self:
123-
sink.setup(self._outbound_broadcaster)
124-
self._register_component(sink, start)
125+
with self._lock:
126+
self._stop_component(Sink)
127+
sink.setup(self._outbound_broadcaster)
128+
self._register_component(sink, start)
125129
return self
126130

127131
def set_processor(self, processor: Processor, start: bool = False) -> Self:
128-
processor.setup(self._inbound_broadcaster, self._outbound_broadcaster)
129-
self._register_component(processor, start)
132+
with self._lock:
133+
self._stop_component(Processor)
134+
processor.setup(self._inbound_broadcaster, self._outbound_broadcaster)
135+
self._register_component(processor, start)
130136
return self
131137

138+
def _stop_component(self, component_cls: type[PipelineComponent]) -> None:
139+
"""Stop and join the existing component of the given type, if any."""
140+
current = self._components.get(component_cls)
141+
if current:
142+
current.stop()
143+
thread = self._threads.get(component_cls)
144+
if thread and thread.is_alive():
145+
thread.join(timeout=5)
146+
132147
def _register_component(self, new_component: PipelineComponent, start: bool = True) -> None:
133148
"""
134149
A method to replace a component with a new one.
@@ -139,27 +154,14 @@ def _register_component(self, new_component: PipelineComponent, start: bool = Tr
139154
new_component: The new component instance.
140155
"""
141156
component_cls = new_component.__class__
142-
143-
with self._lock:
144-
# Stop the current component if one exists
145-
current_component = self._components.get(component_cls)
146-
if current_component:
147-
current_component.stop()
148-
thread = self._threads.get(component_cls)
149-
if thread and thread.is_alive():
150-
thread.join(timeout=5)
151-
if thread.is_alive():
152-
logger.warning(f"{component_cls.__name__} thread did not stop cleanly")
153-
154-
self._inbound_broadcaster.clear()
155-
self._outbound_broadcaster.clear()
156-
157-
self._components[component_cls] = new_component
158-
if start:
159-
thread = Thread(target=new_component, daemon=False)
160-
thread.start()
161-
self._threads[component_cls] = thread
162-
logger.debug(f"Started new {component_cls.__name__}")
157+
self._inbound_broadcaster.clear()
158+
self._outbound_broadcaster.clear()
159+
self._components[component_cls] = new_component
160+
if start:
161+
thread = Thread(target=new_component, daemon=False)
162+
thread.start()
163+
self._threads[component_cls] = thread
164+
logger.debug(f"Started new {component_cls.__name__}")
163165

164166
def seek(self, index: int) -> None:
165167
"""Seek to a specific frame in the source."""

application/backend/app/runtime/core/components/processor.py

Lines changed: 87 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,67 +15,130 @@
1515
EMPTY_RESULT: dict[str, np.ndarray] = {}
1616

1717

18+
class FrameSkipPolicy:
19+
"""
20+
Decides whether to skip (drop) a frame based on a cyclic counter.
21+
22+
Skip pattern with interval N (N >= 1):
23+
Process frames 1..N-1, drop frame N, repeat.
24+
Example (N=3): process, process, DROP, process, process, DROP, ...
25+
26+
If interval is 0, no frames are ever skipped.
27+
28+
Args:
29+
interval: Total cycle length (process and skip). 0 disables skipping.
30+
skip_amount: Number of consecutive frames to skip per cycle. Must be < interval.
31+
32+
Raises:
33+
ValueError: If frame_skip_interval is negative.
34+
"""
35+
36+
def __init__(self, interval: int = 3, skip_amount: int = 1) -> None:
37+
if interval < 0 or interval == 1:
38+
raise ValueError(f"frame_skip_interval must be > 1 or 0 for no skipping, got {interval}")
39+
if interval > 0 and (skip_amount < 0 or skip_amount >= interval):
40+
raise ValueError(f"skip_amount must be >= 0 and < interval, got {skip_amount} and {interval}")
41+
self._interval = interval
42+
self._skip_amount = skip_amount
43+
self._counter = 0
44+
45+
@property
46+
def interval(self) -> int:
47+
return self._interval
48+
49+
@property
50+
def skip_amount(self) -> int:
51+
return self._skip_amount
52+
53+
def should_skip(self) -> bool:
54+
"""Return True if the current frame should be dropped. Advances the internal counter on every call."""
55+
if self._interval == 0 or self._skip_amount == 0:
56+
return False
57+
58+
position = self._counter % self._interval
59+
self._counter += 1
60+
61+
# process the first (interval - skip_count) frames, skip the rest
62+
process_count = self._interval - self._skip_amount
63+
return position >= process_count
64+
65+
def reset(self) -> None:
66+
"""Reset the internal counter."""
67+
self._counter = 0
68+
69+
1870
class Processor(PipelineComponent):
1971
"""
2072
A job component responsible for retrieving raw frames from the inbound broadcaster,
2173
sending them to a processor for inference, and broadcasting the processed results to subscribed consumers.
74+
75+
Supports frame skipping to align model throughput with source frame rate.
2276
"""
2377

2478
def __init__(
25-
self,
26-
model_handler: ModelHandler,
27-
batch_size: int = 3,
28-
category_id_to_label_id: dict[int, str] | None = None,
79+
self, model_handler: ModelHandler, batch_size: int = 1, frame_skip_interval: int = 3, frame_skip_amount: int = 1
2980
) -> None:
3081
super().__init__()
3182
self._model_handler = model_handler
3283
self._batch_size = batch_size
33-
self._category_id_to_label_id = category_id_to_label_id or {}
84+
self._skip_policy = FrameSkipPolicy(interval=frame_skip_interval, skip_amount=frame_skip_amount)
85+
self._initialized = False
3486

3587
def setup(
36-
self,
37-
inbound_broadcaster: FrameBroadcaster[InputData],
38-
outbound_broadcaster: FrameBroadcaster[OutputData],
88+
self, inbound_broadcaster: FrameBroadcaster[InputData], outbound_broadcaster: FrameBroadcaster[OutputData]
3989
) -> None:
4090
self._inbound_broadcaster = inbound_broadcaster
4191
self._outbound_broadcaster = outbound_broadcaster
4292
self._in_queue: Queue[InputData] = inbound_broadcaster.register(self.__class__.__name__)
4393
self._initialized = True
4494

45-
def run(self) -> None:
95+
def run(self) -> None: # noqa: C901
96+
if not self._initialized:
97+
raise RuntimeError("Processor must be set up before running")
4698
logger.debug("Starting a pipeline runner loop")
99+
47100
self._model_handler.initialise()
48-
logger.info("Pipeline model handler initialized")
101+
logger.info(
102+
"Pipeline model handler initialized, batch size: %d, frame skip interval: %d, skip amount: %d",
103+
self._batch_size,
104+
self._skip_policy.interval,
105+
self._skip_policy.skip_amount,
106+
)
49107

50108
while not self._stop_event.is_set():
51109
try:
52110
batch_data: list[InputData] = []
53-
for _ in range(self._batch_size):
111+
while len(batch_data) < self._batch_size and not self._stop_event.is_set():
54112
try:
55-
input_data = self._in_queue.get(timeout=0.1)
113+
input_data: InputData = self._in_queue.get(timeout=0.1)
56114
if input_data.trace:
57115
input_data.trace.record_start("processor")
58-
batch_data.append(input_data)
59-
60-
if input_data.context.get("requires_manual_control", False):
61-
break
62116
except Empty:
117+
if batch_data: # if we have partial batch data, process what we have
118+
break
119+
continue
120+
121+
is_manual = input_data.context.get("requires_manual_control", False)
122+
123+
if not is_manual and self._skip_policy.should_skip():
124+
logger.debug("Frame skipped (timestamp=%s)", input_data.timestamp)
125+
continue
126+
127+
batch_data.append(input_data)
128+
129+
if is_manual:
63130
break
64131

65-
if not batch_data:
132+
if not batch_data or self._stop_event.is_set():
66133
continue
67134

68-
batch_results = self._model_handler.predict(batch_data)
135+
results = self._model_handler.predict(batch_data)
69136

70137
for i, data in enumerate(batch_data):
71-
results: dict[str, np.ndarray] = batch_results[i] if i < len(batch_results) else EMPTY_RESULT
138+
result = results[i] if i < len(results) else EMPTY_RESULT
72139
if data.trace:
73140
data.trace.record_end("processor")
74-
output_data = OutputData(
75-
frame=data.frame,
76-
results=[results],
77-
trace=data.trace,
78-
)
141+
output_data = OutputData(frame=data.frame, results=[result] if result else [], trace=data.trace)
79142
self._outbound_broadcaster.broadcast(output_data)
80143

81144
except Exception as e:

application/backend/app/runtime/pipeline_manager.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _create_pipeline(self, project_id: UUID) -> Pipeline:
171171
"""
172172
source = self._component_factory.create_source(project_id)
173173
reference_batch, category_id_to_label_id = self.get_reference_batch(project_id, PromptType.VISUAL) or (None, {})
174-
processor = self._component_factory.create_processor(project_id, reference_batch, category_id_to_label_id)
174+
processor = self._component_factory.create_processor(project_id, reference_batch)
175175
sink = self._component_factory.create_sink(project_id)
176176

177177
return (
@@ -206,9 +206,7 @@ def _update_pipeline_components(self, project_id: UUID, component_type: Componen
206206
None,
207207
{},
208208
)
209-
processor = self._component_factory.create_processor(
210-
project_id, reference_batch, category_id_to_label_id
211-
)
209+
processor = self._component_factory.create_processor(project_id, reference_batch)
212210
self._pipeline.set_processor(processor, True)
213211
case ComponentType.SINK:
214212
sink = self._component_factory.create_sink(project_id)

application/backend/app/settings.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def log_file(self) -> str:
9999
thumbnail_jpeg_quality: int = 85
100100

101101
# Processor configuration
102-
processor_batch_size: int = Field(default=3, alias="PROCESSOR_BATCH_SIZE")
102+
processor_batch_size: int = Field(default=1, alias="PROCESSOR_BATCH_SIZE")
103+
processor_frame_skip_interval: int = Field(default=3, ge=0, alias="PROCESSOR_FRAME_SKIP_INTERVAL")
104+
processor_frame_skip_amount: int = Field(default=1, ge=0, alias="PROCESSOR_FRAME_SKIP_AMOUNT")
103105
processor_inference_enabled: bool = Field(default=True, alias="PROCESSOR_INFERENCE_ENABLED")
104106

105107
# WebRTC

0 commit comments

Comments
 (0)