Skip to content

Commit 6763d39

Browse files
Merge pull request #39 from xsqian/main
Removed the dependency on mysql service for Iguazio platform
2 parents 6a18ff0 + 3e1ebbd commit 6763d39

File tree

7 files changed

+897
-118
lines changed

7 files changed

+897
-118
lines changed

env.template

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,2 @@
11
OPENAI_API_BASE=
2-
OPENAI_API_KEY=
3-
#for platform Iguazio, you need to set the MYSQL_URL and remove S3_BUCKET_NAME
4-
MYSQL_URL=
5-
#for platform Mck, you need to set the S3_BUCKET_NAME and remove MYSQL_URL
6-
S3_BUCKET_NAME=
2+
OPENAI_API_KEY=

notebook_1_generation.ipynb

Lines changed: 409 additions & 0 deletions
Large diffs are not rendered by default.

notebook_2_analysis.ipynb

Lines changed: 439 additions & 0 deletions
Large diffs are not rendered by default.

project.yaml

Lines changed: 0 additions & 80 deletions
This file was deleted.

project_setup.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pathlib import Path
1616
import boto3
1717
import mlrun
18+
import tempfile
1819

1920
from src.calls_analysis.db_management import create_tables
2021
from src.common import ProjectSecrets
@@ -36,15 +37,15 @@ def setup(
3637
openai_key = os.getenv(ProjectSecrets.OPENAI_API_KEY)
3738
openai_base = os.getenv(ProjectSecrets.OPENAI_API_BASE)
3839
mysql_url = os.getenv(ProjectSecrets.MYSQL_URL, "")
39-
4040
# Unpack parameters:
4141
source = project.get_param(key="source")
42-
default_image = project.get_param(key="default_image", default=None)
42+
default_image = project.get_param(key="default_image", default=".mlrun-project-image-call-center-demo")
4343
build_image = project.get_param(key="build_image", default=False)
4444
gpus = project.get_param(key="gpus", default=0)
4545
node_name = project.get_param(key="node_name", default=None)
4646
node_selector = project.get_param(key="node_selector", default=None)
47-
use_sqlite = project.get_param(key="use_sqlite", default=False)
47+
use_sqlite = project.get_param(key="use_sqlite", default=True)
48+
skip_calls_generation = project.get_param(key="skip_calls_generation", default=False)
4849

4950
# Update sqlite data:
5051
if use_sqlite:
@@ -53,15 +54,16 @@ def setup(
5354
s3 = boto3.client("s3") if not os.getenv("AWS_ENDPOINT_URL_S3") else boto3.client('s3', endpoint_url=os.getenv("AWS_ENDPOINT_URL_S3"))
5455
bucket_name = Path(mlrun.mlconf.artifact_path).parts[1]
5556
# Upload the file
56-
s3.upload_file(
57-
Filename="data/sqlite.db",
58-
Bucket=bucket_name,
59-
Key="sqlite.db",
60-
)
57+
if not skip_calls_generation and Path("./data/sqlite.db").exists():
58+
s3.upload_file(
59+
Filename="data/sqlite.db",
60+
Bucket=bucket_name,
61+
Key="sqlite.db",
62+
)
6163
os.environ["S3_BUCKET_NAME"] = bucket_name
6264
else:
63-
os.environ["MYSQL_URL"] = f"sqlite:///{os.path.abspath('.')}/data/sqlite.db"
64-
mysql_url = os.environ["MYSQL_URL"]
65+
if not skip_calls_generation and Path("./data/sqlite.db").exists():
66+
project.log_artifact("sqlite-db", local_path="./data/sqlite.db", upload=True)
6567

6668
# Set the project git source:
6769
if source:
@@ -72,6 +74,7 @@ def setup(
7274
# Set default image:
7375
if default_image:
7476
project.set_default_image(default_image)
77+
print(f"set default image to : {default_image}")
7578

7679
# Build the image:
7780
if build_image:
@@ -117,11 +120,12 @@ def setup(
117120
]
118121
app.save()
119122

120-
# Create the DB tables:
121-
create_tables()
123+
# # Create the DB tables:
124+
# create_tables()
122125

123126
# Save and return the project:
124127
project.save()
128+
125129
return project
126130

127131
def _build_image(project: mlrun.projects.MlrunProject, with_gpu: bool, default_image):
@@ -133,7 +137,8 @@ def _build_image(project: mlrun.projects.MlrunProject, with_gpu: bool, default_i
133137
# Define commands in logical groups while maintaining order
134138
system_commands = [
135139
# Update apt-get to install ffmpeg (support audio file formats):
136-
"apt-get update -y && apt-get install ffmpeg -y"
140+
"apt-get update -y && apt-get install ffmpeg -y",
141+
"python --version"
137142
]
138143

139144
infrastructure_requirements = [
@@ -150,16 +155,17 @@ def _build_image(project: mlrun.projects.MlrunProject, with_gpu: bool, default_i
150155
] if with_gpu else []
151156

152157
other_requirements = [
153-
"pip install mlrun langchain==0.2.17 openai==1.58.1 langchain_community==0.2.19 pydub==0.25.1 streamlit==1.28.0 st-annotated-text==4.0.1 spacy==3.7.1 librosa==0.10.1 presidio-anonymizer==2.2.34 presidio-analyzer==2.2.34 nltk==3.8.1 flair==0.13.0 htbuilder==0.6.2",
158+
"pip install langchain==0.2.17 openai==1.58.1 langchain_community==0.2.19 pydub==0.25.1 streamlit==1.28.0 st-annotated-text==4.0.1 spacy==3.7.1 librosa==0.10.1 presidio-anonymizer==2.2.34 presidio-analyzer==2.2.34 nltk==3.8.1 flair==0.13.0 htbuilder==0.6.2",
154159
"pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_lg-3.7.1/en_core_web_lg-3.7.1.tar.gz",
155160
# "python -m spacy download en_core_web_lg",
156161

157162
"pip install SQLAlchemy==2.0.31 pymysql requests_toolbelt==0.10.1",
158163
"pip uninstall -y onnxruntime-gpu onnxruntime",
159164
f"pip install {config['onnx_package']}",
160-
"pip uninstall -y protobuf",
161-
"pip install protobuf"
162-
]
165+
"pip show protobuf",
166+
]
167+
if not CE_MODE:
168+
other_requirements.extend(["pip uninstall -y protobuf", "pip install protobuf", "pip show protobuf"])
163169

164170
# Combine commands in the required order
165171
commands = (
@@ -179,7 +185,6 @@ def _build_image(project: mlrun.projects.MlrunProject, with_gpu: bool, default_i
179185
overwrite_build_params=True
180186
)
181187

182-
183188
def _set_secrets(
184189
project: mlrun.projects.MlrunProject,
185190
openai_key: str,
@@ -241,21 +246,23 @@ def _set_function(
241246
if not CE_MODE and apply_auto_mount:
242247
# Apply auto mount:
243248
mlrun_function.apply(mlrun.auto_mount())
249+
244250
# Save:
245251
mlrun_function.save()
246252

247253

248254
def _set_calls_generation_functions(
249255
project: mlrun.projects.MlrunProject,
250256
node_name: str = None,
251-
image: str = ".mlrun-project-image"
257+
image: str = ".mlrun-project-image-call-center-demo"
252258
):
253259
# Client and agent data generator
254260
_set_function(
255261
project=project,
256262
func="hub://structured_data_generator",
257263
name="structured-data-generator",
258264
kind="job",
265+
image=image,
259266
node_name=node_name,
260267
apply_auto_mount=True,
261268
)
@@ -355,7 +362,6 @@ def _set_calls_analysis_functions(
355362

356363

357364
def _set_workflows(project: mlrun.projects.MlrunProject, image):
358-
359365
project.set_workflow(
360366
name="calls-generation", workflow_path="./src/workflows/calls_generation.py", image=image
361367
)

src/calls_analysis/db_management.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import os
1616
import tempfile
1717
from typing import List, Optional, Tuple
18-
1918
import boto3
2019
import mlrun
2120
import pandas as pd
@@ -184,6 +183,7 @@ class DBEngine:
184183
def __init__(self):
185184
self.bucket_name = mlrun.get_secret_or_env(key=ProjectSecrets.S3_BUCKET_NAME)
186185
self.db_url = mlrun.get_secret_or_env(key=ProjectSecrets.MYSQL_URL)
186+
self.project = None
187187
self.temp_file = None
188188
self.engine = self._create_engine()
189189

@@ -194,22 +194,27 @@ def update_db(self):
194194
if self.bucket_name:
195195
s3 = boto3.client("s3")
196196
s3.upload_file(self.temp_file.name, self.bucket_name, "sqlite.db")
197+
else:
198+
# register the temp sqlite.db to project artifact
199+
self.project.log_artifact("sqlite-db", local_path=self.temp_file.name, upload=True)
197200

198201
def _create_engine(self):
202+
# Create a temporary file that will persist throughout the object's lifetime
203+
self.temp_file = tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False)
204+
self.temp_file.close() # Close the file but keep the name
199205
if self.bucket_name:
200-
# Create a temporary file that will persist throughout the object's lifetime
201-
self.temp_file = tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False)
202-
self.temp_file.close() # Close the file but keep the name
203-
204206
s3 = boto3.client("s3")
205207
try:
206208
s3.download_file(self.bucket_name, "sqlite.db", self.temp_file.name)
207209
except Exception as e:
208210
print(f"Warning: Could not download database from S3: {e}")
209-
210-
return create_engine(f"sqlite:///{self.temp_file.name}")
211211
else:
212-
return create_engine(url=self.db_url)
212+
#iguazio platform get artifact to local
213+
self.project = mlrun.get_current_project()
214+
with open(self.temp_file.name, 'wb') as tmpfile:
215+
tmpfile.write(self.project.get_artifact('sqlite-db').to_dataitem().get())
216+
217+
return create_engine(f"sqlite:///{self.temp_file.name}")
213218

214219
def __del__(self):
215220
# Clean up the temporary file when the object is destroyed
@@ -230,10 +235,9 @@ def create_tables():
230235
# Base.metadata.drop_all(engine.engine)
231236
# Create the schema's tables
232237
Base.metadata.create_all(engine.engine)
233-
238+
print('Tables created!')
234239
engine.update_db()
235240

236-
237241
def insert_clients(clients: list):
238242
# Create an engine:
239243
engine = DBEngine()

src/calls_generation/skip.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,18 @@ def skip_and_import_local_data(language: str):
2929
"""
3030
# Get the example data directory:
3131
example_data_dir = Path("data")
32+
3233
# Get the project:
3334
project = mlrun.get_current_project()
3435

3536
# clean and recreate database tables:
3637
engine = DBEngine()
38+
print(f"in skip_and_import_local_data engine: {engine}")
39+
print(f"in skip_and_import_local_data engine.engine: {engine.engine}")
3740
Call.__table__.drop(engine.engine)
3841
Client.__table__.drop(engine.engine)
3942
Agent.__table__.drop(engine.engine)
43+
print(f"in skip_and_import_local_data all tables fropped")
4044
create_tables()
4145
print("- Initialized tables")
4246

@@ -84,7 +88,7 @@ def skip_and_import_local_data(language: str):
8488
clients = yaml.load(clients.get(), Loader=yaml.FullLoader)
8589

8690
# insert agent and client data to database:
87-
_insert_agents_and_clients_to_db(agents, clients)
91+
_insert_agents_and_clients_to_db(engine, agents, clients)
8892
print("- agents and clients inserted")
8993

9094
# log zip files
@@ -137,9 +141,10 @@ def skip_and_import_local_data(language: str):
137141
print("*** first workflow skipped successfully ***")
138142

139143

140-
def _insert_agents_and_clients_to_db(agents: list, clients: list):
144+
def _insert_agents_and_clients_to_db(engine: DBEngine, agents: list, clients: list):
141145
# Create an engine:
142146
engine = DBEngine()
147+
print(f'engin: {engine} created.')
143148

144149
# Initialize a session maker:
145150
session = engine.get_session()
@@ -148,7 +153,7 @@ def _insert_agents_and_clients_to_db(agents: list, clients: list):
148153
with session.begin() as sess:
149154
sess.execute(insert(Agent), agents)
150155
sess.execute(insert(Client), clients)
151-
156+
engine.update_db()
152157

153158
# TODO: change to export the actual data and not the artifacts
154159
def save_current_example_data():

0 commit comments

Comments
 (0)