4444from ..geopandas_tools .geometry_types import get_geom_type
4545from ..geopandas_tools .geometry_types import to_single_geom_type
4646from ..helpers import _get_file_system
47+ from ..helpers import dict_zip
4748from .wms import WmsLoader
4849
4950try :
@@ -385,12 +386,12 @@ def __init__(
385386
386387 super ().__init__ (column = column , show = show , ** (new_kwargs | new_gdfs ))
387388
388- if self .gdfs is None :
389+ if self ._gdfs is None :
389390 return
390391
391392 # stringify or remove columns not renerable by leaflet (list, geometry etc.)
392- new_gdfs , show_new = [], []
393- for gdf , show in zip (self .gdfs , self .show , strict = True ):
393+ new_gdfs , show_new = {}, {}
394+ for label , gdf , show in dict_zip (self ._gdfs , self .show ):
394395 try :
395396 gdf = gdf .reset_index ()
396397 except Exception :
@@ -420,18 +421,15 @@ def __init__(
420421 gdf .index = gdf .index .astype (str )
421422 except Exception :
422423 pass
423- new_gdfs . append ( to_gdf (gdf ) )
424- show_new . append ( show )
424+ new_gdfs [ label ] = to_gdf (gdf )
425+ show_new [ label ] = show
425426 self ._gdfs = new_gdfs
426427 if self ._gdfs :
427428 self ._gdf = pd .concat (new_gdfs , ignore_index = True )
428429 else :
429- self ._gdf = GeoDataFrame ({ "geometry" : [], self ._column : []} )
430+ self ._gdf = self ._get_gdf_template ( )
430431 self .show = show_new
431432
432- # if self._show_was_none and len(self._gdfs) > 6:
433- # self.show = [False] * len(self._gdfs)
434-
435433 if self ._is_categorical :
436434 if len (self .gdfs ) == 1 :
437435 self ._split_categories ()
@@ -455,7 +453,7 @@ def __len__(self) -> int:
455453 rasters = self .raster_data
456454 except AttributeError :
457455 rasters = self .rasters
458- return len ([gdf for gdf in self ._gdfs if len (gdf )]) + len (rasters )
456+ return len ([gdf for gdf in self ._gdfs . values () if len (gdf )]) + len (rasters )
459457
460458 def __bool__ (self ) -> bool :
461459 """True if any gdfs have rows or there are any raster images."""
@@ -473,7 +471,7 @@ def explore(
473471 self .mask = mask if mask is not None else self .mask
474472 if (
475473 self ._gdfs
476- and not any (len (gdf ) for gdf in self ._gdfs )
474+ and not any (len (gdf ) for gdf in self ._gdfs . values () )
477475 and not len (self .rasters )
478476 ):
479477 warnings .warn ("None of the GeoDataFrames have rows." , stacklevel = 1 )
@@ -498,13 +496,11 @@ def explore(
498496 else center
499497 )
500498
501- gdfs : tuple [GeoDataFrame ] = ()
502- for gdf in self ._gdfs :
499+ for label , gdf in self ._gdfs .items ():
503500 keep_geom_type = False if get_geom_type (gdf ) == "mixed" else True
504501 gdf = gdf .clip (centerpoint .buffer (size ), keep_geom_type = keep_geom_type )
505- gdfs = gdfs + (gdf ,)
506- self ._gdfs = gdfs
507- self ._gdf = pd .concat (gdfs , ignore_index = True )
502+ self ._gdfs [label ] = gdf
503+ self ._gdf = pd .concat (self ._gdfs .values (), ignore_index = True )
508504
509505 self ._get_unique_values ()
510506
@@ -558,19 +554,17 @@ def clipmap(
558554 self ._update_column ()
559555 kwargs .pop ("column" , None )
560556
561- gdfs : tuple [GeoDataFrame ] = ()
562- for gdf in self ._gdfs :
557+ for label , gdf in self ._gdfs .items ():
563558 gdf = gdf .clip (self .mask )
564559 collections = gdf .loc [gdf .geom_type == "GeometryCollection" ]
565560 if len (collections ):
566561 collections = make_all_singlepart (collections )
567562 gdf = pd .concat ([gdf , collections ], ignore_index = False )
568- gdfs = gdfs + (gdf ,)
569- self ._gdfs = gdfs
563+ self ._gdfs [label ] = gdf
570564 if self ._gdfs :
571- self ._gdf = pd .concat (self ._gdfs , ignore_index = True )
565+ self ._gdf = pd .concat (self ._gdfs . values () , ignore_index = True )
572566 else :
573- self ._gdf = GeoDataFrame ({ "geometry" : [], self ._column : []} )
567+ self ._gdf = self ._get_gdf_template ( )
574568
575569 self ._explore (** kwargs )
576570
@@ -638,8 +632,11 @@ def save(self, path: str) -> None:
638632 def _explore (self , ** kwargs ) -> None :
639633 self .kwargs = self .kwargs | kwargs
640634
641- if self ._show_was_none and len ([gdf for gdf in self ._gdfs if len (gdf )]) > 6 :
642- self .show = [False ] * len (self ._gdfs )
635+ if (
636+ self ._show_was_none
637+ and len ([gdf for gdf in self ._gdfs .values () if len (gdf )]) > 6
638+ ):
639+ self .show = {label : False for label in self ._gdfs }
643640
644641 if self ._is_categorical :
645642 self ._create_categorical_map ()
@@ -662,15 +659,13 @@ def _explore(self, **kwargs) -> None:
662659 display (self .map )
663660
664661 def _split_categories (self ) -> None :
665- new_gdfs , new_labels , new_shows = [], [], []
662+ new_gdfs , new_shows = {}, {}
666663 for cat in self ._unique_values :
667664 gdf = self .gdf .loc [self .gdf [self .column ] == cat ]
668- new_gdfs .append (gdf )
669- new_labels .append (cat )
670- new_shows .append (self .show [0 ])
665+ new_gdfs [cat ] = gdf
666+ new_shows [cat ] = next (iter (self .show .values ()))
671667 self ._gdfs = new_gdfs
672668 self ._gdf = pd .concat (new_gdfs , ignore_index = True )
673- self .labels = new_labels
674669 self .show = new_shows
675670
676671 def _to_single_geom_type (self , gdf : GeoDataFrame ) -> GeoDataFrame :
@@ -720,12 +715,11 @@ def _get_bounds(
720715 return gdf .total_bounds
721716
722717 def _create_categorical_map (self ) -> None :
723- self ._make_categories_colors_dict ()
718+ self ._prepare_categorical_plot ()
724719 if self ._gdf is not None and len (self ._gdf ):
725- self ._fix_nans ()
726720 gdf = self ._prepare_gdf_for_map (self ._gdf )
727721 else :
728- gdf = GeoDataFrame ({ "geometry" : [], self ._column : []} )
722+ gdf = self ._get_gdf_template ( )
729723
730724 self ._load_rasters_as_images ()
731725
@@ -742,7 +736,7 @@ def _create_categorical_map(self) -> None:
742736 ** self .kwargs ,
743737 )
744738
745- for gdf , label , show in zip (self ._gdfs , self .labels , self . show , strict = True ):
739+ for label , gdf , show in dict_zip (self ._gdfs , self .show ):
746740 if not len (gdf ):
747741 continue
748742
@@ -798,7 +792,10 @@ def _create_continous_map(self):
798792 if self .scheme :
799793 classified = self ._classify_from_bins (self ._gdf , bins = self .bins )
800794 classified_sequential = self ._push_classification (classified )
801- n_colors = len (np .unique (classified_sequential )) - any (self ._nan_idx )
795+ n_colors = (
796+ len (np .unique (classified_sequential ))
797+ - self ._gdf [self ._column ].isna ().any ()
798+ )
802799 unique_colors = self ._get_continous_colors (n = n_colors )
803800
804801 self ._load_rasters_as_images ()
@@ -824,7 +821,7 @@ def _create_continous_map(self):
824821 index = self .bins ,
825822 )
826823
827- for gdf , label , show in zip (self ._gdfs , self . labels , self .show , strict = True ):
824+ for ( label , gdf ) , show in zip (self ._gdfs . items () , self .show , strict = True ):
828825 if not len (gdf ):
829826 continue
830827
0 commit comments