Skip to content

Commit b4b76e9

Browse files
authored
DX-114411: MCP Search - Adding by default removeCatalogName to true for Catalog API call (#84)
* DX-114411: MCP Search - Adding by default removeCatalogName to true for Catalog API call * updated mcp with suggestions * cleanup * fix
1 parent c0cbb0f commit b4b76e9

3 files changed

Lines changed: 19 additions & 5 deletions

File tree

src/dremioai/api/dremio/catalog.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,21 @@ class LineageResponse(BaseModel):
110110
children: List[LineageChildren]
111111

112112

113-
async def get_lineage(dataset_id_or_path: str) -> Dict[str, Any]:
113+
async def get_lineage(
114+
dataset_id_or_path: str,
115+
remove_catalog_name: Optional[bool] = True
116+
) -> Dict[str, Any]:
114117
client = AsyncHttpClient()
115118
if "." in dataset_id_or_path:
116119
response = await get_schema(dataset_id_or_path, by_id=False)
117120
dataset_id_or_path = response["id"]
118121

119122
project_id = settings.instance().dremio.project_id
120123
endpoint = f"/v0/projects/{project_id}/catalog" if project_id else "/api/v3/catalog"
124+
params = {"removeCatalogName": str(remove_catalog_name).lower()}
121125
result: LineageResponse = await client.get(
122126
f"{endpoint}/{dataset_id_or_path}/graph",
127+
params=params,
123128
deser=LineageResponse,
124129
)
125130
return result.model_dump()
@@ -130,10 +135,12 @@ async def get_schema(
130135
by_id: Optional[bool] = False,
131136
include_tags: Optional[bool] = False,
132137
flatten: Optional[bool] = False,
138+
remove_catalog_name: Optional[bool] = True
133139
) -> Dict[str, Any]:
134140
client = AsyncHttpClient()
135141
project_id = settings.instance().dremio.project_id
136142
endpoint = f"/v0/projects/{project_id}/catalog" if project_id else "/api/v3/catalog"
143+
params = {"removeCatalogName": str(remove_catalog_name).lower()}
137144
if by_id:
138145
endpoint += "/" + dataset_path_or_id
139146
else:
@@ -142,7 +149,7 @@ async def get_schema(
142149
reader(StringIO(dataset_path_or_id), delimiter=".", dialect=excel)
143150
)[0]
144151
endpoint += f'/by-path/{"/".join(dataset_path_or_id)}'
145-
schema = await client.get(endpoint)
152+
schema = await client.get(endpoint, params=params)
146153

147154
if include_tags:
148155

src/dremioai/api/dremio/search.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ class EnterpriseSearchResultsWrapper(BaseModel):
214214

215215

216216
async def get_search_results(
217-
search: str | Search, use_df: bool = False
217+
search: str | Search, use_df: bool = False,
218+
remove_catalog_name: Optional[bool] = True
218219
) -> EnterpriseSearchResultsWrapper | pd.DataFrame:
219220
if isinstance(search, str):
220221
search = Search(query=search)
@@ -225,11 +226,15 @@ async def get_search_results(
225226
if settings.instance().dremio.project_id
226227
else "/api/v3/search"
227228
)
229+
230+
params = {"removeCatalogName": str(remove_catalog_name).lower()}
231+
228232
result = []
229233
response = await client.post(
230234
endpoint,
231235
body=search.model_dump(exclude_none=True),
232236
deser=EnterpriseSearchResults,
237+
params=params,
233238
)
234239
while response.results and response.error is None and response.more is None:
235240
result.extend(response.results)
@@ -240,6 +245,7 @@ async def get_search_results(
240245
endpoint,
241246
body=search.model_dump(exclude_none=True),
242247
deser=EnterpriseSearchResults,
248+
params=params,
243249
)
244250

245251
if use_df:

src/dremioai/api/transport.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,12 @@ async def post(
195195
deser: Optional[DeserializationStrategy] = None,
196196
file: Optional[TextIO] = None,
197197
top_level_list: bool = False,
198+
params: Dict[AnyStr, Any] = None,
198199
):
199200
async with ClientSession(middlewares=(retry_middleware,)) as session:
200-
self.log_request("POST", endpoint)
201+
self.log_request("POST", endpoint, params)
201202
async with session.post(
202-
f"{self.uri}{endpoint}", headers=self.headers, json=body, ssl=False
203+
f"{self.uri}{endpoint}", params=params, headers=self.headers, json=body, ssl=False
203204
) as response:
204205
return await self.handle_response(
205206
response, deser, file, top_level_list=top_level_list

0 commit comments

Comments
 (0)