Skip to content

Commit 39ab90b

Browse files
ksuma2109Suma Kasa
andauthored
Make custom @output_formatter the final LMI response (deepjavalibrary#2993)
Co-authored-by: Suma Kasa <sumakasa@amazon.com>
1 parent 17cfb9e commit 39ab90b

File tree

5 files changed

+116
-24
lines changed

5 files changed

+116
-24
lines changed

engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from djl_python.inputs import Input
3232
from djl_python.outputs import Output
3333
from djl_python.encode_decode import decode
34-
from djl_python.async_utils import handle_streaming_response, create_non_stream_output, _extract_lora_adapter
34+
from djl_python.async_utils import handle_streaming_response, create_non_stream_output, create_stream_chunk_output, _extract_lora_adapter
3535
from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError
3636
from djl_python.custom_handler_service import CustomHandlerService
3737
from djl_python.rolling_batch.rolling_batch_vllm_utils import create_lora_request, get_lora_request
@@ -162,6 +162,14 @@ async def initialize(self, properties: dict):
162162
self.session_manager: SessionManager = SessionManager(properties)
163163
self.initialized = True
164164

165+
def _get_custom_formatter(self, adapter_name: Optional[str] = None) -> bool:
166+
"""Check if a custom output formatter exists for the adapter or base model."""
167+
if adapter_name:
168+
adapter_formatter = self.get_adapter_formatter_handler(adapter_name)
169+
if adapter_formatter and adapter_formatter.output_formatter:
170+
return True
171+
return self.output_formatter is not None
172+
165173
def preprocess_request(self, inputs: Input) -> ProcessedRequest:
166174
batch = inputs.get_batches()
167175
assert len(batch) == 1, "only one request per batch allowed"
@@ -255,50 +263,67 @@ async def check_health(self):
255263
logger.fatal("vLLM engine is dead, terminating process")
256264
kill_process_tree(os.getpid())
257265

258-
async def inference(
259-
self,
260-
inputs: Input) -> Union[Output, AsyncGenerator[Output, None]]:
266+
async def inference(self, inputs: Input) -> Union[Output, AsyncGenerator[Output, None]]:
261267
await self.check_health()
262268
try:
263269
processed_request = self.preprocess_request(inputs)
264270
except CustomFormatterError as e:
265271
logger.exception("Custom formatter failed")
266-
output = create_non_stream_output(
272+
return create_non_stream_output(
267273
"", error=f"Custom formatter failed: {str(e)}", code=424)
268-
return output
269274
except Exception as e:
270275
logger.exception("Input parsing failed")
271-
output = create_non_stream_output(
276+
return create_non_stream_output(
272277
"", error=f"Input parsing failed: {str(e)}", code=424)
273-
return output
274278

275279
# vLLM will extract the adapter from the request object via _maybe_get_adapters()
276280
response = await processed_request.inference_invoker(
277281
processed_request.vllm_request)
278282

283+
# Check if custom formatter exists (applies to both streaming and non-streaming)
284+
custom_formatter = self._get_custom_formatter(processed_request.adapter_name)
285+
279286
if isinstance(response, types.AsyncGeneratorType):
280-
# Apply streaming output formatter (adapter-specific or base model)
281-
response = self.apply_output_formatter_streaming_raw(
287+
return self._handle_streaming_response(response, processed_request, custom_formatter)
288+
289+
# Non-streaming response
290+
if custom_formatter:
291+
formatted_response = self.apply_output_formatter(
282292
response, adapter_name=processed_request.adapter_name)
293+
# If custom formatter returns a Pydantic model, serialize it
294+
if hasattr(formatted_response, 'model_dump_json'):
295+
formatted_response = formatted_response.model_dump_json()
296+
elif hasattr(formatted_response, 'model_dump'):
297+
formatted_response = formatted_response.model_dump()
298+
return create_non_stream_output(formatted_response)
299+
300+
# LMI formatter for non-streaming
301+
return processed_request.non_stream_output_formatter(
302+
response,
303+
request=processed_request.vllm_request,
304+
tokenizer=self.tokenizer,
305+
)
283306

284-
return handle_streaming_response(
307+
async def _handle_streaming_response(self, response, processed_request, custom_formatter):
308+
"""Handle streaming responses as an async generator"""
309+
if custom_formatter:
310+
# Custom formatter: apply to each chunk and yield directly
311+
async for chunk in response:
312+
formatted_chunk = self.apply_output_formatter(
313+
chunk, adapter_name=processed_request.adapter_name)
314+
yield create_stream_chunk_output(formatted_chunk, last_chunk=False)
315+
yield create_stream_chunk_output("", last_chunk=True)
316+
else:
317+
# LMI formatter for streaming
318+
async for output in handle_streaming_response(
285319
response,
286320
processed_request.stream_output_formatter,
287321
request=processed_request.vllm_request,
288322
accumulate_chunks=processed_request.accumulate_chunks,
289323
include_prompt=processed_request.include_prompt,
290324
tokenizer=self.tokenizer,
291-
)
292-
293-
# Apply output formatter (adapter-specific or base model)
294-
response = self.apply_output_formatter(
295-
response, adapter_name=processed_request.adapter_name)
296-
297-
return processed_request.non_stream_output_formatter(
298-
response,
299-
request=processed_request.vllm_request,
300-
tokenizer=self.tokenizer,
301-
)
325+
):
326+
yield output
302327

303328
async def add_lora(self, lora_name: str, lora_alias: str, lora_path: str):
304329
logging.info(f"Adding LoRA {lora_name} from {lora_path}")
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from djl_python.output_formatter import output_formatter
2+
3+
@output_formatter
4+
def custom_output_formatter(output):
5+
if hasattr(output, 'choices') and len(output.choices) > 0:
6+
return {
7+
"custom_formatter_applied": True,
8+
"generated_text": output.choices[0].text if hasattr(output.choices[0], 'text') else output.choices[0].message.content,
9+
"model": output.model,
10+
}
11+
return {"custom_formatter_applied": True}

tests/integration/llm/client.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1695,6 +1695,27 @@ def test_custom_formatter_async(model, model_spec):
16951695
assert "custom_formatter_applied" in message, "Output does not contain custom_formatter_applied_tag"
16961696

16971697

1698+
def test_custom_formatter_final(model, model_spec):
1699+
modelspec_checker(model, model_spec)
1700+
spec = model_spec[args.model]
1701+
if "worker" in spec:
1702+
check_worker_number(spec["worker"])
1703+
stream_values = spec.get("stream", [False, True])
1704+
req = {"inputs": batch_generation(1)[0]}
1705+
seq_length = spec["seq_length"][0]
1706+
params = {"max_new_tokens": seq_length}
1707+
req["parameters"] = params
1708+
1709+
for stream in stream_values:
1710+
req["stream"] = stream
1711+
LOGGER.info(f"req {req}")
1712+
res = send_json(req)
1713+
message = res.content.decode("utf-8")
1714+
LOGGER.info(f"res: {message}")
1715+
parsed = json.loads(message.strip().split('\n')[0])
1716+
assert "custom_formatter_applied" in message, "Output does not contain custom_formatter_applied_tag"
1717+
1718+
16981719
def check_output_formatter_applied(response_text, expected_identifier):
16991720
"""
17001721
Check if output formatter was applied correctly.
@@ -1932,7 +1953,7 @@ def test_handler_adapters_chat(model, model_spec):
19321953
res = send_json(req)
19331954
message = res.content.decode("utf-8")
19341955
LOGGER.info(f"res: {message}")
1935-
response_checker(res, message)
1956+
# response_checker(res, message)
19361957

19371958
# Check if output formatter was applied correctly
19381959
if check_formatter:
@@ -2014,6 +2035,10 @@ def test_handler_adapters_chat(model, model_spec):
20142035
line = line.strip()
20152036
if not line:
20162037
continue
2038+
2039+
if line.startswith('data: '):
2040+
line = line[6:] # Remove "data: " prefix
2041+
20172042
try:
20182043
parsed_json = json.loads(line)
20192044
# Check for text completion format
@@ -2582,6 +2607,8 @@ def run(raw_args):
25822607
test_handler_rolling_batch(args.model, vllm_model_spec)
25832608
elif args.handler == "custom":
25842609
test_custom_formatter_async(args.model, custom_formatter_spec)
2610+
elif args.handler == "custom_final":
2611+
test_custom_formatter_final(args.model, vllm_model_spec)
25852612
elif args.handler == "custom_handler":
25862613
test_custom_handler_async(args.model, custom_formatter_spec)
25872614
elif args.handler == "vllm_adapters":

tests/integration/llm/prepare.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,26 @@ def build_vllm_async_model_custom_formatters(model, error_type=None):
12731273
shutil.copy2(source_file, target_file)
12741274

12751275

1276+
def build_vllm_async_model_with_example_formatter(model):
1277+
"""Build vLLM model with test formatter to validate final output format"""
1278+
if model not in vllm_model_list.keys():
1279+
raise ValueError(
1280+
f"{model} is not one of the supporting handler {list(vllm_model_list.keys())}"
1281+
)
1282+
options = vllm_model_list[model]
1283+
options["engine"] = "Python"
1284+
options["option.rolling_batch"] = "disable"
1285+
options["option.async_mode"] = "true"
1286+
options["option.entryPoint"] = "djl_python.lmi_vllm.vllm_async_service"
1287+
write_model_artifacts(options)
1288+
1289+
# Copy test formatter
1290+
source_file = "examples/custom_formatters/example_custom_formatter.py"
1291+
target_file = "models/test/model.py"
1292+
if os.path.exists(source_file):
1293+
shutil.copy2(source_file, target_file)
1294+
1295+
12761296
def build_vllm_model(model):
12771297
if model not in vllm_model_list.keys():
12781298
raise ValueError(
@@ -1386,7 +1406,8 @@ def build_stateful_model(model):
13861406
'text_embedding': build_text_embedding_model,
13871407
'vllm_async': build_vllm_async_model,
13881408
'vllm_async_custom_formatters': build_vllm_async_model_custom_formatters,
1389-
'vllm_async_custom_handler': build_vllm_async_model_with_custom_handler
1409+
'vllm_async_custom_handler': build_vllm_async_model_with_custom_handler,
1410+
'vllm_async_example_formatter': build_vllm_async_model_with_example_formatter
13901411
}
13911412

13921413
if __name__ == '__main__':

tests/integration/tests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,14 @@ def test_custom_formatter_load_error(self):
657657
with pytest.raises(Exception):
658658
r.launch()
659659

660+
def test_custom_formatter_final_output(self):
661+
"""Test that custom formatter is the final formatter (not overridden by LMI formatter)"""
662+
with Runner("lmi", "gpt-neox-20b-custom-final") as r:
663+
prepare.build_vllm_async_model_with_example_formatter(
664+
"gpt-neox-20b-custom")
665+
r.launch()
666+
client.run("custom_final gpt-neox-20b".split())
667+
660668

661669
@pytest.mark.vllm
662670
@pytest.mark.gpu_4

0 commit comments

Comments
 (0)