Skip to content

Commit fa5469c

Browse files
committed
Pre-transpose MatMul's RHS operands
1 parent 2024f4e commit fa5469c

File tree

6 files changed

+152
-19
lines changed

6 files changed

+152
-19
lines changed

rten-convert/rten_convert/converter.py

+15
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class ConstantNode(Node):
5555
"""
5656

5757
shape: list[int]
58+
strides: Optional[list[int]]
5859
data: np.ndarray
5960

6061
def __init__(self, name: str, shape: list[int], data: np.ndarray):
@@ -811,6 +812,12 @@ def op_node_from_onnx_operator(
811812
op_reader.check_attr("input_forget", "int", 0)
812813
op_reader.check_attr("layout", "int", 0)
813814

815+
case "MatMul":
816+
b = constant_nodes.get(onnx_op.input[-1])
817+
if b and len(b.shape) == 2 and b.shape[-1] > 1:
818+
b.data = np.ascontiguousarray(b.data.transpose())
819+
b.strides = [1, b.shape[0]]
820+
814821
case "MaxPool":
815822
attrs = sg.MaxPoolAttrsT()
816823
kernel_shape = op_reader.require_attr("kernel_shape", "ints")
@@ -1141,6 +1148,12 @@ def build_constant_node(
11411148
shape_vec = write_vec(
11421149
builder, sg.ConstantNodeStartShapeVector, constant.shape, "u32"
11431150
)
1151+
if getattr(constant, "strides", None):
1152+
strides_vec = write_vec(
1153+
builder, sg.ConstantNodeStartStridesVector, constant.strides, "u32"
1154+
)
1155+
else:
1156+
strides_vec = None
11441157
n_elems = reduce(mul, constant.shape, 1)
11451158
assert n_elems == constant.data.size, "constant shape does not match element count"
11461159

@@ -1182,6 +1195,8 @@ def build_constant_node(
11821195
sg.ConstantNodeStart(builder)
11831196
sg.ConstantNodeAddShape(builder, shape_vec)
11841197
sg.ConstantNodeAddDtype(builder, dtype)
1198+
if strides_vec:
1199+
sg.ConstantNodeAddStrides(builder, strides_vec)
11851200

11861201
if inline_data:
11871202
sg.ConstantNodeAddDataType(builder, inline_data_type)

rten-convert/rten_convert/schema_generated.py

+60-9
Original file line numberDiff line numberDiff line change
@@ -5145,15 +5145,42 @@ def ShapeIsNone(self):
51455145
return o == 0
51465146

51475147
# ConstantNode
5148-
def DataType(self):
5148+
def Strides(self, j):
5149+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
5150+
if o != 0:
5151+
a = self._tab.Vector(o)
5152+
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
5153+
return 0
5154+
5155+
# ConstantNode
5156+
def StridesAsNumpy(self):
51495157
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
5158+
if o != 0:
5159+
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
5160+
return 0
5161+
5162+
# ConstantNode
5163+
def StridesLength(self):
5164+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
5165+
if o != 0:
5166+
return self._tab.VectorLen(o)
5167+
return 0
5168+
5169+
# ConstantNode
5170+
def StridesIsNone(self):
5171+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
5172+
return o == 0
5173+
5174+
# ConstantNode
5175+
def DataType(self):
5176+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
51505177
if o != 0:
51515178
return self._tab.Get(flatbuffers.number_types.Uint8Flags, o + self._tab.Pos)
51525179
return 0
51535180

51545181
# ConstantNode
51555182
def Data(self):
5156-
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
5183+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
51575184
if o != 0:
51585185
from flatbuffers.table import Table
51595186
obj = Table(bytearray(), 0)
@@ -5163,38 +5190,44 @@ def Data(self):
51635190

51645191
# ConstantNode
51655192
def Dtype(self):
5166-
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
5193+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
51675194
if o != 0:
51685195
return self._tab.Get(flatbuffers.number_types.Uint16Flags, o + self._tab.Pos)
51695196
return None
51705197

51715198
# ConstantNode
51725199
def DataOffset(self):
5173-
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
5200+
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(14))
51745201
if o != 0:
51755202
return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos)
51765203
return None
51775204

51785205
def ConstantNodeStart(builder):
5179-
builder.StartObject(5)
5206+
builder.StartObject(6)
51805207

51815208
def ConstantNodeAddShape(builder, shape):
51825209
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(shape), 0)
51835210

51845211
def ConstantNodeStartShapeVector(builder, numElems):
51855212
return builder.StartVector(4, numElems, 4)
51865213

5214+
def ConstantNodeAddStrides(builder, strides):
5215+
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(strides), 0)
5216+
5217+
def ConstantNodeStartStridesVector(builder, numElems):
5218+
return builder.StartVector(4, numElems, 4)
5219+
51875220
def ConstantNodeAddDataType(builder, dataType):
5188-
builder.PrependUint8Slot(1, dataType, 0)
5221+
builder.PrependUint8Slot(2, dataType, 0)
51895222

51905223
def ConstantNodeAddData(builder, data):
5191-
builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0)
5224+
builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(data), 0)
51925225

51935226
def ConstantNodeAddDtype(builder, dtype):
5194-
builder.PrependUint16Slot(3, dtype, None)
5227+
builder.PrependUint16Slot(4, dtype, None)
51955228

51965229
def ConstantNodeAddDataOffset(builder, dataOffset):
5197-
builder.PrependUint64Slot(4, dataOffset, None)
5230+
builder.PrependUint64Slot(5, dataOffset, None)
51985231

51995232
def ConstantNodeEnd(builder):
52005233
return builder.EndObject()
@@ -5210,6 +5243,7 @@ class ConstantNodeT(object):
52105243
# ConstantNodeT
52115244
def __init__(self):
52125245
self.shape = None # type: List[int]
5246+
self.strides = None # type: List[int]
52135247
self.dataType = 0 # type: int
52145248
self.data = None # type: Union[None, FloatDataT, IntDataT]
52155249
self.dtype = None # type: Optional[int]
@@ -5243,6 +5277,13 @@ def _UnPack(self, constantNode):
52435277
self.shape.append(constantNode.Shape(i))
52445278
else:
52455279
self.shape = constantNode.ShapeAsNumpy()
5280+
if not constantNode.StridesIsNone():
5281+
if np is None:
5282+
self.strides = []
5283+
for i in range(constantNode.StridesLength()):
5284+
self.strides.append(constantNode.Strides(i))
5285+
else:
5286+
self.strides = constantNode.StridesAsNumpy()
52465287
self.dataType = constantNode.DataType()
52475288
self.data = ConstantDataCreator(self.dataType, constantNode.Data())
52485289
self.dtype = constantNode.Dtype()
@@ -5258,11 +5299,21 @@ def Pack(self, builder):
52585299
for i in reversed(range(len(self.shape))):
52595300
builder.PrependUint32(self.shape[i])
52605301
shape = builder.EndVector()
5302+
if self.strides is not None:
5303+
if np is not None and type(self.strides) is np.ndarray:
5304+
strides = builder.CreateNumpyVector(self.strides)
5305+
else:
5306+
ConstantNodeStartStridesVector(builder, len(self.strides))
5307+
for i in reversed(range(len(self.strides))):
5308+
builder.PrependUint32(self.strides[i])
5309+
strides = builder.EndVector()
52615310
if self.data is not None:
52625311
data = self.data.Pack(builder)
52635312
ConstantNodeStart(builder)
52645313
if self.shape is not None:
52655314
ConstantNodeAddShape(builder, shape)
5315+
if self.strides is not None:
5316+
ConstantNodeAddStrides(builder, strides)
52665317
ConstantNodeAddDataType(builder, self.dataType)
52675318
if self.data is not None:
52685319
ConstantNodeAddData(builder, data)

src/model.rs

+40-6
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,9 @@ impl Model {
473473
tensor_data_offset: Option<u64>,
474474
) -> Result<NodeId, ModelLoadError> {
475475
let shape: Vec<usize> = constant.shape().iter().map(|x| x as usize).collect();
476+
let strides: Option<Vec<usize>> = constant
477+
.strides()
478+
.map(|strides| strides.iter().map(|x| x as usize).collect());
476479

477480
if let Some(data_offset) = constant.data_offset() {
478481
// Constant data is stored outside the model buffer, in the same file.
@@ -486,13 +489,21 @@ impl Model {
486489

487490
let graph_node = match constant.dtype() {
488491
Some(sg::ConstantDataType::Int32) => {
489-
let const_data =
490-
constant_data_from_storage_offset::<i32>(storage, &shape, data_offset)?;
492+
let const_data = constant_data_from_storage_offset::<i32>(
493+
storage,
494+
&shape,
495+
strides.as_deref(),
496+
data_offset,
497+
)?;
491498
graph.add_constant(name, const_data)
492499
}
493500
Some(sg::ConstantDataType::Float32) => {
494-
let const_data =
495-
constant_data_from_storage_offset::<f32>(storage, &shape, data_offset)?;
501+
let const_data = constant_data_from_storage_offset::<f32>(
502+
storage,
503+
&shape,
504+
strides.as_deref(),
505+
data_offset,
506+
)?;
496507
graph.add_constant(name, const_data)
497508
}
498509
_ => {
@@ -717,6 +728,7 @@ fn transmute_bytes<T: Pod>(bytes: &[u8]) -> Option<&[T]> {
717728
fn constant_data_from_storage_offset<T: LeBytes + Pod>(
718729
storage: &Arc<ConstantStorage>,
719730
shape: &[usize],
731+
strides: Option<&[usize]>,
720732
offset: usize,
721733
) -> Result<ConstantNodeData<T>, ModelLoadError> {
722734
let n_elements: usize = shape.iter().product();
@@ -731,14 +743,36 @@ fn constant_data_from_storage_offset<T: LeBytes + Pod>(
731743
if let Some(elements) = transmute_bytes(bytes) {
732744
let storage =
733745
ArcSlice::new(storage.clone(), elements).expect("storage does not contain data");
734-
let const_data: ConstantNodeData<T> = ArcTensorView::from_data(shape, storage).into();
746+
let const_data: ConstantNodeData<T> = if let Some(strides) = strides {
747+
ArcTensorView::from_data_with_strides(shape, storage, strides)
748+
.map_err(|_| {
749+
ModelLoadError::GraphError(format!(
750+
"bad strides = {:?}, shape = {:?}",
751+
strides, shape
752+
))
753+
})?
754+
.into()
755+
} else {
756+
ArcTensorView::from_data(shape, storage).into()
757+
};
735758
Ok(const_data)
736759
} else {
737760
let data: Vec<T> = bytes
738761
.chunks(std::mem::size_of::<T>())
739762
.map(|chunk| T::from_le_bytes(chunk.try_into().unwrap()))
740763
.collect();
741-
Ok(Tensor::from_data(shape, data).into())
764+
Ok(if let Some(strides) = strides {
765+
Tensor::from_data_with_strides(shape, data, strides)
766+
.map_err(|_| {
767+
ModelLoadError::GraphError(format!(
768+
"bad strides = {:?}, shape = {:?}",
769+
strides, shape
770+
))
771+
})?
772+
.into()
773+
} else {
774+
Tensor::from_data(shape, data).into()
775+
})
742776
}
743777
}
744778

src/model_builder.rs

+2
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
283283

284284
sg::ConstantNodeArgs {
285285
shape: Some(shape_vec),
286+
strides: None,
286287
data_type: sg::ConstantData::NONE,
287288
data: None,
288289
data_offset: Some(offset),
@@ -294,6 +295,7 @@ impl<'mb, 'a> GraphBuilder<'mb, 'a> {
294295

295296
sg::ConstantNodeArgs {
296297
shape: Some(shape_vec),
298+
strides: None,
297299
data_type: inline_dtype,
298300
data: Some(data),
299301
data_offset: None,

src/schema.fbs

+1
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ enum ConstantDataType: ushort {
480480
// Graph node for a constant tensor value, whose data is part of the model.
481481
table ConstantNode {
482482
shape:[uint] (required);
483+
strides:[uint];
483484

484485
// Tensor data embedded within the model file.
485486
data:ConstantData;

src/schema_generated.rs

+34-4
Original file line numberDiff line numberDiff line change
@@ -8734,10 +8734,11 @@ impl<'a> flatbuffers::Follow<'a> for ConstantNode<'a> {
87348734

87358735
impl<'a> ConstantNode<'a> {
87368736
pub const VT_SHAPE: flatbuffers::VOffsetT = 4;
8737-
pub const VT_DATA_TYPE: flatbuffers::VOffsetT = 6;
8738-
pub const VT_DATA: flatbuffers::VOffsetT = 8;
8739-
pub const VT_DTYPE: flatbuffers::VOffsetT = 10;
8740-
pub const VT_DATA_OFFSET: flatbuffers::VOffsetT = 12;
8737+
pub const VT_STRIDES: flatbuffers::VOffsetT = 6;
8738+
pub const VT_DATA_TYPE: flatbuffers::VOffsetT = 8;
8739+
pub const VT_DATA: flatbuffers::VOffsetT = 10;
8740+
pub const VT_DTYPE: flatbuffers::VOffsetT = 12;
8741+
pub const VT_DATA_OFFSET: flatbuffers::VOffsetT = 14;
87418742

87428743
#[inline]
87438744
pub unsafe fn init_from_table(table: flatbuffers::Table<'a>) -> Self {
@@ -8755,6 +8756,9 @@ impl<'a> ConstantNode<'a> {
87558756
if let Some(x) = args.data {
87568757
builder.add_data(x);
87578758
}
8759+
if let Some(x) = args.strides {
8760+
builder.add_strides(x);
8761+
}
87588762
if let Some(x) = args.shape {
87598763
builder.add_shape(x);
87608764
}
@@ -8780,6 +8784,19 @@ impl<'a> ConstantNode<'a> {
87808784
}
87818785
}
87828786
#[inline]
8787+
pub fn strides(&self) -> Option<flatbuffers::Vector<'a, u32>> {
8788+
// Safety:
8789+
// Created from valid Table for this object
8790+
// which contains a valid value in this slot
8791+
unsafe {
8792+
self._tab
8793+
.get::<flatbuffers::ForwardsUOffset<flatbuffers::Vector<'a, u32>>>(
8794+
ConstantNode::VT_STRIDES,
8795+
None,
8796+
)
8797+
}
8798+
}
8799+
#[inline]
87838800
pub fn data_type(&self) -> ConstantData {
87848801
// Safety:
87858802
// Created from valid Table for this object
@@ -8864,6 +8881,11 @@ impl flatbuffers::Verifiable for ConstantNode<'_> {
88648881
Self::VT_SHAPE,
88658882
true,
88668883
)?
8884+
.visit_field::<flatbuffers::ForwardsUOffset<flatbuffers::Vector<'_, u32>>>(
8885+
"strides",
8886+
Self::VT_STRIDES,
8887+
false,
8888+
)?
88678889
.visit_union::<ConstantData, _>(
88688890
"data_type",
88698891
Self::VT_DATA_TYPE,
@@ -8892,6 +8914,7 @@ impl flatbuffers::Verifiable for ConstantNode<'_> {
88928914
}
88938915
pub struct ConstantNodeArgs<'a> {
88948916
pub shape: Option<flatbuffers::WIPOffset<flatbuffers::Vector<'a, u32>>>,
8917+
pub strides: Option<flatbuffers::WIPOffset<flatbuffers::Vector<'a, u32>>>,
88958918
pub data_type: ConstantData,
88968919
pub data: Option<flatbuffers::WIPOffset<flatbuffers::UnionWIPOffset>>,
88978920
pub dtype: Option<ConstantDataType>,
@@ -8902,6 +8925,7 @@ impl<'a> Default for ConstantNodeArgs<'a> {
89028925
fn default() -> Self {
89038926
ConstantNodeArgs {
89048927
shape: None, // required field
8928+
strides: None,
89058929
data_type: ConstantData::NONE,
89068930
data: None,
89078931
dtype: None,
@@ -8921,6 +8945,11 @@ impl<'a: 'b, 'b, A: flatbuffers::Allocator + 'a> ConstantNodeBuilder<'a, 'b, A>
89218945
.push_slot_always::<flatbuffers::WIPOffset<_>>(ConstantNode::VT_SHAPE, shape);
89228946
}
89238947
#[inline]
8948+
pub fn add_strides(&mut self, strides: flatbuffers::WIPOffset<flatbuffers::Vector<'b, u32>>) {
8949+
self.fbb_
8950+
.push_slot_always::<flatbuffers::WIPOffset<_>>(ConstantNode::VT_STRIDES, strides);
8951+
}
8952+
#[inline]
89248953
pub fn add_data_type(&mut self, data_type: ConstantData) {
89258954
self.fbb_.push_slot::<ConstantData>(
89268955
ConstantNode::VT_DATA_TYPE,
@@ -8965,6 +8994,7 @@ impl core::fmt::Debug for ConstantNode<'_> {
89658994
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
89668995
let mut ds = f.debug_struct("ConstantNode");
89678996
ds.field("shape", &self.shape());
8997+
ds.field("strides", &self.strides());
89688998
ds.field("data_type", &self.data_type());
89698999
match self.data_type() {
89709000
ConstantData::FloatData => {

0 commit comments

Comments
 (0)