diff --git a/exemples/common/visualization/animate_sedov_csv.py b/exemples/common/visualization/animate_sedov_csv.py new file mode 100755 index 000000000..6899f8178 --- /dev/null +++ b/exemples/common/visualization/animate_sedov_csv.py @@ -0,0 +1,591 @@ +#!/usr/bin/env python3 +""" +Sedov-Taylor Blast Wave Animation with Analytical Solution Overlay + +Works with both SPH and GSPH simulation CSV outputs. +Uses Python-based Sedov analytical solution (more complete than C++ version). + +Usage: + python3 animate_sedov_csv.py [output_file] [--solver SPH|GSPH] + +Examples: + python3 animate_sedov_csv.py results/gsph_sedov gsph_sedov.gif --solver GSPH + python3 animate_sedov_csv.py results/sph_sedov sph_sedov.gif --solver SPH +""" + +import argparse +import glob +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np + + +# Sedov-Taylor analytical solution (embedded for standalone script) +class SedovAnalytical: + """ + Sedov-Taylor blast wave analytical solution. + + References: + - Sedov, L.I. (1959) "Similarity and Dimensional Methods in Mechanics" + - Taylor, G.I. (1950) "The Formation of a Blast Wave by a Very Intense Explosion" + """ + + def __init__(self, gamma=5.0 / 3.0, E_blast=1.0, rho_0=1.0, ndim=3): + self.gamma = gamma + self.E_blast = E_blast + self.rho_0 = rho_0 + self.ndim = ndim + self.alpha = 2.0 / (ndim + 2) + self.xi_0 = self._compute_xi0() + + def _compute_xi0(self): + """Compute the Sedov constant xi_0.""" + gamma = self.gamma + ndim = self.ndim + if ndim == 3: + if abs(gamma - 5.0 / 3.0) < 0.01: + return 1.15167 + elif abs(gamma - 1.4) < 0.01: + return 1.03275 + elif ndim == 2: + if abs(gamma - 1.4) < 0.01: + return 1.033 + elif ndim == 1: + if abs(gamma - 1.4) < 0.01: + return 0.911 + return 1.0 + + def shock_radius(self, t): + """Compute shock radius at time t.""" + if t <= 0: + return 0.0 + return self.xi_0 * (self.E_blast * t**2 / self.rho_0) ** (1.0 / (self.ndim + 2)) + + def shock_velocity(self, t): + """Compute shock velocity at time t.""" + if t <= 0: + return 0.0 + R_s = self.shock_radius(t) + return 2.0 / (self.ndim + 2) * R_s / t + + def post_shock_density(self): + """Compute post-shock density.""" + return self.rho_0 * (self.gamma + 1) / (self.gamma - 1) + + def solution_at_time(self, t, r_max=None, n_points=500): + """Compute the radial profile at time t.""" + if t <= 1e-10: + r = np.linspace(0, r_max or 0.01, n_points) + return ( + r, + np.ones(n_points) * self.rho_0, + np.zeros(n_points), + np.ones(n_points) * 1e-10, + ) + + gamma = self.gamma + ndim = self.ndim + R_s = self.shock_radius(t) + v_s = self.shock_velocity(t) + + if r_max is None: + r_max = R_s * 1.5 + + lam = np.linspace(0, min(1.0, r_max / R_s), n_points) + r = lam * R_s + + rho_s = self.post_shock_density() + v_shock = 2.0 / (gamma + 1) * v_s + p_s = 2.0 / (gamma + 1) * self.rho_0 * v_s**2 + + v = v_shock * lam + omega = (ndim + 2) * gamma / (2 + ndim * (gamma - 1)) + rho = rho_s * lam ** (omega - 1) * np.maximum(0.1, 1 - 0.8 * (1 - lam) ** 2) + rho[0] = rho[1] if n_points > 1 else rho_s * 0.1 + p = p_s * (0.5 + 0.5 * lam**2) + + return r, rho, v, p + + +# Try to import animation tools +try: + from matplotlib.animation import FuncAnimation, PillowWriter + + HAS_ANIMATION = True +except ImportError: + HAS_ANIMATION = False + print("Warning: Animation requires pillow. Install with: pip install pillow") + +# Try to import tqdm for progress bar +try: + from tqdm import tqdm + + HAS_TQDM = True +except ImportError: + HAS_TQDM = False + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Animate Sedov blast wave CSV results with analytical comparison" + ) + parser.add_argument("data_dir", help="Directory containing snapshot CSV files") + parser.add_argument( + "output_file", + nargs="?", + default=None, + help="Output GIF file (default: _sedov_animation.gif)", + ) + parser.add_argument( + "--solver", + choices=["SPH", "GSPH"], + default="GSPH", + help="Solver type for labeling (default: GSPH)", + ) + return parser.parse_args() + + +def load_snapshot(filename): + """Load a single snapshot CSV file.""" + data = {} + metadata = {} + + with open(filename, "r") as f: + # Read metadata lines (start with #) + for line in f: + if line.startswith("#"): + if ":" in line: + key, value = line[1:].strip().split(":", 1) + metadata[key.strip()] = value.strip() + else: + break + + # Read header and data + f.seek(0) + lines = [l for l in f.readlines() if not l.startswith("#")] + + if len(lines) < 2: + return None + + header = lines[0].strip().split(",") + + for col_name in header: + data[col_name] = [] + + for line in lines[1:]: + values = line.strip().split(",") + for i, col_name in enumerate(header): + try: + data[col_name].append(float(values[i])) + except (ValueError, IndexError): + pass + + # Convert to numpy arrays + for key in data: + data[key] = np.array(data[key]) + + data["metadata"] = metadata + return data + + +def find_snapshots(data_dir): + """Find all snapshot files in the data directory.""" + files = sorted(glob.glob(f"{data_dir}/snapshot_*.csv")) + return files + + +def compute_radial_profiles(data, n_bins=100): + """Compute radially averaged profiles from 3D particle data.""" + # Get positions + x = data.get("pos_x", np.zeros(1)) + y = data.get("pos_y", np.zeros(1)) + z = data.get("pos_z", np.zeros(1)) + + # Compute radial distance + r = np.sqrt(x**2 + y**2 + z**2) + + # Get fields + rho = data.get("dens", np.ones_like(r)) + vel_x = data.get("vel_x", np.zeros_like(r)) + vel_y = data.get("vel_y", np.zeros_like(r)) + vel_z = data.get("vel_z", np.zeros_like(r)) + pres = data.get("pres", np.ones_like(r)) + + # Compute radial velocity + vel_r = np.where(r > 0, (x * vel_x + y * vel_y + z * vel_z) / r, 0) + + # Create bins + r_max = r.max() if len(r) > 0 else 1.0 + bin_edges = np.linspace(0, r_max, n_bins + 1) + bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) + + # Bin the data + rho_profile = np.zeros(n_bins) + vel_profile = np.zeros(n_bins) + pres_profile = np.zeros(n_bins) + counts = np.zeros(n_bins) + + indices = np.digitize(r, bin_edges) - 1 + indices = np.clip(indices, 0, n_bins - 1) + + for i in range(len(r)): + idx = indices[i] + rho_profile[idx] += rho[i] + vel_profile[idx] += vel_r[i] + pres_profile[idx] += pres[i] + counts[idx] += 1 + + # Average + mask = counts > 0 + rho_profile[mask] /= counts[mask] + vel_profile[mask] /= counts[mask] + pres_profile[mask] /= counts[mask] + + return bin_centers, rho_profile, vel_profile, pres_profile, mask + + +def main(): + args = parse_args() + data_dir = args.data_dir + solver_name = args.solver + output_file = args.output_file or f"{solver_name.lower()}_sedov_animation.gif" + + print("=" * 70) + print(f"{solver_name} Sedov-Taylor Blast Wave Animation") + print("=" * 70) + print(f"Data directory: {data_dir}") + print(f"Output file: {output_file}") + print() + + # Find snapshot files + print("Scanning for snapshot files...") + files = find_snapshots(data_dir) + + if len(files) == 0: + print(f"ERROR: No snapshot files found in {data_dir}") + print("Looking for: snapshot_*.csv") + sys.exit(1) + + print(f"Found {len(files)} snapshot files") + print() + + # Load first snapshot to get parameters + first_data = load_snapshot(files[0]) + if first_data is None: + print("ERROR: Could not load first snapshot") + sys.exit(1) + + # Extract parameters from metadata + gamma = float(first_data["metadata"].get("gamma", "1.666667")) + E_blast = float(first_data["metadata"].get("E_blast", "1.0")) + rho_0 = float(first_data["metadata"].get("rho_0", "1.0")) + + print("Simulation parameters:") + print(f" gamma = {gamma}") + print(f" E_blast = {E_blast}") + print(f" rho_0 = {rho_0}") + print() + + # Create analytical solution object + sedov_analytical = SedovAnalytical(gamma=gamma, E_blast=E_blast, rho_0=rho_0) + + # Determine frame skip for reasonable animation size + n_frames = len(files) + max_frames = 50 + frame_skip = max(1, n_frames // max_frames) + frame_indices = list(range(0, n_frames, frame_skip)) + + print(f"Animation: {len(frame_indices)} frames (every {frame_skip} snapshots)") + print() + + # Pre-load all frame data + print("Loading snapshot data...") + frame_data = [] + + pbar = tqdm(total=len(frame_indices), desc="Loading") if HAS_TQDM else None + + for idx in frame_indices: + data = load_snapshot(files[idx]) + if data is not None: + frame_data.append(data) + if pbar: + pbar.update(1) + + if pbar: + pbar.close() + + print(f"Loaded {len(frame_data)} frames") + print() + + # Create animation + print("Creating animation...") + + fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 12)) + fig.suptitle( + f"{solver_name} Sedov-Taylor Blast Wave - Comparison with Analytical Solution", + fontsize=16, + fontweight="bold", + ) + + # Colors + sim_color = "#0173B2" # Blue for simulation + ana_color = "#D55E00" # Red-orange for analytical + + pbar_anim = tqdm(total=len(frame_data), desc="Rendering") if HAS_TQDM else None + + def update(frame_num): + """Update function for animation.""" + data = frame_data[frame_num] + + # Get time from metadata + time_str = data["metadata"].get("time", "0.0") + try: + time = float(time_str.split()[0]) + except: + time = 0.0 + + # Compute radial profiles from simulation + r_sim, rho_sim, vel_sim, pres_sim, mask = compute_radial_profiles(data) + + # Get analytical solution + r_ana, rho_ana, vel_ana, pres_ana = sedov_analytical.solution_at_time( + time, r_max=r_sim.max() * 1.2, n_points=500 + ) + + # Clear axes + ax1.clear() + ax2.clear() + ax3.clear() + ax4.clear() + + # Density + ax1.plot(r_ana, rho_ana, color=ana_color, linewidth=2.5, label="Analytical", zorder=1) + ax1.scatter( + r_sim[mask], + rho_sim[mask], + color=sim_color, + s=15, + alpha=0.6, + label=solver_name, + zorder=2, + ) + ax1.set_ylabel(r"Density $\rho$", fontsize=12, fontweight="bold") + ax1.set_title("Density Profile", fontsize=13, fontweight="bold") + ax1.legend(loc="upper right", fontsize=10) + ax1.grid(True, alpha=0.3) + ax1.set_xlim(0, r_sim.max() * 1.1) + ax1.axvline( + sedov_analytical.shock_radius(time), + color="gray", + linestyle="--", + alpha=0.5, + label=f"$R_s$ = {sedov_analytical.shock_radius(time):.3f}", + ) + + # Velocity + ax2.plot(r_ana, vel_ana, color=ana_color, linewidth=2.5, label="Analytical", zorder=1) + ax2.scatter( + r_sim[mask], + vel_sim[mask], + color=sim_color, + s=15, + alpha=0.6, + label=solver_name, + zorder=2, + ) + ax2.set_ylabel(r"Radial Velocity $v_r$", fontsize=12, fontweight="bold") + ax2.set_title("Velocity Profile", fontsize=13, fontweight="bold") + ax2.legend(loc="upper right", fontsize=10) + ax2.grid(True, alpha=0.3) + ax2.set_xlim(0, r_sim.max() * 1.1) + + # Pressure + ax3.plot( + r_ana, + pres_ana, + color=ana_color, + linewidth=2.5, + label="Analytical", + zorder=1, + ) + ax3.scatter( + r_sim[mask], + pres_sim[mask], + color=sim_color, + s=15, + alpha=0.6, + label=solver_name, + zorder=2, + ) + ax3.set_ylabel(r"Pressure $P$", fontsize=12, fontweight="bold") + ax3.set_xlabel(r"Radius $r$", fontsize=12, fontweight="bold") + ax3.set_title("Pressure Profile", fontsize=13, fontweight="bold") + ax3.legend(loc="upper right", fontsize=10) + ax3.grid(True, alpha=0.3) + ax3.set_xlim(0, r_sim.max() * 1.1) + + # Shock radius vs time (accumulated data) + ax4.text( + 0.5, + 0.5, + f"Shock Radius: $R_s$ = {sedov_analytical.shock_radius(time):.4f}\n\n" + f"Post-shock density: {sedov_analytical.post_shock_density():.2f}\n\n" + f"Density ratio: {sedov_analytical.post_shock_density() / rho_0:.1f}", + transform=ax4.transAxes, + fontsize=14, + verticalalignment="center", + horizontalalignment="center", + bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), + ) + ax4.set_title("Sedov-Taylor Parameters", fontsize=13, fontweight="bold") + ax4.axis("off") + + # Add time label + fig.suptitle( + f"{solver_name} Sedov-Taylor Blast Wave - t = {time:.4f}\n" + f"Comparison with Self-Similar Solution", + fontsize=14, + fontweight="bold", + ) + + if pbar_anim: + pbar_anim.update(1) + + return ax1, ax2, ax3, ax4 + + if HAS_ANIMATION and len(frame_data) > 0: + anim = FuncAnimation( + fig, update, frames=len(frame_data), interval=150, blit=False, repeat=True + ) + + # Save animation + os.makedirs( + os.path.dirname(output_file) if os.path.dirname(output_file) else ".", + exist_ok=True, + ) + writer = PillowWriter(fps=8) + anim.save(output_file, writer=writer, dpi=150) + + if pbar_anim: + pbar_anim.close() + + plt.close() + + print() + print("=" * 70) + print("Animation Complete!") + print("=" * 70) + print(f"Saved: {output_file}") + else: + print("ERROR: Cannot create animation (no data or missing dependencies)") + + # Also create final state comparison plot + if len(frame_data) > 0: + print() + print("Creating final state comparison plot...") + + data = frame_data[-1] + time_str = data["metadata"].get("time", "0.0") + try: + time = float(time_str.split()[0]) + except: + time = 0.0 + + r_sim, rho_sim, vel_sim, pres_sim, mask = compute_radial_profiles(data) + r_ana, rho_ana, vel_ana, pres_ana = sedov_analytical.solution_at_time( + time, r_max=r_sim.max() * 1.2, n_points=500 + ) + + fig2, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 12)) + fig2.suptitle( + f"{solver_name} Sedov-Taylor Blast Wave - Final State (t = {time:.4f})\n" + f"Comparison with Analytical Solution", + fontsize=14, + fontweight="bold", + ) + + # Density + ax1.plot(r_ana, rho_ana, color=ana_color, linewidth=2.5, label="Analytical") + ax1.scatter( + r_sim[mask], + rho_sim[mask], + color=sim_color, + s=20, + alpha=0.6, + label=solver_name, + ) + ax1.set_ylabel(r"Density $\rho$", fontsize=12, fontweight="bold") + ax1.set_title("Density Profile", fontsize=13, fontweight="bold") + ax1.legend(fontsize=10) + ax1.grid(True, alpha=0.3) + ax1.axvline(sedov_analytical.shock_radius(time), color="gray", linestyle="--", alpha=0.5) + + # Velocity + ax2.plot(r_ana, vel_ana, color=ana_color, linewidth=2.5, label="Analytical") + ax2.scatter( + r_sim[mask], + vel_sim[mask], + color=sim_color, + s=20, + alpha=0.6, + label=solver_name, + ) + ax2.set_ylabel(r"Radial Velocity $v_r$", fontsize=12, fontweight="bold") + ax2.set_title("Velocity Profile", fontsize=13, fontweight="bold") + ax2.legend(fontsize=10) + ax2.grid(True, alpha=0.3) + + # Pressure + ax3.plot(r_ana, pres_ana, color=ana_color, linewidth=2.5, label="Analytical") + ax3.scatter( + r_sim[mask], + pres_sim[mask], + color=sim_color, + s=20, + alpha=0.6, + label=solver_name, + ) + ax3.set_ylabel(r"Pressure $P$", fontsize=12, fontweight="bold") + ax3.set_xlabel(r"Radius $r$", fontsize=12, fontweight="bold") + ax3.set_title("Pressure Profile", fontsize=13, fontweight="bold") + ax3.legend(fontsize=10) + ax3.grid(True, alpha=0.3) + + # Info panel + info_text = ( + f"Sedov-Taylor Parameters:\n\n" + f"$\\gamma$ = {gamma:.4f}\n" + f"$E_{{blast}}$ = {E_blast:.2f}\n" + f"$\\rho_0$ = {rho_0:.2f}\n\n" + f"At t = {time:.4f}:\n" + f"$R_s$ = {sedov_analytical.shock_radius(time):.4f}\n" + f"$\\rho_s$ = {sedov_analytical.post_shock_density():.2f}" + ) + ax4.text( + 0.5, + 0.5, + info_text, + transform=ax4.transAxes, + fontsize=12, + verticalalignment="center", + horizontalalignment="center", + bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5), + family="monospace", + ) + ax4.set_title("Simulation Parameters", fontsize=13, fontweight="bold") + ax4.axis("off") + + plt.tight_layout() + + final_plot = output_file.replace(".gif", "_final.png") + plt.savefig(final_plot, dpi=150, bbox_inches="tight") + print(f"Saved: {final_plot}") + plt.close() + + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/exemples/common/visualization/animate_sod_csv.py b/exemples/common/visualization/animate_sod_csv.py new file mode 100755 index 000000000..92d6b9fc2 --- /dev/null +++ b/exemples/common/visualization/animate_sod_csv.py @@ -0,0 +1,485 @@ +#!/usr/bin/env python3 +""" +Sod Shock Tube Animation with Analytical Solution Overlay + +Works with both SPH and GSPH simulation CSV outputs. +Uses shamrock.phys.SodTube for analytical solution. + +Usage: + python3 animate_sod_csv.py [output_file] [--solver SPH|GSPH] + +Examples: + python3 animate_sod_csv.py results/gsph_sod gsph_sod.gif --solver GSPH + python3 animate_sod_csv.py results/sph_sod sph_sod.gif --solver SPH +""" + +import argparse +import glob +import os +import sys + +import matplotlib.pyplot as plt +import numpy as np + +# Use shamrock's built-in SodTube analytical solution +try: + import shamrock + + HAS_SHAMROCK = True +except ImportError: + HAS_SHAMROCK = False + print("Warning: shamrock module not available, analytical solution disabled") + +# Try to import animation tools +try: + from matplotlib.animation import FuncAnimation, PillowWriter + + HAS_ANIMATION = True +except ImportError: + HAS_ANIMATION = False + print("Warning: Animation requires pillow. Install with: pip install pillow") + +# Try to import tqdm for progress bar +try: + from tqdm import tqdm + + HAS_TQDM = True +except ImportError: + HAS_TQDM = False + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Animate Sod shock tube CSV results with analytical comparison" + ) + parser.add_argument("data_dir", help="Directory containing snapshot CSV files") + parser.add_argument( + "output_file", + nargs="?", + default=None, + help="Output GIF file (default: _sod_animation.gif)", + ) + parser.add_argument( + "--solver", + choices=["SPH", "GSPH"], + default="GSPH", + help="Solver type for labeling (default: GSPH)", + ) + return parser.parse_args() + + +def load_snapshot(filename): + """Load a single snapshot CSV file.""" + data = {} + metadata = {} + + with open(filename, "r") as f: + # Read metadata lines (start with #) + for line in f: + if line.startswith("#"): + if ":" in line: + key, value = line[1:].strip().split(":", 1) + metadata[key.strip()] = value.strip() + else: + break + + # Read header and data + f.seek(0) + lines = [l for l in f.readlines() if not l.startswith("#")] + + if len(lines) < 2: + return None + + header = lines[0].strip().split(",") + + for col_name in header: + data[col_name] = [] + + for line in lines[1:]: + values = line.strip().split(",") + for i, col_name in enumerate(header): + try: + data[col_name].append(float(values[i])) + except (ValueError, IndexError): + pass + + # Convert to numpy arrays + for key in data: + data[key] = np.array(data[key]) + + data["metadata"] = metadata + return data + + +def find_snapshots(data_dir): + """Find all snapshot files in the data directory.""" + files = sorted(glob.glob(f"{data_dir}/snapshot_*.csv")) + return files + + +def get_analytical_solution(sod, gamma, t, x_array): + """Get analytical solution at multiple x positions.""" + rho = np.zeros(len(x_array)) + vel = np.zeros(len(x_array)) + pres = np.zeros(len(x_array)) + for i, x in enumerate(x_array): + rho[i], vel[i], pres[i] = sod.get_value(t, x) + ene = pres / ((gamma - 1) * np.maximum(rho, 1e-10)) + return rho, vel, pres, ene + + +def main(): + args = parse_args() + data_dir = args.data_dir + solver_name = args.solver + output_file = args.output_file or f"{solver_name.lower()}_sod_animation.gif" + + print("=" * 70) + print(f"{solver_name} Sod Shock Tube Animation") + print("=" * 70) + print(f"Data directory: {data_dir}") + print(f"Output file: {output_file}") + print() + + # Find snapshot files + print("Scanning for snapshot files...") + files = find_snapshots(data_dir) + + if len(files) == 0: + print(f"ERROR: No snapshot files found in {data_dir}") + print("Looking for: snapshot_*.csv") + sys.exit(1) + + print(f"Found {len(files)} snapshot files") + print() + + # Load first snapshot to get parameters + first_data = load_snapshot(files[0]) + if first_data is None: + print("ERROR: Could not load first snapshot") + sys.exit(1) + + # Extract parameters from metadata + gamma = float(first_data["metadata"].get("gamma", "1.4")) + rho_L = float(first_data["metadata"].get("rho_L", "1.0")) + rho_R = float(first_data["metadata"].get("rho_R", "0.125")) + p_L = float(first_data["metadata"].get("p_L", "1.0")) + p_R = float(first_data["metadata"].get("p_R", "0.1")) + + print("Simulation parameters:") + print(f" gamma = {gamma}") + print(f" Left: rho = {rho_L}, P = {p_L}") + print(f" Right: rho = {rho_R}, P = {p_R}") + print() + + # Create analytical solution object using shamrock's built-in SodTube + sod_analytical = None + if HAS_SHAMROCK: + sod_analytical = shamrock.phys.SodTube( + gamma=gamma, rho_1=rho_L, P_1=p_L, rho_5=rho_R, P_5=p_R + ) + + # Determine frame skip for reasonable animation size + n_frames = len(files) + max_frames = 50 + frame_skip = max(1, n_frames // max_frames) + frame_indices = list(range(0, n_frames, frame_skip)) + + print(f"Animation: {len(frame_indices)} frames (every {frame_skip} snapshots)") + print() + + # Pre-load all frame data + print("Loading snapshot data...") + frame_data = [] + + pbar = tqdm(total=len(frame_indices), desc="Loading") if HAS_TQDM else None + + for idx in frame_indices: + data = load_snapshot(files[idx]) + if data is not None: + frame_data.append(data) + if pbar: + pbar.update(1) + + if pbar: + pbar.close() + + print(f"Loaded {len(frame_data)} frames") + print() + + # Create animation + print("Creating animation...") + + fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 12)) + fig.suptitle( + f"{solver_name} Sod Shock Tube - Comparison with Analytical Solution", + fontsize=16, + fontweight="bold", + ) + + # Colors + sim_color = "#0173B2" # Blue for simulation + ana_color = "#D55E00" # Red-orange for analytical + + pbar_anim = tqdm(total=len(frame_data), desc="Rendering") if HAS_TQDM else None + + def update(frame_num): + """Update function for animation.""" + data = frame_data[frame_num] + + # Get time from metadata + time_str = data["metadata"].get("time", "0.0") + try: + time = float(time_str.split()[0]) + except: + time = 0.0 + + # Get simulation data + x_sim = data["pos_x"] + sort_idx = np.argsort(x_sim) + x_sim = x_sim[sort_idx] + + rho_sim = data["dens"][sort_idx] + vel_sim = data["vel_x"][sort_idx] + pres_sim = data["pres"][sort_idx] + ene_sim = data["ene"][sort_idx] + + # Get analytical solution + x_ana = np.linspace(x_sim.min(), x_sim.max(), 500) + if sod_analytical is not None: + rho_ana, vel_ana, pres_ana, ene_ana = get_analytical_solution( + sod_analytical, gamma, time, x_ana + ) + else: + rho_ana = vel_ana = pres_ana = ene_ana = None + + # Clear axes + ax1.clear() + ax2.clear() + ax3.clear() + ax4.clear() + + # Density + if rho_ana is not None: + ax1.plot( + x_ana, + rho_ana, + color=ana_color, + linewidth=2.5, + label="Analytical", + zorder=1, + ) + ax1.scatter( + x_sim, + rho_sim, + color=sim_color, + s=10, + alpha=0.6, + label=solver_name, + zorder=2, + ) + ax1.set_ylabel("Density", fontsize=12, fontweight="bold") + ax1.set_title("Density Profile", fontsize=13, fontweight="bold") + ax1.legend(loc="upper right", fontsize=10) + ax1.grid(True, alpha=0.3) + ax1.set_xlim(x_sim.min(), x_sim.max()) + + # Velocity + if vel_ana is not None: + ax2.plot( + x_ana, + vel_ana, + color=ana_color, + linewidth=2.5, + label="Analytical", + zorder=1, + ) + ax2.scatter( + x_sim, + vel_sim, + color=sim_color, + s=10, + alpha=0.6, + label=solver_name, + zorder=2, + ) + ax2.set_ylabel("Velocity", fontsize=12, fontweight="bold") + ax2.set_title("Velocity Profile", fontsize=13, fontweight="bold") + ax2.legend(loc="upper left", fontsize=10) + ax2.grid(True, alpha=0.3) + ax2.set_xlim(x_sim.min(), x_sim.max()) + + # Pressure + if pres_ana is not None: + ax3.plot( + x_ana, + pres_ana, + color=ana_color, + linewidth=2.5, + label="Analytical", + zorder=1, + ) + ax3.scatter( + x_sim, + pres_sim, + color=sim_color, + s=10, + alpha=0.6, + label=solver_name, + zorder=2, + ) + ax3.set_ylabel("Pressure", fontsize=12, fontweight="bold") + ax3.set_xlabel("Position x", fontsize=12, fontweight="bold") + ax3.set_title("Pressure Profile", fontsize=13, fontweight="bold") + ax3.legend(loc="upper right", fontsize=10) + ax3.grid(True, alpha=0.3) + ax3.set_xlim(x_sim.min(), x_sim.max()) + + # Internal Energy + if ene_ana is not None: + ax4.plot( + x_ana, + ene_ana, + color=ana_color, + linewidth=2.5, + label="Analytical", + zorder=1, + ) + ax4.scatter( + x_sim, + ene_sim, + color=sim_color, + s=10, + alpha=0.6, + label=solver_name, + zorder=2, + ) + ax4.set_ylabel("Internal Energy", fontsize=12, fontweight="bold") + ax4.set_xlabel("Position x", fontsize=12, fontweight="bold") + ax4.set_title("Internal Energy Profile", fontsize=13, fontweight="bold") + ax4.legend(loc="upper right", fontsize=10) + ax4.grid(True, alpha=0.3) + ax4.set_xlim(x_sim.min(), x_sim.max()) + + # Add time label + fig.suptitle( + f"{solver_name} Sod Shock Tube - t = {time:.4f}\n" + f"Comparison with Analytical Solution", + fontsize=14, + fontweight="bold", + ) + + if pbar_anim: + pbar_anim.update(1) + + return ax1, ax2, ax3, ax4 + + if HAS_ANIMATION and len(frame_data) > 0: + anim = FuncAnimation( + fig, update, frames=len(frame_data), interval=100, blit=False, repeat=True + ) + + # Save animation + os.makedirs( + os.path.dirname(output_file) if os.path.dirname(output_file) else ".", + exist_ok=True, + ) + writer = PillowWriter(fps=10) + anim.save(output_file, writer=writer, dpi=150) + + if pbar_anim: + pbar_anim.close() + + plt.close() + + print() + print("=" * 70) + print("Animation Complete!") + print("=" * 70) + print(f"Saved: {output_file}") + else: + print("ERROR: Cannot create animation (no data or missing dependencies)") + + # Also create final state comparison plot + if len(frame_data) > 0: + print() + print("Creating final state comparison plot...") + + data = frame_data[-1] + time_str = data["metadata"].get("time", "0.0") + try: + time = float(time_str.split()[0]) + except: + time = 0.0 + + x_sim = data["pos_x"] + sort_idx = np.argsort(x_sim) + x_sim = x_sim[sort_idx] + + rho_sim = data["dens"][sort_idx] + vel_sim = data["vel_x"][sort_idx] + pres_sim = data["pres"][sort_idx] + ene_sim = data["ene"][sort_idx] + + x_ana = np.linspace(x_sim.min(), x_sim.max(), 500) + if sod_analytical is not None: + rho_ana, vel_ana, pres_ana, ene_ana = get_analytical_solution( + sod_analytical, gamma, time, x_ana + ) + else: + rho_ana = vel_ana = pres_ana = ene_ana = None + + fig2, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 12)) + fig2.suptitle( + f"{solver_name} Sod Shock Tube - Final State (t = {time:.4f})\n" + f"Comparison with Analytical Solution", + fontsize=14, + fontweight="bold", + ) + + if rho_ana is not None: + ax1.plot(x_ana, rho_ana, color=ana_color, linewidth=2.5, label="Analytical") + ax1.scatter(x_sim, rho_sim, color=sim_color, s=15, alpha=0.6, label=solver_name) + ax1.set_ylabel("Density", fontsize=12, fontweight="bold") + ax1.set_title("Density Profile", fontsize=13, fontweight="bold") + ax1.legend(fontsize=10) + ax1.grid(True, alpha=0.3) + + if vel_ana is not None: + ax2.plot(x_ana, vel_ana, color=ana_color, linewidth=2.5, label="Analytical") + ax2.scatter(x_sim, vel_sim, color=sim_color, s=15, alpha=0.6, label=solver_name) + ax2.set_ylabel("Velocity", fontsize=12, fontweight="bold") + ax2.set_title("Velocity Profile", fontsize=13, fontweight="bold") + ax2.legend(fontsize=10) + ax2.grid(True, alpha=0.3) + + if pres_ana is not None: + ax3.plot(x_ana, pres_ana, color=ana_color, linewidth=2.5, label="Analytical") + ax3.scatter(x_sim, pres_sim, color=sim_color, s=15, alpha=0.6, label=solver_name) + ax3.set_ylabel("Pressure", fontsize=12, fontweight="bold") + ax3.set_xlabel("Position x", fontsize=12, fontweight="bold") + ax3.set_title("Pressure Profile", fontsize=13, fontweight="bold") + ax3.legend(fontsize=10) + ax3.grid(True, alpha=0.3) + + if ene_ana is not None: + ax4.plot(x_ana, ene_ana, color=ana_color, linewidth=2.5, label="Analytical") + ax4.scatter(x_sim, ene_sim, color=sim_color, s=15, alpha=0.6, label=solver_name) + ax4.set_ylabel("Internal Energy", fontsize=12, fontweight="bold") + ax4.set_xlabel("Position x", fontsize=12, fontweight="bold") + ax4.set_title("Internal Energy Profile", fontsize=13, fontweight="bold") + ax4.legend(fontsize=10) + ax4.grid(True, alpha=0.3) + + plt.tight_layout() + + final_plot = output_file.replace(".gif", "_final.png") + plt.savefig(final_plot, dpi=150, bbox_inches="tight") + print(f"Saved: {final_plot}") + plt.close() + + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/exemples/common/visualization/animate_sod_vtk.py b/exemples/common/visualization/animate_sod_vtk.py new file mode 100755 index 000000000..3719cbb31 --- /dev/null +++ b/exemples/common/visualization/animate_sod_vtk.py @@ -0,0 +1,233 @@ +#!/usr/bin/env python3 +""" +Generate GIF animation from Sod Shock Tube VTK files. + +Works with both SPH and GSPH solver outputs. +Uses shamrock.phys.SodTube for analytical solution (no custom Python implementation). + +Usage: + python animate_sod_vtk.py [output_dir] [--solver SPH|GSPH] + +Examples: + python animate_sod_vtk.py simulations_data/gsph_sod/vtk --solver GSPH + python animate_sod_vtk.py simulations_data/sph_sod/vtk --solver SPH +""" + +import argparse +import glob +import os +import sys + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pyvista as pv +from matplotlib.animation import FuncAnimation, PillowWriter + +# Import shamrock for analytical solution +try: + import shamrock + + HAS_SHAMROCK = True +except ImportError: + HAS_SHAMROCK = False + print("Warning: shamrock module not found. Analytical solution will not be shown.") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Animate Sod shock tube VTK results") + parser.add_argument("vtk_dir", help="Directory containing VTK files") + parser.add_argument( + "output_dir", + nargs="?", + default=None, + help="Output directory (defaults to parent of vtk_dir)", + ) + parser.add_argument( + "--solver", + choices=["SPH", "GSPH"], + default="GSPH", + help="Solver type (affects file naming)", + ) + parser.add_argument("--gamma", type=float, default=1.4, help="Adiabatic index (default: 1.4)") + parser.add_argument( + "--t-final", + type=float, + default=0.245, + help="Final simulation time (default: 0.245)", + ) + parser.add_argument( + "--fps", type=int, default=10, help="Animation frames per second (default: 10)" + ) + return parser.parse_args() + + +def get_analytical_solution(sod, t, x_array): + """Get analytical solution at multiple x positions using shamrock.phys.SodTube.""" + rho = np.zeros(len(x_array)) + vel = np.zeros(len(x_array)) + pres = np.zeros(len(x_array)) + for i, x in enumerate(x_array): + rho[i], vel[i], pres[i] = sod.get_value(t, x) + return x_array, rho, vel, pres + + +def read_vtk(filename): + """Read VTK file using pyvista.""" + mesh = pv.read(filename) + points = np.array(mesh.points) + velocities = np.array(mesh["v"]) + hpart = np.array(mesh["h"]) + rho = np.array(mesh["rho"]) + P = np.array(mesh["P"]) + return points, velocities, hpart, rho, P + + +def main(): + args = parse_args() + + vtk_dir = args.vtk_dir + output_dir = args.output_dir or os.path.dirname(vtk_dir) + solver_name = args.solver + gamma = args.gamma + t_final = args.t_final + + # Find VTK files + vtk_pattern = os.path.join(vtk_dir, "*.vtk") + vtk_files = sorted(glob.glob(vtk_pattern)) + + print(f"{'=' * 70}") + print(f"Sod Shock Tube Animation ({solver_name})") + print(f"{'=' * 70}") + print(f"VTK directory: {vtk_dir}") + print(f"Output directory: {output_dir}") + print(f"Found {len(vtk_files)} VTK files") + print() + + if len(vtk_files) == 0: + print(f"ERROR: No VTK files found in {vtk_dir}") + sys.exit(1) + + n_frames = len(vtk_files) + dt_dump = t_final / n_frames + + # Create analytical solver using shamrock.phys.SodTube + sod_solver = None + if HAS_SHAMROCK: + # Standard Sod problem: left state (rho=1, P=1), right state (rho=0.125, P=0.1) + sod_solver = shamrock.phys.SodTube( + gamma=gamma, + rho_1=1.0, # Left density + P_1=1.0, # Left pressure + rho_5=0.125, # Right density + P_5=0.1, # Right pressure + ) + print(f"Analytical solution: shamrock.phys.SodTube (gamma={gamma})") + else: + print("Analytical solution: not available") + print() + + # Set up figure + fig, axes = plt.subplots(2, 2, figsize=(12, 10)) + + def update(frame): + vtk_file = vtk_files[frame] + t = frame * dt_dump + + # Read data + points, velocities, h, rho, P = read_vtk(vtk_file) + + x = points[:, 0] + vx = velocities[:, 0] + + # Sort by x + idx = np.argsort(x) + x_sort = x[idx] + rho_sort = rho[idx] + vx_sort = vx[idx] + P_sort = P[idx] + h_sort = h[idx] + + # Clear and redraw + for ax in axes.flat: + ax.clear() + + # Plot analytical solution if available + if sod_solver is not None and t > 0: + x_ana = np.linspace(-1.0, 1.0, 500) + _, rho_ana, vx_ana, P_ana = get_analytical_solution(sod_solver, t, x_ana) + + axes[0, 0].plot(x_ana, rho_ana, "r-", lw=2, label="Analytical") + axes[0, 1].plot(x_ana, vx_ana, "r-", lw=2, label="Analytical") + axes[1, 0].plot(x_ana, P_ana, "r-", lw=2, label="Analytical") + + # Density + axes[0, 0].scatter(x_sort, rho_sort, s=1, alpha=0.5, label=solver_name) + axes[0, 0].set_ylabel("Density") + axes[0, 0].set_title("Density") + axes[0, 0].legend() + axes[0, 0].set_xlim(-1.1, 1.1) + axes[0, 0].set_ylim(0, 1.2) + + # Velocity + axes[0, 1].scatter(x_sort, vx_sort, s=1, alpha=0.5, label=solver_name) + axes[0, 1].set_ylabel("Velocity") + axes[0, 1].set_title("Velocity") + axes[0, 1].legend() + axes[0, 1].set_xlim(-1.1, 1.1) + axes[0, 1].set_ylim(-0.1, 1.1) + + # Pressure + axes[1, 0].scatter(x_sort, P_sort, s=1, alpha=0.5, label=solver_name) + axes[1, 0].set_ylabel("Pressure") + axes[1, 0].set_xlabel("x") + axes[1, 0].set_title("Pressure") + axes[1, 0].legend() + axes[1, 0].set_xlim(-1.1, 1.1) + axes[1, 0].set_ylim(0, 1.2) + + # Smoothing length + axes[1, 1].scatter(x_sort, h_sort, s=1, alpha=0.5) + axes[1, 1].set_ylabel("h") + axes[1, 1].set_xlabel("x") + axes[1, 1].set_title("Smoothing Length h") + axes[1, 1].set_xlim(-1.1, 1.1) + + fig.suptitle( + f"{solver_name} Sod Shock Tube (t = {t:.3f})", + fontsize=14, + fontweight="bold", + ) + plt.tight_layout() + + return axes.flat + + # Create animation + print("Creating animation...") + anim = FuncAnimation(fig, update, frames=len(vtk_files), interval=100) + + # Save as GIF + solver_lower = solver_name.lower() + gif_path = os.path.join(output_dir, f"{solver_lower}_sod_animation.gif") + os.makedirs(output_dir, exist_ok=True) + print(f"Saving to {gif_path}...") + anim.save(gif_path, writer=PillowWriter(fps=args.fps)) + print(f"Animation saved to {gif_path}") + + # Save final frame as PNG + print("Saving final frame...") + update(len(vtk_files) - 1) + final_path = os.path.join(output_dir, f"{solver_lower}_sod_final.png") + plt.savefig(final_path, dpi=150) + print(f"Final frame saved to {final_path}") + + print() + print(f"{'=' * 70}") + print("Done!") + print(f"{'=' * 70}") + + +if __name__ == "__main__": + main() diff --git a/exemples/gsph/scripts/gsph_sod_shock_tube.py b/exemples/gsph/scripts/gsph_sod_shock_tube.py new file mode 100644 index 000000000..e8e676e90 --- /dev/null +++ b/exemples/gsph/scripts/gsph_sod_shock_tube.py @@ -0,0 +1,153 @@ +""" +GSPH Sod Shock Tube Simulation with VTK Output +=============================================== + +Runs the Sod shock tube test using Godunov SPH (GSPH) with HLLC Riemann solver +and outputs VTK files for visualization. + +Uses the same initial conditions as the SPH Sod test for direct comparison. + +Output: VTK files in output/ directory with simulation time metadata +""" + +import json +import os + +import shamrock + +# Physical parameters (same as SPH test) +gamma = 1.4 + +rho_L = 1.0 # Left density +rho_R = 0.125 # Right density + +P_L = 1.0 # Left pressure +P_R = 0.1 # Right pressure + +# Derived quantities +fact = (rho_L / rho_R) ** (1.0 / 3.0) +u_L = P_L / ((gamma - 1) * rho_L) # Left internal energy +u_R = P_R / ((gamma - 1) * rho_R) # Right internal energy + +# Resolution (same as SPH test) +resol = 128 + +# Initialize context and model +ctx = shamrock.Context() +ctx.pdata_layout_new() + +# Use GSPH model with M6 kernel (same as SPH test) +model = shamrock.get_Model_GSPH(context=ctx, vector_type="f64_3", sph_kernel="M6") + +# Configure solver +cfg = model.gen_default_config() + +# Set HLLC Riemann solver +cfg.set_riemann_hllc() + +# Set piecewise constant reconstruction (first-order, most stable) +cfg.set_reconstruct_piecewise_constant() + +# Set periodic boundaries (with wall particles for shock tube) +cfg.set_boundary_periodic() + +# Set adiabatic EOS +cfg.set_eos_adiabatic(gamma) + +# Print configuration +cfg.print_status() +model.set_solver_config(cfg) + +model.init_scheduler(int(1e8), 1) + +# Setup domain (same as SPH test) +(xs, ys, zs) = model.get_box_dim_fcc_3d(1, resol, 24, 24) +dr = 1 / xs +(xs, ys, zs) = model.get_box_dim_fcc_3d(dr, resol, 24, 24) + +model.resize_simulation_box((-xs, -ys / 2, -zs / 2), (xs, ys / 2, zs / 2)) + +# Setup initial conditions using HCP lattice (same as SPH test) +# Left side: high density (smaller spacing) +model.add_cube_hcp_3d(dr, (-xs, -ys / 2, -zs / 2), (0, ys / 2, zs / 2)) +# Right side: low density (larger spacing) +model.add_cube_hcp_3d(dr * fact, (0, -ys / 2, -zs / 2), (xs, ys / 2, zs / 2)) + +# Set internal energy for left and right states (discontinuity at x=0) +model.set_field_in_box("uint", "f64", u_L, (-xs, -ys / 2, -zs / 2), (0, ys / 2, zs / 2)) +model.set_field_in_box("uint", "f64", u_R, (0, -ys / 2, -zs / 2), (xs, ys / 2, zs / 2)) + +# Set particle mass (same as SPH test) +vol_b = xs * ys * zs +totmass = (rho_R * vol_b) + (rho_L * vol_b) +pmass = model.total_mass_to_part_mass(totmass) +model.set_particle_mass(pmass) + +print(f"Total mass: {totmass}") +print(f"Particle mass: {pmass}") + +# Set CFL conditions (same as SPH test) +model.set_cfl_cour(0.1) +model.set_cfl_force(0.1) + +# Simulation parameters (same as SPH test) +t_final = 0.245 +n_outputs = 50 +dt_output = t_final / n_outputs + +# Track output times +times = [] +output_count = 0 + +# Create output directory +output_dir = "simulations_data/gsph_sod/vtk" +os.makedirs(output_dir, exist_ok=True) + +# Initial output +filename = f"{output_dir}/gsph_sod_{output_count:04d}.vtk" +model.do_vtk_dump(filename, True) +times.append({"index": output_count, "time": 0.0, "file": filename}) +print(f"Saved: {filename} (t = 0.0)") +output_count += 1 + +# Time evolution with outputs +t_current = 0.0 +t_next_output = dt_output + +while t_current < t_final: + # Evolve to next output time or final time + t_target = min(t_next_output, t_final) + model.evolve_until(t_target) + t_current = t_target + + # Output VTK + filename = f"{output_dir}/gsph_sod_{output_count:04d}.vtk" + model.do_vtk_dump(filename, True) + times.append({"index": output_count, "time": t_current, "file": filename}) + print(f"Saved: {filename} (t = {t_current:.6f})") + output_count += 1 + + t_next_output += dt_output + +# Save times metadata +with open("simulations_data/gsph_sod/times_gsph_sod.json", "w") as f: + json.dump( + { + "method": "GSPH", + "riemann_solver": "HLLC", + "kernel": "M6", + "gamma": gamma, + "rho_L": rho_L, + "rho_R": rho_R, + "P_L": P_L, + "P_R": P_R, + "t_final": t_final, + "outputs": times, + }, + f, + indent=2, + ) + +print(f"\nSimulation complete! {output_count} VTK files saved to {output_dir}/") +print("\nNote: L2 error analysis not available for GSPH model.") +print("Use post-processing scripts for comparison with analytical solution.") diff --git a/src/shammodels/gsph/CMakeLists.txt b/src/shammodels/gsph/CMakeLists.txt index 82a5dc694..6227ecc76 100644 --- a/src/shammodels/gsph/CMakeLists.txt +++ b/src/shammodels/gsph/CMakeLists.txt @@ -11,10 +11,13 @@ cmake_minimum_required(VERSION 3.9) project(Shammodels_gsph CXX C) -# Sources: Core infrastructure + Physics modules +# Sources: GSPH solver, Model, and VTK I/O set(Sources src/SolverConfig.cpp + src/Solver.cpp + src/Model.cpp src/modules/UpdateDerivs.cpp + src/modules/io/VTKDump.cpp ) if(SHAMROCK_USE_SHARED_LIB) diff --git a/src/shammodels/gsph/include/shammodels/gsph/Model.hpp b/src/shammodels/gsph/include/shammodels/gsph/Model.hpp new file mode 100644 index 000000000..38d6a870c --- /dev/null +++ b/src/shammodels/gsph/include/shammodels/gsph/Model.hpp @@ -0,0 +1,394 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2025 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +#pragma once + +/** + * @file Model.hpp + * @author Guo Yansong (guo.yansong.ngy@gmail.com) + * @author Yona Lapeyre (yona.lapeyre@ens-lyon.fr) --no git blame-- + * @brief GSPH Model class - high-level interface for GSPH simulations + * + * The GSPH method originated from: + * - Inutsuka, S. (2002) "Reformulation of Smoothed Particle Hydrodynamics + * with Riemann Solver" + * + * This implementation follows: + * - Cha, S.-H. & Whitworth, A.P. (2003) "Implementations and tests of + * Godunov-type particle hydrodynamics" + */ + +#include "shambase/constants.hpp" +#include "shambase/exception.hpp" +#include "shambase/string.hpp" +#include "shamalgs/collective/exchanges.hpp" +#include "shambackends/BufferMirror.hpp" +#include "shambackends/vec.hpp" +#include "shamcomm/collectives.hpp" +#include "shamcomm/logs.hpp" +#include "shammodels/common/setup/generators.hpp" +#include "shammodels/gsph/Solver.hpp" +#include "shammodels/sph/math/density.hpp" +#include "shamrock/io/ShamrockDump.hpp" +#include "shamrock/patch/PatchDataLayer.hpp" +#include "shamrock/scheduler/ReattributeDataUtility.hpp" +#include "shamrock/scheduler/ShamrockCtx.hpp" +#include "shamsys/NodeInstance.hpp" +#include "shamsys/legacy/log.hpp" +#include "shamtree/kernels/geometry_utils.hpp" +#include +#include +#include + +namespace shammodels::gsph { + + /** + * @brief The GSPH Model class + * + * Provides a high-level interface for setting up and running GSPH simulations. + * The GSPH method uses Riemann solvers at particle interfaces instead of + * artificial viscosity, giving sharper shock resolution. + * + * @tparam Tvec Vector type (e.g., f64_3) + * @tparam SPHKernel Kernel type (e.g., M4, M6, C2, C4, C6 for Wendland) + */ + template class SPHKernel> + class Model { + public: + using Tscal = shambase::VecComponent; + static constexpr u32 dim = shambase::VectorProperties::dimension; + using Kernel = SPHKernel; + + using Solver = Solver; + using SolverConfig = typename Solver::Config; + + ShamrockCtx &ctx; + Solver solver; + + Model(ShamrockCtx &ctx) : ctx(ctx), solver(ctx) {}; + + //////////////////////////////////////////////////////////////////////////////////////////// + // Setup functions + //////////////////////////////////////////////////////////////////////////////////////////// + + void init_scheduler(u32 crit_split, u32 crit_merge); + + template = 0> + inline Tvec get_box_dim_fcc_3d(Tscal dr, u32 xcnt, u32 ycnt, u32 zcnt) { + return generic::setup::generators::get_box_dim(dr, xcnt, ycnt, zcnt); + } + + inline void set_cfl_cour(Tscal cfl_cour) { + solver.solver_config.cfl_config.cfl_cour = cfl_cour; + } + + inline void set_cfl_force(Tscal cfl_force) { + solver.solver_config.cfl_config.cfl_force = cfl_force; + } + + inline void set_particle_mass(Tscal gpart_mass) { + solver.solver_config.gpart_mass = gpart_mass; + } + + inline Tscal get_particle_mass() { return solver.solver_config.gpart_mass; } + + inline void resize_simulation_box(std::pair box) { + ctx.set_coord_domain_bound({box.first, box.second}); + } + + void do_vtk_dump(std::string filename, bool add_patch_world_id) { + solver.vtk_do_dump(filename, add_patch_world_id); + } + + u64 get_total_part_count(); + + f64 total_mass_to_part_mass(f64 totmass); + + std::pair get_ideal_fcc_box(Tscal dr, std::pair box); + std::pair get_ideal_hcp_box(Tscal dr, std::pair box); + + Tscal get_hfact() { return Kernel::hfactd; } + + Tscal rho_h(Tscal h) { + return shamrock::sph::rho_h(solver.solver_config.gpart_mass, h, Kernel::hfactd); + } + + void add_cube_fcc_3d(Tscal dr, std::pair _box); + void add_cube_hcp_3d(Tscal dr, std::pair _box); + + //////////////////////////////////////////////////////////////////////////////////////////// + // Field manipulation + //////////////////////////////////////////////////////////////////////////////////////////// + + /** + * @brief Apply a position-dependent function to initialize a field + * + * Sets field values by evaluating a function at each particle position. + * Useful for setting up spatially-varying initial conditions. + * + * @tparam T Field type (e.g., Tscal for density, Tvec for velocity) + * @param field_name Name of the field to modify (e.g., "uint", "vxyz") + * @param pos_to_val Function mapping position to field value + * + * Example: + * @code + * // Set velocity as a function of position + * model.apply_field_from_position("vxyz", [](Tvec pos) { + * return Tvec{pos[0], 0.0, 0.0}; // Linear velocity profile + * }); + * @endcode + */ + template + inline void apply_field_from_position( + std::string field_name, const std::function pos_to_val) { + + StackEntry stack_loc{}; + PatchScheduler &sched = shambase::get_check_ref(ctx.sched); + sched.patch_data.for_each_patchdata( + [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) { + PatchDataField &xyz + = pdat.template get_field(sched.pdl().get_field_idx("xyz")); + + PatchDataField &f + = pdat.template get_field(sched.pdl().get_field_idx(field_name)); + + if (f.get_nvar() != 1) { + shambase::throw_unimplemented(); + } + + { + auto &buf = f.get_buf(); + auto acc = buf.copy_to_stdvec(); + + auto &buf_xyz = xyz.get_buf(); + auto acc_xyz = buf_xyz.copy_to_stdvec(); + + for (u32 i = 0; i < f.get_obj_cnt(); i++) { + Tvec r = acc_xyz[i]; + acc[i] = pos_to_val(r); + } + + buf.copy_from_stdvec(acc); + buf_xyz.copy_from_stdvec(acc_xyz); + } + }); + } + + /** + * @brief Set field value for particles within a box region + * + * Sets the specified field to a constant value for all particles + * whose positions fall within the given axis-aligned box. + * Useful for setting up discontinuous initial conditions (e.g., Sod shock tube). + * + * @tparam T Field type (e.g., Tscal for scalars, Tvec for vectors) + * @param field_name Name of the field to modify (e.g., "uint", "vxyz") + * @param val Value to set for particles in the region + * @param box Bounding box as (min_corner, max_corner) + * @param ivar Variable index for multi-variable fields (default: 0) + * + * Example: + * @code + * // Sod shock tube: set left state internal energy + * model.set_field_in_box("uint", u_left, {box_min, interface_pos}); + * // Set right state + * model.set_field_in_box("uint", u_right, {interface_pos, box_max}); + * @endcode + */ + template + inline void set_field_in_box( + std::string field_name, T val, std::pair box, u32 ivar = 0) { + StackEntry stack_loc{}; + PatchScheduler &sched = shambase::get_check_ref(ctx.sched); + sched.patch_data.for_each_patchdata( + [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) { + PatchDataField &xyz + = pdat.template get_field(sched.pdl().get_field_idx("xyz")); + + PatchDataField &f + = pdat.template get_field(sched.pdl().get_field_idx(field_name)); + + u32 nvar = f.get_nvar(); + + // Validate ivar parameter to prevent out-of-bounds access + if (ivar >= nvar) { + shambase::throw_with_loc(shambase::format( + "set_field_in_box: ivar ({}) >= f.get_nvar ({}) for field {}", + ivar, + nvar, + field_name)); + } + + { + auto acc = f.get_buf().template mirror_to(); + auto acc_xyz = xyz.get_buf().template mirror_to(); + + for (u32 i = 0; i < f.get_obj_cnt(); i++) { + Tvec r = acc_xyz[i]; + + if (BBAA::is_coord_in_range(r, std::get<0>(box), std::get<1>(box))) { + acc[i * nvar + ivar] = val; + } + } + } + }); + } + + /** + * @brief Set field value for particles within a spherical region + * + * Sets the specified field to a constant value for all particles + * whose positions fall within the given sphere. + * Useful for setting up point-source initial conditions (e.g., Sedov blast). + * + * @tparam T Field type (must be single-variable, e.g., Tscal) + * @param field_name Name of the field to modify (e.g., "uint") + * @param val Value to set for particles in the region + * @param center Center of the sphere + * @param radius Radius of the sphere + * + * Example: + * @code + * // Sedov blast: inject energy in central sphere + * Tscal blast_energy_per_particle = E_blast / n_particles_in_sphere; + * model.set_field_in_sphere("uint", blast_energy_per_particle, origin, r_blast); + * @endcode + */ + template + inline void set_field_in_sphere(std::string field_name, T val, Tvec center, Tscal radius) { + StackEntry stack_loc{}; + PatchScheduler &sched = shambase::get_check_ref(ctx.sched); + sched.patch_data.for_each_patchdata( + [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) { + PatchDataField &xyz + = pdat.template get_field(sched.pdl().get_field_idx("xyz")); + + PatchDataField &f + = pdat.template get_field(sched.pdl().get_field_idx(field_name)); + + if (f.get_nvar() != 1) { + shambase::throw_unimplemented(); + } + + Tscal r2 = radius * radius; + { + auto acc = f.get_buf().template mirror_to(); + auto acc_xyz = xyz.get_buf().template mirror_to(); + + for (u32 i = 0; i < f.get_obj_cnt(); i++) { + Tvec dr = acc_xyz[i] - center; + + if (sycl::dot(dr, dr) < r2) { + acc[i] = val; + } + } + } + }); + } + + template + inline T get_sum(std::string name) { + PatchScheduler &sched = shambase::get_check_ref(ctx.sched); + T sum = shambase::VectorProperties::get_zero(); + + StackEntry stack_loc{}; + sched.patch_data.for_each_patchdata( + [&](u64 patch_id, shamrock::patch::PatchDataLayer &pdat) { + PatchDataField &xyz + = pdat.template get_field(sched.pdl().get_field_idx(name)); + + sum += xyz.compute_sum(); + }); + + return shamalgs::collective::allreduce_sum(sum); + } + + //////////////////////////////////////////////////////////////////////////////////////////// + // Solver configuration + //////////////////////////////////////////////////////////////////////////////////////////// + + inline SolverConfig gen_default_config() { + SolverConfig cfg; + cfg.set_riemann_iterative(); // Default to iterative Riemann solver + cfg.set_reconstruct_piecewise_constant(); // Default to 1st order (piecewise constant) + cfg.set_eos_adiabatic(Tscal{1.4}); + cfg.set_boundary_periodic(); + return cfg; + } + + inline void set_solver_config(SolverConfig cfg) { + if (ctx.is_scheduler_initialized()) { + shambase::throw_with_loc( + "Cannot change solver config after scheduler is initialized"); + } + cfg.check_config(); + solver.solver_config = cfg; + } + + inline f64 solver_logs_last_rate() { return solver.solve_logs.get_last_rate(); } + inline u64 solver_logs_last_obj_count() { return solver.solve_logs.get_last_obj_count(); } + + //////////////////////////////////////////////////////////////////////////////////////////// + // I/O (uses shared ShamrockDump mechanism like SPH) + //////////////////////////////////////////////////////////////////////////////////////////// + + inline void load_from_dump(std::string fname) { + if (shamcomm::world_rank() == 0) { + logger::info_ln("GSPH", "Loading state from dump", fname); + } + + std::string metadata_user{}; + shamrock::load_shamrock_dump(fname, metadata_user, ctx); + + nlohmann::json j = nlohmann::json::parse(metadata_user); + j.at("solver_config").get_to(solver.solver_config); + + solver.init_ghost_layout(); + solver.init_solver_graph(); + + PatchScheduler &sched = shambase::get_check_ref(ctx.sched); + sched.owned_patch_id = sched.patch_list.build_local(); + sched.patch_list.build_local_idx_map(); + sched.patch_list.build_global_idx_map(); + sched.update_local_load_value([&](shamrock::patch::Patch p) { + return sched.patch_data.owned_data.get(p.id_patch).get_obj_cnt(); + }); + } + + inline void dump(std::string fname) { + if (shamcomm::world_rank() == 0) { + logger::info_ln("GSPH", "Dumping state to", fname); + } + + solver.update_sync_load_values(); + + nlohmann::json metadata; + metadata["solver_config"] = solver.solver_config; + + shamrock::write_shamrock_dump( + fname, metadata.dump(4), shambase::get_check_ref(ctx.sched)); + } + + //////////////////////////////////////////////////////////////////////////////////////////// + // Simulation control + //////////////////////////////////////////////////////////////////////////////////////////// + + TimestepLog timestep() { return solver.evolve_once(); } + + inline void evolve_once() { + solver.evolve_once(); + solver.print_timestep_logs(); + } + + inline bool evolve_until(Tscal target_time, i32 niter_max = -1) { + return solver.evolve_until(target_time, niter_max); + } + }; + +} // namespace shammodels::gsph diff --git a/src/shammodels/gsph/include/shammodels/gsph/Solver.hpp b/src/shammodels/gsph/include/shammodels/gsph/Solver.hpp new file mode 100644 index 000000000..d2f15e08e --- /dev/null +++ b/src/shammodels/gsph/include/shammodels/gsph/Solver.hpp @@ -0,0 +1,216 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2025 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +#pragma once + +/** + * @file Solver.hpp + * @author Guo Yansong (guo.yansong.ngy@gmail.com) + * @author Yona Lapeyre (yona.lapeyre@ens-lyon.fr) --no git blame-- + * @brief GSPH Solver class + * + * The GSPH method originated from: + * - Inutsuka, S. (2002) "Reformulation of Smoothed Particle Hydrodynamics + * with Riemann Solver" + * + * This implementation follows: + * - Cha, S.-H. & Whitworth, A.P. (2003) "Implementations and tests of + * Godunov-type particle hydrodynamics" + */ + +#include "shambase/exception.hpp" +#include "SolverConfig.hpp" +#include "shambackends/vec.hpp" +#include "shammodels/gsph/modules/SolverStorage.hpp" +#include "shammodels/sph/BasicSPHGhosts.hpp" +#include "shammodels/sph/SPHUtilities.hpp" +#include "shammodels/sph/SolverLog.hpp" +#include "shamrock/patch/PatchDataLayerLayout.hpp" +#include "shamrock/scheduler/ComputeField.hpp" +#include "shamrock/scheduler/InterfacesUtility.hpp" +#include "shamrock/scheduler/SerialPatchTree.hpp" +#include "shamrock/scheduler/ShamrockCtx.hpp" +#include "shamsys/legacy/log.hpp" +#include "shamtree/TreeTraversalCache.hpp" +#include +#include +#include + +namespace shammodels::gsph { + + struct TimestepLog { + i32 rank; + f64 rate; + u64 npart; + f64 tcompute; + + inline f64 rate_sum() { return shamalgs::collective::allreduce_sum(rate); } + inline u64 npart_sum() { return shamalgs::collective::allreduce_sum(npart); } + inline f64 tcompute_max() { return shamalgs::collective::allreduce_max(tcompute); } + }; + + /** + * @brief The GSPH Solver class + * + * Implements the Godunov SPH method using Riemann solvers at particle + * interfaces instead of artificial viscosity. + * + * @tparam Tvec Vector type (e.g., f64_3) + * @tparam SPHKernel Kernel type (e.g., M4, M6, C2, C4, C6) + */ + template class SPHKernel> + class Solver { + public: + using Tscal = shambase::VecComponent; + static constexpr u32 dim = shambase::VectorProperties::dimension; + using Kernel = SPHKernel; + + using Config = SolverConfig; + + using u_morton = u32; + + static constexpr Tscal Rkern = Kernel::Rkern; + + ShamrockCtx &context; + inline PatchScheduler &scheduler() { return shambase::get_check_ref(context.sched); } + + SolverStorage storage{}; + + Config solver_config; + sph::SolverLog solve_logs; + + inline void init_required_fields() { solver_config.set_layout(context.get_pdl_write()); } + + // Serial patch tree control + void gen_serial_patch_tree(); + inline void reset_serial_patch_tree() { storage.serial_patch_tree.reset(); } + + // Ghost handling - reuse SPH ghost handler + using GhostHandle = sph::BasicSPHGhostHandler; + using GhostHandleCache = typename GhostHandle::CacheMap; + + void gen_ghost_handler(Tscal time_val); + inline void reset_ghost_handler() { storage.ghost_handler.reset(); } + + void build_ghost_cache(); + void clear_ghost_cache(); + + void merge_position_ghost(); + + // Tree operations + using RTree = typename Config::RTree; + void build_merged_pos_trees(); + void clear_merged_pos_trees(); + + void compute_presteps_rint(); + void reset_presteps_rint(); + + void start_neighbors_cache(); + void reset_neighbors_cache(); + + void gsph_prestep(Tscal time_val, Tscal dt); + + void apply_position_boundary(Tscal time_val); + + void do_predictor_leapfrog(Tscal dt); + + void init_ghost_layout(); + + void communicate_merge_ghosts_fields(); + void reset_merge_ghosts_fields(); + + void compute_omega(); + void compute_eos_fields(); + void reset_eos_fields(); + + void prepare_corrector(); + + /** + * @brief Update derivatives using GSPH Riemann solver + * + * This is the core GSPH step: for each particle pair, solve + * the 1D Riemann problem and compute forces from the interface + * pressure p*. + */ + void update_derivs(); + + /** + * @brief Compute CFL timestep constraint + * + * Computes timestep from: + * - Courant condition: dt_cour = C_cour * h / vsig + * - Force condition: dt_force = C_force * sqrt(h / |a|) + * + * @return Minimum CFL timestep across all particles + */ + Tscal compute_dt_cfl(); + + bool apply_corrector(Tscal dt, u64 Npart_all); + + void update_sync_load_values(); + + Solver(ShamrockCtx &context) : context(context) {} + + void init_solver_graph(); + + void vtk_do_dump(std::string filename, bool add_patch_world_id); + + inline void print_timestep_logs() { + if (shamcomm::world_rank() == 0) { + logger::info_ln( + "GSPH", "iteration since start :", solve_logs.get_iteration_count()); + logger::info_ln( + "GSPH", "time since start :", shambase::details::get_wtime(), "(s)"); + } + } + + TimestepLog evolve_once(); + + Tscal evolve_once_time_expl(Tscal t_current, Tscal dt_input) { + solver_config.set_time(t_current); + solver_config.set_next_dt(dt_input); + evolve_once(); + return solver_config.get_dt(); + } + + inline bool evolve_until(Tscal target_time, i32 niter_max = -1) { + auto step = [&]() { + Tscal dt = solver_config.get_dt(); + Tscal t = solver_config.get_time(); + + if (t > target_time) { + throw shambase::make_except_with_loc( + "the target time is higher than the current time"); + } + + if (t + dt > target_time) { + solver_config.set_next_dt(target_time - t); + } + evolve_once(); + }; + + i32 iter_count = 0; + + while (solver_config.get_time() < target_time) { + step(); + iter_count++; + + if ((iter_count >= niter_max) && (niter_max != -1)) { + logger::info_ln("GSPH", "stopping evolve until because of niter =", iter_count); + return false; + } + } + + print_timestep_logs(); + + return true; + } + }; + +} // namespace shammodels::gsph diff --git a/src/shammodels/gsph/include/shammodels/gsph/modules/SolverStorage.hpp b/src/shammodels/gsph/include/shammodels/gsph/modules/SolverStorage.hpp index 10b5b031b..ebf8fbc6b 100644 --- a/src/shammodels/gsph/include/shammodels/gsph/modules/SolverStorage.hpp +++ b/src/shammodels/gsph/include/shammodels/gsph/modules/SolverStorage.hpp @@ -103,6 +103,7 @@ namespace shammodels::gsph { std::shared_ptr> omega; /// Ghost data layout and merged data + std::shared_ptr xyzh_ghost_layout; Component> ghost_layout; Component> merged_patchdata_ghost; diff --git a/src/shammodels/gsph/include/shammodels/gsph/modules/io/VTKDump.hpp b/src/shammodels/gsph/include/shammodels/gsph/modules/io/VTKDump.hpp new file mode 100644 index 000000000..c0d2fa831 --- /dev/null +++ b/src/shammodels/gsph/include/shammodels/gsph/modules/io/VTKDump.hpp @@ -0,0 +1,47 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2025 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +#pragma once + +/** + * @file VTKDump.hpp + * @author Guo Yansong (guo.yansong.ngy@gmail.com) + * @author Yona Lapeyre (yona.lapeyre@ens-lyon.fr) --no git blame-- + * @brief VTK dump module for GSPH solver + */ + +#include "shambackends/typeAliasVec.hpp" +#include "shambackends/vec.hpp" +#include "shammodels/gsph/SolverConfig.hpp" +#include "shamrock/scheduler/ShamrockCtx.hpp" + +namespace shammodels::gsph::modules { + + template class SPHKernel> + class VTKDump { + public: + using Tscal = shambase::VecComponent; + static constexpr u32 dim = shambase::VectorProperties::dimension; + using Kernel = SPHKernel; + + using Config = SolverConfig; + + ShamrockCtx &context; + Config &solver_config; + + VTKDump(ShamrockCtx &context, Config &solver_config) + : context(context), solver_config(solver_config) {} + + void do_dump(std::string filename, bool add_patch_world_id); + + private: + inline PatchScheduler &scheduler() { return shambase::get_check_ref(context.sched); } + }; + +} // namespace shammodels::gsph::modules diff --git a/src/shammodels/gsph/src/Model.cpp b/src/shammodels/gsph/src/Model.cpp new file mode 100644 index 000000000..b0d31044c --- /dev/null +++ b/src/shammodels/gsph/src/Model.cpp @@ -0,0 +1,312 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2025 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +/** + * @file Model.cpp + * @author Guo Yansong (guo.yansong.ngy@gmail.com) + * @author Yona Lapeyre (yona.lapeyre@ens-lyon.fr) --no git blame-- + * @brief GSPH Model implementation + */ + +#include "shambase/aliases_float.hpp" +#include "shambase/exception.hpp" +#include "shambase/memory.hpp" +#include "shambase/stacktrace.hpp" +#include "shambase/string.hpp" +#include "shamcomm/logs.hpp" +#include "shammath/CoordRange.hpp" +#include "shammath/crystalLattice.hpp" +#include "shammath/sphkernels.hpp" +#include "shammodels/common/setup/generators.hpp" +#include "shammodels/gsph/Model.hpp" +#include "shamrock/patch/PatchDataLayer.hpp" +#include "shamrock/scheduler/DataInserterUtility.hpp" +#include "shamrock/scheduler/PatchScheduler.hpp" +#include "shamsys/NodeInstance.hpp" +#include "shamsys/legacy/log.hpp" +#include +#include +#include + +template class SPHKernel> +void shammodels::gsph::Model::init_scheduler(u32 crit_split, u32 crit_merge) { + solver.init_required_fields(); + ctx.init_sched(crit_split, crit_merge); + + using namespace shamrock::patch; + + PatchScheduler &sched = shambase::get_check_ref(ctx.sched); + + sched.add_root_patch(); + + shamlog_debug_ln("Sys", "build local scheduler tables"); + sched.owned_patch_id = sched.patch_list.build_local(); + sched.patch_list.build_local_idx_map(); + sched.update_local_load_value([&](shamrock::patch::Patch p) { + return sched.patch_data.owned_data.get(p.id_patch).get_obj_cnt(); + }); + solver.init_ghost_layout(); + + solver.init_solver_graph(); +} + +template class SPHKernel> +u64 shammodels::gsph::Model::get_total_part_count() { + PatchScheduler &sched = shambase::get_check_ref(ctx.sched); + return shamalgs::collective::allreduce_sum(sched.get_rank_count()); +} + +template class SPHKernel> +f64 shammodels::gsph::Model::total_mass_to_part_mass(f64 totmass) { + return totmass / get_total_part_count(); +} + +template class SPHKernel> +auto shammodels::gsph::Model::get_ideal_fcc_box( + Tscal dr, std::pair box) -> std::pair { + StackEntry stack_loc{}; + auto [a, b] = generic::setup::generators::get_ideal_fcc_box( + dr, std::make_tuple(box.first, box.second)); + return {a, b}; +} + +template class SPHKernel> +auto shammodels::gsph::Model::get_ideal_hcp_box( + Tscal dr, std::pair box) -> std::pair { + StackEntry stack_loc{}; + auto [a, b] = generic::setup::generators::get_ideal_fcc_box( + dr, std::make_tuple(box.first, box.second)); + return {a, b}; +} + +template class SPHKernel> +void shammodels::gsph::Model::add_cube_fcc_3d( + Tscal dr, std::pair _box) { + StackEntry stack_loc{}; + + shammath::CoordRange box = _box; + + using namespace shamrock::patch; + + PatchScheduler &sched = shambase::get_check_ref(ctx.sched); + + std::string log = ""; + + auto make_sliced = [&]() { + std::vector vec_lst; + generic::setup::generators::add_particles_fcc( + dr, + std::make_tuple(box.lower, box.upper), + [&](Tvec r) { + return box.contain_pos(r); + }, + [&](Tvec r, Tscal h) { + vec_lst.push_back(r); + }); + + std::vector> sliced_buf; + + u32 sz_buf = sched.crit_patch_split * 4; + + std::vector cur_buf; + for (u32 i = 0; i < vec_lst.size(); i++) { + cur_buf.push_back(vec_lst[i]); + + if (cur_buf.size() > sz_buf) { + sliced_buf.push_back(std::exchange(cur_buf, std::vector{})); + } + } + + if (cur_buf.size() > 0) { + sliced_buf.push_back(std::exchange(cur_buf, std::vector{})); + } + + return sliced_buf; + }; + + std::vector> sliced_buf = make_sliced(); + + for (std::vector to_ins : sliced_buf) { + + sched.for_each_local_patchdata([&](const Patch p, PatchDataLayer &pdat) { + PatchCoordTransform ptransf + = sched.get_sim_box().template get_patch_transform(); + + shammath::CoordRange patch_coord = ptransf.to_obj_coord(p); + + std::vector vec_acc; + for (Tvec r : to_ins) { + if (patch_coord.contain_pos(r)) { + vec_acc.push_back(r); + } + } + + if (vec_acc.size() == 0) { + return; + } + + log += shambase::format( + "\n rank = {} patch id={}, add N={} particles, coords = {} {}", + shamcomm::world_rank(), + p.id_patch, + vec_acc.size(), + patch_coord.lower, + patch_coord.upper); + + PatchDataLayer tmp(sched.get_layout_ptr()); + tmp.resize(vec_acc.size()); + tmp.fields_raz(); + + { + u32 len = vec_acc.size(); + PatchDataField &f + = tmp.template get_field(sched.pdl().template get_field_idx("xyz")); + sycl::buffer buf(vec_acc.data(), len); + f.override(buf, len); + } + + { + PatchDataField &f = tmp.template get_field( + sched.pdl().template get_field_idx("hpart")); + using Kernel = SPHKernel; + f.override(Kernel::hfactd * dr); + } + + pdat.insert_elements(tmp); + }); + + sched.check_patchdata_locality_corectness(); + sched.scheduler_step(true, true); + } + + sched.owned_patch_id = sched.patch_list.build_local(); + sched.patch_list.build_local_idx_map(); + sched.update_local_load_value([&](Patch p) { + return sched.patch_data.owned_data.get(p.id_patch).get_obj_cnt(); + }); + + shamlog_debug_ln("setup", log); +} + +template class SPHKernel> +void shammodels::gsph::Model::add_cube_hcp_3d( + Tscal dr, std::pair _box) { + StackEntry stack_loc{}; + + shammath::CoordRange box = _box; + + using namespace shamrock::patch; + + PatchScheduler &sched = shambase::get_check_ref(ctx.sched); + + std::string log = ""; + + auto make_sliced = [&]() { + std::vector vec_lst; + generic::setup::generators::add_particles_fcc( + dr, + std::make_tuple(box.lower, box.upper), + [&](Tvec r) { + return box.contain_pos(r); + }, + [&](Tvec r, Tscal h) { + vec_lst.push_back(r); + }); + + std::vector> sliced_buf; + + u32 sz_buf = sched.crit_patch_split * 4; + + std::vector cur_buf; + for (u32 i = 0; i < vec_lst.size(); i++) { + cur_buf.push_back(vec_lst[i]); + + if (cur_buf.size() > sz_buf) { + sliced_buf.push_back(std::exchange(cur_buf, std::vector{})); + } + } + + if (cur_buf.size() > 0) { + sliced_buf.push_back(std::exchange(cur_buf, std::vector{})); + } + + return sliced_buf; + }; + + std::vector> sliced_buf = make_sliced(); + + for (std::vector to_ins : sliced_buf) { + + sched.for_each_local_patchdata([&](const Patch p, PatchDataLayer &pdat) { + PatchCoordTransform ptransf + = sched.get_sim_box().template get_patch_transform(); + + shammath::CoordRange patch_coord = ptransf.to_obj_coord(p); + + std::vector vec_acc; + for (Tvec r : to_ins) { + if (patch_coord.contain_pos(r)) { + vec_acc.push_back(r); + } + } + + if (vec_acc.size() == 0) { + return; + } + + log += shambase::format( + "\n rank = {} patch id={}, add N={} particles, coords = {} {}", + shamcomm::world_rank(), + p.id_patch, + vec_acc.size(), + patch_coord.lower, + patch_coord.upper); + + PatchDataLayer tmp(sched.get_layout_ptr()); + tmp.resize(vec_acc.size()); + tmp.fields_raz(); + + { + u32 len = vec_acc.size(); + PatchDataField &f + = tmp.template get_field(sched.pdl().template get_field_idx("xyz")); + sycl::buffer buf(vec_acc.data(), len); + f.override(buf, len); + } + + { + PatchDataField &f = tmp.template get_field( + sched.pdl().template get_field_idx("hpart")); + using Kernel = SPHKernel; + f.override(Kernel::hfactd * dr); + } + + pdat.insert_elements(tmp); + }); + + sched.check_patchdata_locality_corectness(); + sched.scheduler_step(true, true); + } + + sched.owned_patch_id = sched.patch_list.build_local(); + sched.patch_list.build_local_idx_map(); + sched.update_local_load_value([&](Patch p) { + return sched.patch_data.owned_data.get(p.id_patch).get_obj_cnt(); + }); + + shamlog_debug_ln("setup", log); +} + +// Explicit template instantiations for all supported kernel types +template class shammodels::gsph::Model; +template class shammodels::gsph::Model; +template class shammodels::gsph::Model; +template class shammodels::gsph::Model; +template class shammodels::gsph::Model; +template class shammodels::gsph::Model; diff --git a/src/shammodels/gsph/src/Solver.cpp b/src/shammodels/gsph/src/Solver.cpp new file mode 100644 index 000000000..98b0a5936 --- /dev/null +++ b/src/shammodels/gsph/src/Solver.cpp @@ -0,0 +1,1523 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2025 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +/** + * @file Solver.cpp + * @author Guo Yansong (guo.yansong.ngy@gmail.com) + * @author Yona Lapeyre (yona.lapeyre@ens-lyon.fr) --no git blame-- + * @brief GSPH Solver implementation + * + * The GSPH method originated from: + * - Inutsuka, S. (2002) "Reformulation of Smoothed Particle Hydrodynamics + * with Riemann Solver" + * + * This implementation follows: + * - Cha, S.-H. & Whitworth, A.P. (2003) "Implementations and tests of + * Godunov-type particle hydrodynamics" + */ + +#include "shambase/exception.hpp" +#include "shambase/memory.hpp" +#include "shambase/string.hpp" +#include "shambase/time.hpp" +#include "shamalgs/collective/exchanges.hpp" +#include "shamalgs/collective/reduction.hpp" +#include "shambackends/kernel_call.hpp" +#include "shambackends/math.hpp" +#include "shamcomm/collectives.hpp" +#include "shamcomm/logs.hpp" +#include "shamcomm/worldInfo.hpp" +#include "shammath/sphkernels.hpp" +#include "shammodels/gsph/Solver.hpp" +#include "shammodels/gsph/SolverConfig.hpp" +#include "shammodels/gsph/modules/UpdateDerivs.hpp" +#include "shammodels/gsph/modules/io/VTKDump.hpp" +#include "shammodels/sph/BasicSPHGhosts.hpp" +#include "shammodels/sph/SPHUtilities.hpp" +#include "shammodels/sph/modules/IterateSmoothingLengthDensity.hpp" +#include "shammodels/sph/modules/LoopSmoothingLengthIter.hpp" +#include "shammodels/sph/modules/NeighbourCache.hpp" +#include "shamrock/patch/Patch.hpp" +#include "shamrock/patch/PatchDataLayer.hpp" +#include "shamrock/patch/PatchDataLayerLayout.hpp" +#include "shamrock/scheduler/ComputeField.hpp" +#include "shamrock/scheduler/InterfacesUtility.hpp" +#include "shamrock/scheduler/PatchScheduler.hpp" +#include "shamrock/scheduler/ReattributeDataUtility.hpp" +#include "shamrock/scheduler/SchedulerUtility.hpp" +#include "shamrock/scheduler/SerialPatchTree.hpp" +#include "shamrock/solvergraph/Field.hpp" +#include "shamrock/solvergraph/FieldRefs.hpp" +#include "shamrock/solvergraph/Indexes.hpp" +#include "shamsys/NodeInstance.hpp" +#include "shamsys/legacy/log.hpp" +#include "shamtree/KarrasRadixTreeField.hpp" +#include "shamtree/TreeTraversal.hpp" +#include "shamtree/TreeTraversalCache.hpp" +#include "shamtree/kernels/geometry_utils.hpp" +#include +#include +#include + +template class Kern> +void shammodels::gsph::Solver::init_solver_graph() { + + storage.part_counts + = std::make_shared>("part_counts", "N_{\\rm part}"); + + storage.part_counts_with_ghost = std::make_shared>( + "part_counts_with_ghost", "N_{\\rm part, with ghost}"); + + storage.patch_rank_owner + = std::make_shared>("patch_rank_owner", "rank"); + + // Merged ghost spans + storage.positions_with_ghosts + = std::make_shared>("part_pos", "\\mathbf{r}"); + storage.hpart_with_ghosts + = std::make_shared>("h_part", "h"); + + storage.neigh_cache + = std::make_shared("neigh_cache", "neigh"); + + storage.omega = std::make_shared>(1, "omega", "\\Omega"); + storage.density = std::make_shared>(1, "density", "\\rho"); + storage.pressure = std::make_shared>(1, "pressure", "P"); + storage.soundspeed + = std::make_shared>(1, "soundspeed", "c_s"); +} + +template class Kern> +void shammodels::gsph::Solver::vtk_do_dump( + std::string filename, bool add_patch_world_id) { + + modules::VTKDump(context, solver_config).do_dump(filename, add_patch_world_id); +} + +template class Kern> +void shammodels::gsph::Solver::gen_serial_patch_tree() { + StackEntry stack_loc{}; + + SerialPatchTree _sptree = SerialPatchTree::build(scheduler()); + _sptree.attach_buf(); + storage.serial_patch_tree.set(std::move(_sptree)); +} + +template class Kern> +void shammodels::gsph::Solver::gen_ghost_handler(Tscal time_val) { + StackEntry stack_loc{}; + + using CfgClass = sph::BasicSPHGhostHandlerConfig; + using BCConfig = typename CfgClass::Variant; + + using BCFree = typename CfgClass::Free; + using BCPeriodic = typename CfgClass::Periodic; + using BCShearingPeriodic = typename CfgClass::ShearingPeriodic; + + using SolverConfigBC = typename Config::BCConfig; + using SolverBCFree = typename SolverConfigBC::Free; + using SolverBCPeriodic = typename SolverConfigBC::Periodic; + using SolverBCShearingPeriodic = typename SolverConfigBC::ShearingPeriodic; + + // Boundary condition selection - similar to SPH solver + // Note: Wall boundaries use Periodic with dynamic wall particles + if (SolverBCFree *c = std::get_if(&solver_config.boundary_config.config)) { + storage.ghost_handler.set( + GhostHandle{ + scheduler(), BCFree{}, storage.patch_rank_owner, storage.xyzh_ghost_layout}); + } else if ( + SolverBCPeriodic *c + = std::get_if(&solver_config.boundary_config.config)) { + storage.ghost_handler.set( + GhostHandle{ + scheduler(), BCPeriodic{}, storage.patch_rank_owner, storage.xyzh_ghost_layout}); + } else if ( + SolverBCShearingPeriodic *c + = std::get_if(&solver_config.boundary_config.config)) { + // Shearing periodic boundaries (Stone 2010) - reuse SPH implementation + storage.ghost_handler.set( + GhostHandle{ + scheduler(), + BCShearingPeriodic{ + c->shear_base, c->shear_dir, c->shear_speed * time_val, c->shear_speed}, + storage.patch_rank_owner, + storage.xyzh_ghost_layout}); + } else { + shambase::throw_with_loc("GSPH: Unsupported boundary condition type."); + } +} + +template class Kern> +void shammodels::gsph::Solver::build_ghost_cache() { + StackEntry stack_loc{}; + + using SPHUtils = sph::SPHUtilities; + SPHUtils sph_utils(scheduler()); + + storage.ghost_patch_cache.set(sph_utils.build_interf_cache( + storage.ghost_handler.get(), + storage.serial_patch_tree.get(), + solver_config.htol_up_coarse_cycle)); +} + +template class Kern> +void shammodels::gsph::Solver::clear_ghost_cache() { + StackEntry stack_loc{}; + storage.ghost_patch_cache.reset(); +} + +template class Kern> +void shammodels::gsph::Solver::merge_position_ghost() { + StackEntry stack_loc{}; + + storage.merged_xyzh.set( + storage.ghost_handler.get().build_comm_merge_positions(storage.ghost_patch_cache.get())); + + // Set element counts + shambase::get_check_ref(storage.part_counts).indexes + = storage.merged_xyzh.get().template map( + [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) { + return scheduler().patch_data.get_pdat(id).get_obj_cnt(); + }); + + // Set element counts with ghost + shambase::get_check_ref(storage.part_counts_with_ghost).indexes + = storage.merged_xyzh.get().template map( + [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) { + return mpdat.get_obj_cnt(); + }); + + // Attach spans to block coords + shambase::get_check_ref(storage.positions_with_ghosts) + .set_refs( + storage.merged_xyzh.get().template map>>( + [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) { + return std::ref(mpdat.get_field(0)); + })); + + shambase::get_check_ref(storage.hpart_with_ghosts) + .set_refs( + storage.merged_xyzh.get().template map>>( + [&](u64 id, shamrock::patch::PatchDataLayer &mpdat) { + return std::ref(mpdat.get_field(1)); + })); +} + +template class Kern> +void shammodels::gsph::Solver::build_merged_pos_trees() { + StackEntry stack_loc{}; + + auto &merged_xyzh = storage.merged_xyzh.get(); + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + + shambase::DistributedData trees + = merged_xyzh.template map([&](u64 id, shamrock::patch::PatchDataLayer &merged) { + PatchDataField &pos = merged.template get_field(0); + Tvec bmax = pos.compute_max(); + Tvec bmin = pos.compute_min(); + + shammath::AABB aabb(bmin, bmax); + + Tscal infty = std::numeric_limits::infinity(); + + // Ensure that no particle is on the boundary of the AABB + aabb.lower[0] = std::nextafter(aabb.lower[0], -infty); + aabb.lower[1] = std::nextafter(aabb.lower[1], -infty); + aabb.lower[2] = std::nextafter(aabb.lower[2], -infty); + aabb.upper[0] = std::nextafter(aabb.upper[0], infty); + aabb.upper[1] = std::nextafter(aabb.upper[1], infty); + aabb.upper[2] = std::nextafter(aabb.upper[2], infty); + + auto bvh = RTree::make_empty(dev_sched); + bvh.rebuild_from_positions( + pos.get_buf(), pos.get_obj_cnt(), aabb, solver_config.tree_reduction_level); + + return bvh; + }); + + storage.merged_pos_trees.set(std::move(trees)); +} + +template class Kern> +void shammodels::gsph::Solver::clear_merged_pos_trees() { + StackEntry stack_loc{}; + storage.merged_pos_trees.reset(); +} + +template class Kern> +void shammodels::gsph::Solver::compute_presteps_rint() { + StackEntry stack_loc{}; + + auto &xyzh_merged = storage.merged_xyzh.get(); + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + + storage.rtree_rint_field.set( + storage.merged_pos_trees.get().template map>( + [&](u64 id, RTree &rtree) -> shamtree::KarrasRadixTreeField { + shamrock::patch::PatchDataLayer &tmp = xyzh_merged.get(id); + auto &buf = tmp.get_field_buf_ref(1); + auto buf_int = shamtree::new_empty_karras_radix_tree_field(); + + auto ret = shamtree::compute_tree_field_max_field( + rtree.structure, + rtree.reduced_morton_set.get_leaf_cell_iterator(), + std::move(buf_int), + buf); + + // Increase the size by tolerance factor + sham::kernel_call( + dev_sched->get_queue(), + sham::MultiRef{}, + sham::MultiRef{ret.buf_field}, + ret.buf_field.get_size(), + [htol = solver_config.htol_up_coarse_cycle](u32 i, Tscal *h_tree) { + h_tree[i] *= htol; + }); + + return std::move(ret); + })); +} + +template class Kern> +void shammodels::gsph::Solver::reset_presteps_rint() { + storage.rtree_rint_field.reset(); +} + +template class Kern> +void shammodels::gsph::Solver::start_neighbors_cache() { + StackEntry stack_loc{}; + + shambase::Timer time_neigh; + time_neigh.start(); + + Tscal h_tolerance = solver_config.htol_up_coarse_cycle; + + // Build neighbor cache using tree traversal - same approach as SPH module + auto build_neigh_cache = [&](u64 patch_id) -> shamrock::tree::ObjectCache { + auto &mfield = storage.merged_xyzh.get().get(patch_id); + + sham::DeviceBuffer &buf_xyz = mfield.template get_field_buf_ref(0); + sham::DeviceBuffer &buf_hpart = mfield.template get_field_buf_ref(1); + + sham::DeviceBuffer &tree_field_rint + = storage.rtree_rint_field.get().get(patch_id).buf_field; + + RTree &tree = storage.merged_pos_trees.get().get(patch_id); + auto obj_it = tree.get_object_iterator(); + + u32 obj_cnt = shambase::get_check_ref(storage.part_counts).indexes.get(patch_id); + + constexpr Tscal Rker2 = Kernel::Rkern * Kernel::Rkern; + + // Allocate neighbor count buffer + sham::DeviceBuffer neigh_count( + obj_cnt, shamsys::instance::get_compute_scheduler_ptr()); + + shamsys::instance::get_compute_queue().wait_and_throw(); + + // First pass: count neighbors + { + sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue(); + sham::EventList depends_list; + + auto xyz = buf_xyz.get_read_access(depends_list); + auto hpart = buf_hpart.get_read_access(depends_list); + auto rint_tree = tree_field_rint.get_read_access(depends_list); + auto neigh_cnt = neigh_count.get_write_access(depends_list); + auto particle_looper = obj_it.get_read_access(depends_list); + + auto e = q.submit(depends_list, [&, h_tolerance](sycl::handler &cgh) { + shambase::parallel_for(cgh, obj_cnt, "gsph_count_neighbors", [=](u64 gid) { + u32 id_a = (u32) gid; + + Tscal rint_a = hpart[id_a] * h_tolerance; + Tvec xyz_a = xyz[id_a]; + + Tvec inter_box_a_min = xyz_a - rint_a * Kernel::Rkern; + Tvec inter_box_a_max = xyz_a + rint_a * Kernel::Rkern; + + u32 cnt = 0; + + particle_looper.rtree_for( + [&](u32 node_id, shammath::AABB node_aabb) -> bool { + Tscal int_r_max_cell = rint_tree[node_id] * Kernel::Rkern; + + using namespace walker::interaction_crit; + + return sph_radix_cell_crit( + xyz_a, + inter_box_a_min, + inter_box_a_max, + node_aabb.lower, + node_aabb.upper, + int_r_max_cell); + }, + [&](u32 id_b) { + Tvec dr = xyz_a - xyz[id_b]; + Tscal rab2 = sycl::dot(dr, dr); + Tscal rint_b = hpart[id_b] * h_tolerance; + + bool no_interact + = rab2 > rint_a * rint_a * Rker2 && rab2 > rint_b * rint_b * Rker2; + + cnt += (no_interact) ? 0 : 1; + }); + + neigh_cnt[id_a] = cnt; + }); + }); + + buf_xyz.complete_event_state(e); + buf_hpart.complete_event_state(e); + neigh_count.complete_event_state(e); + tree_field_rint.complete_event_state(e); + obj_it.complete_event_state(e); + } + + // Use tree::prepare_object_cache to do prefix sum and allocate buffers + shamrock::tree::ObjectCache pcache + = shamrock::tree::prepare_object_cache(std::move(neigh_count), obj_cnt); + + // Second pass: fill neighbor indices + { + sham::DeviceQueue &q = shamsys::instance::get_compute_scheduler().get_queue(); + sham::EventList depends_list; + + auto xyz = buf_xyz.get_read_access(depends_list); + auto hpart = buf_hpart.get_read_access(depends_list); + auto rint_tree = tree_field_rint.get_read_access(depends_list); + auto scanned_neigh_cnt = pcache.scanned_cnt.get_read_access(depends_list); + auto neigh = pcache.index_neigh_map.get_write_access(depends_list); + auto particle_looper = obj_it.get_read_access(depends_list); + + auto e = q.submit(depends_list, [&, h_tolerance](sycl::handler &cgh) { + shambase::parallel_for(cgh, obj_cnt, "gsph_fill_neighbors", [=](u64 gid) { + u32 id_a = (u32) gid; + + Tscal rint_a = hpart[id_a] * h_tolerance; + Tvec xyz_a = xyz[id_a]; + + Tvec inter_box_a_min = xyz_a - rint_a * Kernel::Rkern; + Tvec inter_box_a_max = xyz_a + rint_a * Kernel::Rkern; + + u32 write_idx = scanned_neigh_cnt[id_a]; + + particle_looper.rtree_for( + [&](u32 node_id, shammath::AABB node_aabb) -> bool { + Tscal int_r_max_cell = rint_tree[node_id] * Kernel::Rkern; + + using namespace walker::interaction_crit; + + return sph_radix_cell_crit( + xyz_a, + inter_box_a_min, + inter_box_a_max, + node_aabb.lower, + node_aabb.upper, + int_r_max_cell); + }, + [&](u32 id_b) { + Tvec dr = xyz_a - xyz[id_b]; + Tscal rab2 = sycl::dot(dr, dr); + Tscal rint_b = hpart[id_b] * h_tolerance; + + bool no_interact + = rab2 > rint_a * rint_a * Rker2 && rab2 > rint_b * rint_b * Rker2; + + if (!no_interact) { + neigh[write_idx++] = id_b; + } + }); + }); + }); + + buf_xyz.complete_event_state(e); + buf_hpart.complete_event_state(e); + tree_field_rint.complete_event_state(e); + pcache.scanned_cnt.complete_event_state(e); + pcache.index_neigh_map.complete_event_state(e); + obj_it.complete_event_state(e); + } + + return pcache; + }; + + shambase::get_check_ref(storage.neigh_cache).free_alloc(); + + using namespace shamrock::patch; + scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) { + auto &ncache = shambase::get_check_ref(storage.neigh_cache); + ncache.neigh_cache.add_obj(cur_p.id_patch, build_neigh_cache(cur_p.id_patch)); + }); + + time_neigh.end(); + storage.timings_details.neighbors += time_neigh.elasped_sec(); +} + +template class Kern> +void shammodels::gsph::Solver::reset_neighbors_cache() { + storage.neigh_cache->neigh_cache = {}; +} + +template class Kern> +void shammodels::gsph::Solver::gsph_prestep(Tscal time_val, Tscal dt) { + StackEntry stack_loc{}; + + shamlog_debug_ln("GSPH", "Prestep at t =", time_val, "dt =", dt); +} + +template class Kern> +void shammodels::gsph::Solver::apply_position_boundary(Tscal time_val) { + StackEntry stack_loc{}; + + shamlog_debug_ln("GSPH", "apply position boundary"); + + PatchScheduler &sched = scheduler(); + shamrock::SchedulerUtility integrators(sched); + shamrock::ReattributeDataUtility reatrib(sched); + + auto &pdl = sched.pdl(); + const u32 ixyz = pdl.get_field_idx("xyz"); + const u32 ivxyz = pdl.get_field_idx("vxyz"); + auto [bmin, bmax] = sched.get_box_volume(); + + using SolverConfigBC = typename Config::BCConfig; + using SolverBCFree = typename SolverConfigBC::Free; + using SolverBCPeriodic = typename SolverConfigBC::Periodic; + using SolverBCShearingPeriodic = typename SolverConfigBC::ShearingPeriodic; + + if (SolverBCFree *c = std::get_if(&solver_config.boundary_config.config)) { + if (shamcomm::world_rank() == 0) { + logger::info_ln("PositionUpdated", "free boundaries skipping geometry update"); + } + } else if ( + SolverBCPeriodic *c + = std::get_if(&solver_config.boundary_config.config)) { + integrators.fields_apply_periodicity(ixyz, std::pair{bmin, bmax}); + } else if ( + SolverBCShearingPeriodic *c + = std::get_if(&solver_config.boundary_config.config)) { + // Apply shearing periodic boundaries (Stone 2010) - reuse SPH implementation + integrators.fields_apply_shearing_periodicity( + ixyz, + ivxyz, + std::pair{bmin, bmax}, + c->shear_base, + c->shear_dir, + c->shear_speed * time_val, + c->shear_speed); + } else { + shambase::throw_with_loc("GSPH: Unsupported boundary condition type."); + } + + reatrib.reatribute_patch_objects(storage.serial_patch_tree.get(), "xyz"); +} + +template class Kern> +void shammodels::gsph::Solver::do_predictor_leapfrog(Tscal dt) { + StackEntry stack_loc{}; + using namespace shamrock::patch; + + PatchDataLayerLayout &pdl = scheduler().pdl(); + const u32 ixyz = pdl.get_field_idx("xyz"); + const u32 ivxyz = pdl.get_field_idx("vxyz"); + const u32 iaxyz = pdl.get_field_idx("axyz"); + + const bool has_uint = solver_config.has_field_uint(); + const u32 iuint = has_uint ? pdl.get_field_idx("uint") : 0; + const u32 iduint = has_uint ? pdl.get_field_idx("duint") : 0; + + Tscal half_dt = dt / 2; + + // Predictor step: leapfrog kick-drift-kick + scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) { + u32 cnt = pdat.get_obj_cnt(); + if (cnt == 0) + return; + + auto &xyz_field = pdat.get_field(ixyz); + auto &vxyz_field = pdat.get_field(ivxyz); + auto &axyz_field = pdat.get_field(iaxyz); + + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + + // Forward euler: v += a*dt/2, x += v*dt, v += a*dt/2 (leapfrog kick-drift-kick) + sham::kernel_call( + dev_sched->get_queue(), + sham::MultiRef{axyz_field.get_buf()}, + sham::MultiRef{xyz_field.get_buf(), vxyz_field.get_buf()}, + cnt, + [half_dt, dt](u32 i, const Tvec *axyz, Tvec *xyz, Tvec *vxyz) { + // Kick: v += a*dt/2 + vxyz[i] += axyz[i] * half_dt; + // Drift: x += v*dt + xyz[i] += vxyz[i] * dt; + // Kick: v += a*dt/2 + vxyz[i] += axyz[i] * half_dt; + }); + + // Internal energy integration (if adiabatic EOS) + if (has_uint) { + auto &uint_field = pdat.get_field(iuint); + auto &duint_field = pdat.get_field(iduint); + + sham::kernel_call( + dev_sched->get_queue(), + sham::MultiRef{duint_field.get_buf()}, + sham::MultiRef{uint_field.get_buf()}, + cnt, + [dt](u32 i, const Tscal *duint, Tscal *uint) { + // u += du*dt + uint[i] += duint[i] * dt; + }); + } + }); +} + +template class Kern> +void shammodels::gsph::Solver::init_ghost_layout() { + StackEntry stack_loc{}; + + // Initialize xyzh_ghost_layout for BasicSPHGhostHandler (position + smoothing length) + storage.xyzh_ghost_layout = std::make_shared(); + storage.xyzh_ghost_layout->template add_field("xyz", 1); + storage.xyzh_ghost_layout->template add_field("hpart", 1); + + // Reset first in case it was set from a previous timestep + storage.ghost_layout.reset(); + storage.ghost_layout.set(std::make_shared()); + + shamrock::patch::PatchDataLayerLayout &ghost_layout + = shambase::get_check_ref(storage.ghost_layout.get()); + + solver_config.set_ghost_layout(ghost_layout); +} + +template class Kern> +void shammodels::gsph::Solver::communicate_merge_ghosts_fields() { + StackEntry stack_loc{}; + + shambase::Timer timer_interf; + timer_interf.start(); + + using namespace shamrock; + using namespace shamrock::patch; + + PatchDataLayerLayout &pdl = scheduler().pdl(); + const u32 ixyz = pdl.get_field_idx("xyz"); + const u32 ivxyz = pdl.get_field_idx("vxyz"); + const u32 ihpart = pdl.get_field_idx("hpart"); + + const bool has_uint = solver_config.has_field_uint(); + const u32 iuint = has_uint ? pdl.get_field_idx("uint") : 0; + + auto ghost_layout_ptr = storage.ghost_layout.get(); + shamrock::patch::PatchDataLayerLayout &ghost_layout = shambase::get_check_ref(ghost_layout_ptr); + u32 ihpart_interf = ghost_layout.get_field_idx("hpart"); + u32 ivxyz_interf = ghost_layout.get_field_idx("vxyz"); + u32 iomega_interf = ghost_layout.get_field_idx("omega"); + u32 idensity_interf = ghost_layout.get_field_idx("density"); + u32 iuint_interf = has_uint ? ghost_layout.get_field_idx("uint") : 0; + + using InterfaceBuildInfos = typename sph::BasicSPHGhostHandler::InterfaceBuildInfos; + + sph::BasicSPHGhostHandler &ghost_handle = storage.ghost_handler.get(); + shamrock::solvergraph::Field &omega = shambase::get_check_ref(storage.omega); + shamrock::solvergraph::Field &density = shambase::get_check_ref(storage.density); + + // Build interface data from ghost cache + auto pdat_interf = ghost_handle.template build_interface_native( + storage.ghost_patch_cache.get(), + [&](u64 sender, u64, InterfaceBuildInfos binfo, sham::DeviceBuffer &buf_idx, u32 cnt) { + PatchDataLayer pdat(ghost_layout_ptr); + pdat.reserve(cnt); + return pdat; + }); + + // Populate interface data with field values + ghost_handle.template modify_interface_native( + storage.ghost_patch_cache.get(), + pdat_interf, + [&](u64 sender, + u64, + InterfaceBuildInfos binfo, + sham::DeviceBuffer &buf_idx, + u32 cnt, + PatchDataLayer &pdat) { + PatchDataLayer &sender_patch = scheduler().patch_data.get_pdat(sender); + PatchDataField &sender_omega = omega.get(sender); + PatchDataField &sender_density = density.get(sender); + + sender_patch.get_field(ihpart).append_subset_to( + buf_idx, cnt, pdat.get_field(ihpart_interf)); + sender_patch.get_field(ivxyz).append_subset_to( + buf_idx, cnt, pdat.get_field(ivxyz_interf)); + sender_omega.append_subset_to(buf_idx, cnt, pdat.get_field(iomega_interf)); + sender_density.append_subset_to(buf_idx, cnt, pdat.get_field(idensity_interf)); + + if (has_uint) { + sender_patch.get_field(iuint).append_subset_to( + buf_idx, cnt, pdat.get_field(iuint_interf)); + } + }); + + // Apply velocity offset for periodic boundaries + ghost_handle.template modify_interface_native( + storage.ghost_patch_cache.get(), + pdat_interf, + [&](u64 sender, + u64, + InterfaceBuildInfos binfo, + sham::DeviceBuffer &buf_idx, + u32 cnt, + PatchDataLayer &pdat) { + if (sycl::length(binfo.offset_speed) > 0) { + pdat.get_field(ivxyz_interf).apply_offset(binfo.offset_speed); + } + }); + + // Communicate ghost data across MPI ranks + shambase::DistributedDataShared interf_pdat + = ghost_handle.communicate_pdat(ghost_layout_ptr, std::move(pdat_interf)); + + // Count total ghost particles per patch + std::map sz_interf_map; + interf_pdat.for_each([&](u64 s, u64 r, PatchDataLayer &pdat_interf) { + sz_interf_map[r] += pdat_interf.get_obj_cnt(); + }); + + // Merge local and ghost data + storage.merged_patchdata_ghost.set( + ghost_handle.template merge_native( + std::move(interf_pdat), + [&](const shamrock::patch::Patch p, shamrock::patch::PatchDataLayer &pdat) { + PatchDataLayer pdat_new(ghost_layout_ptr); + + u32 or_elem = pdat.get_obj_cnt(); + pdat_new.reserve(or_elem + sz_interf_map[p.id_patch]); + + PatchDataField &cur_omega = omega.get(p.id_patch); + PatchDataField &cur_density = density.get(p.id_patch); + + // Insert local particle data + pdat_new.get_field(ihpart_interf).insert(pdat.get_field(ihpart)); + pdat_new.get_field(ivxyz_interf).insert(pdat.get_field(ivxyz)); + pdat_new.get_field(iomega_interf).insert(cur_omega); + pdat_new.get_field(idensity_interf).insert(cur_density); + + if (has_uint) { + pdat_new.get_field(iuint_interf).insert(pdat.get_field(iuint)); + } + + pdat_new.check_field_obj_cnt_match(); + return pdat_new; + }, + [](PatchDataLayer &pdat, PatchDataLayer &pdat_interf) { + pdat.insert_elements(pdat_interf); + })); + + timer_interf.end(); + storage.timings_details.interface += timer_interf.elasped_sec(); +} + +template class Kern> +void shammodels::gsph::Solver::reset_merge_ghosts_fields() { + storage.merged_patchdata_ghost.reset(); +} + +template class Kern> +void shammodels::gsph::Solver::compute_omega() { + StackEntry stack_loc{}; + + using namespace shamrock; + using namespace shamrock::patch; + + const Tscal pmass = solver_config.gpart_mass; + + // Verify particle mass is valid + if (shamcomm::world_rank() == 0) { + if (pmass <= Tscal(0) || pmass < Tscal(1e-100) || !std::isfinite(pmass)) { + logger::warn_ln("GSPH", "Invalid particle mass in compute_omega: pmass =", pmass); + } + } + + shamrock::solvergraph::Field &omega_field = shambase::get_check_ref(storage.omega); + shamrock::solvergraph::Field &density_field = shambase::get_check_ref(storage.density); + + // Create sizes directly from scheduler to ensure we have all patches + std::shared_ptr> sizes + = std::make_shared>("sizes", "N"); + scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) { + sizes->indexes.add_obj(p.id_patch, pdat.get_obj_cnt()); + }); + + // Ensure fields are allocated for all patches with correct sizes + omega_field.ensure_sizes(sizes->indexes); + density_field.ensure_sizes(sizes->indexes); + + // Get patchdata layout for hpart field + PatchDataLayerLayout &pdl = scheduler().pdl(); + const u32 ihpart = pdl.get_field_idx("hpart"); + + // ========================================================================= + // OUTER-LOOP SMOOTHING LENGTH ITERATION (FIX FOR CACHE CONSISTENCY BUG) + // ========================================================================= + // The original implementation had an inner-loop Newton-Raphson iteration + // inside a GPU kernel. This caused issues because: + // 1. Neighbor cache was built with OLD h values (+ 10% tolerance) + // 2. Inner iteration could change h by more than 10% + // 3. Particles that should be neighbors weren't found in the cache + // 4. Result: underestimated density at discontinuities -> wrong forces + // + // The fix uses the SPH-style outer-loop approach: + // 1. Create GSPH IterateSmoothingLengthDensity module (ONE step per call) + // 2. Wrap in LoopSmoothingLengthIter for multiple iterations + // 3. If h grows beyond tolerance, signal for cache rebuild + // ========================================================================= + + auto &merged_xyzh = storage.merged_xyzh.get(); + + // Create field references for the iteration module + // Position spans (from merged xyzh) + std::shared_ptr> pos_merged + = std::make_shared>("pos", "r"); + shamrock::solvergraph::DDPatchDataFieldRef pos_refs = {}; + + // Old h spans (from merged xyzh - read only during iteration) + std::shared_ptr> hold + = std::make_shared>("h_old", "h^{old}"); + shamrock::solvergraph::DDPatchDataFieldRef hold_refs = {}; + + // New h spans (local patchdata - written during iteration) + std::shared_ptr> hnew + = std::make_shared>("h_new", "h^{new}"); + shamrock::solvergraph::DDPatchDataFieldRef hnew_refs = {}; + + // Populate field references + scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) { + auto &mfield = merged_xyzh.get(p.id_patch); + + // Position from merged data (includes ghosts for neighbor search) + pos_refs.add_obj(p.id_patch, std::ref(mfield.template get_field(0))); + + // h_old from merged data + hold_refs.add_obj(p.id_patch, std::ref(mfield.template get_field(1))); + + // h_new to local patchdata (this is updated during iteration) + hnew_refs.add_obj(p.id_patch, std::ref(pdat.get_field(ihpart))); + }); + + pos_merged->set_refs(pos_refs); + hold->set_refs(hold_refs); + hnew->set_refs(hnew_refs); + + // Initialize hnew with hold values + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) { + u32 cnt = pdat.get_obj_cnt(); + if (cnt == 0) + return; + + auto &mfield = merged_xyzh.get(p.id_patch); + auto &buf_hpart_merged = mfield.template get_field_buf_ref(1); + auto &buf_hpart_local = pdat.get_field_buf_ref(ihpart); + + sham::kernel_call( + dev_sched->get_queue(), + sham::MultiRef{buf_hpart_merged}, + sham::MultiRef{buf_hpart_local}, + cnt, + [](u32 i, const Tscal *h_old, Tscal *h_new) { + h_new[i] = h_old[i]; + }); + }); + + // Create epsilon field for convergence tracking + shamrock::SchedulerUtility utility(scheduler()); + ComputeField _epsilon_h = utility.make_compute_field("epsilon_h", 1); + + // Initialize epsilon to large value (not converged) + scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) { + u32 cnt = pdat.get_obj_cnt(); + if (cnt == 0) + return; + + auto &eps_buf = _epsilon_h.get_buf_check(p.id_patch); + + sham::kernel_call( + dev_sched->get_queue(), + sham::MultiRef{}, + sham::MultiRef{eps_buf}, + cnt, + [](u32 i, Tscal *eps) { + eps[i] = Tscal(1.0); // Start with large epsilon + }); + }); + + // Create epsilon field references + std::shared_ptr> eps_h + = std::make_shared>("eps_h", "\\epsilon_h"); + shamrock::solvergraph::DDPatchDataFieldRef eps_h_refs = {}; + scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) { + auto &field = _epsilon_h.get_field(p.id_patch); + eps_h_refs.add_obj(p.id_patch, std::ref(field)); + }); + eps_h->set_refs(eps_h_refs); + + // Use SPH's IterateSmoothingLengthDensity module (reuse, no duplication) + std::shared_ptr> smth_h_iter + = std::make_shared>( + solver_config.gpart_mass, + solver_config.htol_up_coarse_cycle, + solver_config.htol_up_fine_cycle); + + // SPH's module only iterates h, no density/omega outputs + smth_h_iter->set_edges(sizes, storage.neigh_cache, pos_merged, hold, hnew, eps_h); + + // Create convergence flag + std::shared_ptr> is_converged + = std::make_shared>("is_converged", "converged"); + + // Use LoopSmoothingLengthIter from SPH module for outer loop iteration + shammodels::sph::modules::LoopSmoothingLengthIter loop_smth_h_iter( + smth_h_iter, solver_config.epsilon_h, solver_config.h_iter_per_subcycles, false); + loop_smth_h_iter.set_edges(eps_h, is_converged); + + // Run the outer loop iteration + loop_smth_h_iter.evaluate(); + + // Check convergence + if (!is_converged->value) { + // Get convergence statistics + Tscal local_max_eps = shamrock::solvergraph::get_rank_max(*eps_h); + Tscal global_max_eps = shamalgs::collective::allreduce_max(local_max_eps); + + // Count particles that need cache rebuild (eps == -1) + u64 cnt_unconverged = 0; + scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) { + auto res = _epsilon_h.get_field(p.id_patch).get_ids_buf_where([](auto access, u32 id) { + return access[id] < Tscal(0); + }); + cnt_unconverged += std::get<1>(res); + }); + u64 global_cnt_unconverged = shamalgs::collective::allreduce_sum(cnt_unconverged); + + if (shamcomm::world_rank() == 0) { + if (global_cnt_unconverged > 0) { + logger::warn_ln( + "GSPH", + "Smoothing length iteration: ", + global_cnt_unconverged, + " particles need cache rebuild (h grew beyond tolerance)"); + } else { + logger::warn_ln( + "GSPH", + "Smoothing length iteration did not converge, max eps =", + global_max_eps); + } + } + } + + // ========================================================================= + // COMPUTE DENSITY AND OMEGA AFTER H CONVERGENCE + // ========================================================================= + // Now that h has converged, compute the final density and omega values. + // This is done ONCE here instead of on every iteration (more efficient). + // ========================================================================= + + static constexpr Tscal Rkern = Kernel::Rkern; + + auto &neigh_cache = storage.neigh_cache->neigh_cache; + + scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) { + u32 cnt = pdat.get_obj_cnt(); + if (cnt == 0) + return; + + auto &mfield = merged_xyzh.get(p.id_patch); + auto &pcache = neigh_cache.get(p.id_patch); + + // Get position and h from merged data (includes ghosts for neighbor search) + auto &buf_xyz = mfield.template get_field_buf_ref(0); + auto &buf_hpart = pdat.get_field_buf_ref(ihpart); + + // Get density and omega output fields + auto &dens_field = density_field.get_field(p.id_patch); + auto &omeg_field = omega_field.get_field(p.id_patch); + + sham::DeviceQueue &q = dev_sched->get_queue(); + sham::EventList depends_list; + + auto ploop_ptrs = pcache.get_read_access(depends_list); + auto xyz_acc = buf_xyz.get_read_access(depends_list); + auto h_acc = buf_hpart.get_read_access(depends_list); + auto density_acc = dens_field.get_buf().get_write_access(depends_list); + auto omega_acc = omeg_field.get_buf().get_write_access(depends_list); + + auto e = q.submit(depends_list, [&](sycl::handler &cgh) { + shamrock::tree::ObjectCacheIterator particle_looper(ploop_ptrs); + + shambase::parallel_for(cgh, cnt, "gsph_compute_density_omega", [=](u64 gid) { + u32 id_a = (u32) gid; + + Tvec xyz_a = xyz_acc[id_a]; + Tscal h_a = h_acc[id_a]; + Tscal dint = h_a * h_a * Rkern * Rkern; + + // SPH density summation + Tscal rho_sum = Tscal(0); + Tscal sumdWdh = Tscal(0); + + particle_looper.for_each_object(id_a, [&](u32 id_b) { + Tvec dr = xyz_a - xyz_acc[id_b]; + Tscal rab2 = sycl::dot(dr, dr); + + if (rab2 > dint) { + return; + } + + Tscal rab = sycl::sqrt(rab2); + + rho_sum += pmass * Kernel::W_3d(rab, h_a); + sumdWdh += pmass * Kernel::dhW_3d(rab, h_a); + }); + + // Store density + density_acc[id_a] = sycl::max(rho_sum, Tscal(1e-30)); + + // Compute omega (grad-h correction factor) + // omega = 1 / (1 + h/(dim*rho) * dh_rho) + Tscal omega_val = Tscal(1); + if (rho_sum > Tscal(1e-30)) { + omega_val = Tscal(1) / (Tscal(1) + h_a / (Tscal(dim) * rho_sum) * sumdWdh); + omega_val = sycl::clamp(omega_val, Tscal(0.5), Tscal(1.5)); + } + omega_acc[id_a] = omega_val; + }); + }); + + // Complete event states for all accessed buffers + pcache.complete_event_state({e}); + buf_xyz.complete_event_state(e); + buf_hpart.complete_event_state(e); + dens_field.get_buf().complete_event_state(e); + omeg_field.get_buf().complete_event_state(e); + }); +} + +template class Kern> +void shammodels::gsph::Solver::compute_eos_fields() { + StackEntry stack_loc{}; + + using namespace shamrock; + using namespace shamrock::patch; + + // GSPH EOS: Following reference implementation (g_pre_interaction.cpp) + // P = (γ - 1) * ρ * u where ρ is from SPH summation + // c = sqrt(γ * (γ - 1) * u) -- from internal energy, not from P/ρ + + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + const Tscal gamma = solver_config.get_eos_gamma(); + const bool has_uint = solver_config.has_field_uint(); + + // Get ghost layout field indices + shamrock::patch::PatchDataLayerLayout &ghost_layout + = shambase::get_check_ref(storage.ghost_layout.get()); + u32 idensity_interf = ghost_layout.get_field_idx("density"); + u32 iuint_interf = has_uint ? ghost_layout.get_field_idx("uint") : 0; + + shamrock::solvergraph::Field &pressure_field = shambase::get_check_ref(storage.pressure); + shamrock::solvergraph::Field &soundspeed_field + = shambase::get_check_ref(storage.soundspeed); + + // Size buffers to part_counts_with_ghost (includes ghosts!) + shambase::DistributedData &counts_with_ghosts + = shambase::get_check_ref(storage.part_counts_with_ghost).indexes; + + pressure_field.ensure_sizes(counts_with_ghosts); + soundspeed_field.ensure_sizes(counts_with_ghosts); + + // Iterate over merged_patchdata_ghost (includes local + ghost particles) + storage.merged_patchdata_ghost.get().for_each([&](u64 id, PatchDataLayer &mpdat) { + u32 total_elements + = shambase::get_check_ref(storage.part_counts_with_ghost).indexes.get(id); + if (total_elements == 0) + return; + + // Use SPH-summation density from communicated ghost data + sham::DeviceBuffer &buf_density = mpdat.get_field_buf_ref(idensity_interf); + auto &pressure_buf = pressure_field.get_field(id).get_buf(); + auto &soundspeed_buf = soundspeed_field.get_field(id).get_buf(); + + sham::DeviceQueue &q = dev_sched->get_queue(); + sham::EventList depends_list; + + auto density = buf_density.get_read_access(depends_list); + auto pressure = pressure_buf.get_write_access(depends_list); + auto soundspeed = soundspeed_buf.get_write_access(depends_list); + + const Tscal *uint_ptr = nullptr; + if (has_uint) { + uint_ptr = mpdat.get_field_buf_ref(iuint_interf).get_read_access(depends_list); + } + + auto e = q.submit(depends_list, [&](sycl::handler &cgh) { + shambase::parallel_for(cgh, total_elements, "compute_eos_gsph", [=](u64 gid) { + u32 i = (u32) gid; + + // Use SPH-summation density (from compute_omega, communicated to ghosts) + Tscal rho = density[i]; + rho = sycl::max(rho, Tscal(1e-30)); + + if (has_uint && uint_ptr != nullptr) { + // Adiabatic EOS (reference: g_pre_interaction.cpp line 107) + // P = (γ - 1) * ρ * u + Tscal u = uint_ptr[i]; + u = sycl::max(u, Tscal(1e-30)); + Tscal P = (gamma - Tscal(1.0)) * rho * u; + + // Sound speed from internal energy (reference: solver.cpp line 2661) + // c = sqrt(γ * (γ - 1) * u) + Tscal cs = sycl::sqrt(gamma * (gamma - Tscal(1.0)) * u); + + // Clamp to reasonable values + P = sycl::clamp(P, Tscal(1e-30), Tscal(1e30)); + cs = sycl::clamp(cs, Tscal(1e-10), Tscal(1e10)); + + pressure[i] = P; + soundspeed[i] = cs; + } else { + // Isothermal case + Tscal cs = Tscal(1.0); + Tscal P = cs * cs * rho; + + pressure[i] = P; + soundspeed[i] = cs; + } + }); + }); + + // Complete all buffer event states + buf_density.complete_event_state(e); + if (has_uint) { + mpdat.get_field_buf_ref(iuint_interf).complete_event_state(e); + } + pressure_buf.complete_event_state(e); + soundspeed_buf.complete_event_state(e); + }); +} + +template class Kern> +void shammodels::gsph::Solver::reset_eos_fields() { + // Reset computed EOS fields - they're recomputed each timestep +} + +template class Kern> +void shammodels::gsph::Solver::prepare_corrector() { + StackEntry stack_loc{}; + + shamrock::SchedulerUtility utility(scheduler()); + shamrock::patch::PatchDataLayerLayout &pdl = scheduler().pdl(); + + const u32 iaxyz = pdl.get_field_idx("axyz"); + + // Create compute field to store old acceleration + auto old_axyz = utility.make_compute_field("old_axyz", 1); + + // Copy current acceleration to old_axyz + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + + scheduler().for_each_patchdata_nonempty( + [&](const shamrock::patch::Patch p, shamrock::patch::PatchDataLayer &pdat) { + u32 cnt = pdat.get_obj_cnt(); + if (cnt == 0) + return; + + auto &axyz_field = pdat.get_field(iaxyz); + auto &old_axyz_field = old_axyz.get_field(p.id_patch); + + // Copy using kernel_call + sham::kernel_call( + dev_sched->get_queue(), + sham::MultiRef{axyz_field.get_buf()}, + sham::MultiRef{old_axyz_field.get_buf()}, + cnt, + [](u32 i, const Tvec *src, Tvec *dst) { + dst[i] = src[i]; + }); + }); + + storage.old_axyz.set(std::move(old_axyz)); + + if (solver_config.has_field_uint()) { + const u32 iduint = pdl.get_field_idx("duint"); + auto old_duint = utility.make_compute_field("old_duint", 1); + + scheduler().for_each_patchdata_nonempty( + [&](const shamrock::patch::Patch p, shamrock::patch::PatchDataLayer &pdat) { + u32 cnt = pdat.get_obj_cnt(); + if (cnt == 0) + return; + + auto &duint_field = pdat.get_field(iduint); + auto &old_duint_field = old_duint.get_field(p.id_patch); + + // Copy using kernel_call + sham::kernel_call( + dev_sched->get_queue(), + sham::MultiRef{duint_field.get_buf()}, + sham::MultiRef{old_duint_field.get_buf()}, + cnt, + [](u32 i, const Tscal *src, Tscal *dst) { + dst[i] = src[i]; + }); + }); + + storage.old_duint.set(std::move(old_duint)); + } +} + +template class Kern> +void shammodels::gsph::Solver::update_derivs() { + StackEntry stack_loc{}; + // GSPH derivative update using Riemann solver + gsph::modules::UpdateDerivs(context, solver_config, storage).update_derivs(); +} + +template class Kern> +typename shammodels::gsph::Solver::Tscal shammodels::gsph::Solver:: + compute_dt_cfl() { + StackEntry stack_loc{}; + + using namespace shamrock; + using namespace shamrock::patch; + + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + + PatchDataLayerLayout &pdl = scheduler().pdl(); + const u32 ihpart = pdl.get_field_idx("hpart"); + const u32 iaxyz = pdl.get_field_idx("axyz"); + + shamrock::solvergraph::Field &soundspeed_field + = shambase::get_check_ref(storage.soundspeed); + + Tscal C_cour = solver_config.cfl_config.cfl_cour; + Tscal C_force = solver_config.cfl_config.cfl_force; + + // Use ComputeField for proper reduction support + shamrock::SchedulerUtility utility(scheduler()); + ComputeField cfl_dt = utility.make_compute_field("cfl_dt", 1); + + scheduler().for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) { + u32 cnt = pdat.get_obj_cnt(); + if (cnt == 0) + return; + + auto &buf_hpart = pdat.get_field_buf_ref(ihpart); + auto &buf_axyz = pdat.get_field_buf_ref(iaxyz); + auto &buf_cs = soundspeed_field.get_field(cur_p.id_patch).get_buf(); + auto &cfl_dt_buf = cfl_dt.get_buf_check(cur_p.id_patch); + + sham::DeviceQueue &q = dev_sched->get_queue(); + sham::EventList depends_list; + + auto hpart = buf_hpart.get_read_access(depends_list); + auto axyz = buf_axyz.get_read_access(depends_list); + auto cs = buf_cs.get_read_access(depends_list); + auto cfl_dt_acc = cfl_dt_buf.get_write_access(depends_list); + + auto e = q.submit(depends_list, [&](sycl::handler &cgh) { + shambase::parallel_for(cgh, cnt, "gsph_compute_cfl_dt", [=](u64 gid) { + u32 i = (u32) gid; + + Tscal h_i = hpart[i]; + Tscal cs_i = cs[i]; + Tscal abs_a = sycl::length(axyz[i]); + + // Guard against invalid values (NaN/Inf) + if (!sycl::isfinite(h_i) || h_i <= Tscal(0)) + h_i = Tscal(1e-10); + if (!sycl::isfinite(cs_i) || cs_i <= Tscal(0)) + cs_i = Tscal(1e-10); + if (!sycl::isfinite(abs_a)) + abs_a = Tscal(1e30); + + // Sound CFL condition: dt = C_cour * h / c_s + // Following Kitajima et al. (2025) simple form for GSPH + Tscal dt_c = C_cour * h_i / cs_i; + + // Force condition: dt = C_force * sqrt(h / |a|) + Tscal dt_f = C_force * sycl::sqrt(h_i / (abs_a + Tscal(1e-30))); + + Tscal dt_min = sycl::min(dt_c, dt_f); + + // Ensure a valid finite timestep with minimum floor + if (!sycl::isfinite(dt_min) || dt_min <= Tscal(0)) { + dt_min = Tscal(1e-10); // Minimum timestep floor + } + + cfl_dt_acc[i] = dt_min; + }); + }); + + buf_hpart.complete_event_state(e); + buf_axyz.complete_event_state(e); + buf_cs.complete_event_state(e); + cfl_dt_buf.complete_event_state(e); + }); + + // Compute minimum across all patches on this rank + Tscal rank_dt = cfl_dt.compute_rank_min(); + + // Guard against invalid reduction result + if (!std::isfinite(rank_dt) || rank_dt <= Tscal(0)) { + rank_dt = Tscal(1e-6); // Reasonable floor for SPH simulations + } + + // Global reduction across MPI ranks + Tscal global_min_dt = shamalgs::collective::allreduce_min(rank_dt); + + // Final safety floor to prevent simulation stalling + // For typical SPH simulations, timestep should be O(h/cs) ~ O(1e-4) + // Use 1e-6 as minimum floor to prevent extreme stalling + const Tscal dt_min_floor = Tscal(1e-6); + if (!std::isfinite(global_min_dt) || global_min_dt < dt_min_floor) { + global_min_dt = dt_min_floor; + } + + return global_min_dt; +} + +template class Kern> +bool shammodels::gsph::Solver::apply_corrector(Tscal dt, u64 Npart_all) { + StackEntry stack_loc{}; + + shamrock::patch::PatchDataLayerLayout &pdl = scheduler().pdl(); + + const u32 ivxyz = pdl.get_field_idx("vxyz"); + const u32 iaxyz = pdl.get_field_idx("axyz"); + + Tscal half_dt = Tscal{0.5} * dt; + + // Corrector: v = v + 0.5*(a_new - a_old)*dt + scheduler().for_each_patchdata_nonempty( + [&](const shamrock::patch::Patch p, shamrock::patch::PatchDataLayer &pdat) { + u32 cnt = pdat.get_obj_cnt(); + if (cnt == 0) + return; + + auto &vxyz = pdat.get_field(ivxyz); + auto &axyz = pdat.get_field(iaxyz); + auto &old_axyz = storage.old_axyz.get().get_field(p.id_patch); + + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + + sham::kernel_call( + dev_sched->get_queue(), + sham::MultiRef{axyz.get_buf(), old_axyz.get_buf()}, + sham::MultiRef{vxyz.get_buf()}, + cnt, + [half_dt](u32 i, const Tvec *axyz_new, const Tvec *axyz_old, Tvec *vxyz) { + vxyz[i] += half_dt * (axyz_new[i] - axyz_old[i]); + }); + }); + + if (solver_config.has_field_uint()) { + const u32 iuint = pdl.get_field_idx("uint"); + const u32 iduint = pdl.get_field_idx("duint"); + + scheduler().for_each_patchdata_nonempty( + [&](const shamrock::patch::Patch p, shamrock::patch::PatchDataLayer &pdat) { + u32 cnt = pdat.get_obj_cnt(); + if (cnt == 0) + return; + + auto &uint_field = pdat.get_field(iuint); + auto &duint = pdat.get_field(iduint); + auto &old_duint = storage.old_duint.get().get_field(p.id_patch); + + auto dev_sched = shamsys::instance::get_compute_scheduler_ptr(); + + sham::kernel_call( + dev_sched->get_queue(), + sham::MultiRef{duint.get_buf(), old_duint.get_buf()}, + sham::MultiRef{uint_field.get_buf()}, + cnt, + [half_dt](u32 i, const Tscal *duint_new, const Tscal *duint_old, Tscal *uint) { + uint[i] += half_dt * (duint_new[i] - duint_old[i]); + }); + }); + + storage.old_duint.reset(); + } + + storage.old_axyz.reset(); + + return true; +} + +template class Kern> +void shammodels::gsph::Solver::update_sync_load_values() {} + +template class Kern> +shammodels::gsph::TimestepLog shammodels::gsph::Solver::evolve_once() { + + // Validate configuration before running + solver_config.check_config_runtime(); + + Tscal t_current = solver_config.get_time(); + Tscal dt = solver_config.get_dt(); + + StackEntry stack_loc{}; + + if (shamcomm::world_rank() == 0) { + shamcomm::logs::raw_ln( + shambase::format( + "---------------- GSPH t = {}, dt = {} ----------------", t_current, dt)); + } + + shambase::Timer tstep; + tstep.start(); + + // Load balancing step + scheduler().scheduler_step(true, true); + scheduler().scheduler_step(false, false); + + // Give to the solvergraph the patch rank owners + storage.patch_rank_owner->values = {}; + scheduler().for_each_global_patch([&](const shamrock::patch::Patch p) { + storage.patch_rank_owner->values.add_obj( + p.id_patch, scheduler().get_patch_rank_owner(p.id_patch)); + }); + + using namespace shamrock; + using namespace shamrock::patch; + + u64 Npart_all = scheduler().get_total_obj_count(); + + // ========================================================================= + // CORRECTED SIMULATION LOOP ORDER (matching reference SPH code) + // ========================================================================= + // The key insight from the reference code is that density/EOS must be + // computed AFTER the predictor step, on the NEW positions. Otherwise, + // the forces are computed using stale EOS values. + // + // Loop order: + // 1. PREDICTOR: move particles using OLD accelerations + // 2. BOUNDARY: apply periodic/free boundary conditions + // 3. TREE BUILD: build spatial trees on NEW positions + // 4. DENSITY/EOS: compute density, pressure, soundspeed on NEW positions + // 5. FORCES: compute accelerations using FRESH EOS + // 6. CORRECTOR: refine velocities using average of old/new accelerations + // 7. CFL: compute next timestep + // ========================================================================= + + // STEP 1: PREDICTOR - move particles using OLD accelerations + // (On first iteration, accelerations are zero, so this is just position drift) + do_predictor_leapfrog(dt); + + // STEP 2: BOUNDARY - apply boundary conditions to NEW positions + // Build serial patch tree first (needed for boundary application) + gen_serial_patch_tree(); + apply_position_boundary(t_current + dt); + + // STEP 3: TREE BUILD - build trees on NEW positions + // Generate ghost handler for the new positions + gen_ghost_handler(t_current + dt); + + // Build ghost cache for interface exchange + build_ghost_cache(); + + // Merge positions with ghosts + merge_position_ghost(); + + // Build trees over merged positions + build_merged_pos_trees(); + + // Compute interaction ranges + compute_presteps_rint(); + + // Build neighbor cache + start_neighbors_cache(); + + // STEP 4: DENSITY/OMEGA - compute on NEW positions + // Compute omega (grad-h correction factor) - needed for force computation + compute_omega(); + + // Initialize ghost layout BEFORE communication + init_ghost_layout(); + + // Communicate ghost fields (hpart, uint, vxyz, omega) + // This MUST happen BEFORE compute_eos_fields so EOS can be computed for ghosts + communicate_merge_ghosts_fields(); + + // STEP 4b: EOS - compute AFTER ghost communication (CRITICAL!) + // This ensures P and cs are computed for ALL particles (local + ghost) + // Following SPH pattern: EOS is computed on merged_patchdata_ghost + compute_eos_fields(); + + // STEP 5: FORCES - compute accelerations using FRESH EOS + // Save old accelerations for corrector + prepare_corrector(); + + // Update derivatives using GSPH Riemann solver + update_derivs(); + + // STEP 6: CORRECTOR - refine velocities + apply_corrector(dt, Npart_all); + + // STEP 7: CFL - compute next timestep + Tscal dt_next = compute_dt_cfl(); + + // Ensure dt doesn't grow too fast (max 2x per step), but allow any value if dt was 0 + if (dt > Tscal(0)) { + dt_next = sham::min(dt_next, Tscal(2) * dt); + } + + // Cleanup for next iteration + reset_neighbors_cache(); + reset_presteps_rint(); + clear_merged_pos_trees(); + reset_merge_ghosts_fields(); + storage.merged_xyzh.reset(); + clear_ghost_cache(); + reset_serial_patch_tree(); + reset_ghost_handler(); + storage.ghost_layout.reset(); + + // Update time + solver_config.set_time(t_current + dt); + solver_config.set_next_dt(dt_next); + + solve_logs.step_count++; + + tstep.end(); + + // Prepare timing log + TimestepLog log; + log.rank = shamcomm::world_rank(); + log.rate = Tscal(Npart_all) / tstep.elasped_sec(); + log.npart = Npart_all; + log.tcompute = tstep.elasped_sec(); + + return log; +} + +// Template instantiations +using namespace shammath; + +// M-spline kernels (Monaghan) +template class shammodels::gsph::Solver; +template class shammodels::gsph::Solver; +template class shammodels::gsph::Solver; + +// Wendland kernels (C2, C4, C6) - recommended for GSPH (Inutsuka 2002) +template class shammodels::gsph::Solver; +template class shammodels::gsph::Solver; +template class shammodels::gsph::Solver; diff --git a/src/shammodels/gsph/src/modules/io/VTKDump.cpp b/src/shammodels/gsph/src/modules/io/VTKDump.cpp new file mode 100644 index 000000000..b3e1f6938 --- /dev/null +++ b/src/shammodels/gsph/src/modules/io/VTKDump.cpp @@ -0,0 +1,293 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2025 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +/** + * @file VTKDump.cpp + * @author Guo Yansong (guo.yansong.ngy@gmail.com) + * @author Yona Lapeyre (yona.lapeyre@ens-lyon.fr) --no git blame-- + * @brief VTK dump implementation for GSPH solver + */ + +#include "shammodels/gsph/modules/io/VTKDump.hpp" +#include "shamalgs/memory.hpp" +#include "shambackends/kernel_call.hpp" +#include "shamcomm/worldInfo.hpp" +#include "shammodels/sph/math/density.hpp" +#include "shamrock/io/LegacyVtkWritter.hpp" +#include "shamrock/scheduler/SchedulerUtility.hpp" +#include "shamsys/NodeInstance.hpp" + +namespace { + + template + shamrock::LegacyVtkWritter start_dump(PatchScheduler &sched, std::string dump_name) { + StackEntry stack_loc{}; + shamrock::LegacyVtkWritter writer(dump_name, true, shamrock::UnstructuredGrid); + + using namespace shamrock::patch; + + u64 num_obj = sched.get_rank_count(); + + shamlog_debug_mpi_ln("gsph::vtk", "rank count =", num_obj); + + std::unique_ptr> pos = sched.rankgather_field(0); + + writer.write_points(pos, num_obj); + + return writer; + } + + void vtk_dump_add_patch_id(PatchScheduler &sched, shamrock::LegacyVtkWritter &writter) { + StackEntry stack_loc{}; + + u64 num_obj = sched.get_rank_count(); + + using namespace shamrock::patch; + + if (num_obj > 0) { + sycl::buffer idp(num_obj); + + u64 ptr = 0; + sched.for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) { + using namespace shamalgs::memory; + using namespace shambase; + + write_with_offset_into( + shamsys::instance::get_compute_queue(), + idp, + cur_p.id_patch, + ptr, + pdat.get_obj_cnt()); + + ptr += pdat.get_obj_cnt(); + }); + + writter.write_field("patchid", idp, num_obj); + } else { + writter.write_field_no_buf("patchid"); + } + } + + void vtk_dump_add_worldrank(PatchScheduler &sched, shamrock::LegacyVtkWritter &writter) { + StackEntry stack_loc{}; + + using namespace shamrock::patch; + u64 num_obj = sched.get_rank_count(); + + if (num_obj > 0) { + sycl::buffer idp(num_obj); + + u64 ptr = 0; + sched.for_each_patchdata_nonempty([&](Patch cur_p, PatchDataLayer &pdat) { + using namespace shamalgs::memory; + using namespace shambase; + + write_with_offset_into( + shamsys::instance::get_compute_queue(), + idp, + shamcomm::world_rank(), + ptr, + pdat.get_obj_cnt()); + + ptr += pdat.get_obj_cnt(); + }); + + writter.write_field("world_rank", idp, num_obj); + } else { + writter.write_field_no_buf("world_rank"); + } + } + + template + void vtk_dump_add_compute_field( + PatchScheduler &sched, + shamrock::LegacyVtkWritter &writter, + shamrock::ComputeField &field, + std::string field_dump_name) { + StackEntry stack_loc{}; + + using namespace shamrock::patch; + u64 num_obj = sched.get_rank_count(); + + if (num_obj > 0) { + std::unique_ptr> field_vals = field.rankgather_computefield(sched); + + writter.write_field(field_dump_name, field_vals, num_obj); + } else { + writter.write_field_no_buf(field_dump_name); + } + } + + template + void vtk_dump_add_field( + PatchScheduler &sched, + shamrock::LegacyVtkWritter &writter, + u32 field_idx, + std::string field_dump_name) { + StackEntry stack_loc{}; + + using namespace shamrock::patch; + u64 num_obj = sched.get_rank_count(); + + if (num_obj > 0) { + std::unique_ptr> field_vals = sched.rankgather_field(field_idx); + + writter.write_field(field_dump_name, field_vals, num_obj); + } else { + writter.write_field_no_buf(field_dump_name); + } + } + +} // anonymous namespace + +namespace shammodels::gsph::modules { + + template class SPHKernel> + void VTKDump::do_dump(std::string filename, bool add_patch_world_id) { + + StackEntry stack_loc{}; + + using namespace shamrock; + using namespace shamrock::patch; + shamrock::SchedulerUtility utility(scheduler()); + + PatchDataLayerLayout &pdl = scheduler().pdl(); + const u32 ixyz = pdl.get_field_idx("xyz"); + const u32 ivxyz = pdl.get_field_idx("vxyz"); + const u32 iaxyz = pdl.get_field_idx("axyz"); + const u32 ihpart = pdl.get_field_idx("hpart"); + + // Check for optional internal energy field + const bool has_uint = solver_config.has_field_uint(); + const u32 iuint = has_uint ? pdl.get_field_idx("uint") : 0; + + // Compute density field from smoothing length + ComputeField density = utility.make_compute_field("rho", 1); + + scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) { + shamlog_debug_ln("gsph::vtk", "compute rho field for patch ", p.id_patch); + + auto &buf_hpart = pdat.get_field(ihpart).get_buf(); + + auto sptr = shamsys::instance::get_compute_scheduler_ptr(); + auto &q = sptr->get_queue(); + + sham::EventList depends_list; + const Tscal *acc_h = buf_hpart.get_read_access(depends_list); + auto acc_rho = density.get_buf(p.id_patch).get_write_access(depends_list); + + auto e = q.submit(depends_list, [&](sycl::handler &cgh) { + const Tscal part_mass = solver_config.gpart_mass; + + cgh.parallel_for(sycl::range<1>{pdat.get_obj_cnt()}, [=](sycl::item<1> item) { + u32 gid = (u32) item.get_id(); + using namespace shamrock::sph; + Tscal rho_ha = rho_h(part_mass, acc_h[gid], Kernel::hfactd); + acc_rho[gid] = rho_ha; + }); + }); + + buf_hpart.complete_event_state(e); + density.get_buf(p.id_patch).complete_event_state(e); + }); + + // Compute pressure field from EOS + ComputeField pressure_field = utility.make_compute_field("P", 1); + + scheduler().for_each_patchdata_nonempty([&](const Patch p, PatchDataLayer &pdat) { + auto &buf_hpart = pdat.get_field(ihpart).get_buf(); + + auto sptr = shamsys::instance::get_compute_scheduler_ptr(); + auto &q = sptr->get_queue(); + + sham::EventList depends_list; + const Tscal *acc_h = buf_hpart.get_read_access(depends_list); + auto acc_P = pressure_field.get_buf(p.id_patch).get_write_access(depends_list); + + const Tscal *acc_u = nullptr; + if (has_uint) { + acc_u = pdat.get_field(iuint).get_buf().get_read_access(depends_list); + } + + auto e = q.submit(depends_list, [&](sycl::handler &cgh) { + const Tscal part_mass = solver_config.gpart_mass; + const Tscal gamma = solver_config.get_eos_gamma(); + const bool do_uint = has_uint; + + cgh.parallel_for(sycl::range<1>{pdat.get_obj_cnt()}, [=](sycl::item<1> item) { + u32 gid = (u32) item.get_id(); + using namespace shamrock::sph; + Tscal rho = rho_h(part_mass, acc_h[gid], Kernel::hfactd); + + if (do_uint && acc_u != nullptr) { + // Adiabatic EOS: P = (gamma - 1) * rho * u + acc_P[gid] = (gamma - Tscal(1)) * rho * acc_u[gid]; + } else { + // Isothermal: use cs = 1 by default + acc_P[gid] = rho; // P = cs^2 * rho with cs = 1 + } + }); + }); + + buf_hpart.complete_event_state(e); + pressure_field.get_buf(p.id_patch).complete_event_state(e); + if (has_uint) { + pdat.get_field(iuint).get_buf().complete_event_state(e); + } + }); + + shamrock::LegacyVtkWritter writter = start_dump(scheduler(), filename); + writter.add_point_data_section(); + + // Count fields to write + u32 fnum = 0; + if (add_patch_world_id) { + fnum += 2; // patchid and world_rank + } + fnum++; // h + fnum++; // v + fnum++; // a + fnum++; // rho + fnum++; // P + + if (has_uint) { + fnum++; // u + } + + writter.add_field_data_section(fnum); + + if (add_patch_world_id) { + vtk_dump_add_patch_id(scheduler(), writter); + vtk_dump_add_worldrank(scheduler(), writter); + } + + vtk_dump_add_field(scheduler(), writter, ihpart, "h"); + vtk_dump_add_field(scheduler(), writter, ivxyz, "v"); + vtk_dump_add_field(scheduler(), writter, iaxyz, "a"); + + if (has_uint) { + vtk_dump_add_field(scheduler(), writter, iuint, "u"); + } + + vtk_dump_add_compute_field(scheduler(), writter, density, "rho"); + vtk_dump_add_compute_field(scheduler(), writter, pressure_field, "P"); + } + +} // namespace shammodels::gsph::modules + +// Explicit template instantiations +using namespace shammath; + +template class shammodels::gsph::modules::VTKDump; +template class shammodels::gsph::modules::VTKDump; +template class shammodels::gsph::modules::VTKDump; + +template class shammodels::gsph::modules::VTKDump; +template class shammodels::gsph::modules::VTKDump; +template class shammodels::gsph::modules::VTKDump; diff --git a/src/shammodels/gsph/src/pyGSPHModel.cpp b/src/shammodels/gsph/src/pyGSPHModel.cpp new file mode 100644 index 000000000..f04ee2196 --- /dev/null +++ b/src/shammodels/gsph/src/pyGSPHModel.cpp @@ -0,0 +1,503 @@ +// -------------------------------------------------------// +// +// SHAMROCK code for hydrodynamics +// Copyright (c) 2021-2025 Timothée David--Cléris +// SPDX-License-Identifier: CeCILL Free Software License Agreement v2.1 +// Shamrock is licensed under the CeCILL 2.1 License, see LICENSE for more information +// +// -------------------------------------------------------// + +/** + * @file pyGSPHModel.cpp + * @author Guo Yansong (guo.yansong.ngy@gmail.com) + * @author Yona Lapeyre (yona.lapeyre@ens-lyon.fr) --no git blame-- + * @brief Python bindings for the GSPH (Godunov SPH) model + * + * This provides Python interface for GSPH simulations using Riemann solvers. + * + * References: + * - Inutsuka, S. (2002) "Reformulation of Smoothed Particle Hydrodynamics + * with Riemann Solver" + * - Cha, S.-H. & Whitworth, A.P. (2003) "Implementations and tests of + * Godunov-type particle hydrodynamics" + */ + +#include "shambase/exception.hpp" +#include "shambase/memory.hpp" +#include "shambindings/pybindaliases.hpp" +#include "shambindings/pytypealias.hpp" +#include "shamcomm/worldInfo.hpp" +#include "shammath/sphkernels.hpp" +#include "shammodels/gsph/Model.hpp" +#include "shamrock/scheduler/PatchScheduler.hpp" +#include +#include +#include + +template class SPHKernel> +void add_gsph_instance(py::module &m, std::string name_config, std::string name_model) { + using namespace shammodels::gsph; + + using Tscal = shambase::VecComponent; + + using T = Model; + using TConfig = typename T::SolverConfig; + + shamlog_debug_ln("[Py]", "registering class :", name_config, typeid(T).name()); + shamlog_debug_ln("[Py]", "registering class :", name_model, typeid(T).name()); + + py::class_(m, name_config.c_str()) + .def("print_status", &TConfig::print_status) + .def("set_tree_reduction_level", &TConfig::set_tree_reduction_level) + .def("set_two_stage_search", &TConfig::set_two_stage_search) + // Riemann solver config + .def( + "set_riemann_iterative", + [](TConfig &self, Tscal tol, u32 max_iter) { + self.set_riemann_iterative(tol, max_iter); + }, + py::kw_only(), + py::arg("tolerance") = Tscal{1e-6}, + py::arg("max_iter") = 20, + R"==( + Set iterative Riemann solver (van Leer 1997). + + This is the most accurate but slower Riemann solver. + Uses Newton-Raphson iteration to find the pressure in the star region. + + Parameters + ---------- + tolerance : float + Convergence tolerance for Newton-Raphson iteration (default: 1e-6) + max_iter : int + Maximum number of iterations (default: 20) +)==") + .def( + "set_riemann_hllc", + [](TConfig &self) { + self.set_riemann_hllc(); + }, + R"==( + Set HLLC approximate Riemann solver. + + Fast approximate Riemann solver that captures contact discontinuities. + Recommended for general use - good balance of accuracy and speed. +)==") + // Reconstruction config + .def( + "set_reconstruct_piecewise_constant", + [](TConfig &self) { + self.set_reconstruct_piecewise_constant(); + }, + R"==( + Set first-order piecewise constant reconstruction. + + Sets all gradients to zero. Most diffusive but most stable. + Good for very strong shocks or initial testing. +)==") + .def( + "set_reconstruct_muscl", + [](TConfig &self) { + self.set_reconstruct_muscl(); + }, + R"==( + Set second-order MUSCL reconstruction with Van Leer limiter. + + Uses computed gradients with slope limiter for monotonicity. + Better accuracy at smooth regions while maintaining stability at shocks. +)==") + // EOS config + .def( + "set_eos_adiabatic", + [](TConfig &self, Tscal gamma) { + self.set_eos_adiabatic(gamma); + }, + py::arg("gamma"), + R"==( + Set adiabatic equation of state: P = (γ-1) × ρ × u + + Parameters + ---------- + gamma : float + Adiabatic index (e.g., 5/3 for monatomic gas, 7/5 for diatomic) +)==") + .def( + "set_eos_isothermal", + [](TConfig &self, Tscal cs) { + self.set_eos_isothermal(cs); + }, + py::arg("cs"), + R"==( + Set isothermal equation of state: P = cs² × ρ + + Parameters + ---------- + cs : float + Sound speed +)==") + // Boundary config + .def("set_boundary_free", &TConfig::set_boundary_free) + .def("set_boundary_periodic", &TConfig::set_boundary_periodic) + // External forces + .def( + "add_ext_force_point_mass", + [](TConfig &self, Tscal central_mass, Tscal Racc) { + self.add_ext_force_point_mass(central_mass, Racc); + }, + py::kw_only(), + py::arg("central_mass"), + py::arg("Racc")) + // Units + .def("set_units", &TConfig::set_units) + // CFL + .def( + "set_cfl_cour", + [](TConfig &self, Tscal cfl_cour) { + self.cfl_config.cfl_cour = cfl_cour; + }) + .def( + "set_cfl_force", + [](TConfig &self, Tscal cfl_force) { + self.cfl_config.cfl_force = cfl_force; + }) + .def( + "set_particle_mass", + [](TConfig &self, Tscal gpart_mass) { + self.gpart_mass = gpart_mass; + }) + .def("to_json", [](TConfig &self) { + return nlohmann::json{self}.dump(4); + }); + + py::class_(m, name_model.c_str()) + .def(py::init([](ShamrockCtx &ctx) { + return std::make_unique(ctx); + })) + .def("init_scheduler", &T::init_scheduler) + .def("evolve_once", &T::evolve_once) + .def( + "evolve_until", + [](T &self, f64 target_time, i32 niter_max) { + return self.evolve_until(target_time, niter_max); + }, + py::arg("target_time"), + py::kw_only(), + py::arg("niter_max") = -1) + .def("timestep", &T::timestep) + .def("set_cfl_cour", &T::set_cfl_cour, py::arg("cfl_cour")) + .def("set_cfl_force", &T::set_cfl_force, py::arg("cfl_force")) + .def("set_particle_mass", &T::set_particle_mass, py::arg("gpart_mass")) + .def("get_particle_mass", &T::get_particle_mass) + .def("rho_h", &T::rho_h) + .def("get_hfact", &T::get_hfact) + .def( + "get_box_dim_fcc_3d", + [](T &self, f64 dr, u32 xcnt, u32 ycnt, u32 zcnt) { + return self.get_box_dim_fcc_3d(dr, xcnt, ycnt, zcnt); + }) + .def( + "get_ideal_fcc_box", + [](T &self, f64 dr, f64_3 box_min, f64_3 box_max) { + return self.get_ideal_fcc_box(dr, {box_min, box_max}); + }) + .def( + "get_ideal_hcp_box", + [](T &self, f64 dr, f64_3 box_min, f64_3 box_max) { + return self.get_ideal_hcp_box(dr, {box_min, box_max}); + }) + .def( + "resize_simulation_box", + [](T &self, f64_3 box_min, f64_3 box_max) { + return self.resize_simulation_box({box_min, box_max}); + }) + .def( + "add_cube_fcc_3d", + [](T &self, f64 dr, f64_3 box_min, f64_3 box_max) { + return self.add_cube_fcc_3d(dr, {box_min, box_max}); + }) + .def( + "add_cube_hcp_3d", + [](T &self, f64 dr, f64_3 box_min, f64_3 box_max) { + return self.add_cube_hcp_3d(dr, {box_min, box_max}); + }) + .def("get_total_part_count", &T::get_total_part_count) + .def("total_mass_to_part_mass", &T::total_mass_to_part_mass) + .def( + "set_field_in_box", + [](T &self, + std::string field_name, + std::string field_type, + pybind11::object value, + f64_3 box_min, + f64_3 box_max, + u32 ivar) { + if (field_type == "f64") { + f64 val = value.cast(); + self.set_field_in_box(field_name, val, {box_min, box_max}, ivar); + } else if (field_type == "f64_3") { + f64_3 val = value.cast(); + self.set_field_in_box(field_name, val, {box_min, box_max}, ivar); + } else if (field_type == "u32") { + u32 val = value.cast(); + self.set_field_in_box(field_name, val, {box_min, box_max}, ivar); + } else { + throw shambase::make_except_with_loc( + "unknown field type: " + field_type + ". Valid types: f64, f64_3, u32"); + } + }, + py::arg("field_name"), + py::arg("field_type"), + py::arg("value"), + py::arg("box_min"), + py::arg("box_max"), + py::kw_only(), + py::arg("ivar") = 0, + R"==( + Set field value for particles within a box region. + + Useful for setting up discontinuous initial conditions like Sod shock tube. + + Parameters + ---------- + field_name : str + Name of the field to set (e.g., "vxyz", "uint", "hpart") + field_type : str + Type of the field: "f64", "f64_3", or "u32" + value : float, tuple, or int + Value to set (type must match field_type) + box_min : tuple + Minimum corner of the box (x, y, z) + box_max : tuple + Maximum corner of the box (x, y, z) + ivar : int + Variable index for multi-component fields (default: 0) + + Examples + -------- + >>> # Sod shock tube: set left state internal energy + >>> model.set_field_in_box("uint", "f64", u_left, (-1,-1,-1), (0,1,1)) + >>> # Set right state + >>> model.set_field_in_box("uint", "f64", u_right, (0,-1,-1), (1,1,1)) +)==") + .def( + "set_field_in_sphere", + [](T &self, + std::string field_name, + std::string field_type, + pybind11::object value, + f64_3 center, + f64 radius) { + if (field_type == "f64") { + f64 val = value.cast(); + self.set_field_in_sphere(field_name, val, center, radius); + } else if (field_type == "f64_3") { + f64_3 val = value.cast(); + self.set_field_in_sphere(field_name, val, center, radius); + } else { + throw shambase::make_except_with_loc( + "unknown field type"); + } + }, + py::arg("field_name"), + py::arg("field_type"), + py::arg("value"), + py::arg("center"), + py::arg("radius"), + R"==( + Set field value for particles within a spherical region. + + Useful for setting up point-source initial conditions like Sedov blast. + + Parameters + ---------- + field_name : str + Name of the field to set (e.g., "uint") + field_type : str + Type of the field: "f64" or "f64_3" + value : float or tuple + Value to set (type must match field_type) + center : tuple + Center of the sphere (x, y, z) + radius : float + Radius of the sphere + + Examples + -------- + >>> # Sedov blast: inject energy in central sphere + >>> model.set_field_in_sphere("uint", "f64", u_blast, (0,0,0), r_blast) +)==") + .def("apply_field_from_position_f64_3", &T::template apply_field_from_position) + .def("apply_field_from_position_f64", &T::template apply_field_from_position) + .def( + "get_sum", + [](T &self, std::string field_name, std::string field_type) { + if (field_type == "f64") { + return py::cast(self.template get_sum(field_name)); + } else if (field_type == "f64_3") { + return py::cast(self.template get_sum(field_name)); + } else { + throw shambase::make_except_with_loc( + "unknown field type"); + } + }) + .def( + "gen_default_config", + [](T &self) { + return self.gen_default_config(); + }) + .def( + "get_current_config", + [](T &self) { + return self.solver.solver_config; + }) + .def("set_solver_config", &T::set_solver_config) + .def("do_vtk_dump", &T::do_vtk_dump) + .def("solver_logs_last_rate", &T::solver_logs_last_rate) + .def("solver_logs_last_obj_count", &T::solver_logs_last_obj_count) + .def( + "get_time", + [](T &self) { + return self.solver.solver_config.get_time(); + }) + .def( + "get_dt", + [](T &self) { + return self.solver.solver_config.get_dt(); + }) + .def( + "set_time", + [](T &self, Tscal t) { + return self.solver.solver_config.set_time(t); + }) + .def( + "set_next_dt", + [](T &self, Tscal dt) { + return self.solver.solver_config.set_next_dt(dt); + }) + .def( + "load_from_dump", + &T::load_from_dump, + py::arg("filename"), + R"==( + Load simulation state from a Shamrock dump file. + + Uses the shared ShamrockDump mechanism (same as SPH). + + Parameters + ---------- + filename : str + Path to the dump file + + Example + ------- + >>> model.load_from_dump("checkpoint.shamrock") +)==") + .def( + "dump", + &T::dump, + py::arg("filename"), + R"==( + Write simulation state to a Shamrock dump file. + + Uses the shared ShamrockDump mechanism (same as SPH). + + Parameters + ---------- + filename : str + Path to the dump file + + Example + ------- + >>> model.dump("checkpoint.shamrock") +)=="); +} + +using namespace shammodels::gsph; + +Register_pymod(pygsphmodel) { + + py::module mgsph = m.def_submodule("model_gsph", "Shamrock GSPH (Godunov SPH) solver"); + + using namespace shammodels::gsph; + + // Register GSPH models for different kernels + add_gsph_instance( + mgsph, "GSPHModel_f64_3_M4_SolverConfig", "GSPHModel_f64_3_M4"); + add_gsph_instance( + mgsph, "GSPHModel_f64_3_M6_SolverConfig", "GSPHModel_f64_3_M6"); + add_gsph_instance( + mgsph, "GSPHModel_f64_3_M8_SolverConfig", "GSPHModel_f64_3_M8"); + + add_gsph_instance( + mgsph, "GSPHModel_f64_3_C2_SolverConfig", "GSPHModel_f64_3_C2"); + add_gsph_instance( + mgsph, "GSPHModel_f64_3_C4_SolverConfig", "GSPHModel_f64_3_C4"); + add_gsph_instance( + mgsph, "GSPHModel_f64_3_C6_SolverConfig", "GSPHModel_f64_3_C6"); + + using VariantGSPHModelBind = std::variant< + std::unique_ptr>, + std::unique_ptr>, + std::unique_ptr>, + std::unique_ptr>, + std::unique_ptr>, + std::unique_ptr>>; + + m.def( + "get_Model_GSPH", + [](ShamrockCtx &ctx, std::string vector_type, std::string kernel) -> VariantGSPHModelBind { + VariantGSPHModelBind ret; + + if (vector_type == "f64_3" && kernel == "M4") { + ret = std::make_unique>(ctx); + } else if (vector_type == "f64_3" && kernel == "M6") { + ret = std::make_unique>(ctx); + } else if (vector_type == "f64_3" && kernel == "M8") { + ret = std::make_unique>(ctx); + } else if (vector_type == "f64_3" && kernel == "C2") { + ret = std::make_unique>(ctx); + } else if (vector_type == "f64_3" && kernel == "C4") { + ret = std::make_unique>(ctx); + } else if (vector_type == "f64_3" && kernel == "C6") { + ret = std::make_unique>(ctx); + } else { + throw shambase::make_except_with_loc( + "unknown combination of representation and kernel"); + } + + return ret; + }, + py::kw_only(), + py::arg("context"), + py::arg("vector_type") = "f64_3", + py::arg("sph_kernel") = "C4", + R"==( + Create a GSPH (Godunov SPH) model. + + GSPH uses Riemann solvers at particle interfaces instead of artificial viscosity, + giving sharper shock resolution. + + Parameters + ---------- + context : ShamrockCtx + Shamrock context + vector_type : str + Vector type, e.g., "f64_3" for 3D double precision (default: "f64_3") + sph_kernel : str + SPH kernel type: "C4" (Wendland, default), "M4" (cubic spline), "M6", "M8", "C2", "C6" + + Returns + ------- + GSPHModel + A GSPH model instance + + Examples + -------- + >>> ctx = shamrock.ShamrockCtx() + >>> model = shamrock.get_Model_GSPH(context=ctx) # Uses C4 kernel by default + >>> config = model.gen_default_config() + >>> config.set_riemann_hllc() + >>> config.set_eos_adiabatic(1.4) + >>> model.set_solver_config(config) +)=="); +}