|
12 | 12 | import re |
13 | 13 | import requests |
14 | 14 | import functools |
| 15 | +from collections import OrderedDict, defaultdict |
15 | 16 |
|
16 | 17 |
|
17 | 18 | if is_torch_available(): |
@@ -113,6 +114,93 @@ def __setattr__(self, key, value): |
113 | 114 | KwargsConfig = DottableDict |
114 | 115 |
|
115 | 116 |
|
| 117 | +class CachedDictBySize(OrderedDict): |
| 118 | + '''按照size来缓存的字典, 超过maxsize会pop最老的item''' |
| 119 | + def __init__(self, *args, maxsize=100, **kwargs): |
| 120 | + super().__init__(*args, **kwargs) |
| 121 | + self.__maxsize = maxsize |
| 122 | + # 如果初始化时字典已经超过了maxsize,则移除多余的项 |
| 123 | + while len(self) > self.__maxsize: |
| 124 | + self.popitem(last=False) |
| 125 | + |
| 126 | + def __setitem__(self, key, value): |
| 127 | + # 检查字典是否已满,如果满了则移除最早的项 |
| 128 | + if len(self) >= self.__maxsize: |
| 129 | + self.popitem(last=False) |
| 130 | + # 使用super()来避免无限递归 |
| 131 | + super().__setitem__(key, value) |
| 132 | + |
| 133 | + |
| 134 | +class CachedDictByFreq(dict): |
| 135 | + '''按照freq来缓存的字典, 超过maxsize会先pop频次最低的item''' |
| 136 | + def __init__(self, *args, maxsize=100, **kwargs): |
| 137 | + super().__init__(*args, **kwargs) |
| 138 | + self.__maxsize = maxsize |
| 139 | + self.__frequency = defaultdict(int) # 用于跟踪每个键的访问频次 |
| 140 | + |
| 141 | + # 如果初始化时字典已经超过了maxsize,则移除多余的项 |
| 142 | + self._trim_to_maxsize(self.__maxsize) |
| 143 | + |
| 144 | + def __setitem__(self, key, value): |
| 145 | + if key not in self: |
| 146 | + # 先去掉最低频的,如果先set_item则可能会把当前key,value去掉 |
| 147 | + self._trim_to_maxsize(self.__maxsize-1) |
| 148 | + |
| 149 | + super().__setitem__(key, value) # 这里如果value不一致,则直接替换,仅以key为准 |
| 150 | + |
| 151 | + # 更新频次 |
| 152 | + self.__frequency[key] += 1 |
| 153 | + |
| 154 | + def __getitem__(self, key): |
| 155 | + # 每当项被访问时,更新其频次, 仅当key存在的时候,如果key不存在,则freq不变 |
| 156 | + value = super().__getitem__(key) |
| 157 | + self.__frequency[key] += 1 |
| 158 | + return value |
| 159 | + |
| 160 | + def _trim_to_maxsize(self, maxsize): |
| 161 | + # 辅助方法,用于在初始化或需要时修剪到maxsize |
| 162 | + while len(self) > maxsize: |
| 163 | + min_freq_key = min(self.__frequency, key=self.__frequency.get) |
| 164 | + del self[min_freq_key] |
| 165 | + del self.__frequency[min_freq_key] |
| 166 | + |
| 167 | + |
| 168 | +class CachedDictByTimeout(dict): |
| 169 | + '''按照time来缓存的字典, 超过最大时长会pop最老的item''' |
| 170 | + def __init__(self, *args, timeout=60, **kwargs): |
| 171 | + super().__init__(*args, **kwargs) |
| 172 | + self.__timeout = timeout # 超时时间,以秒为单位 |
| 173 | + |
| 174 | + def __setitem__(self, key, value): |
| 175 | + super().__setitem__(key, (value, time.time())) # 存储值和创建时间(或更新时间) |
| 176 | + |
| 177 | + def __getitem__(self, key): |
| 178 | + value, last_access = super().__getitem__(key) |
| 179 | + if time.time() - last_access > self.__timeout: |
| 180 | + # 如果key已过期,则从字典中删除它并抛出KeyError |
| 181 | + del self[key] |
| 182 | + raise KeyError(f"Key `{key}` has expired") |
| 183 | + return value |
| 184 | + |
| 185 | + def get(self, key, default=None): |
| 186 | + try: |
| 187 | + return self[key] |
| 188 | + except KeyError: |
| 189 | + return default |
| 190 | + |
| 191 | + # 注意:由于我们重写了__getitem__,所以pop方法不会自动检查过期项 |
| 192 | + # 但我们可以提供一个自定义的pop方法,如果需要的话 |
| 193 | + def pop(self, key, default=None): |
| 194 | + if key in self: |
| 195 | + value, last_access = super().__getitem__(key) |
| 196 | + if time.time() - last_access > self.__timeout: |
| 197 | + del self[key] |
| 198 | + return default |
| 199 | + else: |
| 200 | + return super().pop(key)[0] # 只返回值,不返回时间戳 |
| 201 | + return default |
| 202 | + |
| 203 | + |
116 | 204 | class JsonConfig(DottableDict): |
117 | 205 | '''读取json配置文件/字符串/字典并返回可.操作符的字典 |
118 | 206 | 1. json文件路径 |
|
0 commit comments