Skip to content

Commit 3a9d669

Browse files
committed
Refs #22884: Add model restrain for multiple outputs & improve model selection logic
Signed-off-by: Javier Gil Aviles <[email protected]>
1 parent 22eb6e0 commit 3a9d669

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

ml_model_metadata_node.py

+4
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ def task_callback(user_input, node_status, ml_model_metadata):
6363
extra_data_str = ''.join(chr(b) for b in extra_data_bytes)
6464
extra_data_dict = json.loads(extra_data_str)
6565

66+
if "model_restrains" in extra_data_dict:
67+
encoded_data = json.dumps({"model_restrains": extra_data_dict["model_restrains"]}).encode("utf-8")
68+
ml_model_metadata.extra_data(encoded_data)
69+
6670
if "goal" in extra_data_dict and extra_data_dict["goal"] != "":
6771
goal = extra_data_dict["goal"]
6872
ml_model_metadata.ml_model_metadata().append(goal)

ml_model_provider_node.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ def task_callback(ml_model_metadata,
5454
if not ml_model_metadata.ml_model_metadata().empty():
5555

5656
try:
57+
# Model restriction after various outputs
58+
restrained_models = []
59+
extra_data_bytes = ml_model_metadata.extra_data()
60+
if extra_data_bytes:
61+
extra_data_str = ''.join(chr(b) for b in extra_data_bytes)
62+
extra_data_dict = json.loads(extra_data_str)
63+
if "model_restrains" in extra_data_dict:
64+
restrained_models = extra_data_dict["model_restrains"]
65+
print("Restrained models:", restrained_models) #debugging
5766

5867
graph = load_graph(os.path.dirname(__file__)+'/graph_v2.ttl')
5968
metadata = ml_model_metadata.ml_model_metadata()[0]
@@ -67,7 +76,18 @@ def task_callback(ml_model_metadata,
6776
model_names = [model[0] for model in suggested_models]
6877

6978
# Random Model is selected here. In the Final code there should be some sort of selection to choose between Possible Models
70-
chosen_model = model_names[1]
79+
for model_use in model_names:
80+
# Some models can't be downloaded from HF, TODO: Works for all models
81+
if(str(model_use) == "meta-llama/Llama-3.1-8B-Instruct"):
82+
model_use = "openai-community/gpt2"
83+
if(str(model_use) == "mlx-community/Llama-3.2-1B-Instruct-4bit"):
84+
model_use = "openai-community/gpt2-medium"
85+
if str(model_use) not in restrained_models:
86+
chosen_model = model_use
87+
break
88+
else:
89+
print(f"Chosen model: {model_use} is restrained. Choosing the next model.")
90+
7191
print(f"") #Debugging
7292
print(f"Chosen model: {chosen_model}") #Debugging
7393

0 commit comments

Comments
 (0)