@@ -5145,15 +5145,42 @@ def ShapeIsNone(self):
5145
5145
return o == 0
5146
5146
5147
5147
# ConstantNode
5148
- def DataType (self ):
5148
+ def Strides (self , j ):
5149
+ o = flatbuffers .number_types .UOffsetTFlags .py_type (self ._tab .Offset (6 ))
5150
+ if o != 0 :
5151
+ a = self ._tab .Vector (o )
5152
+ return self ._tab .Get (flatbuffers .number_types .Uint32Flags , a + flatbuffers .number_types .UOffsetTFlags .py_type (j * 4 ))
5153
+ return 0
5154
+
5155
+ # ConstantNode
5156
+ def StridesAsNumpy (self ):
5149
5157
o = flatbuffers .number_types .UOffsetTFlags .py_type (self ._tab .Offset (6 ))
5158
+ if o != 0 :
5159
+ return self ._tab .GetVectorAsNumpy (flatbuffers .number_types .Uint32Flags , o )
5160
+ return 0
5161
+
5162
+ # ConstantNode
5163
+ def StridesLength (self ):
5164
+ o = flatbuffers .number_types .UOffsetTFlags .py_type (self ._tab .Offset (6 ))
5165
+ if o != 0 :
5166
+ return self ._tab .VectorLen (o )
5167
+ return 0
5168
+
5169
+ # ConstantNode
5170
+ def StridesIsNone (self ):
5171
+ o = flatbuffers .number_types .UOffsetTFlags .py_type (self ._tab .Offset (6 ))
5172
+ return o == 0
5173
+
5174
+ # ConstantNode
5175
+ def DataType (self ):
5176
+ o = flatbuffers .number_types .UOffsetTFlags .py_type (self ._tab .Offset (8 ))
5150
5177
if o != 0 :
5151
5178
return self ._tab .Get (flatbuffers .number_types .Uint8Flags , o + self ._tab .Pos )
5152
5179
return 0
5153
5180
5154
5181
# ConstantNode
5155
5182
def Data (self ):
5156
- o = flatbuffers .number_types .UOffsetTFlags .py_type (self ._tab .Offset (8 ))
5183
+ o = flatbuffers .number_types .UOffsetTFlags .py_type (self ._tab .Offset (10 ))
5157
5184
if o != 0 :
5158
5185
from flatbuffers .table import Table
5159
5186
obj = Table (bytearray (), 0 )
@@ -5163,38 +5190,44 @@ def Data(self):
5163
5190
5164
5191
# ConstantNode
5165
5192
def Dtype (self ):
5166
- o = flatbuffers .number_types .UOffsetTFlags .py_type (self ._tab .Offset (10 ))
5193
+ o = flatbuffers .number_types .UOffsetTFlags .py_type (self ._tab .Offset (12 ))
5167
5194
if o != 0 :
5168
5195
return self ._tab .Get (flatbuffers .number_types .Uint16Flags , o + self ._tab .Pos )
5169
5196
return None
5170
5197
5171
5198
# ConstantNode
5172
5199
def DataOffset (self ):
5173
- o = flatbuffers .number_types .UOffsetTFlags .py_type (self ._tab .Offset (12 ))
5200
+ o = flatbuffers .number_types .UOffsetTFlags .py_type (self ._tab .Offset (14 ))
5174
5201
if o != 0 :
5175
5202
return self ._tab .Get (flatbuffers .number_types .Uint64Flags , o + self ._tab .Pos )
5176
5203
return None
5177
5204
5178
5205
def ConstantNodeStart (builder ):
5179
- builder .StartObject (5 )
5206
+ builder .StartObject (6 )
5180
5207
5181
5208
def ConstantNodeAddShape (builder , shape ):
5182
5209
builder .PrependUOffsetTRelativeSlot (0 , flatbuffers .number_types .UOffsetTFlags .py_type (shape ), 0 )
5183
5210
5184
5211
def ConstantNodeStartShapeVector (builder , numElems ):
5185
5212
return builder .StartVector (4 , numElems , 4 )
5186
5213
5214
+ def ConstantNodeAddStrides (builder , strides ):
5215
+ builder .PrependUOffsetTRelativeSlot (1 , flatbuffers .number_types .UOffsetTFlags .py_type (strides ), 0 )
5216
+
5217
+ def ConstantNodeStartStridesVector (builder , numElems ):
5218
+ return builder .StartVector (4 , numElems , 4 )
5219
+
5187
5220
def ConstantNodeAddDataType (builder , dataType ):
5188
- builder .PrependUint8Slot (1 , dataType , 0 )
5221
+ builder .PrependUint8Slot (2 , dataType , 0 )
5189
5222
5190
5223
def ConstantNodeAddData (builder , data ):
5191
- builder .PrependUOffsetTRelativeSlot (2 , flatbuffers .number_types .UOffsetTFlags .py_type (data ), 0 )
5224
+ builder .PrependUOffsetTRelativeSlot (3 , flatbuffers .number_types .UOffsetTFlags .py_type (data ), 0 )
5192
5225
5193
5226
def ConstantNodeAddDtype (builder , dtype ):
5194
- builder .PrependUint16Slot (3 , dtype , None )
5227
+ builder .PrependUint16Slot (4 , dtype , None )
5195
5228
5196
5229
def ConstantNodeAddDataOffset (builder , dataOffset ):
5197
- builder .PrependUint64Slot (4 , dataOffset , None )
5230
+ builder .PrependUint64Slot (5 , dataOffset , None )
5198
5231
5199
5232
def ConstantNodeEnd (builder ):
5200
5233
return builder .EndObject ()
@@ -5210,6 +5243,7 @@ class ConstantNodeT(object):
5210
5243
# ConstantNodeT
5211
5244
def __init__ (self ):
5212
5245
self .shape = None # type: List[int]
5246
+ self .strides = None # type: List[int]
5213
5247
self .dataType = 0 # type: int
5214
5248
self .data = None # type: Union[None, FloatDataT, IntDataT]
5215
5249
self .dtype = None # type: Optional[int]
@@ -5243,6 +5277,13 @@ def _UnPack(self, constantNode):
5243
5277
self .shape .append (constantNode .Shape (i ))
5244
5278
else :
5245
5279
self .shape = constantNode .ShapeAsNumpy ()
5280
+ if not constantNode .StridesIsNone ():
5281
+ if np is None :
5282
+ self .strides = []
5283
+ for i in range (constantNode .StridesLength ()):
5284
+ self .strides .append (constantNode .Strides (i ))
5285
+ else :
5286
+ self .strides = constantNode .StridesAsNumpy ()
5246
5287
self .dataType = constantNode .DataType ()
5247
5288
self .data = ConstantDataCreator (self .dataType , constantNode .Data ())
5248
5289
self .dtype = constantNode .Dtype ()
@@ -5258,11 +5299,21 @@ def Pack(self, builder):
5258
5299
for i in reversed (range (len (self .shape ))):
5259
5300
builder .PrependUint32 (self .shape [i ])
5260
5301
shape = builder .EndVector ()
5302
+ if self .strides is not None :
5303
+ if np is not None and type (self .strides ) is np .ndarray :
5304
+ strides = builder .CreateNumpyVector (self .strides )
5305
+ else :
5306
+ ConstantNodeStartStridesVector (builder , len (self .strides ))
5307
+ for i in reversed (range (len (self .strides ))):
5308
+ builder .PrependUint32 (self .strides [i ])
5309
+ strides = builder .EndVector ()
5261
5310
if self .data is not None :
5262
5311
data = self .data .Pack (builder )
5263
5312
ConstantNodeStart (builder )
5264
5313
if self .shape is not None :
5265
5314
ConstantNodeAddShape (builder , shape )
5315
+ if self .strides is not None :
5316
+ ConstantNodeAddStrides (builder , strides )
5266
5317
ConstantNodeAddDataType (builder , self .dataType )
5267
5318
if self .data is not None :
5268
5319
ConstantNodeAddData (builder , data )
0 commit comments