Skip to content

Commit 6c05e90

Browse files
Add LoRA accuracy validation to adapter integration tests (deepjavalibrary#2985)
1 parent cc8c9a4 commit 6c05e90

File tree

2 files changed

+347
-14
lines changed

2 files changed

+347
-14
lines changed

tests/integration/llm/client.py

Lines changed: 338 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,23 +195,23 @@ def get_model_name():
195195
"batch_size": [4],
196196
"seq_length": [16, 32],
197197
"worker": 1,
198-
"adapters": ["french", "spanish"],
198+
"adapters": ["medical", "exam"],
199199
"tokenizer": "unsloth/llama-3-8b-Instruct"
200200
},
201201
"llama3-8b-unmerged-lora-with-custom-code": {
202202
"option.model_id": "s3://djl-llm/llama-3-8b-instruct-hf/",
203203
"batch_size": [4],
204204
"seq_length": [16, 32],
205205
"worker": 1,
206-
"adapters": ["french", "spanish"],
206+
"adapters": ["medical", "exam"],
207207
"tokenizer": "unsloth/llama-3-8b-Instruct",
208208
"add_output_formatter": True,
209209
},
210210
"gemma-7b-unmerged-lora": {
211211
"batch_size": [4],
212212
"seq_length": [16, 32],
213213
"worker": 1,
214-
"adapters": ["alpaca", "dante"],
214+
"adapters": ["chatml", "claude3sonnet"],
215215
"tokenizer": "unsloth/gemma-7b"
216216
},
217217
"phi2-unmerged-lora": {
@@ -767,6 +767,270 @@ def send_json(data, headers={}):
767767
return resp
768768

769769

770+
def extract_generated_text(response_content):
771+
"""
772+
Extract generated_text from text completion response content.
773+
Handles both streaming (multiple JSON lines) and non-streaming formats.
774+
775+
Args:
776+
response_content: The response content string from the model server
777+
778+
Returns:
779+
The generated_text string if found, None otherwise
780+
"""
781+
if not response_content or response_content.strip() == "":
782+
return None
783+
784+
lines = response_content.strip().split('\n')
785+
# Iterate in reverse to find the final response with generated_text
786+
for line in reversed(lines):
787+
line = line.strip()
788+
if not line:
789+
continue
790+
try:
791+
parsed = json.loads(line)
792+
if parsed.get("generated_text") is not None:
793+
return parsed["generated_text"]
794+
except json.JSONDecodeError:
795+
continue
796+
return None
797+
798+
799+
def extract_chat_content(response_content):
800+
"""
801+
Extract content from chat completion response.
802+
Handles both streaming and non-streaming chat completion formats.
803+
804+
Args:
805+
response_content: The response content string from the model server
806+
807+
Returns:
808+
The content string if found, None otherwise
809+
"""
810+
if not response_content or response_content.strip() == "":
811+
return None
812+
813+
lines = response_content.strip().split('\n')
814+
# Iterate in reverse to find the final response with content
815+
for line in reversed(lines):
816+
line = line.strip()
817+
if not line:
818+
continue
819+
try:
820+
parsed = json.loads(line)
821+
# Non-streaming chat completion format
822+
if parsed.get("choices") and parsed.get(
823+
"object") == "chat.completion":
824+
return parsed["choices"][0].get("message",
825+
{}).get("content", "")
826+
# Streaming chat completion format - look for final chunk with content
827+
# For simplicity in accuracy tests, we use non-streaming mode
828+
except json.JSONDecodeError:
829+
continue
830+
return None
831+
832+
833+
def validate_determinism(outputs_1, outputs_2, label=""):
834+
"""
835+
Validate that two invocations produce identical outputs.
836+
837+
Note: LoRA adapters may exhibit non-determinism with vLLM's greedy decoding (temperature=0).
838+
See: https://github.com/vllm-project/vllm/issues/7977
839+
840+
Args:
841+
outputs_1: First invocation outputs (dict for adapters, string for base)
842+
outputs_2: Second invocation outputs
843+
label: Description for logging (e.g., "base model", "adapter french")
844+
"""
845+
if isinstance(outputs_1, dict) and isinstance(outputs_2, dict):
846+
# Adapter outputs
847+
for adapter in outputs_1.keys():
848+
out1 = outputs_1.get(adapter)
849+
out2 = outputs_2.get(adapter)
850+
if out1 != out2:
851+
raise AssertionError(
852+
f"Adapter '{adapter}' not deterministic! Output 1: '{out1[:100] if out1 else None}...' != Output 2: '{out2[:100] if out2 else None}...'"
853+
)
854+
LOGGER.info(f"✓ Determinism verified for {len(outputs_1)} adapters")
855+
else:
856+
# Base model output
857+
if outputs_1 != outputs_2:
858+
raise AssertionError(
859+
f"{label} not deterministic! Output 1: '{outputs_1[:100] if outputs_1 else None}...' != Output 2: '{outputs_2[:100] if outputs_2 else None}...'"
860+
)
861+
LOGGER.info(f"✓ Determinism verified for {label}")
862+
863+
864+
def validate_lora_differentiation(base_output, adapter_outputs):
865+
"""
866+
Validate that adapter outputs differ from base model and from each other.
867+
868+
This validates that LoRA adapters are actually being applied and producing
869+
different outputs than the base model.
870+
"""
871+
if not base_output or not adapter_outputs:
872+
LOGGER.warning("Missing outputs, skipping differentiation validation")
873+
return
874+
875+
# Check adapters differ from base
876+
for name, output in adapter_outputs.items():
877+
if output and output == base_output:
878+
raise AssertionError(
879+
f"Adapter '{name}' produced same output as base model - adapter may not be applied"
880+
)
881+
882+
# Check adapters differ from each other
883+
outputs = list(adapter_outputs.items())
884+
for i, (n1, o1) in enumerate(outputs):
885+
for n2, o2 in outputs[i + 1:]:
886+
if o1 and o2 and o1 == o2:
887+
raise AssertionError(
888+
f"Adapters '{n1}' and '{n2}' produced identical outputs")
889+
890+
LOGGER.info(
891+
f"✓ Differentiation verified: {len(adapter_outputs)} adapters all different from base and each other"
892+
)
893+
894+
895+
def collect_lora_outputs(adapters, input_text, seq_length, stream=False):
896+
"""
897+
Collect outputs from adapters and base model using deterministic parameters.
898+
899+
Args:
900+
adapters: List of adapter names
901+
input_text: Input prompt to use
902+
seq_length: Max new tokens to generate
903+
stream: Whether to use streaming mode
904+
905+
Returns:
906+
Tuple of (adapter_outputs dict, base_model_output string)
907+
"""
908+
adapter_outputs = {}
909+
deterministic_params = {
910+
"do_sample": False,
911+
"temperature": 0.0,
912+
"top_p": 1.0,
913+
"top_k": -1,
914+
"seed": 42,
915+
"max_new_tokens": seq_length,
916+
"details": True
917+
}
918+
919+
if not adapters:
920+
LOGGER.warning(
921+
"No adapters provided, skipping adapter output collection")
922+
return adapter_outputs, None
923+
924+
# Collect adapter outputs
925+
for adapter in adapters:
926+
req = {
927+
"inputs": input_text,
928+
"parameters": deterministic_params,
929+
"adapters": adapter,
930+
"stream": stream
931+
}
932+
LOGGER.info(f"LoRA accuracy req for adapter '{adapter}': {req}")
933+
res = send_json(req)
934+
message = res.content.decode("utf-8")
935+
LOGGER.info(f"LoRA accuracy res for adapter '{adapter}': {message}")
936+
generated_text = extract_generated_text(message)
937+
if generated_text is not None:
938+
adapter_outputs[adapter] = generated_text
939+
LOGGER.info(
940+
f"Collected output for adapter '{adapter}': {generated_text[:100]}..."
941+
)
942+
else:
943+
LOGGER.warning(
944+
f"Could not extract generated_text for adapter '{adapter}'")
945+
946+
# Collect base model output
947+
req = {
948+
"inputs": input_text,
949+
"parameters": deterministic_params,
950+
"stream": stream
951+
}
952+
LOGGER.info(f"LoRA accuracy req for base model (no adapter): {req}")
953+
res = send_json(req)
954+
message = res.content.decode("utf-8")
955+
LOGGER.info(f"LoRA accuracy res for base model: {message}")
956+
base_output = extract_generated_text(message)
957+
if base_output is not None:
958+
LOGGER.info(f"Collected base model output: {base_output[:100]}...")
959+
else:
960+
LOGGER.warning("Could not extract generated_text for base model")
961+
962+
return adapter_outputs, base_output
963+
964+
965+
def collect_lora_outputs_chat(adapters, messages, seq_length, stream=False):
966+
"""
967+
Collect chat outputs from adapters and base model using deterministic parameters.
968+
969+
Args:
970+
adapters: List of adapter names
971+
messages: Chat messages to use
972+
seq_length: Max tokens to generate
973+
stream: Whether to use streaming mode
974+
975+
Returns:
976+
Tuple of (adapter_outputs dict, base_model_output string)
977+
"""
978+
adapter_outputs = {}
979+
980+
if not adapters:
981+
LOGGER.warning(
982+
"No adapters provided, skipping adapter output collection")
983+
return adapter_outputs, None
984+
985+
# Collect adapter outputs
986+
for adapter in adapters:
987+
req = {
988+
"messages": messages,
989+
"max_tokens": seq_length,
990+
"temperature": 0.0,
991+
"top_p": 1.0,
992+
"seed": 42,
993+
"adapters": adapter,
994+
"stream": stream
995+
}
996+
LOGGER.info(f"LoRA chat accuracy req for adapter '{adapter}': {req}")
997+
res = send_json(req)
998+
message = res.content.decode("utf-8")
999+
LOGGER.info(
1000+
f"LoRA chat accuracy res for adapter '{adapter}': {message}")
1001+
content = extract_chat_content(message)
1002+
if content is not None:
1003+
adapter_outputs[adapter] = content
1004+
LOGGER.info(
1005+
f"Collected chat output for adapter '{adapter}': {content[:100]}..."
1006+
)
1007+
else:
1008+
LOGGER.warning(
1009+
f"Could not extract chat content for adapter '{adapter}'")
1010+
1011+
# Collect base model output
1012+
req = {
1013+
"messages": messages,
1014+
"max_tokens": seq_length,
1015+
"temperature": 0.0,
1016+
"top_p": 1.0,
1017+
"seed": 42,
1018+
"stream": stream
1019+
}
1020+
LOGGER.info(f"LoRA chat accuracy req for base model (no adapter): {req}")
1021+
res = send_json(req)
1022+
message = res.content.decode("utf-8")
1023+
LOGGER.info(f"LoRA chat accuracy res for base model: {message}")
1024+
base_output = extract_chat_content(message)
1025+
if base_output is not None:
1026+
LOGGER.info(
1027+
f"Collected base model chat output: {base_output[:100]}...")
1028+
else:
1029+
LOGGER.warning("Could not extract chat content for base model")
1030+
1031+
return adapter_outputs, base_output
1032+
1033+
7701034
def find_awscurl():
7711035
command = "./awscurl -h"
7721036
try:
@@ -939,8 +1203,10 @@ def batch_generation_chat(batch_size):
9391203
"role": "assistant",
9401204
"content": "Hi, what can i help you with today?"
9411205
}, {
942-
"role": "user",
943-
"content": "What is deep learning?"
1206+
"role":
1207+
"user",
1208+
"content":
1209+
"What is deep learning in your opinion? How does it help human?"
9441210
}],
9451211
[{
9461212
"role": "user",
@@ -1526,6 +1792,39 @@ def test_handler_adapters(model, model_spec):
15261792
LOGGER.info(f"res: {message}")
15271793
response_checker(res, message)
15281794

1795+
# LoRA accuracy validation phase - collect outputs with deterministic parameters
1796+
# Test both streaming and non-streaming modes
1797+
for stream in stream_values:
1798+
LOGGER.info(f"LoRA accuracy validation with stream={stream}")
1799+
1800+
# Warm-up call to stabilize vLLM internal state (KV cache, GPU precision)
1801+
# First invocation can differ due to initialization effects
1802+
LOGGER.info("Warm-up invocation before determinism validation")
1803+
collect_lora_outputs(spec.get("adapters"),
1804+
inputs[0],
1805+
spec["seq_length"][0],
1806+
stream=stream)
1807+
1808+
# Collect outputs twice to verify determinism and LoRA differentiation
1809+
adapter_outputs_1, base_output_1 = collect_lora_outputs(
1810+
spec.get("adapters"),
1811+
inputs[0],
1812+
spec["seq_length"][0],
1813+
stream=stream)
1814+
adapter_outputs_2, base_output_2 = collect_lora_outputs(
1815+
spec.get("adapters"),
1816+
inputs[0],
1817+
spec["seq_length"][0],
1818+
stream=stream)
1819+
1820+
# Phase 1: Validate determinism
1821+
validate_determinism(base_output_1, base_output_2, "base model")
1822+
# validate_determinism(adapter_outputs_1, adapter_outputs_2)
1823+
1824+
# Phase 2: Validate differentiation (adapters differ from base and each other)
1825+
validate_lora_differentiation(base_output_1, adapter_outputs_1)
1826+
LOGGER.info("LoRA accuracy validation completed successfully")
1827+
15291828
# awscurl little benchmark phase
15301829
for i, batch_size in enumerate(spec["batch_size"]):
15311830
for seq_length in spec["seq_length"]:
@@ -1624,6 +1923,40 @@ def test_handler_adapters_chat(model, model_spec):
16241923
# Check if output formatter was applied correctly
16251924
if check_formatter:
16261925
check_output_formatter_applied(message, req["adapters"])
1926+
1927+
# LoRA accuracy validation phase - collect outputs with deterministic parameters
1928+
# Test both streaming and non-streaming modes
1929+
for stream in stream_values:
1930+
LOGGER.info(f"LoRA chat accuracy validation with stream={stream}")
1931+
1932+
# Warm-up call to stabilize vLLM internal state (KV cache, GPU precision)
1933+
# First invocation can differ due to initialization effects
1934+
LOGGER.info("Warm-up invocation before determinism validation")
1935+
collect_lora_outputs_chat(spec.get("adapters"),
1936+
messages[0],
1937+
spec["seq_length"][0],
1938+
stream=stream)
1939+
1940+
# Collect outputs twice to verify determinism and LoRA differentiation
1941+
adapter_outputs_1, base_output_1 = collect_lora_outputs_chat(
1942+
spec.get("adapters"),
1943+
messages[0],
1944+
spec["seq_length"][0],
1945+
stream=stream)
1946+
adapter_outputs_2, base_output_2 = collect_lora_outputs_chat(
1947+
spec.get("adapters"),
1948+
messages[0],
1949+
spec["seq_length"][0],
1950+
stream=stream)
1951+
1952+
# Phase 1: Validate determinism
1953+
validate_determinism(base_output_1, base_output_2, "base model")
1954+
# validate_determinism(adapter_outputs_1, adapter_outputs_2)
1955+
1956+
# Phase 2: Validate differentiation (adapters differ from base and each other)
1957+
validate_lora_differentiation(base_output_1, adapter_outputs_1)
1958+
LOGGER.info("LoRA chat accuracy validation completed successfully")
1959+
16271960
# awscurl little benchmark phase
16281961
for i, batch_size in enumerate(spec["batch_size"]):
16291962
for seq_length in spec["seq_length"]:

0 commit comments

Comments
 (0)