Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion lmms_eval/models/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,29 @@
load_dotenv(verbose=True)


def _content_parts_are_text_only(content) -> bool:
if isinstance(content, str):
return True
if not isinstance(content, list):
return False
return all(isinstance(part, dict) and part.get("type") == "text" for part in content)


def _ctx_to_text_chat_messages(ctx: str, chat_messages_raw):
if not isinstance(ctx, str) or not ctx:
return chat_messages_raw
if not isinstance(chat_messages_raw, list) or len(chat_messages_raw) != 1:
return chat_messages_raw

message = chat_messages_raw[0]
if not isinstance(message, dict) or message.get("role") != "user":
return chat_messages_raw
if not _content_parts_are_text_only(message.get("content")):
return chat_messages_raw

return [{"role": "user", "content": [{"type": "text", "text": ctx}]}]


@register_model("openai")
class OpenAICompatible(OpenAICompatibleSimple):
is_simple = False
Expand Down Expand Up @@ -172,9 +195,10 @@ def maybe_update_concurrency(force: bool = False) -> None:

def build_payload_for_index(global_index: int) -> dict:
req = reordered_requests[global_index]
_, doc_to_messages, gen_kwargs, doc_id, task, split = req.args
ctx, doc_to_messages, gen_kwargs, doc_id, task, split = req.args

chat_messages_raw = doc_to_messages(self.task_dict[task][split][doc_id])
chat_messages_raw = _ctx_to_text_chat_messages(ctx, chat_messages_raw)
chat_messages: ChatMessages = ChatMessages(**{"messages": chat_messages_raw})
request_gen_kwargs = dict(gen_kwargs)
max_new_tokens = min(request_gen_kwargs.get("max_new_tokens", 1024), 4096)
Expand Down
31 changes: 31 additions & 0 deletions test/models/test_openai_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from lmms_eval.models.chat.openai import _ctx_to_text_chat_messages


def test_ctx_replaces_auto_text_only_doc_message():
ctx = "Task instructions.\n\nFew-shot example.\n\nQuestion: What is 2 + 2?"
raw_messages = [{"role": "user", "content": [{"type": "text", "text": "Question: What is 2 + 2?"}]}]

assert _ctx_to_text_chat_messages(ctx, raw_messages) == [{"role": "user", "content": [{"type": "text", "text": ctx}]}]


def test_ctx_does_not_replace_multimodal_messages():
raw_messages = [
{
"role": "user",
"content": [
{"type": "image", "url": "chart.png"},
{"type": "text", "text": "What is shown?"},
],
}
]

assert _ctx_to_text_chat_messages("Use the full benchmark prompt", raw_messages) is raw_messages


def test_ctx_does_not_replace_explicit_multi_message_chat():
raw_messages = [
{"role": "system", "content": [{"type": "text", "text": "Be concise."}]},
{"role": "user", "content": [{"type": "text", "text": "Question"}]},
]

assert _ctx_to_text_chat_messages("Full prompt", raw_messages) is raw_messages