-
-
Notifications
You must be signed in to change notification settings - Fork 948
Expand file tree
/
Copy pathconfig.py
More file actions
92 lines (80 loc) · 3.77 KB
/
config.py
File metadata and controls
92 lines (80 loc) · 3.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from getpass import getpass
from pathlib import Path
from tempfile import gettempdir
from typing import Any
from click import UsageError
CONFIG_FOLDER = os.path.expanduser("~/.config")
SHELL_GPT_CONFIG_FOLDER = Path(CONFIG_FOLDER) / "shell_gpt"
SHELL_GPT_CONFIG_PATH = SHELL_GPT_CONFIG_FOLDER / ".sgptrc"
ROLE_STORAGE_PATH = SHELL_GPT_CONFIG_FOLDER / "roles"
FUNCTIONS_PATH = SHELL_GPT_CONFIG_FOLDER / "functions"
CHAT_CACHE_PATH = Path(gettempdir()) / "chat_cache"
CACHE_PATH = Path(gettempdir()) / "cache"
# TODO: Refactor ENV variables with SGPT_ prefix.
DEFAULT_CONFIG = {
# TODO: Refactor it to CHAT_STORAGE_PATH.
"CHAT_CACHE_PATH": os.getenv("CHAT_CACHE_PATH", str(CHAT_CACHE_PATH)),
"CACHE_PATH": os.getenv("CACHE_PATH", str(CACHE_PATH)),
"CHAT_CACHE_LENGTH": int(os.getenv("CHAT_CACHE_LENGTH", "100")),
"CACHE_LENGTH": int(os.getenv("CHAT_CACHE_LENGTH", "100")),
"REQUEST_TIMEOUT": int(os.getenv("REQUEST_TIMEOUT", "60")),
"DEFAULT_MODEL": os.getenv("DEFAULT_MODEL", "gpt-4.1"),
"DEFAULT_COLOR": os.getenv("DEFAULT_COLOR", "magenta"),
"ROLE_STORAGE_PATH": os.getenv("ROLE_STORAGE_PATH", str(ROLE_STORAGE_PATH)),
"DEFAULT_EXECUTE_SHELL_CMD": os.getenv("DEFAULT_EXECUTE_SHELL_CMD", "false"),
"DISABLE_STREAMING": os.getenv("DISABLE_STREAMING", "false"),
"CODE_THEME": os.getenv("CODE_THEME", "dracula"),
"OPENAI_FUNCTIONS_PATH": os.getenv("OPENAI_FUNCTIONS_PATH", str(FUNCTIONS_PATH)),
"OPENAI_USE_FUNCTIONS": os.getenv("OPENAI_USE_FUNCTIONS", "true"),
"SHOW_FUNCTIONS_OUTPUT": os.getenv("SHOW_FUNCTIONS_OUTPUT", "false"),
"API_BASE_URL": os.getenv("API_BASE_URL", "default"),
"PRETTIFY_MARKDOWN": os.getenv("PRETTIFY_MARKDOWN", "true"),
"USE_LITELLM": os.getenv("USE_LITELLM", "false"),
"SHELL_INTERACTION": os.getenv("SHELL_INTERACTION ", "true"),
"OS_NAME": os.getenv("OS_NAME", "auto"),
"SHELL_NAME": os.getenv("SHELL_NAME", "auto"),
# New features might add their own config variables here.
}
class Config(dict): # type: ignore
def __init__(self, config_path: Path, **defaults: Any):
self.config_path = config_path
if self._exists:
self._read()
has_new_config = False
for key, value in defaults.items():
if key not in self:
has_new_config = True
self[key] = value
if has_new_config:
self._write()
else:
config_path.parent.mkdir(parents=True, exist_ok=True)
# Don't write API key to config file if it is in the environment.
if not defaults.get("OPENAI_API_KEY") and not os.getenv("OPENAI_API_KEY"):
__api_key = getpass(prompt="Please enter your OpenAI API key: ")
defaults["OPENAI_API_KEY"] = __api_key
super().__init__(**defaults)
self._write()
@property
def _exists(self) -> bool:
return self.config_path.exists()
def _write(self) -> None:
with open(self.config_path, "w", encoding="utf-8") as file:
string_config = ""
for key, value in self.items():
string_config += f"{key}={value}\n"
file.write(string_config)
def _read(self) -> None:
with open(self.config_path, "r", encoding="utf-8") as file:
for line in file:
if line.strip() and not line.startswith("#"):
key, value = line.strip().split("=", 1)
self[key] = value
def get(self, key: str) -> str: # type: ignore
# Prioritize environment variables over config file.
value = os.getenv(key) or super().get(key)
if not value:
raise UsageError(f"Missing config key: {key}")
return value
cfg = Config(SHELL_GPT_CONFIG_PATH, **DEFAULT_CONFIG)