forked from Zhalslar/astrbot_plugin_qqadmin
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.py
More file actions
260 lines (201 loc) · 8.09 KB
/
data.py
File metadata and controls
260 lines (201 loc) · 8.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import asyncio
import json
from pathlib import Path
import aiosqlite
from astrbot.api import logger
from astrbot.core.config.astrbot_config import AstrBotConfig
from .utils import parse_bool
class QQAdminDB:
"""
群管插件数据库(极简 API + 动态字段 + 自动补齐)
"""
# ====================== 字段中英文映射 ======================
FIELD_MAP = {
"builtin_ban": "启用内置禁词",
"custom_ban_words": "自定义违禁词",
"word_ban_time": "禁词禁言时长",
"spamming_ban_time": "刷屏禁言时长",
}
REVERSE_FIELD_MAP = {v: k for k, v in FIELD_MAP.items()}
# ================================================================
def __init__(self, config: AstrBotConfig, db_path: Path):
self.db_path = db_path
# 默认字段(动态配置核心)
self.default_cfg: dict = config["default"]
self._conn = None
self._cache = {}
self._initialized = False
self._init_lock = asyncio.Lock()
# ============================== 初始化 ==============================
async def init(self):
async with self._init_lock:
if self._initialized:
return
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._conn = await aiosqlite.connect(str(self.db_path))
self._conn.row_factory = aiosqlite.Row
await self._conn.execute("""
CREATE TABLE IF NOT EXISTS groups (
group_id TEXT PRIMARY KEY,
data TEXT NOT NULL
);
""")
await self._conn.commit()
# 加载缓存
async with self._conn.execute("SELECT group_id, data FROM groups;") as cur:
async for row in cur:
try:
self._cache[row["group_id"]] = json.loads(row["data"])
except Exception:
logger.exception("解析 group 数据失败: %s", row["group_id"])
self._initialized = True
logger.info("QQAdminDB initialized (%d groups)", len(self._cache))
async def _save_to_db(self, gid: str, data):
if not self._conn:
raise RuntimeError("请先 init()")
await self._conn.execute(
"""
INSERT INTO groups(group_id, data)
VALUES (?, ?)
ON CONFLICT(group_id) DO UPDATE SET data=excluded.data;
""",
(gid, json.dumps(data, ensure_ascii=False)),
)
await self._conn.commit()
# ============================== 基础:确保配置存在 ==============================
async def ensure_group(self, gid: str):
"""确保存在群配置,若没有则按 default_cfg 初始化"""
if gid not in self._cache:
self._cache[gid] = json.loads(json.dumps(self.default_cfg))
await self._save_to_db(gid, self._cache[gid])
# ============================== 🔥 极简 API ==============================
async def all(self, gid: str) -> dict:
"""
获取整个配置,并自动补齐 default_cfg 的字段
"""
await self.ensure_group(gid)
data = self._cache[gid]
changed = False
for k, v in self.default_cfg.items():
if k not in data:
data[k] = json.loads(json.dumps(v))
changed = True
if changed:
await self._save_to_db(gid, data)
return data
async def get(self, gid: str, field: str, default=None):
"""
读字段,不存在则补齐 default
"""
await self.ensure_group(gid)
data = self._cache[gid]
if field not in data:
data[field] = json.loads(json.dumps(default))
await self._save_to_db(gid, data)
return data[field]
async def set(self, gid: str, field: str, value):
"""
写入字段
"""
await self.ensure_group(gid)
self._cache[gid][field] = value
await self._save_to_db(gid, self._cache[gid])
async def add(self, gid: str, field: str, value):
"""
列表字段追加(自动创建列表)
"""
lst = list(await self.get(gid, field, []))
if value not in lst:
lst.append(value)
await self.set(gid, field, lst)
async def remove(self, gid: str, field: str, value):
"""
列表字段删除(自动创建列表)
"""
lst = [i for i in await self.get(gid, field, []) if i != value]
await self.set(gid, field, lst)
# ============================== 删除群配置 ==============================
async def delete_group(self, gid: str):
"""彻底删除群配置"""
if self._conn:
await self._conn.execute("DELETE FROM groups WHERE group_id = ?", (gid,))
await self._conn.commit()
self._cache.pop(gid, None)
# ============================== 关闭 ==============================
async def close(self):
if self._conn:
await self._conn.close()
self._conn = None
self._initialized = False
# ====================== 中文展示、读回 ======================
async def export_cn_lines(self, gid: str) -> str:
"""
以中文键名 + 多行文本形式输出群配置。
- 列表字段:用空格分隔
- 布尔:开 / 关
- 其它类型按原样输出
"""
data = await self.all(gid)
lines = []
for eng_key, value in data.items():
cn_key = self.FIELD_MAP.get(eng_key, eng_key)
# 列表字段 => 用空格分隔
if isinstance(value, list):
val_str = " ".join(map(str, value))
# 布尔字段 => 显示 为“开 / 关”
elif isinstance(value, bool):
val_str = "开" if value else "关"
# 其他字段 => 按原样
else:
val_str = str(value)
lines.append(f"{cn_key}: {val_str}")
return "\n".join(lines)
async def import_cn_lines(self, gid: str, text: str) -> dict:
"""
解析用户提交的中文多行文本并写回 DB
- 列表字段:空格分隔
- 布尔:开/关/开启/on/off/true/false/1/0
- 数字:自动转 int
- 字符串:原样保存
"""
await self.ensure_group(gid)
data = self._cache[gid]
for line in text.splitlines():
if ":" not in line:
continue
cn_key, raw_v = line.split(":", 1)
cn_key = cn_key.strip()
raw_v = raw_v.strip()
eng_key = self.REVERSE_FIELD_MAP.get(cn_key)
if not eng_key:
continue
old_val = data.get(eng_key)
# 如果原字段是 bool,则优先进行布尔解析
if isinstance(old_val, bool):
parsed = parse_bool(raw_v)
if parsed is not None:
data[eng_key] = parsed
continue
# 若解析失败,退回默认字面处理(防错)
# 列表字段:按空格拆
if isinstance(old_val, list):
value = [x for x in raw_v.split() if x]
# 数字字段:自动转 int
elif isinstance(old_val, int):
try:
value = int(raw_v)
except ValueError:
value = old_val # 防错保底
# 字符串字段
else:
value = raw_v
data[eng_key] = value
await self._save_to_db(gid, data)
return data
async def reset_to_default(self, gid: str | None = None):
"""把指定群(或全部群)配置恢复成 default_cfg"""
targets = [gid] if gid else list(self._cache.keys())
for g in targets:
self._cache[g] = json.loads(json.dumps(self.default_cfg))
await self._save_to_db(g, self._cache[g])
logger.info(f"群聊{gid}的群管配置已重置为默认值")