-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Expand file tree
/
Copy pathllm.py
More file actions
212 lines (174 loc) · 8.38 KB
/
llm.py
File metadata and controls
212 lines (174 loc) · 8.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import types
from typing import Any, Dict, List, Optional, Tuple
import torch
from ...executor.result import CompletionOutput
from ...inputs.registry import DefaultInputProcessor, ExtraProcessedInputs
from ...llmapi.llm import RequestOutput, _TorchLLM
from ...llmapi.tokenizer import TokenizerBase, TransformersTokenizer, tokenizer_factory
from ...sampling_params import SamplingParams
from .distributed import common as dist_ad
from .llm_args import LlmArgs
from .models.factory import ModelFactory
from .shim.demollm import DemoGenerationExecutor
class ADInputProcessor(DefaultInputProcessor):
"""Input processor for AutoDeploy backend.
This is a wrapper to either support standard TRT-LLM text-only input processing or use HF's
message chat template system to process multimodal inputs.
"""
def __init__(self, tokenizer: Optional[TokenizerBase], processor: Optional[Any] = None):
super().__init__(model_path=None, config=None, tokenizer=tokenizer)
# NOTE: HF's tokenizer/processor that has the apply_chat_template method
self.processor = processor or getattr(tokenizer, "tokenizer", None)
def __call__(
self, inputs: Dict[str, Any], sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
if self.processor is None:
raise ValueError("processor is required to tokenize inputs")
# construct kwargs to reflect DefaultInputProcessor
kwargs = {
"add_special_tokens": sampling_params.add_special_tokens,
}
if sampling_params.truncate_prompt_tokens is not None:
kwargs = {
"truncation": True,
"max_length": sampling_params.truncate_prompt_tokens,
}
# check for messages field and if yes, use the apply_chat_template method
if "messages" in inputs:
# multi_modal_data should not be present in the messages field
assert "multi_modal_data" not in inputs, f"unexpected multi_modal_data key in {inputs=}"
# TODO: we don't really need this but it makes for a good sanity check. Consider
# removing this in the future if we need to speed things up.
prompt = self.processor.apply_chat_template(
inputs["messages"],
add_generation_prompt=True,
tokenize=False,
)
inputs["prompt"] = prompt
all_args = self.processor.apply_chat_template(
inputs["messages"],
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
padding=False, # there shouldn't be a need for padding ever...
return_attention_mask=False,
**kwargs,
)
# check if multi_modal_data has already been pre-processed/added to the inputs
# for example, this might be the case when invoking AD via trtllm-serve
elif "multi_modal_data" in inputs:
images = inputs["multi_modal_data"]["image"]
do_rescale = True
if images is not None and isinstance(images[0], torch.Tensor):
# The default multimodal input loader will normalize images to [0, 1] when the requested
# format is "pt" (pytorch tensors), but not for "pil" (PIL images).
do_rescale = False
all_args = self.processor(
text=inputs["prompt"],
images=images,
return_dict=True,
return_tensors="pt",
do_rescale=do_rescale,
)
else:
all_args = None
if all_args is not None:
# TODO: is there a more reliable way to avoid the attention_mask here?
all_args.pop("attention_mask", None)
# TODO: can we avoid the extra tolist() here eventually?
token_ids = all_args.pop("input_ids")
assert token_ids.shape[0] == 1, "messages should be unbatched at this point."
if all_args:
extra_processed_inputs = {"multimodal_data": all_args}
else:
extra_processed_inputs = None
return token_ids[0].tolist(), extra_processed_inputs
else:
token_ids = self.tokenizer.encode(inputs["prompt"], **kwargs)
return token_ids, None
class LLM(_TorchLLM):
"""LLM class is the main class for running an LLM model using AutoDeploy backend."""
args: LlmArgs
_factory: ModelFactory
@property
def factory(self) -> ModelFactory:
if not getattr(self, "_factory", None):
self._factory = self.args.create_factory()
return self._factory
def __init__(self, *args, **kwargs):
kwargs["backend"] = "_autodeploy"
super().__init__(*args, **kwargs)
def _try_load_tokenizer(self) -> Optional[TokenizerBase]:
if self.args.skip_tokenizer_init:
return None
return tokenizer_factory(self.factory.init_tokenizer())
def _validate_args_for_torch_backend(self, kwargs: dict) -> None:
"""We don't need to validate args for AutoDeploy backend for now."""
pass
def _create_input_processor(self) -> ADInputProcessor:
processor = self.factory.init_processor()
base = ADInputProcessor(self.tokenizer, processor)
if hasattr(self.factory, "init_input_processor"):
return self.factory.init_input_processor(base)
return base
def _prefetch_model(self):
"""Prefetch the model for the LLM."""
self.factory.prefetch_checkpoint()
def _build_model(self):
"""Build the model for the LLM.
This is a wrapper around the regular build model method that prefetches the model with the
factory.
"""
# prefetch model with factory
self._prefetch_model()
# NOTE (lucaslie): do regular build model, we bypass the regular LLM CachedModelLoader in
# _autodeploy backend.
super()._build_model()
# now correct input processor
assert isinstance(self.input_processor, DefaultInputProcessor)
assert self.tokenizer is None or isinstance(self.tokenizer, TransformersTokenizer)
self.input_processor = self._create_input_processor()
class DemoLLM(LLM):
"""A simple LLM class to demo the LLM interface while debugging the e2e workflow.
This is a very simple implementation of an LLM class that can be hacked and used for debugging.
"""
def __init__(self, **kwargs):
self.args: LlmArgs = LlmArgs(**kwargs)
self.mpi_session = None
self.runtime_context = None
# prefetch model and load tokenizer
self._prefetch_model()
self._tokenizer = self._try_load_tokenizer()
self._hf_model_config = self._try_load_hf_model_config()
self._generation_config = self._try_load_generation_config()
self.input_processor = self._create_input_processor()
# construct demo executor + engine
self._executor = DemoGenerationExecutor(
world_size=self.args.world_size,
tokenizer=self.tokenizer,
ad_config=self.args,
)
def __del__(self):
"""Ensure proper cleanup of distributed resources."""
if hasattr(self, "_executor") and self._executor is not None:
self._executor.shutdown()
# Call cleanup to ensure process group is properly destroyed
dist_ad.cleanup()
@staticmethod
def _handle_response(request_output: RequestOutput, response: List[CompletionOutput]):
request_output._done = True
gen_request = request_output._generation_request
for i, out in enumerate(response):
out.text = request_output.tokenizer.decode(
out.token_ids,
skip_special_tokens=gen_request.sampling_params.skip_special_tokens,
spaces_between_special_tokens=gen_request.sampling_params.spaces_between_special_tokens,
)
request_output._context_logits = out._postprocess_result["context_logits"]
request_output._outputs[i] = out
def generate_async(self, *args, **kwargs) -> RequestOutput:
request_output = super().generate_async(*args, **kwargs)
# patch the handle_output method for our use case
request_output._handle_response = types.MethodType(self._handle_response, request_output)
return request_output