@@ -350,6 +350,8 @@ typedef union TfLitePtrUnion {
350350// as constant inputs for downstream ops (also in prepare).
351351// * kTfLiteCustom: Custom memory allocation provided by the user. See
352352// TfLiteCustomAllocation below.
353+ // * kTfLiteVariantObject: Allocation is an arbitrary type-erased C++ object.
354+ // Allocation and deallocation are done through `new` and `delete`.
353355typedef enum TfLiteAllocationType {
354356 kTfLiteMemNone = 0 ,
355357 kTfLiteMmapRo ,
@@ -358,6 +360,7 @@ typedef enum TfLiteAllocationType {
358360 kTfLiteDynamic ,
359361 kTfLitePersistentRo ,
360362 kTfLiteCustom ,
363+ kTfLiteVariantObject ,
361364} TfLiteAllocationType ;
362365
363366// The delegates should use zero or positive integers to represent handles.
@@ -1201,5 +1204,74 @@ void* TfLiteOpaqueDelegateGetData(const TfLiteOpaqueDelegate* delegate);
12011204
12021205#ifdef __cplusplus
12031206} // extern "C"
1207+
1208+ #include <utility>
1209+
1210+ // `kTfLiteVariant` type tensors encode arbitrary C++ objects behind their
1211+ // `data.data : void*` member. This is the type-erased interface for interacting
1212+ // with such objects at runtime. Deleting or Cloning any `VariantData`
1213+ // will call the destructor and copy constructor of the erased type
1214+ // automatically. For example usage, see `common_test.cc`.
1215+ class VariantData {
1216+ public :
1217+ // All variant objects must be able to be destroyed and copied.
1218+ virtual ~VariantData () = default ;
1219+ // This allows for a "virtual copy-constructor" like pattern.
1220+ // In most cases, we will be copying from an input to an output tensor.
1221+ // Often, the output tensor is already allocated so we can pass
1222+ // a pointer to its buffer for reuse.
1223+ virtual VariantData * Clone (char * maybe_alloc ) const = 0 ;
1224+ };
1225+
1226+ // An abstract base class for variant objects. The template parameter
1227+ // is the type we are erasing.
1228+ template < typename ErasedDerived >
1229+ class AbstractVariantData : public VariantData {
1230+ public :
1231+ VariantData * Clone (char * maybe_alloc ) const override {
1232+ if (maybe_alloc ) {
1233+ // We assume that the output tensor is already a variant of the same
1234+ // derived type. If the output is still allocated, then it still may have
1235+ // state that was not destroyed, so we must call the destructor before
1236+ // using the buffer.
1237+ // This may actual have a non-negligle effect on perfomance if the
1238+ // destructor is complex. In a future optimization we would want to
1239+ // introduce something like "move to" semantics, allowing for the
1240+ // underlying implementation to optimize for this case.
1241+ reinterpret_cast < VariantData * > (maybe_alloc )-> ~VariantData ();
1242+ return new (maybe_alloc )
1243+ ErasedDerived (static_cast < ErasedDerived const & > (* this ));
1244+ }
1245+ return new ErasedDerived (static_cast < ErasedDerived const & > (* this ));
1246+ }
1247+
1248+ protected :
1249+ AbstractVariantData () = default ;
1250+ AbstractVariantData (const AbstractVariantData & ) = default ;
1251+ AbstractVariantData (AbstractVariantData && ) = delete ;
1252+ };
1253+
1254+ // Analogous to `TfLiteTensorRealloc` for allocation of tensors whose
1255+ // data member points to an arbitrary C++ object. `VariantType` refers
1256+ // to the erased type of said object and `VariantArgs` refers to
1257+ // a list of argument types with which to construct a new `VariantType`
1258+ // `VariantArgs` must match constructor in `VariantType`.
1259+ template < class VariantType , class ... VariantArgs >
1260+ TfLiteStatus TfLiteTensorVariantRealloc (TfLiteTensor * t ,
1261+ VariantArgs && ... args ) {
1262+ if (t -> type != kTfLiteVariant ) return kTfLiteError ;
1263+ if (t -> data .raw ) {
1264+ reinterpret_cast < VariantData * > (t -> data .data )-> ~VariantData ();
1265+ // For now we assume if `t` is already allocated then it was allocated
1266+ // with the same `VariantType` as templated.
1267+ t -> data .data =
1268+ new (t -> data .raw ) VariantType (std ::forward < VariantArgs ...> (args ...));
1269+ } else {
1270+ t -> data .data = new VariantType (std ::forward < VariantArgs ...> (args ...));
1271+ }
1272+ t -> allocation_type = kTfLiteVariantObject ;
1273+ return kTfLiteOk ;
1274+ }
1275+
12041276#endif // __cplusplus
12051277#endif // TENSORFLOW_LITE_CORE_C_COMMON_H_
0 commit comments