Skip to content

Commit 0532dba

Browse files
authored
Merge pull request #755 from malariagen/GH738-mohamed-laarej-fix-plot-haplotype-network-color
Fix color parameter validation in plot_haplotype_network function - Shadow PR
2 parents ff71c3f + 4d4b245 commit 0532dba

File tree

6 files changed

+487
-52
lines changed

6 files changed

+487
-52
lines changed

malariagen_data/anopheles.py

Lines changed: 100 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2236,46 +2236,105 @@ def plot_haplotype_network(
22362236
edges = np.triu(edges)
22372237
alt_edges = np.triu(alt_edges)
22382238

2239-
debug("setup colors")
2240-
color_values = None
2241-
color_values_display = None
2242-
color_discrete_map_display = None
2243-
ht_color_counts = None
2244-
if color is not None:
2245-
# sanitise color column - necessary to avoid grey pie chart segments
2246-
df_haps["partition"] = df_haps[color].str.replace(r"\W", "", regex=True)
2247-
2248-
# extract all unique values of the color column
2249-
color_values = df_haps["partition"].fillna("<NA>").unique()
2250-
color_values_mapping = dict(zip(df_haps["partition"], df_haps[color]))
2251-
color_values_mapping["<NA>"] = "black"
2252-
color_values_display = [color_values_mapping[c] for c in color_values]
2253-
2254-
# count color values for each distinct haplotype
2255-
ht_color_counts = [
2256-
df_haps.iloc[list(s)]["partition"].value_counts().to_dict()
2257-
for s in ht_distinct_sets
2258-
]
2239+
debug("setup colors")
2240+
color_values = None
2241+
color_values_display = None
2242+
color_discrete_map_display = None
2243+
ht_color_counts = None
2244+
2245+
if color is not None:
2246+
# Handle string case (direct column name or cohorts_ prefix)
2247+
if isinstance(color, str):
2248+
# Try direct column name
2249+
if color in df_haps.columns:
2250+
color_column = color
2251+
# Try with cohorts_ prefix
2252+
elif f"cohorts_{color}" in df_haps.columns:
2253+
color_column = f"cohorts_{color}"
2254+
# Neither exists, raise helpful error
2255+
else:
2256+
available_columns = ", ".join(df_haps.columns)
2257+
raise ValueError(
2258+
f"Column '{color}' or 'cohorts_{color}' not found in sample data. "
2259+
f"Available columns: {available_columns}"
2260+
)
2261+
2262+
# Now use the validated color_column for processing
2263+
df_haps["_partition"] = (
2264+
df_haps[color_column]
2265+
.astype(str)
2266+
.str.replace(r"\W", "", regex=True)
2267+
)
22592268

2260-
# Set up colors.
2261-
(
2262-
color_prepped,
2263-
color_discrete_map_prepped,
2264-
category_orders_prepped,
2265-
) = self._setup_sample_colors_plotly(
2266-
data=df_haps,
2267-
color="partition",
2268-
color_discrete_map=color_discrete_map,
2269-
color_discrete_sequence=color_discrete_sequence,
2270-
category_orders=category_orders,
2271-
)
2272-
del color_discrete_map
2273-
del color_discrete_sequence
2274-
del category_orders
2275-
color_discrete_map_display = {
2276-
color_values_mapping[v]: c
2277-
for v, c in color_discrete_map_prepped.items()
2278-
}
2269+
# extract all unique values of the color column
2270+
color_values = df_haps["_partition"].fillna("<NA>").unique()
2271+
color_values_mapping = dict(
2272+
zip(df_haps["_partition"], df_haps[color_column])
2273+
)
2274+
color_values_mapping["<NA>"] = "black"
2275+
color_values_display = [
2276+
color_values_mapping[c] for c in color_values
2277+
]
2278+
2279+
# Handle mapping/dictionary case
2280+
elif isinstance(color, Mapping):
2281+
# For mapping case, we need to create a new column based on the mapping
2282+
# Initialize with None
2283+
df_haps["_partition"] = None
2284+
2285+
# Apply each query in the mapping to create the _partition column
2286+
for label, query in color.items():
2287+
# Apply the query and assign the label to matching rows
2288+
mask = df_haps.eval(query)
2289+
df_haps.loc[mask, "_partition"] = label
2290+
2291+
# Clean up the _partition column to avoid issues with special characters
2292+
if df_haps["_partition"].notna().any():
2293+
df_haps["_partition"] = df_haps["_partition"].str.replace(
2294+
r"\W", "", regex=True
2295+
)
2296+
2297+
# extract all unique values of the color column
2298+
color_values = df_haps["_partition"].fillna("<NA>").unique()
2299+
# For mapping case, use _partition values directly as they're already the labels
2300+
color_values_mapping = dict(
2301+
zip(df_haps["_partition"], df_haps["_partition"])
2302+
)
2303+
color_values_mapping["<NA>"] = "black"
2304+
color_values_display = [
2305+
color_values_mapping[c] for c in color_values
2306+
]
2307+
else:
2308+
# Invalid type
2309+
raise TypeError(
2310+
f"Expected color parameter to be a string or mapping, got {type(color).__name__}"
2311+
)
2312+
2313+
# count color values for each distinct haplotype (same for both string and mapping cases)
2314+
ht_color_counts = [
2315+
df_haps.iloc[list(s)]["_partition"].value_counts().to_dict()
2316+
for s in ht_distinct_sets
2317+
]
2318+
2319+
# Set up colors (same for both string and mapping cases)
2320+
(
2321+
color_prepped,
2322+
color_discrete_map_prepped,
2323+
category_orders_prepped,
2324+
) = self._setup_sample_colors_plotly(
2325+
data=df_haps,
2326+
color="_partition",
2327+
color_discrete_map=color_discrete_map,
2328+
color_discrete_sequence=color_discrete_sequence,
2329+
category_orders=category_orders,
2330+
)
2331+
del color_discrete_map
2332+
del color_discrete_sequence
2333+
del category_orders
2334+
color_discrete_map_display = {
2335+
color_values_mapping[v]: c
2336+
for v, c in color_discrete_map_prepped.items()
2337+
}
22792338

22802339
debug("construct graph")
22812340
anon_width = np.sqrt(0.3 * node_size_factor)
@@ -2284,7 +2343,7 @@ def plot_haplotype_network(
22842343
ht_distinct_mjn=ht_distinct_mjn,
22852344
ht_counts=ht_counts,
22862345
ht_color_counts=ht_color_counts,
2287-
color=color,
2346+
color="_partition" if color is not None else None,
22882347
color_values=color_values,
22892348
edges=edges,
22902349
alt_edges=alt_edges,
@@ -2336,7 +2395,7 @@ def plot_haplotype_network(
23362395
debug("create figure legend")
23372396
if color is not None:
23382397
legend_fig = plotly_discrete_legend(
2339-
color=color,
2398+
color="_partition", # Changed from color=color
23402399
color_values=color_values_display,
23412400
color_discrete_map=color_discrete_map_display,
23422401
category_orders=category_orders_prepped,

0 commit comments

Comments
 (0)