1- from __future__ import annotations
2-
31import json
42from functools import lru_cache
53from pathlib import Path
6- from typing import Dict , List , Any , Callable
4+ from typing import Dict , List , Any , Callable , Union
75
86
97from src .agents .base_agent import Agent
1816 save_enriched_elements as save_json ,
1917 to_geojson ,
2018)
21- from src .tools .mapping_tools import to_heatmap
19+ from src .tools .mapping_tools import to_heatmap , to_hotspots
2220from src .tools .stat_tools import compute_statistics
21+ from src .tools .chart_tools import (
22+ private_public_pie ,
23+ plot_zone_sensitivity ,
24+ plot_sensitivity_reasons ,
25+ plot_hotspots ,
26+ )
2327from src .utils .decorators import log_action
2428
2529Tool = Callable [..., Any ]
@@ -35,16 +39,21 @@ def __init__(
3539 self ,
3640 name : str ,
3741 memory : MemoryStore ,
38- llm : LocalLLM | None = None ,
39- tools : Dict [str , Tool ] | None = None ,
42+ llm : Union [ LocalLLM , None ] = None ,
43+ tools : Union [ Dict [str , Tool ], None ] = None ,
4044 ):
4145 default_tools : Dict [str , Tool ] = {
4246 "load_json" : load_json ,
4347 "enrich" : self ._enrich_element ,
4448 "save_json" : save_json ,
4549 "to_geojson" : to_geojson ,
4650 "to_heatmap" : to_heatmap ,
51+ "to_hotspots" : to_hotspots ,
52+ "plot_hotspots" : plot_hotspots ,
4753 "report" : compute_statistics ,
54+ "plot_pie" : private_public_pie ,
55+ "plot_zone_sensitivity" : plot_zone_sensitivity ,
56+ "plot_sensitivity_reasons" : plot_sensitivity_reasons ,
4857 }
4958 super ().__init__ (name , tools or default_tools , memory )
5059 self .llm = llm or LocalLLM ()
@@ -58,12 +67,22 @@ def perceive(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
5867 raise FileNotFoundError (path )
5968 generate_geojson = input_data .get ("generate_geojson" , True )
6069 generate_heatmap = input_data .get ("generate_heatmap" , False )
70+ generate_hotspots = input_data .get ("generate_hotspots" , False )
6171 compute_stats = input_data .get ("compute_stats" , True )
72+ generate_chart = input_data .get ("generate_chart" , False )
73+ plot_zone = input_data .get ("plot_zone_sensitivity" , False )
74+ plot_reasons = input_data .get ("plot_sensitivity_reasons" , False )
75+ plot_hotspots = input_data .get ("plot_hotspots" , False )
6276 return {
6377 "path" : path ,
6478 "generate_geojson" : generate_geojson ,
6579 "generate_heatmap" : generate_heatmap ,
80+ "generate_hotspots" : generate_hotspots ,
6681 "compute_stats" : compute_stats ,
82+ "generate_chart" : generate_chart ,
83+ "plot_zone_sensitivity" : plot_zone ,
84+ "plot_sensitivity_reasons" : plot_reasons ,
85+ "plot_hotspots" : plot_hotspots ,
6786 }
6887
6988 def plan (self , observation : Dict [str , Any ]) -> List [str ]:
@@ -72,8 +91,18 @@ def plan(self, observation: Dict[str, Any]) -> List[str]:
7291 steps .append ("to_geojson" )
7392 if observation ["generate_heatmap" ]:
7493 steps .append ("to_heatmap" )
94+ if observation ["generate_hotspots" ]:
95+ steps .append ("to_hotspots" )
7596 if observation ["compute_stats" ]:
7697 steps .append ("report" )
98+ if observation ["generate_chart" ]:
99+ steps .append ("plot_pie" )
100+ if observation ["plot_zone_sensitivity" ]:
101+ steps .append ("plot_zone_sensitivity" )
102+ if observation ["plot_sensitivity_reasons" ]:
103+ steps .append ("plot_sensitivity_reasons" )
104+ if observation ["plot_hotspots" ]:
105+ steps .append ("plot_hotspots" )
77106 return steps
78107
79108 @log_action
@@ -162,11 +191,54 @@ def act(self, action: str, context: Dict[str, Any]) -> Any:
162191 self .remember ("heatmap_cache" , f"{ context ['raw_hash' ]} |{ html_path } " )
163192 return str (html_path )
164193
194+ if action == "to_hotspots" :
195+ geojson_path = Path (context ["geojson_path" ])
196+ hotspots_path = geojson_path .with_name (
197+ geojson_path .stem + "_hotspots.geojson"
198+ )
199+ self .tools ["to_hotspots" ](geojson_path , hotspots_path )
200+ context ["hotspots_path" ] = str (hotspots_path )
201+ cache_val = f"{ context ['raw_hash' ]} |{ hotspots_path } "
202+ self .remember ("hotspot_cache" , cache_val )
203+ return str (hotspots_path )
204+
205+ if action == "plot_hotspots" :
206+ hot = Path (context ["hotspots_path" ])
207+ pic = hot .with_suffix (".png" )
208+ self .tools ["plot_hotspots" ](hot , pic )
209+ context ["hotspots_plot" ] = str (pic )
210+ return str (pic )
211+
165212 if action == "report" :
166213 stats : Dict [str , Any ] = self .tools ["report" ](context ["enriched" ])
167214 self .remember ("report" , json .dumps (stats ))
168215 return stats
169216
217+ if action == "plot_pie" :
218+ stats = context ["stats" ]
219+ src_path = Path (context ["path" ])
220+ chart_path = self .tools ["plot_pie" ](stats , src_path .parent )
221+ context ["chart_path" ] = chart_path
222+ self .remember ("pie_chart" , f"{ context ['raw_hash' ]} |{ chart_path } " )
223+ return str (chart_path )
224+
225+ if action == "plot_zone_sensitivity" :
226+ stats = context ["stats" ]
227+ src_path = Path (context ["path" ])
228+ chart_path = self .tools ["plot_zone_sensitivity" ](stats , src_path .parent )
229+ context ["chart_path" ] = chart_path
230+ self .remember ("chart_zone_sens" , f"{ context ['raw_hash' ]} |{ chart_path } " )
231+ return str (chart_path )
232+
233+ if action == "plot_sensitivity_reasons" :
234+ enriched_path = Path (context ["output_path" ])
235+ chart_path = enriched_path .with_name (
236+ f"{ enriched_path .stem } _sensitivity.png"
237+ )
238+ self .tools ["plot_sensitivity_reasons" ](enriched_path , chart_path )
239+ context ["sensitivity_reasons_chart" ] = str (chart_path )
240+ return str (chart_path )
241+
170242 logger .error (f"Unhandled action: { action } " )
171243 raise NotImplementedError (f"Unhandled action: { action } " )
172244
@@ -188,8 +260,18 @@ def achieve_goal(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
188260 context ["geojson_path" ] = result
189261 elif step == "to_heatmap" :
190262 context ["heatmap_path" ] = result
263+ elif step == "to_hotspots" :
264+ context ["hotspots_path" ] = result
191265 elif step == "report" :
192266 context ["stats" ] = result
267+ elif step == "plot_pie" :
268+ context ["chart_path" ] = result
269+ elif step == "plot_zone_sensitivity" :
270+ context ["chart_zone_sens" ] = result
271+ elif step == "plot_sensitivity_reasons" :
272+ context ["chart_sens_reasons" ] = result
273+ elif step == "plot_hotspots" :
274+ context ["plot_hotspots" ] = result
193275
194276 return context
195277
0 commit comments