Skip to content

Commit 160eb2c

Browse files
authored
Merge pull request #652 from robertknight/dropout-op
Implement Dropout operator
2 parents 4161071 + 3cf89b3 commit 160eb2c

File tree

9 files changed

+474
-20
lines changed

9 files changed

+474
-20
lines changed

Diff for: rten-convert/rten_convert/converter.py

+4
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,10 @@ def op_node_from_onnx_operator(
816816
attrs.blockSize = attr_reader.require_attr("blocksize", "int")
817817
attrs.mode = attr_reader.get_enum_attr("mode", sg.DepthToSpaceMode, "dcr")
818818

819+
case "Dropout":
820+
attrs = sg.DropoutAttrsT()
821+
attrs.seed = attr_reader.get_attr("seed", "int", None)
822+
819823
case "Einsum":
820824
attrs = sg.EinsumAttrsT()
821825
attrs.equation = attr_reader.require_attr("equation", "string")

Diff for: rten-convert/rten_convert/schema_generated.py

+82-1
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class OperatorType(object):
119119
DepthToSpace = 109
120120
ConvInteger = 110
121121
CastLike = 111
122+
Dropout = 112
122123

123124

124125
class RNNDirection(object):
@@ -205,6 +206,7 @@ class OperatorAttrs(object):
205206
DepthToSpaceAttrs = 43
206207
CastLikeAttrs = 44
207208
ShapeAttrs = 45
209+
DropoutAttrs = 46
208210

209211
def OperatorAttrsCreator(unionType, table):
210212
from flatbuffers.table import Table
@@ -300,6 +302,8 @@ def OperatorAttrsCreator(unionType, table):
300302
return CastLikeAttrsT.InitFromBuf(table.Bytes, table.Pos)
301303
if unionType == OperatorAttrs.ShapeAttrs:
302304
return ShapeAttrsT.InitFromBuf(table.Bytes, table.Pos)
305+
if unionType == OperatorAttrs.DropoutAttrs:
306+
return DropoutAttrsT.InitFromBuf(table.Bytes, table.Pos)
303307
return None
304308

305309

@@ -1113,6 +1117,83 @@ def Pack(self, builder):
11131117
return depthToSpaceAttrs
11141118

11151119

1120+
class DropoutAttrs(object):
1121+
__slots__ = ['_tab']
1122+
1123+
@classmethod
1124+
def GetRootAs(cls, buf, offset=0):
1125+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
1126+
x = DropoutAttrs()
1127+
x.Init(buf, n + offset)
1128+
return x
1129+
1130+
@classmethod
1131+
def GetRootAsDropoutAttrs(cls, buf, offset=0):
1132+
"""This method is deprecated. Please switch to GetRootAs."""
1133+
return cls.GetRootAs(buf, offset)
1134+
@classmethod
1135+
def DropoutAttrsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
1136+
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x52\x54\x45\x4E", size_prefixed=size_prefixed)
1137+
1138+
# DropoutAttrs
1139+
def Init(self, buf, pos):
1140+
self._tab = flatbuffers.table.Table(buf, pos)
1141+
1142+
# DropoutAttrs
1143+
def Seed(self):
1144+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
1145+
if o != 0:
1146+
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
1147+
return None
1148+
1149+
def DropoutAttrsStart(builder):
1150+
builder.StartObject(1)
1151+
1152+
def DropoutAttrsAddSeed(builder, seed):
1153+
builder.PrependInt32Slot(0, seed, None)
1154+
1155+
def DropoutAttrsEnd(builder):
1156+
return builder.EndObject()
1157+
1158+
1159+
1160+
class DropoutAttrsT(object):
1161+
1162+
# DropoutAttrsT
1163+
def __init__(self):
1164+
self.seed = None # type: Optional[int]
1165+
1166+
@classmethod
1167+
def InitFromBuf(cls, buf, pos):
1168+
dropoutAttrs = DropoutAttrs()
1169+
dropoutAttrs.Init(buf, pos)
1170+
return cls.InitFromObj(dropoutAttrs)
1171+
1172+
@classmethod
1173+
def InitFromPackedBuf(cls, buf, pos=0):
1174+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
1175+
return cls.InitFromBuf(buf, pos+n)
1176+
1177+
@classmethod
1178+
def InitFromObj(cls, dropoutAttrs):
1179+
x = DropoutAttrsT()
1180+
x._UnPack(dropoutAttrs)
1181+
return x
1182+
1183+
# DropoutAttrsT
1184+
def _UnPack(self, dropoutAttrs):
1185+
if dropoutAttrs is None:
1186+
return
1187+
self.seed = dropoutAttrs.Seed()
1188+
1189+
# DropoutAttrsT
1190+
def Pack(self, builder):
1191+
DropoutAttrsStart(builder)
1192+
DropoutAttrsAddSeed(builder, self.seed)
1193+
dropoutAttrs = DropoutAttrsEnd(builder)
1194+
return dropoutAttrs
1195+
1196+
11161197
class IntScalar(object):
11171198
__slots__ = ['_tab']
11181199

@@ -5316,7 +5397,7 @@ class OperatorNodeT(object):
53165397
def __init__(self):
53175398
self.type = 0 # type: int
53185399
self.attrsType = 0 # type: int
5319-
self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT, EluAttrsT, RandomUniformLikeAttrsT, RandomNormalAttrsT, RandomNormalLikeAttrsT, GatherNDAttrsT, GeluAttrsT, EinsumAttrsT, IfAttrsT, PadAttrsT, DequantizeLinearAttrsT, QuantizeLinearAttrsT, DepthToSpaceAttrsT, CastLikeAttrsT, ShapeAttrsT]
5400+
self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT, EluAttrsT, RandomUniformLikeAttrsT, RandomNormalAttrsT, RandomNormalLikeAttrsT, GatherNDAttrsT, GeluAttrsT, EinsumAttrsT, IfAttrsT, PadAttrsT, DequantizeLinearAttrsT, QuantizeLinearAttrsT, DepthToSpaceAttrsT, CastLikeAttrsT, ShapeAttrsT, DropoutAttrsT]
53205401
self.inputs = None # type: List[int]
53215402
self.outputs = None # type: List[int]
53225403

Diff for: src/model.rs

+23-2
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,17 @@ mod tests {
13871387
});
13881388

13891389
add_operator!(Div, [input_node, input_node]);
1390+
#[cfg(feature = "random")]
1391+
{
1392+
let dropout_out = graph_builder.add_value("Dropout_out", None, None);
1393+
let dropout_out_mask = graph_builder.add_value("Dropout_out_mask", None, None);
1394+
graph_builder.add_operator(
1395+
"Dropout",
1396+
OpType::Dropout(ops::Dropout { seed: None }),
1397+
&[input_2d].map(Some),
1398+
&[dropout_out, dropout_out_mask],
1399+
);
1400+
}
13901401
add_operator!(Elu, [input_node], { alpha: 1.0 });
13911402
add_operator!(Equal, [input_node, input_node]);
13921403
add_operator!(Erf, [input_node]);
@@ -1697,6 +1708,8 @@ mod tests {
16971708

16981709
for output in op_outputs {
16991710
if [
1711+
"Dropout_out",
1712+
"Dropout_out_mask",
17001713
"Gemm_out",
17011714
"MatMul_out",
17021715
"Range_out",
@@ -1753,15 +1766,23 @@ mod tests {
17531766
assert_eq!(result.len(), 1);
17541767
}
17551768

1756-
// Outputs of ops tested with a 2D input.
1757-
let outputs = vec![
1769+
// Outputs of ops which either have multiple outputs, or which are tested
1770+
// with a 2D input.
1771+
#[allow(unused_mut)]
1772+
let mut outputs = vec![
17581773
"Gemm_out",
17591774
"MatMul_out",
17601775
"Split_out_1",
17611776
"Split_out_2",
17621777
"TopK_out_indices",
17631778
"TopK_out_values",
17641779
];
1780+
1781+
#[cfg(feature = "random")]
1782+
{
1783+
outputs.extend(["Dropout_out", "Dropout_out_mask"]);
1784+
}
1785+
17651786
let input = Tensor::from_data(&[3, 3], vec![1., 2., 3., 4., 5., 6., 7., 8., 9.]);
17661787

17671788
for output in outputs {

Diff for: src/model_builder.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use crate::ops::{
1717
use crate::schema_generated as sg;
1818

1919
#[cfg(feature = "random")]
20-
use crate::ops::{RandomNormal, RandomNormalLike, RandomUniform, RandomUniformLike};
20+
use crate::ops::{Dropout, RandomNormal, RandomNormalLike, RandomUniform, RandomUniformLike};
2121

2222
/// Struct like `crate::ops::If` with subgraph attributes replaced by
2323
/// pre-serialized graphs.
@@ -51,6 +51,8 @@ pub enum OpType<'a> {
5151
DequantizeLinear(DequantizeLinear),
5252
DepthToSpace(DepthToSpace),
5353
Div,
54+
#[cfg(feature = "random")]
55+
Dropout(Dropout),
5456
DynamicQuantizeLinear,
5557
Einsum(Einsum),
5658
Elu(Elu),
@@ -546,6 +548,12 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
546548
}
547549
),
548550
OpType::Div => op!(Div),
551+
#[cfg(feature = "random")]
552+
OpType::Dropout(args) => op_with_attrs!(
553+
Dropout,
554+
DropoutAttrs,
555+
sg::DropoutAttrsArgs { seed: args.seed }
556+
),
549557
OpType::DynamicQuantizeLinear => op!(DynamicQuantizeLinear),
550558
OpType::Einsum(args) => {
551559
let equation = self.builder.create_string(&args.equation);

Diff for: src/op_registry.rs

+18-7
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ impl OpRegistry {
105105
register_op!(DequantizeLinear);
106106
register_op!(DepthToSpace);
107107
register_op!(Div);
108+
109+
#[cfg(feature = "random")]
110+
register_op!(Dropout);
111+
108112
register_op!(DynamicQuantizeLinear);
109113
register_op!(Einsum);
110114
register_op!(Elu);
@@ -154,13 +158,12 @@ impl OpRegistry {
154158
register_op!(QuantizeLinear);
155159

156160
#[cfg(feature = "random")]
157-
register_op!(RandomNormal);
158-
#[cfg(feature = "random")]
159-
register_op!(RandomNormalLike);
160-
#[cfg(feature = "random")]
161-
register_op!(RandomUniform);
162-
#[cfg(feature = "random")]
163-
register_op!(RandomUniformLike);
161+
{
162+
register_op!(RandomNormal);
163+
register_op!(RandomNormalLike);
164+
register_op!(RandomUniform);
165+
register_op!(RandomUniformLike);
166+
}
164167

165168
register_op!(Range);
166169
register_op!(Reciprocal);
@@ -511,6 +514,14 @@ impl_read_op!(
511514
}
512515
);
513516
impl_read_op!(Div);
517+
518+
#[cfg(feature = "random")]
519+
impl_read_op!(
520+
Dropout,
521+
attrs_as_dropout_attrs,
522+
|attrs: sg::DropoutAttrs| { Ok(ops::Dropout { seed: attrs.seed() }) }
523+
);
524+
514525
impl_read_op!(DynamicQuantizeLinear);
515526
impl_read_op!(Einsum, attrs_as_einsum_attrs, |attrs: sg::EinsumAttrs| {
516527
Ok(ops::Einsum {

Diff for: src/ops/mod.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ pub use quantize::{
101101
};
102102

103103
#[cfg(feature = "random")]
104-
pub use random::{RandomNormal, RandomNormalLike, RandomUniform, RandomUniformLike};
104+
pub use random::{Dropout, RandomNormal, RandomNormalLike, RandomUniform, RandomUniformLike};
105105

106106
pub use reduce::{
107107
arg_max, arg_min, cum_sum, nonzero, reduce_l2, reduce_max, reduce_mean, reduce_min,
@@ -952,6 +952,13 @@ impl<'a> InputList<'a> {
952952
self.inputs.to_mut().push(Some(inp.into()))
953953
}
954954

955+
/// Append an optional input to the list.
956+
///
957+
/// This will copy the existing inputs into a new owned vector.
958+
pub fn push_optional<I: Into<Input<'a>>>(&mut self, inp: Option<I>) {
959+
self.inputs.to_mut().push(inp.map(|inp| inp.into()))
960+
}
961+
955962
/// Construct an input list from a slice of non-optional inputs.
956963
///
957964
/// This copies the inputs into a new vector of `Optional<Input>`s. Using

0 commit comments

Comments
 (0)