-
Notifications
You must be signed in to change notification settings - Fork 122
Expand file tree
/
Copy pathsemantic_layer.py
More file actions
105 lines (94 loc) · 3.96 KB
/
semantic_layer.py
File metadata and controls
105 lines (94 loc) · 3.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import annotations
from typing import TYPE_CHECKING
from dbt_mcp.config.headers import (
SemanticLayerHeadersProvider,
)
from dbt_mcp.errors import NotFoundError
from .base import (
ConfigProvider,
CredentialsProviderProtocol,
MultiProjectConfigProvider,
SemanticLayerConfig,
)
if TYPE_CHECKING:
from dbt_mcp.dbt_admin.client import DbtAdminAPIClient
class DefaultSemanticLayerConfigProvider(ConfigProvider[SemanticLayerConfig]):
def __init__(
self,
credentials_provider: CredentialsProviderProtocol,
*,
metrics_related_max: int = 10,
max_response_chars: int = 16000,
):
self.credentials_provider = credentials_provider
self.metrics_related_max = metrics_related_max
self.max_response_chars = max_response_chars
async def get_config(self) -> SemanticLayerConfig:
settings, token_provider = await self.credentials_provider.get_credentials()
assert settings.actual_host and settings.actual_prod_environment_id
is_local = settings.actual_host and settings.actual_host.startswith("localhost")
if is_local:
host = settings.actual_host
elif settings.actual_host_prefix:
host = f"{settings.actual_host_prefix}.semantic-layer.{settings.base_host}"
else:
host = f"semantic-layer.{settings.actual_host}"
assert host is not None
return SemanticLayerConfig(
url=f"http://{host}" if is_local else f"https://{host}" + "/api/graphql",
host=host,
prod_environment_id=settings.actual_prod_environment_id,
token_provider=token_provider,
headers_provider=SemanticLayerHeadersProvider(
token_provider=token_provider
),
metrics_related_max=self.metrics_related_max,
max_response_chars=self.max_response_chars,
)
class MultiProjectSemanticLayerConfigProvider(
MultiProjectConfigProvider[SemanticLayerConfig]
):
def __init__(
self,
credentials_provider: CredentialsProviderProtocol,
admin_client: DbtAdminAPIClient,
*,
metrics_related_max: int = 10,
max_response_chars: int = 16000,
):
self.credentials_provider = credentials_provider
self.admin_client = admin_client
self.metrics_related_max = metrics_related_max
self.max_response_chars = max_response_chars
async def get_config(self, project_id: int) -> SemanticLayerConfig:
settings, token_provider = await self.credentials_provider.get_credentials()
assert settings.actual_host
if settings.dbt_project_ids and project_id not in settings.dbt_project_ids:
raise ValueError(
f"Project {project_id} is not in the selected projects. "
f"Available project IDs: {settings.dbt_project_ids}"
)
is_local = settings.actual_host and settings.actual_host.startswith("localhost")
if is_local:
host = settings.actual_host
elif settings.actual_host_prefix:
host = f"{settings.actual_host_prefix}.semantic-layer.{settings.base_host}"
else:
host = f"semantic-layer.{settings.actual_host}"
assert host is not None
prod_env, _ = await self.admin_client.get_environments_for_project(project_id)
if not prod_env or not prod_env.id:
raise NotFoundError(
f"No production environment found for project {project_id}"
)
return SemanticLayerConfig(
url=f"http://{host}" if is_local else f"https://{host}" + "/api/graphql",
host=host,
prod_environment_id=prod_env.id,
token_provider=token_provider,
headers_provider=SemanticLayerHeadersProvider(
token_provider=token_provider
),
metrics_related_max=self.metrics_related_max,
max_response_chars=self.max_response_chars,
)