Skip to content

Commit 4329a65

Browse files
authored
Merge pull request #646 from robertknight/cast-like-op
Implement `CastLike` operator
2 parents ee70cec + 61ea8b1 commit 4329a65

File tree

8 files changed

+291
-16
lines changed

8 files changed

+291
-16
lines changed

rten-convert/rten_convert/schema_generated.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ class OperatorType(object):
118118
MatMulInteger = 108
119119
DepthToSpace = 109
120120
ConvInteger = 110
121+
CastLike = 111
121122

122123

123124
class RNNDirection(object):
@@ -201,6 +202,7 @@ class OperatorAttrs(object):
201202
DequantizeLinearAttrs = 41
202203
QuantizeLinearAttrs = 42
203204
DepthToSpaceAttrs = 43
205+
CastLikeAttrs = 44
204206

205207
def OperatorAttrsCreator(unionType, table):
206208
from flatbuffers.table import Table
@@ -292,6 +294,8 @@ def OperatorAttrsCreator(unionType, table):
292294
return QuantizeLinearAttrsT.InitFromBuf(table.Bytes, table.Pos)
293295
if unionType == OperatorAttrs.DepthToSpaceAttrs:
294296
return DepthToSpaceAttrsT.InitFromBuf(table.Bytes, table.Pos)
297+
if unionType == OperatorAttrs.CastLikeAttrs:
298+
return CastLikeAttrsT.InitFromBuf(table.Bytes, table.Pos)
295299
return None
296300

297301

@@ -873,6 +877,71 @@ def Pack(self, builder):
873877
return castAttrs
874878

875879

880+
class CastLikeAttrs(object):
881+
__slots__ = ['_tab']
882+
883+
@classmethod
884+
def GetRootAs(cls, buf, offset=0):
885+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
886+
x = CastLikeAttrs()
887+
x.Init(buf, n + offset)
888+
return x
889+
890+
@classmethod
891+
def GetRootAsCastLikeAttrs(cls, buf, offset=0):
892+
"""This method is deprecated. Please switch to GetRootAs."""
893+
return cls.GetRootAs(buf, offset)
894+
@classmethod
895+
def CastLikeAttrsBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
896+
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x52\x54\x45\x4E", size_prefixed=size_prefixed)
897+
898+
# CastLikeAttrs
899+
def Init(self, buf, pos):
900+
self._tab = flatbuffers.table.Table(buf, pos)
901+
902+
def CastLikeAttrsStart(builder):
903+
builder.StartObject(0)
904+
905+
def CastLikeAttrsEnd(builder):
906+
return builder.EndObject()
907+
908+
909+
910+
class CastLikeAttrsT(object):
911+
912+
# CastLikeAttrsT
913+
def __init__(self):
914+
pass
915+
916+
@classmethod
917+
def InitFromBuf(cls, buf, pos):
918+
castLikeAttrs = CastLikeAttrs()
919+
castLikeAttrs.Init(buf, pos)
920+
return cls.InitFromObj(castLikeAttrs)
921+
922+
@classmethod
923+
def InitFromPackedBuf(cls, buf, pos=0):
924+
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
925+
return cls.InitFromBuf(buf, pos+n)
926+
927+
@classmethod
928+
def InitFromObj(cls, castLikeAttrs):
929+
x = CastLikeAttrsT()
930+
x._UnPack(castLikeAttrs)
931+
return x
932+
933+
# CastLikeAttrsT
934+
def _UnPack(self, castLikeAttrs):
935+
if castLikeAttrs is None:
936+
return
937+
938+
# CastLikeAttrsT
939+
def Pack(self, builder):
940+
CastLikeAttrsStart(builder)
941+
castLikeAttrs = CastLikeAttrsEnd(builder)
942+
return castLikeAttrs
943+
944+
876945
class ConcatAttrs(object):
877946
__slots__ = ['_tab']
878947

@@ -5153,7 +5222,7 @@ class OperatorNodeT(object):
51535222
def __init__(self):
51545223
self.type = 0 # type: int
51555224
self.attrsType = 0 # type: int
5156-
self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT, EluAttrsT, RandomUniformLikeAttrsT, RandomNormalAttrsT, RandomNormalLikeAttrsT, GatherNDAttrsT, GeluAttrsT, EinsumAttrsT, IfAttrsT, PadAttrsT, DequantizeLinearAttrsT, QuantizeLinearAttrsT, DepthToSpaceAttrsT]
5225+
self.attrs = None # type: Union[None, ArgMaxAttrsT, AveragePoolAttrsT, BatchNormalizationAttrsT, CastAttrsT, ConcatAttrsT, ConstantOfShapeAttrsT, ConvAttrsT, ConvTransposeAttrsT, FlattenAttrsT, GatherAttrsT, GemmAttrsT, GRUAttrsT, LeakyReluAttrsT, LSTMAttrsT, MaxPoolAttrsT, ReduceMeanAttrsT, ReshapeAttrsT, ResizeAttrsT, SplitAttrsT, SoftmaxAttrsT, TransposeAttrsT, ModAttrsT, ScatterElementsAttrsT, OneHotAttrsT, TopKAttrsT, HardSigmoidAttrsT, TriluAttrsT, ScatterNDAttrsT, NonMaxSuppressionAttrsT, LayerNormalizationAttrsT, RandomUniformAttrsT, EluAttrsT, RandomUniformLikeAttrsT, RandomNormalAttrsT, RandomNormalLikeAttrsT, GatherNDAttrsT, GeluAttrsT, EinsumAttrsT, IfAttrsT, PadAttrsT, DequantizeLinearAttrsT, QuantizeLinearAttrsT, DepthToSpaceAttrsT, CastLikeAttrsT]
51575226
self.inputs = None # type: List[int]
51585227
self.outputs = None # type: List[int]
51595228

src/model.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,7 @@ mod tests {
13301330
);
13311331

13321332
add_operator!(Cast, [input_node], { to: ops::DataType::Float });
1333+
add_operator!(CastLike, [input_node, input_node], {});
13331334
add_operator!(Ceil, [input_node]);
13341335

13351336
let clip_min = graph_builder.add_constant(Tensor::from(1.).view());

src/model_builder.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ use crate::graph::{Dimension, NodeId};
66
use crate::header::Header;
77
use crate::number::LeBytes;
88
use crate::ops::{
9-
ArgMax, ArgMin, AveragePool, BatchNormalization, BoxOrder, Cast, Concat, ConstantOfShape, Conv,
10-
ConvInteger, ConvTranspose, CoordTransformMode, DataType, DepthToSpace, DepthToSpaceMode,
11-
DequantizeLinear, Einsum, Elu, Flatten, Gather, GatherElements, GatherND, Gelu, Gemm,
12-
HardSigmoid, InstanceNormalization, LayerNormalization, LeakyRelu, LogSoftmax, MaxPool, Mod,
13-
NearestMode, NonMaxSuppression, OneHot, Padding, QuantizeLinear, ReduceMax, ReduceMean,
14-
ReduceMin, ReduceProd, ReduceSum, ReduceSumSquare, Reshape, Resize, ResizeMode, Scalar,
15-
ScatterElements, ScatterReduction, Softmax, Split, TopK, Transpose, Trilu,
9+
ArgMax, ArgMin, AveragePool, BatchNormalization, BoxOrder, Cast, CastLike, Concat,
10+
ConstantOfShape, Conv, ConvInteger, ConvTranspose, CoordTransformMode, DataType, DepthToSpace,
11+
DepthToSpaceMode, DequantizeLinear, Einsum, Elu, Flatten, Gather, GatherElements, GatherND,
12+
Gelu, Gemm, HardSigmoid, InstanceNormalization, LayerNormalization, LeakyRelu, LogSoftmax,
13+
MaxPool, Mod, NearestMode, NonMaxSuppression, OneHot, Padding, QuantizeLinear, ReduceMax,
14+
ReduceMean, ReduceMin, ReduceProd, ReduceSum, ReduceSumSquare, Reshape, Resize, ResizeMode,
15+
Scalar, ScatterElements, ScatterReduction, Softmax, Split, TopK, Transpose, Trilu,
1616
};
1717
use crate::schema_generated as sg;
1818

@@ -39,6 +39,7 @@ pub enum OpType<'a> {
3939
AveragePool(AveragePool),
4040
BatchNormalization(BatchNormalization),
4141
Cast(Cast),
42+
CastLike(CastLike),
4243
Ceil,
4344
Clip,
4445
Concat(Concat),
@@ -449,6 +450,9 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
449450
to: convert_dtype(args.to),
450451
}
451452
),
453+
OpType::CastLike(_args) => {
454+
op_with_attrs!(CastLike, CastLikeAttrs, sg::CastLikeAttrsArgs {})
455+
}
452456
OpType::Ceil => op!(Ceil),
453457
OpType::Clip => op!(Clip),
454458
OpType::Concat(args) => op_with_attrs!(

src/op_registry.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ impl OpRegistry {
9292
register_op!(AveragePool);
9393
register_op!(BatchNormalization);
9494
register_op!(Cast);
95+
register_op!(CastLike);
9596
register_op!(Ceil);
9697
register_op!(Clip);
9798
register_op!(Concat);
@@ -438,6 +439,11 @@ impl_read_op!(Cast, attrs_as_cast_attrs, |attrs: sg::CastAttrs| {
438439
let to = convert_dtype(attrs.to())?;
439440
Ok(ops::Cast { to })
440441
});
442+
impl_read_op!(
443+
CastLike,
444+
attrs_as_cast_like_attrs,
445+
|_attrs: sg::CastLikeAttrs| { Ok(ops::CastLike {}) }
446+
);
441447
impl_read_op!(Ceil);
442448
impl_read_op!(Clip);
443449
impl_read_op!(Concat, attrs_as_concat_attrs, axis);

src/ops/convert.rs

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,49 @@ impl Operator for Cast {
7979
}
8080
}
8181

82+
#[derive(Debug)]
83+
pub struct CastLike {}
84+
85+
impl Operator for CastLike {
86+
fn name(&self) -> &str {
87+
"CastLike"
88+
}
89+
90+
fn run(&self, pool: &TensorPool, inputs: InputList) -> Result<OutputList, OpError> {
91+
let input = inputs.require(0)?;
92+
let to_type = inputs.require(1)?.dtype();
93+
cast(pool, input, to_type).into_op_result()
94+
}
95+
96+
fn can_run_in_place(&self) -> bool {
97+
true
98+
}
99+
100+
fn run_in_place(
101+
&self,
102+
pool: &TensorPool,
103+
input: Output,
104+
other: InputList,
105+
) -> Result<Output, OpError> {
106+
let to_type = other.require(0)?.dtype();
107+
108+
if input.dtype() == to_type {
109+
Ok(input)
110+
} else {
111+
let converted = cast(pool, input.as_input(), to_type)?;
112+
input.add_to_pool(pool);
113+
Ok(converted)
114+
}
115+
}
116+
}
117+
82118
#[cfg(test)]
83119
mod tests {
84120
use rten_tensor::Tensor;
85121
use rten_testing::TestCases;
86122

87123
use crate::ops::tests::new_pool;
88-
use crate::ops::{Cast, DataType, Operator, Output};
124+
use crate::ops::{Cast, CastLike, DataType, Operator, Output};
89125

90126
#[test]
91127
fn test_cast() {
@@ -159,4 +195,36 @@ mod tests {
159195
assert_eq!(result, case.expected);
160196
})
161197
}
198+
199+
#[test]
200+
fn test_cast_like() {
201+
#[derive(Debug)]
202+
struct Case {
203+
input: Output,
204+
other: Output,
205+
expected: Output,
206+
}
207+
208+
// `CastLike` uses the same conversions as the `Cast` operator,
209+
// so these tests don't check all data type combinations, only that the
210+
// target type is taken from the second argument.
211+
let cases = [
212+
// i32 -> f32
213+
Case {
214+
input: Tensor::from([0i32, 1, 2]).into(),
215+
other: Tensor::from([0f32]).into(),
216+
expected: Tensor::from([0., 1., 2.]).into(),
217+
},
218+
];
219+
220+
cases.test_each(|case| {
221+
let pool = new_pool();
222+
let cast_op = CastLike {};
223+
let result = cast_op
224+
.run(&pool, (&case.input, &case.other).into())
225+
.unwrap()
226+
.remove(0);
227+
assert_eq!(result, case.expected);
228+
})
229+
}
162230
}

src/ops/mod.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ pub use binary_elementwise::{
7272
pub use concat::{concat, tile, Concat, Tile};
7373
pub use control_flow::If;
7474
pub use conv::{conv, conv_integer, conv_transpose, Conv, ConvInteger, ConvTranspose};
75-
pub use convert::Cast;
75+
pub use convert::{Cast, CastLike};
7676
pub use einsum::{einsum, Einsum};
7777
pub use gather::{
7878
gather, gather_elements, gather_nd, scatter_elements, scatter_nd, Gather, GatherElements,
@@ -255,6 +255,16 @@ pub enum Input<'a> {
255255
}
256256

257257
impl Input<'_> {
258+
/// Return the data type of elements in this tensor.
259+
pub fn dtype(&self) -> DataType {
260+
match self {
261+
Self::FloatTensor(_) => DataType::Float,
262+
Self::Int32Tensor(_) => DataType::Int32,
263+
Self::Int8Tensor(_) => DataType::Int8,
264+
Self::UInt8Tensor(_) => DataType::UInt8,
265+
}
266+
}
267+
258268
pub fn to_output(&self) -> Output {
259269
match self {
260270
Input::FloatTensor(t) => t.to_tensor().into(),

src/schema.fbs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ enum OperatorType: ubyte {
124124
MatMulInteger,
125125
DepthToSpace,
126126
ConvInteger,
127+
CastLike,
127128
}
128129

129130
enum RNNDirection: ubyte {
@@ -218,6 +219,7 @@ union OperatorAttrs {
218219
DequantizeLinearAttrs,
219220
QuantizeLinearAttrs,
220221
DepthToSpaceAttrs,
222+
CastLikeAttrs,
221223
}
222224

223225
table ArgMaxAttrs {
@@ -245,6 +247,8 @@ table CastAttrs {
245247
to:DataType;
246248
}
247249

250+
table CastLikeAttrs {}
251+
248252
table ConcatAttrs {
249253
axis:int;
250254
}

0 commit comments

Comments
 (0)