55 "smart_format" ,
66 "pretty_bin_labels" ,
77 "add_colorbar" ,
8- "add_legend_panel_colorbar"
8+ "add_legend_panel_colorbar" ,
9+ "clip_data_to_bins"
910]
1011
1112import pandas as pd
3132from idd_forecast_mbp .helper_functions import read_income_paths
3233from idd_forecast_mbp .parquet_functions import read_parquet_with_integer_ids
3334from idd_forecast_mbp .xarray_functions import read_netcdf_with_integer_ids , convert_to_xarray , write_netcdf
35+ from idd_forecast_mbp .color_functions import (
36+ create_change_colormap ,
37+ create_outcome_colormap ,
38+ create_diverging_colors ,
39+ get_colors
40+ )
3441
3542
36- def draw_legend_bins (ax , plot_dict ):
37- map_type = plot_dict ['map_type' ]
38- legend_dict = plot_dict ['legend_dict' ]
39- bin_dict = plot_dict ['bin_dict' ]
43+ def clip_data_to_bins (df , data_column , bins ):
44+ """Clip data values to ensure they fall within bin boundaries."""
45+ df = df .copy ()
46+ min_bin , max_bin = bins [0 ], bins [- 1 ]
47+
48+ # Show clipping stats
49+ original_min , original_max = df [data_column ].min (), df [data_column ].max ()
50+ clipped_count = ((df [data_column ] < min_bin ) | (df [data_column ] > max_bin )).sum ()
51+
52+ if clipped_count > 0 :
53+ print (f"Clipping { clipped_count } values outside bins [{ min_bin } , { max_bin } ]" )
54+ print (f"Original range: [{ original_min :.2f} , { original_max :.2f} ]" )
55+
56+ # Clip the data
57+ df [data_column ] = df [data_column ].clip (lower = min_bin , upper = max_bin )
58+
59+ return df
60+
61+ def draw_legend_bins (ax , map_plot_dict ):
62+ map_type = map_plot_dict ['map_type' ]
63+ legend_dict = map_plot_dict ['legend_dict' ]
64+ bin_dict = map_plot_dict ['bin_dict' ]
4065 legend_panel = legend_dict ['legend_panel' ]
4166 legend_bin_spacing = legend_panel ['legend_bin_spacing' ]
4267 legend_margin = legend_panel ['legend_margin' ]
@@ -49,7 +74,7 @@ def draw_legend_bins(ax, plot_dict):
4974 bin_height = bin_top - bin_bottom
5075 bin_label_y = bin_bottom - bin_label_gap
5176
52- category_labels = plot_dict ['bin_dict' ]['bin_labels' ]
77+ category_labels = map_plot_dict ['bin_dict' ]['bin_labels' ]
5378
5479 bin_dict ['n_bins' ] = len (category_labels )
5580 n_bins = bin_dict ['n_bins' ]
@@ -72,44 +97,41 @@ def draw_legend_bins(ax, plot_dict):
7297 rect = Rectangle ((bin_left [i ], bin_bottom ), bin_width , bin_height , facecolor = bin_colors [i ], edgecolor = 'black' , linewidth = 0.5 )
7398 ax .add_patch (rect )
7499 ax .text (bin_left [i ] + bin_width / 2 , bin_label_y , category_labels [i ], ha = 'center' , va = 'top' ,
75- fontsize = plot_dict ['fontsizes' ]['legend_label_fontsize' ])
100+ fontsize = map_plot_dict ['fontsizes' ]['legend_label_fontsize' ])
76101
77- def add_legend (fig , ax , plot_dict ):
78- if plot_dict ['have_legend_panel' ]:
79- if plot_dict ['legend_dict' ]['use_colorbar' ]:
80- add_legend_panel_colorbar (fig , ax , plot_dict )
102+ def add_legend (fig , ax , map_plot_dict ):
103+ if map_plot_dict ['have_legend_panel' ]:
104+ if map_plot_dict ['legend_dict' ]['use_colorbar' ]:
105+ add_legend_panel_colorbar (fig , ax , map_plot_dict )
81106 else :
82- draw_legend_bins (ax , plot_dict )
107+ draw_legend_bins (ax , map_plot_dict )
83108 else :
84- add_colorbar (fig , ax , plot_dict )
109+ add_colorbar (fig , ax , map_plot_dict )
85110
86- def get_bin_info (plot_dict , valid_data , map_data_masked ):
87- map_type = plot_dict ['map_type' ]
88- bin_dict = plot_dict ['bin_dict' ]
89- # Categorization logic (same as before)
90- bins = bin_dict ['bins' ]
91- n_bins = bin_dict ['n_bins' ]
92- base_colormap = plt .colormaps [plot_dict ['colors_dict' ]['base_cmap' ]]
93- if bins [0 ] == 0 :
94- white_rgba = to_rgba ('white' )
95- bin_colors = np .vstack ([white_rgba , base_colormap (np .linspace (0.1 , 0.9 , n_bins - 1 ))])
111+ def get_bin_info (map_plot_dict , plot_data ):
112+ map_type = map_plot_dict ['map_type' ]
96113
114+ if map_type == 'outcome' :
115+ # Outcome map
116+ create_outcome_colormap (map_plot_dict )
97117 else :
98- bin_colors = base_colormap ( np . linspace ( 0.1 , 0.9 , n_bins ) )
118+ create_change_colormap ( map_plot_dict )
99119
100- cmap = ListedColormap (bin_colors )
101- categorical_data = np .full_like (map_data_masked , np .nan )
120+ map_plot_dict ['bin_dict' ]['bin_labels' ] = pretty_bin_labels (map_plot_dict )
121+ bin_dict = map_plot_dict ['bin_dict' ]
122+ bins = bin_dict ['bins' ]
123+ n_bins = bin_dict ['n_bins' ]
124+ plot_data_values = plot_data .values if hasattr (plot_data , 'values' ) else plot_data
125+ categorical_data = np .full_like (plot_data_values , np .nan )
102126 for i in range (n_bins ):
103- if i == 0 :
104- mask = map_data_masked <= bins [i + 1 ]
127+ if i == 0.0 :
128+ mask = plot_data_values <= bins [i + 1 ]
105129 elif i == n_bins - 1 :
106- mask = map_data_masked > bins [i ]
130+ mask = plot_data_values > bins [i ]
107131 else :
108- mask = (map_data_masked > bins [i ]) & (map_data_masked <= bins [i + 1 ])
132+ mask = (plot_data_values > bins [i ]) & (plot_data_values <= bins [i + 1 ])
109133 categorical_data [mask ] = i
110134
111- bin_dict ['bin_colors' ] = bin_colors
112- bin_dict ['cmap' ] = cmap
113135 bin_dict ['categorical_data' ] = categorical_data
114136
115137def smart_format (val ):
@@ -120,14 +142,14 @@ def smart_format(val):
120142 s = f"{ val :,.2f} " .rstrip ('0' ).rstrip ('.' )
121143 return s
122144
123- def pretty_bin_labels (plot_dict ):
124- bins = plot_dict ['bin_dict' ]['bins' ]
125- le = plot_dict ['bin_dict' ]['le' ]
126- ge = plot_dict ['bin_dict' ]['ge' ]
127- zero_bin = plot_dict ['bin_dict' ]['zero_bin' ]
128- prefix_units = plot_dict ['bin_dict' ]['prefix_units' ]
129- suffix_units = plot_dict ['bin_dict' ]['suffix_units' ]
130- abbreviate_labels = plot_dict ['bin_dict' ]['abbreviate_labels' ]
145+ def pretty_bin_labels (map_plot_dict ):
146+ bins = map_plot_dict ['bin_dict' ]['bins' ]
147+ le = map_plot_dict ['bin_dict' ]['le' ]
148+ ge = map_plot_dict ['bin_dict' ]['ge' ]
149+ zero_bin = map_plot_dict ['bin_dict' ]['zero_bin' ]
150+ prefix_units = map_plot_dict ['bin_dict' ]['prefix_units' ]
151+ suffix_units = map_plot_dict ['bin_dict' ]['suffix_units' ]
152+ abbreviate_labels = map_plot_dict ['bin_dict' ]['abbreviate_labels' ]
131153 # bins: array-like of bin edges
132154 # fmt: format for numbers (default: 2 significant digits)
133155
@@ -188,12 +210,12 @@ def pretty_bin_labels(plot_dict):
188210
189211 return labels
190212
191- def add_colorbar (fig , ax , plot_dict ):
192- figure_dict = plot_dict ['figure_dict' ]
193- legend_dict = plot_dict ['legend_dict' ]
194- bins = plot_dict ['bin_dict' ]['bins' ]
213+ def add_colorbar (fig , ax , map_plot_dict ):
214+ figure_dict = map_plot_dict ['figure_dict' ]
215+ legend_dict = map_plot_dict ['legend_dict' ]
216+ bins = map_plot_dict ['bin_dict' ]['bins' ]
195217 bin_centers = [(bins [i ] + bins [i + 1 ]) / 2 for i in range (len (bins )- 1 )]
196- sm = plt .cm .ScalarMappable (cmap = plot_dict ['bin_dict' ]['cmap' ], norm = plot_dict ['bin_dict' ]['norm' ])
218+ sm = plt .cm .ScalarMappable (cmap = map_plot_dict ['bin_dict' ]['cmap' ], norm = map_plot_dict ['bin_dict' ]['norm' ])
197219 sm .set_array ([])
198220 cbar = fig .colorbar (sm , ax = ax , orientation = 'horizontal' ,
199221 shrink = legend_dict ['color_bar_dict' ]['shrink' ],
@@ -204,22 +226,22 @@ def add_colorbar(fig, ax, plot_dict):
204226 cbar .set_ticklabels (figure_dict ['bin_labels' ], fontsize = figure_dict ['tick_font_size' ])
205227 cbar .set_label (figure_dict ['colorbar_label' ], fontsize = figure_dict ['colorbar_title_font_size' ])
206228
207- def add_legend_panel_colorbar (fig , ax_legend , plot_dict ):
229+ def add_legend_panel_colorbar (fig , ax_legend , map_plot_dict ):
208230 """
209231 Fixed version that uses the calculated parameters properly
210232 """
211233 ax_legend .clear ()
212234 ax_legend .axis ('off' )
213235
214236 # Get colorbar data
215- bin_dict = plot_dict ['bin_dict' ]
237+ bin_dict = map_plot_dict ['bin_dict' ]
216238 cmap = bin_dict ['cmap' ]
217239 norm = bin_dict ['norm' ]
218240 bins = bin_dict .get ('bins' , None )
219241 bin_labels = bin_dict .get ('bin_labels' , None )
220- colorbar_label = plot_dict ['full_outcome_label' ]
242+ colorbar_label = map_plot_dict ['full_outcome_label' ]
221243
222- color_bar_dict = plot_dict ['legend_dict' ]['color_bar_dict' ]
244+ color_bar_dict = map_plot_dict ['legend_dict' ]['color_bar_dict' ]
223245
224246 # Create ScalarMappable
225247 sm = plt .cm .ScalarMappable (cmap = cmap , norm = norm )
@@ -242,7 +264,7 @@ def add_legend_panel_colorbar(fig, ax_legend, plot_dict):
242264
243265 cbar_ax = fig .add_axes ([cbar_left , cbar_bottom , cbar_width , cbar_height ])
244266
245- extend_option = 'max' if plot_dict .get ('extend_colorbar' , False ) else 'neither'
267+ extend_option = 'max' if map_plot_dict .get ('extend_colorbar' , False ) else 'neither'
246268 cbar = fig .colorbar (sm , cax = cbar_ax , orientation = 'horizontal' , extend = extend_option )
247269
248270 if bins is not None and bin_labels is not None :
0 commit comments