Skip to content

Commit 0c9d627

Browse files
committed
增加cachedict
1 parent 791ac1f commit 0c9d627

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

torch4keras/snippets/misc.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import re
1313
import requests
1414
import functools
15+
from collections import OrderedDict, defaultdict
1516

1617

1718
if is_torch_available():
@@ -113,6 +114,93 @@ def __setattr__(self, key, value):
113114
KwargsConfig = DottableDict
114115

115116

117+
class CachedDictBySize(OrderedDict):
118+
'''按照size来缓存的字典, 超过maxsize会pop最老的item'''
119+
def __init__(self, *args, maxsize=100, **kwargs):
120+
super().__init__(*args, **kwargs)
121+
self.__maxsize = maxsize
122+
# 如果初始化时字典已经超过了maxsize,则移除多余的项
123+
while len(self) > self.__maxsize:
124+
self.popitem(last=False)
125+
126+
def __setitem__(self, key, value):
127+
# 检查字典是否已满,如果满了则移除最早的项
128+
if len(self) >= self.__maxsize:
129+
self.popitem(last=False)
130+
# 使用super()来避免无限递归
131+
super().__setitem__(key, value)
132+
133+
134+
class CachedDictByFreq(dict):
135+
'''按照freq来缓存的字典, 超过maxsize会先pop频次最低的item'''
136+
def __init__(self, *args, maxsize=100, **kwargs):
137+
super().__init__(*args, **kwargs)
138+
self.__maxsize = maxsize
139+
self.__frequency = defaultdict(int) # 用于跟踪每个键的访问频次
140+
141+
# 如果初始化时字典已经超过了maxsize,则移除多余的项
142+
self._trim_to_maxsize(self.__maxsize)
143+
144+
def __setitem__(self, key, value):
145+
if key not in self:
146+
# 先去掉最低频的,如果先set_item则可能会把当前key,value去掉
147+
self._trim_to_maxsize(self.__maxsize-1)
148+
149+
super().__setitem__(key, value) # 这里如果value不一致,则直接替换,仅以key为准
150+
151+
# 更新频次
152+
self.__frequency[key] += 1
153+
154+
def __getitem__(self, key):
155+
# 每当项被访问时,更新其频次, 仅当key存在的时候,如果key不存在,则freq不变
156+
value = super().__getitem__(key)
157+
self.__frequency[key] += 1
158+
return value
159+
160+
def _trim_to_maxsize(self, maxsize):
161+
# 辅助方法,用于在初始化或需要时修剪到maxsize
162+
while len(self) > maxsize:
163+
min_freq_key = min(self.__frequency, key=self.__frequency.get)
164+
del self[min_freq_key]
165+
del self.__frequency[min_freq_key]
166+
167+
168+
class CachedDictByTimeout(dict):
169+
'''按照time来缓存的字典, 超过最大时长会pop最老的item'''
170+
def __init__(self, *args, timeout=60, **kwargs):
171+
super().__init__(*args, **kwargs)
172+
self.__timeout = timeout # 超时时间,以秒为单位
173+
174+
def __setitem__(self, key, value):
175+
super().__setitem__(key, (value, time.time())) # 存储值和创建时间(或更新时间)
176+
177+
def __getitem__(self, key):
178+
value, last_access = super().__getitem__(key)
179+
if time.time() - last_access > self.__timeout:
180+
# 如果key已过期,则从字典中删除它并抛出KeyError
181+
del self[key]
182+
raise KeyError(f"Key `{key}` has expired")
183+
return value
184+
185+
def get(self, key, default=None):
186+
try:
187+
return self[key]
188+
except KeyError:
189+
return default
190+
191+
# 注意:由于我们重写了__getitem__,所以pop方法不会自动检查过期项
192+
# 但我们可以提供一个自定义的pop方法,如果需要的话
193+
def pop(self, key, default=None):
194+
if key in self:
195+
value, last_access = super().__getitem__(key)
196+
if time.time() - last_access > self.__timeout:
197+
del self[key]
198+
return default
199+
else:
200+
return super().pop(key)[0] # 只返回值,不返回时间戳
201+
return default
202+
203+
116204
class JsonConfig(DottableDict):
117205
'''读取json配置文件/字符串/字典并返回可.操作符的字典
118206
1. json文件路径

0 commit comments

Comments
 (0)