Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _parse_kwargs(self, **kwargs):
- `repetition_penalty`: float, default 1.05. Repetition penalty
- `max_tokens`: int, default 512. Maximum tokens to generate
- `use_cache`: bool, default True. Whether to use reponse cache
- `quantization`: str, default None. Quantization method (e.g., 'bitsandbytes', 'awq', 'gptq')
"""

self.model_name = kwargs.get("model", None)
Expand All @@ -91,7 +92,8 @@ def _parse_kwargs(self, **kwargs):
self.repetition_penalty = kwargs.get("repetition_penalty", 1.05)
self.max_tokens = kwargs.get("max_tokens", 512)
self.use_cache = kwargs.get("use_cache", True)

self.quantization = kwargs.get("quantization", None)

def inference(self, data):
"""Inference the model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,19 @@ def _load(self, model):
model : str
Hugging Face style model name. Example: `Qwen/Qwen2.5-0.5B-Instruct`
"""
self.model = LLM(
model=model,
trust_remote_code=True,
dtype="float16",
tensor_parallel_size=self.tensor_parallel_size,
gpu_memory_utilization=self.gpu_memory_utilization,
max_model_len = 8192
#quantization=self.quantization # TODO need to align with vllm API
)

llm_kwargs = {
"model": model,
"trust_remote_code": True,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding trust_remote_code=True can pose a security risk, as it allows arbitrary code execution from the model's repository. It's highly recommended to make this a configurable parameter that defaults to False. Users should explicitly enable it only when they trust the source of the model.

"dtype": "float16",
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The dtype is hardcoded to "float16". While this is a common default, some models perform better with "bfloat16" (if supported by the hardware), and certain quantization methods might have specific dtype requirements. To improve flexibility, consider making this a configurable parameter, which could be parsed in BaseLLM with a default of "auto" or "float16".

"tensor_parallel_size": self.tensor_parallel_size,
"gpu_memory_utilization": self.gpu_memory_utilization,
"max_model_len": 8192
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The max_model_len is hardcoded to 8192. This could be restrictive for models with larger context windows or inefficient for models with smaller ones. To improve flexibility, consider making this a configurable parameter. You could add max_model_len to _parse_kwargs in base_llm.py with a sensible default, and then use self.max_model_len here.

}

if self.quantization:
llm_kwargs["quantization"] = self.quantization

self.model = LLM(**llm_kwargs)
self.sampling_params = SamplingParams(
temperature=self.temperature,
top_p=self.top_p,
Expand Down