Skip to content

Commit 700d93c

Browse files
committed
array now returns a reference in the original vector instead of a copy
1 parent b1fe169 commit 700d93c

File tree

5 files changed

+119
-25
lines changed

5 files changed

+119
-25
lines changed

infra-rs

python/geodynamix/geodynamix.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ class raster:
9393
@property
9494
def array(self) -> np.ndarray:
9595
"""
96-
Returns the raster as a numpy array. Data is copied so changes to the array will not affect the raster.
96+
Returns the raster as a numpy array. The data in the array is a view on the raster data so changes to the array will affect the raster.
97+
This also means that the array cannot be changed in size, only in content.
9798
"""
9899
...
99100
@property

src/raster.rs

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
use numpy::PyArrayDescr;
2+
use pyo3::exceptions::PyBufferError;
3+
use pyo3::ffi;
24
use pyo3::prelude::*;
5+
use std::ffi::c_void;
6+
use std::ffi::CString;
37
use std::ops::Add;
48
use std::ops::Div;
59
use std::ops::Mul;
@@ -71,8 +75,9 @@ impl Raster {
7175
}
7276

7377
#[getter]
74-
pub fn array<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
75-
utils::raster_array(py, &self.raster)
78+
pub fn array<'py>(slf: Bound<'py, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
79+
//utils::raster_array(py, &slf.borrow().raster)
80+
utils::raster_buffer_array(py, slf)
7681
}
7782

7883
#[getter]
@@ -109,6 +114,26 @@ impl Raster {
109114
pub fn __truediv__(&self, rhs_object: Bound<'_, PyAny>) -> PyResult<Raster> {
110115
impl_raster_op!(div, self.raster, rhs_object)
111116
}
117+
/// # Safety
118+
/// This function is unsafe because it exposes a raw pointer to the Python buffer protocol.
119+
pub unsafe fn __getbuffer__(
120+
slf: Bound<'_, Self>,
121+
view: *mut ffi::Py_buffer,
122+
flags: std::ffi::c_int,
123+
) -> PyResult<()> {
124+
fill_view_from_data(
125+
view,
126+
flags,
127+
slf.borrow().raster.raw_data_u8_slice(),
128+
slf.into_any(),
129+
)
130+
}
131+
132+
/// # Safety
133+
/// This function is unsafe because it exposes a raw pointer to the Python buffer protocol.
134+
pub unsafe fn __releasebuffer__(&self, _buffer: *mut ffi::Py_buffer) {
135+
// No need to release, the raster still owns the data
136+
}
112137
}
113138

114139
impl From<PythonDenseArray> for Raster {
@@ -144,3 +169,50 @@ pub fn raster_equal(
144169

145170
Ok(array1 == array2)
146171
}
172+
173+
/// # Safety
174+
///
175+
/// `view` must be a valid pointer to `ffi::Py_buffer`, or null
176+
/// `data` must outlive the Python lifetime of `owner` (i.e. data must be owned by owner, or data
177+
/// must be static data)
178+
unsafe fn fill_view_from_data(
179+
view: *mut ffi::Py_buffer,
180+
flags: std::ffi::c_int,
181+
data: &[u8],
182+
owner: Bound<'_, PyAny>,
183+
) -> PyResult<()> {
184+
if view.is_null() {
185+
return Err(PyBufferError::new_err("View is null"));
186+
}
187+
188+
(*view).obj = owner.into_ptr();
189+
(*view).buf = data.as_ptr() as *mut c_void;
190+
(*view).len = data.len() as isize;
191+
(*view).readonly = 0;
192+
(*view).itemsize = 1;
193+
194+
(*view).format = if (flags & ffi::PyBUF_FORMAT) == ffi::PyBUF_FORMAT {
195+
let msg = CString::new("B").unwrap();
196+
msg.into_raw()
197+
} else {
198+
std::ptr::null_mut()
199+
};
200+
201+
(*view).ndim = 1;
202+
(*view).shape = if (flags & ffi::PyBUF_ND) == ffi::PyBUF_ND {
203+
&mut (*view).len
204+
} else {
205+
std::ptr::null_mut()
206+
};
207+
208+
(*view).strides = if (flags & ffi::PyBUF_STRIDES) == ffi::PyBUF_STRIDES {
209+
&mut (*view).itemsize
210+
} else {
211+
std::ptr::null_mut()
212+
};
213+
214+
(*view).suboffsets = std::ptr::null_mut();
215+
(*view).internal = std::ptr::null_mut();
216+
217+
Ok(())
218+
}

src/utils.rs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use pyo3::{exceptions::PyValueError, prelude::*, types::IntoPyDict};
33

44
use geo::{raster::algo, Array, ArrayDataType, ArrayMetadata, ArrayNum, DenseArray};
55

6-
use crate::{PythonDenseArray, RasterMetadata};
6+
use crate::{PythonDenseArray, Raster, RasterMetadata};
77

88
fn convert_array<'py, T: ArrayNum + numpy::Element>(
99
py: Python<'py>,
@@ -47,6 +47,27 @@ pub fn raster_mask<'py>(py: Python<'py>, raster: &PythonDenseArray) -> PyResult<
4747
}
4848
}
4949

50+
pub fn raster_buffer_array<'py>(
51+
py: Python<'py>,
52+
raster: Bound<'py, Raster>,
53+
) -> PyResult<Bound<'py, PyAny>> {
54+
let np = PyModule::import(py, "numpy")?;
55+
56+
let shape = [
57+
raster.borrow().raster.rows().count(),
58+
raster.borrow().raster.columns().count(),
59+
];
60+
let dtype =
61+
array_type_to_numpy_dtype(py, raster.borrow().raster.data_type()).into_pyobject(py)?;
62+
let kwargs = vec![
63+
("dtype", dtype.into_pyobject(py)?.into_any()),
64+
("buffer", raster.into_pyobject(py)?.into_any()),
65+
];
66+
67+
np.getattr("ndarray")?
68+
.call((shape.into_pyobject(py)?,), Some(&kwargs.into_py_dict(py)?))
69+
}
70+
5071
pub fn raster_masked_array<'py>(
5172
py: Python<'py>,
5273
raster: &PythonDenseArray,

tests/testgdx.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -182,31 +182,31 @@ def test_write_map(self):
182182
os.remove("writtenraster.asc")
183183
self.assertTrue(np.allclose(raster, written_raster.array, equal_nan=True))
184184

185-
# def test_modify_raster_using_numpy(self):
186-
# # modify the raster data using the ndarry accessor
187-
# # verify that the internal raster data has changed by writing it to disk
188-
# # and comparing the result
185+
def test_modify_raster_using_numpy(self):
186+
# modify the raster data using the ndarry accessor
187+
# verify that the internal raster data has changed by writing it to disk
188+
# and comparing the result
189189

190-
# create_test_file("raster.asc", test_raster)
191-
# raster = gdx.read("raster.asc")
190+
create_test_file("raster.asc", test_raster)
191+
raster = gdx.read("raster.asc")
192192

193-
# raster.array.fill(44)
193+
raster.array.fill(44)
194194

195-
# gdx.write(raster, "writtenraster.asc")
196-
# written_raster = gdx.read("writtenraster.asc")
197-
# os.remove("writtenraster.asc")
195+
gdx.write(raster, "writtenraster.asc")
196+
written_raster = gdx.read("writtenraster.asc")
197+
os.remove("writtenraster.asc")
198198

199-
# expected = np.array(
200-
# [
201-
# [44, 44, 44, 44, 44],
202-
# [44, 44, 44, 44, 44],
203-
# [44, 44, 44, 44, 44],
204-
# [44, 44, 44, 44, 44],
205-
# ],
206-
# dtype="f",
207-
# )
199+
expected = np.array(
200+
[
201+
[44, 44, 44, 44, 44],
202+
[44, 44, 44, 44, 44],
203+
[44, 44, 44, 44, 44],
204+
[44, 44, 44, 44, 44],
205+
],
206+
dtype="f",
207+
)
208208

209-
# self.assertTrue(np.allclose(expected, written_raster.array, equal_nan=True))
209+
self.assertTrue(np.allclose(expected, written_raster.array, equal_nan=True))
210210

211211
def test_replace_value(self):
212212
expected = np.array(

0 commit comments

Comments
 (0)