Skip to content

Commit 2bd5244

Browse files
committed
Loading/saving of previous configuration for plugins.
1 parent 358ba94 commit 2bd5244

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

main.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,14 @@ def get_plugin_info(plugin_name: str):
284284

285285
@app.get("/plugins/get_config/{plugin_name}")
286286
def get_plugin_config(plugin_name: str):
287-
if plugin_name in plugin_list:
287+
if plugin_name in plugin_list:
288+
sleep = 0
289+
while plugin_states[plugin] != "RUNNING":
290+
start_plugin(plugin)
291+
time.sleep(5)
292+
sleep += 5
293+
if sleep > 120:
294+
return {"status": "failed", "error": "Plugin too slow to start"}
288295
if plugin_name in port_mapping.keys():
289296
port = port_mapping[plugin_name]
290297
r = client.get("http://127.0.0.1:" + port + "/get_config")

plugin/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from tqdm import tqdm
1212
import shutil
1313
import threading
14+
import storage_db
1415

1516
if sys.platform == "win32":
1617
storage_folder = os.path.join(os.getenv('APPDATA'),"DeepMake")
@@ -78,8 +79,9 @@ class Plugin():
7879
Generic plugin class
7980
"""
8081

81-
def __init__(self, arguments={}):
82-
self.plugin_name = "default"
82+
def __init__(self, arguments={}, plugin_name="default"):
83+
self.plugin_name = plugin_name
84+
self.db = storage_db.storage_db()
8385
if arguments == {}:
8486
self.plugin = {}
8587
self.config = {}
@@ -88,7 +90,9 @@ def __init__(self, arguments={}):
8890
self.plugin = arguments.plugin
8991
self.config = arguments.config
9092
self.endpoints = arguments.endpoints
91-
93+
config = self.db.retrieve_data(f"plugin_config.{self.plugin_name}")
94+
if config:
95+
self.config = config
9296

9397
# Create a plugin-specific storage path
9498
self.plugin_storage_path = os.path.join(storage_folder, self.plugin_name)
@@ -103,10 +107,9 @@ def get_config(self):
103107

104108
def set_config(self, update: dict):
105109
self.config.update(update) # TODO: Validate config dict are all valid keys
106-
if "model_name" in update or "scheduler" in update or "loras" in update or "inverters" in update:
110+
self.db.store_data(f"plugin_config.{self.plugin_name}", self.config)
111+
if "model_name" in update:
107112
self.set_model()
108-
# if response["status"] == "Failed":
109-
# return response
110113
return self.config
111114

112115
def progress_callback(self, progress, stage):

0 commit comments

Comments
 (0)