@@ -9,6 +9,7 @@ use axum::{
99 response:: IntoResponse ,
1010} ;
1111use serde:: { Deserialize , Serialize } ;
12+ use serde_json:: Value ;
1213use tower_http:: cors:: { CorsLayer , Any } ;
1314use tracing:: { info, error} ;
1415use tracing_subscriber:: { layer:: SubscriberExt , util:: SubscriberInitExt } ;
@@ -33,6 +34,14 @@ struct QueryRequest {
3334 model : Option < String > ,
3435}
3536
37+ // Request model for visualization generation
38+ #[ derive( Deserialize ) ]
39+ struct VisualizationRequest {
40+ query : String ,
41+ results : Value ,
42+ model : Option < String > ,
43+ }
44+
3645// Response model
3746#[ derive( Serialize ) ]
3847struct QueryResponse {
@@ -41,6 +50,14 @@ struct QueryResponse {
4150 confidence : Option < f64 > ,
4251}
4352
53+ // Response model for visualization generation
54+ #[ derive( Serialize ) ]
55+ struct VisualizationResponse {
56+ html_code : String ,
57+ explanation : String ,
58+ confidence : f64 ,
59+ }
60+
4461// Error handling
4562enum AppError {
4663 InternalError ( String ) ,
@@ -110,6 +127,204 @@ async fn process_query(
110127 }
111128}
112129
130+ // Process visualization request
131+ async fn generate_visualization (
132+ State ( _state) : State < Arc < AppState > > ,
133+ Json ( request) : Json < VisualizationRequest > ,
134+ ) -> Result < impl IntoResponse , AppError > {
135+ info ! ( "Generating visualization for query: {}" , request. query) ;
136+
137+ // Get the model name from the request or use a default
138+ let model = request. model . unwrap_or_else ( || "gpt-3.5-turbo" . to_string ( ) ) ;
139+
140+ // Format the results data for the prompt
141+ let results_json = serde_json:: to_string_pretty ( & request. results )
142+ . unwrap_or_else ( |_| format ! ( "{:?}" , request. results) ) ;
143+
144+ // Create the prompt for visualization generation
145+ let system_prompt = format ! (
146+ "You are a data visualization expert. Your task is to create a Plotly.js visualization based on the provided data and query. \
147+ You must return a complete, valid HTML file that uses Plotly.js to visualize the data.\
148+ \
149+ STYLING GUIDELINES:\
150+ - Use Plotly.js to create the chart\
151+ - Use 14px font for axis labels, 18px for titles\
152+ - Use consistent margins and padding\
153+ - Use a neobrutalist style\
154+ - Include axis titles with units (if known)\
155+ - Rotate x-axis labels if they are dates or long strings\
156+ - Use tight layout with `autosize: true` and `responsive: true`\
157+ - Enable zoom and pan interactivity\
158+ - Enable tooltips on hover showing exact values and labels\
159+ - Use hovermode: 'closest'\
160+ \
161+ OUTPUT FORMAT:\
162+ Your response should be a complete HTML file that can be directly viewed in a browser.\
163+ Return valid HTML that includes the Plotly.js library from a CDN and creates the visualization.\
164+ Also include a brief explanation of the visualization choices you made."
165+ ) ;
166+
167+ let user_prompt = format ! (
168+ "Natural Language Query: {}\n \n Data Results:\n {}\n \n Based on this query and data, create a complete HTML file with a Plotly.js visualization." ,
169+ request. query, results_json
170+ ) ;
171+
172+ // Call the LLM with the specialized prompt
173+ let llm_response = call_llm_api ( & system_prompt, & user_prompt, & model) . await
174+ . map_err ( |e| {
175+ let error_msg = format ! ( "Error calling LLM API: {}" , e) ;
176+ error ! ( "{}" , error_msg) ;
177+ AppError :: InternalError ( error_msg)
178+ } ) ?;
179+
180+ // Extract the HTML code and explanation from the response
181+ let ( html_code, explanation, confidence) = parse_visualization_response ( & llm_response) ;
182+
183+ Ok ( Json ( VisualizationResponse {
184+ html_code,
185+ explanation,
186+ confidence,
187+ } ) )
188+ }
189+
190+ // Call LLM API with system and user prompts
191+ async fn call_llm_api ( system_prompt : & str , user_prompt : & str , model_name : & str ) -> Result < String , anyhow:: Error > {
192+ let api_key = env:: var ( "LLM_API_KEY" ) . map_err ( |_| anyhow:: anyhow!( "LLM_API_KEY environment variable not set" ) ) ?;
193+
194+ let client = reqwest:: Client :: new ( ) ;
195+
196+ #[ derive( Serialize , Deserialize ) ]
197+ struct Message {
198+ role : String ,
199+ content : String ,
200+ }
201+
202+ #[ derive( Serialize ) ]
203+ struct OpenAIRequest {
204+ model : String ,
205+ messages : Vec < Message > ,
206+ temperature : f64 ,
207+ }
208+
209+ let request = OpenAIRequest {
210+ model : model_name. to_string ( ) ,
211+ messages : vec ! [
212+ Message {
213+ role: "system" . to_string( ) ,
214+ content: system_prompt. to_string( ) ,
215+ } ,
216+ Message {
217+ role: "user" . to_string( ) ,
218+ content: user_prompt. to_string( ) ,
219+ } ,
220+ ] ,
221+ temperature : 0.7 , // Slightly higher temperature for more creative visualizations
222+ } ;
223+
224+ let response = client
225+ . post ( "https://api.openai.com/v1/chat/completions" )
226+ . header ( "Authorization" , format ! ( "Bearer {}" , api_key) )
227+ . header ( "Content-Type" , "application/json" )
228+ . json ( & request)
229+ . send ( )
230+ . await ?;
231+
232+ if !response. status ( ) . is_success ( ) {
233+ let error_text = response. text ( ) . await ?;
234+ return Err ( anyhow:: anyhow!( "LLM API returned error: {}" , error_text) ) ;
235+ }
236+
237+ #[ derive( Deserialize ) ]
238+ struct OpenAIChoice {
239+ message : Message ,
240+ }
241+
242+ #[ derive( Deserialize ) ]
243+ struct OpenAIResponse {
244+ choices : Vec < OpenAIChoice > ,
245+ }
246+
247+ let response_json: OpenAIResponse = response. json ( ) . await ?;
248+
249+ if response_json. choices . is_empty ( ) {
250+ return Err ( anyhow:: anyhow!( "LLM API returned empty choices" ) ) ;
251+ }
252+
253+ Ok ( response_json. choices [ 0 ] . message . content . clone ( ) )
254+ }
255+
256+ // Parse LLM response to extract HTML, explanation and confidence
257+ fn parse_visualization_response ( response : & str ) -> ( String , String , f64 ) {
258+ // Try to extract HTML content - look for <!DOCTYPE html> or <html>
259+ let html_start_patterns = [ "<!DOCTYPE html>" , "<html>" ] ;
260+ let mut html_code = String :: new ( ) ;
261+ let mut explanation = String :: new ( ) ;
262+ let confidence = 0.8 ; // Default confidence
263+
264+ // First, check if the response contains a code block with HTML
265+ if let Some ( html_block_start) = response. find ( "```html" ) {
266+ // Find the end of the code block (next ```)
267+ if let Some ( html_block_end) = response[ html_block_start + 6 ..] . find ( "```" ) {
268+ // Extract HTML content (skip the ```html and end ```)
269+ let block_start_pos = html_block_start + "```html" . len ( ) ;
270+ let block_end_pos = html_block_start + 6 + html_block_end;
271+ html_code = response[ block_start_pos..block_end_pos] . trim ( ) . to_string ( ) ;
272+
273+ // Look for explanation after the HTML block
274+ if block_end_pos + 3 < response. len ( ) {
275+ explanation = response[ block_end_pos + 3 ..] . trim ( ) . to_string ( ) ;
276+ }
277+ }
278+ }
279+ // If no code block, try to find direct HTML
280+ else {
281+ for pattern in html_start_patterns. iter ( ) {
282+ if let Some ( start_idx) = response. find ( pattern) {
283+ html_code = response[ start_idx..] . trim ( ) . to_string ( ) ;
284+
285+ // Everything before HTML is considered explanation
286+ if start_idx > 0 {
287+ explanation = response[ 0 ..start_idx] . trim ( ) . to_string ( ) ;
288+ }
289+ break ;
290+ }
291+ }
292+ }
293+
294+ // If still no HTML found, look for any content between <script> tags or <div id="plot">
295+ if html_code. is_empty ( ) {
296+ if let Some ( script_start) = response. find ( "<script>" ) {
297+ if let Some ( script_end) = response[ script_start..] . find ( "</script>" ) {
298+ // Create a basic HTML wrapper around the script
299+ let script_content = & response[ script_start..script_start + script_end + 9 ] ;
300+ html_code = format ! (
301+ "<!DOCTYPE html>\n <html>\n <head>\n <title>Visualization</title>\n <script src=\" https://cdn.plot.ly/plotly-latest.min.js\" ></script>\n </head>\n <body>\n <div id=\" plot\" ></div>\n {}\n </body>\n </html>" ,
302+ script_content
303+ ) ;
304+
305+ // Everything else is explanation
306+ explanation = response. replace ( script_content, "" ) . trim ( ) . to_string ( ) ;
307+ }
308+ }
309+ }
310+
311+ // If still nothing found, return the whole response as HTML with a warning
312+ if html_code. is_empty ( ) {
313+ html_code = format ! (
314+ "<!DOCTYPE html>\n <html>\n <head>\n <title>Visualization Error</title>\n </head>\n <body>\n <h1>Could not generate visualization</h1>\n <pre>{}</pre>\n </body>\n </html>" ,
315+ response. replace( "<" , "<" ) . replace( ">" , ">" )
316+ ) ;
317+ explanation = "Could not parse LLM response into valid HTML visualization." . to_string ( ) ;
318+ }
319+
320+ // If explanation is empty, provide a default
321+ if explanation. is_empty ( ) {
322+ explanation = "Visualization generated from the provided data." . to_string ( ) ;
323+ }
324+
325+ ( html_code, explanation, confidence)
326+ }
327+
113328#[ tokio:: main]
114329async fn main ( ) {
115330 // Load environment variables
@@ -131,6 +346,7 @@ async fn main() {
131346 . route ( "/" , get ( root) )
132347 . route ( "/health" , get ( health_check) )
133348 . route ( "/process-query" , post ( process_query) )
349+ . route ( "/generate" , post ( generate_visualization) )
134350 . layer (
135351 CorsLayer :: new ( )
136352 . allow_origin ( Any )
0 commit comments