Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/dremioai/api/cli/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def run_catalog(dataset_id: Annotated[str, Option(...)]):
pp(lineage)


@catalog_app.command(name="schema")
def run_catalog(
id_or_path: Annotated[str, Option(help="Dataset ID or path")],
by_id: Annotated[bool, Option(help="Whether the dataset id is an id")] = False,
):
schema = asyncio.run(catalog.get_schema(id_or_path, by_id=by_id))
pp(schema)


# _qg = "Query / Job ID "
@sql_app.command("run")
def run_sql(
Expand Down
4 changes: 2 additions & 2 deletions src/dremioai/api/dremio/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class LineageResponse(BaseModel):
children: List[LineageChildren]


async def get_lineage(dataset_id_or_path: str) -> str:
async def get_lineage(dataset_id_or_path: str) -> Dict[str, Any]:
client = AsyncHttpClient()
if "." in dataset_id_or_path:
response = await get_schema(dataset_id_or_path, by_id=False)
Expand All @@ -122,7 +122,7 @@ async def get_lineage(dataset_id_or_path: str) -> str:
f"{endpoint}/{dataset_id_or_path}/graph",
deser=LineageResponse,
)
return result.model_dump_json()
return result.model_dump()


async def get_schema(
Expand Down
14 changes: 7 additions & 7 deletions src/dremioai/api/dremio/engines.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#
#
# Copyright (C) 2017-2025 Dremio Corporation
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#

from pydantic import BaseModel, Field, BeforeValidator
from typing import List, Dict, Union, Optional, Any, Annotated
Expand Down Expand Up @@ -103,7 +103,7 @@ async def get_engines(
use_df: Optional[bool] = False,
add_project_id: Optional[bool] = False,
) -> Union[pd.DataFrame, EngineList]:
client = AsyncHttpClient(uri=uri, pat=pat)
client = AsyncHttpClient()

if isinstance(project_id, list):
result = await run_in_parallel(
Expand All @@ -119,7 +119,7 @@ async def get_engines(
engine_ids = [engine_ids]

async def _fetch_one(eid: str):
client = AsyncHttpClient(uri=uri, pat=pat)
client = AsyncHttpClient()
return await client.get(
f"/v0/projects/{project_id}/engines/{eid}", deser=Engine
)
Expand Down
14 changes: 7 additions & 7 deletions src/dremioai/api/dremio/projects.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#
#
# Copyright (C) 2017-2025 Dremio Corporation
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#

from pydantic import BaseModel, Field, BeforeValidator
from typing import List, Dict, Union, Optional, Any, Annotated
Expand Down Expand Up @@ -105,14 +105,14 @@ async def get_projects(
project_ids: Optional[Union[List[str], str]] = None,
use_df: Optional[bool] = False,
) -> Union[pd.DataFrame, ProjectsList]:
client = AsyncHttpClient(uri=uri, pat=pat)
client = AsyncHttpClient()

if project_ids:
if isinstance(engine_ids, str):
engine_ids = [engine_ids]

async def _fetch_one(pid: str):
client = AsyncHttpClient(uri=uri, pat=pat)
client = AsyncHttpClient()
return await client.get(f"/v0/projects/{pid}", deser=Project)

pl = await run_in_parallel([_fetch_one(p) for p in project_ids])
Expand Down
5 changes: 1 addition & 4 deletions src/dremioai/api/dremio/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,7 @@ async def get_search_results(
if isinstance(search, str):
search = Search(query=search)

client = AsyncHttpClient(
settings.instance().dremio.uri, settings.instance().dremio.pat
)

client = AsyncHttpClient()
endpoint = (
f"/v0/projects/{settings.instance().dremio.project_id}/search"
if settings.instance().dremio.project_id
Expand Down
4 changes: 2 additions & 2 deletions src/dremioai/api/dremio/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class JobResultsParams(BaseModel):
async def _fetch_results(
uri: str, pat: str, project_id: str, job_id: str, off: int, limit: int
) -> JobResults:
client = AsyncHttpClient(uri=uri, pat=pat)
client = AsyncHttpClient()
params = JobResultsParams(offset=off, limit=limit)
endpoint = f"/v0/projects/{project_id}" if project_id else "/api/v3"
return await client.get(
Expand All @@ -182,7 +182,7 @@ async def get_results(
qs = QuerySubmission(id=qs)

if client is None:
client = AsyncHttpClient(uri=uri, pat=pat)
client = AsyncHttpClient()

endpoint = f"/v0/projects/{project_id}" if project_id else "/api/v3"
job: Job = await client.get(f"{endpoint}/job/{qs.id}", deser=Job)
Expand Down
2 changes: 1 addition & 1 deletion src/dremioai/api/dremio/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ async def get_usage(
if isinstance(project_ids, str):
params.for_project_id(project_ids)

client = AsyncHttpClient(uri=uri, pat=pat)
client = AsyncHttpClient()

async def _get_usage(p: Params) -> Usage:
p = p.model_dump() if p is not None else None
Expand Down
8 changes: 3 additions & 5 deletions src/dremioai/api/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def post(


class DremioAsyncHttpClient(AsyncHttpClient):
def __init__(self, uri: Optional[str] = None, pat: Optional[str] = None):
def __init__(self):
dremio = settings.instance().dremio
if (
dremio.oauth_supported
Expand All @@ -135,10 +135,8 @@ def __init__(self, uri: Optional[str] = None, pat: Optional[str] = None):
oauth = get_oauth2_tokens()
oauth.update_settings()

if uri is None:
uri = dremio.uri
if pat is None:
pat = dremio.pat
uri = dremio.uri
pat = dremio.pat

if uri is None or pat is None:
raise RuntimeError(f"uri={uri} pat={pat} are required")
Expand Down
4 changes: 2 additions & 2 deletions src/dremioai/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def ensure_query_allowed(s: str):
"The query contains a DML statement. Only select queries are allowed"
)

async def invoke(self, s: str) -> Dict[str, List[Any]]:
async def invoke(self, s: str) -> Dict[str, List[Dict[Any, Any]]]:
"""Run a SELECT sql query on the Dremio cluster and return the results.
Ensure that SQL keywords like 'day', 'month', 'count', 'table' etc are enclosed in double quotes
You are premitted to run only SELECT queries. No DML statements are allowed.
Expand All @@ -285,7 +285,7 @@ async def invoke(self, s: str) -> Dict[str, List[Any]]:
try:
s = f"/* dremioai: submitter={self.__class__.__name__} */\n{s}"
df = await sql.run_query(query=s, use_df=True)
return {"results": df.to_dict(orient="records")}
return {"result": df.to_dict(orient="records")}
except RuntimeError as e:
return {
"error": str(e),
Expand Down
1 change: 1 addition & 0 deletions tests/mocks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Mock framework for testing
179 changes: 179 additions & 0 deletions tests/mocks/http_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#
# Copyright (C) 2017-2025 Dremio Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import json
import re
from pathlib import Path
from typing import Dict, Any, Optional, Union
from unittest.mock import MagicMock
from aiohttp import ClientSession
from collections import OrderedDict


class MockResponse:
"""Mock ClientResponse that returns data from files"""

def __init__(self, data: str, status: int = 200, headers: Optional[Dict] = None):
self.data = data
self.status = status
self.headers = headers or {}
self.request_info = MagicMock()
self.request_info.method = "GET"
self.request_info.url = "http://mock.url"

async def text(self) -> str:
"""Return the mock data as text"""
return self.data

async def json(self) -> Dict[str, Any]:
"""Return the mock data as JSON"""
return json.loads(self.data)

def raise_for_status(self):
"""Mock raise_for_status - only raises if status >= 400"""
if self.status >= 400:
raise Exception(f"HTTP {self.status}")

@property
def content(self):
"""Mock content property for streaming reads"""
mock_content = MagicMock()

async def read(chunk_size=1024):
# Return data in chunks for download simulation
if hasattr(self, "_read_position"):
if self._read_position >= len(self.data):
return b""
chunk = self.data[
self._read_position : self._read_position + chunk_size
].encode()
self._read_position += chunk_size
return chunk
else:
self._read_position = 0
chunk = self.data[0:chunk_size].encode()
self._read_position = chunk_size
return chunk

mock_content.read = read
return mock_content

async def __aenter__(self):
"""Async context manager entry"""
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
pass


class HttpMockFramework:
"""Simple HTTP mock framework for testing transport.py"""

def __init__(self, resources_dir: str = "tests/resources"):
self.resources_dir = Path(resources_dir)
self.mock_responses = OrderedDict()
self.original_session = None

def load_mock_data(self, endpoint: str, filename: str) -> "HttpMockFramework":
"""
Load mock data from a file for a specific endpoint

Args:
endpoint: The API endpoint to mock (e.g., "/api/v3/catalog")
filename: The filename in tests/resources (e.g., "catalog/spaces.json")
"""
file_path = self.resources_dir / filename
if not file_path.exists():
raise FileNotFoundError(f"Mock data file not found: {file_path}")

with open(file_path, "r") as f:
self.mock_responses[endpoint] = f.read()

return self

def add_mock_response(
self, endpoint: str, response_data: Union[str, Dict]
) -> "HttpMockFramework":
"""
Add a mock response directly without loading from file

Args:
endpoint: The API endpoint to mock
response_data: The response data (string or dict that will be JSON serialized)
"""
if isinstance(response_data, dict):
response_data = json.dumps(response_data)
self.mock_responses[endpoint] = response_data
return self

def _get_mock_response(self, url: str, method: str = "GET") -> MockResponse:
"""Get mock response for a URL"""
# Extract endpoint from full URL
for endpoint, data in self.mock_responses.items():
if re.search(endpoint, url):
return MockResponse(data)

# Default response if no mock found
return MockResponse('{"error": "No mock data found"}', status=404)

def _mock_get(self, url: str, **kwargs) -> MockResponse:
"""Mock ClientSession.get method"""
return self._get_mock_response(url, "GET")

def _mock_post(self, url: str, **kwargs) -> MockResponse:
"""Mock ClientSession.post method"""
return self._get_mock_response(url, "POST")

def __enter__(self):
"""Context manager entry - start mocking"""
# Store original methods
self.original_get = ClientSession.get
self.original_post = ClientSession.post

# Replace with mocks
ClientSession.get = self._mock_get
ClientSession.post = self._mock_post

return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - restore original methods"""
# Restore original methods
ClientSession.get = self.original_get
ClientSession.post = self.original_post


# Convenience function for quick setup
def mock_http_client(mock_data: OrderedDict[str, str]) -> HttpMockFramework:
"""
Create and configure an HTTP mock framework

Args:
mock_data: Dictionary mapping endpoints to filenames in tests/resources

Example:
with mock_http_client({
"/api/v3/catalog": "catalog/spaces.json",
"/api/v3/sql": "sql/job_status.json"
}) as mock:
# Your test code here
pass
"""
framework = HttpMockFramework()
for endpoint, filename in mock_data.items():
framework.load_mock_data(endpoint, filename)
return framework
Loading