Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions integrations/langchain/src/databricks_langchain/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
return {"role": "user", **message_dict}
elif isinstance(message, AIMessage):
if tool_calls := _get_tool_calls_from_ai_message(message):
print(tool_calls)
message_dict["tool_calls"] = tool_calls # type: ignore[assignment]
# If tool calls present, content null value should be None not empty string.
message_dict["content"] = message_dict["content"] or None # type: ignore[assignment]
Expand Down Expand Up @@ -1196,15 +1197,37 @@ def _get_tool_calls_from_ai_message(message: AIMessage) -> List[Dict]:
for tc in message.invalid_tool_calls
]

"""
thought signature encodes model reasoning
it is required for each tool call to gemini 3 pro - https://arc.net/l/quote/jhoeoqbl
this means we need to encode this info in the responses events in order to fix this bug, in addition to the work on this PR

will have to change _langchain_message_stream_to_responses_stream
"""

if tool_calls or invalid_tool_calls:
return tool_calls + invalid_tool_calls
# Merge thoughtSignature from additional_kwargs if present
all_tool_calls = tool_calls + invalid_tool_calls
additional_tool_calls = message.additional_kwargs.get("tool_calls", [])
if additional_tool_calls:
# Create a mapping of tool call IDs to their thoughtSignature
thought_signatures = {
tc.get("id"): tc.get("thoughtSignature")
for tc in additional_tool_calls
if tc.get("thoughtSignature")
}
# Add thoughtSignature to matching tool calls
for tc in all_tool_calls:
if tc["id"] in thought_signatures:
tc["thoughtSignature"] = thought_signatures[tc["id"]]
return all_tool_calls

# Get tool calls from additional kwargs if present.
return [
{
k: v
for k, v in tool_call.items() # type: ignore[union-attr]
if k in {"id", "type", "function"}
if k in {"id", "type", "function", "thoughtSignature"}
}
for tool_call in message.additional_kwargs.get("tool_calls", [])
]
Expand Down
200 changes: 200 additions & 0 deletions integrations/langchain/tests/integration_tests/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from typing import Annotated, Any, Generator, Optional, Sequence, TypedDict, Union

import mlflow
from langchain.messages import AIMessage, AIMessageChunk, AnyMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
ResponsesAgentRequest,
ResponsesAgentResponse,
ResponsesAgentStreamEvent,
output_to_responses_items_stream,
to_chat_completions_input,
)

from databricks_langchain import (
ChatDatabricks,
UCFunctionToolkit,
)

############################################
# Define your LLM endpoint and system prompt
############################################
# TODO: Replace with your model serving endpoint
# LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4-5"
LLM_ENDPOINT_NAME = "databricks-gemini-3-pro"
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

# TODO: Update with your system prompt
system_prompt = "You are a helpful assistant that can run Python code."

###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/en/generative-ai/agent-framework/agent-tool.html
###############################################################################
tools = []

# You can use UDFs in Unity Catalog as agent tools
# Below, we add the `system.ai.python_exec` UDF, which provides
# a python code interpreter tool to our agent
# You can also add local LangChain python tools. See https://python.langchain.com/docs/concepts/tools

# TODO: Add additional tools
UC_TOOL_NAMES = ["system.ai.python_exec"]
uc_toolkit = UCFunctionToolkit(function_names=UC_TOOL_NAMES)
tools.extend(uc_toolkit.tools)

# Use Databricks vector search indexes as tools
# See https://docs.databricks.com/en/generative-ai/agent-framework/unstructured-retrieval-tools.html#locally-develop-vector-search-retriever-tools-with-ai-bridge
# List to store vector search tool instances for unstructured retrieval.
VECTOR_SEARCH_TOOLS = []

# To add vector search retriever tools,
# use VectorSearchRetrieverTool and create_tool_info,
# then append the result to TOOL_INFOS.
# Example:
# VECTOR_SEARCH_TOOLS.append(
# VectorSearchRetrieverTool(
# index_name="",
# # filters="..."
# )
# )

tools.extend(VECTOR_SEARCH_TOOLS)

#####################
## Define agent logic
#####################


class AgentState(TypedDict):
messages: Annotated[Sequence[AnyMessage], add_messages]
custom_inputs: Optional[dict[str, Any]]
custom_outputs: Optional[dict[str, Any]]


def create_tool_calling_agent(
model: ChatDatabricks,
tools: Union[ToolNode, Sequence[BaseTool]],
system_prompt: Optional[str] = None,
):
model = model.bind_tools(tools)

# Define the function that determines which node to go to
def should_continue(state: AgentState):
messages = state["messages"]
last_message = messages[-1]
# If there are function calls, continue. else, end
if isinstance(last_message, AIMessage) and last_message.tool_calls:
return "continue"
else:
return "end"

if system_prompt:
preprocessor = RunnableLambda(
lambda state: [{"role": "system", "content": system_prompt}] + state["messages"]
)
else:
preprocessor = RunnableLambda(lambda state: state["messages"])
model_runnable = preprocessor | model

def call_model(
state: AgentState,
config: RunnableConfig,
):
response = model_runnable.invoke(state, config)

return {"messages": [response]}

workflow = StateGraph(AgentState)

workflow.add_node("agent", RunnableLambda(call_model))
workflow.add_node("tools", ToolNode(tools))

workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
should_continue,
{
"continue": "tools",
"end": END,
},
)
workflow.add_edge("tools", "agent")

return workflow.compile()


class LangGraphResponsesAgent(ResponsesAgent):
def __init__(self, agent):
self.agent = agent

def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
session_id = None
if request.custom_inputs and "session_id" in request.custom_inputs:
session_id = request.custom_inputs.get("session_id")
elif request.context and request.context.conversation_id:
session_id = request.context.conversation_id

if session_id:
mlflow.update_current_trace(
metadata={
"mlflow.trace.session": session_id,
}
)

outputs = [
event.item
for event in self.predict_stream(request)
if event.type == "response.output_item.done"
]
return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs)

def predict_stream(
self,
request: ResponsesAgentRequest,
) -> Generator[ResponsesAgentStreamEvent, None, None]:
session_id = None
if request.custom_inputs and "session_id" in request.custom_inputs:
session_id = request.custom_inputs.get("session_id")
elif request.context and request.context.conversation_id:
session_id = request.context.conversation_id

if session_id:
mlflow.update_current_trace(
metadata={
"mlflow.trace.session": session_id,
}
)

cc_msgs = to_chat_completions_input([i.model_dump() for i in request.input])

for event in self.agent.stream({"messages": cc_msgs}, stream_mode=["updates", "messages"]):
if event[0] == "updates":
for node_data in event[1].values():
if len(node_data.get("messages", [])) > 0:
yield from output_to_responses_items_stream(node_data["messages"])
# filter the streamed messages to just the generated text messages
elif event[0] == "messages":
try:
chunk = event[1][0]
if isinstance(chunk, AIMessageChunk) and (content := chunk.content):
yield ResponsesAgentStreamEvent(
**self.create_text_delta(delta=content, item_id=chunk.id),
)
except Exception as e:
print(e)


# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
mlflow.langchain.autolog()
agent = create_tool_calling_agent(llm, tools, system_prompt)
AGENT = LangGraphResponsesAgent(agent)
mlflow.models.set_model(AGENT)
75 changes: 54 additions & 21 deletions integrations/langchain/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,9 @@ class GetWeather(BaseModel):
return

# Models should make at least one tool call when tool_choice is not "none"
assert len(response.tool_calls) >= 1, (
f"Expected at least 1 tool call, got {len(response.tool_calls)}"
)
assert (
len(response.tool_calls) >= 1
), f"Expected at least 1 tool call, got {len(response.tool_calls)}"

# The first tool call should be for GetWeather
first_call = response.tool_calls[0]
Expand All @@ -267,9 +267,9 @@ class GetWeather(BaseModel):
]
)
# Should call GetWeather tool for the followup question
assert len(response.tool_calls) >= 1, (
f"Expected at least 1 tool call, got {len(response.tool_calls)}"
)
assert (
len(response.tool_calls) >= 1
), f"Expected at least 1 tool call, got {len(response.tool_calls)}"
tool_call = response.tool_calls[0]
assert tool_call["name"] == "GetWeather", f"Expected GetWeather tool, got {tool_call['name']}"
assert "location" in tool_call["args"], f"Expected location in args, got {tool_call['args']}"
Expand Down Expand Up @@ -559,12 +559,8 @@ def test_chat_databricks_chatagent_invoke():
):
python_tool_used = True

assert has_tool_calls, (
f"Expected ChatAgent to use tool calls for fibonacci computation. Content: {response.content}"
)
assert python_tool_used, (
f"Expected ChatAgent to use python execution tool for fibonacci computation. Content: {response.content}"
)
assert has_tool_calls, f"Expected ChatAgent to use tool calls for fibonacci computation. Content: {response.content}"
assert python_tool_used, f"Expected ChatAgent to use python execution tool for fibonacci computation. Content: {response.content}"


@pytest.mark.st_endpoints
Expand Down Expand Up @@ -847,9 +843,9 @@ def test_chat_databricks_gpt5_stream_with_usage():
]

# Should have exactly ONE usage chunk from the final usage-only chunk
assert len(usage_chunks) == 1, (
f"Expected exactly 1 usage chunk from GPT-5 final chunk, got {len(usage_chunks)}"
)
assert (
len(usage_chunks) == 1
), f"Expected exactly 1 usage chunk from GPT-5 final chunk, got {len(usage_chunks)}"

# Verify usage chunk has correct metadata structure
usage_chunk = usage_chunks[0]
Expand All @@ -860,12 +856,12 @@ def test_chat_databricks_gpt5_stream_with_usage():
assert "total_tokens" in usage_chunk.usage_metadata

# Verify token counts are positive
assert usage_chunk.usage_metadata["input_tokens"] > 0, (
f"Expected positive input_tokens, got {usage_chunk.usage_metadata['input_tokens']}"
)
assert usage_chunk.usage_metadata["output_tokens"] > 0, (
f"Expected positive output_tokens, got {usage_chunk.usage_metadata['output_tokens']}"
)
assert (
usage_chunk.usage_metadata["input_tokens"] > 0
), f"Expected positive input_tokens, got {usage_chunk.usage_metadata['input_tokens']}"
assert (
usage_chunk.usage_metadata["output_tokens"] > 0
), f"Expected positive output_tokens, got {usage_chunk.usage_metadata['output_tokens']}"

# Verify total_tokens equals sum of input and output
expected_total = (
Expand All @@ -875,3 +871,40 @@ def test_chat_databricks_gpt5_stream_with_usage():
f"Expected total_tokens ({usage_chunk.usage_metadata['total_tokens']}) "
f"to equal input_tokens + output_tokens ({expected_total})"
)


def test_chat_databricks_with_gemini():
os.environ["DATABRICKS_CONFIG_PROFILE"] = "dogfood"
from .agent import AGENT

result = AGENT.predict(
{
"input": [
{"role": "user", "content": "What is 6*7 in Python?"},
{
"type": "function_call",
"id": "lc_run--e58dec26-ce5d-4597-b4f8-28e6db62cd49",
"call_id": "system__ai__python_exec",
"name": "system__ai__python_exec",
"arguments": '{"code": "print(6 * 7)"}',
},
{
"type": "function_call_output",
"call_id": "system__ai__python_exec",
"output": '{"format": "SCALAR", "value": "42\\n"}',
},
{
"type": "message",
"id": "lc_run--dd658def-dfdc-4bc7-b0d9-b6e25d1ecc48",
"content": [
{"text": "The result of `6 * 7` in Python is 42.", "type": "output_text"}
],
"role": "assistant",
},
]
}
)
assert result is not None
assert result.output is not None
print(result.model_dump())
assert False
Loading
Loading