@@ -2236,46 +2236,105 @@ def plot_haplotype_network(
22362236 edges = np .triu (edges )
22372237 alt_edges = np .triu (alt_edges )
22382238
2239- debug ("setup colors" )
2240- color_values = None
2241- color_values_display = None
2242- color_discrete_map_display = None
2243- ht_color_counts = None
2244- if color is not None :
2245- # sanitise color column - necessary to avoid grey pie chart segments
2246- df_haps ["partition" ] = df_haps [color ].str .replace (r"\W" , "" , regex = True )
2247-
2248- # extract all unique values of the color column
2249- color_values = df_haps ["partition" ].fillna ("<NA>" ).unique ()
2250- color_values_mapping = dict (zip (df_haps ["partition" ], df_haps [color ]))
2251- color_values_mapping ["<NA>" ] = "black"
2252- color_values_display = [color_values_mapping [c ] for c in color_values ]
2253-
2254- # count color values for each distinct haplotype
2255- ht_color_counts = [
2256- df_haps .iloc [list (s )]["partition" ].value_counts ().to_dict ()
2257- for s in ht_distinct_sets
2258- ]
2239+ debug ("setup colors" )
2240+ color_values = None
2241+ color_values_display = None
2242+ color_discrete_map_display = None
2243+ ht_color_counts = None
2244+
2245+ if color is not None :
2246+ # Handle string case (direct column name or cohorts_ prefix)
2247+ if isinstance (color , str ):
2248+ # Try direct column name
2249+ if color in df_haps .columns :
2250+ color_column = color
2251+ # Try with cohorts_ prefix
2252+ elif f"cohorts_{ color } " in df_haps .columns :
2253+ color_column = f"cohorts_{ color } "
2254+ # Neither exists, raise helpful error
2255+ else :
2256+ available_columns = ", " .join (df_haps .columns )
2257+ raise ValueError (
2258+ f"Column '{ color } ' or 'cohorts_{ color } ' not found in sample data. "
2259+ f"Available columns: { available_columns } "
2260+ )
2261+
2262+ # Now use the validated color_column for processing
2263+ df_haps ["_partition" ] = (
2264+ df_haps [color_column ]
2265+ .astype (str )
2266+ .str .replace (r"\W" , "" , regex = True )
2267+ )
22592268
2260- # Set up colors.
2261- (
2262- color_prepped ,
2263- color_discrete_map_prepped ,
2264- category_orders_prepped ,
2265- ) = self ._setup_sample_colors_plotly (
2266- data = df_haps ,
2267- color = "partition" ,
2268- color_discrete_map = color_discrete_map ,
2269- color_discrete_sequence = color_discrete_sequence ,
2270- category_orders = category_orders ,
2271- )
2272- del color_discrete_map
2273- del color_discrete_sequence
2274- del category_orders
2275- color_discrete_map_display = {
2276- color_values_mapping [v ]: c
2277- for v , c in color_discrete_map_prepped .items ()
2278- }
2269+ # extract all unique values of the color column
2270+ color_values = df_haps ["_partition" ].fillna ("<NA>" ).unique ()
2271+ color_values_mapping = dict (
2272+ zip (df_haps ["_partition" ], df_haps [color_column ])
2273+ )
2274+ color_values_mapping ["<NA>" ] = "black"
2275+ color_values_display = [
2276+ color_values_mapping [c ] for c in color_values
2277+ ]
2278+
2279+ # Handle mapping/dictionary case
2280+ elif isinstance (color , Mapping ):
2281+ # For mapping case, we need to create a new column based on the mapping
2282+ # Initialize with None
2283+ df_haps ["_partition" ] = None
2284+
2285+ # Apply each query in the mapping to create the _partition column
2286+ for label , query in color .items ():
2287+ # Apply the query and assign the label to matching rows
2288+ mask = df_haps .eval (query )
2289+ df_haps .loc [mask , "_partition" ] = label
2290+
2291+ # Clean up the _partition column to avoid issues with special characters
2292+ if df_haps ["_partition" ].notna ().any ():
2293+ df_haps ["_partition" ] = df_haps ["_partition" ].str .replace (
2294+ r"\W" , "" , regex = True
2295+ )
2296+
2297+ # extract all unique values of the color column
2298+ color_values = df_haps ["_partition" ].fillna ("<NA>" ).unique ()
2299+ # For mapping case, use _partition values directly as they're already the labels
2300+ color_values_mapping = dict (
2301+ zip (df_haps ["_partition" ], df_haps ["_partition" ])
2302+ )
2303+ color_values_mapping ["<NA>" ] = "black"
2304+ color_values_display = [
2305+ color_values_mapping [c ] for c in color_values
2306+ ]
2307+ else :
2308+ # Invalid type
2309+ raise TypeError (
2310+ f"Expected color parameter to be a string or mapping, got { type (color ).__name__ } "
2311+ )
2312+
2313+ # count color values for each distinct haplotype (same for both string and mapping cases)
2314+ ht_color_counts = [
2315+ df_haps .iloc [list (s )]["_partition" ].value_counts ().to_dict ()
2316+ for s in ht_distinct_sets
2317+ ]
2318+
2319+ # Set up colors (same for both string and mapping cases)
2320+ (
2321+ color_prepped ,
2322+ color_discrete_map_prepped ,
2323+ category_orders_prepped ,
2324+ ) = self ._setup_sample_colors_plotly (
2325+ data = df_haps ,
2326+ color = "_partition" ,
2327+ color_discrete_map = color_discrete_map ,
2328+ color_discrete_sequence = color_discrete_sequence ,
2329+ category_orders = category_orders ,
2330+ )
2331+ del color_discrete_map
2332+ del color_discrete_sequence
2333+ del category_orders
2334+ color_discrete_map_display = {
2335+ color_values_mapping [v ]: c
2336+ for v , c in color_discrete_map_prepped .items ()
2337+ }
22792338
22802339 debug ("construct graph" )
22812340 anon_width = np .sqrt (0.3 * node_size_factor )
@@ -2284,7 +2343,7 @@ def plot_haplotype_network(
22842343 ht_distinct_mjn = ht_distinct_mjn ,
22852344 ht_counts = ht_counts ,
22862345 ht_color_counts = ht_color_counts ,
2287- color = color ,
2346+ color = "_partition" if color is not None else None ,
22882347 color_values = color_values ,
22892348 edges = edges ,
22902349 alt_edges = alt_edges ,
@@ -2336,7 +2395,7 @@ def plot_haplotype_network(
23362395 debug ("create figure legend" )
23372396 if color is not None :
23382397 legend_fig = plotly_discrete_legend (
2339- color = color ,
2398+ color = "_partition" , # Changed from color=color
23402399 color_values = color_values_display ,
23412400 color_discrete_map = color_discrete_map_display ,
23422401 category_orders = category_orders_prepped ,
0 commit comments