Skip to content
Merged
141 changes: 100 additions & 41 deletions malariagen_data/anopheles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2232,46 +2232,105 @@ def plot_haplotype_network(
edges = np.triu(edges)
alt_edges = np.triu(alt_edges)

debug("setup colors")
color_values = None
color_values_display = None
color_discrete_map_display = None
ht_color_counts = None
if color is not None:
# sanitise color column - necessary to avoid grey pie chart segments
df_haps["partition"] = df_haps[color].str.replace(r"\W", "", regex=True)

# extract all unique values of the color column
color_values = df_haps["partition"].fillna("<NA>").unique()
color_values_mapping = dict(zip(df_haps["partition"], df_haps[color]))
color_values_mapping["<NA>"] = "black"
color_values_display = [color_values_mapping[c] for c in color_values]

# count color values for each distinct haplotype
ht_color_counts = [
df_haps.iloc[list(s)]["partition"].value_counts().to_dict()
for s in ht_distinct_sets
]
debug("setup colors")
color_values = None
color_values_display = None
color_discrete_map_display = None
ht_color_counts = None

if color is not None:
# Handle string case (direct column name or cohorts_ prefix)
if isinstance(color, str):
# Try direct column name
if color in df_haps.columns:
color_column = color
# Try with cohorts_ prefix
elif f"cohorts_{color}" in df_haps.columns:
color_column = f"cohorts_{color}"
# Neither exists, raise helpful error
else:
available_columns = ", ".join(df_haps.columns)
raise ValueError(
f"Column '{color}' or 'cohorts_{color}' not found in sample data. "
f"Available columns: {available_columns}"
)

# Now use the validated color_column for processing
df_haps["partition"] = (
df_haps[color_column]
.astype(str)
.str.replace(r"\W", "", regex=True)
)

# Set up colors.
(
color_prepped,
color_discrete_map_prepped,
category_orders_prepped,
) = self._setup_sample_colors_plotly(
data=df_haps,
color="partition",
color_discrete_map=color_discrete_map,
color_discrete_sequence=color_discrete_sequence,
category_orders=category_orders,
)
del color_discrete_map
del color_discrete_sequence
del category_orders
color_discrete_map_display = {
color_values_mapping[v]: c
for v, c in color_discrete_map_prepped.items()
}
# extract all unique values of the color column
color_values = df_haps["partition"].fillna("<NA>").unique()
color_values_mapping = dict(
zip(df_haps["partition"], df_haps[color_column])
)
color_values_mapping["<NA>"] = "black"
color_values_display = [
color_values_mapping[c] for c in color_values
]

# Handle mapping/dictionary case
elif isinstance(color, Mapping):
# For mapping case, we need to create a new column based on the mapping
# Initialize with None
df_haps["partition"] = None

# Apply each query in the mapping to create the partition column
for label, query in color.items():
# Apply the query and assign the label to matching rows
mask = df_haps.eval(query)
df_haps.loc[mask, "partition"] = label

# Clean up the partition column to avoid issues with special characters
if df_haps["partition"].notna().any():
df_haps["partition"] = df_haps["partition"].str.replace(
r"\W", "", regex=True
)

# extract all unique values of the color column
color_values = df_haps["partition"].fillna("<NA>").unique()
# For mapping case, use partition values directly as they're already the labels
color_values_mapping = dict(
zip(df_haps["partition"], df_haps["partition"])
)
color_values_mapping["<NA>"] = "black"
color_values_display = [
color_values_mapping[c] for c in color_values
]
else:
# Invalid type
raise TypeError(
f"Expected color parameter to be a string or mapping, got {type(color).__name__}"
)

# count color values for each distinct haplotype (same for both string and mapping cases)
ht_color_counts = [
df_haps.iloc[list(s)]["partition"].value_counts().to_dict()
for s in ht_distinct_sets
]

# Set up colors (same for both string and mapping cases)
(
color_prepped,
color_discrete_map_prepped,
category_orders_prepped,
) = self._setup_sample_colors_plotly(
data=df_haps,
color="partition",
color_discrete_map=color_discrete_map,
color_discrete_sequence=color_discrete_sequence,
category_orders=category_orders,
)
del color_discrete_map
del color_discrete_sequence
del category_orders
color_discrete_map_display = {
color_values_mapping[v]: c
for v, c in color_discrete_map_prepped.items()
}

debug("construct graph")
anon_width = np.sqrt(0.3 * node_size_factor)
Expand All @@ -2280,7 +2339,7 @@ def plot_haplotype_network(
ht_distinct_mjn=ht_distinct_mjn,
ht_counts=ht_counts,
ht_color_counts=ht_color_counts,
color=color,
color="partition" if color is not None else None,
color_values=color_values,
edges=edges,
alt_edges=alt_edges,
Expand Down Expand Up @@ -2332,7 +2391,7 @@ def plot_haplotype_network(
debug("create figure legend")
if color is not None:
legend_fig = plotly_discrete_legend(
color=color,
color="partition", # Changed from color=color
color_values=color_values_display,
color_discrete_map=color_discrete_map_display,
category_orders=category_orders_prepped,
Expand Down
Loading
Loading