Skip to content

Commit d5d6f04

Browse files
authored
Merge pull request #648 from robertknight/shape-start-end-attrs
Support `start` and `end` attributes for `Shape` operator
2 parents 37aa6d3 + acbd980 commit d5d6f04

File tree

9 files changed

+413
-43
lines changed

9 files changed

+413
-43
lines changed

Diff for: rten-convert/rten_convert/converter.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -991,8 +991,13 @@ def op_node_from_onnx_operator(
991991
)
992992

993993
case "Shape":
994-
op_reader.check_attr("end", "int", 0)
995-
op_reader.check_attr("start", "int", 0)
994+
attrs = sg.ShapeAttrsT()
995+
start = op_reader.get_attr("start", "int", None)
996+
if start is not None:
997+
attrs.start = start
998+
end = op_reader.get_attr("end", "int", None)
999+
if end is not None:
1000+
attrs.end = end
9961001

9971002
case "Softmax":
9981003
attrs = sg.SoftmaxAttrsT()

Diff for: rten-convert/rten_convert/schema_generated.py

+94-1
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ class OperatorAttrs(object):
204204
QuantizeLinearAttrs = 42
205205
DepthToSpaceAttrs = 43
206206
CastLikeAttrs = 44
207+
ShapeAttrs = 45
207208

208209
def OperatorAttrsCreator(unionType, table):
209210
from flatbuffers.table import Table
@@ -297,6 +298,8 @@ def OperatorAttrsCreator(unionType, table):
297298
return DepthToSpaceAttrsT.InitFromBuf(table.Bytes, table.Pos)
298299
if unionType == OperatorAttrs.CastLikeAttrs:
299300
return CastLikeAttrsT.InitFromBuf(table.Bytes, table.Pos)
301+
if unionType == OperatorAttrs.ShapeAttrs:
302+
return ShapeAttrsT.InitFromBuf(table.Bytes, table.Pos)
300303
return None
301304

302305

@@ -4631,6 +4634,96 @@ def Pack(self, builder):
46314634
return scatterNdattrs
46324635

46334636

4637+
class ShapeAttrs(object):
4638+
__slots__ = ['_tab']
4639+
4640+
@classmethod
4641+
def GetRootAs(cls, buf, offset=0):
4642+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
4643+
x = ShapeAttrs()
4644+
x.Init(buf, n + offset)
4645+
return x
4646+
4647+
@classmethod
4648+
def GetRootAsShapeAttrs(cls, buf, offset=0):
4649+
"""This method is deprecated. Please switch to GetRootAs."""
4650+
return cls.GetRootAs(buf, offset)
4651+
@classmethod
4652+
def ShapeAttrsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
4653+
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x52\x54\x45\x4E", size_prefixed=size_prefixed)
4654+
4655+
# ShapeAttrs
4656+
def Init(self, buf, pos):
4657+
self._tab = flatbuffers.table.Table(buf, pos)
4658+
4659+
# ShapeAttrs
4660+
def Start(self):
4661+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
4662+
if o != 0:
4663+
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
4664+
return None
4665+
4666+
# ShapeAttrs
4667+
def End(self):
4668+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
4669+
if o != 0:
4670+
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
4671+
return None
4672+
4673+
def ShapeAttrsStart(builder):
4674+
builder.StartObject(2)
4675+
4676+
def ShapeAttrsAddStart(builder, start):
4677+
builder.PrependInt32Slot(0, start, None)
4678+
4679+
def ShapeAttrsAddEnd(builder, end):
4680+
builder.PrependInt32Slot(1, end, None)
4681+
4682+
def ShapeAttrsEnd(builder):
4683+
return builder.EndObject()
4684+
4685+
4686+
4687+
class ShapeAttrsT(object):
4688+
4689+
# ShapeAttrsT
4690+
def __init__(self):
4691+
self.start = None # type: Optional[int]
4692+
self.end = None # type: Optional[int]
4693+
4694+
@classmethod
4695+
def InitFromBuf(cls, buf, pos):
4696+
shapeAttrs = ShapeAttrs()
4697+
shapeAttrs.Init(buf, pos)
4698+
return cls.InitFromObj(shapeAttrs)
4699+
4700+
@classmethod
4701+
def InitFromPackedBuf(cls, buf, pos=0):
4702+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
4703+
return cls.InitFromBuf(buf, pos+n)
4704+
4705+
@classmethod
4706+
def InitFromObj(cls, shapeAttrs):
4707+
x = ShapeAttrsT()
4708+
x._UnPack(shapeAttrs)
4709+
return x
4710+
4711+
# ShapeAttrsT
4712+
def _UnPack(self, shapeAttrs):
4713+
if shapeAttrs is None:
4714+
return
4715+
self.start = shapeAttrs.Start()
4716+
self.end = shapeAttrs.End()
4717+
4718+
# ShapeAttrsT
4719+
def Pack(self, builder):
4720+
ShapeAttrsStart(builder)
4721+
ShapeAttrsAddStart(builder, self.start)
4722+
ShapeAttrsAddEnd(builder, self.end)
4723+
shapeAttrs = ShapeAttrsEnd(builder)
4724+
return shapeAttrs
4725+
4726+
46344727
class SoftmaxAttrs(object):
46354728
__slots__ = ['_tab']
46364729

@@ -5223,7 +5316,7 @@ class OperatorNodeT(object):
52235316
def __init__(self):
52245317
self.type = 0 # type: int
52255318
self.attrsType = 0 # type: int
5226-
self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT, EluAttrsT, RandomUniformLikeAttrsT, RandomNormalAttrsT, RandomNormalLikeAttrsT, GatherNDAttrsT, GeluAttrsT, EinsumAttrsT, IfAttrsT, PadAttrsT, DequantizeLinearAttrsT, QuantizeLinearAttrsT, DepthToSpaceAttrsT, CastLikeAttrsT]
5319+
self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT, EluAttrsT, RandomUniformLikeAttrsT, RandomNormalAttrsT, RandomNormalLikeAttrsT, GatherNDAttrsT, GeluAttrsT, EinsumAttrsT, IfAttrsT, PadAttrsT, DequantizeLinearAttrsT, QuantizeLinearAttrsT, DepthToSpaceAttrsT, CastLikeAttrsT, ShapeAttrsT]
52275320
self.inputs = None # type: List[int]
52285321
self.outputs = None # type: List[int]
52295322

Diff for: src/graph.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -1369,7 +1369,7 @@ mod tests {
13691369
let input_b_id = g.add_value(Some("input_b"), None, None);
13701370

13711371
let (add_op, add_out) = g.add_simple_op("add", Add {}, &[input_a_id, input_b_id]);
1372-
let (shape_op, shape_out) = g.add_simple_op("shape", Shape {}, &[input_a_id]);
1372+
let (shape_op, shape_out) = g.add_simple_op("shape", Shape::default(), &[input_a_id]);
13731373

13741374
// The execution plan could run operators in either order and produce
13751375
// the correct output. Since the `Add` op has the _potential_ to run in
@@ -1624,7 +1624,12 @@ mod tests {
16241624
// as opposed to passing a shorter input list. This enables omitting
16251625
// an input but still providing subsequent ones.
16261626
let output = g.add_value(None, None, None);
1627-
g.add_op(Some("shape"), Box::new(Shape {}), &[None], &[Some(output)]);
1627+
g.add_op(
1628+
Some("shape"),
1629+
Box::new(Shape::default()),
1630+
&[None],
1631+
&[Some(output)],
1632+
);
16281633

16291634
let results = g.run(vec![], &[output], None, None);
16301635

Diff for: src/model.rs

+11-3
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ mod tests {
887887
use crate::ops;
888888
use crate::ops::{
889889
BoxOrder, CoordTransformMode, DataType, DepthToSpaceMode, NearestMode, OpError, Output,
890-
ResizeMode, Scalar,
890+
ResizeMode, Scalar, Shape,
891891
};
892892
use crate::{ModelLoadError, OpRegistry, ReadOpError};
893893

@@ -1214,7 +1214,12 @@ mod tests {
12141214

12151215
let output_node = graph_builder.add_value("output", None, None);
12161216
graph_builder.add_output(output_node);
1217-
graph_builder.add_operator("shape", OpType::Shape, &[None], &[output_node]);
1217+
graph_builder.add_operator(
1218+
"shape",
1219+
OpType::Shape(Shape::default()),
1220+
&[None],
1221+
&[output_node],
1222+
);
12181223

12191224
let graph = graph_builder.finish();
12201225
builder.set_graph(graph);
@@ -1597,7 +1602,10 @@ mod tests {
15971602

15981603
add_operator!(Round, [input_node]);
15991604

1600-
add_operator!(Shape, [input_node]);
1605+
add_operator!(Shape, [input_node], {
1606+
start: Some(1),
1607+
end: Some(-1),
1608+
});
16011609
add_operator!(Sigmoid, [input_node]);
16021610
add_operator!(Sign, [input_node]);
16031611
add_operator!(Sin, [input_node]);

Diff for: src/model_builder.rs

+8-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::ops::{
1212
Gelu, Gemm, HardSigmoid, InstanceNormalization, LayerNormalization, LeakyRelu, LogSoftmax,
1313
MaxPool, Mod, NearestMode, NonMaxSuppression, OneHot, Padding, QuantizeLinear, ReduceMax,
1414
ReduceMean, ReduceMin, ReduceProd, ReduceSum, ReduceSumSquare, Reshape, Resize, ResizeMode,
15-
Scalar, ScatterElements, ScatterReduction, Softmax, Split, TopK, Transpose, Trilu,
15+
Scalar, ScatterElements, ScatterReduction, Shape, Softmax, Split, TopK, Transpose, Trilu,
1616
};
1717
use crate::schema_generated as sg;
1818

@@ -119,7 +119,7 @@ pub enum OpType<'a> {
119119
Round,
120120
QuantizeLinear(QuantizeLinear),
121121
ScatterElements(ScatterElements),
122-
Shape,
122+
Shape(Shape),
123123
Sigmoid,
124124
Sign,
125125
Sin,
@@ -828,7 +828,12 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
828828
}
829829
})
830830
}
831-
OpType::Shape => op!(Shape),
831+
OpType::Shape(args) => op_with_attrs!(Shape, ShapeAttrs, {
832+
sg::ShapeAttrsArgs {
833+
start: args.start,
834+
end: args.end,
835+
}
836+
}),
832837
OpType::Sigmoid => op!(Sigmoid),
833838
OpType::Slice => op!(Slice),
834839
OpType::Sin => op!(Sin),

Diff for: src/op_registry.rs

+15-1
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,21 @@ impl_read_op!(
847847
})
848848
}
849849
);
850-
impl_read_op!(Shape);
850+
851+
impl ReadOp for ops::Shape {
852+
fn op_type() -> sg::OperatorType {
853+
OperatorType::Shape
854+
}
855+
856+
fn read(op: &OperatorNode, _ctx: &dyn OpLoadContext) -> Result<Self, ReadOpError> {
857+
// Shape attributes are optional for backwards compatibility
858+
let attrs = op.attrs_as_shape_attrs();
859+
let start = attrs.and_then(|a| a.start());
860+
let end = attrs.and_then(|a| a.end());
861+
Ok(ops::Shape { start, end })
862+
}
863+
}
864+
851865
impl_read_op!(Sigmoid);
852866
impl_read_op!(Sign);
853867
impl_read_op!(Sin);

0 commit comments

Comments
 (0)