Skip to content

Commit ab863cf

Browse files
committed
refactor(infer): enhance guided decoding parameter handling
- Updated the guided decoding parameter logic to support dynamic imports, improving compatibility with different configurations. - Simplified the creation of guided decoding parameters by introducing a helper function. - Modified the retry mechanism in the OpenAI API call to retry on all exceptions, enhancing robustness in error handling.
1 parent cfbb4fd commit ab863cf

3 files changed

Lines changed: 46 additions & 37 deletions

File tree

WC-exp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
Subproject commit e62bd446c9147b92ebb2b03dd02f169cbbbf90dd
1+
Subproject commit e0acf48c642b56625832e858ce3a7bf95ac14270

weclone/core/inference/offline_infer.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,21 @@
1111
from vllm import LLM, SamplingParams
1212
from vllm.lora.request import LoRARequest
1313
from vllm.outputs import RequestOutput
14-
from vllm.sampling_params import GuidedDecodingParams
14+
15+
try:
16+
from vllm.sampling_params import GuidedDecodingParams as _GuidedDecodingParams # type: ignore[attr-defined]
17+
18+
_STRUCTURED_OUTPUTS_PARAMS = None
19+
except ImportError:
20+
_GuidedDecodingParams = None # type: ignore[assignment,misc]
21+
from vllm.sampling_params import StructuredOutputsParams as _STRUCTURED_OUTPUTS_PARAMS # type: ignore[assignment]
22+
23+
24+
def _make_guided_decoding_params(json_schema: dict, disable_any_whitespace: bool = True):
25+
if _GuidedDecodingParams is not None:
26+
return _GuidedDecodingParams(json=json_schema, disable_any_whitespace=disable_any_whitespace)
27+
return _STRUCTURED_OUTPUTS_PARAMS(json=json_schema, disable_any_whitespace=disable_any_whitespace) # type: ignore[misc]
28+
1529

1630
from weclone.utils.config import load_config
1731
from weclone.utils.config_models import VllmArgs
@@ -134,22 +148,28 @@ def vllm_infer(
134148
template_obj = get_template_and_fix_tokenizer(tokenizer, data_args)
135149
template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate
136150

151+
guided_decoding_params = None
137152
if guided_decoding_class:
138153
json_schema = guided_decoding_class.model_json_schema()
139-
guided_decoding_params = GuidedDecodingParams(json=json_schema, disable_any_whitespace=True)
140-
141-
sampling_params = SamplingParams(
142-
repetition_penalty=generating_args.repetition_penalty or 1.0,
143-
temperature=generating_args.temperature,
144-
top_p=generating_args.top_p or 1.0,
145-
top_k=generating_args.top_k or -1,
146-
stop_token_ids=template_obj.get_stop_token_ids(tokenizer),
147-
max_tokens=generating_args.max_new_tokens,
148-
skip_special_tokens=skip_special_tokens,
149-
seed=seed,
150-
bad_words=bad_words,
151-
guided_decoding=guided_decoding_params if guided_decoding_class else None,
152-
)
154+
guided_decoding_params = _make_guided_decoding_params(json_schema)
155+
156+
_sampling_kwargs: dict = {
157+
"repetition_penalty": generating_args.repetition_penalty or 1.0,
158+
"temperature": generating_args.temperature,
159+
"top_p": generating_args.top_p or 1.0,
160+
"top_k": generating_args.top_k or -1,
161+
"stop_token_ids": template_obj.get_stop_token_ids(tokenizer),
162+
"max_tokens": generating_args.max_new_tokens,
163+
"skip_special_tokens": skip_special_tokens,
164+
"seed": seed,
165+
"bad_words": bad_words,
166+
}
167+
if guided_decoding_params is not None:
168+
if _GuidedDecodingParams is not None:
169+
_sampling_kwargs["guided_decoding"] = guided_decoding_params
170+
else:
171+
_sampling_kwargs["structured_outputs"] = guided_decoding_params
172+
sampling_params = SamplingParams(**_sampling_kwargs)
153173
if model_args.adapter_name_or_path is not None:
154174
lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
155175
else:
@@ -163,9 +183,10 @@ def vllm_infer(
163183
"disable_log_stats": True,
164184
"enable_lora": model_args.adapter_name_or_path is not None,
165185
"enable_prefix_caching": True,
166-
"guided_decoding_backend": "guidance",
167-
"guided_decoding_disable_any_whitespace": True,
168186
}
187+
if _GuidedDecodingParams is not None:
188+
engine_args["guided_decoding_backend"] = "guidance"
189+
engine_args["guided_decoding_disable_any_whitespace"] = True
169190

170191
if template_obj.mm_plugin.__class__.__name__ != "BasePlugin":
171192
engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2}

weclone/utils/retry.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def retry_openai_api(
9999
):
100100
"""
101101
专门用于OpenAI API调用的重试装饰器
102-
处理OpenAI特有的异常类型
102+
对所有Exception执行重试
103103
"""
104104

105105
def decorator(func):
@@ -110,18 +110,7 @@ def wrapper(*args, **kwargs):
110110
return func(*args, **kwargs)
111111

112112
except Exception as e:
113-
# 检查是否是速率限制或临时错误
114-
error_message = str(e).lower()
115-
should_retry = (
116-
"rate limit" in error_message
117-
or "429" in error_message
118-
or "too many requests" in error_message
119-
or "server error" in error_message
120-
or "timeout" in error_message
121-
or "connection" in error_message
122-
)
123-
124-
if should_retry and attempt < max_retries:
113+
if attempt < max_retries:
125114
delay = _calculate_delay(attempt, base_delay, max_delay, backoff_factor, jitter)
126115
logger.warning(
127116
f"OpenAI API调用失败: {type(e).__name__}: {e},"
@@ -130,12 +119,11 @@ def wrapper(*args, **kwargs):
130119
)
131120
time.sleep(delay)
132121
continue
133-
else:
134-
if attempt >= max_retries:
135-
logger.error(
136-
f"OpenAI API调用在 {max_retries + 1} 次尝试后最终失败: {type(e).__name__}: {e}"
137-
)
138-
raise
122+
123+
logger.error(
124+
f"OpenAI API调用在 {max_retries + 1} 次尝试后最终失败: {type(e).__name__}: {e}"
125+
)
126+
raise
139127

140128
return None
141129

0 commit comments

Comments
 (0)