Skip to content

Commit 965d973

Browse files
committed
fix: mpi tests_007 passing, fixed typo Nz->NZ
1 parent f6180c9 commit 965d973

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

tests/test_007_mpi_lossy_cavity.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import os, sys
1+
import os
2+
import sys
23
import numpy as np
34
import pyvista as pv
45
import matplotlib.pyplot as plt
@@ -15,6 +16,9 @@
1516

1617
import pytest
1718

19+
# Run with:
20+
# mpiexec -n 2 python -m pytest --color=yes -v -s tests/test_007_mpi_lossy_cavity.py
21+
1822
# Turn true when running local
1923
flag_plot_3D = True
2024

@@ -92,7 +96,6 @@ def test_mpi_import(self):
9296
from mpi4py import MPI
9397

9498
comm = MPI.COMM_WORLD # Get MPI communicator
95-
rank = comm.Get_rank() # Process ID
9699
size = comm.Get_size() # Total number of MPI processes
97100
if size > 1:
98101
use_mpi = True
@@ -184,10 +187,12 @@ def test_mpi_simulation(self):
184187
beam.update(solver, n*solver.dt)
185188
solver.mpi_one_step()
186189

187-
Ez = solver.mpi_gather('Ez', x=int(Nx/2), y=int(Ny/2), z=np.s_[::5])
190+
Ez = solver.mpi_gather('Ez', x=int(Nx/2), y=int(Ny/2))
188191
if solver.rank == 0:
189192
#print(Ez)
190-
assert np.allclose(Ez, self.Ez, rtol=0.1), "Electric field Ez samples MPI failed"
193+
print(len(Ez))
194+
assert len(Ez) == NZ, "Electric field Ez samples length mismatch"
195+
assert np.allclose(Ez[np.s_[::5]], self.Ez, rtol=0.1), "Electric field Ez samples MPI failed"
191196
else:
192197
Nt = 3000
193198
for n in tqdm(range(Nt)):
@@ -197,9 +202,9 @@ def test_mpi_simulation(self):
197202

198203
Ez = solver.E[int(Nx/2), int(Ny/2), np.s_[::5], 'z']
199204
#print(Ez)
205+
assert len(solver.E[int(Nx/2), int(Ny/2), :, 'z']) == NZ, "Electric field Ez samples length mismatch"
200206
assert np.allclose(Ez, self.Ez, rtol=0.1), "Electric field Ez samples failed"
201207

202-
203208
def test_mpi_gather_asField(self):
204209
# Plot inspect after mpi gather
205210
global solver
@@ -292,10 +297,13 @@ def test_long_wake_potential(self):
292297
if use_mpi:
293298
if solver.rank == 0:
294299
#print(wake.WP[::50])
300+
print(len(wake.WP))
301+
assert len(wake.WP) == 5195, "Wake potential samples length mismatch"
295302
assert np.allclose(wake.WP[::50], self.WP, rtol=0.1), "Wake potential samples failed"
296303
assert np.cumsum(np.abs(wake.WP))[-1] == pytest.approx(184.43818552913254, 0.1), "Wake potential cumsum MPI failed"
297304
else:
298305
#print(wake.WP[::50])
306+
assert len(wake.WP) == 5195, "Wake potential samples length mismatch"
299307
assert np.allclose(wake.WP[::50], self.WP, rtol=0.1), "Wake potential samples failed"
300308
assert np.cumsum(np.abs(wake.WP))[-1] == pytest.approx(184.43818552913254, 0.1), "Wake potential cumsum MPI failed"
301309

@@ -305,12 +313,15 @@ def test_long_impedance(self):
305313
if use_mpi:
306314
if solver.rank == 0:
307315
#print(wake.Z[::20])
316+
print(len(wake.Z))
317+
assert len(wake.Z) == 998, "Impedance samples length mismatch"
308318
assert np.allclose(np.abs(wake.Z)[::20], np.abs(self.Z), rtol=0.1), "Abs Impedance samples MPI failed"
309319
assert np.allclose(np.real(wake.Z)[::20], np.real(self.Z), rtol=0.1), "Real Impedance samples MPI failed"
310320
assert np.allclose(np.imag(wake.Z)[::20], np.imag(self.Z), rtol=0.1), "Imag Impedance samples MPI failed"
311321
assert np.cumsum(np.abs(wake.Z))[-1] == pytest.approx(250910.51090497518, 0.1), "Abs Impedance cumsum MPI failed"
312322
else:
313323
#print(wake.Z[::20])
324+
assert len(wake.Z) == 998, "Impedance samples length mismatch"
314325
assert np.allclose(np.abs(wake.Z)[::20], np.abs(self.Z), rtol=0.1), "Abs Impedance samples failed"
315326
assert np.allclose(np.real(wake.Z)[::20], np.real(self.Z), rtol=0.1), "Real Impedance samples failed"
316327
assert np.allclose(np.imag(wake.Z)[::20], np.imag(self.Z), rtol=0.1), "Imag Impedance samples failed"

wakis/gridFIT3D.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,7 @@ def __init__(self, xmin=None, xmax=None,
9595
if stl_solids is not None:
9696
self._prepare_stl_dicts()
9797

98-
# primal Grid G base axis x, y, z
99-
self.x = x
100-
self.y = y
101-
self.z = z
102-
98+
# Grid data
10399
# generate from file
104100
if load_from_h5 is not None:
105101
t0 = time.time()
@@ -114,8 +110,11 @@ def __init__(self, xmin=None, xmax=None,
114110
return
115111

116112
# generate from custom x,y,z arrays
117-
elif self.x is not None and self.y is not None and self.z is not None:
113+
elif x is not None and y is not None and z is not None:
118114
# allow user to set the grid axis manually
115+
self.x = x
116+
self.y = y
117+
self.z = z
119118
self.Nx = len(self.x) - 1
120119
self.Ny = len(self.y) - 1
121120
self.Nz = len(self.z) - 1
@@ -146,7 +145,8 @@ def __init__(self, xmin=None, xmax=None,
146145

147146
self.dx = np.min(np.diff(self.x))
148147
self.dy = np.min(np.diff(self.y))
149-
self.dz = np.min(np.diff(self.z))
148+
#self.dz = np.min(np.diff(self.z))
149+
self.dz = (self.zmax - self.zmin)/self.Nz
150150

151151
# refine self.x, self.y, self.z using snap points
152152
self.use_mesh_refinement = use_mesh_refinement
@@ -201,7 +201,6 @@ def __init__(self, xmin=None, xmax=None,
201201

202202

203203
def compute_grid(self):
204-
205204
X, Y, Z = np.meshgrid(self.x, self.y, self.z, indexing='ij')
206205
self.grid = pv.StructuredGrid(X.transpose(), Y.transpose(), Z.transpose())
207206

@@ -261,13 +260,13 @@ def mpi_initialize(self):
261260

262261
# MPI subdomain quantities
263262
self.Nz = self.NZ // (self.size)
264-
self.dz = (self.ZMAX - self.ZMIN) / self.Nz
263+
self.dz = (self.ZMAX - self.ZMIN) / self.NZ
265264
self.zmin = self.rank * self.Nz * self.dz + self.ZMIN
266265
self.zmax = (self.rank+1) * self.Nz * self.dz + self.ZMIN
267266

268267
if self.verbose:
269268
print(f"MPI rank {self.rank} of {self.size} initialized with \
270-
zmin={self.zmin}, zmax={self.zmax}, Nz={self.Nz}")
269+
zmin={self.zmin}, zmax={self.zmax}, Nz={self.Nz}")
271270
# Add ghost cells
272271
self.n_ghosts = 1
273272
if self.rank > 0:
@@ -341,9 +340,9 @@ def mark_cells_in_stl(self):
341340

342341
# mark cells in stl [True == in stl, False == out stl]
343342
try:
344-
select = self.grid.select_enclosed_points(surf, tolerance=tol)
343+
select = self.grid.select_enclosed_points(surf, tolerance=stl_tolerance)
345344
except Exception:
346-
select = self.grid.select_enclosed_points(surf, tolerance=tol, check_surface=False)
345+
select = self.grid.select_enclosed_points(surf, tolerance=stl_tolerance, check_surface=False)
347346
if self.verbose > 1:
348347
print(f'[!] Warning: stl solid {key} may have issues with closed surfaces. Consider checking the STL file.')
349348

0 commit comments

Comments
 (0)