Skip to content

Commit 4a5ced0

Browse files
committed
[JTH] little edits to format code
1 parent 61f296b commit 4a5ced0

File tree

3 files changed

+45
-43
lines changed

3 files changed

+45
-43
lines changed

bluemath_tk/topo_bathy/swan_grid.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import numpy as np
22
import xarray as xr
33

4+
45
def generate_grid_parameters(
56
bathy_data: xr.DataArray,
67
alpc: float = 0,
78
xpc: float = None,
89
ypc: float = None,
910
xlenc: float = None,
1011
ylenc: float = None,
11-
buffer_distance: float = None,
12+
buffer_distance: float = None,
1213
) -> dict:
1314
"""
1415
Generate grid parameters for the SWAN model based on bathymetry.
@@ -31,29 +32,15 @@ def generate_grid_parameters(
3132
-------
3233
dict
3334
Dictionary with grid configuration for SWAN input.
34-
35-
36-
Contact
37-
-------
38-
@
39-
"""
4035
"""
41-
Generate the grid parameters for the SWAN model.
4236

43-
Returns
44-
-------
45-
dict
46-
Grid parameters for the SWAN model.
47-
"""
4837
# Determine coordinate system based on coordinate names
4938
coord_names = list(bathy_data.coords)
5039

5140
# Get coordinate variables
5241
if any(name in ["lon", "longitude"] for name in coord_names):
5342
x_coord = next(name for name in coord_names if name in ["lon", "longitude"])
5443
y_coord = next(name for name in coord_names if name in ["lat", "latitude"])
55-
is_geographic = True
56-
# coord_type = 'geographic'
5744
else:
5845
x_coord = next(
5946
name for name in coord_names if name in ["x", "X", "cx", "easting"]
@@ -62,7 +49,6 @@ def generate_grid_parameters(
6249
name for name in coord_names if name in ["y", "Y", "cy", "northing"]
6350
)
6451

65-
6652
# Get resolution from cropped data
6753
grid_resolution_x = abs(
6854
bathy_data[x_coord][1].values - bathy_data[x_coord][0].values
@@ -93,7 +79,7 @@ def generate_grid_parameters(
9379
x = rotated[:, 0] + xpc
9480
y = rotated[:, 1] + ypc
9581
corners = np.column_stack([x, y])
96-
82+
9783
x_min = np.min(x) - buffer_distance
9884
x_max = np.max(x) + buffer_distance
9985
y_min = np.min(y) - buffer_distance
@@ -128,8 +114,6 @@ def generate_grid_parameters(
128114
return grid_parameters, cropped, corners
129115

130116
else:
131-
132-
133117
# Compute parameters from full bathymetry
134118
grid_parameters = {
135119
"xpc": float(np.nanmin(bathy_data[x_coord])), # origin x
@@ -145,7 +129,7 @@ def generate_grid_parameters(
145129
"myc": len(bathy_data[y_coord]) - 1, # num mesh y
146130
"xpinp": float(np.nanmin(bathy_data[x_coord])), # origin x
147131
"ypinp": float(np.nanmin(bathy_data[y_coord])), # origin y
148-
"alpinp":0, # x-axis direction
132+
"alpinp": 0, # x-axis direction
149133
"mxinp": len(bathy_data[x_coord]) - 1, # num mesh x
150134
"myinp": len(bathy_data[y_coord]) - 1, # num mesh y
151135
"dxinp": float(
@@ -154,6 +138,5 @@ def generate_grid_parameters(
154138
"dyinp": float(
155139
abs(bathy_data[y_coord][1].values - bathy_data[y_coord][0].values)
156140
), # resolution y
157-
}
141+
}
158142
return grid_parameters
159-

bluemath_tk/waves/calibration.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,6 @@ def _create_vec_direc(self, waves: np.ndarray, direcs: np.ndarray) -> np.ndarray
354354
else:
355355
bin_idx = int(direcs[i] / self.direction_bin_size)
356356
data[i, bin_idx] = waves[i]
357-
358357

359358
return data
360359

@@ -562,7 +561,9 @@ def correct(
562561
correction_coeffs[n_part, :] = np.array(
563562
[
564563
self.calibration_params["sea_correction"][
565-
int(peak_direction / self.direction_bin_size) if peak_direction < 360 else 0 #TODO: Check if this with Javi
564+
int(peak_direction / self.direction_bin_size)
565+
if peak_direction < 360
566+
else 0 # TODO: Check if this with Javi
566567
]
567568
for peak_direction in peak_directions.isel(
568569
part=n_part
@@ -573,7 +574,9 @@ def correct(
573574
correction_coeffs[n_part, :] = np.array(
574575
[
575576
self.calibration_params["swell_correction"][
576-
int(peak_direction / self.direction_bin_size) if peak_direction < 360 else 0 #TODO: Check if this with Javi
577+
int(peak_direction / self.direction_bin_size)
578+
if peak_direction < 360
579+
else 0 # TODO: Check if this with Javi
577580
]
578581
for peak_direction in peak_directions.isel(
579582
part=n_part
@@ -599,7 +602,9 @@ def correct(
599602
* np.array(
600603
[
601604
self.calibration_params["sea_correction"][
602-
int(peak_direction / self.direction_bin_size) if peak_direction < 360 else 0
605+
int(peak_direction / self.direction_bin_size)
606+
if peak_direction < 360
607+
else 0
603608
]
604609
for peak_direction in corrected_data["Dirsea"]
605610
]
@@ -613,7 +618,9 @@ def correct(
613618
* np.array(
614619
[
615620
self.calibration_params["swell_correction"][
616-
int(peak_direction / self.direction_bin_size) if peak_direction < 360 else 0
621+
int(peak_direction / self.direction_bin_size)
622+
if peak_direction < 360
623+
else 0
617624
]
618625
for peak_direction in corrected_data[f"Dirswell{n_part}"]
619626
]
@@ -764,7 +771,7 @@ def plot_calibration_results(self) -> Tuple[Figure, list]:
764771
valid_mask = np.isfinite(sea_dirs) & np.isfinite(sea_heights)
765772
sea_dirs_valid = sea_dirs[valid_mask]
766773
sea_heights_valid = sea_heights[valid_mask]
767-
774+
768775
if len(sea_dirs_valid) > 0:
769776
x, y, z = density_scatter(sea_dirs_valid, sea_heights_valid)
770777
ax5.scatter(x, y, c=z, s=3, cmap="jet")
@@ -784,7 +791,7 @@ def plot_calibration_results(self) -> Tuple[Figure, list]:
784791
valid_mask = np.isfinite(swell_dirs) & np.isfinite(swell_heights)
785792
swell_dirs_valid = swell_dirs[valid_mask]
786793
swell_heights_valid = swell_heights[valid_mask]
787-
794+
788795
if len(swell_dirs_valid) > 0:
789796
x, y, z = density_scatter(swell_dirs_valid, swell_heights_valid)
790797
ax6.scatter(x, y, c=z, s=3, cmap="jet")

bluemath_tk/waves/superpoint.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,36 +56,48 @@ def load_station_data(ds: xr.Dataset) -> xr.DataArray:
5656
superpoint_dataarray = xr.zeros_like(
5757
stations_data.isel({stations_dimension_name: 0})
5858
)
59-
59+
6060
if overlap_angle == 0:
6161
for station_id, (dir_min, dir_max) in sectors_for_each_station.items():
6262
if dir_min < dir_max:
63-
mask = (stations_data["dir"] >= dir_min) & (stations_data["dir"] < dir_max)
63+
mask = (stations_data["dir"] >= dir_min) & (
64+
stations_data["dir"] < dir_max
65+
)
6466
else:
6567
# Handle wrap-around (e.g., 350° to 10°)
66-
mask = (stations_data["dir"] >= dir_min) | (stations_data["dir"] < dir_max)
67-
superpoint_dataarray += stations_data.sel({stations_dimension_name: station_id}).where(mask, 0.0)
68-
68+
mask = (stations_data["dir"] >= dir_min) | (
69+
stations_data["dir"] < dir_max
70+
)
71+
superpoint_dataarray += stations_data.sel(
72+
{stations_dimension_name: station_id}
73+
).where(mask, 0.0)
74+
6975
else:
7076
# With overlap - expand sectors and average overlaps
7177
directions = stations_data["dir"]
7278
count_array = xr.zeros_like(superpoint_dataarray) # Counter for overlaps
73-
79+
7480
for station_id, (dir_min, dir_max) in sectors_for_each_station.items():
7581
station_data = stations_data.sel({stations_dimension_name: station_id})
76-
82+
7783
# Expand sector boundaries by overlap_angle
78-
if (dir_max - dir_min) < 0:
79-
mask = (directions >= dir_min - overlap_angle) | (directions <= dir_max + overlap_angle)
80-
else:
81-
mask = (directions >= dir_min - overlap_angle) & (directions <= dir_max + overlap_angle)
82-
84+
if (dir_max - dir_min) < 0:
85+
mask = (directions >= dir_min - overlap_angle) | (
86+
directions <= dir_max + overlap_angle
87+
)
88+
else:
89+
mask = (directions >= dir_min - overlap_angle) & (
90+
directions <= dir_max + overlap_angle
91+
)
92+
8393
# Add contribution where mask is true
8494
superpoint_dataarray += station_data.where(mask, 0.0)
8595
count_array += xr.where(mask, 1, 0)
86-
96+
8797
# Average where there are overlaps (count > 1)
8898
overlap_mask = count_array > 1
89-
superpoint_dataarray = xr.where(overlap_mask, superpoint_dataarray / count_array, superpoint_dataarray)
99+
superpoint_dataarray = xr.where(
100+
overlap_mask, superpoint_dataarray / count_array, superpoint_dataarray
101+
)
90102

91103
return superpoint_dataarray

0 commit comments

Comments
 (0)