-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig_loader.py
More file actions
104 lines (94 loc) · 3.94 KB
/
config_loader.py
File metadata and controls
104 lines (94 loc) · 3.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
配置加载与合并。
"""
import json
import os
from typing import Any, Dict, Optional
def resolve_config_path(config_arg: Optional[str]) -> Optional[str]:
if config_arg:
return config_arg
if os.path.exists('config.json'):
return 'config.json'
return None
def load_config(args) -> Dict[str, Any]:
config_path = resolve_config_path(args.config)
config: Dict[str, Any] = {}
if config_path:
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
env_api_key = os.getenv("OPENAI_API_KEY")
env_base_url = os.getenv("OPENAI_BASE_URL")
env_model = os.getenv("OPENAI_MODEL")
env_prompt = os.getenv("IMAGE_PROMPT")
env_quality_hint = os.getenv("QUALITY_HINT")
env_output_name_template = os.getenv("OUTPUT_NAME_TEMPLATE")
env_timeout = os.getenv("OPENAI_TIMEOUT")
env_write_results = os.getenv("WRITE_RESULTS")
env_output_long_side = os.getenv("OUTPUT_LONG_SIDE")
env_input_dir = os.getenv("IMAGE_INPUT_DIR")
env_prompt_file = os.getenv("IMAGE_PROMPT_FILE")
api_key = args.api_key or env_api_key or config.get('api_key')
input_dir = args.input or env_input_dir or config.get('input_dir')
prompt = args.prompt or env_prompt or config.get('prompt')
model = args.model or env_model or config.get('model')
base_url = args.base_url or env_base_url or config.get('base_url')
timeout = config.get('timeout', args.timeout)
if env_timeout:
timeout = float(env_timeout)
delay = config.get('delay', args.delay)
output_file = config.get('output_file', args.output)
output_dir = config.get('output_dir', args.output_dir)
output_name_template = args.output_name_template or env_output_name_template or config.get('output_name_template')
output_long_side = config.get('output_long_side', args.output_long_side)
if env_output_long_side:
output_long_side = int(env_output_long_side)
write_results = config.get('write_results', args.write_results)
if env_write_results is not None:
write_results = env_write_results.strip().lower() not in ("0", "false", "no")
temperature = config.get('temperature', args.temperature)
max_tokens = config.get('max_tokens', args.max_tokens)
prompt_file = args.prompt_file or env_prompt_file or config.get('prompt_file')
quality_hint = args.quality_hint or env_quality_hint or config.get('quality_hint')
upscale = float(config.get('upscale', args.upscale))
max_side = int(config.get('max_side', args.max_side))
concurrency = int(config.get('concurrency', args.concurrency))
resume = bool(config.get('resume', args.resume))
retry = int(config.get('retry', args.retry))
debug_save = bool(config.get('debug_save', args.debug_save))
if config_path:
stream = not bool(config.get('no_stream', False))
else:
stream = not args.no_stream
return {
"api_key": api_key,
"input_dir": input_dir,
"prompt": prompt,
"model": model,
"base_url": base_url,
"timeout": timeout,
"delay": delay,
"output_file": output_file,
"output_dir": output_dir,
"output_name_template": output_name_template,
"output_long_side": output_long_side,
"write_results": write_results,
"temperature": temperature,
"max_tokens": max_tokens,
"prompt_file": prompt_file,
"quality_hint": quality_hint,
"upscale": upscale,
"max_side": max_side,
"concurrency": concurrency,
"resume": resume,
"retry": retry,
"debug_save": debug_save,
"stream": stream,
}
def load_prompt_file(prompt_file: Optional[str], fallback_prompt: Optional[str]) -> Optional[str]:
if not prompt_file:
return fallback_prompt
with open(prompt_file, 'r', encoding='utf-8') as pf:
prompt_from_file = pf.read().strip()
return prompt_from_file or fallback_prompt