Skip to content

Commit 0c1cf06

Browse files
committed
Refactor model info and shapes
The previous design of the ModelInfo was coupled to the representation of the ModelSession message used for the grpc communication. The new design creates a more feature-rich interface to work easily with the concept of shape, and the compression of the ModelInfo needed to be transferred by the server is localized.
1 parent 609cccf commit 0c1cf06

File tree

5 files changed

+363
-100
lines changed

5 files changed

+363
-100
lines changed

tests/test_converters.py

+14-34
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
from numpy.testing import assert_array_equal
55

66
from tiktorch.converters import (
7-
NamedExplicitOutputShape,
8-
NamedImplicitOutputShape,
9-
NamedParametrizedShape,
107
input_shape_to_pb_input_shape,
118
numpy_to_pb_tensor,
129
output_shape_to_pb_output_shape,
@@ -15,6 +12,7 @@
1512
xarray_to_pb_tensor,
1613
)
1714
from tiktorch.proto import inference_pb2
15+
from tiktorch.server.session.process import AxisWithValue, ParameterizedShape, ShapeWithHalo, ShapeWithReference
1816

1917

2018
def _numpy_to_pb_tensor(arr):
@@ -186,32 +184,12 @@ def test_should_same_data(self, shape):
186184

187185

188186
class TestShapeConversions:
189-
def to_named_explicit_shape(self, shape, axes, halo):
190-
return NamedExplicitOutputShape(
191-
halo=[(name, dim) for name, dim in zip(axes, halo)], shape=[(name, dim) for name, dim in zip(axes, shape)]
192-
)
193-
194-
def to_named_implicit_shape(self, axes, halo, offset, scales, reference_tensor):
195-
return NamedImplicitOutputShape(
196-
halo=[(name, dim) for name, dim in zip(axes, halo)],
197-
offset=[(name, dim) for name, dim in zip(axes, offset)],
198-
scale=[(name, scale) for name, scale in zip(axes, scales)],
199-
reference_tensor=reference_tensor,
200-
)
201-
202-
def to_named_paramtrized_shape(self, min_shape, axes, step):
203-
return NamedParametrizedShape(
204-
min_shape=[(name, dim) for name, dim in zip(axes, min_shape)],
205-
step_shape=[(name, dim) for name, dim in zip(axes, step)],
206-
)
207-
208187
@pytest.mark.parametrize(
209188
"shape,axes,halo",
210-
[((42,), "x", (0,)), ((42, 128, 5), "abc", (1, 1, 1)), ((5, 4, 3, 2, 1, 42), "btzyxc", (1, 2, 3, 4, 5, 24))],
189+
[((42,), "x", (0,)), ((42, 128, 5), "xyz", (1, 1, 1)), ((5, 4, 3, 2, 1, 42), "btzyxc", (1, 2, 3, 4, 5, 24))],
211190
)
212191
def test_explicit_output_shape(self, shape, axes, halo):
213-
named_shape = self.to_named_explicit_shape(shape, axes, halo)
214-
pb_shape = output_shape_to_pb_output_shape(named_shape)
192+
pb_shape = output_shape_to_pb_output_shape(ShapeWithHalo.from_values(shape=shape, halo=halo, axes=axes))
215193

216194
assert pb_shape.shapeType == 0
217195
assert pb_shape.referenceTensor == ""
@@ -223,11 +201,13 @@ def test_explicit_output_shape(self, shape, axes, halo):
223201

224202
@pytest.mark.parametrize(
225203
"axes,halo,offset,scales,reference_tensor",
226-
[("x", (0,), (10,), (1.0,), "forty-two"), ("abc", (1, 1, 1), (1, 2, 3), (1.0, 2.0, 3.0), "helloworld")],
204+
[("x", (0,), (10,), (1.0,), "forty-two"), ("xyz", (1, 1, 1), (1, 2, 3), (1.0, 2.0, 3.0), "helloworld")],
227205
)
228206
def test_implicit_output_shape(self, axes, halo, offset, scales, reference_tensor):
229-
named_shape = self.to_named_implicit_shape(axes, halo, offset, scales, reference_tensor)
230-
pb_shape = output_shape_to_pb_output_shape(named_shape)
207+
shape = ShapeWithReference.from_values(
208+
axes=axes, halo=halo, offset=offset, scale=scales, reference_tensor=reference_tensor
209+
)
210+
pb_shape = output_shape_to_pb_output_shape(shape)
231211

232212
assert pb_shape.shapeType == 1
233213
assert pb_shape.referenceTensor == reference_tensor
@@ -248,11 +228,10 @@ def test_output_shape_raises(self):
248228

249229
@pytest.mark.parametrize(
250230
"shape,axes",
251-
[((42,), "x"), ((42, 128, 5), "abc"), ((5, 4, 3, 2, 1, 42), "btzyxc")],
231+
[((42,), "x"), ((42, 128, 5), "xyz"), ((5, 4, 3, 2, 1, 42), "btzyxc")],
252232
)
253233
def test_explicit_input_shape(self, shape, axes):
254-
named_shape = [(name, dim) for name, dim in zip(axes, shape)]
255-
pb_shape = input_shape_to_pb_input_shape(named_shape)
234+
pb_shape = input_shape_to_pb_input_shape(AxisWithValue(axes=axes, values=shape))
256235

257236
assert pb_shape.shapeType == 0
258237
assert [(d.name, d.size) for d in pb_shape.shape.namedInts] == [(name, size) for name, size in zip(axes, shape)]
@@ -261,13 +240,14 @@ def test_explicit_input_shape(self, shape, axes):
261240
"min_shape,axes,step",
262241
[
263242
((42,), "x", (5,)),
264-
((42, 128, 5), "abc", (1, 2, 3)),
243+
((42, 128, 5), "xyz", (1, 2, 3)),
265244
((5, 4, 3, 2, 1, 42), "btzyxc", (15, 24, 33, 42, 51, 642)),
266245
],
267246
)
268247
def test_parametrized_input_shape(self, min_shape, axes, step):
269-
named_shape = self.to_named_paramtrized_shape(min_shape, axes, step)
270-
pb_shape = input_shape_to_pb_input_shape(named_shape)
248+
pb_shape = input_shape_to_pb_input_shape(
249+
ParameterizedShape.from_values(axes=axes, steps=step, min_shape=min_shape)
250+
)
271251

272252
assert pb_shape.shapeType == 1
273253
assert [(d.name, d.size) for d in pb_shape.shape.namedInts] == [
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import pytest
2+
3+
from tiktorch.server.session.process import AxisWithValue, ParameterizedShape
4+
5+
6+
@pytest.mark.parametrize(
7+
"min_shape, step, axes, expected",
8+
[
9+
((512, 512), (10, 10), "yx", (512, 512)),
10+
((256, 512), (10, 10), "yx", (256, 512)),
11+
((256, 256), (2, 2), "yx", (512, 512)),
12+
((128, 256), (2, 2), "yx", (384, 512)),
13+
((64, 64, 64), (1, 1, 1), "zyx", (64, 64, 64)),
14+
((2, 64, 64), (1, 1, 1), "zyx", (2, 64, 64)),
15+
((2, 2, 64), (1, 1, 1), "zyx", (2, 2, 64)),
16+
((2, 2, 32), (1, 1, 1), "zyx", (34, 34, 64)),
17+
((42, 10, 512, 512), (0, 0, 10, 10), "tcyx", (42, 10, 512, 512)),
18+
],
19+
)
20+
def test_enforce_min_shape(min_shape, step, axes, expected):
21+
shape = ParameterizedShape.from_values(min_shape, step, axes)
22+
assert shape.get_total_shape().values == expected
23+
24+
25+
def test_param_shape_set_custom_multiplier():
26+
min_shape = (512, 512, 256)
27+
step = (2, 2, 2)
28+
axes = "zyx"
29+
30+
shape = ParameterizedShape.from_values(min_shape, step, axes)
31+
shape.multiplier = 2
32+
assert shape.get_total_shape().values == (516, 516, 260)
33+
34+
assert shape.get_total_shape(4).values == (520, 520, 264)
35+
assert shape.multiplier == 4
36+
37+
with pytest.raises(ValueError):
38+
shape.multiplier = -1
39+
40+
41+
@pytest.mark.parametrize(
42+
"sizes, axes, spatial_axes, spatial_sizes",
43+
[
44+
((512, 512), "yx", "yx", (512, 512)),
45+
((1, 256, 512), "tyx", "yx", (256, 512)),
46+
((256, 1, 512), "ytx", "yx", (256, 512)),
47+
((128, 256, 1), "yxt", "yx", (128, 256)),
48+
((64, 64, 64), "zyx", "zyx", (64, 64, 64)),
49+
((1, 2, 64, 64), "bzyx", "zyx", (2, 64, 64)),
50+
((1, 2, 3, 64), "zbyx", "zyx", (1, 3, 64)),
51+
((1, 2, 3, 4), "zybx", "zyx", (1, 2, 4)),
52+
((1, 2, 3, 4, 5), "tczyx", "zyx", (3, 4, 5)),
53+
],
54+
)
55+
def test_spatial_axes(sizes, axes, spatial_axes, spatial_sizes):
56+
shape = AxisWithValue(axes, sizes)
57+
assert shape.spatial_values == spatial_sizes
58+
assert shape.spatial_axes == spatial_axes

tiktorch/converters.py

+43-12
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
import dataclasses
2-
from typing import List, Tuple, Union
2+
from typing import List, Tuple
33

44
import numpy as np
55
import xarray as xr
66

77
from tiktorch.proto import inference_pb2
8+
from tiktorch.server.session.process import (
9+
AxisWithValue,
10+
InputShapes,
11+
ModelInfo,
12+
OutputShapes,
13+
ParameterizedShape,
14+
ShapeWithHalo,
15+
ShapeWithReference,
16+
)
817

918
# pairs of axis-shape for a single tensor
1019
NamedInt = Tuple[str, int]
@@ -33,6 +42,28 @@ class NamedImplicitOutputShape:
3342
halo: NamedShape
3443

3544

45+
def info2session(session_id: str, model_info: ModelInfo) -> inference_pb2.ModelSession:
46+
inputAxes = "".join(input_shape.axes for input_shape in model_info.input_shapes.values())
47+
outputAxes = "".join(output_shape.axes for output_shape in model_info.output_shapes.values())
48+
pb_input_shapes = [input_shape_to_pb_input_shape(shape) for shape in model_info.input_shapes.values()]
49+
pb_output_shapes = [output_shape_to_pb_output_shape(shape) for shape in model_info.output_shapes.values()]
50+
return inference_pb2.ModelSession(
51+
id=session_id,
52+
name=model_info.name,
53+
inputAxes=inputAxes,
54+
outputAxes=outputAxes,
55+
inputShapes=pb_input_shapes,
56+
outputShapes=pb_output_shapes,
57+
hasTraining=False,
58+
inputNames=list(model_info.input_shapes.keys()),
59+
outputNames=list(model_info.output_shapes.keys()),
60+
)
61+
62+
63+
def session2info(model_session: inference_pb2.ModelSession) -> ModelInfo:
64+
pass
65+
66+
3667
def numpy_to_pb_tensor(array: np.ndarray, axistags=None) -> inference_pb2.Tensor:
3768
if axistags:
3869
shape = [inference_pb2.NamedInt(size=dim, name=name) for dim, name in zip(array.shape, axistags)]
@@ -46,44 +77,44 @@ def xarray_to_pb_tensor(array: xr.DataArray) -> inference_pb2.Tensor:
4677
return inference_pb2.Tensor(dtype=str(array.dtype), shape=shape, buffer=bytes(array.data))
4778

4879

49-
def name_int_tuples_to_pb_NamedInts(name_int_tuples) -> inference_pb2.NamedInts:
80+
def name_int_tuples_to_pb_NamedInts(name_int_tuples: AxisWithValue[int]) -> inference_pb2.NamedInts:
5081
return inference_pb2.NamedInts(
5182
namedInts=[inference_pb2.NamedInt(size=dim, name=name) for name, dim in name_int_tuples]
5283
)
5384

5485

55-
def name_float_tuples_to_pb_NamedFloats(name_float_tuples) -> inference_pb2.NamedFloats:
86+
def name_float_tuples_to_pb_NamedFloats(name_float_tuples: AxisWithValue[float]) -> inference_pb2.NamedFloats:
5687
return inference_pb2.NamedFloats(
5788
namedFloats=[inference_pb2.NamedFloat(size=dim, name=name) for name, dim in name_float_tuples]
5889
)
5990

6091

61-
def input_shape_to_pb_input_shape(input_shape: Union[NamedShape, NamedParametrizedShape]) -> inference_pb2.InputShape:
62-
if isinstance(input_shape, NamedParametrizedShape):
92+
def input_shape_to_pb_input_shape(input_shape: InputShapes) -> inference_pb2.InputShape:
93+
if isinstance(input_shape, ParameterizedShape):
6394
return inference_pb2.InputShape(
6495
shapeType=1,
6596
shape=name_int_tuples_to_pb_NamedInts(input_shape.min_shape),
66-
stepShape=name_int_tuples_to_pb_NamedInts(input_shape.step_shape),
97+
stepShape=name_int_tuples_to_pb_NamedInts(input_shape.steps),
6798
)
68-
else:
99+
elif isinstance(input_shape, AxisWithValue):
69100
return inference_pb2.InputShape(
70101
shapeType=0,
71102
shape=name_int_tuples_to_pb_NamedInts(input_shape),
72103
)
104+
else:
105+
raise ValueError(f"Unexpected shape {input_shape}")
73106

74107

75-
def output_shape_to_pb_output_shape(
76-
output_shape: Union[NamedExplicitOutputShape, NamedImplicitOutputShape]
77-
) -> inference_pb2.InputShape:
78-
if isinstance(output_shape, NamedImplicitOutputShape):
108+
def output_shape_to_pb_output_shape(output_shape: OutputShapes) -> inference_pb2.InputShape:
109+
if isinstance(output_shape, ShapeWithReference):
79110
return inference_pb2.OutputShape(
80111
shapeType=1,
81112
halo=name_int_tuples_to_pb_NamedInts(output_shape.halo),
82113
referenceTensor=output_shape.reference_tensor,
83114
scale=name_float_tuples_to_pb_NamedFloats(output_shape.scale),
84115
offset=name_float_tuples_to_pb_NamedFloats(output_shape.offset),
85116
)
86-
elif isinstance(output_shape, NamedExplicitOutputShape):
117+
elif isinstance(output_shape, ShapeWithHalo):
87118
return inference_pb2.OutputShape(
88119
shapeType=0,
89120
shape=name_int_tuples_to_pb_NamedInts(output_shape.shape),

tiktorch/server/grpc/inference_servicer.py

+2-14
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import xarray
99

1010
from tiktorch import converters
11+
from tiktorch.converters import info2session
1112
from tiktorch.proto import inference_pb2, inference_pb2_grpc
1213
from tiktorch.server.data_store import IDataStore
1314
from tiktorch.server.device_pool import DeviceStatus, IDevicePool
@@ -52,20 +53,7 @@ def CreateModelSession(
5253
lease.terminate()
5354
raise
5455

55-
pb_input_shapes = [converters.input_shape_to_pb_input_shape(shape) for shape in model_info.input_shapes]
56-
pb_output_shapes = [converters.output_shape_to_pb_output_shape(shape) for shape in model_info.output_shapes]
57-
58-
return inference_pb2.ModelSession(
59-
id=session.id,
60-
name=model_info.name,
61-
inputAxes=model_info.input_axes,
62-
outputAxes=model_info.output_axes,
63-
inputShapes=pb_input_shapes,
64-
hasTraining=False,
65-
outputShapes=pb_output_shapes,
66-
inputNames=model_info.input_names,
67-
outputNames=model_info.output_names,
68-
)
56+
return info2session(session.id, model_info)
6957

7058
def CreateDatasetDescription(
7159
self, request: inference_pb2.CreateDatasetDescriptionRequest, context

0 commit comments

Comments
 (0)