-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdata_utils.py
More file actions
358 lines (281 loc) · 12 KB
/
data_utils.py
File metadata and controls
358 lines (281 loc) · 12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
"""Utilities for data loading, Neo4j operations, and Databricks AI services.
This module provides embedding generation using Databricks Foundation Model APIs
(hosted models like BGE and GTE) which are pre-deployed and ready to use.
Available Databricks Embedding Models:
- databricks-bge-large-en: 1024 dimensions, 512 token context
- databricks-gte-large-en: 1024 dimensions, 8192 token context
These models use OpenAI-compatible API format and are accessed via
the MLflow deployments client when running in Databricks.
"""
import asyncio
import concurrent.futures
from pathlib import Path
from typing import Any, List, Optional, Type, Union
import mlflow.deployments
from neo4j import GraphDatabase
from pydantic import BaseModel
from neo4j_graphrag.embeddings.base import Embedder
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import FixedSizeSplitter
from neo4j_graphrag.llm.base import LLMInterfaceV2
from neo4j_graphrag.llm.types import LLMResponse
from neo4j_graphrag.types import LLMMessage
# =============================================================================
# Default Model Configuration
# =============================================================================
DEFAULT_EMBEDDING_MODEL = "databricks-bge-large-en"
DEFAULT_LLM_MODEL = "databricks-meta-llama-3-3-70b-instruct"
# =============================================================================
# Databricks Embeddings
# =============================================================================
class DatabricksEmbeddings(Embedder):
"""Generate embeddings using Databricks Foundation Model APIs.
Databricks provides pre-deployed embedding models as part of the
Foundation Model APIs. These are ready to use without deployment.
Available Models:
- databricks-bge-large-en: 1024 dims, 512 token context
- databricks-gte-large-en: 1024 dims, 8192 token context
API Format (OpenAI-Compatible):
Input: {"input": ["text1", "text2"]}
Output: {"data": [{"embedding": [0.1, ...]}, ...]}
Example:
>>> embedder = DatabricksEmbeddings(model_id="databricks-bge-large-en")
>>> embedding = embedder.embed_query("test text")
>>> len(embedding)
1024
"""
def __init__(self, model_id: str = "databricks-bge-large-en"):
"""Initialize the Databricks embeddings provider.
Args:
model_id: The Databricks Foundation Model endpoint name.
Default: databricks-bge-large-en (1024 dimensions)
"""
self.model_id = model_id
self._client = mlflow.deployments.get_deploy_client("databricks")
def embed_query(self, text: str) -> list[float]:
"""Generate embedding for a single text string.
Uses the MLflow deployments client to call the Databricks
Foundation Model API with OpenAI-compatible format.
Args:
text: The text to embed
Returns:
List of floats representing the embedding vector (1024 dimensions)
"""
response = self._client.predict(
endpoint=self.model_id,
inputs={"input": [text]},
)
return response["data"][0]["embedding"]
# =============================================================================
# Databricks LLM
# =============================================================================
class DatabricksLLM(LLMInterfaceV2):
"""LLM interface using Databricks Foundation Model APIs.
Implements LLMInterfaceV2 for compatibility with GraphRAG and other
neo4j-graphrag components.
Supports Databricks-hosted LLM endpoints like:
- databricks-meta-llama-3-3-70b-instruct
- databricks-dbrx-instruct
- databricks-mixtral-8x7b-instruct
Uses MLflow deployments client for API calls.
"""
def __init__(self, model_id: str = "databricks-meta-llama-3-3-70b-instruct"):
"""Initialize the Databricks LLM provider.
Args:
model_id: The Databricks Foundation Model endpoint name.
"""
super().__init__(model_name=model_id)
self.model_id = model_id
self._client = mlflow.deployments.get_deploy_client("databricks")
def invoke(
self,
input: List[LLMMessage],
response_format: Optional[Union[Type[BaseModel], dict[str, Any]]] = None,
**kwargs: Any,
) -> LLMResponse:
"""Generate a response from the LLM.
Args:
input: List of messages in LLMMessage format (role + content dicts).
response_format: Optional response format (not used by Databricks API).
Returns:
LLMResponse containing the generated text
"""
messages = [{"role": msg["role"], "content": msg["content"]} for msg in input]
response = self._client.predict(
endpoint=self.model_id,
inputs={
"messages": messages,
"max_tokens": 2048,
},
)
content = response["choices"][0]["message"]["content"]
return LLMResponse(content=content)
async def ainvoke(
self,
input: List[LLMMessage],
response_format: Optional[Union[Type[BaseModel], dict[str, Any]]] = None,
**kwargs: Any,
) -> LLMResponse:
"""Async version of invoke (runs synchronously)."""
return self.invoke(input, response_format=response_format)
# =============================================================================
# AI Services Factory Functions
# =============================================================================
def get_embedder(model_id: str = DEFAULT_EMBEDDING_MODEL) -> DatabricksEmbeddings:
"""Get embedder using Databricks Foundation Model APIs.
Args:
model_id: Databricks embedding endpoint name.
Default: databricks-bge-large-en (1024 dimensions)
Returns:
DatabricksEmbeddings configured for the specified model
"""
return DatabricksEmbeddings(model_id=model_id)
def get_llm(model_id: str = DEFAULT_LLM_MODEL) -> DatabricksLLM:
"""Get LLM using Databricks Foundation Model APIs.
Args:
model_id: Databricks LLM endpoint name.
Default: databricks-meta-llama-3-3-70b-instruct
Returns:
DatabricksLLM configured for the specified model
"""
return DatabricksLLM(model_id=model_id)
# =============================================================================
# Neo4j Connection
# =============================================================================
class Neo4jConnection:
"""Manages Neo4j database connection."""
def __init__(self, uri: str, username: str, password: str):
"""Initialize and connect to Neo4j.
Args:
uri: Neo4j URI (e.g., "neo4j+s://xxxxxxxx.databases.neo4j.io")
username: Neo4j username (typically "neo4j")
password: Neo4j password
"""
self.uri = uri
self.username = username
self.password = password
self.driver = GraphDatabase.driver(
self.uri,
auth=(self.username, self.password)
)
def verify(self):
"""Verify the connection is working."""
self.driver.verify_connectivity()
print("Connected to Neo4j successfully!")
return self
def clear_chunks(self):
"""Remove all Document and Chunk nodes (preserves aircraft graph from Lab 5)."""
records, _, _ = self.driver.execute_query("""
MATCH (n) WHERE n:Document OR n:Chunk
DETACH DELETE n
RETURN count(n) as deleted
""")
count = records[0]["deleted"]
print(f"Deleted {count} Document/Chunk nodes")
return self
def get_graph_stats(self):
"""Show current graph statistics."""
records, _, _ = self.driver.execute_query("""
MATCH (n)
WITH labels(n) as nodeLabels
UNWIND nodeLabels as label
RETURN label, count(*) as count
ORDER BY label
""")
print("=== Graph Statistics ===")
for record in records:
print(f" {record['label']}: {record['count']}")
return self
def close(self):
"""Close the database connection."""
self.driver.close()
print("Connection closed.")
# =============================================================================
# Data Loading
# =============================================================================
# Default Volume path for workshop data
DEFAULT_VOLUME_PATH = "/Volumes/aws-databricks-neo4j-lab/lab-schema/lab-volume"
class DataLoader:
"""Handles loading text data from files (local or Unity Catalog Volume)."""
def __init__(self, file_path: str):
"""Initialize with path to data file.
Args:
file_path: Path to the file. Can be:
- Relative path (loaded from current directory)
- Absolute local path
- Volume path (e.g., /Volumes/catalog/schema/volume/file.md)
"""
self.file_path = Path(file_path)
self._text = None
@property
def text(self) -> str:
"""Load and return the text content from the file."""
if self._text is None:
self._text = self.file_path.read_text().strip()
return self._text
def get_metadata(self) -> dict:
"""Return metadata about the loaded file."""
return {
"path": str(self.file_path),
"name": self.file_path.name,
"size": len(self.text)
}
class VolumeDataLoader:
"""Handles loading text data from Unity Catalog Volumes.
Unity Catalog Volumes are accessible as file paths in Databricks:
/Volumes/<catalog>/<schema>/<volume>/<file>
Example:
>>> loader = VolumeDataLoader("maintenance_manual.md")
>>> text = loader.text
"""
def __init__(self, file_name: str, volume_path: str = DEFAULT_VOLUME_PATH):
"""Initialize with file name and optional Volume path.
Args:
file_name: Name of the file in the Volume (e.g., "maintenance_manual.md")
volume_path: Path to the Unity Catalog Volume.
Defaults to /Volumes/aws-databricks-neo4j-lab/lab-schema/lab-volume
"""
self.volume_path = Path(volume_path)
self.file_name = file_name
self.file_path = self.volume_path / file_name
self._text = None
@property
def text(self) -> str:
"""Load and return the text content from the Volume."""
if self._text is None:
self._text = self.file_path.read_text().strip()
return self._text
def get_metadata(self) -> dict:
"""Return metadata about the loaded file."""
return {
"path": str(self.file_path),
"name": self.file_name,
"volume": str(self.volume_path),
"size": len(self.text)
}
# =============================================================================
# Text Splitting
# =============================================================================
def split_text(text: str, chunk_size: int = 500, chunk_overlap: int = 50) -> list[str]:
"""Split text into chunks using FixedSizeSplitter.
Args:
text: Text to split
chunk_size: Maximum characters per chunk
chunk_overlap: Characters to overlap between chunks
Returns:
List of chunk text strings
"""
splitter = FixedSizeSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
approximate=True
)
# Run in a separate thread to avoid "asyncio.run() cannot be called from
# a running event loop" in Jupyter/Databricks environments.
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
result = pool.submit(asyncio.run, splitter.run(text)).result()
return [chunk.text for chunk in result.chunks]
# =============================================================================
# Embedding Configuration
# =============================================================================
# Databricks BGE and GTE models produce 1024-dimensional vectors
EMBEDDING_DIMENSIONS = 1024