11import asyncio
22import re
3- from collections .abc import Awaitable , Callable
43
54import firecrest as f7t
65
@@ -87,11 +86,18 @@ async def _get_partitions(
8786 )
8887
8988
89+ def _split_vendor_model (combined : str ) -> tuple [str , str ]:
90+ vendor , model_name = combined .split ("::" , 1 )
91+ return vendor , model_name
92+
93+
9094async def _get_preconfigured_default (
9195 get_value_from_context : GetValueFn , preconfigured : list [LaunchRequest ], field : str
9296) -> str | None :
93- vendor = get_value_from_context ("model_vendor" )
94- model_name = get_value_from_context ("model_name" )
97+ combined = get_value_from_context ("model_vendor_model" )
98+ if combined is None :
99+ return None
100+ vendor , model_name = _split_vendor_model (combined )
95101 framework = get_value_from_context ("framework" )
96102 match = next (
97103 (
@@ -108,43 +114,36 @@ async def _get_preconfigured_default(
108114 return str (getattr (match , field ))
109115
110116
111- def _make_served_model_name_default (
112- preconfigured : list [LaunchRequest ],
113- ) -> Callable [[GetValueFn ], Awaitable [str ]]:
114- async def _default (get_value : GetValueFn ) -> str :
115- value = await _get_preconfigured_default (
116- get_value , preconfigured , "served_model_name"
117- )
118- if value and value != "None" :
119- return value
120- return f"{ get_value ('model_vendor' )} /{ get_value ('model_name' )} -{ create_salt (4 )} "
121-
122- return _default
117+ async def _get_router_options (get_value : GetValueFn ) -> dict [str , tuple [str , str ]]:
118+ workers = get_value ("workers" )
119+ if workers is not None and int (workers ) > 1 :
120+ return {
121+ "yes" : ("Yes" , "Use router to load balance across workers" ),
122+ "no" : ("No" , "Do not use router" ),
123+ }
124+ return {
125+ "no" : ("No" , "Do not use router" ),
126+ }
123127
124128
125129async def _get_launch_request (launcher : Launcher ) -> LaunchRequest :
126130 preconfigured_launch_requests = await launcher .get_preconfigured_models ()
127131
128- async def _get_vendors () -> dict [str , tuple [str , str ]]:
129- return {
130- lr .vendor : (lr .vendor , lr .vendor ) for lr in preconfigured_launch_requests
131- }
132-
133- async def _get_models (
134- get_value_from_context : GetValueFn ,
135- ) -> dict [str , tuple [str , str ]]:
136- vendor = get_value_from_context ("model_vendor" )
137- return {
138- lr .model_name : (lr .model_name , lr .model_name )
139- for lr in preconfigured_launch_requests
140- if lr .vendor == vendor
141- }
132+ async def _get_vendor_models () -> dict [str , tuple [str , str ]]:
133+ seen : dict [str , tuple [str , str ]] = {}
134+ for lr in preconfigured_launch_requests :
135+ key = f"{ lr .vendor } ::{ lr .model_name } "
136+ if key not in seen :
137+ seen [key ] = (lr .model_name , lr .vendor )
138+ return seen
142139
143140 async def _get_frameworks (
144141 get_value_from_context : GetValueFn ,
145142 ) -> dict [str , tuple [str , str ]]:
146- vendor = get_value_from_context ("model_vendor" )
147- model_name = get_value_from_context ("model_name" )
143+ combined = get_value_from_context ("model_vendor_model" )
144+ if combined is None :
145+ return {}
146+ vendor , model_name = _split_vendor_model (combined )
148147 return {
149148 lr .framework : (lr .framework , lr .framework )
150149 for lr in preconfigured_launch_requests
@@ -155,14 +154,9 @@ async def _get_frameworks(
155154 name = "launcher_request_configuration" ,
156155 chain = [
157156 OptionsConfiguration (
158- name = "model_vendor" ,
159- prompt = "Choose the model vendor." ,
160- options_factory = _get_vendors ,
161- ),
162- OptionsConfiguration (
163- name = "model_name" ,
157+ name = "model_vendor_model" ,
164158 prompt = "Choose the model to launch." ,
165- options_factory = _get_models ,
159+ options_factory = _get_vendor_models ,
166160 ),
167161 OptionsConfiguration (
168162 name = "framework" ,
@@ -177,13 +171,10 @@ async def _get_frameworks(
177171 get_value , preconfigured_launch_requests , "workers"
178172 ),
179173 ),
180- TextConfiguration (
181- name = "nodes_per_worker" ,
182- prompt = "Number of nodes to use per worker for running the model." ,
183- validator = lambda v : v .isdigit () and int (v ) > 0 ,
184- default_factory = lambda get_value : _get_preconfigured_default (
185- get_value , preconfigured_launch_requests , "nodes_per_worker"
186- ),
174+ OptionsConfiguration (
175+ name = "use_router" ,
176+ prompt = "Use router to load balance across workers." ,
177+ options_factory = lambda get_value : _get_router_options (get_value ),
187178 ),
188179 TextConfiguration (
189180 name = "time" ,
@@ -195,26 +186,35 @@ async def _get_frameworks(
195186 get_value , preconfigured_launch_requests , "time"
196187 ),
197188 ),
198- TextConfiguration (
199- name = "served_model_name" ,
200- prompt = "Served model name." ,
201- validator = lambda s : len (s ) > 0 ,
202- default_factory = _make_served_model_name_default (
203- preconfigured_launch_requests
204- ),
205- ),
206189 ],
207190 )
208191 await launch_req_config .aconfigure ()
209192
193+ vendor , model_name = _split_vendor_model (
194+ launch_req_config .get_non_none_value ("model_vendor_model" )
195+ )
196+ framework = launch_req_config .get_non_none_value ("framework" )
197+ preconfigured = next (
198+ (
199+ lr
200+ for lr in preconfigured_launch_requests
201+ if lr .vendor == vendor
202+ and lr .model_name == model_name
203+ and lr .framework == framework
204+ ),
205+ None ,
206+ )
210207 return LaunchRequest (
211- vendor = launch_req_config .get_non_none_value ("model_vendor" ),
212- model_name = launch_req_config .get_non_none_value ("model_name" ),
213- framework = launch_req_config .get_non_none_value ("framework" ),
208+ vendor = vendor ,
209+ model_name = model_name ,
210+ framework = framework ,
211+ environment = preconfigured .environment if preconfigured else None ,
214212 workers = int (launch_req_config .get_non_none_value ("workers" )),
215- nodes_per_worker = int ( launch_req_config . get_non_none_value ( " nodes_per_worker" )) ,
213+ nodes_per_worker = preconfigured . nodes_per_worker if preconfigured else 1 ,
216214 time = launch_req_config .get_non_none_value ("time" ),
217- served_model_name = launch_req_config .get_non_none_value ("served_model_name" ),
215+ served_model_name = f"{ vendor } /{ model_name } -{ create_salt (4 )} " ,
216+ framework_args = preconfigured .framework_args if preconfigured else None ,
217+ use_router = launch_req_config .get_non_none_value ("use_router" ) == "yes" ,
218218 )
219219
220220
@@ -262,3 +262,7 @@ async def _monitor() -> None:
262262
263263def main () -> None :
264264 asyncio .run (_main ())
265+
266+
267+ if __name__ == "__main__" :
268+ main ()
0 commit comments