|
| 1 | +from pathlib import Path |
| 2 | + |
| 3 | +import matplotlib as mpl |
| 4 | +import plotly.graph_objects as go |
| 5 | +import polars as pl |
| 6 | +from polars import col |
| 7 | + |
| 8 | +from biolit import DATADIR |
| 9 | +from biolit.taxref import TAXREF_HIERARCHY |
| 10 | + |
| 11 | +COLOR_MATCHING = { |
| 12 | + i: f"rgb({', '.join(str(int(x * 255)) for x in mpl.colormaps['tab10'](i)[:3])})" |
| 13 | + for i in range(20) |
| 14 | +} |
| 15 | + |
| 16 | + |
| 17 | +def _species_colors(frame: pl.DataFrame) -> pl.DataFrame: |
| 18 | + return ( |
| 19 | + frame["regne"] |
| 20 | + .unique() |
| 21 | + .sort() |
| 22 | + .to_frame() |
| 23 | + .with_row_index("color") |
| 24 | + .with_columns(col("color").replace_strict(COLOR_MATCHING)) |
| 25 | + ) |
| 26 | + |
| 27 | + |
| 28 | +def plot_species_distribution(frame: pl.DataFrame, fn: Path): |
| 29 | + colors = _species_colors(frame) |
| 30 | + species_counts = ( |
| 31 | + frame.filter(col("cd_nom").is_not_null()) |
| 32 | + .group_by(["nom_scientifique", "cd_nom"] + TAXREF_HIERARCHY) |
| 33 | + .agg(col("id").count()) |
| 34 | + .join(colors, on="regne") |
| 35 | + ) |
| 36 | + |
| 37 | + edges = _baseline_edges(species_counts) |
| 38 | + nodes = nodes_from_edges(edges) |
| 39 | + edges = enrich_edges(edges, nodes) |
| 40 | + edges.write_parquet(DATADIR / "species_edges.parquet") |
| 41 | + nodes.write_parquet(DATADIR / "species_node.parquet") |
| 42 | + save_sankey_plot(edges, nodes, fn) |
| 43 | + |
| 44 | + |
| 45 | +def save_sankey_plot(edges: pl.DataFrame, nodes: pl.DataFrame, fn: Path) -> Path: |
| 46 | + _data = go.Sankey( |
| 47 | + link=edges.to_dict(as_series=False), |
| 48 | + node=nodes.select("label", "color", "customdata").to_dict(as_series=False) |
| 49 | + | { |
| 50 | + "line": dict(color="lightgrey", width=0.1), |
| 51 | + "hovertemplate": "<b>%{customdata.name}</b><br>" |
| 52 | + "node_id: %{customdata.node_id}<br>" |
| 53 | + "# images: %{value}<br>" |
| 54 | + "# sub level: %{customdata.n_incoming}<br>" |
| 55 | + "# species: %{customdata.n_species}<br>" |
| 56 | + "<extra></extra>", |
| 57 | + }, |
| 58 | + ) |
| 59 | + |
| 60 | + _fig = go.Figure(_data) |
| 61 | + _fig.update_layout( |
| 62 | + autosize=False, |
| 63 | + width=1000, |
| 64 | + height=1500, |
| 65 | + title_text="Répartition des images Biolit en selon les différentes strates de la hierarchie", |
| 66 | + font_size=10, |
| 67 | + ) |
| 68 | + _fig.write_html(fn) |
| 69 | + |
| 70 | + |
| 71 | +def _baseline_edges(species_counts: pl.DataFrame) -> pl.DataFrame: |
| 72 | + _edges = [] |
| 73 | + |
| 74 | + _steps = ["nom_scientifique"] + TAXREF_HIERARCHY[:-1][::-1] |
| 75 | + for _source, _target in zip(_steps, _steps[1:]): |
| 76 | + tmp = ( |
| 77 | + species_counts.group_by(_source, _target) |
| 78 | + .agg( |
| 79 | + col("id").sum(), |
| 80 | + col("id").count().alias("n_species"), |
| 81 | + col("color").first(), |
| 82 | + ) |
| 83 | + .rename({_source: "source", _target: "target", "id": "value"}) |
| 84 | + ) |
| 85 | + _edges.append(tmp) |
| 86 | + return pl.concat(_edges) |
| 87 | + |
| 88 | + |
| 89 | +def nodes_from_edges(edges: pl.DataFrame) -> pl.DataFrame: |
| 90 | + has_labels = _node_has_labels(edges) |
| 91 | + return ( |
| 92 | + pl.concat([edges["source"], edges["target"]]) |
| 93 | + .unique() |
| 94 | + .sort() |
| 95 | + .to_frame() |
| 96 | + .with_row_index("id") |
| 97 | + .with_columns(col("id") - 1) |
| 98 | + .join(has_labels, left_on="source", right_on="target") |
| 99 | + .with_columns( |
| 100 | + pl.when(col("has_label")).then(col("source")).alias("label"), |
| 101 | + pl.when(col("has_label")) |
| 102 | + .then(pl.lit("blue")) |
| 103 | + .otherwise(pl.lit("lightgrey")) |
| 104 | + .alias("color"), |
| 105 | + pl.struct( |
| 106 | + name=col("source"), |
| 107 | + n_incoming=col("n_incoming"), |
| 108 | + n_species=col("n_species"), |
| 109 | + node_id=col("id"), |
| 110 | + ).alias("customdata"), |
| 111 | + ) |
| 112 | + ) |
| 113 | + |
| 114 | + |
| 115 | +def _node_has_labels(edges: pl.DataFrame) -> pl.DataFrame: |
| 116 | + return ( |
| 117 | + edges.group_by("target") |
| 118 | + .agg( |
| 119 | + col("value").sum(), |
| 120 | + col("source").count().alias("n_incoming"), |
| 121 | + col("n_species").sum(), |
| 122 | + ) |
| 123 | + .with_columns( |
| 124 | + (col("value") > 300).alias("has_label"), |
| 125 | + col("target").str.count_matches("|", literal=True).alias("n_levels"), |
| 126 | + ) |
| 127 | + ) |
| 128 | + |
| 129 | + |
| 130 | +def enrich_edges(edges: pl.DataFrame, nodes: pl.DataFrame) -> pl.DataFrame: |
| 131 | + _sub_nodes = nodes.select("id", "source") |
| 132 | + return ( |
| 133 | + edges.select("source", "target", "value", "color") |
| 134 | + .join(_sub_nodes, left_on="source", right_on="source") |
| 135 | + .join(_sub_nodes, left_on="target", right_on="source") |
| 136 | + .drop("target", "source") |
| 137 | + .rename({"id": "source", "id_right": "target"}) |
| 138 | + .sort("source", "target") |
| 139 | + ) |
0 commit comments