@@ -54,6 +54,15 @@ def task_callback(ml_model_metadata,
54
54
if not ml_model_metadata .ml_model_metadata ().empty ():
55
55
56
56
try :
57
+ # Model restriction after various outputs
58
+ restrained_models = []
59
+ extra_data_bytes = ml_model_metadata .extra_data ()
60
+ if extra_data_bytes :
61
+ extra_data_str = '' .join (chr (b ) for b in extra_data_bytes )
62
+ extra_data_dict = json .loads (extra_data_str )
63
+ if "model_restrains" in extra_data_dict :
64
+ restrained_models = extra_data_dict ["model_restrains" ]
65
+ print ("Restrained models:" , restrained_models ) #debugging
57
66
58
67
graph = load_graph (os .path .dirname (__file__ )+ '/graph_v2.ttl' )
59
68
metadata = ml_model_metadata .ml_model_metadata ()[0 ]
@@ -67,7 +76,18 @@ def task_callback(ml_model_metadata,
67
76
model_names = [model [0 ] for model in suggested_models ]
68
77
69
78
# Random Model is selected here. In the Final code there should be some sort of selection to choose between Possible Models
70
- chosen_model = model_names [1 ]
79
+ for model_use in model_names :
80
+ # Some models can't be downloaded from HF, TODO: Works for all models
81
+ if (str (model_use ) == "meta-llama/Llama-3.1-8B-Instruct" ):
82
+ model_use = "openai-community/gpt2"
83
+ if (str (model_use ) == "mlx-community/Llama-3.2-1B-Instruct-4bit" ):
84
+ model_use = "openai-community/gpt2-medium"
85
+ if str (model_use ) not in restrained_models :
86
+ chosen_model = model_use
87
+ break
88
+ else :
89
+ print (f"Chosen model: { model_use } is restrained. Choosing the next model." )
90
+
71
91
print (f"" ) #Debugging
72
92
print (f"Chosen model: { chosen_model } " ) #Debugging
73
93
0 commit comments