4
4
from numpy .testing import assert_array_equal
5
5
6
6
from tiktorch .converters import (
7
- NamedExplicitOutputShape ,
8
- NamedImplicitOutputShape ,
9
- NamedParametrizedShape ,
10
7
input_shape_to_pb_input_shape ,
11
8
numpy_to_pb_tensor ,
12
9
output_shape_to_pb_output_shape ,
15
12
xarray_to_pb_tensor ,
16
13
)
17
14
from tiktorch .proto import inference_pb2
15
+ from tiktorch .server .session .process import AxisWithValue , ParameterizedShape , ShapeWithHalo , ShapeWithReference
18
16
19
17
20
18
def _numpy_to_pb_tensor (arr ):
@@ -186,32 +184,12 @@ def test_should_same_data(self, shape):
186
184
187
185
188
186
class TestShapeConversions :
189
- def to_named_explicit_shape (self , shape , axes , halo ):
190
- return NamedExplicitOutputShape (
191
- halo = [(name , dim ) for name , dim in zip (axes , halo )], shape = [(name , dim ) for name , dim in zip (axes , shape )]
192
- )
193
-
194
- def to_named_implicit_shape (self , axes , halo , offset , scales , reference_tensor ):
195
- return NamedImplicitOutputShape (
196
- halo = [(name , dim ) for name , dim in zip (axes , halo )],
197
- offset = [(name , dim ) for name , dim in zip (axes , offset )],
198
- scale = [(name , scale ) for name , scale in zip (axes , scales )],
199
- reference_tensor = reference_tensor ,
200
- )
201
-
202
- def to_named_paramtrized_shape (self , min_shape , axes , step ):
203
- return NamedParametrizedShape (
204
- min_shape = [(name , dim ) for name , dim in zip (axes , min_shape )],
205
- step_shape = [(name , dim ) for name , dim in zip (axes , step )],
206
- )
207
-
208
187
@pytest .mark .parametrize (
209
188
"shape,axes,halo" ,
210
- [((42 ,), "x" , (0 ,)), ((42 , 128 , 5 ), "abc " , (1 , 1 , 1 )), ((5 , 4 , 3 , 2 , 1 , 42 ), "btzyxc" , (1 , 2 , 3 , 4 , 5 , 24 ))],
189
+ [((42 ,), "x" , (0 ,)), ((42 , 128 , 5 ), "xyz " , (1 , 1 , 1 )), ((5 , 4 , 3 , 2 , 1 , 42 ), "btzyxc" , (1 , 2 , 3 , 4 , 5 , 24 ))],
211
190
)
212
191
def test_explicit_output_shape (self , shape , axes , halo ):
213
- named_shape = self .to_named_explicit_shape (shape , axes , halo )
214
- pb_shape = output_shape_to_pb_output_shape (named_shape )
192
+ pb_shape = output_shape_to_pb_output_shape (ShapeWithHalo .from_values (shape = shape , halo = halo , axes = axes ))
215
193
216
194
assert pb_shape .shapeType == 0
217
195
assert pb_shape .referenceTensor == ""
@@ -223,11 +201,13 @@ def test_explicit_output_shape(self, shape, axes, halo):
223
201
224
202
@pytest .mark .parametrize (
225
203
"axes,halo,offset,scales,reference_tensor" ,
226
- [("x" , (0 ,), (10 ,), (1.0 ,), "forty-two" ), ("abc " , (1 , 1 , 1 ), (1 , 2 , 3 ), (1.0 , 2.0 , 3.0 ), "helloworld" )],
204
+ [("x" , (0 ,), (10 ,), (1.0 ,), "forty-two" ), ("xyz " , (1 , 1 , 1 ), (1 , 2 , 3 ), (1.0 , 2.0 , 3.0 ), "helloworld" )],
227
205
)
228
206
def test_implicit_output_shape (self , axes , halo , offset , scales , reference_tensor ):
229
- named_shape = self .to_named_implicit_shape (axes , halo , offset , scales , reference_tensor )
230
- pb_shape = output_shape_to_pb_output_shape (named_shape )
207
+ shape = ShapeWithReference .from_values (
208
+ axes = axes , halo = halo , offset = offset , scale = scales , reference_tensor = reference_tensor
209
+ )
210
+ pb_shape = output_shape_to_pb_output_shape (shape )
231
211
232
212
assert pb_shape .shapeType == 1
233
213
assert pb_shape .referenceTensor == reference_tensor
@@ -248,11 +228,10 @@ def test_output_shape_raises(self):
248
228
249
229
@pytest .mark .parametrize (
250
230
"shape,axes" ,
251
- [((42 ,), "x" ), ((42 , 128 , 5 ), "abc " ), ((5 , 4 , 3 , 2 , 1 , 42 ), "btzyxc" )],
231
+ [((42 ,), "x" ), ((42 , 128 , 5 ), "xyz " ), ((5 , 4 , 3 , 2 , 1 , 42 ), "btzyxc" )],
252
232
)
253
233
def test_explicit_input_shape (self , shape , axes ):
254
- named_shape = [(name , dim ) for name , dim in zip (axes , shape )]
255
- pb_shape = input_shape_to_pb_input_shape (named_shape )
234
+ pb_shape = input_shape_to_pb_input_shape (AxisWithValue (axes = axes , values = shape ))
256
235
257
236
assert pb_shape .shapeType == 0
258
237
assert [(d .name , d .size ) for d in pb_shape .shape .namedInts ] == [(name , size ) for name , size in zip (axes , shape )]
@@ -261,13 +240,14 @@ def test_explicit_input_shape(self, shape, axes):
261
240
"min_shape,axes,step" ,
262
241
[
263
242
((42 ,), "x" , (5 ,)),
264
- ((42 , 128 , 5 ), "abc " , (1 , 2 , 3 )),
243
+ ((42 , 128 , 5 ), "xyz " , (1 , 2 , 3 )),
265
244
((5 , 4 , 3 , 2 , 1 , 42 ), "btzyxc" , (15 , 24 , 33 , 42 , 51 , 642 )),
266
245
],
267
246
)
268
247
def test_parametrized_input_shape (self , min_shape , axes , step ):
269
- named_shape = self .to_named_paramtrized_shape (min_shape , axes , step )
270
- pb_shape = input_shape_to_pb_input_shape (named_shape )
248
+ pb_shape = input_shape_to_pb_input_shape (
249
+ ParameterizedShape .from_values (axes = axes , steps = step , min_shape = min_shape )
250
+ )
271
251
272
252
assert pb_shape .shapeType == 1
273
253
assert [(d .name , d .size ) for d in pb_shape .shape .namedInts ] == [
0 commit comments