Skip to content

Commit 24f35c1

Browse files
committed
changes
1 parent a859617 commit 24f35c1

2 files changed

Lines changed: 34 additions & 2 deletions

File tree

gradio/caching.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import copy
56
import functools
67
import hashlib
78
import inspect
@@ -173,7 +174,7 @@ def sync_gen_wrapper(*args, **kwargs):
173174
return
174175
all_yields = []
175176
for value in func(**normalized):
176-
all_yields.append(value)
177+
all_yields.append(copy.deepcopy(value))
177178
yield value
178179
if all_yields:
179180
store.put(key_hash, yields=all_yields)
@@ -194,7 +195,7 @@ async def async_gen_wrapper(*args, **kwargs):
194195
return
195196
all_yields = []
196197
async for value in func(**normalized):
197-
all_yields.append(value)
198+
all_yields.append(copy.deepcopy(value))
198199
yield value
199200
if all_yields:
200201
store.put(key_hash, yields=all_yields)

test/test_caching.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,37 @@ async def afn(x):
217217
assert inspect.isasyncgenfunction(agen)
218218
assert inspect.iscoroutinefunction(afn)
219219

220+
def test_sync_generator_mutable_yields(self):
221+
"""Cached replay should snapshot each yield, not alias mutable objects."""
222+
223+
@cache
224+
def streamer(n):
225+
result = []
226+
for i in range(1, n + 1):
227+
result.append(i)
228+
yield result
229+
230+
list(streamer(3))
231+
cached_run = list(streamer(3))
232+
assert cached_run == [[1], [1, 2], [1, 2, 3]]
233+
234+
def test_async_generator_mutable_yields(self):
235+
"""Async variant: cached replay should snapshot each yield."""
236+
237+
@cache
238+
async def streamer(n):
239+
result = []
240+
for i in range(1, n + 1):
241+
result.append(i)
242+
yield result
243+
244+
async def run():
245+
return [v async for v in streamer(3)]
246+
247+
asyncio.run(run())
248+
cached_run = asyncio.run(run())
249+
assert cached_run == [[1], [1, 2], [1, 2, 3]]
250+
220251
def test_cache_clear(self):
221252
@cache
222253
def fn(x):

0 commit comments

Comments
 (0)