Skip to content

Commit 4d02b7a

Browse files
fixed ugrid so it works with Py 3.10, 3.11
1 parent f123650 commit 4d02b7a

File tree

3 files changed

+73
-72
lines changed

3 files changed

+73
-72
lines changed

tests/test_visualization/test_mpl_plotting.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,14 @@ def test_plot_ugrid_start_index_1():
6868
# SGRID tests
6969
#############
7070

71-
# def test_plot_sgrid_only_grid():
72-
# import cftime
71+
# def test_plot_sgrid_and_nodes():
7372
# ds = xr.open_dataset(EXAMPLE_DATA / "wcofs_small_subset.nc", decode_times=False)
7473

7574
# fig, axis = plt.subplots()
7675

77-
# plot_sgrid(axis, ds)
76+
# plot_sgrid(axis, ds, nodes=True)
7877

79-
# fig.savefig(OUTPUT_DIR / "sgrid_just_plot")
78+
# fig.savefig(OUTPUT_DIR / "sgrid_nodes")
8079

8180

8281

xarray_subset_grid/grids/ugrid.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -403,9 +403,11 @@ def assign_ugrid_topology(
403403
mesh.__dict__.update(mesh_attrs)
404404

405405
# Add in the ones passed in:
406-
mesh.__dict__.update({att: vars()[att]
406+
variables = vars()
407+
mesh.__dict__.update({att: variables[att]
407408
for att in ALL_MESH_VARS
408-
if vars()[att] is not None})
409+
if variables[att] is not None})
410+
mesh.start_index = start_index
409411

410412
if mesh.face_node_connectivity is None:
411413
raise ValueError(

xarray_subset_grid/visualization/mpl_plotting.py

+66-66
Original file line numberDiff line numberDiff line change
@@ -113,73 +113,73 @@ def plot_sgrid(axes, ds, nodes=False, rho_points=False, edge_points=False):
113113

114114
raise NotImplementedError("have to port ugrid code to Sgrid")
115115

116-
mesh_defs = ds[ds.cf.cf_roles["grid_topology"][0]].attrs
117-
lon_var, lat_var = mesh_defs["node_coordinates"].split()
118-
nodes_lon, nodes_lat = (ds[n] for n in mesh_defs["node_coordinates"].split())
119-
120-
121-
faces = ds[mesh_defs["face_node_connectivity"]]
122-
123-
if faces.shape[0] == 3:
124-
# swap order for mpl triangulation
125-
faces = faces.T
126-
start_index = faces.attrs.get("start_index")
127-
start_index = 0 if start_index is None else start_index
128-
faces = faces - start_index
129-
130-
mpl_tri = Triangulation(nodes_lon, nodes_lat, faces)
131-
132-
axes.triplot(mpl_tri)
133-
if face_numbers:
134-
try:
135-
face_lon, face_lat = (ds[n] for n in mesh_defs["face_coordinates"].split())
136-
except KeyError:
137-
raise ValueError('"face_coordinates" must be defined to plot the face numbers')
138-
for i, point in enumerate(zip(face_lon, face_lat)):
139-
axes.annotate(
140-
f"{i}",
141-
point,
142-
xytext=(0, 0),
143-
textcoords="offset points",
144-
horizontalalignment="center",
145-
verticalalignment="center",
146-
bbox={
147-
"facecolor": "white",
148-
"alpha": 1.0,
149-
"boxstyle": "round,pad=0.0",
150-
"ec": "white",
151-
},
152-
)
116+
grid_defs = ds[ds.cf.cf_roles["grid_topology"][0]].attrs
117+
lon_var, lat_var = grid_defs["node_coordinates"].split()
118+
nodes_lon, nodes_lat = (ds[n] for n in grid_defs["node_coordinates"].split())
119+
120+
121+
# faces = ds[mesh_defs["face_node_connectivity"]]
122+
123+
# if faces.shape[0] == 3:
124+
# # swap order for mpl triangulation
125+
# faces = faces.T
126+
# start_index = faces.attrs.get("start_index")
127+
# start_index = 0 if start_index is None else start_index
128+
# faces = faces - start_index
129+
130+
# mpl_tri = Triangulation(nodes_lon, nodes_lat, faces)
131+
132+
# axes.triplot(mpl_tri)
133+
# if face_numbers:
134+
# try:
135+
# face_lon, face_lat = (ds[n] for n in mesh_defs["face_coordinates"].split())
136+
# except KeyError:
137+
# raise ValueError('"face_coordinates" must be defined to plot the face numbers')
138+
# for i, point in enumerate(zip(face_lon, face_lat)):
139+
# axes.annotate(
140+
# f"{i}",
141+
# point,
142+
# xytext=(0, 0),
143+
# textcoords="offset points",
144+
# horizontalalignment="center",
145+
# verticalalignment="center",
146+
# bbox={
147+
# "facecolor": "white",
148+
# "alpha": 1.0,
149+
# "boxstyle": "round,pad=0.0",
150+
# "ec": "white",
151+
# },
152+
# )
153153

154154
# plot nodes
155155
if nodes:
156156
axes.plot(nodes_lon, nodes_lat, "o")
157-
# plot node numbers
158-
if node_numbers:
159-
for i, point in enumerate(zip(nodes_lon, nodes_lat)):
160-
axes.annotate(
161-
f"{i}",
162-
point,
163-
xytext=(2, 2),
164-
textcoords="offset points",
165-
bbox={
166-
"facecolor": "white",
167-
"alpha": 1.0,
168-
"boxstyle": "round,pad=0.0",
169-
"ec": "white",
170-
},
171-
)
172-
173-
# boundaries -- if they are there.
174-
if "boundary_node_connectivity" in mesh_defs:
175-
bounds = ds[mesh_defs["boundary_node_connectivity"]]
176-
177-
lines = []
178-
for bound in bounds.data:
179-
line = (
180-
(nodes_lon[bound[0]], nodes_lat[bound[0]]),
181-
(nodes_lon[bound[1]], nodes_lat[bound[1]]),
182-
)
183-
lines.append(line)
184-
lc = LineCollection(lines, linewidths=2, colors=(1, 0, 0, 1))
185-
axes.add_collection(lc)
157+
# # plot node numbers
158+
# if node_numbers:
159+
# for i, point in enumerate(zip(nodes_lon, nodes_lat)):
160+
# axes.annotate(
161+
# f"{i}",
162+
# point,
163+
# xytext=(2, 2),
164+
# textcoords="offset points",
165+
# bbox={
166+
# "facecolor": "white",
167+
# "alpha": 1.0,
168+
# "boxstyle": "round,pad=0.0",
169+
# "ec": "white",
170+
# },
171+
# )
172+
173+
# # boundaries -- if they are there.
174+
# if "boundary_node_connectivity" in mesh_defs:
175+
# bounds = ds[mesh_defs["boundary_node_connectivity"]]
176+
177+
# lines = []
178+
# for bound in bounds.data:
179+
# line = (
180+
# (nodes_lon[bound[0]], nodes_lat[bound[0]]),
181+
# (nodes_lon[bound[1]], nodes_lat[bound[1]]),
182+
# )
183+
# lines.append(line)
184+
# lc = LineCollection(lines, linewidths=2, colors=(1, 0, 0, 1))
185+
# axes.add_collection(lc)

0 commit comments

Comments
 (0)