diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index 0aec9bd9a..e5c81bdf2 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -2232,46 +2232,105 @@ def plot_haplotype_network( edges = np.triu(edges) alt_edges = np.triu(alt_edges) - debug("setup colors") - color_values = None - color_values_display = None - color_discrete_map_display = None - ht_color_counts = None - if color is not None: - # sanitise color column - necessary to avoid grey pie chart segments - df_haps["partition"] = df_haps[color].str.replace(r"\W", "", regex=True) - - # extract all unique values of the color column - color_values = df_haps["partition"].fillna("").unique() - color_values_mapping = dict(zip(df_haps["partition"], df_haps[color])) - color_values_mapping[""] = "black" - color_values_display = [color_values_mapping[c] for c in color_values] - - # count color values for each distinct haplotype - ht_color_counts = [ - df_haps.iloc[list(s)]["partition"].value_counts().to_dict() - for s in ht_distinct_sets - ] + debug("setup colors") + color_values = None + color_values_display = None + color_discrete_map_display = None + ht_color_counts = None + + if color is not None: + # Handle string case (direct column name or cohorts_ prefix) + if isinstance(color, str): + # Try direct column name + if color in df_haps.columns: + color_column = color + # Try with cohorts_ prefix + elif f"cohorts_{color}" in df_haps.columns: + color_column = f"cohorts_{color}" + # Neither exists, raise helpful error + else: + available_columns = ", ".join(df_haps.columns) + raise ValueError( + f"Column '{color}' or 'cohorts_{color}' not found in sample data. " + f"Available columns: {available_columns}" + ) + + # Now use the validated color_column for processing + df_haps["partition"] = ( + df_haps[color_column] + .astype(str) + .str.replace(r"\W", "", regex=True) + ) - # Set up colors. - ( - color_prepped, - color_discrete_map_prepped, - category_orders_prepped, - ) = self._setup_sample_colors_plotly( - data=df_haps, - color="partition", - color_discrete_map=color_discrete_map, - color_discrete_sequence=color_discrete_sequence, - category_orders=category_orders, - ) - del color_discrete_map - del color_discrete_sequence - del category_orders - color_discrete_map_display = { - color_values_mapping[v]: c - for v, c in color_discrete_map_prepped.items() - } + # extract all unique values of the color column + color_values = df_haps["partition"].fillna("").unique() + color_values_mapping = dict( + zip(df_haps["partition"], df_haps[color_column]) + ) + color_values_mapping[""] = "black" + color_values_display = [ + color_values_mapping[c] for c in color_values + ] + + # Handle mapping/dictionary case + elif isinstance(color, Mapping): + # For mapping case, we need to create a new column based on the mapping + # Initialize with None + df_haps["partition"] = None + + # Apply each query in the mapping to create the partition column + for label, query in color.items(): + # Apply the query and assign the label to matching rows + mask = df_haps.eval(query) + df_haps.loc[mask, "partition"] = label + + # Clean up the partition column to avoid issues with special characters + if df_haps["partition"].notna().any(): + df_haps["partition"] = df_haps["partition"].str.replace( + r"\W", "", regex=True + ) + + # extract all unique values of the color column + color_values = df_haps["partition"].fillna("").unique() + # For mapping case, use partition values directly as they're already the labels + color_values_mapping = dict( + zip(df_haps["partition"], df_haps["partition"]) + ) + color_values_mapping[""] = "black" + color_values_display = [ + color_values_mapping[c] for c in color_values + ] + else: + # Invalid type + raise TypeError( + f"Expected color parameter to be a string or mapping, got {type(color).__name__}" + ) + + # count color values for each distinct haplotype (same for both string and mapping cases) + ht_color_counts = [ + df_haps.iloc[list(s)]["partition"].value_counts().to_dict() + for s in ht_distinct_sets + ] + + # Set up colors (same for both string and mapping cases) + ( + color_prepped, + color_discrete_map_prepped, + category_orders_prepped, + ) = self._setup_sample_colors_plotly( + data=df_haps, + color="partition", + color_discrete_map=color_discrete_map, + color_discrete_sequence=color_discrete_sequence, + category_orders=category_orders, + ) + del color_discrete_map + del color_discrete_sequence + del category_orders + color_discrete_map_display = { + color_values_mapping[v]: c + for v, c in color_discrete_map_prepped.items() + } debug("construct graph") anon_width = np.sqrt(0.3 * node_size_factor) @@ -2280,7 +2339,7 @@ def plot_haplotype_network( ht_distinct_mjn=ht_distinct_mjn, ht_counts=ht_counts, ht_color_counts=ht_color_counts, - color=color, + color="partition" if color is not None else None, color_values=color_values, edges=edges, alt_edges=alt_edges, @@ -2332,7 +2391,7 @@ def plot_haplotype_network( debug("create figure legend") if color is not None: legend_fig = plotly_discrete_legend( - color=color, + color="partition", # Changed from color=color color_values=color_values_display, color_discrete_map=color_discrete_map_display, category_orders=category_orders_prepped, diff --git a/notebooks/plot_haplotype_networks.ipynb b/notebooks/plot_haplotype_networks.ipynb index ea3f515a0..38adaa9a8 100644 --- a/notebooks/plot_haplotype_networks.ipynb +++ b/notebooks/plot_haplotype_networks.ipynb @@ -1,9 +1,18 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "f6b86889", + "metadata": {}, + "source": [ + "# Haplotype Network Plotting Examples\n", + "This notebook demonstrates the `plot_haplotype_network` function from the `malariagen_data` package, showcasing different ways to use the `color` parameter to visualize haplotype networks." + ] + }, { "cell_type": "code", "execution_count": null, - "id": "9de62268", + "id": "1cadfacf", "metadata": {}, "outputs": [], "source": [ @@ -26,6 +35,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Initialize Ag3 instance\n", "ag3 = malariagen_data.Ag3(\n", " \"simplecache::gs://vo_agam_release_master_us_central1\",\n", " simplecache=dict(cache_storage=\"../gcs_cache\"),\n", @@ -40,7 +50,16 @@ "id": "e3ffe116", "metadata": {}, "source": [ - "N.B., manually specifying the server_port parameter doesn't seem to be necessary on colab, but is needed when running locally via Jupyter notebook, otherwise get \"Address already in use\" error and cannot run multiple plots in same notebook. " + "N.B., manually specifying the server_port parameter doesn't seem to be necessary on colab, but is needed when running locally via Jupyter notebook, otherwise get \"Address already in use\" error and cannot run multiple plots in same notebook." + ] + }, + { + "cell_type": "markdown", + "id": "e5687f24", + "metadata": {}, + "source": [ + "## Example 1: Direct Column Name (String)\n", + "Use a direct column name like 'country' to color nodes by country." ] }, { @@ -50,22 +69,111 @@ "metadata": {}, "outputs": [], "source": [ + "# Plot haplotype network with country coloring\n", "ag3.plot_haplotype_network(\n", " region=\"2L:2,358,158-2,431,617\",\n", " analysis=\"gamb_colu\",\n", - " sample_query=\"taxon == 'coluzzii'\",\n", " sample_sets=\"3.0\",\n", + " sample_query=\"taxon == 'coluzzii'\",\n", " color=\"country\",\n", " max_dist=2,\n", ")" ] }, + { + "cell_type": "markdown", + "id": "2798b459", + "metadata": {}, + "source": [ + "## Example 2: Cohorts Prefix (String)\n", + "In this example, `\"admin1_iso\"` is used, which the function interprets as `\"cohorts_admin1_iso\"`, a column typically available in cohort-annotated metadata." + ] + }, { "cell_type": "code", "execution_count": null, "id": "3206fc04-1074-4f6c-8130-81dadff05c72", "metadata": {}, "outputs": [], + "source": [ + "ag3.plot_haplotype_network(\n", + " region=\"2L:2,358,158-2,431,617\",\n", + " analysis=\"gamb_colu\",\n", + " sample_query=\"taxon == 'coluzzii'\",\n", + " sample_sets=\"3.0\",\n", + " color=\"admin1_iso\", # Implies \"cohorts_admin1_iso\"\n", + " max_dist=2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "44c6a40b", + "metadata": {}, + "source": [ + "This example uses a dictionary to define custom color groups based on conditions applied to the `\"country\"` column." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8236cd99", + "metadata": {}, + "outputs": [], + "source": [ + "color_mapping = {\n", + " \"Ghana\": \"country == 'Ghana'\",\n", + " \"Other\": \"country != 'Ghana'\"\n", + "}\n", + "ag3.plot_haplotype_network(\n", + " region=\"2L:2,358,158-2,431,617\",\n", + " analysis=\"gamb_colu\",\n", + " sample_query=\"taxon == 'coluzzii'\",\n", + " sample_sets=\"3.0\",\n", + " color=color_mapping,\n", + " max_dist=2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "bd1962ba", + "metadata": {}, + "source": [ + "Setting `color=None` applies the default coloring scheme, typically uniform across all nodes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eab4c6fb", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_haplotype_network(\n", + " region=\"2L:2,358,158-2,431,617\",\n", + " analysis=\"gamb_colu\",\n", + " sample_query=\"taxon == 'coluzzii'\",\n", + " sample_sets=\"3.0\",\n", + " color=None,\n", + " max_dist=2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "20b54aa0", + "metadata": {}, + "source": [ + "This replicates Example 1 but uses `server_mode=\"external\"`, useful for rendering plots in certain environments." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "698ab518", + "metadata": {}, + "outputs": [], "source": [ "ag3.plot_haplotype_network(\n", " region=\"2L:2,358,158-2,431,617\",\n", @@ -94,6 +202,7 @@ "metadata": {}, "outputs": [], "source": [ + "# Initialize Af1 instance\n", "af1 = malariagen_data.Af1(\n", " \"simplecache::gs://vo_afun_release_master_us_central1\",\n", " simplecache=dict(cache_storage=\"../gcs_cache\"),\n", @@ -102,6 +211,14 @@ "af1" ] }, + { + "cell_type": "markdown", + "id": "d8aeab21", + "metadata": {}, + "source": [ + "Here, nodes are colored based on the `\"sample_set\"` column." + ] + }, { "cell_type": "code", "execution_count": null, @@ -120,13 +237,81 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "b1cde074", + "metadata": {}, + "source": [ + "Using `\"year\"` implies the function looks for `\"cohorts_year\"` in the metadata." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d7fc155", + "metadata": {}, + "outputs": [], + "source": [ + "af1.plot_haplotype_network(\n", + " region=\"2RL:2,358,158-2,431,617\",\n", + " sample_query=\"country == 'Ghana'\",\n", + " sample_sets=\"1.0\",\n", + " color=\"year\", # Implies \"cohorts_year\"\n", + " max_dist=2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e6e60160", + "metadata": {}, + "source": [ + "A dictionary defines custom groups based on the `\"year\"` column (assuming year data is available)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "882e6b8f", + "metadata": {}, + "outputs": [], + "source": [ + "color_mapping = {\n", + " \"2012\": \"year == 2012\",\n", + " \"2014\": \"year == 2014\"\n", + "}\n", + "af1.plot_haplotype_network(\n", + " region=\"2RL:2,358,158-2,431,617\",\n", + " sample_query=\"country == 'Ghana'\",\n", + " sample_sets=\"1.0\",\n", + " color=color_mapping,\n", + " max_dist=2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "485becad", + "metadata": {}, + "source": [ + "With `color=None`, the default coloring is applied." + ] + }, { "cell_type": "code", "execution_count": null, - "id": "42af79bc-35a6-4c96-ae5b-62bd46a30ad1", + "id": "bd013c5c", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "af1.plot_haplotype_network(\n", + " region=\"2RL:2,358,158-2,431,617\",\n", + " sample_query=\"country == 'Ghana'\",\n", + " sample_sets=\"1.0\",\n", + " color=None,\n", + " max_dist=2,\n", + ")" + ] } ], "metadata": { @@ -145,12 +330,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" - }, - "vscode": { - "interpreter": { - "hash": "3b9ddb1005cd06989fd869b9e3d566470f1be01faa610bb17d64e58e32302e8b" - } + "version": "3.12.0" }, "widgets": { "application/vnd.jupyter.widget-state+json": { diff --git a/tests/integration/test_af1.py b/tests/integration/test_af1.py index de1b749de..56720244a 100644 --- a/tests/integration/test_af1.py +++ b/tests/integration/test_af1.py @@ -90,3 +90,92 @@ def test_karyotyping(inversion): sample_sets="1229-VO-GH-DADZIE-VMF00095", sample_query=None, ) + + +def test_plot_haplotype_network_string_direct(mocker): + af1 = setup_af1(debug=True) + mocker.patch("dash.Dash.run") + mock_mjn = mocker.patch("malariagen_data.anopheles.mjn_graph") + mock_mjn.return_value = ([{"data": {"id": "n1"}}], []) + + af1.plot_haplotype_network( + region="2RL:24,630,355-24,633,221", + analysis="funestus", + sample_sets="1.0", + sample_query="taxon == 'funestus'", + color="country", + max_dist=2, + server_mode="inline", + ) + + assert mock_mjn.called + call_args = mock_mjn.call_args[1] + assert call_args["color"] == "partition" + assert call_args["ht_color_counts"] is not None + + +def test_plot_haplotype_network_string_cohort(mocker): + af1 = setup_af1(debug=True) + mocker.patch("dash.Dash.run") + mock_mjn = mocker.patch("malariagen_data.anopheles.mjn_graph") + mock_mjn.return_value = ([{"data": {"id": "n1"}}], []) + + af1.plot_haplotype_network( + region="2RL:24,630,355-24,633,221", + analysis="funestus", + sample_sets="1.0", + sample_query="taxon == 'funestus'", + color="year", + max_dist=2, + server_mode="inline", + ) + + assert mock_mjn.called + call_args = mock_mjn.call_args[1] + assert call_args["color"] == "partition" + assert call_args["ht_color_counts"] is not None + + +def test_plot_haplotype_network_mapping(mocker): + af1 = setup_af1(debug=True) + mocker.patch("dash.Dash.run") + mock_mjn = mocker.patch("malariagen_data.anopheles.mjn_graph") + mock_mjn.return_value = ([{"data": {"id": "n1"}}], []) + + color_mapping = {"2012": "year == 2012", "2014": "year == 2014"} + af1.plot_haplotype_network( + region="2RL:24,630,355-24,633,221", + analysis="funestus", + sample_sets="1.0", + sample_query="taxon == 'funestus'", + color=color_mapping, + max_dist=2, + server_mode="inline", + ) + + assert mock_mjn.called + call_args = mock_mjn.call_args[1] + assert call_args["color"] == "partition" + assert call_args["ht_color_counts"] is not None + + +def test_plot_haplotype_network_none(mocker): + af1 = setup_af1(debug=True) + mocker.patch("dash.Dash.run") + mock_mjn = mocker.patch("malariagen_data.anopheles.mjn_graph") + mock_mjn.return_value = ([{"data": {"id": "n1"}}], []) + + af1.plot_haplotype_network( + region="2RL:24,630,355-24,633,221", + analysis="funestus", + sample_sets="1.0", + sample_query="taxon == 'funestus'", + color=None, + max_dist=2, + server_mode="inline", + ) + + assert mock_mjn.called + call_args = mock_mjn.call_args[1] + assert call_args["color"] is None + assert call_args["ht_color_counts"] is None diff --git a/tests/integration/test_ag3.py b/tests/integration/test_ag3.py index 5ee539d2f..cf6a2f17f 100644 --- a/tests/integration/test_ag3.py +++ b/tests/integration/test_ag3.py @@ -186,3 +186,92 @@ def test_karyotyping(inversion): assert set(df.columns) == set(expected_cols) assert all(df[f"karyotype_{inversion}"].isin([0, 1, 2])) assert all(df[f"karyotype_{inversion}_mean"].between(0, 2)) + + +def test_plot_haplotype_network_string_direct(mocker): + ag3 = setup_ag3() + mocker.patch("dash.Dash.run") + mock_mjn = mocker.patch("malariagen_data.anopheles.mjn_graph") + mock_mjn.return_value = ([{"data": {"id": "n1"}}], []) + + ag3.plot_haplotype_network( + region="2L:2,358,158-2,358,258", + analysis="gamb_colu", + sample_sets="3.0", + sample_query="taxon == 'coluzzii'", + color="country", + max_dist=2, + server_mode="inline", + ) + + assert mock_mjn.called + call_args = mock_mjn.call_args[1] + assert call_args["color"] == "partition" + assert call_args["ht_color_counts"] is not None + + +def test_plot_haplotype_network_string_cohort(mocker): + ag3 = setup_ag3() + mocker.patch("dash.Dash.run") + mock_mjn = mocker.patch("malariagen_data.anopheles.mjn_graph") + mock_mjn.return_value = ([{"data": {"id": "n1"}}], []) + + ag3.plot_haplotype_network( + region="2L:2,358,158-2,358,258", + analysis="gamb_colu", + sample_sets="3.0", + sample_query="taxon == 'coluzzii'", + color="admin1_iso", + max_dist=2, + server_mode="inline", + ) + + assert mock_mjn.called + call_args = mock_mjn.call_args[1] + assert call_args["color"] == "partition" + assert call_args["ht_color_counts"] is not None + + +def test_plot_haplotype_network_mapping(mocker): + ag3 = setup_ag3() + mocker.patch("dash.Dash.run") + mock_mjn = mocker.patch("malariagen_data.anopheles.mjn_graph") + mock_mjn.return_value = ([{"data": {"id": "n1"}}], []) + + color_mapping = {"Ghana": "country == 'Ghana'", "Other": "country != 'Ghana'"} + ag3.plot_haplotype_network( + region="2L:2,358,158-2,358,258", + analysis="gamb_colu", + sample_sets="3.0", + sample_query="taxon == 'coluzzii'", + color=color_mapping, + max_dist=2, + server_mode="inline", + ) + + assert mock_mjn.called + call_args = mock_mjn.call_args[1] + assert call_args["color"] == "partition" + assert call_args["ht_color_counts"] is not None + + +def test_plot_haplotype_network_none(mocker): + ag3 = setup_ag3() + mocker.patch("dash.Dash.run") + mock_mjn = mocker.patch("malariagen_data.anopheles.mjn_graph") + mock_mjn.return_value = ([{"data": {"id": "n1"}}], []) + + ag3.plot_haplotype_network( + region="2L:2,358,158-2,358,258", + analysis="gamb_colu", + sample_sets="3.0", + sample_query="taxon == 'coluzzii'", + color=None, + max_dist=2, + server_mode="inline", + ) + + assert mock_mjn.called + call_args = mock_mjn.call_args[1] + assert call_args["color"] is None + assert call_args["ht_color_counts"] is None