diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b8dd309..7a718b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,6 +5,6 @@ repos: - id: black language_version: python3.10 - repo: https://github.com/pycqa/pylint - rev: pylint-2.6.0 + rev: v2.14.4 hooks: - id: pylint diff --git a/geograph/geograph.py b/geograph/geograph.py index 360363f..eae44b8 100644 --- a/geograph/geograph.py +++ b/geograph/geograph.py @@ -422,48 +422,57 @@ def _load_from_dataframe( # Reset index to ensure consistent indices df = df.reset_index(drop=True) - # Using this list and iterating through it is slightly faster than - # iterating through df due to the dataframe overhead - geom: List[shapely.Polygon] = df["geometry"].tolist() - # this dict maps polygon row numbers in df to a list - # of neighbouring polygon row numbers - graph_dict = {} - - if tolerance > 0: - # Expand the borders of the polygons by `tolerance``` - new_polygons: List[shapely.Polygon] = ( - df["geometry"].buffer(tolerance).tolist() - ) + # pylint: disable=protected-access + if gpd._compat.USE_PYGEOS: + if tolerance > 0: + neighbour_arr = df.sindex.query_bulk( + df["geometry"].buffer(tolerance), predicate="intersects" + ).transpose() + else: + neighbour_arr = df.sindex.query_bulk( + df["geometry"], predicate="intersects" + ).transpose() + else: + # this dict maps polygon row numbers in df to a list + # of neighbouring polygon row numbers + graph_dict = {} + if tolerance > 0: + # Expand the borders of the polygons by `tolerance``` + new_polygons: gpd.GeoSeries = df["geometry"].buffer(tolerance) # Creating nodes (=vertices) and finding neighbors for index, polygon in tqdm( - enumerate(geom), + enumerate(df["geometry"]), desc="Step 1 of 2: Creating nodes and finding neighbours", - total=len(geom), + total=len(df), ): - if tolerance > 0: - # find the indexes of all polygons which intersect with this one - neighbours = df.sindex.query( - new_polygons[index], predicate="intersects" - ) - else: - neighbours = df.sindex.query(polygon, predicate="intersects") - - graph_dict[index] = neighbours - # add each polygon as a node to the graph with useful attributes - self.graph.add_node( - index, - rep_point=polygon.representative_point(), - area=polygon.area, - perimeter=polygon.length, - bounds=polygon.bounds, - ) + # pylint: disable=protected-access + if not gpd._compat.USE_PYGEOS: + if tolerance > 0: + # find the indexes of all polygons which intersect with this one + neighbours = df.sindex.query( + new_polygons[index], predicate="intersects" + ) + else: + neighbours = df.sindex.query(polygon, predicate="intersects") + + graph_dict[index] = neighbours + # TODO: factor out for use_pygeos + self.graph.add_node(index) # iterate through the dict and add edges between neighbouring polygons - for polygon_id, neighbours in tqdm( - graph_dict.items(), desc="Step 2 of 2: Adding edges" - ): - for neighbour_id in neighbours: - if polygon_id != neighbour_id: - self.graph.add_edge(polygon_id, neighbour_id) + # pylint: disable=protected-access + if gpd._compat.USE_PYGEOS: + for index, neighbour in tqdm( + neighbour_arr, desc="Step 2 of 2: Adding edges" + ): + if index != neighbour: + self.graph.add_edge(index, neighbour) + else: + for polygon_id, neighbours in tqdm( + graph_dict.items(), desc="Step 2 of 2: Adding edges" + ): + for neighbour_id in neighbours: + if polygon_id != neighbour_id: + self.graph.add_edge(polygon_id, neighbour_id) # add index name df.index.name = "node_index" @@ -802,7 +811,9 @@ def get_graph_components( """ components: List[set] = list(nx.connected_components(self.graph)) if calc_polygons: - geom = [self.df["geometry"].loc[comp].unary_union for comp in components] + geom = [ + self.df["geometry"].loc[list(comp)].unary_union for comp in components + ] gdf = gpd.GeoDataFrame( {"geometry": geom, "class_label": -1}, crs=self.df.crs ) @@ -850,7 +861,7 @@ def get_metric( result = metrics._get_metric( name=name, geo_graph=self, class_value=class_value, **metric_kwargs ) - if name in self.class_metrics.keys(): + if name in self.class_metrics: self.class_metrics[name][class_value] = result else: self.class_metrics[name] = {class_value: result} @@ -968,14 +979,7 @@ def _add_node( node_data = dict(data.items()) # Add node to graph - self.graph.add_node( - node_id, - rep_point=node_data["geometry"].representative_point(), - area=node_data["geometry"].area, - perimeter=node_data["geometry"].length, - class_label=node_data["class_label"], - bounds=node_data["geometry"].bounds, - ) + self.graph.add_node(node_id) # Add node data to dataframe missing_cols = { @@ -1121,25 +1125,25 @@ class label in `valid_classes`, as long as they are less than f"and {self.graph.number_of_edges()} edges.", ) - def _load_from_graph_path(self, load_path: pathlib.Path) -> None: + def _load_from_graph_path(self, graph_path: pathlib.Path) -> None: """ Load networkx graph and dataframe objects from a pickle file. Args: - load_path (pathlib.Path): Path to a pickle file. Can be compressed + graph_path (pathlib.Path): Path to a pickle file. Can be compressed with gzip or bz2. Returns: gpd.GeoDataFrame: The dataframe containing polygon objects. """ - if load_path.suffix == ".bz2": - with bz2.BZ2File(load_path, "rb") as bz2_file: + if graph_path.suffix == ".bz2": + with bz2.BZ2File(graph_path, "rb") as bz2_file: data = pickle.load(bz2_file) - elif load_path.suffix == ".gz": - with gzip.GzipFile(load_path, "rb") as gz_file: + elif graph_path.suffix == ".gz": + with gzip.GzipFile(graph_path, "rb") as gz_file: data = pickle.loads(gz_file.read()) else: - with open(load_path, "rb") as file: + with open(graph_path, "rb") as file: data = pickle.load(file) self.df = data["dataframe"] self.name = data["name"] @@ -1263,16 +1267,6 @@ def _load_from_dataframe( self.graph = nx.complete_graph(len(df)) else: self.graph = nx.empty_graph(len(df)) - # Add node attributes - for node in tqdm( - self.graph.nodes, desc="Constructing graph", total=len(self.graph) - ): - polygon = geom[node] - self.graph.nodes[node]["rep_point"] = polygon.representative_point() - self.graph.nodes[node]["area"] = polygon.area - self.graph.nodes[node]["perimeter"] = polygon.length - self.graph.nodes[node]["bounds"] = polygon.bounds - # Add edge attributes if necessary if self.has_distance_edges: for u, v, attrs in tqdm( diff --git a/geograph/metrics.py b/geograph/metrics.py index 0b02ff4..b17949d 100644 --- a/geograph/metrics.py +++ b/geograph/metrics.py @@ -2,7 +2,8 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional, Union +from itertools import combinations +from typing import TYPE_CHECKING, Any, Dict, Optional, Union import networkx as nx import numpy as np @@ -10,7 +11,9 @@ if TYPE_CHECKING: import geograph - +# TODO: refactor this file to return a tuple with the values to put inside the +# metric such that the metric is only created once by the calling GeoGraph +# since dataclass creation is slow # define a metric dataclass with < <= => > == comparisons that work as you would # expect intuitively @dataclass() @@ -52,6 +55,8 @@ def __ge__(self, o: object) -> bool: ######################################################################################## # 1. Landscape level metrics ######################################################################################## + + def _num_patches(geo_graph: geograph.GeoGraph) -> Metric: """ Calculate number of patches. @@ -562,7 +567,7 @@ def _class_effective_mesh_size( } ######################################################################################## -# 3. Habitat componment level metrics +# 3. Habitat component level metrics ######################################################################################## @@ -655,10 +660,49 @@ def _avg_component_isolation(geo_graph: geograph.GeoGraph) -> Metric: ) +def _habitat_iic( + geo_graph: geograph.GeoGraph, + get_total_area: bool = False, + shortest_path_cutoff: Optional[int] = None, +) -> Metric: + if get_total_area: + # Most efficient way to get area of convex hull of the GeoGraph + total_area = geo_graph.components.df.dissolve().convex_hull.values[0].area + iic = 0.0 + idx_dict = dict(zip(geo_graph.df.index.values, range(len(geo_graph.df)))) + path_lengths: Dict = dict( + nx.all_pairs_shortest_path_length(geo_graph.graph, cutoff=shortest_path_cutoff) + ) + for x in combinations(geo_graph.df.index.values, 2): + if x[1] not in path_lengths[x[0]]: + continue + iic += ( + geo_graph.graph.nodes[x[0]]["area"] * geo_graph.graph.nodes[x[1]]["area"] + ) / (1 + path_lengths[x[0]][x[1]]) + if get_total_area: + # for node in geo_graph.graph.nodes: + # iic += geo_graph.graph.nodes[node]["area"] ** 2 + return Metric( + value=iic / total_area, + name="habitat_iic", + description="The habitat IIC metric", + variant="component", + unit="dimensionless", + ) + return Metric( + value=iic, + name="habitat_iic", + description="The habitat IIC metric", + variant="component", + unit="dimensionless", + ) + + COMPONENT_METRICS_DICT = { "num_components": _num_components, "avg_component_area": _avg_component_area, "avg_component_isolation": _avg_component_isolation, + "habitat_iic": _habitat_iic, } @@ -666,6 +710,7 @@ def _avg_component_isolation(geo_graph: geograph.GeoGraph) -> Metric: # 4. Define access interface for GeoGraph ######################################################################################## + STANDARD_METRICS = ["num_components", "avg_patch_area", "total_area"] diff --git a/pylintrc b/pylintrc index cbac165..c015dac 100644 --- a/pylintrc +++ b/pylintrc @@ -160,12 +160,6 @@ disable=abstract-method, # mypackage.mymodule.MyReporterClass. output-format=text -# Put messages in a separate file for each module / package specified on the -# command line instead of printing them on stdout. Reports (if any) will be -# written in a file name "pylint_global.[txt|html]". This option is deprecated -# and it will be removed in Pylint 2.0. -files-output=no - # Tells whether to display a full report or only the messages reports=no @@ -284,12 +278,6 @@ ignore-long-lines=(?x)( # else. single-line-if-stmt=yes -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check= - # Maximum number of lines in a module max-module-lines=99999