|
| 1 | +import asyncio |
| 2 | +import unittest |
| 3 | +from unittest import mock |
| 4 | + |
| 5 | +from genai_processors import cache |
| 6 | +from genai_processors import content_api |
| 7 | + |
| 8 | + |
| 9 | +ProcessorPart = content_api.ProcessorPart |
| 10 | +ProcessorContent = content_api.ProcessorContent |
| 11 | + |
| 12 | + |
| 13 | +class InMemoryCacheTest(unittest.IsolatedAsyncioTestCase): |
| 14 | + |
| 15 | + async def test_put_and_lookup_with_default_hash(self): |
| 16 | + """Tests basic put and lookup using the default hashing.""" |
| 17 | + mem_cache = cache.InMemoryCache(max_items=10, ttl_hours=1) |
| 18 | + query = ProcessorContent(['query_text']) |
| 19 | + value_content = ProcessorContent( |
| 20 | + [ProcessorPart('response_text', role='model')] |
| 21 | + ) |
| 22 | + await mem_cache.put(query, value_content) |
| 23 | + retrieved = await mem_cache.lookup(query) |
| 24 | + |
| 25 | + self.assertIsNot(retrieved, cache.CacheMiss) |
| 26 | + self.assertEqual(content_api.as_text(retrieved), 'response_text') |
| 27 | + |
| 28 | + async def test_lookup_miss(self): |
| 29 | + """Tests that a lookup for a non-existent key returns CacheMiss.""" |
| 30 | + mem_cache = cache.InMemoryCache() |
| 31 | + retrieved = await mem_cache.lookup(ProcessorContent(['non_existent_key'])) |
| 32 | + self.assertIs(retrieved, cache.CacheMiss) |
| 33 | + |
| 34 | + async def test_put_override(self): |
| 35 | + mem_cache = cache.InMemoryCache() |
| 36 | + query = ProcessorContent(['query_text']) |
| 37 | + await mem_cache.put(query, ProcessorContent(['value1'])) |
| 38 | + await mem_cache.put(query, ProcessorContent(['value2'])) |
| 39 | + |
| 40 | + retrieved = await mem_cache.lookup(query) |
| 41 | + self.assertIsNot(retrieved, cache.CacheMiss) |
| 42 | + self.assertEqual(content_api.as_text(retrieved), 'value2') |
| 43 | + |
| 44 | + async def test_ttl_expiration(self): |
| 45 | + """Tests that an item expires after its TTL.""" |
| 46 | + ttl_hours = 0.0001 # A very short TTL |
| 47 | + mem_cache = cache.InMemoryCache(ttl_hours=ttl_hours, max_items=10) |
| 48 | + query = ProcessorContent(['key1']) |
| 49 | + await mem_cache.put(query, ProcessorContent(['value1'])) |
| 50 | + |
| 51 | + # Should exist immediately after |
| 52 | + self.assertIsNot(await mem_cache.lookup(query), cache.CacheMiss) |
| 53 | + |
| 54 | + # Wait for more than the TTL |
| 55 | + await asyncio.sleep(ttl_hours * 3600 * 1.1) |
| 56 | + self.assertIs(await mem_cache.lookup(query), cache.CacheMiss) |
| 57 | + |
| 58 | + async def test_max_items_eviction(self): |
| 59 | + mem_cache = cache.InMemoryCache(max_items=2) |
| 60 | + query1 = ProcessorContent(['key1']) |
| 61 | + query2 = ProcessorContent(['key2']) |
| 62 | + query3 = ProcessorContent(['key3']) |
| 63 | + |
| 64 | + await mem_cache.put(query1, ProcessorContent(['value1'])) |
| 65 | + await mem_cache.put(query2, ProcessorContent(['value2'])) |
| 66 | + await mem_cache.put( |
| 67 | + query3, ProcessorContent(['value3']) |
| 68 | + ) # This should evict query1 |
| 69 | + |
| 70 | + # query1 should be gone |
| 71 | + self.assertIs(await mem_cache.lookup(query1), cache.CacheMiss) |
| 72 | + # query2 and query3 should still be present |
| 73 | + self.assertIsNot(await mem_cache.lookup(query2), cache.CacheMiss) |
| 74 | + self.assertIsNot(await mem_cache.lookup(query3), cache.CacheMiss) |
| 75 | + |
| 76 | + async def test_remove(self): |
| 77 | + mem_cache = cache.InMemoryCache() |
| 78 | + query = ProcessorContent(['key1']) |
| 79 | + await mem_cache.put(query, ProcessorContent(['value1'])) |
| 80 | + self.assertIsNot(await mem_cache.lookup(query), cache.CacheMiss) |
| 81 | + |
| 82 | + await mem_cache.remove(query) |
| 83 | + self.assertIs(await mem_cache.lookup(query), cache.CacheMiss) |
| 84 | + |
| 85 | + async def test_uses_custom_hash_function(self): |
| 86 | + custom_hash_fn = mock.Mock(return_value='custom_key_123') |
| 87 | + mem_cache = cache.InMemoryCache(hash_fn=custom_hash_fn) |
| 88 | + query = ProcessorContent(['query_text']) |
| 89 | + |
| 90 | + await mem_cache.put(query, ProcessorContent(['value'])) |
| 91 | + custom_hash_fn.assert_called_once_with(query) |
| 92 | + |
| 93 | + custom_hash_fn.reset_mock() |
| 94 | + await mem_cache.lookup(query) |
| 95 | + custom_hash_fn.assert_called_once_with(query) |
| 96 | + |
| 97 | + async def test_with_key_prefix_isolates_caches(self): |
| 98 | + cache1 = cache.InMemoryCache(max_items=10) |
| 99 | + cache2 = cache1.with_key_prefix('p_') |
| 100 | + |
| 101 | + self.assertIsNot(cache1, cache2) |
| 102 | + self.assertIsNot( |
| 103 | + cache1._cache, cache2._cache |
| 104 | + ) # Important: they have different TTLCache instances |
| 105 | + self.assertIsInstance(cache2, cache.InMemoryCache) |
| 106 | + self.assertEqual(cache1._max_items, cache2._max_items) |
| 107 | + |
| 108 | + query = ProcessorContent(['shared_query']) |
| 109 | + await cache1.put(query, ProcessorContent(['value1'])) |
| 110 | + await cache2.put(query, ProcessorContent(['value2'])) |
| 111 | + |
| 112 | + # Check that each cache has its own value for the same query |
| 113 | + retrieved1 = await cache1.lookup(query) |
| 114 | + retrieved2 = await cache2.lookup(query) |
| 115 | + |
| 116 | + self.assertIsNot(retrieved1, cache.CacheMiss) |
| 117 | + self.assertIsNot(retrieved2, cache.CacheMiss) |
| 118 | + self.assertEqual(content_api.as_text(retrieved1), 'value1') |
| 119 | + self.assertEqual(content_api.as_text(retrieved2), 'value2') |
| 120 | + |
| 121 | + async def test_hash_fn_returns_none(self): |
| 122 | + mem_cache = cache.InMemoryCache(hash_fn=lambda q: None) |
| 123 | + query = ProcessorContent(['any_query']) |
| 124 | + |
| 125 | + await mem_cache.put(query, ProcessorContent(['some_response'])) |
| 126 | + self.assertIs(await mem_cache.lookup(query), cache.CacheMiss) |
| 127 | + |
| 128 | + async def test_hash_fn_raises_exception(self): |
| 129 | + mock_hash_fn = mock.Mock(side_effect=ValueError('Hashing failed!')) |
| 130 | + mem_cache = cache.InMemoryCache(hash_fn=mock_hash_fn) |
| 131 | + query = ProcessorContent(['query_that_will_fail_hash']) |
| 132 | + |
| 133 | + await mem_cache.put(query, ProcessorContent(['irrelevant'])) |
| 134 | + self.assertIs(await mem_cache.lookup(query), cache.CacheMiss) |
| 135 | + self.assertEqual(mock_hash_fn.call_count, 2) |
| 136 | + |
| 137 | + async def test_put_with_serialization_error_propagates(self): |
| 138 | + mem_cache = cache.InMemoryCache() |
| 139 | + query = ProcessorContent(['query']) |
| 140 | + |
| 141 | + with mock.patch.object( |
| 142 | + mem_cache, '_serialize_fn', side_effect=RuntimeError('Unexpected!') |
| 143 | + ): |
| 144 | + with self.assertRaises(RuntimeError): |
| 145 | + await mem_cache.put(query, ProcessorContent(['irrelevant'])) |
| 146 | + |
| 147 | + async def test_invalid_init_with_zero_max_items(self): |
| 148 | + with self.assertRaisesRegex( |
| 149 | + ValueError, 'max_items must be positive, got: 0' |
| 150 | + ): |
| 151 | + cache.InMemoryCache(max_items=0) |
| 152 | + |
| 153 | + with self.assertRaisesRegex( |
| 154 | + ValueError, 'max_items must be positive, got: -1' |
| 155 | + ): |
| 156 | + cache.InMemoryCache(max_items=-1) |
| 157 | + |
| 158 | + async def test_default_hash_is_part_order_sensitive(self): |
| 159 | + """Tests the default hash function's sensitivity to part order.""" |
| 160 | + mem_cache = cache.InMemoryCache(max_items=10) |
| 161 | + part1 = ProcessorPart('Hello', role='user') |
| 162 | + part2 = ProcessorPart('World', role='user') |
| 163 | + value = ProcessorContent(['Response']) |
| 164 | + |
| 165 | + query_order1 = ProcessorContent([part1, part2]) |
| 166 | + query_order2 = ProcessorContent([part2, part1]) |
| 167 | + |
| 168 | + # Hashes should be different |
| 169 | + hash1 = cache.default_processor_content_hash(query_order1) |
| 170 | + hash2 = cache.default_processor_content_hash(query_order2) |
| 171 | + self.assertNotEqual(hash1, hash2) |
| 172 | + |
| 173 | + # Caching one should not allow lookup of the other |
| 174 | + await mem_cache.put(query_order1, value) |
| 175 | + self.assertIsNot(await mem_cache.lookup(query_order1), cache.CacheMiss) |
| 176 | + self.assertIs(await mem_cache.lookup(query_order2), cache.CacheMiss) |
| 177 | + |
| 178 | + async def test_default_hash_is_key_order_insensitive(self): |
| 179 | + """Tests the hash is insensitive to a part's internal dict key order.""" |
| 180 | + mem_cache = cache.InMemoryCache(max_items=10) |
| 181 | + part_dict_1 = {'role': 'model', 'part': {'text': 'response text'}} |
| 182 | + part_dict_2 = {'part': {'text': 'response text'}, 'role': 'model'} |
| 183 | + |
| 184 | + query1 = ProcessorContent([ProcessorPart.from_dict(data=part_dict_1)]) |
| 185 | + query2 = ProcessorContent([ProcessorPart.from_dict(data=part_dict_2)]) |
| 186 | + value = ProcessorContent(['Value']) |
| 187 | + |
| 188 | + # Hashes should be identical because json.dumps uses sort_keys=True. |
| 189 | + hash1 = cache.default_processor_content_hash(query1) |
| 190 | + hash2 = cache.default_processor_content_hash(query2) |
| 191 | + self.assertEqual(hash1, hash2) |
| 192 | + |
| 193 | + await mem_cache.put(query1, value) |
| 194 | + retrieved = await mem_cache.lookup(query2) |
| 195 | + self.assertIsNot(retrieved, cache.CacheMiss) |
| 196 | + self.assertEqual(content_api.as_text(retrieved), 'Value') |
| 197 | + |
| 198 | + |
| 199 | +if __name__ == '__main__': |
| 200 | + unittest.main() |
0 commit comments