Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies = [
"fastapi>=0.115.11",
"langchain>=0.3.20",
"langchain-core>=0.3.41",
"langchain-mcp-adapters>=0.1.7",
"langchain-ollama>=0.2.3",
"langchain-openai>=0.3.7",
"langgraph>=0.3.12",
Expand Down
7 changes: 6 additions & 1 deletion src/dremioai/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
BaseModel,
ConfigDict,
field_serializer,
AliasChoices,
)
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Optional, Union, Annotated, Self, List, Dict, Any, Callable
Expand Down Expand Up @@ -120,7 +121,11 @@ class Dremio(BaseModel):
]
raw_pat: Optional[str] = Field(default=None, alias="pat")
project_id: Optional[str] = None
enable_experimental: Optional[bool] = False # enable experimental tools
enable_experimental: Optional[bool] = Field(
default=False,
alias=AliasChoices("enable_experimental", "enable_search"),
description="enable experimental tools",
)
oauth2: Optional[OAuth2] = None
allow_dml: Optional[bool] = False
model_config = ConfigDict(validate_assignment=True)
Expand Down
1 change: 1 addition & 0 deletions src/dremioai/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def configure(enable_json_logging=None, to_file=False):
log_file_path, maxBytes=10 * 1024 * 1024, backupCount=5 # 10MB
)
file_handler.setLevel(level())
logging.getLogger().handlers.clear()
logging.getLogger().addHandler(file_handler)

renderer = (
Expand Down
147 changes: 88 additions & 59 deletions src/dremioai/servers/frameworks/langchain/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,34 +20,32 @@
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import ToolMessage, AIMessage
from langchain_core.language_models import LanguageModelLike
from langchain_mcp_adapters.client import MultiServerMCPClient

from prompt_toolkit import PromptSession, print_formatted_text
from prompt_toolkit.history import FileHistory
from prompt_toolkit.styles import Style
from prompt_toolkit.formatted_text import HTML

from dremioai.tools.tools import get_tools, get_resources
from dremioai.log import logger
from pathlib import Path
from typing import Dict
import sys
import asyncio

from typing import List, Optional, Annotated
from dremioai import log
from dremioai.config import settings
from typer import Typer, Option
from rich.console import Console
from rich.markdown import Markdown
from dremioai.servers.frameworks.langchain.tools import discover_tools, discover_prompt

# def system_prompt():
# return """
# You are helpful AI bot with access to several tools for analyzing Dremio cluster and jobs.
# The information about jobs is stored in a table called sys.project.jobs_recent.
# There is a tool avaialble for each of the following operations:
# """ + "\n".join(
# f"{t.__name__}: {'\n\t'.join(t.invoke.__doc__.splitlines())}"
# for t in (get_tools() + get_resources())
# )

from dremioai.servers.frameworks.langchain.tools import (
discover_tools,
discover_prompt,
StructuredTool,
ChatPromptTemplate,
)

tl = Typer(
context_settings=dict(help_option_names=["-h", "--help"]),
Expand All @@ -56,22 +54,12 @@
)


@tl.command()
def main(
config_file: Annotated[
Optional[Path],
Option(
"-c",
"--cfg",
help="The config yaml for various options",
show_default=True,
),
] = settings.default_config(),
debug: Annotated[bool, Option(help="Enable debug logging")] = False,
async def user_input(
tools: List[StructuredTool],
prompt: ChatPromptTemplate,
llm: LanguageModelLike,
debug: bool = False,
):

settings.configure(config_file)

custom_style = Style.from_dict(
{
"prompt": "ansicyan", # Input prompt color
Expand All @@ -81,61 +69,102 @@ def main(
session = PromptSession(
history=FileHistory(Path.home() / ".mcp.history"), style=custom_style
)

if settings.instance().langchain.ollama is not None:
llm = ChatOllama(
model=settings.instance().langchain.ollama.model,
temperature=0,
verbose=True,
)
else:
llm = ChatOpenAI(
model=settings.instance().langchain.openai.model,
temperature=0,
verbose=True,
api_key=settings.instance().langchain.openai.api_key,
)

executor = create_react_agent(
model=llm,
tools=discover_tools(settings.instance().tools.server_mode),
prompt=discover_prompt(),
debug=debug,
)

# agent = create_openai_tools_agent(
# llm=llm, tools=registered_tools, prompt=chat_prompt
# )

# agent_executor = AgentExecutor(agent=agent, tools=registered_tools, verbose=True)
# chain = chat_prompt | llm | StrOutputParser()
chat_history = []
console = Console()

while True:
try:
user_input = session.prompt(HTML("<prompt>>> </prompt>"))
user_input = await session.prompt_async(HTML("<prompt>>> </prompt>"))
if user_input.lower() == "q":
break
except EOFError:
break
args = {"messages": [("human", user_input)]}
if chat_history:
args["chat_history"] = chat_history
response: Dict[str, List[ToolMessage, AIMessage]] = executor.invoke(args)
# for message in response:
# if isinstance(message, ToolMessage):
# message.pretty_print()
response: Dict[str, List[ToolMessage, AIMessage]] = await executor.ainvoke(args)

console.print(Markdown(str(response["messages"][-1].content)))
chat_history.extend(
[("human", user_input), ("system", response["messages"][-1].content)]
)
# response = agent_executor.invoke({"input": user_input})
# if output := response.get("output"):
# pp(output)
# else:
# pp(response)


async def using_mcp(llm: LanguageModelLike, config_file: Path, debug: bool = False):
client = MultiServerMCPClient(
{
"dremioai": {
"command": sys.executable,
"args": [
"-m",
"dremioai.servers.mcp",
"run",
"-c",
str(config_file),
],
"transport": "stdio",
}
}
)
async with client.session("dremioai") as session:
tools = await session.list_tools()
# prompt = await session.list_prompts()
# logger().info(f"Found prompt={prompt}")
prompt = await session.get_prompt("System Prompt")
prompt = discover_prompt(prompt.messages[0].content.text)
logger().info(f"Found {len(tools.tools)} tools and prompt={prompt}")
await user_input(tools, prompt, llm, debug)


@tl.command()
def main(
config_file: Annotated[
Optional[Path],
Option(
"-c",
"--cfg",
help="The config yaml for various options",
show_default=True,
),
] = settings.default_config(),
debug: Annotated[bool, Option(help="Enable debug logging")] = False,
use_as_mcp: Annotated[
bool,
Option(help="Use the server as a MCP server instead of using tools directly"),
] = True,
):
log.configure(to_file=True)

settings.configure(config_file)

if settings.instance().langchain.ollama is not None:
llm = ChatOllama(
model=settings.instance().langchain.ollama.model,
temperature=0,
verbose=True,
)
else:
llm = ChatOpenAI(
model=settings.instance().langchain.openai.model,
temperature=0,
verbose=True,
api_key=settings.instance().langchain.openai.api_key,
)

if use_as_mcp:
asyncio.run(using_mcp(llm, config_file, debug))
else:
tools = discover_tools(settings.instance().tools.server_mode)
prompt = discover_prompt()
logger().info(f"[non mcp] Found {len(tools.tools)} tools and prompt={prompt}")
asyncio.run(user_input(tools, prompt, llm, debug))


def cli():
Expand Down
16 changes: 9 additions & 7 deletions src/dremioai/servers/frameworks/langchain/tools.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#
#
# Copyright (C) 2017-2025 Dremio Corporation
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
from langchain_core.tools.structured import StructuredTool
from langchain_core.tools.base import create_schema_from_function
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
Expand Down Expand Up @@ -40,10 +40,12 @@ def discover_tools(For: ToolType = None) -> List[StructuredTool]:
return [instantiate(tool) for tool in get_tools(For=For)]


def discover_prompt() -> ChatPromptTemplate:
def discover_prompt(with_prompt: str = None) -> ChatPromptTemplate:
if with_prompt is None:
with_prompt = system_prompt()
return ChatPromptTemplate.from_messages(
[
("system", system_prompt()),
("system", with_prompt),
(
"system",
"You must respond in Markdown or tables in tab separated values format",
Expand Down
15 changes: 15 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.