Skip to content

Commit bdb83ca

Browse files
authored
feat(plotting): add data plotting capabilities with Plotly.js (#25)
* implement visualization generation * fix html index slicing * update styling guidelines * harmonize default gpt model
1 parent 0c8c933 commit bdb83ca

File tree

3 files changed

+304
-1
lines changed

3 files changed

+304
-1
lines changed

llm_engine/src/llm_processor.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ pub async fn process_natural_language_query(
4646
// Get model name from parameters or environment
4747
let model_name = model
4848
.or_else(|| env::var("LLM_MODEL").ok())
49-
.unwrap_or_else(|| "gpt-4".to_string());
49+
.unwrap_or_else(|| "gpt-3.5-turbo".to_string());
5050

5151
// Generate prompt with query and schema context
5252
let prompt = generate_prompt(&query, db_schema.as_ref());

llm_engine/src/main.rs

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use axum::{
99
response::IntoResponse,
1010
};
1111
use serde::{Deserialize, Serialize};
12+
use serde_json::Value;
1213
use tower_http::cors::{CorsLayer, Any};
1314
use tracing::{info, error};
1415
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
@@ -33,6 +34,14 @@ struct QueryRequest {
3334
model: Option<String>,
3435
}
3536

37+
// Request model for visualization generation
38+
#[derive(Deserialize)]
39+
struct VisualizationRequest {
40+
query: String,
41+
results: Value,
42+
model: Option<String>,
43+
}
44+
3645
// Response model
3746
#[derive(Serialize)]
3847
struct QueryResponse {
@@ -41,6 +50,14 @@ struct QueryResponse {
4150
confidence: Option<f64>,
4251
}
4352

53+
// Response model for visualization generation
54+
#[derive(Serialize)]
55+
struct VisualizationResponse {
56+
html_code: String,
57+
explanation: String,
58+
confidence: f64,
59+
}
60+
4461
// Error handling
4562
enum AppError {
4663
InternalError(String),
@@ -110,6 +127,204 @@ async fn process_query(
110127
}
111128
}
112129

130+
// Process visualization request
131+
async fn generate_visualization(
132+
State(_state): State<Arc<AppState>>,
133+
Json(request): Json<VisualizationRequest>,
134+
) -> Result<impl IntoResponse, AppError> {
135+
info!("Generating visualization for query: {}", request.query);
136+
137+
// Get the model name from the request or use a default
138+
let model = request.model.unwrap_or_else(|| "gpt-3.5-turbo".to_string());
139+
140+
// Format the results data for the prompt
141+
let results_json = serde_json::to_string_pretty(&request.results)
142+
.unwrap_or_else(|_| format!("{:?}", request.results));
143+
144+
// Create the prompt for visualization generation
145+
let system_prompt = format!(
146+
"You are a data visualization expert. Your task is to create a Plotly.js visualization based on the provided data and query. \
147+
You must return a complete, valid HTML file that uses Plotly.js to visualize the data.\
148+
\
149+
STYLING GUIDELINES:\
150+
- Use Plotly.js to create the chart\
151+
- Use 14px font for axis labels, 18px for titles\
152+
- Use consistent margins and padding\
153+
- Use a neobrutalist style\
154+
- Include axis titles with units (if known)\
155+
- Rotate x-axis labels if they are dates or long strings\
156+
- Use tight layout with `autosize: true` and `responsive: true`\
157+
- Enable zoom and pan interactivity\
158+
- Enable tooltips on hover showing exact values and labels\
159+
- Use hovermode: 'closest'\
160+
\
161+
OUTPUT FORMAT:\
162+
Your response should be a complete HTML file that can be directly viewed in a browser.\
163+
Return valid HTML that includes the Plotly.js library from a CDN and creates the visualization.\
164+
Also include a brief explanation of the visualization choices you made."
165+
);
166+
167+
let user_prompt = format!(
168+
"Natural Language Query: {}\n\nData Results:\n{}\n\nBased on this query and data, create a complete HTML file with a Plotly.js visualization.",
169+
request.query, results_json
170+
);
171+
172+
// Call the LLM with the specialized prompt
173+
let llm_response = call_llm_api(&system_prompt, &user_prompt, &model).await
174+
.map_err(|e| {
175+
let error_msg = format!("Error calling LLM API: {}", e);
176+
error!("{}", error_msg);
177+
AppError::InternalError(error_msg)
178+
})?;
179+
180+
// Extract the HTML code and explanation from the response
181+
let (html_code, explanation, confidence) = parse_visualization_response(&llm_response);
182+
183+
Ok(Json(VisualizationResponse {
184+
html_code,
185+
explanation,
186+
confidence,
187+
}))
188+
}
189+
190+
// Call LLM API with system and user prompts
191+
async fn call_llm_api(system_prompt: &str, user_prompt: &str, model_name: &str) -> Result<String, anyhow::Error> {
192+
let api_key = env::var("LLM_API_KEY").map_err(|_| anyhow::anyhow!("LLM_API_KEY environment variable not set"))?;
193+
194+
let client = reqwest::Client::new();
195+
196+
#[derive(Serialize, Deserialize)]
197+
struct Message {
198+
role: String,
199+
content: String,
200+
}
201+
202+
#[derive(Serialize)]
203+
struct OpenAIRequest {
204+
model: String,
205+
messages: Vec<Message>,
206+
temperature: f64,
207+
}
208+
209+
let request = OpenAIRequest {
210+
model: model_name.to_string(),
211+
messages: vec![
212+
Message {
213+
role: "system".to_string(),
214+
content: system_prompt.to_string(),
215+
},
216+
Message {
217+
role: "user".to_string(),
218+
content: user_prompt.to_string(),
219+
},
220+
],
221+
temperature: 0.7, // Slightly higher temperature for more creative visualizations
222+
};
223+
224+
let response = client
225+
.post("https://api.openai.com/v1/chat/completions")
226+
.header("Authorization", format!("Bearer {}", api_key))
227+
.header("Content-Type", "application/json")
228+
.json(&request)
229+
.send()
230+
.await?;
231+
232+
if !response.status().is_success() {
233+
let error_text = response.text().await?;
234+
return Err(anyhow::anyhow!("LLM API returned error: {}", error_text));
235+
}
236+
237+
#[derive(Deserialize)]
238+
struct OpenAIChoice {
239+
message: Message,
240+
}
241+
242+
#[derive(Deserialize)]
243+
struct OpenAIResponse {
244+
choices: Vec<OpenAIChoice>,
245+
}
246+
247+
let response_json: OpenAIResponse = response.json().await?;
248+
249+
if response_json.choices.is_empty() {
250+
return Err(anyhow::anyhow!("LLM API returned empty choices"));
251+
}
252+
253+
Ok(response_json.choices[0].message.content.clone())
254+
}
255+
256+
// Parse LLM response to extract HTML, explanation and confidence
257+
fn parse_visualization_response(response: &str) -> (String, String, f64) {
258+
// Try to extract HTML content - look for <!DOCTYPE html> or <html>
259+
let html_start_patterns = ["<!DOCTYPE html>", "<html>"];
260+
let mut html_code = String::new();
261+
let mut explanation = String::new();
262+
let confidence = 0.8; // Default confidence
263+
264+
// First, check if the response contains a code block with HTML
265+
if let Some(html_block_start) = response.find("```html") {
266+
// Find the end of the code block (next ```)
267+
if let Some(html_block_end) = response[html_block_start + 6..].find("```") {
268+
// Extract HTML content (skip the ```html and end ```)
269+
let block_start_pos = html_block_start + "```html".len();
270+
let block_end_pos = html_block_start + 6 + html_block_end;
271+
html_code = response[block_start_pos..block_end_pos].trim().to_string();
272+
273+
// Look for explanation after the HTML block
274+
if block_end_pos + 3 < response.len() {
275+
explanation = response[block_end_pos + 3..].trim().to_string();
276+
}
277+
}
278+
}
279+
// If no code block, try to find direct HTML
280+
else {
281+
for pattern in html_start_patterns.iter() {
282+
if let Some(start_idx) = response.find(pattern) {
283+
html_code = response[start_idx..].trim().to_string();
284+
285+
// Everything before HTML is considered explanation
286+
if start_idx > 0 {
287+
explanation = response[0..start_idx].trim().to_string();
288+
}
289+
break;
290+
}
291+
}
292+
}
293+
294+
// If still no HTML found, look for any content between <script> tags or <div id="plot">
295+
if html_code.is_empty() {
296+
if let Some(script_start) = response.find("<script>") {
297+
if let Some(script_end) = response[script_start..].find("</script>") {
298+
// Create a basic HTML wrapper around the script
299+
let script_content = &response[script_start..script_start + script_end + 9];
300+
html_code = format!(
301+
"<!DOCTYPE html>\n<html>\n<head>\n<title>Visualization</title>\n<script src=\"https://cdn.plot.ly/plotly-latest.min.js\"></script>\n</head>\n<body>\n<div id=\"plot\"></div>\n{}\n</body>\n</html>",
302+
script_content
303+
);
304+
305+
// Everything else is explanation
306+
explanation = response.replace(script_content, "").trim().to_string();
307+
}
308+
}
309+
}
310+
311+
// If still nothing found, return the whole response as HTML with a warning
312+
if html_code.is_empty() {
313+
html_code = format!(
314+
"<!DOCTYPE html>\n<html>\n<head>\n<title>Visualization Error</title>\n</head>\n<body>\n<h1>Could not generate visualization</h1>\n<pre>{}</pre>\n</body>\n</html>",
315+
response.replace("<", "&lt;").replace(">", "&gt;")
316+
);
317+
explanation = "Could not parse LLM response into valid HTML visualization.".to_string();
318+
}
319+
320+
// If explanation is empty, provide a default
321+
if explanation.is_empty() {
322+
explanation = "Visualization generated from the provided data.".to_string();
323+
}
324+
325+
(html_code, explanation, confidence)
326+
}
327+
113328
#[tokio::main]
114329
async fn main() {
115330
// Load environment variables
@@ -131,6 +346,7 @@ async fn main() {
131346
.route("/", get(root))
132347
.route("/health", get(health_check))
133348
.route("/process-query", post(process_query))
349+
.route("/generate", post(generate_visualization))
134350
.layer(
135351
CorsLayer::new()
136352
.allow_origin(Any)

query_router/src/main.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ struct TranslateAndExecuteRequest {
2222
model: String,
2323
}
2424

25+
// Request model for visualization generation
26+
#[derive(Debug, Deserialize)]
27+
struct VisualizationRequest {
28+
natural_query: String,
29+
results: Value,
30+
#[serde(default = "default_model")]
31+
model: String,
32+
}
33+
2534
fn default_model() -> String {
2635
"gpt-3.5-turbo".to_string()
2736
}
@@ -43,6 +52,14 @@ struct ResponseMetadata {
4352
total_time_ms: u64,
4453
}
4554

55+
// Response model for visualization generation
56+
#[derive(Debug, Serialize)]
57+
struct VisualizationResponse {
58+
html_code: String,
59+
explanation: String,
60+
metadata: ResponseMetadata,
61+
}
62+
4663
// LLM Engine response structure
4764
#[derive(Debug, Deserialize)]
4865
struct LlmResponse {
@@ -51,6 +68,14 @@ struct LlmResponse {
5168
confidence: f64,
5269
}
5370

71+
// LLM Engine visualization response structure
72+
#[derive(Debug, Deserialize)]
73+
struct LlmVisualizationResponse {
74+
html_code: String,
75+
explanation: String,
76+
confidence: f64,
77+
}
78+
5479
// Application state
5580
struct AppState {
5681
client: Client,
@@ -128,6 +153,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
128153
let app = Router::new()
129154
.route("/health", get(health_check))
130155
.route("/translate-and-execute", post(translate_and_execute))
156+
.route("/visualize", post(generate_visualization))
131157
.with_state(state)
132158
.layer(middleware);
133159

@@ -262,3 +288,64 @@ async fn execute_sql_query(state: &AppState, sql_query: &str) -> Result<Value, A
262288

263289
Ok(result)
264290
}
291+
292+
// Endpoint for generating visualizations from natural language and data
293+
async fn generate_visualization(
294+
State(state): State<Arc<AppState>>,
295+
Json(request): Json<VisualizationRequest>,
296+
) -> Result<Json<VisualizationResponse>, AppError> {
297+
let start_time = Instant::now();
298+
299+
// Call LLM engine to generate visualization
300+
let llm_start_time = Instant::now();
301+
let llm_response = call_llm_visualization_engine(&state, &request).await?;
302+
let llm_processing_time = llm_start_time.elapsed().as_millis() as u64;
303+
304+
// Build the response
305+
let total_time = start_time.elapsed().as_millis() as u64;
306+
307+
let response = VisualizationResponse {
308+
html_code: llm_response.html_code,
309+
explanation: llm_response.explanation,
310+
metadata: ResponseMetadata {
311+
confidence: llm_response.confidence,
312+
execution_time_ms: 0, // No SQL execution in this flow
313+
llm_processing_time_ms: llm_processing_time,
314+
total_time_ms: total_time,
315+
},
316+
};
317+
318+
Ok(Json(response))
319+
}
320+
321+
// Call LLM engine to generate visualization HTML/JS
322+
async fn call_llm_visualization_engine(
323+
state: &AppState,
324+
request: &VisualizationRequest
325+
) -> Result<LlmVisualizationResponse, AppError> {
326+
let url = format!("{}/generate", state.llm_engine_url);
327+
328+
let llm_request = json!({
329+
"query": request.natural_query,
330+
"results": request.results,
331+
"model": request.model
332+
});
333+
334+
let response = state.client
335+
.post(&url)
336+
.json(&llm_request)
337+
.send()
338+
.await
339+
.map_err(AppError::LlmEngineError)?;
340+
341+
if !response.status().is_success() {
342+
let status = response.status();
343+
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
344+
return Err(AppError::LlmResponseError(format!("LLM engine returned error ({}): {}", status, error_text)));
345+
}
346+
347+
let llm_response = response.json::<LlmVisualizationResponse>().await
348+
.map_err(|e| AppError::LlmResponseError(format!("Failed to parse LLM visualization response: {}", e)))?;
349+
350+
Ok(llm_response)
351+
}

0 commit comments

Comments
 (0)