Skip to content

Commit 9c3226a

Browse files
committed
include system message
1 parent 4d7745e commit 9c3226a

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

Diff for: backend/onyx/chat/prompt_builder/answer_prompt_builder.py

+13
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,19 @@ def get_user_message_content(self) -> str:
155155
query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0])
156156
return query
157157

158+
def get_message_history(self) -> list[PreviousMessage]:
159+
"""
160+
Get the message history as a list of PreviousMessage objects.
161+
"""
162+
ret = []
163+
if self.system_message_and_token_cnt:
164+
tmp = PreviousMessage.from_langchain_msg(*self.system_message_and_token_cnt)
165+
ret.append(tmp)
166+
for i, msg in enumerate(self.message_history):
167+
tmp = PreviousMessage.from_langchain_msg(msg, self.history_token_cnts[i])
168+
ret.append(tmp)
169+
return ret
170+
158171
def build(self) -> list[BaseMessage]:
159172
if not self.user_message_and_token_cnt:
160173
raise ValueError("User message must be set before building prompt")

Diff for: backend/onyx/chat/tool_handling/tool_response_handler.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,10 @@ def get_tool_call_for_non_tool_calling_llm_impl(
163163
llm: LLM,
164164
) -> tuple[Tool, dict] | None:
165165
user_query = prompt_builder.raw_user_query
166+
history = prompt_builder.raw_message_history
166167
if isinstance(prompt_builder, AnswerPromptBuilder):
167168
user_query = prompt_builder.get_user_message_content()
169+
history = prompt_builder.get_message_history()
168170

169171
if force_use_tool.force_use:
170172
# if we are forcing a tool, we don't need to check which tools to run
@@ -175,7 +177,7 @@ def get_tool_call_for_non_tool_calling_llm_impl(
175177
if force_use_tool.args is not None
176178
else tool.get_args_for_non_tool_calling_llm(
177179
query=user_query,
178-
history=prompt_builder.raw_message_history,
180+
history=history,
179181
llm=llm,
180182
force_run=True,
181183
)
@@ -193,7 +195,7 @@ def get_tool_call_for_non_tool_calling_llm_impl(
193195
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
194196
tools=tools,
195197
query=user_query,
196-
history=prompt_builder.raw_message_history,
198+
history=history,
197199
llm=llm,
198200
)
199201

@@ -210,7 +212,7 @@ def get_tool_call_for_non_tool_calling_llm_impl(
210212
chosen_tool_and_args = (
211213
select_single_tool_for_non_tool_calling_llm(
212214
tools_and_args=available_tools_and_args,
213-
history=prompt_builder.raw_message_history,
215+
history=history,
214216
query=user_query,
215217
llm=llm,
216218
)

Diff for: backend/onyx/llm/models.py

+20
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from onyx.configs.constants import MessageType
1010
from onyx.file_store.models import InMemoryChatFile
1111
from onyx.llm.utils import build_content_with_imgs
12+
from onyx.llm.utils import message_to_string
1213
from onyx.tools.models import ToolCallFinalResult
1314

1415
if TYPE_CHECKING:
@@ -59,3 +60,22 @@ def to_langchain_msg(self) -> BaseMessage:
5960
return AIMessage(content=content)
6061
else:
6162
return SystemMessage(content=content)
63+
64+
@classmethod
65+
def from_langchain_msg(
66+
cls, msg: BaseMessage, token_count: int
67+
) -> "PreviousMessage":
68+
message_type = MessageType.SYSTEM
69+
if isinstance(msg, HumanMessage):
70+
message_type = MessageType.USER
71+
elif isinstance(msg, AIMessage):
72+
message_type = MessageType.ASSISTANT
73+
message = message_to_string(msg)
74+
return cls(
75+
message=message,
76+
token_count=token_count,
77+
message_type=message_type,
78+
files=[],
79+
tool_call=None,
80+
refined_answer_improvement=None,
81+
)

0 commit comments

Comments
 (0)