Skip to content

Commit 4cd56d8

Browse files
Implement direct getter/setter on data, swap utest default + new utest to confirm data is modified
1 parent f9bb58e commit 4cd56d8

File tree

2 files changed

+45
-23
lines changed

2 files changed

+45
-23
lines changed

ndsl/quantity/quantity.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,14 @@ def data_as_xarray(self) -> xr.DataArray:
329329
def np(self) -> ModuleType:
330330
return self.metadata.np
331331

332+
def __getitem__(self, subscript: Any) -> Any:
333+
"""Slicing operator accessing the full buffer"""
334+
return self.data[subscript]
335+
336+
def __setitem__(self, subscript: Any, value: Any) -> None:
337+
"""Slicing operator setting the full buffer"""
338+
self.data[subscript] = value
339+
332340
@property
333341
def __array_interface__(self): # type: ignore[no-untyped-def]
334342
return self.data.__array_interface__

tests/quantity/test_quantity.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_smaller_extent_raises(data, origin, extent, dims, units):
129129

130130
def test_data_change_affects_quantity(data, quantity, numpy):
131131
data[:] = 5.0
132-
numpy.testing.assert_array_equal(quantity.data, 5.0)
132+
numpy.testing.assert_array_equal(quantity, 5.0)
133133

134134

135135
def test_quantity_units(quantity, units):
@@ -150,7 +150,7 @@ def test_quantity_extent(quantity, extent):
150150

151151

152152
def test_compute_view_get_value(quantity, extent_1d, n_halo, n_dims):
153-
quantity.data[:] = 0.0
153+
quantity[:] = 0.0
154154
if extent_1d == 0 and n_halo == 0:
155155
with pytest.raises(IndexError):
156156
quantity.view[[0] * n_dims]
@@ -160,52 +160,50 @@ def test_compute_view_get_value(quantity, extent_1d, n_halo, n_dims):
160160

161161

162162
def test_compute_view_edit_start_halo(quantity, extent_1d, n_halo, n_dims):
163-
quantity.data[:] = 0.0
163+
quantity[:] = 0.0
164164
if extent_1d == 0 and n_halo == 0:
165165
with pytest.raises(IndexError):
166166
quantity.view[[-1] * n_dims] = 1
167167
else:
168168
quantity.view[[-1] * n_dims] = 1
169-
assert quantity.np.sum(quantity.data) == 1.0
170-
assert quantity.data[(n_halo - 1,) * n_dims] == 1
169+
assert quantity.np.sum(quantity) == 1.0
170+
assert quantity[(n_halo - 1,) * n_dims] == 1
171171

172172

173173
def test_compute_view_edit_end_halo(quantity, extent_1d, n_halo, n_dims):
174-
quantity.data[:] = 0.0
174+
quantity[:] = 0.0
175175
if n_halo == 0:
176176
with pytest.raises(IndexError):
177177
quantity.view[[extent_1d] * n_dims] = 1
178178
else:
179179
quantity.view[(extent_1d,) * n_dims] = 1
180-
assert quantity.np.sum(quantity.data) == 1.0
181-
assert quantity.data[(n_halo + extent_1d,) * n_dims] == 1
180+
assert quantity.np.sum(quantity) == 1.0
181+
assert quantity[(n_halo + extent_1d,) * n_dims] == 1
182182

183183

184184
def test_compute_view_edit_start_of_domain(quantity, extent_1d, n_halo, n_dims):
185185
if extent_1d == 0:
186186
return # cannot edit an empty domain
187187

188-
quantity.data[:] = 0.0
188+
quantity[:] = 0.0
189189
quantity.view[(0,) * n_dims] = 1
190-
assert quantity.data[(n_halo,) * n_dims] == 1
191-
assert quantity.np.sum(quantity.data) == 1.0
190+
assert quantity[(n_halo,) * n_dims] == 1
191+
assert quantity.np.sum(quantity) == 1.0
192192

193193

194194
def test_compute_view_edit_all_domain(quantity, n_halo, n_dims, extent_1d):
195195
if extent_1d == 0:
196196
return # cannot edit an empty domain
197197

198-
quantity.data[:] = 0.0
198+
quantity[:] = 0.0
199199
quantity.view[:] = 1
200-
assert quantity.np.sum(quantity.data) == extent_1d**n_dims
200+
assert quantity.np.sum(quantity) == extent_1d**n_dims
201201
if n_dims > 1:
202-
quantity.np.testing.assert_array_equal(quantity.data[:n_halo, :], 0.0)
203-
quantity.np.testing.assert_array_equal(
204-
quantity.data[n_halo + extent_1d :, :], 0.0
205-
)
202+
quantity.np.testing.assert_array_equal(quantity[:n_halo, :], 0.0)
203+
quantity.np.testing.assert_array_equal(quantity[n_halo + extent_1d :, :], 0.0)
206204
else:
207-
quantity.np.testing.assert_array_equal(quantity.data[:n_halo], 0.0)
208-
quantity.np.testing.assert_array_equal(quantity.data[n_halo + extent_1d :], 0.0)
205+
quantity.np.testing.assert_array_equal(quantity[:n_halo], 0.0)
206+
quantity.np.testing.assert_array_equal(quantity[n_halo + extent_1d :], 0.0)
209207

210208

211209
@pytest.mark.parametrize(
@@ -298,19 +296,35 @@ def test_to_data_array(quantity):
298296
assert quantity.field_as_xarray.dims == quantity.dims
299297
assert quantity.field_as_xarray.shape == quantity.extent
300298
np.testing.assert_array_equal(quantity.field_as_xarray.values, quantity.view[:])
301-
if quantity.extent == quantity.data.shape:
299+
if quantity.extent == quantity.shape:
302300
assert (
303301
quantity.field_as_xarray.data.ctypes.data == quantity.data.ctypes.data
304302
), "data memory address is not equal"
305303

306304

305+
def test_data_attribute_and_default_setter_are_the_same():
306+
quantity = Quantity(
307+
np.arange(
308+
5,
309+
),
310+
dims=["dim1"],
311+
units="",
312+
backend=Backend.python(),
313+
)
314+
315+
assert quantity.shape == quantity.shape
316+
assert quantity[3] == quantity[3]
317+
quantity[2] = 42.0
318+
assert quantity[2] == 42.0
319+
320+
307321
def test_data_setter():
308322
quantity = Quantity(
309323
np.ones((5,)), dims=["dim1"], units="", backend=Backend.python()
310324
)
311325

312326
# After allocation - field and data are the same (origin is 0)
313-
assert quantity.data.shape == quantity.field.shape
327+
assert quantity.shape == quantity.field.shape
314328

315329
# Allows swap: new array is bigger than Q.shape
316330
new_array = np.ones((10,))
@@ -319,9 +333,9 @@ def test_data_setter():
319333

320334
# After swap - field and data points to the same memory
321335
# BUT field still respects the original origin/extent
322-
assert (quantity.data[:] == 2).all()
336+
assert (quantity[:] == 2).all()
323337
assert (quantity.field[:] == 2).all()
324-
assert quantity.data.shape != quantity.field.shape
338+
assert quantity.shape != quantity.field.shape
325339
assert quantity.field.shape == (5,)
326340

327341
# Expected fail: new array is too small

0 commit comments

Comments
 (0)