Skip to content

Commit a8e0cc6

Browse files
Run Pre/Post processing for components in a separate thread (#13168)
1 parent 835e4bd commit a8e0cc6

15 files changed

Lines changed: 378 additions & 80 deletions

.changeset/wild-hounds-laugh.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"gradio": patch
3+
---
4+
5+
feat:Run Pre/Post processing for components in a separate thread

gradio/blocks.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,11 +1771,14 @@ async def preprocess_data(
17711771
inputs[i].get("value", None) if is_prop_input else inputs[i]
17721772
)
17731773

1774-
inputs_cached = await processing_utils.async_move_files_to_cache(
1775-
value_to_process,
1776-
block,
1777-
check_in_upload_folder=not explicit_call,
1778-
)
1774+
from gradio.profiling import trace_phase
1775+
1776+
async with trace_phase("preprocess_move_to_cache"):
1777+
inputs_cached = await processing_utils.async_move_files_to_cache(
1778+
value_to_process,
1779+
block,
1780+
check_in_upload_folder=not explicit_call,
1781+
)
17791782
if getattr(block, "data_model", None) and inputs_cached is not None:
17801783
data_model = cast(
17811784
Union[GradioModel, GradioRootModel], block.data_model
@@ -1793,7 +1796,9 @@ async def preprocess_data(
17931796
state._update_value_in_config(block._id, inputs_serialized)
17941797

17951798
if block_fn.preprocess:
1796-
processed_value = block.preprocess(inputs_cached)
1799+
processed_value = await anyio.to_thread.run_sync(
1800+
block.preprocess, inputs_cached, limiter=self.limiter
1801+
)
17971802
else:
17981803
processed_value = inputs_serialized
17991804

@@ -1855,6 +1860,8 @@ async def postprocess_data(
18551860
predictions: list | dict,
18561861
state: SessionState | None,
18571862
) -> list[Any]:
1863+
from gradio.profiling import trace_phase
1864+
18581865
state = state or SessionState(self)
18591866
if (
18601867
isinstance(predictions, dict)
@@ -1943,33 +1950,36 @@ async def postprocess_data(
19431950
)
19441951
if block._id in state:
19451952
block = state[block._id]
1946-
prediction_value = block.postprocess(prediction_value)
1953+
prediction_value = await anyio.to_thread.run_sync(
1954+
block.postprocess, prediction_value, limiter=self.limiter
1955+
)
19471956
if isinstance(prediction_value, (GradioModel, GradioRootModel)):
19481957
prediction_value_serialized = prediction_value.model_dump()
19491958
else:
19501959
prediction_value_serialized = prediction_value
1951-
prediction_value_serialized = (
1952-
await processing_utils.async_move_files_to_cache(
1953-
prediction_value_serialized,
1954-
block,
1955-
postprocess=True,
1960+
async with trace_phase("postprocess_update_state_in_config"):
1961+
prediction_value_serialized = (
1962+
await processing_utils.async_move_files_to_cache(
1963+
prediction_value_serialized,
1964+
block,
1965+
postprocess=True,
1966+
)
1967+
)
1968+
if block._id not in state:
1969+
state[block._id] = block
1970+
state._update_value_in_config(
1971+
block._id, prediction_value_serialized
19561972
)
1957-
)
1958-
if block._id not in state:
1959-
state[block._id] = block
1960-
state._update_value_in_config(
1961-
block._id, prediction_value_serialized
1962-
)
19631973
elif not block_fn.postprocess:
19641974
if block._id not in state:
19651975
state[block._id] = block
19661976
state._update_value_in_config(block._id, prediction_value)
1967-
1968-
outputs_cached = await processing_utils.async_move_files_to_cache(
1969-
prediction_value,
1970-
block,
1971-
postprocess=True,
1972-
)
1977+
async with trace_phase("postprocess_move_to_cache"):
1978+
outputs_cached = await processing_utils.async_move_files_to_cache(
1979+
prediction_value,
1980+
block,
1981+
postprocess=True,
1982+
)
19731983
output.append(outputs_cached)
19741984

19751985
return output

gradio/components/audio.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import dataclasses
66
import io
7+
import json
78
from collections.abc import Callable, Sequence
89
from pathlib import Path
910
from typing import TYPE_CHECKING, Any, Literal
@@ -414,9 +415,6 @@ def _process_json_subtitles(
414415
def _process_subtitle_file(
415416
self, subtitle_file: str | Path
416417
) -> FileData | list[dict[str, Any]]:
417-
import json
418-
from pathlib import Path
419-
420418
file_path = Path(subtitle_file)
421419
if file_path.suffix.lower() == ".json":
422420
try:

gradio/components/video.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from gradio.data_classes import FileData, MediaStreamChunk
2323
from gradio.events import Events
2424
from gradio.i18n import I18nData
25+
from gradio.profiling import trace_phase_sync, traced_sync
2526
from gradio.utils import get_upload_folder, set_default_buttons
2627

2728
if TYPE_CHECKING:
@@ -189,6 +190,7 @@ def __init__(
189190
)
190191
self._value_description = "a string filepath to a video"
191192

193+
@traced_sync("preprocess_video")
192194
def preprocess(self, payload: FileData | None) -> str | None:
193195
"""
194196
Parameters:
@@ -250,6 +252,7 @@ def preprocess(self, payload: FileData | None) -> str | None:
250252
else:
251253
return str(file_name)
252254

255+
@traced_sync("postprocess_video")
253256
def postprocess(self, value: str | Path | None) -> FileData | None:
254257
"""
255258
Parameters:
@@ -272,6 +275,7 @@ def _format_video(self, video: str | Path | None) -> FileData | None:
272275
Processes a video to ensure that it is in the correct format
273276
and adds a watermark if requested.
274277
"""
278+
275279
if video is None:
276280
return None
277281
video = str(video)
@@ -301,7 +305,8 @@ def _format_video(self, video: str | Path | None) -> FileData | None:
301305
warnings.warn(
302306
"Video does not have browser-compatible container or codec. Converting to mp4."
303307
)
304-
video = processing_utils.convert_video_to_playable_mp4(video)
308+
with trace_phase_sync("postprocess_video_convert_video_to_playable_mp4"):
309+
video = processing_utils.convert_video_to_playable_mp4(video)
305310
# Recalculate the format in case convert_video_to_playable_mp4 already made it the selected format
306311
returned_format = utils.get_extension_from_file_path_or_url(video).lower()
307312
if (
@@ -404,8 +409,6 @@ def _format_subtitles(self, subtitle: str | Path | None) -> FileData | None:
404409
"""
405410
Convert subtitle format to VTT and process the video to ensure it meets the HTML5 requirements.
406411
"""
407-
import json
408-
from pathlib import Path
409412

410413
def srt_to_vtt(srt_file_path, vtt_file_path):
411414
"""Convert an SRT subtitle file to a VTT subtitle file"""

gradio/image_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from gradio.components.image_editor import WatermarkOptions
1818
from gradio.data_classes import ImageData
1919
from gradio.exceptions import Error
20+
from gradio.profiling import traced_sync
2021

2122
PIL.Image.init() # fixes https://github.com/gradio-app/gradio/issues/2843 (remove when requiring Pillow 9.4+)
2223

@@ -258,6 +259,7 @@ def extract_svg_content(image_file: str | Path) -> str:
258259
return svg_content
259260

260261

262+
@traced_sync("preprocess_format_image")
261263
def preprocess_image(
262264
payload: ImageData | None,
263265
cache_dir: str,
@@ -311,6 +313,7 @@ def preprocess_image(
311313
warnings.simplefilter("ignore")
312314
if image_mode is not None:
313315
im = im.convert(image_mode)
316+
314317
return format_image(
315318
im,
316319
type=type,

gradio/processing_utils.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from gradio.context import LocalContext
3131
from gradio.data_classes import FileData, GradioModel, GradioRootModel, JsonData
3232
from gradio.exceptions import Error, InvalidPathError
33+
from gradio.profiling import traced_sync
3334
from gradio.route_utils import API_PREFIX
3435
from gradio.utils import abspath, get_hash_seed, get_upload_folder, is_in_or_equal
3536

@@ -155,6 +156,7 @@ def hash_base64(base64_encoding: str, chunk_num_blocks: int = 128) -> str:
155156
return sha.hexdigest()
156157

157158

159+
@traced_sync("postprocess_save_pil_to_cache")
158160
def save_pil_to_cache(
159161
img: Image.Image,
160162
cache_dir: str,
@@ -169,13 +171,15 @@ def save_pil_to_cache(
169171
return filename
170172

171173

174+
@traced_sync("postprocess_save_img_array_to_cache")
172175
def save_img_array_to_cache(
173176
arr: np.ndarray, cache_dir: str, format: str = "webp"
174177
) -> str:
175178
pil_image = Image.fromarray(_convert(arr, np.uint8, force_copy=False))
176179
return save_pil_to_cache(pil_image, cache_dir, format=format)
177180

178181

182+
@traced_sync("postprocess_save_audio_to_cache")
179183
def save_audio_to_cache(
180184
data: np.ndarray, sample_rate: int, format: str, cache_dir: str
181185
) -> str:
@@ -207,6 +211,7 @@ def detect_audio_format(data: bytes) -> str:
207211
return ""
208212

209213

214+
@traced_sync("postprocess_save_bytes_to_cache")
210215
def save_bytes_to_cache(data: bytes, file_name: str, cache_dir: str) -> str:
211216
path = Path(cache_dir) / hash_bytes(data)
212217
path.mkdir(exist_ok=True, parents=True)
@@ -218,6 +223,7 @@ def save_bytes_to_cache(data: bytes, file_name: str, cache_dir: str) -> str:
218223
return str(path.resolve())
219224

220225

226+
@traced_sync("save_file_to_cache")
221227
def save_file_to_cache(file_path: str | Path, cache_dir: str) -> str:
222228
"""Returns a temporary file path for a copy of the given file path if it does
223229
not already exist. Otherwise returns the path to the existing temp file."""
@@ -644,28 +650,31 @@ def resize_and_crop(img, size, crop_type="center"):
644650
def audio_from_file(
645651
filename: str, crop_min: float = 0, crop_max: float = 100
646652
) -> tuple[int, np.ndarray]:
647-
try:
648-
audio = AudioSegment.from_file(filename)
649-
except FileNotFoundError as e:
650-
isfile = Path(filename).is_file()
651-
msg = (
652-
f"Cannot load audio from file: `{'ffprobe' if isfile else filename}` not found."
653-
+ " Please install `ffmpeg` in your system to use non-WAV audio file formats"
654-
" and make sure `ffprobe` is in your PATH."
655-
if isfile
656-
else ""
657-
)
658-
raise RuntimeError(msg) from e
659-
except OSError as e:
660-
raise e
661-
if crop_min != 0 or crop_max != 100:
662-
audio_start = len(audio) * crop_min / 100
663-
audio_end = len(audio) * crop_max / 100
664-
audio = audio[audio_start:audio_end]
665-
data = np.array(audio.get_array_of_samples())
666-
if audio.channels > 1:
667-
data = data.reshape(-1, audio.channels)
668-
return audio.frame_rate, data
653+
from gradio.profiling import trace_phase_sync
654+
655+
with trace_phase_sync("preprocess_audio_from_file"):
656+
try:
657+
audio = AudioSegment.from_file(filename)
658+
except FileNotFoundError as e:
659+
isfile = Path(filename).is_file()
660+
msg = (
661+
f"Cannot load audio from file: `{'ffprobe' if isfile else filename}` not found."
662+
+ " Please install `ffmpeg` in your system to use non-WAV audio file formats"
663+
" and make sure `ffprobe` is in your PATH."
664+
if isfile
665+
else ""
666+
)
667+
raise RuntimeError(msg) from e
668+
except OSError as e:
669+
raise e
670+
if crop_min != 0 or crop_max != 100:
671+
audio_start = len(audio) * crop_min / 100
672+
audio_end = len(audio) * crop_max / 100
673+
audio = audio[audio_start:audio_end]
674+
data = np.array(audio.get_array_of_samples())
675+
if audio.channels > 1:
676+
data = data.reshape(-1, audio.channels)
677+
return audio.frame_rate, data
669678

670679

671680
def audio_to_file(sample_rate, data, filename, format="wav"):

0 commit comments

Comments
 (0)