Skip to content

Commit 60c1968

Browse files
Merge commit '9a20c4ae9414f61d9377ca7e54e8e3ec3d43f9aa' into feature/data_dimnesion_fields
2 parents 1daedff + 9a20c4a commit 60c1968

File tree

8 files changed

+32
-32
lines changed

8 files changed

+32
-32
lines changed

.github/workflows/fv3_translate_tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010

1111
jobs:
1212
fv3_translate_tests:
13-
uses: twicki/pyFV3/.github/workflows/translate.yaml@update/numpy_2x
13+
uses: NOAA-GFDL/pyFV3/.github/workflows/translate.yaml@develop
1414
with:
1515
component_trigger: true
1616
component_name: NDSL

.github/workflows/pace_tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010

1111
jobs:
1212
pace_main_tests:
13-
uses: floriandeconinck/pace/.github/workflows/main_unit_tests.yaml@update/numpy_2x
13+
uses: NOAA-GFDL/pace/.github/workflows/main_unit_tests.yaml@develop
1414
with:
1515
component_trigger: true
1616
component_name: NDSL

.github/workflows/shield_tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010

1111
jobs:
1212
shield_translate_tests:
13-
uses: floriandeconinck/pySHiELD/.github/workflows/translate.yaml@update/numpy_2x
13+
uses: NOAA-GFDL/pySHiELD/.github/workflows/translate.yaml@develop
1414
with:
1515
component_trigger: true
1616
component_name: NDSL

external/gt4py

ndsl/stencils/column_operations.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@ def column_max(field, start_index, end_index):
1717
Returns: [max value, index of max value]
1818
"""
1919
max_index = start_index
20+
max_value = field.at(K=max_index)
2021
level = start_index
2122
while level <= end_index:
22-
new = field.at(K=level)
23-
old = field.at(K=max_index)
24-
if new > old:
23+
value = field.at(K=level)
24+
if value > max_value:
25+
max_value = value
2526
max_index = level
2627
level += 1
2728

28-
return field.at(K=max_index), max_index
29+
return max_value, max_index
2930

3031

3132
@typing.no_type_check
@@ -42,15 +43,16 @@ def column_max_ddim(field, ddim, start_index, end_index):
4243
Returns: [max value, index of max value]
4344
"""
4445
max_index = start_index
46+
max_value = field.at(K=max_index, ddim=[ddim])
4547
level = start_index
4648
while level <= end_index:
47-
new = field.at(K=level, ddim=[ddim])
48-
old = field.at(K=max_index, ddim=[ddim])
49-
if new > old:
49+
value = field.at(K=level, ddim=[ddim])
50+
if value > max_value:
51+
max_value = value
5052
max_index = level
5153
level += 1
5254

53-
return field.at(K=max_index, ddim=[ddim]), max_index
55+
return max_value, max_index
5456

5557

5658
@typing.no_type_check
@@ -67,15 +69,16 @@ def column_min(field, start_index, end_index):
6769
Returns: [min value, index of min value]
6870
"""
6971
min_index = start_index
72+
min_value = field.at(K=min_index)
7073
level = start_index
7174
while level <= end_index:
72-
new = field.at(K=level)
73-
old = field.at(K=min_index)
74-
if new < old:
75+
value = field.at(K=level)
76+
if value < min_value:
77+
min_value = value
7578
min_index = level
7679
level += 1
7780

78-
return field.at(K=min_index), min_index
81+
return min_value, min_index
7982

8083

8184
@typing.no_type_check
@@ -92,12 +95,13 @@ def column_min_ddim(field, ddim, start_index, end_index):
9295
Returns: [min value, index of min value]
9396
"""
9497
min_index = start_index
98+
min_value = field.at(K=min_index, ddim=[ddim])
9599
level = start_index
96100
while level <= end_index:
97-
new = field.at(K=level, ddim=[ddim])
98-
old = field.at(K=min_index, ddim=[ddim])
99-
if new < old:
101+
value = field.at(K=level, ddim=[ddim])
102+
if value < min_value:
103+
min_value = value
100104
min_index = level
101105
level += 1
102106

103-
return field.at(K=min_index, ddim=[ddim]), min_index
107+
return min_value, min_index

ndsl/stencils/testing/translate.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,13 @@ def __init__(
6666
self.out_vars: dict[str, Any] = {}
6767
self.write_vars: list = []
6868
self.grid = grid
69-
self.maxshape: tuple[int, ...] = grid.domain_shape_full(add=(1, 1, 1))
7069
self.ordered_input_vars = None
7170
self.ignore_near_zero_errors: dict[str, Any] = {}
7271
self.skip_test = skip_test
72+
if self.stencil_factory.backend.is_fortran_aligned():
73+
self.maxshape = self.grid.domain_shape_full()
74+
else:
75+
self.maxshape = self.grid.domain_shape_full(add=(1, 1, 1))
7376

7477
def extra_data_load(self, data_loader: DataLoader):
7578
pass

ndsl/testing/comparison.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ def __init__(
2727
reference_values: np.ndarray,
2828
computed_values: np.ndarray,
2929
):
30-
self.references = np.atleast_1d(reference_values)
31-
self.computed = np.atleast_1d(computed_values)
30+
self.references = np.squeeze(np.atleast_1d(reference_values))
31+
self.computed = np.squeeze(np.atleast_1d(computed_values))
3232
self.check = False
3333

3434
@abstractmethod

ndsl/xumpy/alloc.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44
import numpy.typing as npt
5-
from numpy._typing import _SupportsDType
65

76
from ndsl.config import Backend
87
from ndsl.dsl.typing import Float
@@ -14,12 +13,6 @@
1413

1514
# Taking a page from cupy's playbook to have tuple & ndarray
1615
_ShapeLike = SupportsIndex | Sequence[SupportsIndex]
17-
_DTypeLikeFloat32 = (
18-
np.dtype[np.float32] | _SupportsDType[np.dtype[np.float32]] | type[np.float32]
19-
)
20-
_DTypeLikeFloat64 = (
21-
np.dtype[np.float64] | _SupportsDType[np.dtype[np.float64]] | type[np.float64]
22-
)
2316

2417

2518
def zeros(
@@ -55,7 +48,7 @@ def empty(
5548
def full(
5649
shape: _ShapeLike,
5750
backend: Backend,
58-
value: np.ScalarType,
51+
value: npt.DTypeLike,
5952
dtype: npt.DTypeLike = Float,
6053
) -> np.ndarray | cp.ndarray:
6154
if backend.is_gpu_backend():
@@ -66,7 +59,7 @@ def full(
6659
def random(
6760
shape: _ShapeLike,
6861
backend: Backend,
69-
dtype: _DTypeLikeFloat32 | _DTypeLikeFloat64 = Float, # type: ignore [valid-type]
62+
dtype: np.floating = Float,
7063
) -> np.ndarray | cp.ndarray:
7164
if backend.is_gpu_backend():
7265
gen = cp.random.default_rng()

0 commit comments

Comments
 (0)