Skip to content

grass.jupyter: fix TimeSeriesMap layer rendering #5632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions python/grass/jupyter/baseseriesmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ def __init__(self, width=None, height=None, env=None):
else:
self._env = os.environ.copy()

self.baseseries = None
self._base_layer_calls = []
self._base_calls = []
self._calls = []
self._baseseries_added = False
self._layers_rendered = False
self._base_filename_dict = {}
Expand Down Expand Up @@ -91,11 +90,9 @@ def __getattr__(self, name):
def wrapper(**kwargs):
if not self._baseseries_added:
self._base_layer_calls.append((grass_module, kwargs))
elif self._base_calls is not None:
for row in self._base_calls:
row.append((grass_module, kwargs))
else:
self._base_calls.append((grass_module, kwargs))
for row in self._calls:
row.append((grass_module, kwargs))

return wrapper

Expand All @@ -104,20 +101,36 @@ def _render_baselayers(self, img):
for grass_module, kwargs in self._base_layer_calls:
img.run(grass_module, **kwargs)

def _render(self, tasks):
"""
Renders the base image for the dataset.
def _render_worker(self, i):
"""Function to render a single layer."""
filename = os.path.join(self._tmpdir.name, f"{i}.png")
shutil.copyfile(self.base_file, filename)
img = Map(
width=self._width,
height=self._height,
filename=filename,
use_region=True,
env=self._env,
read_file=True,
)
for grass_module, kwargs in self._calls[i]:
if grass_module is not None:
img.run(grass_module, **kwargs)
return self._indices[i], filename

Saves PNGs to a temporary directory.
This method must be run before creating a visualization (e.g., show or save).
It can be time-consuming to run with large space-time datasets.
def render(self):
"""Renders image for each raster in series.

Child classes should override the `render` method
to define specific rendering behaviors, such as:
- Rendering images for each time-step in a space-time dataset (e.g., class1).
- Rendering images for each raster in a series (e.g., class2).
Save PNGs to temporary directory. Must be run before creating a visualization
(i.e. show or save).
"""
# Runtime error in respective classes
if not self._baseseries_added:
msg = (
"Cannot render series since none has been added."
"Use SeriesMap.add_rasters() or SeriesMap.add_vectors()"
)
raise RuntimeError(msg)
tasks = [(i,) for i in range(len(self._indices))]

# Make base image (background and baselayers)
# Random name needed to avoid potential conflict with layer names
Expand Down
24 changes: 14 additions & 10 deletions python/grass/jupyter/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def __init__(self, use_region, saved_region, env):
self._use_region = use_region
self._saved_region = saved_region

def set_region_from_timeseries(self, timeseries):
def set_region_from_timeseries(self, timeseries, element_type="strds"):
"""Sets computational region for rendering.

This function sets the computation region from the extent of
Expand All @@ -353,13 +353,17 @@ def set_region_from_timeseries(self, timeseries):
# use current
return
# Get extent, resolution from space time dataset
info = gs.parse_command("t.info", input=timeseries, flags="g", env=self._env)
# Set grass region from extent
self._env["GRASS_REGION"] = gs.region_env(
n=info["north"],
s=info["south"],
e=info["east"],
w=info["west"],
nsres=info["nsres_min"],
ewres=info["ewres_min"],
info = gs.parse_command(
"t.info", input=timeseries, type=element_type, flags="g", env=self._env
)
# Set grass region from extent
params = {
"n": info["north"],
"s": info["south"],
"e": info["east"],
"w": info["west"],
}
if "nsres_min" in info:
params["nsres"] = info["nsres_min"]
params["ewres"] = info["ewres_min"]
self._env["GRASS_REGION"] = gs.region_env(**params, env=self._env)
75 changes: 21 additions & 54 deletions python/grass/jupyter/seriesmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,8 @@
# for details.
"""Create and display visualizations for a series of rasters."""

import os
import shutil

from grass.grassdb.data import map_exists

from .map import Map
from .region import RegionManagerForSeries
from .baseseriesmap import BaseSeriesMap

Expand Down Expand Up @@ -63,6 +59,8 @@ def __init__(
"""
super().__init__(width, height, env)

self._layer_count = 0

# Handle Regions
self._region_manager = RegionManagerForSeries(
use_region=use_region,
Expand All @@ -82,17 +80,17 @@ def add_rasters(self, rasters, **kwargs):
# Update region to rasters if not use_region or saved_region
self._region_manager.set_region_from_rasters(rasters)
if self._baseseries_added:
assert self.baseseries == len(rasters), _(
"Number of vectors in series must match number of vectors"
)
for i in range(self.baseseries):
if self._layer_count != len(rasters):
msg = _("Number of rasters in series must match")
raise RuntimeError(msg)
for i in range(self._layer_count):
kwargs["map"] = rasters[i]
self._base_calls[i].append(("d.rast", kwargs.copy()))
self._calls[i].append(("d.rast", kwargs.copy()))
else:
self.baseseries = len(rasters)
self._layer_count = len(rasters)
for raster in rasters:
kwargs["map"] = raster
self._base_calls.append([("d.rast", kwargs.copy())])
self._calls.append([("d.rast", kwargs.copy())])
self._baseseries_added = True
if not self._labels:
self._labels = rasters
Expand All @@ -109,59 +107,28 @@ def add_vectors(self, vectors, **kwargs):
# Update region extent to vectors if not use_region or saved_region
self._region_manager.set_region_from_vectors(vectors)
if self._baseseries_added:
assert self.baseseries == len(vectors), _(
"Number of rasters in series must match number of vectors"
)
for i in range(self.baseseries):
if self._layer_count != len(vectors):
msg = _("Number of vectors in series must match")
raise RuntimeError(msg)
for i in range(self._layer_count):
kwargs["map"] = vectors[i]
self._base_calls[i].append(("d.vect", kwargs.copy()))
self._calls[i].append(("d.vect", kwargs.copy()))
else:
self.baseseries = len(vectors)
self._layer_count = len(vectors)
for vector in vectors:
kwargs["map"] = vector
self._base_calls.append([("d.vect", kwargs.copy())])
self._calls.append([("d.vect", kwargs.copy())])
self._baseseries_added = True
if not self._labels:
self._labels = vectors
self._layers_rendered = False
self._indices = range(len(self._labels))
self._indices = list(range(len(self._labels)))

def add_names(self, names):
"""Add list of names associated with layers.
Default will be names of first series added."""
assert self.baseseries == len(names), _(
"Number of vectors in series must match number of vectors"
)
self._labels = names
self._indices = list(range(len(self._labels)))

def _render_worker(self, i):
"""Function to render a single layer."""
filename = os.path.join(self._tmpdir.name, f"{i}.png")
shutil.copyfile(self.base_file, filename)
img = Map(
width=self._width,
height=self._height,
filename=filename,
use_region=True,
env=self._env,
read_file=True,
)
for grass_module, kwargs in self._base_calls[i]:
img.run(grass_module, **kwargs)
return i, filename

def render(self):
"""Renders image for each raster in series.

Save PNGs to temporary directory. Must be run before creating a visualization
(i.e. show or save).
"""
if not self._baseseries_added:
msg = (
"Cannot render series since none has been added."
"Use SeriesMap.add_rasters() or SeriesMap.add_vectors()"
)
if self._layer_count != len(names):
msg = _("Number of names must match number of added layers")
raise RuntimeError(msg)
tasks = [(i,) for i in range(self.baseseries)]
self._render(tasks)
self._labels = names
self._indices = list(range(self._layer_count))
2 changes: 1 addition & 1 deletion python/grass/jupyter/tests/timeseriesmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_default_init(space_time_raster_dataset):
"""Check that TimeSeriesMap init runs with default parameters"""
img = gj.TimeSeriesMap()
img.add_raster_series(space_time_raster_dataset.name)
assert img.baseseries == space_time_raster_dataset.name
assert img._baseseries == space_time_raster_dataset.name


@pytest.mark.needs_solo_run
Expand Down
Loading
Loading