Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dremioai/api/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def common_args(
Option("-c", "--config", help="The config file with Dremio connection details"),
] = None,
):
settings.configure(config_file)
settings.configure(config_file, force=True)


app = Typer(context_settings=dict(help_option_names=["-h", "--help"]))
Expand Down
14 changes: 8 additions & 6 deletions src/dremioai/api/cli/search.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#
#
# Copyright (C) 2017-2025 Dremio Corporation
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#

from typing import Annotated, Optional, List
from typer import Option, Argument, Typer
Expand Down Expand Up @@ -42,5 +42,7 @@ def do_search(
args = {"query": query}
if category:
args["filter"] = category
result = asyncio.run(search.get_search_results(search=search.Search(**args)))
result = asyncio.run(
search.get_search_results(search=search.Search(**args), use_df=use_df)
)
pp(result)
18 changes: 16 additions & 2 deletions src/dremioai/api/dremio/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ async def get_schema(
dataset_path_or_id: Optional[Union[List[str], str]],
by_id: Optional[bool] = False,
include_tags: Optional[bool] = False,
flatten: Optional[bool] = False,
) -> Dict[str, Any]:
client = AsyncHttpClient()
project_id = settings.instance().dremio.project_id
Expand Down Expand Up @@ -163,17 +164,30 @@ async def _get(id, suffix):
for i, r in enumerate(extras):
k, v = extras[r]
schema[k] = results[i].get(r, {}).get(v)
if flatten:
flattened_schema = {
"schema": {
c.get("name", ""): c.get("type", {}).get("name", "unknown")
for c in schema.get("fields", [])
}
}
for k in extras.values():
if v := schema.get(k[0]):
flattened_schema[k[0]] = v
schema = flattened_schema

return schema


async def get_schemas(
dataset_path_or_ids: List[Union[List[str], str]],
by_id: Optional[bool] = False,
include_tags: Optional[bool] = False,
) -> Dict[str, Any]:
flatten: Optional[bool] = False,
) -> List[Dict[str, Any]]:

return await run_in_parallel(
[get_schema(p, by_id, include_tags) for p in dataset_path_or_ids]
[get_schema(p, by_id, include_tags, flatten) for p in dataset_path_or_ids]
)


Expand Down
62 changes: 42 additions & 20 deletions src/dremioai/api/dremio/search.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,37 @@
#
#
# Copyright (C) 2017-2025 Dremio Corporation
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#

from pydantic import (
BaseModel,
Field,
AfterValidator,
ValidationError,
ConfigDict,
field_validator,
)
from typing import (
Annotated,
List,
Dict,
AnyStr,
Any,
Literal,
Union,
Optional,
)
from dremioai.api.util import UStrEnum
from datetime import datetime
from enum import auto, StrEnum
from enum import auto
from dremioai.config import settings
from dremioai.api.transport import DremioAsyncHttpClient as AsyncHttpClient
from dremioai.api.dremio.catalog import get_schemas
import pandas as pd


class QueryType(UStrEnum):
Expand Down Expand Up @@ -80,7 +75,7 @@ class JobStatus(UStrEnum):
INVALID_STATE = auto()


class Category(StrEnum):
class Category(UStrEnum):
JOB = auto()
VIEW = auto()
TABLE = auto()
Expand All @@ -93,9 +88,9 @@ class Category(StrEnum):


class UserOrRole(UStrEnum):
USER_TYPE_UNSPECIFIED = auto()
USER_TYPE_USER = auto()
USER_TYPE_ROLE = auto()
UNSPECIFIED = auto()
USER = auto()
ROLE = auto()


class EnterpriseDatasetType(UStrEnum):
Expand Down Expand Up @@ -161,6 +156,15 @@ class EnterpriseSearchCatalogObject(BaseModel):
func_sql: Optional[str] = Field(default=None, alias="functionSql")
owner: Optional[EnterpriseSearchUserOrRoleObject] = None

def as_df_dict(self):
return {
"path": self.path,
"name": ".".join(f'"{p}"' for p in self.path),
"type": self.type,
"tags": ",".join(self.labels),
"description": self.wiki,
}


class EnterpriseSearchResultsObject(BaseModel):
category: Optional[Category] = None
Expand Down Expand Up @@ -209,17 +213,24 @@ class EnterpriseSearchResultsWrapper(BaseModel):
results: List[EnterpriseSearchResultsObject] = Field(default_factory=list)


async def get_search_results(search: str | Search) -> EnterpriseSearchResultsWrapper:
async def get_search_results(
search: str | Search, use_df: bool = False
) -> EnterpriseSearchResultsWrapper | pd.DataFrame:
if isinstance(search, str):
search = Search(query=search)

client = AsyncHttpClient(
settings.instance().dremio.uri, settings.instance().dremio.pat
)

endpoint = (
f"/v0/projects/{settings.instance().dremio.project_id}/search"
if settings.instance().dremio.project_id
else "/api/v3/search"
)
result = []
response = await client.post(
"/api/v3/search",
endpoint,
body=search.model_dump(exclude_none=True),
deser=EnterpriseSearchResults,
)
Expand All @@ -229,9 +240,20 @@ async def get_search_results(search: str | Search) -> EnterpriseSearchResultsWra
break
search.next_page_token = response.next_page_token
response = await client.post(
"/api/v3/search",
endpoint,
body=search.model_dump(exclude_none=True),
deser=EnterpriseSearchResults,
)

if use_df:
result = [r for r in result if r.category in (Category.TABLE, Category.VIEW)]

data = [r.catalog.as_df_dict() for r in result]
paths = [p["path"] for p in data]
if schemas := await get_schemas(paths, include_tags=True, flatten=True):
for ix, schema in enumerate(schemas):
data[ix]["schema"] = schema.get("schema")

return pd.DataFrame(data=data)

return EnterpriseSearchResultsWrapper(results=result)
29 changes: 16 additions & 13 deletions src/dremioai/servers/frameworks/langchain/server.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#
#
# Copyright (C) 2017-2025 Dremio Corporation
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#


from langchain_ollama import ChatOllama
Expand All @@ -21,7 +21,7 @@
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import ToolMessage, AIMessage

from prompt_toolkit import PromptSession
from prompt_toolkit import PromptSession, print_formatted_text
from prompt_toolkit.history import FileHistory
from prompt_toolkit.styles import Style
from prompt_toolkit.formatted_text import HTML
Expand All @@ -34,7 +34,8 @@
from dremioai import log
from dremioai.config import settings
from typer import Typer, Option
from rich import print as pp
from rich.console import Console
from rich.markdown import Markdown
from dremioai.servers.frameworks.langchain.tools import discover_tools, discover_prompt

# def system_prompt():
Expand Down Expand Up @@ -66,6 +67,7 @@ def main(
show_default=True,
),
] = settings.default_config(),
debug: Annotated[bool, Option(help="Enable debug logging")] = False,
):

settings.configure(config_file)
Expand All @@ -77,7 +79,7 @@ def main(
}
)
session = PromptSession(
history=FileHistory(Path("~/.mcp.history").expanduser()), style=custom_style
history=FileHistory(Path.home() / ".mcp.history"), style=custom_style
)

if settings.instance().langchain.ollama is not None:
Expand All @@ -98,7 +100,7 @@ def main(
model=llm,
tools=discover_tools(settings.instance().tools.server_mode),
prompt=discover_prompt(),
debug=True,
debug=debug,
)

# agent = create_openai_tools_agent(
Expand All @@ -108,6 +110,7 @@ def main(
# agent_executor = AgentExecutor(agent=agent, tools=registered_tools, verbose=True)
# chain = chat_prompt | llm | StrOutputParser()
chat_history = []
console = Console()

while True:
try:
Expand All @@ -120,11 +123,11 @@ def main(
if chat_history:
args["chat_history"] = chat_history
response: Dict[str, List[ToolMessage, AIMessage]] = executor.invoke(args)
for message in response:
if isinstance(message, ToolMessage):
message.pretty_print()
# for message in response:
# if isinstance(message, ToolMessage):
# message.pretty_print()

response["messages"][-1].pretty_print()
console.print(Markdown(str(response["messages"][-1].content)))
chat_history.extend(
[("human", user_input), ("system", response["messages"][-1].content)]
)
Expand Down
Loading