-
Notifications
You must be signed in to change notification settings - Fork 103
Expand file tree
/
Copy pathconfig_manager.py
More file actions
1789 lines (1466 loc) · 70.5 KB
/
Copy pathconfig_manager.py
File metadata and controls
1789 lines (1466 loc) · 70.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
import json
import csv
import tempfile
import shutil
import folder_paths
class ConfigManager:
def __init__(self):
# 插件目录
self.dir_path = os.path.dirname(os.path.abspath(__file__))
# 获取 ComfyUI 用户目录
try:
user_dir = folder_paths.get_user_directory()
if user_dir and os.path.isdir(user_dir):
# 使用 user/default/prompt-assistant 作为基础目录
self.base_dir = os.path.join(user_dir, "default", "prompt-assistant")
# self._log(f"使用用户配置目录: {self.base_dir}")
else:
# 回退到插件目录
self.base_dir = self.dir_path
self._log(f"回退到插件配置目录: {self.base_dir}")
except Exception as e:
# 异常处理,回退到插件目录
self.base_dir = self.dir_path
self._log(f"无法获取用户目录({str(e)}),使用插件配置目录")
# 定义各个子目录
self.config_dir = os.path.join(self.base_dir, "config")
self.rules_dir = os.path.join(self.base_dir, "rules")
self.tags_dir = os.path.join(self.base_dir, "tags")
# 确保目录存在
os.makedirs(self.config_dir, exist_ok=True)
os.makedirs(self.rules_dir, exist_ok=True)
os.makedirs(self.tags_dir, exist_ok=True)
# 配置文件路径(用户配置和选择)
self.config_path = os.path.join(self.config_dir, "config.json")
self.active_prompts_path = os.path.join(self.config_dir, "active_prompts.json")
self.tags_user_path = os.path.join(self.config_dir, "tags_user.json")
self.tags_selection_path = os.path.join(self.config_dir, "tags_selection.json")
# 规则文件路径(规则定义和模板)
self.system_prompts_path = os.path.join(self.rules_dir, "system_prompts.json")
self.kontext_presets_path = os.path.join(self.rules_dir, "kontext_presets.json")
# ---模板目录(插件内置)---
self.templates_dir = os.path.join(self.dir_path, "config")
# 存储模板版本号(用于版本比对)
self._template_versions = {}
# ---加载默认配置(从模板文件)---
self.default_config = self._load_template("config", {"version": "2.0", "model_services": []})
self.default_system_prompts = self._load_template("system_prompts", {})
self.default_kontext_presets = self._load_template("kontext_presets", {})
# ---简单默认配置(无需模板,直接定义)---
self.default_active_prompts = {
"expand": "expand_扩写-通用",
"vision_zh": "vision_zh_图像描述-Tag风格",
"vision_en": "vision_en_Detail_Caption"
}
self.default_user_tags = {"favorites": []}
# 默认标签选择
self.default_tags_selection = {"selected_file": "用户标签.csv"}
# 执行数据迁移和配置文件初始化
# migration_tool 统一处理:确保文件存在 -> CSV标签迁移 -> 旧版迁移 -> 增量更新
self._run_migrations()
# 验证并修复激活提示词(静默模式,仅异常时修复)
self.validate_and_fix_active_prompts()
# 验证并修复模型参数配置
self.validate_and_fix_model_params()
# --- 数据迁移 ---
def _run_migrations(self):
"""
执行数据迁移(按需调用,不影响性能)
仅在需要时才导入和运行迁移工具
"""
try:
from .utils.migration_tool import run_migrations
# 准备默认配置数据用于增量更新
default_configs = {
'config': self.default_config,
'system_prompts': self.default_system_prompts,
'active_prompts': self.default_active_prompts,
'tags_user': self.default_user_tags,
'kontext_presets': self.default_kontext_presets
}
# 运行迁移
results = run_migrations(
plugin_dir=self.dir_path,
user_base_dir=self.base_dir,
logger=self._log,
default_configs=default_configs
)
# 记录迁移结果
if results.get('tags_migration'):
self._log("[用户标签.csv] 数据迁移完成")
except Exception as e:
self._log(f"数据迁移失败: {str(e)}")
# 迁移失败不影响正常运行,仅记录日志
# --- 统一日志输出 ---
def _log(self, msg: str):
"""统一控制台日志前缀"""
from .utils.common import _ANSI_CLEAR_EOL
print(f"\r{_ANSI_CLEAR_EOL}✨ {msg}", flush=True)
# ---模板加载---
def _load_template(self, template_name: str, fallback: dict = None) -> dict:
"""
从模板文件加载默认配置
参数:
template_name: 模板名称(不含扩展名和_template后缀)
fallback: 加载失败时的回退默认值
返回:
配置字典(包含 __config_version 用于版本管理)
"""
template_path = os.path.join(self.templates_dir, f"{template_name}_template.json")
try:
with open(template_path, "r", encoding="utf-8") as f:
data = json.load(f)
# 获取版本号并保存,用于后续比对
template_version = data.get("__config_version", "2.0")
self._template_versions[template_name] = template_version
return data
except Exception as e:
self._log(f"加载模板 {template_name} 失败: {str(e)},使用回退值")
# 确保 fallback 也包含版本号
if fallback is None:
fallback = {}
# 如果 fallback 没有版本号,添加默认版本号
if "__config_version" not in fallback:
fallback = {"__config_version": "2.0", **fallback}
self._template_versions[template_name] = "2.0"
return fallback
def _get_config_version(self, config: dict) -> str:
"""
获取配置版本号(兼容新旧两种版本字段)
版本字段优先级:
1. __config_version (新版本字段,如 "2.0.0")
2. version (旧版本字段,如 "2.0" 或 "1.0")
3. 默认返回 "1.0"(无版本字段视为最旧版本)
返回:
版本字符串,如 "2.0.0"、"2.0" 或 "1.0"
"""
# 优先使用新版本字段
if "__config_version" in config:
return config["__config_version"]
# 兼容旧版本字段
return config.get("version", "1.0")
def _is_v2_config(self, config: dict) -> bool:
"""
检查配置是否为 v2.0 或更高版本
返回:
True 表示 v2.0 或更高版本 (1.9 也视为 v2 格式,用于增量测试)
"""
version = self._get_config_version(config)
try:
v_float = float(version)
return v_float >= 1.9
except ValueError:
# 如果不是数字(如 "2.0.0"),取主版本号比较
major_version = version.split(".")[0]
try:
return int(major_version) >= 2
except ValueError:
return False
# --- 注意:以下方法已迁移到 migration_tool.py ---
# - _apply_migrated_api_keys
# - _migrate_provider_to_service
# - _create_or_update_custom_service
# - _match_service_by_provider
# - _check_and_add_missing_services
# 配置文件的创建、迁移和增量更新统一由 migration_tool 处理
def _atomic_write_json(self, file_path: str, data: dict) -> bool:
"""
原子性写入 JSON 文件
采用"写临时文件 + 原子性重命名"的策略,确保文件写入的原子性:
- 如果写入成功,新文件会替换旧文件
- 如果写入失败或被中断,旧文件保持不变
参数:
file_path: 目标文件路径
data: 要保存的数据字典
返回:
bool: 保存成功返回 True,失败返回 False
"""
temp_fd = None
temp_path = None
try:
# ---步骤1:写入临时文件---
# 在同一目录下创建临时文件(确保在同一文件系统,rename 才是原子的)
temp_fd, temp_path = tempfile.mkstemp(
dir=os.path.dirname(file_path),
suffix='.tmp',
prefix='.tmp_'
)
# 完整写入新配置到临时文件
with os.fdopen(temp_fd, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
temp_fd = None # 文件已关闭,避免重复关闭
# ---步骤2:原子性替换---
# rename 操作是原子的,要么成功替换,要么失败不变
shutil.move(temp_path, file_path)
temp_path = None # 已移动,避免清理时删除
return True
except Exception as e:
self._log(f"原子性写入 JSON 文件失败 [{os.path.basename(file_path)}]: {str(e)}")
return False
finally:
# 清理临时文件(如果写入失败)
if temp_fd is not None:
try:
os.close(temp_fd)
except:
pass
if temp_path is not None and os.path.exists(temp_path):
try:
os.unlink(temp_path)
except:
pass
def load_config(self):
"""加载配置文件"""
try:
with open(self.config_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
self._log(f"加载配置文件失败: {str(e)}")
return self.default_config
def save_config(self, config):
"""保存配置文件"""
return self._atomic_write_json(self.config_path, config)
def load_system_prompts(self):
"""加载系统提示词配置"""
try:
with open(self.system_prompts_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
self._log(f"加载系统提示词配置失败: {str(e)}")
return self.default_system_prompts
def save_system_prompts(self, system_prompts):
"""保存系统提示词配置"""
return self._atomic_write_json(self.system_prompts_path, system_prompts)
def load_active_prompts(self):
"""加载激活的提示词配置"""
try:
with open(self.active_prompts_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
self._log(f"加载激活的提示词配置失败: {str(e)}")
return self.default_active_prompts
def save_active_prompts(self, active_prompts):
"""保存激活的提示词配置"""
return self._atomic_write_json(self.active_prompts_path, active_prompts)
def load_user_tags(self):
"""加载用户标签配置"""
try:
with open(self.tags_user_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
self._log(f"加载用户标签配置失败: {str(e)}")
return self.default_user_tags
def save_user_tags(self, user_tags):
"""保存用户标签配置"""
return self._atomic_write_json(self.tags_user_path, user_tags)
def load_kontext_presets(self):
"""加载Kontext预设配置"""
try:
with open(self.kontext_presets_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
self._log(f"加载Kontext预设配置失败: {str(e)}")
return {}
def save_kontext_presets(self, kontext_presets):
"""保存Kontext预设配置"""
return self._atomic_write_json(self.kontext_presets_path, kontext_presets)
# --- 注意:ensure_tags_csv_exists 和 CSV 标签迁移已迁移到 migration_tool.py ---
def list_tags_files(self) -> list:
"""列出tags目录下所有CSV文件"""
try:
files = []
for filename in os.listdir(self.tags_dir):
if filename.endswith(".csv"):
files.append(filename)
return sorted(files)
except Exception as e:
self._log(f"列出标签文件失败: {str(e)}")
return []
def load_tags_csv(self, filename: str) -> dict:
"""加载CSV标签文件,返回嵌套字典结构"""
csv_path = os.path.join(self.tags_dir, filename)
if not os.path.exists(csv_path):
self._log(f"CSV文件不存在: {filename}")
return {}
# 尝试多种编码,优先尝试 utf-8-sig (Excel默认UTF-8),然后是 gbk (Excel默认ANSI),最后是 utf-8
encodings = ['utf-8-sig', 'gbk', 'gb18030', 'utf-8']
for encoding in encodings:
try:
result = {}
with open(csv_path, "r", encoding=encoding, newline="") as f:
reader = csv.reader(f)
try:
header = next(reader, None) # 跳过表头
except StopIteration:
return {} # 空文件
for row in reader:
# 过滤无效行
if not row or not any(cell.strip() for cell in row):
continue
# 至少需要两列:标签名, 标签值
if len(row) < 2:
continue
tag_name = row[0].strip()
tag_value = row[1].strip()
if not tag_name:
continue
# 分类路径:从第3列开始,过滤空值
categories = [c.strip() for c in row[2:] if c.strip()]
# 构建嵌套结构
current = result
for cat in categories:
if cat not in current or not isinstance(current[cat], dict):
current[cat] = {}
current = current[cat]
# 处理空分类占位符:只创建分类结构,不添加标签
if tag_name == "__empty__" or tag_name == "__placeholder__":
continue
# 添加标签
current[tag_name] = tag_value
return result
except UnicodeDecodeError:
continue
except Exception as e:
self._log(f"加载CSV标签失败 ({encoding}): {str(e)}")
continue
self._log(f"无法加载CSV文件: {filename},尝试了所有编码均失败")
return {}
def save_tags_csv(self, filename: str, tags: dict) -> bool:
"""保存标签数据到CSV文件"""
csv_path = os.path.join(self.tags_dir, filename)
try:
rows = []
max_depth = 0
def extract_tags(obj, path: list):
nonlocal max_depth
# 确保 obj 是字典类型
if not isinstance(obj, dict):
return
# 如果是空分类(空字典),添加占位行
if len(obj) == 0 and path:
# 使用 __empty__ 作为占位符标记空分类
rows.append(["__empty__", ""] + path)
max_depth = max(max_depth, len(path))
return
for key, value in obj.items():
if isinstance(value, str):
rows.append([key, value] + path)
max_depth = max(max_depth, len(path))
elif isinstance(value, dict):
extract_tags(value, path + [key])
# 提取所有标签
extract_tags(tags, [])
if not rows:
self._log(f"保存CSV标签: 数据为空")
# 如果数据为空,写入只含表头的文件或保持现状?
# 通常为了防止误删,如果 tags 为空暂不操作或清空文件。
# 这里选择写入表头:
with open(csv_path, "w", encoding="utf-8-sig", newline="") as f:
writer = csv.writer(f)
writer.writerow(["标签名", "标签值"])
return True
# 动态构建表头
header = ["标签名", "标签值"]
for i in range(max_depth):
num_zh = ["一", "二", "三", "四", "五", "六", "七", "八", "九", "十"]
suffix = num_zh[i] if i < len(num_zh) else str(i + 1)
header.append(f"{suffix}级分类")
with open(csv_path, "w", encoding="utf-8-sig", newline="") as f:
writer = csv.writer(f)
writer.writerow(header)
for row in rows:
# 补齐长度以匹配表头
while len(row) < len(header):
row.append("")
# 确保 row 长度不超过表头(防御性)
writer.writerow(row[:len(header)])
return True
except Exception as e:
self._log(f"保存CSV标签失败: {str(e)}")
return False
def get_tags_selection(self) -> dict:
"""获取用户选择的标签文件"""
try:
if os.path.exists(self.tags_selection_path):
with open(self.tags_selection_path, "r", encoding="utf-8") as f:
return json.load(f)
return self.default_tags_selection
except Exception as e:
self._log(f"读取标签选择失败: {str(e)}")
return self.default_tags_selection
def save_tags_selection(self, selection: dict) -> bool:
"""保存用户选择的标签文件"""
try:
with open(self.tags_selection_path, "w", encoding="utf-8") as f:
json.dump(selection, f, ensure_ascii=False, indent=2)
return True
except Exception as e:
self._log(f"保存标签选择失败: {str(e)}")
return False
def get_favorites(self) -> dict:
"""获取收藏列表"""
user_tags = self.load_user_tags()
favorites = user_tags.get("favorites", {})
# 兼容性处理:如果是列表,转换为字典
if isinstance(favorites, list):
new_favorites = {}
for item in favorites:
if isinstance(item, str):
new_favorites[item] = item
elif isinstance(item, dict):
name = item.get("name", item.get("value"))
value = item.get("value")
if name and value:
new_favorites[name] = value
return new_favorites
return favorites
def add_favorite(self, tag_value: str, tag_name: str = None, category: str = "默认") -> bool:
"""添加收藏"""
try:
user_tags = self.load_user_tags()
favorites = user_tags.get("favorites", {})
# 兼容性迁移:如果是一维字典 {name: value},无需强制迁移,但新添加的会放入 category
# 如果是列表,先迁移为字典
if isinstance(favorites, list):
favorites = self.get_favorites()
name = tag_name if tag_name else tag_value
# 使用嵌套结构 {分类: {名称: 值}}
if category not in favorites:
# 检查是否存在旧的平铺结构,如果有,且category是默认,可能混杂
# 这里简单处理:如果 favorites 只有键值对且都不是字典,说明是旧版平铺
# 但为了不破坏旧数据,我们在顶层只存储分类字典
# 如果 favorites 中已有非字典的值,说明是旧版平铺结构 {name: value}
# 我们将它们移动到 "默认" 分类
has_legacy = any(not isinstance(v, dict) for v in favorites.values())
if has_legacy:
legacy_items = {k: v for k, v in favorites.items() if not isinstance(v, dict)}
# 清除旧项
for k in legacy_items:
del favorites[k]
# 初始化默认分类
if "默认" not in favorites:
favorites["默认"] = {}
favorites["默认"].update(legacy_items)
if category not in favorites:
favorites[category] = {}
# 如果 favorites[category] 不是字典(防御性编程),初始化为字典
if not isinstance(favorites.get(category), dict):
favorites[category] = {}
favorites[category][name] = tag_value
user_tags["favorites"] = favorites
return self.save_user_tags(user_tags)
except Exception as e:
self._log(f"添加收藏失败: {str(e)}")
return False
def remove_favorite(self, tag_value: str, category: str = None) -> bool:
"""移除收藏"""
try:
user_tags = self.load_user_tags()
favorites = user_tags.get("favorites", {})
# 兼容性迁移
if isinstance(favorites, list):
favorites = self.get_favorites()
removed = False
# 如果指定了分类,只在指定分类中删除
if category:
# 尝试直接匹配分类(完全匹配)
target_categories = [category]
# 如果没找到,尝试模糊匹配(处理文件名后缀差异)
if category not in favorites:
# 比如 category 是 "foo",favorites里有 "foo.csv" 或相反
# 但通常 favorites 里的 key 已经是去后缀的
pass
for cat in target_categories:
if cat in favorites and isinstance(favorites[cat], dict):
# 根据值删除
keys_to_remove = [k for k, v in favorites[cat].items() if v == tag_value]
for k in keys_to_remove:
del favorites[cat][k]
removed = True
# 如果该分类空了,是否删除分类键?暂时保留
else:
# 未指定分类,递归全部删除(旧逻辑)
# 如果是旧版平铺结构
if any(not isinstance(v, dict) for v in favorites.values()):
keys_to_remove = [k for k, v in favorites.items() if not isinstance(v, dict) and v == tag_value]
for k in keys_to_remove:
del favorites[k]
removed = True
# 如果是新版嵌套结构
for cat, items in favorites.items():
if isinstance(items, dict):
keys_to_remove = [k for k, v in items.items() if v == tag_value]
for k in keys_to_remove:
del items[k]
removed = True
if removed:
user_tags["favorites"] = favorites
return self.save_user_tags(user_tags)
return True
except Exception as e:
self._log(f"移除收藏失败: {str(e)}")
return False
def get_system_prompts(self):
"""获取系统提示词配置 (合并提示词定义和激活状态)"""
system_prompts = self.load_system_prompts()
active_prompts = self.load_active_prompts()
system_prompts['active_prompts'] = active_prompts
return system_prompts
def update_system_prompts(self, system_prompts):
"""更新系统提示词配置 (仅更新提示词定义)"""
prompts_to_save = system_prompts.copy()
if 'active_prompts' in prompts_to_save:
del prompts_to_save['active_prompts']
return self.save_system_prompts(prompts_to_save)
def update_active_prompts(self, active_prompts):
"""更新所有激活的提示词"""
return self.save_active_prompts(active_prompts)
def update_active_prompt(self, prompt_type, prompt_id):
"""更新单个激活的提示词"""
active_prompts = self.load_active_prompts()
active_prompts[prompt_type] = prompt_id
return self.save_active_prompts(active_prompts)
def get_baidu_translate_config(self):
"""获取百度翻译配置"""
config = self.load_config()
return config.get("baidu_translate", self.default_config["baidu_translate"])
def get_llm_config(self):
"""获取LLM配置"""
config = self.load_config()
current_service_info = config.get('current_services', {}).get('llm')
# 适配新旧格式:支持字符串(旧)和字典(新)
if isinstance(current_service_info, str):
# 旧格式: "service_id"
current_service_id = current_service_info
current_model_name = None
elif isinstance(current_service_info, dict):
# 新格式: {"service": "service_id", "model": "model_name"}
current_service_id = current_service_info.get('service')
current_model_name = current_service_info.get('model')
else:
# 未设置
current_service_id = None
current_model_name = None
if not current_service_id:
# 没有选中的服务,返回默认结构
return self._get_empty_llm_config()
# 查找对应的服务
service = self._get_service_by_id(current_service_id)
if not service:
return self._get_empty_llm_config()
# 获取LLM模型列表
llm_models = service.get('llm_models', [])
# 如果指定了模型名称,尝试查找
target_model = None
if current_model_name:
target_model = next((m for m in llm_models if m.get('name') == current_model_name), None)
# 如果未找到指定模型,使用默认模型或第一个模型
if not target_model:
target_model = next((m for m in llm_models if m.get('is_default')),
llm_models[0] if llm_models else None)
if not target_model:
return self._get_empty_llm_config()
# 直接获取API Key(明文存储)
api_key = service.get('api_key', '')
# 返回配置
return {
"provider": service.get('id', ''), # 使用service_id作为provider
"model": target_model.get('name', ''),
"base_url": service.get('base_url', ''),
"api_key": api_key,
"temperature": target_model.get('temperature', 0.7),
"max_tokens": target_model.get('max_tokens', 1000),
"top_p": target_model.get('top_p', 0.9),
"auto_unload": service.get('auto_unload', True) if service.get('type') == 'ollama' else None,
"providers": {} # v2.0中不再使用此字段
}
def _get_empty_llm_config(self):
"""返回空的LLM配置"""
return {
"provider": "",
"model": "",
"base_url": "",
"api_key": "",
"temperature": 0.7,
"max_tokens": 1000,
"top_p": 0.9,
"providers": {}
}
def _get_service_by_id(self, service_id: str) -> dict:
"""根据ID获取服务配置"""
config = self.load_config()
services = config.get('model_services', [])
for service in services:
if service.get('id') == service_id:
return service
return None
def get_vision_config(self):
"""获取视觉模型配置"""
config = self.load_config()
current_service_info = config.get('current_services', {}).get('vlm')
# 适配新旧格式:支持字符串(旧)和字典(新)
if isinstance(current_service_info, str):
# 旧格式: "service_id"
current_service_id = current_service_info
current_model_name = None
elif isinstance(current_service_info, dict):
# 新格式: {"service": "service_id", "model": "model_name"}
current_service_id = current_service_info.get('service')
current_model_name = current_service_info.get('model')
else:
# 未设置
current_service_id = None
current_model_name = None
if not current_service_id:
# 没有选中的服务,返回默认结构
return self._get_empty_vision_config()
# 查找对应的服务
service = self._get_service_by_id(current_service_id)
if not service:
return self._get_empty_vision_config()
# 获取VLM模型列表
vlm_models = service.get('vlm_models', [])
# 如果指定了模型名称,尝试查找
target_model = None
if current_model_name:
target_model = next((m for m in vlm_models if m.get('name') == current_model_name), None)
# 如果未找到指定模型,使用默认模型或第一个模型
if not target_model:
target_model = next((m for m in vlm_models if m.get('is_default')),
vlm_models[0] if vlm_models else None)
if not target_model:
return self._get_empty_vision_config()
# 直接获取API Key(明文存储)
api_key = service.get('api_key', '')
# 返回配置
return {
"provider": service.get('id', ''), # 使用service_id作为provider
"model": target_model.get('name', ''),
"base_url": service.get('base_url', ''),
"api_key": api_key,
"temperature": target_model.get('temperature', 0.7),
"max_tokens": target_model.get('max_tokens', 4096),
"top_p": target_model.get('top_p', 0.9),
"auto_unload": service.get('auto_unload', True) if service.get('type') == 'ollama' else None,
"providers": {} # v2.0中不再使用此字段
}
def _get_empty_vision_config(self):
"""返回空的视觉模型配置"""
return {
"provider": "",
"model": "",
"base_url": "",
"api_key": "",
"temperature": 0.7,
"max_tokens": 4096,
"top_p": 0.9,
"providers": {}
}
def get_translate_config(self):
"""获取翻译服务配置(支持百度翻译和LLM翻译)"""
config = self.load_config()
current_service_info = config.get('current_services', {}).get('translate')
# 适配新旧格式:支持字符串(旧)和字典(新)
if isinstance(current_service_info, str):
# 旧格式: "service_id"
current_service_id = current_service_info
current_model_name = None
elif isinstance(current_service_info, dict):
# 新格式: {"service": "service_id", "model": "model_name"}
current_service_id = current_service_info.get('service')
current_model_name = current_service_info.get('model')
else:
# 未设置,默认使用百度翻译
current_service_id = 'baidu'
current_model_name = None
# 百度翻译特殊处理(使用独立的baidu_translate配置)
if current_service_id == 'baidu':
baidu_config = self.get_baidu_translate_config()
return {
"provider": "baidu",
"model": "",
"base_url": "",
"api_key": baidu_config.get('app_id', ''),
"secret_key": baidu_config.get('secret_key', ''),
"temperature": 0.7,
"max_tokens": 1000,
"top_p": 0.9,
"providers": {}
}
# 查找对应的LLM服务
service = self._get_service_by_id(current_service_id)
if not service:
# 服务不存在,回退到百度翻译
baidu_config = self.get_baidu_translate_config()
return {
"provider": "baidu",
"model": "",
"base_url": "",
"api_key": baidu_config.get('app_id', ''),
"secret_key": baidu_config.get('secret_key', ''),
"temperature": 0.7,
"max_tokens": 1000,
"top_p": 0.9,
"providers": {}
}
# 获取LLM模型列表
llm_models = service.get('llm_models', [])
# 如果指定了模型名称,尝试查找
target_model = None
if current_model_name:
target_model = next((m for m in llm_models if m.get('name') == current_model_name), None)
# 如果未找到指定模型,使用默认模型或第一个模型
if not target_model:
target_model = next((m for m in llm_models if m.get('is_default')),
llm_models[0] if llm_models else None)
if not target_model:
# 没有可用模型,回退到百度翻译
baidu_config = self.get_baidu_translate_config()
return {
"provider": "baidu",
"model": "",
"base_url": "",
"api_key": baidu_config.get('app_id', ''),
"secret_key": baidu_config.get('secret_key', ''),
"temperature": 0.7,
"max_tokens": 1000,
"top_p": 0.9,
"providers": {}
}
# 返回LLM翻译配置
api_key = service.get('api_key', '')
return {
"provider": service.get('id', ''),
"model": target_model.get('name', ''),
"base_url": service.get('base_url', ''),
"api_key": api_key,
"temperature": target_model.get('temperature', 0.7),
"max_tokens": target_model.get('max_tokens', 1000),
"top_p": target_model.get('top_p', 0.9),
"auto_unload": service.get('auto_unload', True) if service.get('type') == 'ollama' else None,
"providers": {}
}
def get_settings(self):
"""获取ComfyUI用户设置(从设置文件读取)"""
try:
# ComfyUI的设置文件通常位于 user/default/comfy.settings.json
# 需要找到ComfyUI的根目录
import sys
# 尝试从多个可能的路径查找设置文件
possible_paths = []
# 方法1: 通过当前文件路径向上查找
current_dir = os.path.dirname(os.path.abspath(__file__))
# custom_nodes/comfyui_prompt_assistant -> custom_nodes -> ComfyUI
comfyui_root = os.path.dirname(os.path.dirname(current_dir))
possible_paths.append(os.path.join(comfyui_root, "user", "default", "comfy.settings.json"))
# 方法2: 通过sys.path查找
for path in sys.path:
if 'ComfyUI' in path:
possible_paths.append(os.path.join(path, "user", "default", "comfy.settings.json"))
# 尝试读取设置文件
for settings_path in possible_paths:
if os.path.exists(settings_path):
try:
with open(settings_path, 'r', encoding='utf-8') as f:
settings_data = json.load(f)
# 返回设置数据
return settings_data
except Exception as e:
self._log(f"读取设置文件失败: {settings_path}, 错误: {str(e)}")
continue
# 如果都找不到,返回空字典
return {}
except Exception as e:
# 如果无法获取,返回空字典
self._log(f"获取用户设置失败: {str(e)}")
return {}
def update_baidu_translate_config(self, app_id=None, secret_key=None):
"""更新百度翻译配置"""
config = self.load_config()
if "baidu_translate" not in config:
config["baidu_translate"] = {}
# 仅更新提供的参数
if app_id is not None:
config["baidu_translate"]["app_id"] = app_id
if secret_key is not None:
config["baidu_translate"]["secret_key"] = secret_key
return self.save_config(config)
# --- 注意:validate_and_fix_system_prompts 已迁移到 migration_tool.py ---
# 系统提示词的验证和补全由 migration_tool 的增量更新逻辑统一处理
def validate_and_fix_active_prompts(self):
"""
验证激活提示词是否存在,如果不存在则修复
注意:此方法只修复 active_prompts.json(切换到存在的提示词)
不会恢复 system_prompts.json 中被删除的内容(尊重用户的删除操作)
"""
try:
system_prompts = self.load_system_prompts()
active_prompts = self.load_active_prompts()
# 标记是否需要更新激活提示词
modified = False
# 检查并修复扩写提示词
if "expand" in active_prompts:
expand_id = active_prompts["expand"]
expand_prompts = system_prompts.get("expand_prompts", {})
if expand_id not in expand_prompts:
# 激活的提示词不存在,切换到第一个可用的
if expand_prompts:
first_expand_id = next(iter(expand_prompts))
active_prompts["expand"] = first_expand_id
self._log(f"激活的扩写提示词 '{expand_id}' 不存在,已切换到 '{first_expand_id}'")
modified = True
else:
# 没有可用的扩写提示词,清空激活
active_prompts["expand"] = ""
self._log(f"警告:没有可用的扩写提示词")
modified = True
# 检查并修复中文反推提示词
if "vision_zh" in active_prompts:
vision_zh_id = active_prompts["vision_zh"]
vision_prompts = system_prompts.get("vision_prompts", {})
zh_prompts = {k: v for k, v in vision_prompts.items() if k.startswith("vision_zh_")}
if vision_zh_id not in vision_prompts:
if zh_prompts:
first_id = next(iter(zh_prompts))
active_prompts["vision_zh"] = first_id
self._log(f"激活的中文反推提示词 '{vision_zh_id}' 不存在,已切换到 '{first_id}'")
modified = True
else:
active_prompts["vision_zh"] = ""
self._log(f"警告:没有可用的中文反推提示词")
modified = True
# 检查并修复英文反推提示词
if "vision_en" in active_prompts: