Skip to content

Pre-transpose constant MatMul operand #315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions rten-convert/rten_convert/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ConstantNode(Node):
"""

shape: list[int]
strides: Optional[list[int]]
data: np.ndarray

def __init__(self, name: str, shape: list[int], data: np.ndarray):
Expand Down Expand Up @@ -861,6 +862,12 @@ def op_node_from_onnx_operator(
op_reader.check_attr("input_forget", "int", 0)
op_reader.check_attr("layout", "int", 0)

case "MatMul":
b = constant_nodes.get(onnx_op.input[-1])
if b and len(b.shape) == 2 and b.shape[-1] > 1:
b.data = np.ascontiguousarray(b.data.transpose())
b.strides = [1, b.shape[0]]

case "MaxPool":
attrs = sg.MaxPoolAttrsT()
kernel_shape = op_reader.require_attr("kernel_shape", "ints")
Expand Down Expand Up @@ -1202,6 +1209,12 @@ def build_constant_node(
shape_vec = write_vec(
builder, sg.ConstantNodeStartShapeVector, constant.shape, "u32"
)
if getattr(constant, "strides", None):
strides_vec = write_vec(
builder, sg.ConstantNodeStartStridesVector, constant.strides, "u32"
)
else:
strides_vec = None
n_elems = reduce(mul, constant.shape, 1)
assert n_elems == constant.data.size, "constant shape does not match element count"

Expand Down Expand Up @@ -1261,6 +1274,8 @@ def build_constant_node(
sg.ConstantNodeStart(builder)
sg.ConstantNodeAddShape(builder, shape_vec)
sg.ConstantNodeAddDtype(builder, dtype)
if strides_vec:
sg.ConstantNodeAddStrides(builder, strides_vec)

if inline_data:
sg.ConstantNodeAddDataType(builder, inline_data_type)
Expand Down
157 changes: 104 additions & 53 deletions rten-convert/rten_convert/schema_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,91 +205,91 @@ def OperatorAttrsCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == OperatorAttrs().ArgMaxAttrs:
if unionType == OperatorAttrs.ArgMaxAttrs:
return ArgMaxAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().AveragePoolAttrs:
if unionType == OperatorAttrs.AveragePoolAttrs:
return AveragePoolAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().BatchNormalizationAttrs:
if unionType == OperatorAttrs.BatchNormalizationAttrs:
return BatchNormalizationAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().CastAttrs:
if unionType == OperatorAttrs.CastAttrs:
return CastAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ConcatAttrs:
if unionType == OperatorAttrs.ConcatAttrs:
return ConcatAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ConstantOfShapeAttrs:
if unionType == OperatorAttrs.ConstantOfShapeAttrs:
return ConstantOfShapeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ConvAttrs:
if unionType == OperatorAttrs.ConvAttrs:
return ConvAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ConvTransposeAttrs:
if unionType == OperatorAttrs.ConvTransposeAttrs:
return ConvTransposeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().FlattenAttrs:
if unionType == OperatorAttrs.FlattenAttrs:
return FlattenAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GatherAttrs:
if unionType == OperatorAttrs.GatherAttrs:
return GatherAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GemmAttrs:
if unionType == OperatorAttrs.GemmAttrs:
return GemmAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GRUAttrs:
if unionType == OperatorAttrs.GRUAttrs:
return GRUAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().LeakyReluAttrs:
if unionType == OperatorAttrs.LeakyReluAttrs:
return LeakyReluAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().LSTMAttrs:
if unionType == OperatorAttrs.LSTMAttrs:
return LSTMAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().MaxPoolAttrs:
if unionType == OperatorAttrs.MaxPoolAttrs:
return MaxPoolAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ReduceMeanAttrs:
if unionType == OperatorAttrs.ReduceMeanAttrs:
return ReduceMeanAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ReshapeAttrs:
if unionType == OperatorAttrs.ReshapeAttrs:
return ReshapeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ResizeAttrs:
if unionType == OperatorAttrs.ResizeAttrs:
return ResizeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().SplitAttrs:
if unionType == OperatorAttrs.SplitAttrs:
return SplitAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().SoftmaxAttrs:
if unionType == OperatorAttrs.SoftmaxAttrs:
return SoftmaxAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().TransposeAttrs:
if unionType == OperatorAttrs.TransposeAttrs:
return TransposeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ModAttrs:
if unionType == OperatorAttrs.ModAttrs:
return ModAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ScatterElementsAttrs:
if unionType == OperatorAttrs.ScatterElementsAttrs:
return ScatterElementsAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().OneHotAttrs:
if unionType == OperatorAttrs.OneHotAttrs:
return OneHotAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().TopKAttrs:
if unionType == OperatorAttrs.TopKAttrs:
return TopKAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().HardSigmoidAttrs:
if unionType == OperatorAttrs.HardSigmoidAttrs:
return HardSigmoidAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().TriluAttrs:
if unionType == OperatorAttrs.TriluAttrs:
return TriluAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().ScatterNDAttrs:
if unionType == OperatorAttrs.ScatterNDAttrs:
return ScatterNDAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().NonMaxSuppressionAttrs:
if unionType == OperatorAttrs.NonMaxSuppressionAttrs:
return NonMaxSuppressionAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().LayerNormalizationAttrs:
if unionType == OperatorAttrs.LayerNormalizationAttrs:
return LayerNormalizationAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().RandomUniformAttrs:
if unionType == OperatorAttrs.RandomUniformAttrs:
return RandomUniformAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().EluAttrs:
if unionType == OperatorAttrs.EluAttrs:
return EluAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().RandomUniformLikeAttrs:
if unionType == OperatorAttrs.RandomUniformLikeAttrs:
return RandomUniformLikeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().RandomNormalAttrs:
if unionType == OperatorAttrs.RandomNormalAttrs:
return RandomNormalAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().RandomNormalLikeAttrs:
if unionType == OperatorAttrs.RandomNormalLikeAttrs:
return RandomNormalLikeAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GatherNDAttrs:
if unionType == OperatorAttrs.GatherNDAttrs:
return GatherNDAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().GeluAttrs:
if unionType == OperatorAttrs.GeluAttrs:
return GeluAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().EinsumAttrs:
if unionType == OperatorAttrs.EinsumAttrs:
return EinsumAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().IfAttrs:
if unionType == OperatorAttrs.IfAttrs:
return IfAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().PadAttrs:
if unionType == OperatorAttrs.PadAttrs:
return PadAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().DequantizeLinearAttrs:
if unionType == OperatorAttrs.DequantizeLinearAttrs:
return DequantizeLinearAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().QuantizeLinearAttrs:
if unionType == OperatorAttrs.QuantizeLinearAttrs:
return QuantizeLinearAttrsT.InitFromBuf(table.Bytes, table.Pos)
if unionType == OperatorAttrs().DepthToSpaceAttrs:
if unionType == OperatorAttrs.DepthToSpaceAttrs:
return DepthToSpaceAttrsT.InitFromBuf(table.Bytes, table.Pos)
return None

Expand All @@ -308,9 +308,9 @@ def ScalarCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == Scalar().IntScalar:
if unionType == Scalar.IntScalar:
return IntScalarT.InitFromBuf(table.Bytes, table.Pos)
if unionType == Scalar().FloatScalar:
if unionType == Scalar.FloatScalar:
return FloatScalarT.InitFromBuf(table.Bytes, table.Pos)
return None

Expand Down Expand Up @@ -343,11 +343,11 @@ def NodeKindCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == NodeKind().OperatorNode:
if unionType == NodeKind.OperatorNode:
return OperatorNodeT.InitFromBuf(table.Bytes, table.Pos)
if unionType == NodeKind().ConstantNode:
if unionType == NodeKind.ConstantNode:
return ConstantNodeT.InitFromBuf(table.Bytes, table.Pos)
if unionType == NodeKind().ValueNode:
if unionType == NodeKind.ValueNode:
return ValueNodeT.InitFromBuf(table.Bytes, table.Pos)
return None

Expand All @@ -363,13 +363,13 @@ def ConstantDataCreator(unionType, table):
from flatbuffers.table import Table
if not isinstance(table, Table):
return None
if unionType == ConstantData().FloatData:
if unionType == ConstantData.FloatData:
return FloatDataT.InitFromBuf(table.Bytes, table.Pos)
if unionType == ConstantData().Int32Data:
if unionType == ConstantData.Int32Data:
return Int32DataT.InitFromBuf(table.Bytes, table.Pos)
if unionType == ConstantData().Int8Data:
if unionType == ConstantData.Int8Data:
return Int8DataT.InitFromBuf(table.Bytes, table.Pos)
if unionType == ConstantData().UInt8Data:
if unionType == ConstantData.UInt8Data:
return UInt8DataT.InitFromBuf(table.Bytes, table.Pos)
return None

Expand Down Expand Up @@ -5784,8 +5784,35 @@ def DataOffset(self):
return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos)
return None

# ConstantNode
def Strides(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0

# ConstantNode
def StridesAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
return 0

# ConstantNode
def StridesLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
if o != 0:
return self._tab.VectorLen(o)
return 0

# ConstantNode
def StridesIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
return o == 0

def ConstantNodeStart(builder):
builder.StartObject(5)
builder.StartObject(6)

def ConstantNodeAddShape(builder, shape):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)
Expand All @@ -5805,6 +5832,12 @@ def ConstantNodeAddDtype(builder, dtype):
def ConstantNodeAddDataOffset(builder, dataOffset):
builder.PrependUint64Slot(4, dataOffset, None)

def ConstantNodeAddStrides(builder, strides):
builder.PrependUOffsetTRelativeSlot(5, flatbuffers.number_types.UOffsetTFlags.py_type(strides), 0)

def ConstantNodeStartStridesVector(builder, numElems):
return builder.StartVector(4, numElems, 4)

def ConstantNodeEnd(builder):
return builder.EndObject()

Expand All @@ -5823,6 +5856,7 @@ def __init__(self):
self.data = None # type: Union[None, FloatDataT, Int32DataT, Int8DataT, UInt8DataT]
self.dtype = None # type: Optional[int]
self.dataOffset = None # type: Optional[int]
self.strides = None # type: List[int]

@classmethod
def InitFromBuf(cls, buf, pos):
Expand Down Expand Up @@ -5856,6 +5890,13 @@ def _UnPack(self, constantNode):
self.data = ConstantDataCreator(self.dataType, constantNode.Data())
self.dtype = constantNode.Dtype()
self.dataOffset = constantNode.DataOffset()
if not constantNode.StridesIsNone():
if np is None:
self.strides = []
for i in range(constantNode.StridesLength()):
self.strides.append(constantNode.Strides(i))
else:
self.strides = constantNode.StridesAsNumpy()

# ConstantNodeT
def Pack(self, builder):
Expand All @@ -5869,6 +5910,14 @@ def Pack(self, builder):
shape = builder.EndVector()
if self.data is not None:
data = self.data.Pack(builder)
if self.strides is not None:
if np is not None and type(self.strides) is np.ndarray:
strides = builder.CreateNumpyVector(self.strides)
else:
ConstantNodeStartStridesVector(builder, len(self.strides))
for i in reversed(range(len(self.strides))):
builder.PrependUint32(self.strides[i])
strides = builder.EndVector()
ConstantNodeStart(builder)
if self.shape is not None:
ConstantNodeAddShape(builder, shape)
Expand All @@ -5877,6 +5926,8 @@ def Pack(self, builder):
ConstantNodeAddData(builder, data)
ConstantNodeAddDtype(builder, self.dtype)
ConstantNodeAddDataOffset(builder, self.dataOffset)
if self.strides is not None:
ConstantNodeAddStrides(builder, strides)
constantNode = ConstantNodeEnd(builder)
return constantNode

Expand Down
Loading
Loading