1111from vllm import LLM , SamplingParams
1212from vllm .lora .request import LoRARequest
1313from vllm .outputs import RequestOutput
14- from vllm .sampling_params import GuidedDecodingParams
14+
15+ try :
16+ from vllm .sampling_params import GuidedDecodingParams as _GuidedDecodingParams # type: ignore[attr-defined]
17+
18+ _STRUCTURED_OUTPUTS_PARAMS = None
19+ except ImportError :
20+ _GuidedDecodingParams = None # type: ignore[assignment,misc]
21+ from vllm .sampling_params import StructuredOutputsParams as _STRUCTURED_OUTPUTS_PARAMS # type: ignore[assignment]
22+
23+
24+ def _make_guided_decoding_params (json_schema : dict , disable_any_whitespace : bool = True ):
25+ if _GuidedDecodingParams is not None :
26+ return _GuidedDecodingParams (json = json_schema , disable_any_whitespace = disable_any_whitespace )
27+ return _STRUCTURED_OUTPUTS_PARAMS (json = json_schema , disable_any_whitespace = disable_any_whitespace ) # type: ignore[misc]
28+
1529
1630from weclone .utils .config import load_config
1731from weclone .utils .config_models import VllmArgs
@@ -134,22 +148,28 @@ def vllm_infer(
134148 template_obj = get_template_and_fix_tokenizer (tokenizer , data_args )
135149 template_obj .mm_plugin .expand_mm_tokens = False # for vllm generate
136150
151+ guided_decoding_params = None
137152 if guided_decoding_class :
138153 json_schema = guided_decoding_class .model_json_schema ()
139- guided_decoding_params = GuidedDecodingParams (json = json_schema , disable_any_whitespace = True )
140-
141- sampling_params = SamplingParams (
142- repetition_penalty = generating_args .repetition_penalty or 1.0 ,
143- temperature = generating_args .temperature ,
144- top_p = generating_args .top_p or 1.0 ,
145- top_k = generating_args .top_k or - 1 ,
146- stop_token_ids = template_obj .get_stop_token_ids (tokenizer ),
147- max_tokens = generating_args .max_new_tokens ,
148- skip_special_tokens = skip_special_tokens ,
149- seed = seed ,
150- bad_words = bad_words ,
151- guided_decoding = guided_decoding_params if guided_decoding_class else None ,
152- )
154+ guided_decoding_params = _make_guided_decoding_params (json_schema )
155+
156+ _sampling_kwargs : dict = {
157+ "repetition_penalty" : generating_args .repetition_penalty or 1.0 ,
158+ "temperature" : generating_args .temperature ,
159+ "top_p" : generating_args .top_p or 1.0 ,
160+ "top_k" : generating_args .top_k or - 1 ,
161+ "stop_token_ids" : template_obj .get_stop_token_ids (tokenizer ),
162+ "max_tokens" : generating_args .max_new_tokens ,
163+ "skip_special_tokens" : skip_special_tokens ,
164+ "seed" : seed ,
165+ "bad_words" : bad_words ,
166+ }
167+ if guided_decoding_params is not None :
168+ if _GuidedDecodingParams is not None :
169+ _sampling_kwargs ["guided_decoding" ] = guided_decoding_params
170+ else :
171+ _sampling_kwargs ["structured_outputs" ] = guided_decoding_params
172+ sampling_params = SamplingParams (** _sampling_kwargs )
153173 if model_args .adapter_name_or_path is not None :
154174 lora_request = LoRARequest ("default" , 1 , model_args .adapter_name_or_path [0 ])
155175 else :
@@ -163,9 +183,10 @@ def vllm_infer(
163183 "disable_log_stats" : True ,
164184 "enable_lora" : model_args .adapter_name_or_path is not None ,
165185 "enable_prefix_caching" : True ,
166- "guided_decoding_backend" : "guidance" ,
167- "guided_decoding_disable_any_whitespace" : True ,
168186 }
187+ if _GuidedDecodingParams is not None :
188+ engine_args ["guided_decoding_backend" ] = "guidance"
189+ engine_args ["guided_decoding_disable_any_whitespace" ] = True
169190
170191 if template_obj .mm_plugin .__class__ .__name__ != "BasePlugin" :
171192 engine_args ["limit_mm_per_prompt" ] = {"image" : 4 , "video" : 2 , "audio" : 2 }
0 commit comments