-
Notifications
You must be signed in to change notification settings - Fork 40
Expand file tree
/
Copy pathmcp_server.py
More file actions
2857 lines (2405 loc) · 109 KB
/
mcp_server.py
File metadata and controls
2857 lines (2405 loc) · 109 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""DataHub MCP Server Implementation.
IMPORTANT: This file is kept in sync between two repositories.
When making changes, ensure both versions remain identical. Use relative imports
(e.g., `from ._token_estimator import ...`) instead of absolute imports to maintain
compatibility across both repositories.
"""
import contextlib
import contextvars
import functools
import html
import inspect
import json
import os
import pathlib
import re
import string
import threading
from enum import Enum
from typing import (
Any,
Awaitable,
Callable,
Dict,
Generator,
Iterator,
List,
Literal,
Optional,
ParamSpec,
TypeVar,
)
import asyncer
import cachetools
import jmespath
from datahub.cli.env_utils import get_boolean_env_variable
from datahub.errors import ItemNotFoundError
from datahub.ingestion.graph.client import DataHubGraph
from datahub.metadata.urns import DatasetUrn, SchemaFieldUrn, Urn
from datahub.sdk.main_client import DataHubClient
from datahub.sdk.search_client import compile_filters
from datahub.sdk.search_filters import Filter, FilterDsl, load_filters
from datahub.utilities.ordered_set import OrderedSet
from fastmcp import FastMCP
from fastmcp.tools.tool import Tool as FastMCPTool
from json_repair import repair_json
from loguru import logger
from pydantic import BaseModel
# IMPORTANT: Use relative imports to maintain compatibility across repositories
from ._token_estimator import TokenCountEstimator
from .tools.descriptions import update_description
from .tools.documents import grep_documents, search_documents
from .tools.domains import remove_domains, set_domains
from .tools.get_me import get_me
from .tools.owners import add_owners, remove_owners
from .tools.save_document import is_save_document_enabled, save_document
from .tools.structured_properties import (
add_structured_properties,
remove_structured_properties,
)
from .tools.tags import add_tags, remove_tags
from .tools.terms import (
add_glossary_terms,
remove_glossary_terms,
)
from .version_requirements import TOOL_VERSION_REQUIREMENTS
_P = ParamSpec("_P")
_R = TypeVar("_R")
T = TypeVar("T")
DESCRIPTION_LENGTH_HARD_LIMIT = 1000
QUERY_LENGTH_HARD_LIMIT = 5000
DOCUMENT_CONTENT_CHAR_LIMIT = 8000
# Maximum token count for tool responses to prevent context window issues
# As per telemetry tool result length goes upto
TOOL_RESPONSE_TOKEN_LIMIT = int(os.getenv("TOOL_RESPONSE_TOKEN_LIMIT", 80000))
# Per-entity schema token budget for field truncation
# Assumes ~5 entities per response: 80K total / 5 = 16K per entity
ENTITY_SCHEMA_TOKEN_BUDGET = int(os.getenv("ENTITY_SCHEMA_TOKEN_BUDGET", "16000"))
class ToolType(Enum):
"""Tool type enumeration for different tool types."""
SEARCH = "search" # Datahub search tools
MUTATION = "mutation" # Datahub mutation tools
USER = "user" # Datahub user tools
DEFAULT = "default" # Fallback tag
def _select_results_within_budget(
results: Iterator[T],
fetch_entity: Callable[[T], dict],
max_results: int = 10,
token_budget: Optional[int] = None,
) -> Generator[T, None, None]:
"""
Generator that yields results within token budget.
Generic helper that works for any result structure. Caller provides a function
to extract/clean entity for token counting (can mutate the result).
Yields results until:
- max_results reached, OR
- token_budget would be exceeded (and we have at least 1 result)
Args:
results: Iterator of result objects of any type T (memory efficient)
fetch_entity: Function that extracts entity dict from result for token counting.
Can mutate the result to clean/update entity in place.
Signature: T -> dict (entity for token counting)
Example: lambda r: (r.__setitem__("entity", clean(r["entity"])), r["entity"])[1]
max_results: Maximum number of results to return
token_budget: Token budget (defaults to 90% of TOOL_RESPONSE_TOKEN_LIMIT)
Yields:
Original result objects of type T (possibly mutated by fetch_entity)
"""
if token_budget is None:
# Use 90% of limit as safety buffer:
# - Token estimation is approximate, not exact
# - Response wrapper adds overhead
# - Better to return fewer results that fit than exceed limit
token_budget = int(TOOL_RESPONSE_TOKEN_LIMIT * 0.9)
total_tokens = 0
results_count = 0
# Consume iterator up to max_results
for i, result in enumerate(results):
if i >= max_results:
break
# Extract (and possibly clean) entity using caller's lambda
# Note: fetch_entity may mutate result to clean/update entity in place
entity = fetch_entity(result)
# Estimate token cost
entity_tokens = TokenCountEstimator.estimate_dict_tokens(entity)
# Check if adding this entity would exceed budget
if total_tokens + entity_tokens > token_budget:
if results_count == 0:
# Always yield at least 1 result
logger.warning(
f"First result ({entity_tokens:,} tokens) exceeds budget ({token_budget:,}), "
"yielding it anyway"
)
yield result # Yield original result structure
results_count += 1
total_tokens += entity_tokens
else:
# Have at least 1 result, stop here to stay within budget
logger.info(
f"Stopping at {results_count} results (next would exceed {token_budget:,} token budget)"
)
break
else:
yield result # Yield original result structure
results_count += 1
total_tokens += entity_tokens
logger.info(
f"Selected {results_count} results using {total_tokens:,} tokens "
f"(budget: {token_budget:,})"
)
def sanitize_html_content(text: str) -> str:
"""
Remove HTML tags and decode HTML entities from text.
Uses a bounded regex pattern to prevent ReDoS (Regular Expression Denial of Service)
attacks. The pattern limits matching to tags with at most 100 characters between < and >,
which prevents backtracking on malicious input like "<" followed by millions of characters
without a closing ">".
"""
if not text:
return text
# Use bounded regex to prevent ReDoS (max 100 chars between < and >)
text = re.sub(r"<[^<>]{0,100}>", "", text)
# Decode HTML entities
text = html.unescape(text)
return text.strip()
def truncate_with_ellipsis(text: str, max_length: int, suffix: str = "...") -> str:
"""Truncate text to max_length and add suffix if truncated."""
if not text or len(text) <= max_length:
return text
# Account for suffix length
actual_max = max_length - len(suffix)
return text[:actual_max] + suffix
def sanitize_markdown_content(text: str) -> str:
"""Remove markdown-style embeds that contain encoded data from text, but preserve alt text."""
if not text:
return text
# Remove markdown embeds with data URLs (base64 encoded content) but preserve alt text
# Pattern:  -> alt text
text = re.sub(r"!\[([^\]]*)\]\(data:[^)]+\)", r"\1", text)
return text.strip()
def sanitize_and_truncate_description(text: str, max_length: int) -> str:
"""Sanitize HTML content and truncate to specified length."""
if not text:
return text
try:
# First sanitize HTML content
sanitized = sanitize_html_content(text)
# Then sanitize markdown content (preserving alt text)
sanitized = sanitize_markdown_content(sanitized)
# Then truncate if needed
return truncate_with_ellipsis(sanitized, max_length)
except Exception as e:
logger.warning(f"Error sanitizing and truncating description: {e}")
return text[:max_length] if len(text) > max_length else text
def truncate_descriptions(
data: dict | list, max_length: int = DESCRIPTION_LENGTH_HARD_LIMIT
) -> None:
"""
Recursively truncates values of keys named 'description' in a dictionary in place.
"""
# TODO: path-aware truncate, for different length limits per entity type
if isinstance(data, dict):
for key, value in data.items():
if key == "description" and isinstance(value, str):
data[key] = sanitize_and_truncate_description(value, max_length)
elif isinstance(value, (dict, list)):
truncate_descriptions(value)
elif isinstance(data, list):
for item in data:
truncate_descriptions(item)
def truncate_query(query: str) -> str:
"""
Truncate a SQL query if it exceeds the maximum length.
"""
return truncate_with_ellipsis(
query, QUERY_LENGTH_HARD_LIMIT, suffix="... [truncated]"
)
# See https://github.com/jlowin/fastmcp/issues/864#issuecomment-3103678258
# for why we need to wrap sync functions with asyncify.
def async_background(fn: Callable[_P, _R]) -> Callable[_P, Awaitable[_R]]:
if inspect.iscoroutinefunction(fn):
raise RuntimeError("async_background can only be used on non-async functions")
@functools.wraps(fn)
async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
try:
return await asyncer.asyncify(fn)(*args, **kwargs)
except Exception:
# Log with full stack trace before FastMCP catches it
logger.exception(
f"Tool function {fn.__name__} failed with args={args}, kwargs={kwargs}"
)
raise
return wrapper
def _register_tool(
mcp_instance: FastMCP,
name: str,
fn: Callable,
*,
description: Optional[str] = None,
tags: Optional[set] = None,
) -> None:
"""Register a tool on the MCP instance and capture its version requirement.
This is a convenience wrapper that:
1. Wraps the sync function with async_background
2. Registers it on the MCP instance
3. Reads the _version_requirement attribute (set by @min_version decorator)
and populates TOOL_VERSION_REQUIREMENTS
Args:
mcp_instance: The FastMCP instance to register on.
name: The tool name (may differ from fn.__name__).
fn: The tool function (sync).
description: Tool description. Defaults to fn.__doc__.
tags: Optional set of tag strings.
"""
mcp_instance.tool(
name=name,
description=description or fn.__doc__,
tags=tags,
)(async_background(fn))
req = getattr(fn, "_version_requirement", None)
if req is not None:
TOOL_VERSION_REQUIREMENTS[name] = req
def create_mcp_server(name: str = "datahub") -> FastMCP[None]:
"""Create a new FastMCP server instance for DataHub tools."""
return FastMCP[None](name=name)
_mcp_dh_client = contextvars.ContextVar[DataHubClient]("_mcp_dh_client")
def get_datahub_client() -> DataHubClient:
# Will raise a LookupError if no client is set.
return _mcp_dh_client.get()
def set_datahub_client(client: DataHubClient) -> None:
_mcp_dh_client.set(client)
@contextlib.contextmanager
def with_datahub_client(client: DataHubClient) -> Iterator[None]:
token = _mcp_dh_client.set(client)
try:
yield
finally:
_mcp_dh_client.reset(token)
def _enable_newer_gms_fields(query: str) -> str:
"""
Enable newer GMS fields by removing the #[NEWER_GMS] marker suffix.
Converts:
someField #[NEWER_GMS]
To:
someField
"""
lines = query.split("\n")
cleaned_lines = [
line.replace(" #[NEWER_GMS]", "").replace("\t#[NEWER_GMS]", "")
for line in lines
]
return "\n".join(cleaned_lines)
def _disable_newer_gms_fields(query: str) -> str:
"""
Disable newer GMS fields by commenting out lines with #[NEWER_GMS] marker.
Converts:
someField #[NEWER_GMS]
To:
# someField #[NEWER_GMS]
"""
lines = query.split("\n")
processed_lines = []
for line in lines:
if "#[NEWER_GMS]" in line:
# Comment out the line by prefixing with #
processed_lines.append("# " + line)
else:
processed_lines.append(line)
return "\n".join(processed_lines)
def _enable_cloud_fields(query: str) -> str:
"""
Enable cloud fields by removing the #[CLOUD] marker suffix.
Converts:
someField #[CLOUD]
To:
someField
"""
lines = query.split("\n")
cleaned_lines = [
line.replace(" #[CLOUD]", "").replace("\t#[CLOUD]", "") for line in lines
]
return "\n".join(cleaned_lines)
def _disable_cloud_fields(query: str) -> str:
"""
Disable cloud fields by commenting out lines with #[CLOUD] marker.
Converts:
someField #[CLOUD]
To:
# someField #[CLOUD]
"""
lines = query.split("\n")
processed_lines = []
for line in lines:
if "#[CLOUD]" in line:
# Comment out the line by prefixing with #
processed_lines.append("# " + line)
else:
processed_lines.append(line)
return "\n".join(processed_lines)
# Cache to track whether newer GMS fields are supported for each graph instance
# Key: id(graph), Value: bool indicating if newer GMS fields are supported
_newer_gms_fields_support_cache: dict[int, bool] = {}
def _is_datahub_cloud(graph: DataHubGraph) -> bool:
"""Check if the graph instance is DataHub Cloud.
Cloud instances typically have newer GMS versions with additional fields.
This heuristic uses the presence of frontend_base_url to detect Cloud instances.
"""
# Allow disabling newer GMS field detection via environment variable
# This is useful when the GMS version doesn't support all newer fields
if get_boolean_env_variable("DISABLE_NEWER_GMS_FIELD_DETECTION", default=False):
logger.debug(
"Newer GMS field detection disabled via DISABLE_NEWER_GMS_FIELD_DETECTION"
)
return False
try:
# Only DataHub Cloud has a frontend base url.
# Cloud instances typically run newer GMS versions with additional fields.
_ = graph.frontend_base_url
except ValueError:
return False
return True
def _is_field_validation_error(error_msg: str) -> bool:
"""Check if the error is a GraphQL field/type validation or syntax error.
Includes InvalidSyntax because unknown types (like Document on older GMS)
cause syntax errors rather than validation errors.
"""
return (
"FieldUndefined" in error_msg
or "ValidationError" in error_msg
or "InvalidSyntax" in error_msg
)
def execute_graphql(
graph: DataHubGraph,
*,
query: str,
operation_name: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
) -> Any:
graph_id = id(graph)
original_query = query # Keep original for fallback
# Detect if this is a DataHub Cloud instance
is_cloud = _is_datahub_cloud(graph)
# Process CLOUD tags
if is_cloud:
query = _enable_cloud_fields(query)
else:
query = _disable_cloud_fields(query)
# Process NEWER_GMS tags
# Check if we've already determined newer GMS fields support for this graph
newer_gms_enabled_for_this_query = False
if graph_id in _newer_gms_fields_support_cache:
supports_newer_fields = _newer_gms_fields_support_cache[graph_id]
if supports_newer_fields:
query = _enable_newer_gms_fields(query)
newer_gms_enabled_for_this_query = True
else:
query = _disable_newer_gms_fields(query)
else:
# First attempt: try with newer GMS fields if it's detected as cloud
# (Cloud instances typically run newer GMS versions)
if is_cloud:
query = _enable_newer_gms_fields(query)
newer_gms_enabled_for_this_query = True
else:
query = _disable_newer_gms_fields(query)
# Cache the initial detection result
_newer_gms_fields_support_cache[graph_id] = is_cloud
logger.debug(
f"Executing GraphQL {operation_name or 'query'}: "
f"is_cloud={is_cloud}, newer_gms_enabled={newer_gms_enabled_for_this_query}"
)
logger.debug(
f"GraphQL query for {operation_name or 'query'}:\n{query}\nVariables: {variables}"
)
try:
# Execute the GraphQL query
result = graph.execute_graphql(
query=query, variables=variables, operation_name=operation_name
)
return result
except Exception as e:
error_msg = str(e)
# Check if this is a field validation error and we tried with newer GMS fields enabled
# Only retry if we had newer GMS fields enabled in the query that just failed
if _is_field_validation_error(error_msg) and newer_gms_enabled_for_this_query:
logger.warning(
f"GraphQL schema validation error detected for {operation_name or 'query'}. "
f"Retrying without newer GMS fields as fallback."
)
logger.exception(e)
# Update cache to indicate newer GMS fields are NOT supported
_newer_gms_fields_support_cache[graph_id] = False
# Retry with newer GMS fields disabled - process both tags again
try:
fallback_query = original_query
# Reprocess CLOUD tags
if is_cloud:
fallback_query = _enable_cloud_fields(fallback_query)
else:
fallback_query = _disable_cloud_fields(fallback_query)
# Disable newer GMS fields for fallback
fallback_query = _disable_newer_gms_fields(fallback_query)
logger.debug(
f"Retry {operation_name or 'query'} with NEWER_GMS fields disabled: "
f"is_cloud={is_cloud}"
)
result = graph.execute_graphql(
query=fallback_query,
variables=variables,
operation_name=operation_name,
)
logger.info(
f"Fallback query succeeded without newer GMS fields for operation: {operation_name}"
)
return result
except Exception as fallback_error:
logger.exception(
f"Fallback query also failed for {operation_name or 'query'}: {fallback_error}"
)
raise fallback_error
elif (
_is_field_validation_error(error_msg)
and not newer_gms_enabled_for_this_query
):
# Field validation error but NEWER_GMS fields were already disabled
logger.error(
f"GraphQL schema validation error for {operation_name or 'query'} "
f"but NEWER_GMS fields were already disabled (is_cloud={is_cloud}). "
f"This may indicate a CLOUD-only field being used on a non-cloud instance, "
f"or a field that's unavailable in this GMS version."
)
logger.exception(e)
# Keep essential error logging for troubleshooting with full stack trace
logger.exception(
f"GraphQL {operation_name or 'query'} failed: {e}\n"
f"Cloud instance: {is_cloud}\n"
f"Newer GMS fields enabled: {_newer_gms_fields_support_cache.get(graph_id, 'unknown')}\n"
f"Variables: {variables}"
)
raise
def inject_urls_for_urns(
graph: DataHubGraph, response: Any, json_paths: List[str]
) -> None:
if not _is_datahub_cloud(graph):
return
for path in json_paths:
for item in jmespath.search(path, response) if path else [response]:
if isinstance(item, dict) and item.get("urn"):
# Update item in place with url, ensuring that urn and url are first.
new_item = {"urn": item["urn"], "url": graph.url_for(item["urn"])}
new_item.update({k: v for k, v in item.items() if k != "urn"})
item.clear()
item.update(new_item)
def maybe_convert_to_schema_field_urn(urn: str, column: Optional[str]) -> str:
if column:
maybe_dataset_urn = Urn.from_string(urn)
if not isinstance(maybe_dataset_urn, DatasetUrn):
raise ValueError(
f"Input urn should be a dataset urn if column is provided, but got {urn}."
)
urn = str(SchemaFieldUrn(maybe_dataset_urn, column))
return urn
search_gql = (pathlib.Path(__file__).parent / "gql/search.gql").read_text()
semantic_search_gql = (
pathlib.Path(__file__).parent / "gql/semantic_search.gql"
).read_text()
smart_search_gql = (pathlib.Path(__file__).parent / "gql/smart_search.gql").read_text()
entity_details_fragment_gql = (
pathlib.Path(__file__).parent / "gql/entity_details.gql"
).read_text()
queries_gql = (pathlib.Path(__file__).parent / "gql/queries.gql").read_text()
query_entity_gql = (pathlib.Path(__file__).parent / "gql/query_entity.gql").read_text()
related_documents_gql = (
pathlib.Path(__file__).parent / "gql/related_documents.gql"
).read_text()
def _is_semantic_search_enabled() -> bool:
"""Check if semantic search is enabled via environment variable.
IMPORTANT: Semantic search is an EXPERIMENTAL feature that is ONLY available on
DataHub Cloud deployments with specific versions and configurations. This feature
must be explicitly enabled by the DataHub team for your Cloud instance.
Note:
This function only checks the environment variable. Actual feature
availability is validated when the DataHub client is used.
"""
return get_boolean_env_variable("SEMANTIC_SEARCH_ENABLED", default=False)
# Global View Configuration
DISABLE_DEFAULT_VIEW = get_boolean_env_variable(
"DATAHUB_MCP_DISABLE_DEFAULT_VIEW", default=False
)
VIEW_CACHE_TTL_SECONDS = 300 # 5 minutes hardcoded
# Log configuration on startup
if not DISABLE_DEFAULT_VIEW:
logger.info("Default view application ENABLED (cache TTL: 5 minutes)")
else:
logger.info("Default view application DISABLED")
@cachetools.cached(cache=cachetools.TTLCache(maxsize=1, ttl=VIEW_CACHE_TTL_SECONDS))
def fetch_global_default_view(graph: DataHubGraph) -> Optional[str]:
"""
Fetch the organization's default global view URN unless disabled.
Cached for VIEW_CACHE_TTL_SECONDS seconds.
Returns None if disabled or if no default view is configured.
"""
# Return None immediately if feature is disabled
if DISABLE_DEFAULT_VIEW:
return None
query = """
query getGlobalViewsSettings {
globalViewsSettings {
defaultView
}
}
"""
result = execute_graphql(graph, query=query)
settings = result.get("globalViewsSettings")
if settings:
view_urn = settings.get("defaultView")
if view_urn:
logger.debug(f"Fetched global default view: {view_urn}")
return view_urn
logger.debug("No global default view configured")
return None
def clean_gql_response(response: Any) -> Any:
"""
Clean GraphQL response by removing metadata and empty values.
Recursively removes:
- __typename fields (GraphQL metadata not useful for consumers)
- None values
- Empty arrays []
- Empty dicts {} (after cleaning)
- Base64-encoded images from description fields (can be huge - 2MB!)
Args:
response: Raw GraphQL response (dict, list, or primitive)
Returns:
Cleaned response with same structure but without noise
"""
if isinstance(response, dict):
banned_keys = {
"__typename",
}
cleaned_response = {}
for k, v in response.items():
if k in banned_keys or v is None or v == []:
continue
cleaned_v = clean_gql_response(v)
# Strip base64 images from description fields
if (
k == "description"
and isinstance(cleaned_v, str)
and "base64" in cleaned_v
):
import re
cleaned_v = re.sub(
r"data:image/[^;]+;base64,[A-Za-z0-9+/=]+",
"[image removed]",
cleaned_v,
)
cleaned_v = re.sub(
r"!\[[^\]]*\]\(data:image/[^)]+\)", "[image removed]", cleaned_v
)
if cleaned_v is not None and cleaned_v != {}:
cleaned_response[k] = cleaned_v
return cleaned_response
elif isinstance(response, list):
return [clean_gql_response(item) for item in response]
else:
return response
def _sort_fields_by_priority(fields: List[dict]) -> Iterator[dict]:
"""
Yield schema fields sorted by priority for deterministic truncation.
Priority order:
1. Primary/partition keys (isPartOfKey, isPartitioningKey)
2. Fields with descriptions
3. Fields with tags or glossary terms
4. Alphabetically by fieldPath
Each field gets a score tuple for sorting:
- key_score: 2 if isPartOfKey, 1 if isPartitioningKey, 0 otherwise
- has_description: 1 if description exists, 0 otherwise
- has_tags: 1 if tags or glossary terms exist, 0 otherwise
- fieldPath: for alphabetical tiebreaker
Sorted in descending order by score components, then ascending by fieldPath.
Args:
fields: List of field dicts from GraphQL response
Yields:
Fields in priority order (generator for memory efficiency)
"""
# Score each field with tuple: (key_score, has_description, has_tags, fieldPath, index)
scored_fields = []
for idx, field in enumerate(fields):
# Score key fields (highest priority)
key_score = 0
if field.get("isPartOfKey"):
key_score = 2
elif field.get("isPartitioningKey"):
key_score = 1
# Score fields with descriptions
has_description = 1 if field.get("description") else 0
# Score fields with tags or glossary terms
has_tags_or_terms = 0
if field.get("tags") or field.get("glossaryTerms"):
has_tags_or_terms = 1
# Get fieldPath for alphabetical sorting (tiebreaker)
field_path = field.get("fieldPath", "")
# Store as (score_tuple, original_index, field)
# Sort descending by scores, ascending by fieldPath
score_tuple = (-key_score, -has_description, -has_tags_or_terms, field_path)
scored_fields.append((score_tuple, idx, field))
# Sort by score tuple
scored_fields.sort(key=lambda x: x[0])
# Yield fields in sorted order
for _, _, field in scored_fields:
yield field
def _clean_schema_fields(
sorted_fields: Iterator[dict], editable_map: dict[str, dict]
) -> Iterator[dict]:
"""
Clean and normalize schema fields for response.
Yields cleaned field dicts with only essential properties for SQL generation
and understanding schema structure. Merges user-edited metadata (descriptions,
tags, glossary terms) into fields with "edited*" prefix when they differ.
Note: All fields are expected to have fieldPath (always requested in GraphQL).
If fieldPath is missing, it indicates a data quality issue.
Args:
sorted_fields: Iterator of fields in priority order
editable_map: Map of fieldPath -> editable field data for merging
Yields:
Cleaned field dicts with merged editable data (generator for memory efficiency)
"""
for f in sorted_fields:
# fieldPath is required - it's always requested in GraphQL and is essential
# for identifying the field. If missing, fail fast rather than silently skipping.
field_dict = {"fieldPath": f["fieldPath"]}
# Add type if present (essential for SQL)
if field_type := f.get("type"):
field_dict["type"] = field_type
# Add nativeDataType if present (important for SQL type casting)
if native_type := f.get("nativeDataType"):
field_dict["nativeDataType"] = native_type
# Add description if present (truncated)
if description := f.get("description"):
field_dict["description"] = description[:120]
# Add nullable if present (important for SQL NULL handling)
if f.get("nullable") is not None:
field_dict["nullable"] = f.get("nullable")
# Add label if present (useful for human-readable names)
if label := f.get("label"):
field_dict["label"] = label
# Add isPartOfKey only if truthy (important for joins)
if f.get("isPartOfKey"):
field_dict["isPartOfKey"] = True
# Add isPartitioningKey only if truthy (important for query optimization)
if f.get("isPartitioningKey"):
field_dict["isPartitioningKey"] = True
# Add recursive only if truthy
if f.get("recursive"):
field_dict["recursive"] = True
# Add deprecation status if present (warn about deprecated fields)
if schema_field_entity := f.get("schemaFieldEntity"):
if deprecation := schema_field_entity.get("deprecation"):
if deprecation.get("deprecated"):
field_dict["deprecated"] = {
"deprecated": True,
"note": deprecation.get("note", "")[:120], # Truncate note
}
# Add tags if present (keep minimal info for classification context)
if tags := f.get("tags"):
if tag_list := tags.get("tags"):
# Keep just tag names for context
field_dict["tags"] = [
t["tag"]["properties"]["name"]
for t in tag_list
if t.get("tag", {}).get("properties")
and t["tag"]["properties"].get("name")
]
# Add glossary terms if present (keep minimal info for business context)
if glossary_terms := f.get("glossaryTerms"):
if terms_list := glossary_terms.get("terms"):
# Keep just term names for context
field_dict["glossaryTerms"] = [
t["term"]["properties"]["name"]
for t in terms_list
if t.get("term", {}).get("properties")
and t["term"]["properties"].get("name")
]
# Merge editable metadata if available for this field
field_path = f["fieldPath"]
if editable := editable_map.get(field_path):
# Add editedDescription if it differs from system description
if editable_desc := editable.get("description"):
system_desc = field_dict.get("description", "")
# Only add if different (token optimization)
if editable_desc[:120] != system_desc: # Compare truncated versions
field_dict["editedDescription"] = editable_desc[:120]
# Add editedTags if present and different
if editable_tags := editable.get("tags"):
if tag_list := editable_tags.get("tags"):
edited_tag_names = [
t["tag"]["properties"]["name"]
for t in tag_list
if t.get("tag", {}).get("properties")
and t["tag"]["properties"].get("name")
]
if edited_tag_names:
system_tags = field_dict.get("tags", [])
if edited_tag_names != system_tags:
field_dict["editedTags"] = edited_tag_names
# Add editedGlossaryTerms if present and different
if editable_terms := editable.get("glossaryTerms"):
if terms_list := editable_terms.get("terms"):
edited_term_names = [
t["term"]["properties"]["name"]
for t in terms_list
if t.get("term", {}).get("properties")
and t["term"]["properties"].get("name")
]
if edited_term_names:
system_terms = field_dict.get("glossaryTerms", [])
if edited_term_names != system_terms:
field_dict["editedGlossaryTerms"] = edited_term_names
yield field_dict
def clean_get_entities_response(
raw_response: dict,
*,
sort_fn: Optional[Callable[[List[dict]], Iterator[dict]]] = None,
offset: int = 0,
limit: Optional[int] = None,
) -> dict:
"""
Clean and optimize entity responses for LLM consumption.
Performs several transformations to reduce token usage while preserving essential information:
1. **Clean GraphQL artifacts**: Removes __typename, null values, empty objects/arrays
(via clean_gql_response)
2. **Schema field processing** (if schemaMetadata.fields exists):
- Sorts fields using sort_fn (defaults to _sort_fields_by_priority)
- Cleans each field to keep only essential properties (fieldPath, type, description, etc.)
- Merges editableSchemaMetadata into fields with "edited*" prefix (editedDescription,
editedTags, editedGlossaryTerms) - only included when they differ from system values
- Applies pagination (offset/limit) with token budget constraint
- Field selection stops when EITHER limit is reached OR ENTITY_SCHEMA_TOKEN_BUDGET is exceeded
- Adds schemaFieldsTruncated metadata when fields are cut
3. **Remove duplicates**: Deletes editableSchemaMetadata after merging into schemaMetadata
4. **Truncate view definitions**: Limits SQL view logic to QUERY_LENGTH_HARD_LIMIT
The result is optimized for LLM tool responses: reduced token usage, no duplication,
clear distinction between system-generated and user-curated content.
Args:
raw_response: Raw entity dict from GraphQL query
sort_fn: Optional custom function to sort fields. If None, uses _sort_fields_by_priority.
Should take a list of field dicts and return an iterator of sorted fields.
offset: Number of fields to skip after sorting (default: 0)
limit: Maximum number of fields to include after offset (default: None = unlimited)
Returns:
Cleaned entity dict optimized for LLM consumption
"""
response = clean_gql_response(raw_response)
if response and (schema_metadata := response.get("schemaMetadata")):
# Remove empty platformSchema to reduce response clutter
if platform_schema := schema_metadata.get("platformSchema"):
schema_value = platform_schema.get("schema")
if not schema_value or schema_value == "":
del schema_metadata["platformSchema"]
# Clean schemaMetadata.fields to keep important fields while reducing size
# Keep fields essential for SQL generation and understanding schema structure
if fields := schema_metadata.get("fields"):
total_fields = len(fields) # Use original count before any filtering
# Build editable map from editableSchemaMetadata for merging
# Make this safe - if duplicate fieldPaths exist, last one wins (no failure)
editable_map = {}
if editable_schema := response.get("editableSchemaMetadata"):
if editable_fields := editable_schema.get("editableSchemaFieldInfo"):
for editable_field in editable_fields:
if field_path := editable_field.get("fieldPath"):
editable_map[field_path] = editable_field
# Sort fields using custom function or default priority sorting
sort_function = sort_fn if sort_fn is not None else _sort_fields_by_priority
sorted_fields = sort_function(fields)
cleaned_fields = _clean_schema_fields(sorted_fields, editable_map)
# Apply offset, limit, and token budget to select fields
selected_fields: list[dict] = []
accumulated_tokens = 0
fields_remaining = limit # None means unlimited
for idx, field in enumerate(cleaned_fields):
# Skip fields before offset
if idx < offset:
continue
field_tokens = TokenCountEstimator.estimate_dict_tokens(field)