Skip to content

Commit ef54810

Browse files
authored
Add context loading for ChatModel (mlflow#19250)
Signed-off-by: Ben Wilson <[email protected]>
1 parent 65aa828 commit ef54810

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

mlflow/pyfunc/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3172,7 +3172,7 @@ def predict(self, context, model_input: List[str], params=None) -> List[str]:
31723172
)
31733173
elif isinstance(python_model, ChatAgent):
31743174
input_example = _save_model_chat_agent_helper(
3175-
python_model, mlflow_model, signature, input_example
3175+
python_model, mlflow_model, signature, input_example, artifacts, model_config
31763176
)
31773177
elif IS_RESPONSES_AGENT_AVAILABLE and isinstance(python_model, ResponsesAgent):
31783178
input_example = _save_model_responses_agent_helper(
@@ -3754,7 +3754,9 @@ def _save_model_with_loader_module_and_data_path(
37543754
return mlflow_model
37553755

37563756

3757-
def _save_model_chat_agent_helper(python_model, mlflow_model, signature, input_example):
3757+
def _save_model_chat_agent_helper(
3758+
python_model, mlflow_model, signature, input_example, artifacts, model_config
3759+
):
37583760
"""Helper method for save_model for ChatAgent models
37593761
37603762
Returns: a dict input_example
@@ -3792,6 +3794,8 @@ def _save_model_chat_agent_helper(python_model, mlflow_model, signature, input_e
37923794
input_example = CHAT_AGENT_INPUT_EXAMPLE
37933795

37943796
_logger.info("Predicting on input example to validate output")
3797+
context = PythonModelContext(artifacts, model_config)
3798+
python_model.load_context(context)
37953799
request = ChatAgentRequest(**input_example)
37963800
output = python_model.predict(request.messages, request.context, request.custom_inputs)
37973801
try:

tests/pyfunc/test_chat_agent.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,3 +430,41 @@ def test_chat_agent_predict_with_params(tmp_path):
430430
responses = list(loaded_model.predict_stream(CHAT_AGENT_INPUT_EXAMPLE, params=None))
431431
for i, resp in enumerate(responses[:-1]):
432432
assert resp["delta"]["content"] == f"message {i}"
433+
434+
435+
def test_chat_agent_load_context_called_during_save(tmp_path):
436+
class ChatAgentWithArtifacts(ChatAgent):
437+
def __init__(self):
438+
self.prefix = None
439+
440+
def load_context(self, context):
441+
self.prefix = "loaded_prefix"
442+
443+
def predict(
444+
self,
445+
messages: list[ChatAgentMessage],
446+
context: ChatContext,
447+
custom_inputs: dict[str, Any],
448+
) -> ChatAgentResponse:
449+
if self.prefix is None:
450+
raise ValueError("load_context was not called - prefix is None")
451+
return ChatAgentResponse(
452+
messages=[
453+
{
454+
"role": "assistant",
455+
"content": f"{self.prefix}: {messages[0].content}",
456+
"id": str(uuid4()),
457+
}
458+
]
459+
)
460+
461+
model = ChatAgentWithArtifacts()
462+
save_path = tmp_path / "model"
463+
mlflow.pyfunc.save_model(
464+
python_model=model,
465+
path=save_path,
466+
)
467+
468+
loaded_model = mlflow.pyfunc.load_model(save_path)
469+
response = loaded_model.predict({"messages": [{"role": "user", "content": "Hello!"}]})
470+
assert response["messages"][0]["content"] == "loaded_prefix: Hello!"

0 commit comments

Comments
 (0)