Skip to content

优化PPI #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cinrad/visualize/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Author: Puyuan Du

FIG_SIZE = (10, 8)
FIG_SIZE_TRANSPARENT = (10, 10)
CBAR_POS = [0.83, 0.06, 0.04, 0.38]
TEXT_AXES_POS = [0.83, 0.06, 0.01, 0.35]
GEOAXES_POS = [0, 0, 0.8, 1]
Expand Down
218 changes: 158 additions & 60 deletions cinrad/visualize/ppi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# Author: Puyuan Du

import os
from pathlib import Path
import warnings
import json
from typing import Union, Optional, Any, List
from datetime import datetime

import matplotlib.pyplot as plt
from matplotlib.colorbar import ColorbarBase
import numpy as np
import cartopy.crs as ccrs
from cartopy.mpl.geoaxes import GeoAxes
Expand All @@ -19,11 +20,23 @@
from cinrad.io.level3 import StormTrackInfo
from cinrad._typing import Number_T
from cinrad.common import get_dtype, is_radial
from cinrad.visualize.layout import TEXT_AXES_POS, TEXT_SPACING, INIT_TEXT_POS, CBAR_POS
from cinrad.visualize.layout import *
from cartopy.io.shapereader import Reader

__all__ = ["PPI"]


def update_dict(d1: dict, d2: dict):
r"""
Update the content of the first dict with entries in the second,
and return the copy.
"""
d = d1.copy()
for k, v in d2.items():
d[k] = v
return d


class PPI(object):
r"""
Create a figure plotting plan position indicator
Expand All @@ -36,23 +49,28 @@ class PPI(object):

fig (matplotlib.figure.Figure): The figure to plot on. Optional.

norm (matplotlib.colors.Normalize): Customized norm data. Optional.
norm (matplotlib.colors.Normalize): Customized normalize object. Optional.

cmap (matplotlib.colors.Colormap): Customized colormap. Optional.

nlabel (int): Number of labels on the colorbar. Optional.
nlabel (int): Number of labels on the colorbar, will only be used when label is
also passed. Optional.

label (list[str]): Colorbar labels. Optional.

dpi (int): DPI of the figure. Optional.
dpi (int): DPI of the figure. Default 350.

highlight (str, list(str)): Areas to be highlighted. Optional.
highlight (str, list[str]): Areas to be highlighted. Optional.

coastline (bool): Plot coastline on the figure if set to True. Default False.

extent (list(float)): The extent of figure. Optional.
extent (list[float]): The extent of figure. Optional.

add_city_names (bool): Label city names on the figure if set to True. Default True.
add_city_names (bool): Label city names on the figure if set to True. Default False.

plot_labels (bool): Text scan information on the side of the plot. Default True.

text_param (dict): Optional parameters passed to matplotlib text function.
"""

# The CRS of data is believed to be PlateCarree.
Expand All @@ -65,7 +83,7 @@ def __init__(
fig: Optional[Any] = None,
norm: Optional[Any] = None,
cmap: Optional[Any] = None,
nlabel: Optional[int] = None,
nlabel: int = 10,
label: Optional[List[str]] = None,
dpi: Number_T = 350,
highlight: Optional[Union[str, List[str]]] = None,
Expand All @@ -75,6 +93,8 @@ def __init__(
style: str = "black",
add_city_names: bool = False,
plot_labels: bool = True,
text_param: Optional[dict] = None,
add_shps: bool = True,
**kwargs
):
self.data = data
Expand All @@ -93,14 +113,28 @@ def __init__(
"add_city_names": add_city_names,
"plot_labels": plot_labels,
"is_inline": is_inline(),
"add_shps": add_shps,
}
if fig is None:
self.fig = setup_plot(dpi, style=style)
if style == "transparent":
self.fig = plt.figure(figsize=FIG_SIZE_TRANSPARENT, dpi=dpi)
else:
self.fig = plt.figure(figsize=FIG_SIZE, dpi=dpi)
self.fig.patch.set_facecolor(style)
plt.axis("off")
else:
self.fig = fig
# avoid in-place modification
self.text_pos = TEXT_AXES_POS.copy()
self.cbar_pos = CBAR_POS.copy()
self.font_kw = default_font_kw.copy()
if style == "black":
self.font_kw["color"] = "white"
else:
self.font_kw["color"] = "black"
if text_param:
# Override use input setting
self.font_kw = update_dict(self.font_kw, text_param)
self._plot_ctx = dict()
self.rf_flag = "RF" in data
self._fig_init = False
Expand All @@ -110,19 +144,15 @@ def __init__(
# call this action at initialization
self._text_before_save()

def __call__(self, fpath: Optional[str] = None):
if not fpath:
# When the path is not specified, store the picture in home dir.
fpath = os.path.join(str(Path.home()), "PyCINRAD")
def __call__(self, fpath):
ext_name = fpath.split(".")
if len(ext_name) > 1:
all_fmt = self.fig.canvas.get_supported_filetypes()
if ext_name[-1] in all_fmt:
self.settings["path_customize"] = True
else:
ext_name = fpath.split(".")
if len(ext_name) > 1:
all_fmt = self.fig.canvas.get_supported_filetypes()
if ext_name[-1] in all_fmt:
self.settings["path_customize"] = True
else:
if not fpath.endswith(os.path.sep):
fpath += os.path.sep
if not fpath.endswith(os.path.sep):
fpath += os.path.sep
return self._save(fpath)

def _norm(self):
Expand All @@ -132,10 +162,7 @@ def _norm(self):
clabel = self.settings["label"]
else:
nlabel = self.settings["nlabel"]
if nlabel:
clabel = np.linspace(n.vmin, n.vmax, nlabel).astype(str)
else:
clabel = np.linspace(n.vmin, n.vmax, 10).astype(str)
clabel = np.linspace(n.vmin, n.vmax, nlabel).astype(str)
return n, n, clabel
else:
n = norm_plot[self.dtype]
Expand Down Expand Up @@ -199,13 +226,14 @@ def _plot(self, **kwargs):
)
if not self.settings["extent"]:
self._autoscale()
add_shp(
self.geoax,
proj,
coastline=self.settings["coastline"],
style=self.settings["style"],
extent=self.geoax.get_extent(self.data_crs),
)
if self.settings["add_shps"]:
add_shp(
self.geoax,
proj,
coastline=self.settings["coastline"],
style=self.settings["style"],
extent=self.geoax.get_extent(self.data_crs),
)
if self.settings["highlight"]:
draw_highlight_area(self.settings["highlight"])
if self.settings["add_city_names"]:
Expand All @@ -216,6 +244,13 @@ def _plot(self, **kwargs):
self._fig_init = True

def _text(self):
def _draw(ax: Any, y_index: int, text: str):
"""
Draw text on the axes.
"""
y = INIT_TEXT_POS - TEXT_SPACING * y_index
ax.text(0, y, text, **self.font_kw)

# axes used for text which has the same x-position as
# the colorbar axes (for matplotlib 3 compatibility)
var = self._plot_ctx["var"]
Expand All @@ -226,29 +261,27 @@ def _text(self):
ax2.xaxis.set_visible(False)
# Make VCP21 the default scanning strategy
task = self.data.attrs.get("task", "VCP21")
text(
ax2,
self.data.range,
self.data.tangential_reso,
self.data.scan_time,
self.data.site_name,
task,
self.data.elevation,
)
ax2.text(0, INIT_TEXT_POS, prodname[self.dtype], **plot_kw)
ax2.text(
0,
INIT_TEXT_POS - TEXT_SPACING * 8,
if self.data.tangential_reso >= 0.1:
reso = "{:.2f}km".format(self.data.tangential_reso)
else:
reso = "{:.0f}m".format(self.data.tangential_reso * 1000)
s_time = datetime.strptime(self.data.scan_time, "%Y-%m-%d %H:%M:%S")
texts = [
prodname[self.dtype],
"Range: {:.0f}km".format(self.data.range),
"Resolution: {}".format(reso),
"Date: {}".format(s_time.strftime("%Y.%m.%d")),
"Time: {}".format(s_time.strftime("%H:%M")),
"RDA: " + (self.data.site_name or "Unknown"),
"Task: {}".format(task),
"Elev: {:.2f}deg".format(self.data.elevation),
"Max: {:.1f}{}".format(np.nanmax(var), unit[self.dtype]),
**plot_kw
)
]
if self.dtype.startswith("VEL"):
ax2.text(
0,
INIT_TEXT_POS - TEXT_SPACING * 9,
"Min: {:.1f}{}".format(np.nanmin(var), unit[self.dtype]),
**plot_kw
)
min_vel = "Min: {:.1f}{}".format(np.nanmin(var), unit[self.dtype])
texts.append(min_vel)
for i, text in enumerate(texts):
_draw(ax2, i, text)

def _text_before_save(self):
# Finalize texting here
Expand All @@ -258,11 +291,21 @@ def _text_before_save(self):
pcmap, ccmap = self._cmap()
if self.settings["plot_labels"]:
self._text()
cbar = setup_axes(self.fig, ccmap, cnorm, self.cbar_pos)
cax = self.fig.add_axes(self.cbar_pos)
cbar = ColorbarBase(
cax, cmap=ccmap, norm=cnorm, orientation="vertical", drawedges=False
)
cbar.ax.tick_params(
axis="both",
which="both",
length=0,
labelsize=10,
colors=self.font_kw["color"],
)
cbar.outline.set_visible(False)
if not isinstance(clabel, type(None)):
change_cbar_text(
cbar, np.linspace(cnorm.vmin, cnorm.vmax, len(clabel)), clabel
)
cbar.set_ticks(np.linspace(cnorm.vmin, cnorm.vmax, len(clabel)))
cbar.set_ticklabels(clabel, **self.font_kw)

def _save(self, fpath: str):
if not self.settings["is_inline"]:
Expand Down Expand Up @@ -290,7 +333,16 @@ def _save(self, fpath: str):
)
else:
path_string = fpath
save(path_string, self.settings["style"])
save_options = dict(pad_inches=0)
if self.settings["style"] == "transparent":
save_options["transparent"] = True
else:
if self.settings["style"] == "white":
save_options["facecolor"] = "white"
elif self.settings["style"] == "black":
save_options["facecolor"] = "black"
plt.savefig(path_string, **save_options)
# plt.close("all")
return path_string

def plot_range_rings(
Expand All @@ -317,6 +369,52 @@ def plot_range_rings(
**kwargs
)

def plot_ring_rays(
self,
angle: Union[int, float, list],
range: int,
color: str = "white",
linewidth: Number_T = 0.5,
**kwargs
):
r"""Plot ring rays on PPI plot."""
slon, slat = self.data.site_longitude, self.data.site_latitude
if isinstance(angle, (int, float)):
angle = [angle]
for a in angle:
theta = np.deg2rad(a)
x, y = get_coordinate(range, theta, 0, slon, slat, h_offset=False)
self.geoax.plot(
[slon, x],
[slat, y],
color=color,
linewidth=linewidth,
transform=self.data_crs,
**kwargs
)

def add_custom_shp(
self,
shp_path: str,
encoding: str = "gbk",
color: str = "white",
linewidth: Number_T = 0.5,
**kwargs
):
"""
Add custom shapefile to the plot.
"""
reader = Reader(shp_path, encoding=encoding)
self.geoax.add_geometries(
geoms=list(reader.geometries()),
crs=ccrs.PlateCarree(),
edgecolor=color,
facecolor="None",
zorder=3,
linewidth=linewidth,
**kwargs
)

def plot_cross_section(
self,
data: Dataset,
Expand Down Expand Up @@ -479,7 +577,7 @@ def _add_city_names(self):
stlon,
stlat,
nm,
**plot_kw,
**default_font_kw,
color="darkgrey",
transform=self.data_crs,
horizontalalignment="center",
Expand Down
Loading
Loading