Skip to content

feat: Add ml model input and output shape to allow models run on entire tiles #205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/.backend_git_ref
Original file line number Diff line number Diff line change
@@ -1 +1 @@
af126cb150c974cf47a52d2fac5b4a96a81d2c77
7aadfa383e6eee63442e366890dfb1160114caed
2 changes: 1 addition & 1 deletion examples/expression.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
9 changes: 5 additions & 4 deletions examples/ml_pipeline.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import geoengine as ge\n",
"from geoengine.ml import MlModelConfig\n",
"\n",
"from geoengine_openapi_client.models import MlModelMetadata, RasterDataType\n",
"from geoengine_openapi_client.models import MlModelMetadata, RasterDataType, MlTensorShape3D as TensorShape3D\n",
"\n",
"from sklearn.tree import DecisionTreeClassifier\n",
"import numpy as np\n",
Expand Down Expand Up @@ -88,8 +88,9 @@
"metadata = MlModelMetadata(\n",
" file_name=\"model.onnx\",\n",
" input_type=RasterDataType.F32,\n",
" num_input_bands=2,\n",
" output_type=RasterDataType.I64,\n",
" input_shape=TensorShape3D(y=1, x=1, bands=2),\n",
" output_shape=TensorShape3D(y=1, x=1, bands=1)\n",
")\n",
"\n",
"model_config = MlModelConfig(\n",
Expand Down Expand Up @@ -179,7 +180,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
72 changes: 62 additions & 10 deletions geoengine/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import geoengine_openapi_client.models
from onnx import TypeProto, TensorProto, ModelProto
from onnx.helper import tensor_dtype_to_string
from geoengine_openapi_client.models import MlModelMetadata, MlModel, RasterDataType
from geoengine_openapi_client.models import MlModelMetadata, MlModel, RasterDataType, MlTensorShape3D
import geoengine_openapi_client
from geoengine.auth import get_session
from geoengine.resource_identifier import UploadId, MlModelName
Expand All @@ -35,8 +35,10 @@ def register_ml_model(onnx_model: ModelProto,
onnx_model,
input_type=model_config.metadata.input_type,
output_type=model_config.metadata.output_type,
num_input_bands=model_config.metadata.num_input_bands,
input_shape=model_config.metadata.input_shape,
out_shape=model_config.metadata.output_shape
)
check_backend_constraints(model_config.metadata.input_shape, model_config.metadata.output_shape)

session = get_session()

Expand All @@ -61,10 +63,57 @@ def register_ml_model(onnx_model: ModelProto,
return MlModelName.from_response(res_name)


def model_dim_to_tensorshape(model_dims):
'''Transform an ONNX dimension into a MlTensorShape3D'''

mts = MlTensorShape3D(x=1, y=1, bands=1)
if len(model_dims) == 1 and model_dims[0].dim_value in (-1, 0):
pass # in this case, the model will produce as many outs as inputs
elif len(model_dims) == 1 and model_dims[0].dim_value > 0:
mts.bands = model_dims[0].dim_value
elif len(model_dims) == 2:
if model_dims[0].dim_value in (None, -1, 0, 1):
mts.bands = model_dims[1].dim_value
else:
mts.y = model_dims[0].dim_value
mts.x = model_dims[1].dim_value
elif len(model_dims) == 3:
if model_dims[0].dim_value in (None, -1, 0, 1):
mts.y = model_dims[1].dim_value
mts.x = model_dims[2].dim_value
else:
mts.y = model_dims[0].dim_value
mts.x = model_dims[1].dim_value
mts.bands = model_dims[2].dim_value
elif len(model_dims) == 4 and model_dims[0].dim_value in (None, -1, 0, 1):
mts.y = model_dims[1].dim_value
mts.x = model_dims[2].dim_value
mts.bands = model_dims[3].dim_value
else:
raise InputException(f'Only 1D and 3D input tensors are supported. Got model dim {model_dims}')
return mts


def check_backend_constraints(input_shape: MlTensorShape3D, output_shape: MlTensorShape3D, ge_tile_size=(512, 512)):
''' Checks that the shapes match the constraintsof the backend'''

if not (
input_shape.x in [1, ge_tile_size[0]] and input_shape.y in [1, ge_tile_size[1]] and input_shape.bands > 0
):
raise InputException(f'Backend currently supports single pixel and full tile shaped input! Got {input_shape}!')

if not (
output_shape.x in [1, ge_tile_size[0]] and output_shape.y in [1, ge_tile_size[1]] and output_shape.bands > 0
):
raise InputException(f'Backend currently supports single pixel and full tile shaped Output! Got {input_shape}!')


# pylint: disable=too-many-branches,too-many-statements
def validate_model_config(onnx_model: ModelProto, *,
input_type: RasterDataType,
output_type: RasterDataType,
num_input_bands: int):
input_shape: MlTensorShape3D,
out_shape: MlTensorShape3D):
'''Validates the model config. Raises an exception if the model config is invalid'''

def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix: 'str'):
Expand All @@ -85,18 +134,21 @@ def check_data_type(data_type: TypeProto, expected_type: RasterDataType, prefix:
raise InputException('Models with multiple inputs are not supported')
check_data_type(model_inputs[0].type, input_type, 'input')

dims = model_inputs[0].type.tensor_type.shape.dim
if len(dims) != 2:
raise InputException('Only 2D input tensors are supported')
if not dims[1].dim_value:
raise InputException('Dimension 1 of the input tensor must have a length')
if dims[1].dim_value != num_input_bands:
raise InputException(f'Model input has {dims[1].dim_value} bands, but {num_input_bands} bands are expected')
dim = model_inputs[0].type.tensor_type.shape.dim

in_ts3d = model_dim_to_tensorshape(dim)
if not in_ts3d == input_shape:
raise InputException(f"Input shape {in_ts3d} and metadata {input_shape} not equal!")

if len(model_outputs) < 1:
raise InputException('Models with no outputs are not supported')
check_data_type(model_outputs[0].type, output_type, 'output')

dim = model_outputs[0].type.tensor_type.shape.dim
out_ts3d = model_dim_to_tensorshape(dim)
if not out_ts3d == out_shape:
raise InputException(f"Output shape {out_ts3d} and metadata {out_shape} not equal!")


RASTER_TYPE_TO_ONNX_TYPE = {
RasterDataType.F32: TensorProto.FLOAT,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package_dir =
packages = find:
python_requires = >=3.10
install_requires =
geoengine-openapi-client == 0.0.23
geoengine-openapi-client == 0.0.25
geopandas >=1.0,<2.0
matplotlib >=3.5,<3.11
numpy >=1.21,<2.3
Expand Down
63 changes: 55 additions & 8 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,63 @@
'''Tests ML functionality'''

from typing import List
import unittest
from onnx import TensorShapeProto as TSP
from sklearn.ensemble import RandomForestClassifier
from skl2onnx import to_onnx
import numpy as np
from geoengine_openapi_client.models import MlModelMetadata, RasterDataType
from geoengine_openapi_client.models import MlModelMetadata, RasterDataType, MlTensorShape3D
import geoengine as ge
from geoengine.ml import model_dim_to_tensorshape
from tests.ge_test import GeoEngineTestInstance


class WorkflowStorageTests(unittest.TestCase):
'''Test methods for storing workflows as datasets'''
class MlModelTests(unittest.TestCase):
'''Test methods for MlModels'''

def setUp(self) -> None:
ge.reset(False)

def test_model_dim_to_tensorshape(self):
''' Test model_dim_to_tensorshape '''

dim_1d: List[TSP.Dimension] = [TSP.Dimension(dim_value=7)]
mts_1d = MlTensorShape3D(bands=7, y=1, x=1)
self.assertEqual(model_dim_to_tensorshape(dim_1d), mts_1d)

dim_1d_v: List[TSP.Dimension] = [TSP.Dimension(dim_value=None), TSP.Dimension(dim_value=7)]
mts_1d_v = MlTensorShape3D(bands=7, y=1, x=1)
self.assertEqual(model_dim_to_tensorshape(dim_1d_v), mts_1d_v)

dim_2d_t: List[TSP.Dimension] = [TSP.Dimension(dim_value=512), TSP.Dimension(dim_value=512)]
mts_2d_t = MlTensorShape3D(bands=1, y=512, x=512)
self.assertEqual(model_dim_to_tensorshape(dim_2d_t), mts_2d_t)

dim_2d_1: List[TSP.Dimension] = [TSP.Dimension(dim_value=1), TSP.Dimension(dim_value=7)]
mts_2d_1 = MlTensorShape3D(bands=7, y=1, x=1)
self.assertEqual(model_dim_to_tensorshape(dim_2d_1), mts_2d_1)

dim_3d_t: List[TSP.Dimension] = [
TSP.Dimension(dim_value=512), TSP.Dimension(dim_value=512), TSP.Dimension(dim_value=7)
]
mts_3d_t = MlTensorShape3D(bands=7, y=512, x=512)
self.assertEqual(model_dim_to_tensorshape(dim_3d_t), mts_3d_t)

dim_3d_v: List[TSP.Dimension] = [
TSP.Dimension(dim_value=None), TSP.Dimension(dim_value=512), TSP.Dimension(dim_value=512)
]
mts_3d_v = MlTensorShape3D(bands=1, y=512, x=512)
self.assertEqual(model_dim_to_tensorshape(dim_3d_v), mts_3d_v)

dim_4d_v: List[TSP.Dimension] = [
TSP.Dimension(dim_value=None),
TSP.Dimension(dim_value=512),
TSP.Dimension(dim_value=512),
TSP.Dimension(dim_value=4)
]
mts_4d_v = MlTensorShape3D(bands=4, y=512, x=512)
self.assertEqual(model_dim_to_tensorshape(dim_4d_v), mts_4d_v)

def test_uploading_onnx_model(self):

clf = RandomForestClassifier(random_state=42)
Expand All @@ -40,8 +83,9 @@ def test_uploading_onnx_model(self):
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
num_input_bands=2,
output_type=RasterDataType.I64,
input_shape=MlTensorShape3D(y=1, x=1, bands=2),
output_shape=MlTensorShape3D(y=1, x=1, bands=1)
),
display_name="Decision Tree",
description="A simple decision tree model",
Expand Down Expand Up @@ -77,16 +121,17 @@ def test_uploading_onnx_model(self):
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
num_input_bands=4,
output_type=RasterDataType.I64,
input_shape=MlTensorShape3D(y=1, x=1, bands=4),
output_shape=MlTensorShape3D(y=1, x=1, bands=1)
),
display_name="Decision Tree",
description="A simple decision tree model",
)
)
self.assertEqual(
str(exception.exception),
'Model input has 2 bands, but 4 bands are expected'
'Input shape bands=2 x=1 y=1 and metadata bands=4 x=1 y=1 not equal!'
)

with self.assertRaises(ge.InputException) as exception:
Expand All @@ -97,8 +142,9 @@ def test_uploading_onnx_model(self):
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F64,
num_input_bands=2,
output_type=RasterDataType.I64,
input_shape=MlTensorShape3D(y=1, x=1, bands=2),
output_shape=MlTensorShape3D(y=1, x=1, bands=1)
),
display_name="Decision Tree",
description="A simple decision tree model",
Expand All @@ -117,8 +163,9 @@ def test_uploading_onnx_model(self):
metadata=MlModelMetadata(
file_name="model.onnx",
input_type=RasterDataType.F32,
num_input_bands=2,
output_type=RasterDataType.I32,
input_shape=MlTensorShape3D(y=1, x=1, bands=2),
output_shape=MlTensorShape3D(y=1, x=1, bands=1)
),
display_name="Decision Tree",
description="A simple decision tree model",
Expand Down