Skip to content

Commit 4a8baab

Browse files
committed
Publish unit tests.
PiperOrigin-RevId: 779184803
1 parent 1ed0a9d commit 4a8baab

24 files changed

Lines changed: 6425 additions & 0 deletions
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import time
2+
import unittest
3+
from unittest import mock
4+
5+
from absl.testing import parameterized
6+
from genai_processors import content_api
7+
from genai_processors import streams
8+
from genai_processors.core import audio_io
9+
import pyaudio
10+
11+
12+
class PyAudioInTest(parameterized.TestCase, unittest.IsolatedAsyncioTestCase):
13+
"""Tests for the PyAudioIn processor."""
14+
15+
def setUp(self):
16+
super().setUp()
17+
self.pyaudio_mock = mock.MagicMock()
18+
self.stream_mock = mock.MagicMock()
19+
self.input_stream = streams.stream_content(
20+
[
21+
content_api.ProcessorPart(
22+
'hello',
23+
),
24+
],
25+
# This delay is added after the text part is returned to ensure the
26+
# stream ends after the audio part is returned.
27+
with_delay_sec=0.1,
28+
)
29+
30+
async def test_py_audio_in(self):
31+
32+
def side_effect(chunk_size, exception_on_overflow=False):
33+
del exception_on_overflow
34+
del chunk_size
35+
# Delay it to return the audio bytes after the text part.
36+
time.sleep(0.05)
37+
return b'audio_bytes'
38+
39+
self.stream_mock.read = mock.MagicMock()
40+
self.stream_mock.read.side_effect = side_effect
41+
self.pyaudio_mock.open.return_value = self.stream_mock
42+
with mock.patch.object(
43+
pyaudio,
44+
'PyAudio',
45+
return_value=self.pyaudio_mock,
46+
):
47+
audio_in = audio_io.PyAudioIn(pya=self.pyaudio_mock)
48+
output = await streams.gather_stream(audio_in(self.input_stream))
49+
self.assertEqual(
50+
output,
51+
[
52+
content_api.ProcessorPart('hello'),
53+
content_api.ProcessorPart(
54+
content_api.ProcessorPart(
55+
b'audio_bytes', mimetype='audio/l16;rate=24000'
56+
),
57+
substream_name='realtime',
58+
role='USER',
59+
),
60+
],
61+
)
62+
63+
async def test_py_audio_in_with_exception(self):
64+
self.stream_mock.read = mock.MagicMock()
65+
self.stream_mock.read.side_effect = IOError('IOError')
66+
self.pyaudio_mock.open.return_value = self.stream_mock
67+
with mock.patch.object(
68+
pyaudio,
69+
'PyAudio',
70+
return_value=self.pyaudio_mock,
71+
):
72+
audio_in = audio_io.PyAudioIn(pya=self.pyaudio_mock)
73+
with self.assertRaises(IOError):
74+
await streams.gather_stream(audio_in(self.input_stream))
75+
76+
77+
class PyAudioOutTest(parameterized.TestCase, unittest.IsolatedAsyncioTestCase):
78+
"""Tests for the PyAudioOut processor."""
79+
80+
def setUp(self):
81+
super().setUp()
82+
self.pyaudio_mock = mock.MagicMock()
83+
self.stream_mock = mock.MagicMock()
84+
self.input_stream = streams.stream_content(
85+
[
86+
content_api.ProcessorPart(
87+
'hello',
88+
),
89+
content_api.ProcessorPart(
90+
b'audio_bytes',
91+
mimetype='audio/l16',
92+
),
93+
],
94+
# This delay is for testing only.
95+
with_delay_sec=0.1,
96+
)
97+
98+
@parameterized.named_parameters(
99+
(
100+
'passthrough_audio',
101+
True,
102+
[
103+
content_api.ProcessorPart('hello'),
104+
content_api.ProcessorPart(b'audio_bytes', mimetype='audio/l16'),
105+
],
106+
),
107+
('no_passthrough_audio', False, [content_api.ProcessorPart('hello')]),
108+
)
109+
async def test_py_audio_out(self, passthrough_audio, expected):
110+
111+
self.pyaudio_mock.open.return_value = self.stream_mock
112+
with mock.patch.object(
113+
pyaudio,
114+
'PyAudio',
115+
return_value=self.pyaudio_mock,
116+
):
117+
audio_out = audio_io.PyAudioOut(
118+
pya=self.pyaudio_mock, passthrough_audio=passthrough_audio
119+
)
120+
output = await streams.gather_stream(audio_out(self.input_stream))
121+
self.stream_mock.write.assert_called_with(b'audio_bytes')
122+
self.assertEqual(output, expected)
123+
124+
async def test_py_audio_out_with_exception(self):
125+
self.stream_mock.write.side_effect = IOError('IOError')
126+
self.pyaudio_mock.open.return_value = self.stream_mock
127+
with mock.patch.object(
128+
pyaudio,
129+
'PyAudio',
130+
return_value=self.pyaudio_mock,
131+
):
132+
audio_out = audio_io.PyAudioOut(
133+
pya=self.pyaudio_mock, passthrough_audio=True
134+
)
135+
with self.assertRaises(IOError):
136+
await streams.gather_stream(audio_out(self.input_stream))
137+
138+
139+
if __name__ == '__main__':
140+
unittest.main()
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
import asyncio
2+
import unittest
3+
from unittest import mock
4+
5+
from genai_processors import cache
6+
from genai_processors import content_api
7+
8+
9+
ProcessorPart = content_api.ProcessorPart
10+
ProcessorContent = content_api.ProcessorContent
11+
12+
13+
class InMemoryCacheTest(unittest.IsolatedAsyncioTestCase):
14+
15+
async def test_put_and_lookup_with_default_hash(self):
16+
"""Tests basic put and lookup using the default hashing."""
17+
mem_cache = cache.InMemoryCache(max_items=10, ttl_hours=1)
18+
query = ProcessorContent(['query_text'])
19+
value_content = ProcessorContent(
20+
[ProcessorPart('response_text', role='model')]
21+
)
22+
await mem_cache.put(query, value_content)
23+
retrieved = await mem_cache.lookup(query)
24+
25+
self.assertIsNot(retrieved, cache.CacheMiss)
26+
self.assertEqual(content_api.as_text(retrieved), 'response_text')
27+
28+
async def test_lookup_miss(self):
29+
"""Tests that a lookup for a non-existent key returns CacheMiss."""
30+
mem_cache = cache.InMemoryCache()
31+
retrieved = await mem_cache.lookup(ProcessorContent(['non_existent_key']))
32+
self.assertIs(retrieved, cache.CacheMiss)
33+
34+
async def test_put_override(self):
35+
mem_cache = cache.InMemoryCache()
36+
query = ProcessorContent(['query_text'])
37+
await mem_cache.put(query, ProcessorContent(['value1']))
38+
await mem_cache.put(query, ProcessorContent(['value2']))
39+
40+
retrieved = await mem_cache.lookup(query)
41+
self.assertIsNot(retrieved, cache.CacheMiss)
42+
self.assertEqual(content_api.as_text(retrieved), 'value2')
43+
44+
async def test_ttl_expiration(self):
45+
"""Tests that an item expires after its TTL."""
46+
ttl_hours = 0.0001 # A very short TTL
47+
mem_cache = cache.InMemoryCache(ttl_hours=ttl_hours, max_items=10)
48+
query = ProcessorContent(['key1'])
49+
await mem_cache.put(query, ProcessorContent(['value1']))
50+
51+
# Should exist immediately after
52+
self.assertIsNot(await mem_cache.lookup(query), cache.CacheMiss)
53+
54+
# Wait for more than the TTL
55+
await asyncio.sleep(ttl_hours * 3600 * 1.1)
56+
self.assertIs(await mem_cache.lookup(query), cache.CacheMiss)
57+
58+
async def test_max_items_eviction(self):
59+
mem_cache = cache.InMemoryCache(max_items=2)
60+
query1 = ProcessorContent(['key1'])
61+
query2 = ProcessorContent(['key2'])
62+
query3 = ProcessorContent(['key3'])
63+
64+
await mem_cache.put(query1, ProcessorContent(['value1']))
65+
await mem_cache.put(query2, ProcessorContent(['value2']))
66+
await mem_cache.put(
67+
query3, ProcessorContent(['value3'])
68+
) # This should evict query1
69+
70+
# query1 should be gone
71+
self.assertIs(await mem_cache.lookup(query1), cache.CacheMiss)
72+
# query2 and query3 should still be present
73+
self.assertIsNot(await mem_cache.lookup(query2), cache.CacheMiss)
74+
self.assertIsNot(await mem_cache.lookup(query3), cache.CacheMiss)
75+
76+
async def test_remove(self):
77+
mem_cache = cache.InMemoryCache()
78+
query = ProcessorContent(['key1'])
79+
await mem_cache.put(query, ProcessorContent(['value1']))
80+
self.assertIsNot(await mem_cache.lookup(query), cache.CacheMiss)
81+
82+
await mem_cache.remove(query)
83+
self.assertIs(await mem_cache.lookup(query), cache.CacheMiss)
84+
85+
async def test_uses_custom_hash_function(self):
86+
custom_hash_fn = mock.Mock(return_value='custom_key_123')
87+
mem_cache = cache.InMemoryCache(hash_fn=custom_hash_fn)
88+
query = ProcessorContent(['query_text'])
89+
90+
await mem_cache.put(query, ProcessorContent(['value']))
91+
custom_hash_fn.assert_called_once_with(query)
92+
93+
custom_hash_fn.reset_mock()
94+
await mem_cache.lookup(query)
95+
custom_hash_fn.assert_called_once_with(query)
96+
97+
async def test_with_key_prefix_isolates_caches(self):
98+
cache1 = cache.InMemoryCache(max_items=10)
99+
cache2 = cache1.with_key_prefix('p_')
100+
101+
self.assertIsNot(cache1, cache2)
102+
self.assertIsNot(
103+
cache1._cache, cache2._cache
104+
) # Important: they have different TTLCache instances
105+
self.assertIsInstance(cache2, cache.InMemoryCache)
106+
self.assertEqual(cache1._max_items, cache2._max_items)
107+
108+
query = ProcessorContent(['shared_query'])
109+
await cache1.put(query, ProcessorContent(['value1']))
110+
await cache2.put(query, ProcessorContent(['value2']))
111+
112+
# Check that each cache has its own value for the same query
113+
retrieved1 = await cache1.lookup(query)
114+
retrieved2 = await cache2.lookup(query)
115+
116+
self.assertIsNot(retrieved1, cache.CacheMiss)
117+
self.assertIsNot(retrieved2, cache.CacheMiss)
118+
self.assertEqual(content_api.as_text(retrieved1), 'value1')
119+
self.assertEqual(content_api.as_text(retrieved2), 'value2')
120+
121+
async def test_hash_fn_returns_none(self):
122+
mem_cache = cache.InMemoryCache(hash_fn=lambda q: None)
123+
query = ProcessorContent(['any_query'])
124+
125+
await mem_cache.put(query, ProcessorContent(['some_response']))
126+
self.assertIs(await mem_cache.lookup(query), cache.CacheMiss)
127+
128+
async def test_hash_fn_raises_exception(self):
129+
mock_hash_fn = mock.Mock(side_effect=ValueError('Hashing failed!'))
130+
mem_cache = cache.InMemoryCache(hash_fn=mock_hash_fn)
131+
query = ProcessorContent(['query_that_will_fail_hash'])
132+
133+
await mem_cache.put(query, ProcessorContent(['irrelevant']))
134+
self.assertIs(await mem_cache.lookup(query), cache.CacheMiss)
135+
self.assertEqual(mock_hash_fn.call_count, 2)
136+
137+
async def test_put_with_serialization_error_propagates(self):
138+
mem_cache = cache.InMemoryCache()
139+
query = ProcessorContent(['query'])
140+
141+
with mock.patch.object(
142+
mem_cache, '_serialize_fn', side_effect=RuntimeError('Unexpected!')
143+
):
144+
with self.assertRaises(RuntimeError):
145+
await mem_cache.put(query, ProcessorContent(['irrelevant']))
146+
147+
async def test_invalid_init_with_zero_max_items(self):
148+
with self.assertRaisesRegex(
149+
ValueError, 'max_items must be positive, got: 0'
150+
):
151+
cache.InMemoryCache(max_items=0)
152+
153+
with self.assertRaisesRegex(
154+
ValueError, 'max_items must be positive, got: -1'
155+
):
156+
cache.InMemoryCache(max_items=-1)
157+
158+
async def test_default_hash_is_part_order_sensitive(self):
159+
"""Tests the default hash function's sensitivity to part order."""
160+
mem_cache = cache.InMemoryCache(max_items=10)
161+
part1 = ProcessorPart('Hello', role='user')
162+
part2 = ProcessorPart('World', role='user')
163+
value = ProcessorContent(['Response'])
164+
165+
query_order1 = ProcessorContent([part1, part2])
166+
query_order2 = ProcessorContent([part2, part1])
167+
168+
# Hashes should be different
169+
hash1 = cache.default_processor_content_hash(query_order1)
170+
hash2 = cache.default_processor_content_hash(query_order2)
171+
self.assertNotEqual(hash1, hash2)
172+
173+
# Caching one should not allow lookup of the other
174+
await mem_cache.put(query_order1, value)
175+
self.assertIsNot(await mem_cache.lookup(query_order1), cache.CacheMiss)
176+
self.assertIs(await mem_cache.lookup(query_order2), cache.CacheMiss)
177+
178+
async def test_default_hash_is_key_order_insensitive(self):
179+
"""Tests the hash is insensitive to a part's internal dict key order."""
180+
mem_cache = cache.InMemoryCache(max_items=10)
181+
part_dict_1 = {'role': 'model', 'part': {'text': 'response text'}}
182+
part_dict_2 = {'part': {'text': 'response text'}, 'role': 'model'}
183+
184+
query1 = ProcessorContent([ProcessorPart.from_dict(data=part_dict_1)])
185+
query2 = ProcessorContent([ProcessorPart.from_dict(data=part_dict_2)])
186+
value = ProcessorContent(['Value'])
187+
188+
# Hashes should be identical because json.dumps uses sort_keys=True.
189+
hash1 = cache.default_processor_content_hash(query1)
190+
hash2 = cache.default_processor_content_hash(query2)
191+
self.assertEqual(hash1, hash2)
192+
193+
await mem_cache.put(query1, value)
194+
retrieved = await mem_cache.lookup(query2)
195+
self.assertIsNot(retrieved, cache.CacheMiss)
196+
self.assertEqual(content_api.as_text(retrieved), 'Value')
197+
198+
199+
if __name__ == '__main__':
200+
unittest.main()

0 commit comments

Comments
 (0)